diff options
| author | Albert Magyar | 2020-02-13 10:57:12 -0700 |
|---|---|---|
| committer | GitHub | 2020-02-13 10:57:12 -0700 |
| commit | d0791c567aceb7ef7d78e6563b3322a9969895c9 (patch) | |
| tree | 8b1d307fe1db309fd39f603157c8798b6b584315 | |
| parent | 8ef2ab50411642f9eaa2c1aac7147658fbab8368 (diff) | |
| parent | 555d1e4397f9e750b186f4c07ef3172b7ee39c0d (diff) | |
Merge pull request #1361 from freechipsproject/const-prop-eq
Constant prop binary PrimOps with matching arguments
| -rw-r--r-- | .travis.yml | 6 | ||||
| -rw-r--r-- | regress/Ops.fir | 54 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 49 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 44 |
4 files changed, 150 insertions, 3 deletions
diff --git a/.travis.yml b/.travis.yml index d1d6cd64..2065cb56 100644 --- a/.travis.yml +++ b/.travis.yml @@ -77,5 +77,11 @@ jobs: - "travis_wait 30 sleep 1800 &" - ./.run_formal_checks.sh ICache - stage: test + name: "Formal equivalence: Ops" + script: + - yosys -V + - "travis_wait 30 sleep 1800 &" + - ./.run_formal_checks.sh Ops + - stage: test script: - benchmark/scripts/benchmark_cold_compile.py -N 2 --designs regress/ICache.fir --versions HEAD diff --git a/regress/Ops.fir b/regress/Ops.fir new file mode 100644 index 00000000..51cffad5 --- /dev/null +++ b/regress/Ops.fir @@ -0,0 +1,54 @@ +circuit Ops: + module Ops: + input sel: UInt<4> + input is: SInt<8> + input iu: UInt<8> + output os: SInt<14> + output ou: UInt<13> + output obool: UInt<1> + + os <= SInt(0) + ou <= UInt(0) + obool <= UInt(0) + + when eq(sel, UInt(0)): + os <= add(is, is) + ou <= add(iu, iu) + else: + when eq(sel, UInt(1)): + os <= sub(is, is) + ou <= sub(iu, iu) + else: + when eq(sel, UInt(2)): + os <= mux(eq(is, SInt(0)), SInt(1), div(is, is)) + ou <= mux(eq(iu, UInt(0)), UInt(1), div(iu, iu)) + else: + when eq(sel, UInt(3)): + os <= rem(is, is) + ou <= rem(iu, iu) + else: + when eq(sel, UInt(4)): + ou <= add(and(is, is), and(iu, iu)) + else: + when eq(sel, UInt(5)): + ou <= add(or(is, is), or(iu, iu)) + else: + when eq(sel, UInt(4)): + ou <= add(xor(is, is), xor(iu, iu)) + else: + when eq(sel, UInt(5)): + ou <= add(eq(is, is), eq(iu, iu)) + else: + when eq(sel, UInt(4)): + ou <= add(neq(is, is), neq(iu, iu)) + else: + when eq(sel, UInt(5)): + ou <= add(geq(is, is), geq(iu, iu)) + else: + when eq(sel, UInt(4)): + ou <= add(leq(is, is), leq(iu, iu)) + else: + when eq(sel, UInt(5)): + ou <= add(gt(is, is), gt(iu, iu)) + else: + ou <= add(lt(is, is), lt(iu, iu)) diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index f450f6a6..201c3325 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -18,6 +18,11 @@ import annotation.tailrec import collection.mutable object ConstantPropagation { + private def litOfType(value: BigInt, t: Type): Literal = t match { + case UIntType(w) => UIntLiteral(value, w) + case SIntType(w) => SIntLiteral(value, w) + } + private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) /** Pads e to the width of t */ @@ -99,14 +104,22 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation]) - trait FoldCommutativeOp { + sealed trait SimplifyBinaryOp { + def matchingArgsValue(e: DoPrim, arg: Expression): Expression + def apply(e: DoPrim): Expression = { + if (e.args.head == e.args(1)) matchingArgsValue(e, e.args.head) else e + } + } + + sealed trait FoldCommutativeOp extends SimplifyBinaryOp { def fold(c1: Literal, c2: Literal): Expression def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression - def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { + override def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) + case (lhs, rhs) if (lhs == rhs) => matchingArgsValue(e, lhs) case _ => e } } @@ -121,6 +134,19 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case SIntLiteral(v, w) if v == BigInt(0) => rhs case _ => e } + def matchingArgsValue(e: DoPrim, arg: Expression) = e + } + + object SimplifySUB extends SimplifyBinaryOp { + def matchingArgsValue(e: DoPrim, arg: Expression) = litOfType(0, e.tpe) + } + + object SimplifyDIV extends SimplifyBinaryOp { + def matchingArgsValue(e: DoPrim, arg: Expression) = litOfType(1, e.tpe) + } + + object SimplifyREM extends SimplifyBinaryOp { + def matchingArgsValue(e: DoPrim, arg: Expression) = litOfType(0, e.tpe) } object FoldAND extends FoldCommutativeOp { @@ -131,6 +157,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs case _ => e } + def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe) } object FoldOR extends FoldCommutativeOp { @@ -141,6 +168,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs case _ => e } + def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe) } object FoldXOR extends FoldCommutativeOp { @@ -150,6 +178,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) case _ => e } + def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe)) } object FoldEqual extends FoldCommutativeOp { @@ -159,6 +188,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } + def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(1) } object FoldNotEqual extends FoldCommutativeOp { @@ -168,6 +198,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } + def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0) } private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { @@ -267,7 +298,16 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case ex => ex } } - foldIfZeroedArg(foldIfOutsideRange(e)) + + def foldIfMatchingArgs(x: Expression) = x match { + case DoPrim(op, Seq(a, b), _, _) if (a == b) => op match { + case (Lt | Gt) => zero + case (Leq | Geq) => one + case _ => x + } + case _ => x + } + foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e))) } private def constPropPrim(e: DoPrim): Expression = e.op match { @@ -277,6 +317,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case Dshr => foldDynamicShiftRight(e) case Cat => foldConcat(e) case Add => FoldADD(e) + case Sub => SimplifySUB(e) + case Div => SimplifyDIV(e) + case Rem => SimplifyREM(e) case And => FoldAND(e) case Or => FoldOR(e) case Xor => FoldXOR(e) diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index ef52507f..cc7a5e32 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -1393,6 +1393,50 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { """.stripMargin execute(input, check, Seq.empty) } + + private def matchingArgs(op: String, iType: String, oType: String, result: String): Unit = { + val input = + s"""circuit Top : + | module Top : + | input i : ${iType} + | output o : ${oType} + | o <= ${op}(i, i) + """.stripMargin + val check = + s"""circuit Top : + | module Top : + | input i : ${iType} + | output o : ${oType} + | o <= ${result} + """.stripMargin + execute(input, check, Seq.empty) + } + + it should "optimize some binary operations when arguments match" in { + // Signedness matters + matchingArgs("sub", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) + matchingArgs("sub", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ ) + matchingArgs("div", "UInt<8>", "UInt<8>", """ UInt<8>("h1") """ ) + matchingArgs("div", "SInt<8>", "SInt<8>", """ SInt<8>("h1") """ ) + matchingArgs("rem", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) + matchingArgs("rem", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ ) + matchingArgs("and", "UInt<8>", "UInt<8>", """ i """ ) + matchingArgs("and", "SInt<8>", "UInt<8>", """ asUInt(i) """ ) + // Signedness doesn't matter + matchingArgs("or", "UInt<8>", "UInt<8>", """ i """ ) + matchingArgs("or", "SInt<8>", "UInt<8>", """ asUInt(i) """ ) + matchingArgs("xor", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) + matchingArgs("xor", "SInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) + // Always true + matchingArgs("eq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) + matchingArgs("leq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) + matchingArgs("geq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) + // Never true + matchingArgs("neq", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) + matchingArgs("lt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) + matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) + } + } |
