aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/transforms/RemoveReset.scala12
-rw-r--r--src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala28
2 files changed, 39 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala
index f1434ad2..7f2207af 100644
--- a/src/main/scala/firrtl/transforms/RemoveReset.scala
+++ b/src/main/scala/firrtl/transforms/RemoveReset.scala
@@ -52,6 +52,7 @@ object RemoveReset extends Transform with DependencyAPIMigration {
val resets = mutable.HashMap.empty[String, Reset]
val asyncResets = mutable.HashMap.empty[String, Reset]
val invalids = computeInvalids(m)
+ lazy val namespace = Namespace(m)
def onStmt(stmt: Statement): Statement = {
stmt match {
case reg @ DefRegister(_, name, _, _, reset, init) if isPreset(name) =>
@@ -93,6 +94,17 @@ object RemoveReset extends Transform with DependencyAPIMigration {
// addUpdate(info, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), Seq.empty)
val infox = MultiInfo(reset.info, reset.info, info)
Connect(infox, ref, expr)
+ /* Synchronously reset register that has reset value but only an invalid connection */
+ case IsInvalid(iinfo, ref @ WRef(rname, tpe, RegKind, _)) if resets.contains(rname) =>
+ // We need to mux with the invalid value to be consistent with async reset registers
+ val dummyWire = DefWire(iinfo, namespace.newName(rname), tpe)
+ val wireRef = Reference(dummyWire).copy(flow = SourceFlow)
+ val invalid = IsInvalid(iinfo, wireRef)
+ // Now mux between the invalid wire and the reset value
+ val Reset(cond, init, info) = resets(rname)
+ val muxType = Utils.mux_type_and_widths(init, wireRef)
+ val connect = Connect(info, ref, Mux(cond, init, wireRef, muxType))
+ Block(Seq(dummyWire, invalid, connect))
case other => other.map(onStmt)
}
}
diff --git a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala
index 1adeeed8..666320b7 100644
--- a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala
+++ b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala
@@ -8,7 +8,7 @@ import firrtl.testutils.FirrtlFlatSpec
import firrtl.testutils.FirrtlCheckers._
import firrtl.{CircuitState, WRef}
-import firrtl.ir.{Connect, DefRegister, Mux}
+import firrtl.ir.{Connect, DefRegister, IsInvalid, Mux, UIntLiteral}
import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage}
class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen {
@@ -47,6 +47,32 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen {
outputState shouldNot containTree { case Connect(_, WRef("foo", _, _, _), Mux(_, _, _, _)) => true }
}
+ it should "generate a reset mux for a sync reset register with an invalid connection" in {
+ Given("an 8-bit register 'foo' initialized to UInt(3) with an invalid connection")
+ val input =
+ """|circuit Example :
+ | module Example :
+ | input clock : Clock
+ | input rst : UInt<1>
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | reg foo : UInt<8>, clock with : (reset => (rst, UInt(3)))
+ | foo is invalid
+ | out <= foo""".stripMargin
+
+ val outputState = toLowFirrtl(input)
+
+ Then("'foo' should not have a reset")
+ outputState should containTree {
+ case DefRegister(_, "foo", _, _, UIntLiteral(value, _), WRef("foo", _, _, _)) if value == 0 => true
+ }
+ And("'foo' is connected to a mux with its old reset value")
+ outputState should containTree {
+ case Connect(_, WRef("foo", _, _, _), Mux(_, UIntLiteral(value, _), _, _)) if value == 3 => true
+ }
+ }
+
it should "generate a reset mux for only the portion of an invalid aggregate that is reset" in {
Given("aggregate register 'foo' with 2-bit field 'a' and 1-bit field 'b'")
And("aggregate, invalid wire 'bar' with the same fields")