aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala')
-rw-r--r--src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala13
1 files changed, 7 insertions, 6 deletions
diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
index 717d95e8..7d4c96b2 100644
--- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
+++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
@@ -28,6 +28,7 @@ MODIFICATIONS.
package firrtl.passes
import firrtl._
+import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
@@ -36,12 +37,12 @@ import annotation.tailrec
object CommonSubexpressionElimination extends Pass {
def name = "Common Subexpression Elimination"
- private def cseOnce(s: Stmt): (Stmt, Long) = {
+ private def cseOnce(s: Statement): (Statement, Long) = {
var nEliminated = 0L
val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]()
val nodes = collection.mutable.HashMap[String, Expression]()
- def recordNodes(s: Stmt): Stmt = s match {
+ def recordNodes(s: Statement): Statement = s match {
case x: DefNode =>
nodes(x.name) = x.value
expressions.getOrElseUpdate(x.value, x.name)
@@ -62,22 +63,22 @@ object CommonSubexpressionElimination extends Pass {
case _ => e map eliminateNodeRef
}
- def eliminateNodeRefs(s: Stmt): Stmt = s map eliminateNodeRefs map eliminateNodeRef
+ def eliminateNodeRefs(s: Statement): Statement = s map eliminateNodeRefs map eliminateNodeRef
recordNodes(s)
(eliminateNodeRefs(s), nEliminated)
}
@tailrec
- private def cse(s: Stmt): Stmt = {
+ private def cse(s: Statement): Statement = {
val (res, n) = cseOnce(s)
if (n > 0) cse(res) else res
}
def run(c: Circuit): Circuit = {
val modulesx = c.modules.map {
- case m: ExModule => m
- case m: InModule => InModule(m.info, m.name, m.ports, cse(m.body))
+ case m: ExtModule => m
+ case m: Module => Module(m.info, m.name, m.ports, cse(m.body))
}
Circuit(c.info, modulesx, c.main)
}