aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala28
-rw-r--r--src/test/scala/firrtlTests/AnnotationTests.scala7
2 files changed, 28 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
index 8a2d6ec6..f2aa1a03 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/passes/ConstProp.scala
@@ -234,21 +234,38 @@ object ConstProp extends Pass {
case _ => r
}
+ // Two pass process
+ // 1. Propagate constants in expressions and forward propagate references
+ // 2. Propagate references again for backwards reference (Wires)
+ // TODO Replacing all wires with nodes makes the second pass unnecessary
@tailrec
private def constPropModule(m: Module): Module = {
var nPropagated = 0L
val nodeMap = collection.mutable.HashMap[String, Expression]()
+ def backPropExpr(expr: Expression): Expression = {
+ val old = expr map backPropExpr
+ val propagated = old match {
+ case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
+ constPropNodeRef(ref, nodeMap(rname))
+ case x => x
+ }
+ if (old ne propagated) {
+ nPropagated += 1
+ }
+ propagated
+ }
+ def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr
+
def constPropExpression(e: Expression): Expression = {
val old = e map constPropExpression
val propagated = old match {
case p: DoPrim => constPropPrim(p)
case m: Mux => constPropMux(m)
- case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name))
+ case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
+ constPropNodeRef(ref, nodeMap(rname))
case x => x
}
- if (old ne propagated)
- nPropagated += 1
propagated
}
@@ -256,12 +273,15 @@ object ConstProp extends Pass {
val stmtx = s map constPropStmt map constPropExpression
stmtx match {
case x: DefNode => nodeMap(x.name) = x.value
+ case Connect(_, WRef(wname, wtpe, WireKind, _), expr) =>
+ val exprx = constPropExpression(pad(expr, wtpe))
+ nodeMap(wname) = exprx
case _ =>
}
stmtx
}
- val res = Module(m.info, m.name, m.ports, constPropStmt(m.body))
+ val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body)))
if (nPropagated > 0) constPropModule(res) else res
}
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala
index e3dd3dbd..3e93081e 100644
--- a/src/test/scala/firrtlTests/AnnotationTests.scala
+++ b/src/test/scala/firrtlTests/AnnotationTests.scala
@@ -272,7 +272,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
anno("n.a"), anno("n.b[0]"), anno("n.b[1]"),
anno("r.a"), anno("r.b[0]"), anno("r.b[1]"),
anno("write.a"), anno("write.b[0]"), anno("write.b[1]"),
- dontTouch("Top.r")
+ dontTouch("Top.r"), dontTouch("Top.w")
)
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
@@ -326,7 +326,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| out <= n
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
- val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"))
+ val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"),
+ dontTouch("Top.w"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_a"))
@@ -362,7 +363,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"),
- dontTouch("Top.r"))
+ dontTouch("Top.r"), dontTouch("Top.w"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_b_0"))