aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala28
1 files changed, 24 insertions, 4 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
index 8a2d6ec6..f2aa1a03 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/passes/ConstProp.scala
@@ -234,21 +234,38 @@ object ConstProp extends Pass {
case _ => r
}
+ // Two pass process
+ // 1. Propagate constants in expressions and forward propagate references
+ // 2. Propagate references again for backwards reference (Wires)
+ // TODO Replacing all wires with nodes makes the second pass unnecessary
@tailrec
private def constPropModule(m: Module): Module = {
var nPropagated = 0L
val nodeMap = collection.mutable.HashMap[String, Expression]()
+ def backPropExpr(expr: Expression): Expression = {
+ val old = expr map backPropExpr
+ val propagated = old match {
+ case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
+ constPropNodeRef(ref, nodeMap(rname))
+ case x => x
+ }
+ if (old ne propagated) {
+ nPropagated += 1
+ }
+ propagated
+ }
+ def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr
+
def constPropExpression(e: Expression): Expression = {
val old = e map constPropExpression
val propagated = old match {
case p: DoPrim => constPropPrim(p)
case m: Mux => constPropMux(m)
- case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name))
+ case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
+ constPropNodeRef(ref, nodeMap(rname))
case x => x
}
- if (old ne propagated)
- nPropagated += 1
propagated
}
@@ -256,12 +273,15 @@ object ConstProp extends Pass {
val stmtx = s map constPropStmt map constPropExpression
stmtx match {
case x: DefNode => nodeMap(x.name) = x.value
+ case Connect(_, WRef(wname, wtpe, WireKind, _), expr) =>
+ val exprx = constPropExpression(pad(expr, wtpe))
+ nodeMap(wname) = exprx
case _ =>
}
stmtx
}
- val res = Module(m.info, m.name, m.ports, constPropStmt(m.body))
+ val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body)))
if (nPropagated > 0) constPropModule(res) else res
}