diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/ExpandWhens.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 58 |
1 files changed, 45 insertions, 13 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 4d02e192..a2845f43 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -10,21 +10,27 @@ import firrtl.PrimOps._ import firrtl.WrappedExpression._ import annotation.tailrec +import collection.mutable +import collection.immutable.ListSet /** Expand Whens * -* @note This pass does three things: remove last connect semantics, -* remove conditional blocks, and eliminate concept of scoping. +* This pass does the following things: +* $ - Remove last connect semantics +* $ - Remove conditional blocks +* $ - Eliminate concept of scoping +* $ - Consolidate attaches +* * @note Assumes bulk connects and isInvalids have been expanded * @note Assumes all references are declared */ object ExpandWhens extends Pass { def name = "Expand Whens" - type NodeMap = collection.mutable.HashMap[MemoizedHash[Expression], String] - type Netlist = collection.mutable.LinkedHashMap[WrappedExpression, Expression] - type Simlist = collection.mutable.ArrayBuffer[Statement] - type Attachlist = collection.mutable.ArrayBuffer[Statement] - type Defaults = Seq[collection.mutable.Map[WrappedExpression, Expression]] + type NodeMap = mutable.HashMap[MemoizedHash[Expression], String] + type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression] + type Simlist = mutable.ArrayBuffer[Statement] + // Defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow + type Defaults = Seq[mutable.Map[WrappedExpression, Expression]] // ========== Expand When Utilz ========== private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = { @@ -45,6 +51,27 @@ object ExpandWhens extends Pass { case (k, WInvalid) => IsInvalid(NoInfo, k.e1) case (k, v) => Connect(NoInfo, k.e1, v) } + /** Combines Attaches + * @todo Preserve Info + */ + private def combineAttaches(attaches: Seq[Attach]): Seq[Attach] = { + // Helper type to add an ordering index to attached Expressions + case class AttachAcc(exprs: Seq[Expression], idx: Int) + // Map from every attached expression to its corresponding AttachAcc + // (many keys will point to same value) + val attachMap = mutable.HashMap.empty[WrappedExpression, AttachAcc] + for (Attach(_, exprs) <- attaches) { + val acc = exprs.map(attachMap.get(_)).flatten match { + case Seq() => // None of these expressions is present in the attachMap + AttachAcc(exprs, attachMap.size) + case accs => // At least one expression present in the attachMap + val sorted = accs sortBy (_.idx) + AttachAcc((sorted.map(_.exprs) :+ exprs).flatten.distinct, sorted.head.idx) + } + attachMap ++= acc.exprs.map(e => (we(e) -> acc)) + } + attachMap.values.toList.distinct.map(acc => Attach(NoInfo, acc.exprs)) + } // Searches nested scopes of defaults for lvalue // defaults uses mutable Map because we are searching LinkedHashMaps and conversion to immutable is VERY slow @tailrec @@ -65,12 +92,13 @@ object ExpandWhens extends Pass { // ------------ Pass ------------------- def run(c: Circuit): Circuit = { - def expandWhens(m: Module): (Netlist, Simlist, Statement) = { + def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement) = { val namespace = Namespace(m) val simlist = new Simlist val nodes = new NodeMap + // Seq of attaches in order + lazy val attaches = mutable.ArrayBuffer.empty[Attach] - // defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow def expandWhens(netlist: Netlist, defaults: Defaults, p: Expression) @@ -90,7 +118,9 @@ object ExpandWhens extends Pass { case c: IsInvalid => netlist(c.expr) = WInvalid EmptyStmt - case c: Attach => c + case a: Attach => + attaches += a + EmptyStmt case sx: Conditionally => val conseqNetlist = new Netlist val altNetlist = new Netlist @@ -150,13 +180,15 @@ object ExpandWhens extends Pass { netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) => getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid) }) - (netlist, simlist, expandWhens(netlist, Seq(netlist), one)(m.body)) + val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) + (netlist, simlist, attaches, bodyx) } val modulesx = c.modules map { case m: ExtModule => m case m: Module => - val (netlist, simlist, bodyx) = expandWhens(m) - val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist) + val (netlist, simlist, attaches, bodyx) = expandWhens(m) + val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ + combineAttaches(attaches) ++ simlist) Module(m.info, m.name, m.ports, newBody) } Circuit(c.info, modulesx, c.main) |
