diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 41 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 55 |
2 files changed, 82 insertions, 14 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 8a273476..54338719 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -412,21 +412,34 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] constOutputs(pname) = paddedLit - // Const prop registers that are fed only a constant or a mux between and constant and the - // register itself + // Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns // This requires that reset has been made explicit - case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), expr) if !dontTouches.contains(lname) => expr match { - case lit: Literal => - nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ltpe)) - case Mux(_, tval: WRef, fval: Literal, _) if weq(lref, tval) => - nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(fval, ltpe)) - case Mux(_, tval: Literal, fval: WRef, _) if weq(lref, fval) => - nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(tval, ltpe)) - case WRef(`lname`, _,_,_) => // If a register is connected to itself, propagate zero - val zero = passes.RemoveValidIf.getGroundZero(ltpe) - nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(zero, ltpe)) - case _ => - } + case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches.contains(lname) => + /** Checks if an RHS expression e of a register assignment is convertible to a constant assignment. + * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of cases (1) and (2). + * In case (3), it also recursively checks that the two mux cases are convertible to constants and + * uses pattern matching on the returned options to check that they are convertible to the *same* constant. + * When encountering a node reference, it expands the node by to its RHS assignment and recurses. + * + * @return an option containing the literal or self-connect that e is convertible to, if any + */ + def regConstant(e: Expression): Option[Expression] = e match { + case lit: Literal => Some(pad(lit, ltpe)) + case WRef(regName, _, RegKind, _) if (regName == lname) => Some(e) + case WRef(nodeName, _, NodeKind, _) => nodeMap.get(nodeName).flatMap(regConstant(_)) + case Mux(_, tval, fval, _) => (regConstant(tval), regConstant(fval)) match { + case (Some(wr: WRef), Some(x)) if weq(lref, wr) => Some(x) // Mux(_, selfassign, <constRHSCandidate>) + case (Some(x), Some(wr: WRef)) if weq(lref, wr) => Some(x) // Mux(_, <constRHSCandidate>, selfassign) + case (x, y) if (x == y) => x // No-op mux + case _ => None // At least one case not constant-convertible + } + case _ => None + } + def cpExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(e) + regConstant(rhs).foreach { + case wr: WRef => nodeMap(lname) = cpExp(pad(passes.RemoveValidIf.getGroundZero(ltpe), ltpe)) + case e => nodeMap(lname) = cpExp(e) + } // Mark instance inputs connected to a constant case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) => val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 8a69fcaa..ee2540e0 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -1005,6 +1005,61 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input, check, Seq.empty) } + "Registers with constant reset and connection to the same constant" should "be replaced with that constant" in { + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cond : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | when cond : + | r <= UInt<4>("hb") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cond : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin + execute(input, check, Seq.empty) + } + + "A register with constant reset and all connection to either itself or the same constant" should "be replaced with that constant" in { + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cmd : UInt<3> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("h7"))) + | r <= r + | when eq(cmd, UInt<3>("h0")) : + | r <= UInt<3>("h7") + | else : + | when eq(cmd, UInt<3>("h1")) : + | r <= r + | else : + | when eq(cmd, UInt<3>("h2")) : + | r <= UInt<4>("h7") + | else : + | r <= r + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cmd : UInt<3> + | output z : UInt<8> + | z <= UInt<8>("h7")""".stripMargin + execute(input, check, Seq.empty) + } + "Registers with ONLY constant connection" should "be replaced with that constant" in { val input = """circuit Top : |
