aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSchuyler Eldridge2020-05-11 14:51:34 -0400
committerGitHub2020-05-11 14:51:34 -0400
commit706fbd7e36d7810fd07b4648d6d9ab8c9e98c598 (patch)
tree2b379714431a9069059fc526aa03ef11e6311802 /src
parent73c5020919c6113b73521138aa3b6ac7728a9dee (diff)
parent7227dba83c971e7353991a0f3ed7d6dac0a795d1 (diff)
Merge pull request #1558 from freechipsproject/constant-prop-reduction-ops-1343
Constant Prop Reduction Operations of Literals
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala57
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala104
2 files changed, 161 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 0b21df21..29410c7f 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -142,6 +142,45 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
}
+ /** Interface for describing a simplification of a reduction primitive op */
+ sealed trait SimplifyReductionOp {
+
+ /** The initial value used in the reduction */
+ def identityValue: Boolean
+
+ /** The reduction function of the primitive op expressed */
+ def reduce: (Boolean, Boolean) => Boolean
+
+ /** Utility to simplify a reduction op of a literal, parameterized by identityValue and reduce methods. This will
+ * return the identityValue in the event of reducing a zero-width literal.
+ */
+ private def simplifyLiteral(a: Literal): Literal = {
+
+ val w: BigInt = getWidth(a) match {
+ case IntWidth(b) => b
+ }
+
+ val v: Seq[Boolean] = s"%${w}s".format(a.value.toString(2)).map(_ == '1')
+
+ (BigInt(0) until w).zip(v).foldLeft(identityValue) {
+ case (acc, (_, x)) => reduce(acc, x)
+ } match {
+ case false => zero
+ case true => one
+ }
+ }
+
+ /** Reduce a reduction primitive op to a simpler expression if possible
+ * @param prim the primitive op to reduce
+ * @return a simplified expression or the original primitive op
+ */
+ def apply(prim: DoPrim): Expression = prim.args.head match {
+ case a: Literal => simplifyLiteral(a)
+ case _ => prim
+ }
+
+ }
+
object FoldADD extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match {
case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
@@ -328,6 +367,21 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e)))
}
+ final object FoldANDR extends SimplifyReductionOp {
+ override def identityValue = true
+ override def reduce = (a: Boolean, b: Boolean) => a & b
+ }
+
+ final object FoldORR extends SimplifyReductionOp {
+ override def identityValue = false
+ override def reduce = (a: Boolean, b: Boolean) => a | b
+ }
+
+ final object FoldXORR extends SimplifyReductionOp {
+ override def identityValue = false
+ override def reduce = (a: Boolean, b: Boolean) => a ^ b
+ }
+
private def constPropPrim(e: DoPrim): Expression = e.op match {
case Shl => foldShiftLeft(e)
case Dshl => foldDynamicShiftLeft(e)
@@ -343,6 +397,9 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case Xor => FoldXOR(e)
case Eq => FoldEqual(e)
case Neq => FoldNotEqual(e)
+ case Andr => FoldANDR(e)
+ case Orr => FoldORR(e)
+ case Xorr => FoldXORR(e)
case (Lt | Leq | Gt | Geq) => foldComparison(e)
case Not => e.args.head match {
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index ba952c50..32303949 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -1404,6 +1404,110 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ )
}
+ behavior of "Reduction operators"
+
+ it should "optimize andr of a literal" in {
+ val input =
+ s"""|circuit Foo:
+ | module Foo:
+ | output _4b0: UInt<1>
+ | output _4b15: UInt<1>
+ | output _4b7: UInt<1>
+ | output _4b1: UInt<1>
+ | output _0b0: UInt<1>
+ | _4b0 <= andr(UInt<4>(0))
+ | _4b15 <= andr(UInt<4>(15))
+ | _4b7 <= andr(UInt<4>(7))
+ | _4b1 <= andr(UInt<4>(1))
+ | wire _0bI: UInt<0>
+ | _0bI is invalid
+ | _0b0 <= andr(_0bI)
+ |""".stripMargin
+ val check =
+ s"""|circuit Foo:
+ | module Foo:
+ | output _4b0: UInt<1>
+ | output _4b15: UInt<1>
+ | output _4b7: UInt<1>
+ | output _4b1: UInt<1>
+ | output _0b0: UInt<1>
+ | _4b0 <= UInt<1>(0)
+ | _4b15 <= UInt<1>(1)
+ | _4b7 <= UInt<1>(0)
+ | _4b1 <= UInt<1>(0)
+ | _0b0 <= UInt<1>(1)
+ |""".stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ it should "optimize orr of a literal" in {
+ val input =
+ s"""|circuit Foo:
+ | module Foo:
+ | output _4b0: UInt<1>
+ | output _4b15: UInt<1>
+ | output _4b7: UInt<1>
+ | output _4b1: UInt<1>
+ | output _0b0: UInt<1>
+ | _4b0 <= orr(UInt<4>(0))
+ | _4b15 <= orr(UInt<4>(15))
+ | _4b7 <= orr(UInt<4>(7))
+ | _4b1 <= orr(UInt<4>(1))
+ | wire _0bI: UInt<0>
+ | _0bI is invalid
+ | _0b0 <= orr(_0bI)
+ |""".stripMargin
+ val check =
+ s"""|circuit Foo:
+ | module Foo:
+ | output _4b0: UInt<1>
+ | output _4b15: UInt<1>
+ | output _4b7: UInt<1>
+ | output _4b1: UInt<1>
+ | output _0b0: UInt<1>
+ | _4b0 <= UInt<1>(0)
+ | _4b15 <= UInt<1>(1)
+ | _4b7 <= UInt<1>(1)
+ | _4b1 <= UInt<1>(1)
+ | _0b0 <= UInt<1>(0)
+ |""".stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ it should "optimize xorr of a literal" in {
+ val input =
+ s"""|circuit Foo:
+ | module Foo:
+ | output _4b0: UInt<1>
+ | output _4b15: UInt<1>
+ | output _4b7: UInt<1>
+ | output _4b1: UInt<1>
+ | output _0b0: UInt<1>
+ | _4b0 <= xorr(UInt<4>(0))
+ | _4b15 <= xorr(UInt<4>(15))
+ | _4b7 <= xorr(UInt<4>(7))
+ | _4b1 <= xorr(UInt<4>(1))
+ | wire _0bI: UInt<0>
+ | _0bI is invalid
+ | _0b0 <= xorr(_0bI)
+ |""".stripMargin
+ val check =
+ s"""|circuit Foo:
+ | module Foo:
+ | output _4b0: UInt<1>
+ | output _4b15: UInt<1>
+ | output _4b7: UInt<1>
+ | output _4b1: UInt<1>
+ | output _0b0: UInt<1>
+ | _4b0 <= UInt<1>(0)
+ | _4b15 <= UInt<1>(0)
+ | _4b7 <= UInt<1>(1)
+ | _4b1 <= UInt<1>(1)
+ | _0b0 <= UInt<1>(0)
+ |""".stripMargin
+ execute(input, check, Seq.empty)
+ }
+
}