diff options
| author | Jack Koenig | 2017-06-22 17:51:39 -0700 |
|---|---|---|
| committer | Jack Koenig | 2017-06-26 17:17:49 -0700 |
| commit | f8572ba6532359e8a0f1bc34f3eb8241a29129ab (patch) | |
| tree | 4c3706c616a064fe01060b8a4c8d2e14b5544520 | |
| parent | 0fca90f951fa91c944c68bffef1c91f74d563028 (diff) | |
Add support for wires in ConstProp
This requires a quick second pass to back propagate constant wires but
the QoR win is substantial. We also only need to count back propagations
in determining whether to run ConstProp again which shaves off an
iteration in the common case.
| -rw-r--r-- | src/main/scala/firrtl/passes/ConstProp.scala | 28 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/AnnotationTests.scala | 7 |
2 files changed, 28 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index 8a2d6ec6..f2aa1a03 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -234,21 +234,38 @@ object ConstProp extends Pass { case _ => r } + // Two pass process + // 1. Propagate constants in expressions and forward propagate references + // 2. Propagate references again for backwards reference (Wires) + // TODO Replacing all wires with nodes makes the second pass unnecessary @tailrec private def constPropModule(m: Module): Module = { var nPropagated = 0L val nodeMap = collection.mutable.HashMap[String, Expression]() + def backPropExpr(expr: Expression): Expression = { + val old = expr map backPropExpr + val propagated = old match { + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + if (old ne propagated) { + nPropagated += 1 + } + propagated + } + def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr + def constPropExpression(e: Expression): Expression = { val old = e map constPropExpression val propagated = old match { case p: DoPrim => constPropPrim(p) case m: Mux => constPropMux(m) - case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name)) + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) case x => x } - if (old ne propagated) - nPropagated += 1 propagated } @@ -256,12 +273,15 @@ object ConstProp extends Pass { val stmtx = s map constPropStmt map constPropExpression stmtx match { case x: DefNode => nodeMap(x.name) = x.value + case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => + val exprx = constPropExpression(pad(expr, wtpe)) + nodeMap(wname) = exprx case _ => } stmtx } - val res = Module(m.info, m.name, m.ports, constPropStmt(m.body)) + val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) if (nPropagated > 0) constPropModule(res) else res } diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index e3dd3dbd..3e93081e 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -272,7 +272,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { anno("n.a"), anno("n.b[0]"), anno("n.b[1]"), anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), anno("write.a"), anno("write.b[0]"), anno("write.b[1]"), - dontTouch("Top.r") + dontTouch("Top.r"), dontTouch("Top.w") ) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations @@ -326,7 +326,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r")) + val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"), + dontTouch("Top.w")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_a")) @@ -362,7 +363,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"), - dontTouch("Top.r")) + dontTouch("Top.r"), dontTouch("Top.w")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_b_0")) |
