diff options
| author | Schuyler Eldridge | 2020-05-11 14:51:34 -0400 |
|---|---|---|
| committer | GitHub | 2020-05-11 14:51:34 -0400 |
| commit | 706fbd7e36d7810fd07b4648d6d9ab8c9e98c598 (patch) | |
| tree | 2b379714431a9069059fc526aa03ef11e6311802 /src | |
| parent | 73c5020919c6113b73521138aa3b6ac7728a9dee (diff) | |
| parent | 7227dba83c971e7353991a0f3ed7d6dac0a795d1 (diff) | |
Merge pull request #1558 from freechipsproject/constant-prop-reduction-ops-1343
Constant Prop Reduction Operations of Literals
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 57 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 104 |
2 files changed, 161 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 0b21df21..29410c7f 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -142,6 +142,45 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } } + /** Interface for describing a simplification of a reduction primitive op */ + sealed trait SimplifyReductionOp { + + /** The initial value used in the reduction */ + def identityValue: Boolean + + /** The reduction function of the primitive op expressed */ + def reduce: (Boolean, Boolean) => Boolean + + /** Utility to simplify a reduction op of a literal, parameterized by identityValue and reduce methods. This will + * return the identityValue in the event of reducing a zero-width literal. + */ + private def simplifyLiteral(a: Literal): Literal = { + + val w: BigInt = getWidth(a) match { + case IntWidth(b) => b + } + + val v: Seq[Boolean] = s"%${w}s".format(a.value.toString(2)).map(_ == '1') + + (BigInt(0) until w).zip(v).foldLeft(identityValue) { + case (acc, (_, x)) => reduce(acc, x) + } match { + case false => zero + case true => one + } + } + + /** Reduce a reduction primitive op to a simpler expression if possible + * @param prim the primitive op to reduce + * @return a simplified expression or the original primitive op + */ + def apply(prim: DoPrim): Expression = prim.args.head match { + case a: Literal => simplifyLiteral(a) + case _ => prim + } + + } + object FoldADD extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match { case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) @@ -328,6 +367,21 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e))) } + final object FoldANDR extends SimplifyReductionOp { + override def identityValue = true + override def reduce = (a: Boolean, b: Boolean) => a & b + } + + final object FoldORR extends SimplifyReductionOp { + override def identityValue = false + override def reduce = (a: Boolean, b: Boolean) => a | b + } + + final object FoldXORR extends SimplifyReductionOp { + override def identityValue = false + override def reduce = (a: Boolean, b: Boolean) => a ^ b + } + private def constPropPrim(e: DoPrim): Expression = e.op match { case Shl => foldShiftLeft(e) case Dshl => foldDynamicShiftLeft(e) @@ -343,6 +397,9 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case Xor => FoldXOR(e) case Eq => FoldEqual(e) case Neq => FoldNotEqual(e) + case Andr => FoldANDR(e) + case Orr => FoldORR(e) + case Xorr => FoldXORR(e) case (Lt | Leq | Gt | Geq) => foldComparison(e) case Not => e.args.head match { case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index ba952c50..32303949 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -1404,6 +1404,110 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) } + behavior of "Reduction operators" + + it should "optimize andr of a literal" in { + val input = + s"""|circuit Foo: + | module Foo: + | output _4b0: UInt<1> + | output _4b15: UInt<1> + | output _4b7: UInt<1> + | output _4b1: UInt<1> + | output _0b0: UInt<1> + | _4b0 <= andr(UInt<4>(0)) + | _4b15 <= andr(UInt<4>(15)) + | _4b7 <= andr(UInt<4>(7)) + | _4b1 <= andr(UInt<4>(1)) + | wire _0bI: UInt<0> + | _0bI is invalid + | _0b0 <= andr(_0bI) + |""".stripMargin + val check = + s"""|circuit Foo: + | module Foo: + | output _4b0: UInt<1> + | output _4b15: UInt<1> + | output _4b7: UInt<1> + | output _4b1: UInt<1> + | output _0b0: UInt<1> + | _4b0 <= UInt<1>(0) + | _4b15 <= UInt<1>(1) + | _4b7 <= UInt<1>(0) + | _4b1 <= UInt<1>(0) + | _0b0 <= UInt<1>(1) + |""".stripMargin + execute(input, check, Seq.empty) + } + + it should "optimize orr of a literal" in { + val input = + s"""|circuit Foo: + | module Foo: + | output _4b0: UInt<1> + | output _4b15: UInt<1> + | output _4b7: UInt<1> + | output _4b1: UInt<1> + | output _0b0: UInt<1> + | _4b0 <= orr(UInt<4>(0)) + | _4b15 <= orr(UInt<4>(15)) + | _4b7 <= orr(UInt<4>(7)) + | _4b1 <= orr(UInt<4>(1)) + | wire _0bI: UInt<0> + | _0bI is invalid + | _0b0 <= orr(_0bI) + |""".stripMargin + val check = + s"""|circuit Foo: + | module Foo: + | output _4b0: UInt<1> + | output _4b15: UInt<1> + | output _4b7: UInt<1> + | output _4b1: UInt<1> + | output _0b0: UInt<1> + | _4b0 <= UInt<1>(0) + | _4b15 <= UInt<1>(1) + | _4b7 <= UInt<1>(1) + | _4b1 <= UInt<1>(1) + | _0b0 <= UInt<1>(0) + |""".stripMargin + execute(input, check, Seq.empty) + } + + it should "optimize xorr of a literal" in { + val input = + s"""|circuit Foo: + | module Foo: + | output _4b0: UInt<1> + | output _4b15: UInt<1> + | output _4b7: UInt<1> + | output _4b1: UInt<1> + | output _0b0: UInt<1> + | _4b0 <= xorr(UInt<4>(0)) + | _4b15 <= xorr(UInt<4>(15)) + | _4b7 <= xorr(UInt<4>(7)) + | _4b1 <= xorr(UInt<4>(1)) + | wire _0bI: UInt<0> + | _0bI is invalid + | _0b0 <= xorr(_0bI) + |""".stripMargin + val check = + s"""|circuit Foo: + | module Foo: + | output _4b0: UInt<1> + | output _4b15: UInt<1> + | output _4b7: UInt<1> + | output _4b1: UInt<1> + | output _0b0: UInt<1> + | _4b0 <= UInt<1>(0) + | _4b15 <= UInt<1>(0) + | _4b7 <= UInt<1>(1) + | _4b1 <= UInt<1>(1) + | _0b0 <= UInt<1>(0) + |""".stripMargin + execute(input, check, Seq.empty) + } + } |
