diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/ConstProp.scala | 28 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/AnnotationTests.scala | 7 |
2 files changed, 28 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index 8a2d6ec6..f2aa1a03 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -234,21 +234,38 @@ object ConstProp extends Pass { case _ => r } + // Two pass process + // 1. Propagate constants in expressions and forward propagate references + // 2. Propagate references again for backwards reference (Wires) + // TODO Replacing all wires with nodes makes the second pass unnecessary @tailrec private def constPropModule(m: Module): Module = { var nPropagated = 0L val nodeMap = collection.mutable.HashMap[String, Expression]() + def backPropExpr(expr: Expression): Expression = { + val old = expr map backPropExpr + val propagated = old match { + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + if (old ne propagated) { + nPropagated += 1 + } + propagated + } + def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr + def constPropExpression(e: Expression): Expression = { val old = e map constPropExpression val propagated = old match { case p: DoPrim => constPropPrim(p) case m: Mux => constPropMux(m) - case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name)) + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) case x => x } - if (old ne propagated) - nPropagated += 1 propagated } @@ -256,12 +273,15 @@ object ConstProp extends Pass { val stmtx = s map constPropStmt map constPropExpression stmtx match { case x: DefNode => nodeMap(x.name) = x.value + case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => + val exprx = constPropExpression(pad(expr, wtpe)) + nodeMap(wname) = exprx case _ => } stmtx } - val res = Module(m.info, m.name, m.ports, constPropStmt(m.body)) + val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) if (nPropagated > 0) constPropModule(res) else res } diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index e3dd3dbd..3e93081e 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -272,7 +272,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { anno("n.a"), anno("n.b[0]"), anno("n.b[1]"), anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), anno("write.a"), anno("write.b[0]"), anno("write.b[1]"), - dontTouch("Top.r") + dontTouch("Top.r"), dontTouch("Top.w") ) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations @@ -326,7 +326,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r")) + val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"), + dontTouch("Top.w")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_a")) @@ -362,7 +363,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"), - dontTouch("Top.r")) + dontTouch("Top.r"), dontTouch("Top.w")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_b_0")) |
