diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/RemoveAccesses.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 97 |
1 files changed, 48 insertions, 49 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index d3340f2d..a3ce49f7 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -13,93 +13,90 @@ import scala.collection.mutable */ object RemoveAccesses extends Pass { def name = "Remove Accesses" - /** Container for a base expression and its corresponding guard */ - case class Location(base: Expression, guard: Expression) + private case class Location(base: Expression, guard: Expression) /** Walks a referencing expression and returns a list of valid references * (base) and the corresponding guard which, if true, returns that base. * E.g. if called on a[i] where a: UInt[2], we would return: * Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1))) */ - def getLocations(e: Expression): Seq[Location] = e match { + private def getLocations(e: Expression): Seq[Location] = e match { case e: WRef => create_exps(e).map(Location(_,one)) case e: WSubIndex => val ls = getLocations(e.exp) val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { - lsx += ls(i) - } - } - lsx + val end = start + get_size(e.tpe) + val stride = get_size(e.exp.tpe) + for ((l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end)) yield l case e: WSubField => val ls = getLocations(e.exp) val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) } - } - lsx + val end = start + get_size(e.tpe) + val stride = get_size(e.exp.tpe) + for ((l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end)) yield l case e: WSubAccess => val ls = getLocations(e.exp) - val stride = get_size(tpe(e)) - val wrap = tpe(e.exp).asInstanceOf[VectorType].size - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { + val stride = get_size(e.tpe) + val wrap = e.exp.tpe.asInstanceOf[VectorType].size + ls.zipWithIndex map {case (l, i) => val c = (i / stride) % wrap - val basex = ls(i).base - val guardx = AND(ls(i).guard,EQV(uint(c),e.index)) - lsx += Location(basex,guardx) + val basex = l.base + val guardx = AND(l.guard,EQV(uint(c),e.index)) + Location(basex,guardx) } - lsx } + /** Returns true if e contains a [[firrtl.WSubAccess]] */ - def hasAccess(e: Expression): Boolean = { - var ret: Boolean = false - def rec_has_access(e: Expression): Expression = e match { - case (e:WSubAccess) => { ret = true; e } - case (e) => e map (rec_has_access) + private def hasAccess(e: Expression): Boolean = { + var ret: Boolean = false + def rec_has_access(e: Expression): Expression = { + e match { + case e : WSubAccess => ret = true + case e => + } + e map rec_has_access } rec_has_access(e) ret } + + // This improves the performance of this pass + private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]() + private def create_exps(e: Expression) = + createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e)) + def run(c: Circuit): Circuit = { def remove_m(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { - val stmts = mutable.ArrayBuffer[Statement]() - def create_temp(e: Expression): Expression = { + def create_temp(e: Expression): (Statement, Expression) = { val n = namespace.newTemp - stmts += DefWire(info(s), n, tpe(e)) - WRef(n, tpe(e), kind(e), gender(e)) + (DefWire(info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) } /** Replaces a subaccess in a given male expression */ + val stmts = mutable.ArrayBuffer[Statement]() def removeMale(e: Expression): Expression = e match { case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(e)) => val rs = getLocations(e) - val foo = rs.find(x => {x.guard != one}) - foo match { + rs find (x => x.guard != one) match { case None => error("Shouldn't be here") - case foo: Some[Location] => - val temp = create_temp(e) + case Some(_) => + val (wire, temp) = create_temp(e) val temps = create_exps(temp) def getTemp(i: Int) = temps(i % temps.size) - for((x, i) <- rs.zipWithIndex) { - if (i < temps.size) { + stmts += wire + rs.zipWithIndex foreach { + case (x, i) if i < temps.size => stmts += Connect(info(s),getTemp(i),x.base) - } else { + case (x, i) => stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt) - } } temp } @@ -113,8 +110,10 @@ object RemoveAccesses extends Pass { val ls = getLocations(loc) if (ls.size == 1 & weq(ls(0).guard,one)) loc else { - val temp = create_temp(loc) - for (x <- ls) { stmts += Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt) } + val (wire, temp) = create_temp(loc) + stmts += wire + ls foreach (x => stmts += + Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt)) temp } case _ => loc @@ -148,13 +147,13 @@ object RemoveAccesses extends Pass { stmts += sx if (stmts.size != 1) Block(stmts) else stmts(0) } - Module(m.info, m.name, m.ports, onStmt(m.body)) + Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body))) } - val newModules = c.modules.map( _ match { + val newModules = c.modules.map { case m: ExtModule => m case m: Module => remove_m(m) - }) + } Circuit(c.info, newModules, c.main) } } |
