diff options
| author | Albert Magyar | 2020-03-11 15:14:45 -0700 |
|---|---|---|
| committer | GitHub | 2020-03-11 15:14:45 -0700 |
| commit | 3726fba89bb70f424ac8be4ad2d4b300c471d7e8 (patch) | |
| tree | 4e921e7e809fd3071f2a1213c5dbc45af8eb3a87 | |
| parent | 026c18dd76d4e2121c7f6c582d15e4d5a3ab842b (diff) | |
Don't const-prop a register's self-init (#1441)
* Fixes #1214
Co-authored-by: Jack Koenig <koenig@sifive.com>
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 29 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 29 |
2 files changed, 45 insertions, 13 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index c11bc44d..18577147 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -501,17 +501,24 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { propagated } - def backPropStmt(stmt: Statement): Statement = stmt map backPropExpr match { - case decl: IsDeclaration if swapMap.contains(decl.name) => - val newName = swapMap(decl.name) - nPropagated += 1 - decl match { - case node: DefNode => node.copy(name = newName) - case wire: DefWire => wire.copy(name = newName) - case reg: DefRegister => reg.copy(name = newName) - case other => throwInternalError() - } - case other => other map backPropStmt + def backPropStmt(stmt: Statement): Statement = stmt match { + case reg: DefRegister if (WrappedExpression.weq(reg.init, WRef(reg))) => + // Self-init reset is an idiom for "no reset," and must be handled separately + swapMap.get(reg.name) + .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName))) + .getOrElse(reg) + case s => s map backPropExpr match { + case decl: IsDeclaration if swapMap.contains(decl.name) => + val newName = swapMap(decl.name) + nPropagated += 1 + decl match { + case node: DefNode => node.copy(name = newName) + case wire: DefWire => wire.copy(name = newName) + case reg: DefRegister => reg.copy(name = newName) + case other => throwInternalError() + } + case other => other map backPropStmt + } } // When propagating a reference, check if we want to keep the name that would be deleted diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 189809a6..3296b13b 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -5,6 +5,7 @@ package firrtlTests import firrtl._ import firrtl.passes._ import firrtl.transforms._ +import firrtl.annotations.Annotation class ConstantPropagationSpec extends FirrtlFlatSpec { val transforms = Seq( @@ -14,8 +15,8 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { ResolveFlows, new InferWidths, new ConstantPropagation) - protected def exec(input: String) = { - transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { + protected def exec(input: String, annos: Seq[Annotation] = Nil) = { + transforms.foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) => t.runTransform(c) }.circuit.serialize } @@ -751,6 +752,30 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { (parse(exec(input))) should be(parse(check)) } + "ConstProp" should "NOT touch self-inits" in { + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input rst : UInt<1> + | output z : UInt<4> + | reg selfinit : UInt<1>, clk with : (reset => (UInt<1>(0), selfinit)) + | selfinit <= UInt<1>(0) + | z <= mux(UInt(1), UInt<2>(0), UInt<4>(0)) + """.stripMargin + val check = + """circuit Top : + | module Top : + | input clk : Clock + | input rst : UInt<1> + | output z : UInt<4> + | reg selfinit : UInt<1>, clk with : (reset => (UInt<1>(0), selfinit)) + | selfinit <= UInt<1>(0) + | z <= UInt<4>(0) + """.stripMargin + (parse(exec(input, Seq(NoDCEAnnotation)))) should be(parse(check)) + } + def castCheck(tpe: String, cast: String): Unit = { val input = s"""circuit Top : |
