diff options
| author | Jack Koenig | 2018-03-27 10:58:45 -0700 |
|---|---|---|
| committer | GitHub | 2018-03-27 10:58:45 -0700 |
| commit | 65454f5ff1a370d66202a073e18cdcd40180f051 (patch) | |
| tree | 0d24d5e3152af1cea51a4de3fc8bd1036ec964df | |
| parent | ae623fd24794bddc3ad8ab0849787fdf033af7b7 (diff) | |
Const prop improvement (#772)
Improve constant propagation of connections to references
[skip formal checks]
LEC fails on this PR because this PR actually changes the circuit. The
change is that it constant propagates some additional registers. This is
really just extending #621 to work on more registers that it was
supposed to be propagating anyway.
3 files changed, 36 insertions, 19 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 57b88890..8217a9bd 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -358,6 +358,7 @@ class ConstantPropagation extends Transform { def constPropStmt(s: Statement): Statement = { val stmtx = s map constPropStmt map constPropExpression(nodeMap, instMap, constSubOutputs) + // Record things that should be propagated stmtx match { case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => @@ -387,7 +388,13 @@ class ConstantPropagation extends Transform { portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty) case _ => } - stmtx + // Actually transform some statements + stmtx match { + // Propagate connections to references + case Connect(info, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) => + Connect(info, lhs, nodeMap(rname)) + case other => other + } } val modx = m.copy(body = backPropStmt(constPropStmt(m.body))) diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index c1e36b67..9fb8dfd9 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -180,8 +180,7 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers { | output out: {a: UInt<3>, b: UInt<3>[2]} | wire w: {a: UInt<3>, b: UInt<3>[2]} | w is invalid - | node n = mux(pred, in, w) - | out <= n + | out <= mux(pred, in, w) | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk | cmem mem: {a: UInt<3>, b: UInt<3>[2]}[8] | write mport write = mem[pred], clk @@ -191,7 +190,6 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers { anno("in.a"), anno("in.b[0]"), anno("in.b[1]"), anno("out.a"), anno("out.b[0]"), anno("out.b[1]"), anno("w.a"), anno("w.b[0]"), anno("w.b[1]"), - anno("n.a"), anno("n.b[0]"), anno("n.b[1]"), anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), anno("write.a"), anno("write.b[0]"), anno("write.b[1]"), dontTouch("Top.r"), dontTouch("Top.w") @@ -222,9 +220,6 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should contain (anno("w_a")) resultAnno should contain (anno("w_b_0")) resultAnno should contain (anno("w_b_1")) - resultAnno should contain (anno("n_a")) - resultAnno should contain (anno("n_b_0")) - resultAnno should contain (anno("n_b_1")) resultAnno should contain (anno("r_a")) resultAnno should contain (anno("r_b_0")) resultAnno should contain (anno("r_b_1")) @@ -244,11 +239,10 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers { | output out: {a: UInt<3>, b: UInt<3>[2]} | wire w: {a: UInt<3>, b: UInt<3>[2]} | w is invalid - | node n = mux(pred, in, w) - | out <= n + | out <= mux(pred, in, w) | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"), + val annos = Seq(anno("in"), anno("out"), anno("w"), anno("r"), dontTouch("Top.r"), dontTouch("Top.w")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, annos), Nil) val resultAnno = result.annotations.toSeq @@ -261,9 +255,6 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should contain (anno("w_a")) resultAnno should contain (anno("w_b_0")) resultAnno should contain (anno("w_b_1")) - resultAnno should contain (anno("n_a")) - resultAnno should contain (anno("n_b_0")) - resultAnno should contain (anno("n_b_1")) resultAnno should contain (anno("r_a")) resultAnno should contain (anno("r_b_0")) resultAnno should contain (anno("r_b_1")) @@ -284,7 +275,7 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers { | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"), + val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("r.b"), dontTouch("Top.r"), dontTouch("Top.w")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, annos), Nil) val resultAnno = result.annotations.toSeq @@ -294,8 +285,6 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should contain (anno("out_b_1")) resultAnno should contain (anno("w_b_0")) resultAnno should contain (anno("w_b_1")) - resultAnno should contain (anno("n_b_0")) - resultAnno should contain (anno("n_b_1")) resultAnno should contain (anno("r_b_0")) resultAnno should contain (anno("r_b_1")) } diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index e143f853..079b4823 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -543,7 +543,7 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { output z : UInt<1> node _T_1 = and(x, y) node n = _T_1 - z <= n + z <= and(n, x) """ val check = """circuit Top : @@ -553,7 +553,7 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { output z : UInt<1> node n = and(x, y) node _T_1 = n - z <= n + z <= and(n, x) """ (parse(exec(input))) should be (parse(check)) } @@ -663,7 +663,7 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { wire hit : UInt<1> node _T_1 = or(x, y) node _T_2 = _T_1 - hit <= _T_1 + hit <= or(x, y) z <= hit """ (parse(exec(input))) should be (parse(check)) @@ -950,4 +950,25 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } + + "Connections to a node reference" should "be replaced with the rhs of that node" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<8> + | input b : UInt<8> + | input c : UInt<1> + | output z : UInt<8> + | node x = mux(c, a, b) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<8> + | input b : UInt<8> + | input c : UInt<1> + | output z : UInt<8> + | z <= mux(c, a, b)""".stripMargin + execute(input, check, Seq.empty) + } } |
