aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala41
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala55
2 files changed, 82 insertions, 14 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 8a273476..54338719 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -412,21 +412,34 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) =>
val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
constOutputs(pname) = paddedLit
- // Const prop registers that are fed only a constant or a mux between and constant and the
- // register itself
+ // Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns
// This requires that reset has been made explicit
- case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), expr) if !dontTouches.contains(lname) => expr match {
- case lit: Literal =>
- nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ltpe))
- case Mux(_, tval: WRef, fval: Literal, _) if weq(lref, tval) =>
- nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(fval, ltpe))
- case Mux(_, tval: Literal, fval: WRef, _) if weq(lref, fval) =>
- nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(tval, ltpe))
- case WRef(`lname`, _,_,_) => // If a register is connected to itself, propagate zero
- val zero = passes.RemoveValidIf.getGroundZero(ltpe)
- nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(zero, ltpe))
- case _ =>
- }
+ case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches.contains(lname) =>
+ /** Checks if an RHS expression e of a register assignment is convertible to a constant assignment.
+ * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of cases (1) and (2).
+ * In case (3), it also recursively checks that the two mux cases are convertible to constants and
+ * uses pattern matching on the returned options to check that they are convertible to the *same* constant.
+ * When encountering a node reference, it expands the node by to its RHS assignment and recurses.
+ *
+ * @return an option containing the literal or self-connect that e is convertible to, if any
+ */
+ def regConstant(e: Expression): Option[Expression] = e match {
+ case lit: Literal => Some(pad(lit, ltpe))
+ case WRef(regName, _, RegKind, _) if (regName == lname) => Some(e)
+ case WRef(nodeName, _, NodeKind, _) => nodeMap.get(nodeName).flatMap(regConstant(_))
+ case Mux(_, tval, fval, _) => (regConstant(tval), regConstant(fval)) match {
+ case (Some(wr: WRef), Some(x)) if weq(lref, wr) => Some(x) // Mux(_, selfassign, <constRHSCandidate>)
+ case (Some(x), Some(wr: WRef)) if weq(lref, wr) => Some(x) // Mux(_, <constRHSCandidate>, selfassign)
+ case (x, y) if (x == y) => x // No-op mux
+ case _ => None // At least one case not constant-convertible
+ }
+ case _ => None
+ }
+ def cpExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(e)
+ regConstant(rhs).foreach {
+ case wr: WRef => nodeMap(lname) = cpExp(pad(passes.RemoveValidIf.getGroundZero(ltpe), ltpe))
+ case e => nodeMap(lname) = cpExp(e)
+ }
// Mark instance inputs connected to a constant
case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) =>
val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 8a69fcaa..ee2540e0 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -1005,6 +1005,61 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
execute(input, check, Seq.empty)
}
+ "Registers with constant reset and connection to the same constant" should "be replaced with that constant" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | input cond : UInt<1>
+ | output z : UInt<8>
+ | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb")))
+ | when cond :
+ | r <= UInt<4>("hb")
+ | z <= r""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | input cond : UInt<1>
+ | output z : UInt<8>
+ | z <= UInt<8>("hb")""".stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ "A register with constant reset and all connection to either itself or the same constant" should "be replaced with that constant" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | input cmd : UInt<3>
+ | output z : UInt<8>
+ | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("h7")))
+ | r <= r
+ | when eq(cmd, UInt<3>("h0")) :
+ | r <= UInt<3>("h7")
+ | else :
+ | when eq(cmd, UInt<3>("h1")) :
+ | r <= r
+ | else :
+ | when eq(cmd, UInt<3>("h2")) :
+ | r <= UInt<4>("h7")
+ | else :
+ | r <= r
+ | z <= r""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | input cmd : UInt<3>
+ | output z : UInt<8>
+ | z <= UInt<8>("h7")""".stripMargin
+ execute(input, check, Seq.empty)
+ }
+
"Registers with ONLY constant connection" should "be replaced with that constant" in {
val input =
"""circuit Top :