diff options
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/passes/ConstProp.scala | 28 |
1 files changed, 24 insertions, 4 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index 8a2d6ec6..f2aa1a03 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -234,21 +234,38 @@ object ConstProp extends Pass { case _ => r } + // Two pass process + // 1. Propagate constants in expressions and forward propagate references + // 2. Propagate references again for backwards reference (Wires) + // TODO Replacing all wires with nodes makes the second pass unnecessary @tailrec private def constPropModule(m: Module): Module = { var nPropagated = 0L val nodeMap = collection.mutable.HashMap[String, Expression]() + def backPropExpr(expr: Expression): Expression = { + val old = expr map backPropExpr + val propagated = old match { + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + if (old ne propagated) { + nPropagated += 1 + } + propagated + } + def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr + def constPropExpression(e: Expression): Expression = { val old = e map constPropExpression val propagated = old match { case p: DoPrim => constPropPrim(p) case m: Mux => constPropMux(m) - case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name)) + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) case x => x } - if (old ne propagated) - nPropagated += 1 propagated } @@ -256,12 +273,15 @@ object ConstProp extends Pass { val stmtx = s map constPropStmt map constPropExpression stmtx match { case x: DefNode => nodeMap(x.name) = x.value + case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => + val exprx = constPropExpression(pad(expr, wtpe)) + nodeMap(wname) = exprx case _ => } stmtx } - val res = Module(m.info, m.name, m.ports, constPropStmt(m.body)) + val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) if (nPropagated > 0) constPropModule(res) else res } |
