From 5093da03083a37a0a7bdaf44f9867d7f7a0a5980 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Thu, 21 Apr 2022 20:20:47 -0700 Subject: Fix optimization of register with reset but invalid connection (#2520) Fixes #2516 Previously, reg r : UInt<8>, clock with : reset => (p, UInt<8>(3)) r is invalid would compile to: reg r : UInt<8>, clock r <= UInt<8>(0) now it compiles to: reg r : UInt<8>, clock wire r_1 : UInt<8> r_1 is invalid r <= mux(reset, UInt<8>(3), r_1) This is consistent with the behavior for a reset with an asynchronous reset.--- src/main/scala/firrtl/transforms/RemoveReset.scala | 12 ++++++++++ .../firrtlTests/transforms/RemoveResetSpec.scala | 28 +++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index f1434ad2..7f2207af 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -52,6 +52,7 @@ object RemoveReset extends Transform with DependencyAPIMigration { val resets = mutable.HashMap.empty[String, Reset] val asyncResets = mutable.HashMap.empty[String, Reset] val invalids = computeInvalids(m) + lazy val namespace = Namespace(m) def onStmt(stmt: Statement): Statement = { stmt match { case reg @ DefRegister(_, name, _, _, reset, init) if isPreset(name) => @@ -93,6 +94,17 @@ object RemoveReset extends Transform with DependencyAPIMigration { // addUpdate(info, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), Seq.empty) val infox = MultiInfo(reset.info, reset.info, info) Connect(infox, ref, expr) + /* Synchronously reset register that has reset value but only an invalid connection */ + case IsInvalid(iinfo, ref @ WRef(rname, tpe, RegKind, _)) if resets.contains(rname) => + // We need to mux with the invalid value to be consistent with async reset registers + val dummyWire = DefWire(iinfo, namespace.newName(rname), tpe) + val wireRef = Reference(dummyWire).copy(flow = SourceFlow) + val invalid = IsInvalid(iinfo, wireRef) + // Now mux between the invalid wire and the reset value + val Reset(cond, init, info) = resets(rname) + val muxType = Utils.mux_type_and_widths(init, wireRef) + val connect = Connect(info, ref, Mux(cond, init, wireRef, muxType)) + Block(Seq(dummyWire, invalid, connect)) case other => other.map(onStmt) } } diff --git a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala index 1adeeed8..666320b7 100644 --- a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala +++ b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala @@ -8,7 +8,7 @@ import firrtl.testutils.FirrtlFlatSpec import firrtl.testutils.FirrtlCheckers._ import firrtl.{CircuitState, WRef} -import firrtl.ir.{Connect, DefRegister, Mux} +import firrtl.ir.{Connect, DefRegister, IsInvalid, Mux, UIntLiteral} import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage} class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { @@ -47,6 +47,32 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { outputState shouldNot containTree { case Connect(_, WRef("foo", _, _, _), Mux(_, _, _, _)) => true } } + it should "generate a reset mux for a sync reset register with an invalid connection" in { + Given("an 8-bit register 'foo' initialized to UInt(3) with an invalid connection") + val input = + """|circuit Example : + | module Example : + | input clock : Clock + | input rst : UInt<1> + | input in : UInt<8> + | output out : UInt<8> + | + | reg foo : UInt<8>, clock with : (reset => (rst, UInt(3))) + | foo is invalid + | out <= foo""".stripMargin + + val outputState = toLowFirrtl(input) + + Then("'foo' should not have a reset") + outputState should containTree { + case DefRegister(_, "foo", _, _, UIntLiteral(value, _), WRef("foo", _, _, _)) if value == 0 => true + } + And("'foo' is connected to a mux with its old reset value") + outputState should containTree { + case Connect(_, WRef("foo", _, _, _), Mux(_, UIntLiteral(value, _), _, _)) if value == 3 => true + } + } + it should "generate a reset mux for only the portion of an invalid aggregate that is reset" in { Given("aggregate register 'foo' with 2-bit field 'a' and 1-bit field 'b'") And("aggregate, invalid wire 'bar' with the same fields") -- cgit v1.2.3