aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2018-03-27 10:58:45 -0700
committerGitHub2018-03-27 10:58:45 -0700
commit65454f5ff1a370d66202a073e18cdcd40180f051 (patch)
tree0d24d5e3152af1cea51a4de3fc8bd1036ec964df /src
parentae623fd24794bddc3ad8ab0849787fdf033af7b7 (diff)
Const prop improvement (#772)
Improve constant propagation of connections to references [skip formal checks] LEC fails on this PR because this PR actually changes the circuit. The change is that it constant propagates some additional registers. This is really just extending #621 to work on more registers that it was supposed to be propagating anyway.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala9
-rw-r--r--src/test/scala/firrtlTests/AnnotationTests.scala19
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala27
3 files changed, 36 insertions, 19 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 57b88890..8217a9bd 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -358,6 +358,7 @@ class ConstantPropagation extends Transform {
def constPropStmt(s: Statement): Statement = {
val stmtx = s map constPropStmt map constPropExpression(nodeMap, instMap, constSubOutputs)
+ // Record things that should be propagated
stmtx match {
case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value)
case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) =>
@@ -387,7 +388,13 @@ class ConstantPropagation extends Transform {
portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty)
case _ =>
}
- stmtx
+ // Actually transform some statements
+ stmtx match {
+ // Propagate connections to references
+ case Connect(info, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) =>
+ Connect(info, lhs, nodeMap(rname))
+ case other => other
+ }
}
val modx = m.copy(body = backPropStmt(constPropStmt(m.body)))
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala
index c1e36b67..9fb8dfd9 100644
--- a/src/test/scala/firrtlTests/AnnotationTests.scala
+++ b/src/test/scala/firrtlTests/AnnotationTests.scala
@@ -180,8 +180,7 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers {
| output out: {a: UInt<3>, b: UInt<3>[2]}
| wire w: {a: UInt<3>, b: UInt<3>[2]}
| w is invalid
- | node n = mux(pred, in, w)
- | out <= n
+ | out <= mux(pred, in, w)
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
| cmem mem: {a: UInt<3>, b: UInt<3>[2]}[8]
| write mport write = mem[pred], clk
@@ -191,7 +190,6 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers {
anno("in.a"), anno("in.b[0]"), anno("in.b[1]"),
anno("out.a"), anno("out.b[0]"), anno("out.b[1]"),
anno("w.a"), anno("w.b[0]"), anno("w.b[1]"),
- anno("n.a"), anno("n.b[0]"), anno("n.b[1]"),
anno("r.a"), anno("r.b[0]"), anno("r.b[1]"),
anno("write.a"), anno("write.b[0]"), anno("write.b[1]"),
dontTouch("Top.r"), dontTouch("Top.w")
@@ -222,9 +220,6 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should contain (anno("w_a"))
resultAnno should contain (anno("w_b_0"))
resultAnno should contain (anno("w_b_1"))
- resultAnno should contain (anno("n_a"))
- resultAnno should contain (anno("n_b_0"))
- resultAnno should contain (anno("n_b_1"))
resultAnno should contain (anno("r_a"))
resultAnno should contain (anno("r_b_0"))
resultAnno should contain (anno("r_b_1"))
@@ -244,11 +239,10 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers {
| output out: {a: UInt<3>, b: UInt<3>[2]}
| wire w: {a: UInt<3>, b: UInt<3>[2]}
| w is invalid
- | node n = mux(pred, in, w)
- | out <= n
+ | out <= mux(pred, in, w)
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
- val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"),
+ val annos = Seq(anno("in"), anno("out"), anno("w"), anno("r"), dontTouch("Top.r"),
dontTouch("Top.w"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, annos), Nil)
val resultAnno = result.annotations.toSeq
@@ -261,9 +255,6 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should contain (anno("w_a"))
resultAnno should contain (anno("w_b_0"))
resultAnno should contain (anno("w_b_1"))
- resultAnno should contain (anno("n_a"))
- resultAnno should contain (anno("n_b_0"))
- resultAnno should contain (anno("n_b_1"))
resultAnno should contain (anno("r_a"))
resultAnno should contain (anno("r_b_0"))
resultAnno should contain (anno("r_b_1"))
@@ -284,7 +275,7 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers {
| out <= n
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
- val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"),
+ val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("r.b"),
dontTouch("Top.r"), dontTouch("Top.w"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, annos), Nil)
val resultAnno = result.annotations.toSeq
@@ -294,8 +285,6 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should contain (anno("out_b_1"))
resultAnno should contain (anno("w_b_0"))
resultAnno should contain (anno("w_b_1"))
- resultAnno should contain (anno("n_b_0"))
- resultAnno should contain (anno("n_b_1"))
resultAnno should contain (anno("r_b_0"))
resultAnno should contain (anno("r_b_1"))
}
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index e143f853..079b4823 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -543,7 +543,7 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
output z : UInt<1>
node _T_1 = and(x, y)
node n = _T_1
- z <= n
+ z <= and(n, x)
"""
val check =
"""circuit Top :
@@ -553,7 +553,7 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
output z : UInt<1>
node n = and(x, y)
node _T_1 = n
- z <= n
+ z <= and(n, x)
"""
(parse(exec(input))) should be (parse(check))
}
@@ -663,7 +663,7 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
wire hit : UInt<1>
node _T_1 = or(x, y)
node _T_2 = _T_1
- hit <= _T_1
+ hit <= or(x, y)
z <= hit
"""
(parse(exec(input))) should be (parse(check))
@@ -950,4 +950,25 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
| z <= UInt<8>("hb")""".stripMargin
execute(input, check, Seq.empty)
}
+
+ "Connections to a node reference" should "be replaced with the rhs of that node" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<8>
+ | input b : UInt<8>
+ | input c : UInt<1>
+ | output z : UInt<8>
+ | node x = mux(c, a, b)
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<8>
+ | input b : UInt<8>
+ | input c : UInt<1>
+ | output z : UInt<8>
+ | z <= mux(c, a, b)""".stripMargin
+ execute(input, check, Seq.empty)
+ }
}