diff options
| -rw-r--r-- | src/main/scala/firrtl/transforms/RemoveReset.scala | 24 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala | 115 |
2 files changed, 138 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index 0b8b907d..ed1baf7d 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -5,8 +5,10 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ +import firrtl.WrappedExpression.we -import scala.collection.mutable +import scala.collection.{immutable, mutable} /** Remove Synchronous Reset * @@ -18,10 +20,30 @@ class RemoveReset extends Transform { private case class Reset(cond: Expression, value: Expression) + /** Return an immutable set of all invalid expressions in a module + * @param m a module + */ + private def computeInvalids(m: DefModule): immutable.Set[WrappedExpression] = { + val invalids = mutable.HashSet.empty[WrappedExpression] + + def onStmt(s: Statement): Unit = s match { + case IsInvalid(_, expr) => invalids += we(expr) + case Connect(_, lhs, rhs) if invalids.contains(we(rhs)) => invalids += we(lhs) + case other => other.foreach(onStmt) + } + + m.foreach(onStmt) + invalids.toSet + } + private def onModule(m: DefModule): DefModule = { val resets = mutable.HashMap.empty[String, Reset] + val invalids = computeInvalids(m) def onStmt(stmt: Statement): Statement = { stmt match { + /* A register is initialized to an invalid expression */ + case reg @ DefRegister(_, _, _, _, _, init) if invalids.contains(we(init)) => + reg.copy(reset = Utils.zero, init = WRef(reg)) case reg @ DefRegister(_, rname, _, _, reset, init) if reset != Utils.zero && reset.tpe != AsyncResetType => // Add register reset to map diff --git a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala new file mode 100644 index 00000000..b9d92a6a --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala @@ -0,0 +1,115 @@ +// See LICENSE for license details. + +package firrtlTests.transforms + +import org.scalatest.GivenWhenThen + +import firrtlTests.FirrtlFlatSpec +import firrtlTests.FirrtlCheckers._ + +import firrtl.{CircuitState, WRef} +import firrtl.ir.{Connect, Mux} +import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage} + +class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { + + private def toLowFirrtl(string: String): CircuitState = { + When("the circuit is compiled to low FIRRTL") + (new FirrtlStage) + .execute(Array("-X", "low"), Seq(FirrtlSourceAnnotation(string))) + .collectFirst{ case FirrtlCircuitAnnotation(a) => a } + .map(a => firrtl.CircuitState(a, firrtl.UnknownForm)) + .get + } + + behavior of "RemoveReset" + + it should "not generate a reset mux for an invalid init" in { + Given("a 1-bit register 'foo' initialized to invalid, 1-bit wire 'bar'") + val input = + """|circuit Example : + | module Example : + | input clock : Clock + | input rst : UInt<1> + | input in : UInt<1> + | output out : UInt<1> + | + | wire bar : UInt<1> + | bar is invalid + | + | reg foo : UInt<1>, clock with : (reset => (rst, bar)) + | foo <= in + | out <= foo""".stripMargin + + val outputState = toLowFirrtl(input) + + Then("'foo' is NOT connected to a reset mux") + outputState shouldNot containTree { case Connect(_, WRef("foo",_,_,_), Mux(_,_,_,_)) => true } + } + + it should "generate a reset mux for only the portion of an invalid aggregate that is reset" in { + Given("aggregate register 'foo' with 2-bit field 'a' and 1-bit field 'b'") + And("aggregate, invalid wire 'bar' with the same fields") + And("'foo' is initialized to 'bar'") + And("'bar.a[1]' connected to zero") + val input = + """|circuit Example : + | module Example : + | input clock : Clock + | input rst : UInt<1> + | input in : {a : UInt<1>[2], b : UInt<1>} + | output out : {a : UInt<1>[2], b : UInt<1>} + | + | wire bar : {a : UInt<1>[2], b : UInt<1>} + | bar is invalid + | bar.a[1] <= UInt<1>(0) + | + | reg foo : {a : UInt<1>[2], b : UInt<1>}, clock with : (reset => (rst, bar)) + | foo <= in + | out <= foo""".stripMargin + + val outputState = toLowFirrtl(input) + + Then("foo.a[0] is NOT connected to a reset mux") + outputState shouldNot containTree { case Connect(_, WRef("foo_a_0",_,_,_), Mux(_,_,_,_)) => true } + And("foo.a[1] is connected to a reset mux") + outputState should containTree { case Connect(_, WRef("foo_a_1",_,_,_), Mux(_,_,_,_)) => true } + And("foo.b is NOT connected to a reset mux") + outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_), Mux(_,_,_,_)) => true } + } + + it should "propagate invalidations across connects" in { + Given("aggregate register 'foo' with 1-bit field 'a' and 1-bit field 'b'") + And("aggregate, invalid wires 'bar' and 'baz' with the same fields") + And("'foo' is initialized to 'baz'") + And("'bar.a' is connected to zero") + And("'baz' is connected to 'bar'") + val input = + """|circuit Example : + | module Example : + | input clock : Clock + | input rst : UInt<1> + | input in : { a : UInt<1>, b : UInt<1> } + | output out : { a : UInt<1>, b : UInt<1> } + | + | wire bar : { a : UInt<1>, b : UInt<1> } + | bar is invalid + | bar.a <= UInt<1>(0) + | + | wire baz : { a : UInt<1>, b : UInt<1> } + | baz is invalid + | baz <= bar + | + | reg foo : { a : UInt<1>, b : UInt<1> }, clock with : (reset => (rst, baz)) + | foo <= in + | out <= foo""".stripMargin + + val outputState = toLowFirrtl(input) + + Then("'foo.a' is connected to a reset mux") + outputState should containTree { case Connect(_, WRef("foo_a",_,_,_), Mux(_,_,_,_)) => true } + And("'foo.b' is NOT connected to a reset mux") + outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_), Mux(_,_,_,_)) => true } + } + +} |
