aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2017-04-13 15:37:55 -0700
committerGitHub2017-04-13 15:37:55 -0700
commit1c42e87bae86992c3804bb438f7888838664cef7 (patch)
treeca6584ff7e969876d12c7243fa977d80df8b599c /src
parent7d0b48708b05aba6d840cc4a9d4ab00abe31929b (diff)
Speed up CSE by doing CSE on node expression before recording the node (#543)
This means CSE need only be run once to get same QOR and prevents pathological cases. Fixes #448
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala31
1 files changed, 11 insertions, 20 deletions
diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
index 0abdaa36..2925bfd9 100644
--- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
+++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
@@ -10,24 +10,14 @@ import firrtl.Mappers._
import annotation.tailrec
object CommonSubexpressionElimination extends Pass {
- private def cseOnce(s: Statement): (Statement, Long) = {
- var nEliminated = 0L
+ private def cse(s: Statement): Statement = {
val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]()
val nodes = collection.mutable.HashMap[String, Expression]()
- def recordNodes(s: Statement): Statement = s match {
- case x: DefNode =>
- nodes(x.name) = x.value
- expressions.getOrElseUpdate(x.value, x.name)
- x
- case _ => s map recordNodes
- }
-
def eliminateNodeRef(e: Expression): Expression = e match {
case WRef(name, tpe, kind, gender) => nodes get name match {
case Some(expression) => expressions get expression match {
case Some(cseName) if cseName != name =>
- nEliminated += 1
WRef(cseName, tpe, kind, gender)
case _ => e
}
@@ -36,16 +26,17 @@ object CommonSubexpressionElimination extends Pass {
case _ => e map eliminateNodeRef
}
- def eliminateNodeRefs(s: Statement): Statement = s map eliminateNodeRefs map eliminateNodeRef
-
- recordNodes(s)
- (eliminateNodeRefs(s), nEliminated)
- }
+ def eliminateNodeRefs(s: Statement): Statement = {
+ s map eliminateNodeRef match {
+ case x: DefNode =>
+ nodes(x.name) = x.value
+ expressions.getOrElseUpdate(x.value, x.name)
+ x
+ case other => other map eliminateNodeRefs
+ }
+ }
- @tailrec
- private def cse(s: Statement): Statement = {
- val (res, n) = cseOnce(s)
- if (n > 0) cse(res) else res
+ eliminateNodeRefs(s)
}
def run(c: Circuit): Circuit = {