aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala16
1 files changed, 12 insertions, 4 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index e02e2bf0..7c013b51 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -45,6 +45,7 @@ import annotation.tailrec
*/
object ExpandWhens extends Pass {
def name = "Expand Whens"
+ type NodeMap = collection.mutable.HashMap[MemoizedHash[Expression], String]
type Netlist = collection.mutable.LinkedHashMap[WrappedExpression, Expression]
type Simlist = collection.mutable.ArrayBuffer[Statement]
type Defaults = Seq[collection.mutable.Map[WrappedExpression, Expression]]
@@ -88,6 +89,7 @@ object ExpandWhens extends Pass {
def expandWhens(m: Module): (Netlist, Simlist, Statement) = {
val namespace = Namespace(m)
val simlist = new Simlist
+ val nodes = new NodeMap
// defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow
def expandWhens(netlist: Netlist,
@@ -133,10 +135,16 @@ object ExpandWhens extends Pass {
conseqNetlist getOrElse (lvalue, altNetlist(lvalue))
}
- val memoNode = DefNode(s.info, namespace.newTemp, res)
- val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE)
- netlist(lvalue) = memoExpr
- memoNode
+ nodes get res match {
+ case Some(name) =>
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE)
+ EmptyStmt
+ case None =>
+ val name = namespace.newTemp
+ nodes(res) = name
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE)
+ DefNode(s.info, name, res)
+ }
}
Block(Seq(conseqStmt, altStmt) ++ memos)
case s: Print =>