From 1c42e87bae86992c3804bb438f7888838664cef7 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Thu, 13 Apr 2017 15:37:55 -0700 Subject: Speed up CSE by doing CSE on node expression before recording the node (#543) This means CSE need only be run once to get same QOR and prevents pathological cases. Fixes #448 --- .../passes/CommonSubexpressionElimination.scala | 31 ++++++++-------------- 1 file changed, 11 insertions(+), 20 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 0abdaa36..2925bfd9 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -10,24 +10,14 @@ import firrtl.Mappers._ import annotation.tailrec object CommonSubexpressionElimination extends Pass { - private def cseOnce(s: Statement): (Statement, Long) = { - var nEliminated = 0L + private def cse(s: Statement): Statement = { val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]() val nodes = collection.mutable.HashMap[String, Expression]() - def recordNodes(s: Statement): Statement = s match { - case x: DefNode => - nodes(x.name) = x.value - expressions.getOrElseUpdate(x.value, x.name) - x - case _ => s map recordNodes - } - def eliminateNodeRef(e: Expression): Expression = e match { case WRef(name, tpe, kind, gender) => nodes get name match { case Some(expression) => expressions get expression match { case Some(cseName) if cseName != name => - nEliminated += 1 WRef(cseName, tpe, kind, gender) case _ => e } @@ -36,16 +26,17 @@ object CommonSubexpressionElimination extends Pass { case _ => e map eliminateNodeRef } - def eliminateNodeRefs(s: Statement): Statement = s map eliminateNodeRefs map eliminateNodeRef - - recordNodes(s) - (eliminateNodeRefs(s), nEliminated) - } + def eliminateNodeRefs(s: Statement): Statement = { + s map eliminateNodeRef match { + case x: DefNode => + nodes(x.name) = x.value + expressions.getOrElseUpdate(x.value, x.name) + x + case other => other map eliminateNodeRefs + } + } - @tailrec - private def cse(s: Statement): Statement = { - val (res, n) = cseOnce(s) - if (n > 0) cse(res) else res + eliminateNodeRefs(s) } def run(c: Circuit): Circuit = { -- cgit v1.2.3