aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/transforms/RemoveReset.scala24
-rw-r--r--src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala115
2 files changed, 138 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala
index 0b8b907d..ed1baf7d 100644
--- a/src/main/scala/firrtl/transforms/RemoveReset.scala
+++ b/src/main/scala/firrtl/transforms/RemoveReset.scala
@@ -5,8 +5,10 @@ package transforms
import firrtl.ir._
import firrtl.Mappers._
+import firrtl.traversals.Foreachers._
+import firrtl.WrappedExpression.we
-import scala.collection.mutable
+import scala.collection.{immutable, mutable}
/** Remove Synchronous Reset
*
@@ -18,10 +20,30 @@ class RemoveReset extends Transform {
private case class Reset(cond: Expression, value: Expression)
+ /** Return an immutable set of all invalid expressions in a module
+ * @param m a module
+ */
+ private def computeInvalids(m: DefModule): immutable.Set[WrappedExpression] = {
+ val invalids = mutable.HashSet.empty[WrappedExpression]
+
+ def onStmt(s: Statement): Unit = s match {
+ case IsInvalid(_, expr) => invalids += we(expr)
+ case Connect(_, lhs, rhs) if invalids.contains(we(rhs)) => invalids += we(lhs)
+ case other => other.foreach(onStmt)
+ }
+
+ m.foreach(onStmt)
+ invalids.toSet
+ }
+
private def onModule(m: DefModule): DefModule = {
val resets = mutable.HashMap.empty[String, Reset]
+ val invalids = computeInvalids(m)
def onStmt(stmt: Statement): Statement = {
stmt match {
+ /* A register is initialized to an invalid expression */
+ case reg @ DefRegister(_, _, _, _, _, init) if invalids.contains(we(init)) =>
+ reg.copy(reset = Utils.zero, init = WRef(reg))
case reg @ DefRegister(_, rname, _, _, reset, init)
if reset != Utils.zero && reset.tpe != AsyncResetType =>
// Add register reset to map
diff --git a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala
new file mode 100644
index 00000000..b9d92a6a
--- /dev/null
+++ b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala
@@ -0,0 +1,115 @@
+// See LICENSE for license details.
+
+package firrtlTests.transforms
+
+import org.scalatest.GivenWhenThen
+
+import firrtlTests.FirrtlFlatSpec
+import firrtlTests.FirrtlCheckers._
+
+import firrtl.{CircuitState, WRef}
+import firrtl.ir.{Connect, Mux}
+import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage}
+
+class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen {
+
+ private def toLowFirrtl(string: String): CircuitState = {
+ When("the circuit is compiled to low FIRRTL")
+ (new FirrtlStage)
+ .execute(Array("-X", "low"), Seq(FirrtlSourceAnnotation(string)))
+ .collectFirst{ case FirrtlCircuitAnnotation(a) => a }
+ .map(a => firrtl.CircuitState(a, firrtl.UnknownForm))
+ .get
+ }
+
+ behavior of "RemoveReset"
+
+ it should "not generate a reset mux for an invalid init" in {
+ Given("a 1-bit register 'foo' initialized to invalid, 1-bit wire 'bar'")
+ val input =
+ """|circuit Example :
+ | module Example :
+ | input clock : Clock
+ | input rst : UInt<1>
+ | input in : UInt<1>
+ | output out : UInt<1>
+ |
+ | wire bar : UInt<1>
+ | bar is invalid
+ |
+ | reg foo : UInt<1>, clock with : (reset => (rst, bar))
+ | foo <= in
+ | out <= foo""".stripMargin
+
+ val outputState = toLowFirrtl(input)
+
+ Then("'foo' is NOT connected to a reset mux")
+ outputState shouldNot containTree { case Connect(_, WRef("foo",_,_,_), Mux(_,_,_,_)) => true }
+ }
+
+ it should "generate a reset mux for only the portion of an invalid aggregate that is reset" in {
+ Given("aggregate register 'foo' with 2-bit field 'a' and 1-bit field 'b'")
+ And("aggregate, invalid wire 'bar' with the same fields")
+ And("'foo' is initialized to 'bar'")
+ And("'bar.a[1]' connected to zero")
+ val input =
+ """|circuit Example :
+ | module Example :
+ | input clock : Clock
+ | input rst : UInt<1>
+ | input in : {a : UInt<1>[2], b : UInt<1>}
+ | output out : {a : UInt<1>[2], b : UInt<1>}
+ |
+ | wire bar : {a : UInt<1>[2], b : UInt<1>}
+ | bar is invalid
+ | bar.a[1] <= UInt<1>(0)
+ |
+ | reg foo : {a : UInt<1>[2], b : UInt<1>}, clock with : (reset => (rst, bar))
+ | foo <= in
+ | out <= foo""".stripMargin
+
+ val outputState = toLowFirrtl(input)
+
+ Then("foo.a[0] is NOT connected to a reset mux")
+ outputState shouldNot containTree { case Connect(_, WRef("foo_a_0",_,_,_), Mux(_,_,_,_)) => true }
+ And("foo.a[1] is connected to a reset mux")
+ outputState should containTree { case Connect(_, WRef("foo_a_1",_,_,_), Mux(_,_,_,_)) => true }
+ And("foo.b is NOT connected to a reset mux")
+ outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_), Mux(_,_,_,_)) => true }
+ }
+
+ it should "propagate invalidations across connects" in {
+ Given("aggregate register 'foo' with 1-bit field 'a' and 1-bit field 'b'")
+ And("aggregate, invalid wires 'bar' and 'baz' with the same fields")
+ And("'foo' is initialized to 'baz'")
+ And("'bar.a' is connected to zero")
+ And("'baz' is connected to 'bar'")
+ val input =
+ """|circuit Example :
+ | module Example :
+ | input clock : Clock
+ | input rst : UInt<1>
+ | input in : { a : UInt<1>, b : UInt<1> }
+ | output out : { a : UInt<1>, b : UInt<1> }
+ |
+ | wire bar : { a : UInt<1>, b : UInt<1> }
+ | bar is invalid
+ | bar.a <= UInt<1>(0)
+ |
+ | wire baz : { a : UInt<1>, b : UInt<1> }
+ | baz is invalid
+ | baz <= bar
+ |
+ | reg foo : { a : UInt<1>, b : UInt<1> }, clock with : (reset => (rst, baz))
+ | foo <= in
+ | out <= foo""".stripMargin
+
+ val outputState = toLowFirrtl(input)
+
+ Then("'foo.a' is connected to a reset mux")
+ outputState should containTree { case Connect(_, WRef("foo_a",_,_,_), Mux(_,_,_,_)) => true }
+ And("'foo.b' is NOT connected to a reset mux")
+ outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_), Mux(_,_,_,_)) => true }
+ }
+
+}