diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 717d95e8..7d4c96b2 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -28,6 +28,7 @@ MODIFICATIONS. package firrtl.passes import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ @@ -36,12 +37,12 @@ import annotation.tailrec object CommonSubexpressionElimination extends Pass { def name = "Common Subexpression Elimination" - private def cseOnce(s: Stmt): (Stmt, Long) = { + private def cseOnce(s: Statement): (Statement, Long) = { var nEliminated = 0L val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]() val nodes = collection.mutable.HashMap[String, Expression]() - def recordNodes(s: Stmt): Stmt = s match { + def recordNodes(s: Statement): Statement = s match { case x: DefNode => nodes(x.name) = x.value expressions.getOrElseUpdate(x.value, x.name) @@ -62,22 +63,22 @@ object CommonSubexpressionElimination extends Pass { case _ => e map eliminateNodeRef } - def eliminateNodeRefs(s: Stmt): Stmt = s map eliminateNodeRefs map eliminateNodeRef + def eliminateNodeRefs(s: Statement): Statement = s map eliminateNodeRefs map eliminateNodeRef recordNodes(s) (eliminateNodeRefs(s), nEliminated) } @tailrec - private def cse(s: Stmt): Stmt = { + private def cse(s: Statement): Statement = { val (res, n) = cseOnce(s) if (n > 0) cse(res) else res } def run(c: Circuit): Circuit = { val modulesx = c.modules.map { - case m: ExModule => m - case m: InModule => InModule(m.info, m.name, m.ports, cse(m.body)) + case m: ExtModule => m + case m: Module => Module(m.info, m.name, m.ports, cse(m.body)) } Circuit(c.info, modulesx, c.main) } |
