aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ExpandWhens.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/ExpandWhens.scala')
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala58
1 files changed, 45 insertions, 13 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index 4d02e192..a2845f43 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -10,21 +10,27 @@ import firrtl.PrimOps._
import firrtl.WrappedExpression._
import annotation.tailrec
+import collection.mutable
+import collection.immutable.ListSet
/** Expand Whens
*
-* @note This pass does three things: remove last connect semantics,
-* remove conditional blocks, and eliminate concept of scoping.
+* This pass does the following things:
+* $ - Remove last connect semantics
+* $ - Remove conditional blocks
+* $ - Eliminate concept of scoping
+* $ - Consolidate attaches
+*
* @note Assumes bulk connects and isInvalids have been expanded
* @note Assumes all references are declared
*/
object ExpandWhens extends Pass {
def name = "Expand Whens"
- type NodeMap = collection.mutable.HashMap[MemoizedHash[Expression], String]
- type Netlist = collection.mutable.LinkedHashMap[WrappedExpression, Expression]
- type Simlist = collection.mutable.ArrayBuffer[Statement]
- type Attachlist = collection.mutable.ArrayBuffer[Statement]
- type Defaults = Seq[collection.mutable.Map[WrappedExpression, Expression]]
+ type NodeMap = mutable.HashMap[MemoizedHash[Expression], String]
+ type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression]
+ type Simlist = mutable.ArrayBuffer[Statement]
+ // Defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow
+ type Defaults = Seq[mutable.Map[WrappedExpression, Expression]]
// ========== Expand When Utilz ==========
private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = {
@@ -45,6 +51,27 @@ object ExpandWhens extends Pass {
case (k, WInvalid) => IsInvalid(NoInfo, k.e1)
case (k, v) => Connect(NoInfo, k.e1, v)
}
+ /** Combines Attaches
+ * @todo Preserve Info
+ */
+ private def combineAttaches(attaches: Seq[Attach]): Seq[Attach] = {
+ // Helper type to add an ordering index to attached Expressions
+ case class AttachAcc(exprs: Seq[Expression], idx: Int)
+ // Map from every attached expression to its corresponding AttachAcc
+ // (many keys will point to same value)
+ val attachMap = mutable.HashMap.empty[WrappedExpression, AttachAcc]
+ for (Attach(_, exprs) <- attaches) {
+ val acc = exprs.map(attachMap.get(_)).flatten match {
+ case Seq() => // None of these expressions is present in the attachMap
+ AttachAcc(exprs, attachMap.size)
+ case accs => // At least one expression present in the attachMap
+ val sorted = accs sortBy (_.idx)
+ AttachAcc((sorted.map(_.exprs) :+ exprs).flatten.distinct, sorted.head.idx)
+ }
+ attachMap ++= acc.exprs.map(e => (we(e) -> acc))
+ }
+ attachMap.values.toList.distinct.map(acc => Attach(NoInfo, acc.exprs))
+ }
// Searches nested scopes of defaults for lvalue
// defaults uses mutable Map because we are searching LinkedHashMaps and conversion to immutable is VERY slow
@tailrec
@@ -65,12 +92,13 @@ object ExpandWhens extends Pass {
// ------------ Pass -------------------
def run(c: Circuit): Circuit = {
- def expandWhens(m: Module): (Netlist, Simlist, Statement) = {
+ def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement) = {
val namespace = Namespace(m)
val simlist = new Simlist
val nodes = new NodeMap
+ // Seq of attaches in order
+ lazy val attaches = mutable.ArrayBuffer.empty[Attach]
- // defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow
def expandWhens(netlist: Netlist,
defaults: Defaults,
p: Expression)
@@ -90,7 +118,9 @@ object ExpandWhens extends Pass {
case c: IsInvalid =>
netlist(c.expr) = WInvalid
EmptyStmt
- case c: Attach => c
+ case a: Attach =>
+ attaches += a
+ EmptyStmt
case sx: Conditionally =>
val conseqNetlist = new Netlist
val altNetlist = new Netlist
@@ -150,13 +180,15 @@ object ExpandWhens extends Pass {
netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) =>
getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid)
})
- (netlist, simlist, expandWhens(netlist, Seq(netlist), one)(m.body))
+ val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
+ (netlist, simlist, attaches, bodyx)
}
val modulesx = c.modules map {
case m: ExtModule => m
case m: Module =>
- val (netlist, simlist, bodyx) = expandWhens(m)
- val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist)
+ val (netlist, simlist, attaches, bodyx) = expandWhens(m)
+ val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++
+ combineAttaches(attaches) ++ simlist)
Module(m.info, m.name, m.ports, newBody)
}
Circuit(c.info, modulesx, c.main)