From 864a3978cf94f336187831773dfc2c9f9ea064c8 Mon Sep 17 00:00:00 2001 From: Donggyu Kim Date: Fri, 26 Aug 2016 02:48:55 -0700 Subject: memoize nodes in ExpandWhens --- src/main/scala/firrtl/passes/ExpandWhens.scala | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index e02e2bf0..7c013b51 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -45,6 +45,7 @@ import annotation.tailrec */ object ExpandWhens extends Pass { def name = "Expand Whens" + type NodeMap = collection.mutable.HashMap[MemoizedHash[Expression], String] type Netlist = collection.mutable.LinkedHashMap[WrappedExpression, Expression] type Simlist = collection.mutable.ArrayBuffer[Statement] type Defaults = Seq[collection.mutable.Map[WrappedExpression, Expression]] @@ -88,6 +89,7 @@ object ExpandWhens extends Pass { def expandWhens(m: Module): (Netlist, Simlist, Statement) = { val namespace = Namespace(m) val simlist = new Simlist + val nodes = new NodeMap // defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow def expandWhens(netlist: Netlist, @@ -133,10 +135,16 @@ object ExpandWhens extends Pass { conseqNetlist getOrElse (lvalue, altNetlist(lvalue)) } - val memoNode = DefNode(s.info, namespace.newTemp, res) - val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE) - netlist(lvalue) = memoExpr - memoNode + nodes get res match { + case Some(name) => + netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE) + EmptyStmt + case None => + val name = namespace.newTemp + nodes(res) = name + netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE) + DefNode(s.info, name, res) + } } Block(Seq(conseqStmt, altStmt) ++ memos) case s: Print => -- cgit v1.2.3