aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack Koenig2017-06-22 17:51:39 -0700
committerJack Koenig2017-06-26 17:17:49 -0700
commitf8572ba6532359e8a0f1bc34f3eb8241a29129ab (patch)
tree4c3706c616a064fe01060b8a4c8d2e14b5544520
parent0fca90f951fa91c944c68bffef1c91f74d563028 (diff)
Add support for wires in ConstProp
This requires a quick second pass to back propagate constant wires but the QoR win is substantial. We also only need to count back propagations in determining whether to run ConstProp again which shaves off an iteration in the common case.
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala28
-rw-r--r--src/test/scala/firrtlTests/AnnotationTests.scala7
2 files changed, 28 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
index 8a2d6ec6..f2aa1a03 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/passes/ConstProp.scala
@@ -234,21 +234,38 @@ object ConstProp extends Pass {
case _ => r
}
+ // Two pass process
+ // 1. Propagate constants in expressions and forward propagate references
+ // 2. Propagate references again for backwards reference (Wires)
+ // TODO Replacing all wires with nodes makes the second pass unnecessary
@tailrec
private def constPropModule(m: Module): Module = {
var nPropagated = 0L
val nodeMap = collection.mutable.HashMap[String, Expression]()
+ def backPropExpr(expr: Expression): Expression = {
+ val old = expr map backPropExpr
+ val propagated = old match {
+ case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
+ constPropNodeRef(ref, nodeMap(rname))
+ case x => x
+ }
+ if (old ne propagated) {
+ nPropagated += 1
+ }
+ propagated
+ }
+ def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr
+
def constPropExpression(e: Expression): Expression = {
val old = e map constPropExpression
val propagated = old match {
case p: DoPrim => constPropPrim(p)
case m: Mux => constPropMux(m)
- case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name))
+ case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
+ constPropNodeRef(ref, nodeMap(rname))
case x => x
}
- if (old ne propagated)
- nPropagated += 1
propagated
}
@@ -256,12 +273,15 @@ object ConstProp extends Pass {
val stmtx = s map constPropStmt map constPropExpression
stmtx match {
case x: DefNode => nodeMap(x.name) = x.value
+ case Connect(_, WRef(wname, wtpe, WireKind, _), expr) =>
+ val exprx = constPropExpression(pad(expr, wtpe))
+ nodeMap(wname) = exprx
case _ =>
}
stmtx
}
- val res = Module(m.info, m.name, m.ports, constPropStmt(m.body))
+ val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body)))
if (nPropagated > 0) constPropModule(res) else res
}
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala
index e3dd3dbd..3e93081e 100644
--- a/src/test/scala/firrtlTests/AnnotationTests.scala
+++ b/src/test/scala/firrtlTests/AnnotationTests.scala
@@ -272,7 +272,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
anno("n.a"), anno("n.b[0]"), anno("n.b[1]"),
anno("r.a"), anno("r.b[0]"), anno("r.b[1]"),
anno("write.a"), anno("write.b[0]"), anno("write.b[1]"),
- dontTouch("Top.r")
+ dontTouch("Top.r"), dontTouch("Top.w")
)
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
@@ -326,7 +326,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| out <= n
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
- val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"))
+ val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"),
+ dontTouch("Top.w"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_a"))
@@ -362,7 +363,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"),
- dontTouch("Top.r"))
+ dontTouch("Top.r"), dontTouch("Top.w"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_b_0"))