aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/RemoveAccesses.scala
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/passes/RemoveAccesses.scala
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/passes/RemoveAccesses.scala')
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala79
1 files changed, 44 insertions, 35 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index 18db5939..015346ff 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -2,7 +2,7 @@
package firrtl.passes
-import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubIndex, WSubField}
+import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubField, WSubIndex}
import firrtl.PrimOps.{And, Eq}
import firrtl.ir._
import firrtl.Mappers._
@@ -17,10 +17,12 @@ import scala.collection.mutable
object RemoveAccesses extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ZeroLengthVecs),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ZeroLengthVecs),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects)
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case Uniquify | ResolveKinds | ResolveFlows => true
@@ -28,8 +30,8 @@ object RemoveAccesses extends Pass {
}
private def AND(e1: Expression, e2: Expression) =
- if(e1 == one) e2
- else if(e2 == one) e1
+ if (e1 == one) e2
+ else if (e2 == one) e1
else DoPrim(And, Seq(e1, e2), Nil, BoolType)
private def EQV(e1: Expression, e2: Expression): Expression =
@@ -45,30 +47,35 @@ object RemoveAccesses extends Pass {
* Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
*/
private def getLocations(e: Expression): Seq[Location] = e match {
- case e: WRef => create_exps(e).map(Location(_,one))
+ case e: WRef => create_exps(e).map(Location(_, one))
case e: WSubIndex =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
val stride = get_size(e.expr.tpe)
- for ((l, i) <- ls.zipWithIndex
- if ((i % stride) >= start) & ((i % stride) < end)) yield l
+ for (
+ (l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)
+ ) yield l
case e: WSubField =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
val stride = get_size(e.expr.tpe)
- for ((l, i) <- ls.zipWithIndex
- if ((i % stride) >= start) & ((i % stride) < end)) yield l
+ for (
+ (l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)
+ ) yield l
case e: WSubAccess =>
val ls = getLocations(e.expr)
val stride = get_size(e.tpe)
val wrap = e.expr.tpe.asInstanceOf[VectorType].size
- ls.zipWithIndex map {case (l, i) =>
- val c = (i / stride) % wrap
- val basex = l.base
- val guardx = AND(l.guard,EQV(UIntLiteral(c),e.index))
- Location(basex,guardx)
+ ls.zipWithIndex.map {
+ case (l, i) =>
+ val c = (i / stride) % wrap
+ val basex = l.base
+ val guardx = AND(l.guard, EQV(UIntLiteral(c), e.index))
+ Location(basex, guardx)
}
}
@@ -78,10 +85,10 @@ object RemoveAccesses extends Pass {
var ret: Boolean = false
def rec_has_access(e: Expression): Expression = {
e match {
- case _ : WSubAccess => ret = true
+ case _: WSubAccess => ret = true
case _ =>
}
- e map rec_has_access
+ e.map(rec_has_access)
}
rec_has_access(e)
ret
@@ -90,7 +97,7 @@ object RemoveAccesses extends Pass {
// This improves the performance of this pass
private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]()
private def create_exps(e: Expression) =
- createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e))
+ createExpsCache.getOrElseUpdate(e, firrtl.Utils.create_exps(e))
def run(c: Circuit): Circuit = {
def remove_m(m: Module): Module = {
@@ -105,21 +112,21 @@ object RemoveAccesses extends Pass {
*/
val stmts = mutable.ArrayBuffer[Statement]()
def removeSource(e: Expression): Expression = e match {
- case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(e) =>
+ case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) =>
val rs = getLocations(e)
- rs find (x => x.guard != one) match {
+ rs.find(x => x.guard != one) match {
case None => throwInternalError(s"removeSource: shouldn't be here - $e")
case Some(_) =>
val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
stmts += wire
- rs.zipWithIndex foreach {
+ rs.zipWithIndex.foreach {
case (x, i) if i < temps.size =>
- stmts += IsInvalid(get_info(s),getTemp(i))
- stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
+ stmts += IsInvalid(get_info(s), getTemp(i))
+ stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
case (x, i) =>
- stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
+ stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
}
temp
}
@@ -129,14 +136,16 @@ object RemoveAccesses extends Pass {
/** Replaces a subaccess in a given sink expression
*/
def removeSink(info: Info, loc: Expression): Expression = loc match {
- case (_: WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(loc) =>
+ case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(loc) =>
val ls = getLocations(loc)
- if (ls.size == 1 & weq(ls.head.guard,one)) loc
+ if (ls.size == 1 & weq(ls.head.guard, one)) loc
else {
val (wire, temp) = create_temp(loc)
stmts += wire
- ls foreach (x => stmts +=
- Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt))
+ ls.foreach(x =>
+ stmts +=
+ Conditionally(info, x.guard, Connect(info, x.base, temp), EmptyStmt)
+ )
temp
}
case _ => loc
@@ -150,7 +159,7 @@ object RemoveAccesses extends Pass {
case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow))
//case w: WSubIndex => removeSource(w)
//case w: WSubField => removeSource(w)
- case x => x map fixSource
+ case x => x.map(fixSource)
}
/** Recursively walks a sink expression and fixes all subaccesses
@@ -159,13 +168,13 @@ object RemoveAccesses extends Pass {
*/
def fixSink(e: Expression): Expression = e match {
case w: WSubAccess => WSubAccess(fixSink(w.expr), fixSource(w.index), w.tpe, w.flow)
- case x => x map fixSink
+ case x => x.map(fixSink)
}
val sx = s match {
case Connect(info, loc, exp) =>
Connect(info, removeSink(info, fixSink(loc)), fixSource(exp))
- case sxx => sxx map fixSource map onStmt
+ case sxx => sxx.map(fixSource).map(onStmt)
}
stmts += sx
if (stmts.size != 1) Block(stmts.toSeq) else stmts(0)
@@ -173,9 +182,9 @@ object RemoveAccesses extends Pass {
Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body)))
}
- c copy (modules = c.modules map {
+ c.copy(modules = c.modules.map {
case m: ExtModule => m
- case m: Module => remove_m(m)
+ case m: Module => remove_m(m)
})
}
}