aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlbert Magyar2020-02-13 10:57:12 -0700
committerGitHub2020-02-13 10:57:12 -0700
commitd0791c567aceb7ef7d78e6563b3322a9969895c9 (patch)
tree8b1d307fe1db309fd39f603157c8798b6b584315
parent8ef2ab50411642f9eaa2c1aac7147658fbab8368 (diff)
parent555d1e4397f9e750b186f4c07ef3172b7ee39c0d (diff)
Merge pull request #1361 from freechipsproject/const-prop-eq
Constant prop binary PrimOps with matching arguments
-rw-r--r--.travis.yml6
-rw-r--r--regress/Ops.fir54
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala49
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala44
4 files changed, 150 insertions, 3 deletions
diff --git a/.travis.yml b/.travis.yml
index d1d6cd64..2065cb56 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -77,5 +77,11 @@ jobs:
- "travis_wait 30 sleep 1800 &"
- ./.run_formal_checks.sh ICache
- stage: test
+ name: "Formal equivalence: Ops"
+ script:
+ - yosys -V
+ - "travis_wait 30 sleep 1800 &"
+ - ./.run_formal_checks.sh Ops
+ - stage: test
script:
- benchmark/scripts/benchmark_cold_compile.py -N 2 --designs regress/ICache.fir --versions HEAD
diff --git a/regress/Ops.fir b/regress/Ops.fir
new file mode 100644
index 00000000..51cffad5
--- /dev/null
+++ b/regress/Ops.fir
@@ -0,0 +1,54 @@
+circuit Ops:
+ module Ops:
+ input sel: UInt<4>
+ input is: SInt<8>
+ input iu: UInt<8>
+ output os: SInt<14>
+ output ou: UInt<13>
+ output obool: UInt<1>
+
+ os <= SInt(0)
+ ou <= UInt(0)
+ obool <= UInt(0)
+
+ when eq(sel, UInt(0)):
+ os <= add(is, is)
+ ou <= add(iu, iu)
+ else:
+ when eq(sel, UInt(1)):
+ os <= sub(is, is)
+ ou <= sub(iu, iu)
+ else:
+ when eq(sel, UInt(2)):
+ os <= mux(eq(is, SInt(0)), SInt(1), div(is, is))
+ ou <= mux(eq(iu, UInt(0)), UInt(1), div(iu, iu))
+ else:
+ when eq(sel, UInt(3)):
+ os <= rem(is, is)
+ ou <= rem(iu, iu)
+ else:
+ when eq(sel, UInt(4)):
+ ou <= add(and(is, is), and(iu, iu))
+ else:
+ when eq(sel, UInt(5)):
+ ou <= add(or(is, is), or(iu, iu))
+ else:
+ when eq(sel, UInt(4)):
+ ou <= add(xor(is, is), xor(iu, iu))
+ else:
+ when eq(sel, UInt(5)):
+ ou <= add(eq(is, is), eq(iu, iu))
+ else:
+ when eq(sel, UInt(4)):
+ ou <= add(neq(is, is), neq(iu, iu))
+ else:
+ when eq(sel, UInt(5)):
+ ou <= add(geq(is, is), geq(iu, iu))
+ else:
+ when eq(sel, UInt(4)):
+ ou <= add(leq(is, is), leq(iu, iu))
+ else:
+ when eq(sel, UInt(5)):
+ ou <= add(gt(is, is), gt(iu, iu))
+ else:
+ ou <= add(lt(is, is), lt(iu, iu))
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index f450f6a6..201c3325 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -18,6 +18,11 @@ import annotation.tailrec
import collection.mutable
object ConstantPropagation {
+ private def litOfType(value: BigInt, t: Type): Literal = t match {
+ case UIntType(w) => UIntLiteral(value, w)
+ case SIntType(w) => SIntLiteral(value, w)
+ }
+
private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t)
/** Pads e to the width of t */
@@ -99,14 +104,22 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation])
- trait FoldCommutativeOp {
+ sealed trait SimplifyBinaryOp {
+ def matchingArgsValue(e: DoPrim, arg: Expression): Expression
+ def apply(e: DoPrim): Expression = {
+ if (e.args.head == e.args(1)) matchingArgsValue(e, e.args.head) else e
+ }
+ }
+
+ sealed trait FoldCommutativeOp extends SimplifyBinaryOp {
def fold(c1: Literal, c2: Literal): Expression
def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression
- def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match {
+ override def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match {
case (lhs: Literal, rhs: Literal) => fold(lhs, rhs)
case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe)
case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe)
+ case (lhs, rhs) if (lhs == rhs) => matchingArgsValue(e, lhs)
case _ => e
}
}
@@ -121,6 +134,19 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case SIntLiteral(v, w) if v == BigInt(0) => rhs
case _ => e
}
+ def matchingArgsValue(e: DoPrim, arg: Expression) = e
+ }
+
+ object SimplifySUB extends SimplifyBinaryOp {
+ def matchingArgsValue(e: DoPrim, arg: Expression) = litOfType(0, e.tpe)
+ }
+
+ object SimplifyDIV extends SimplifyBinaryOp {
+ def matchingArgsValue(e: DoPrim, arg: Expression) = litOfType(1, e.tpe)
+ }
+
+ object SimplifyREM extends SimplifyBinaryOp {
+ def matchingArgsValue(e: DoPrim, arg: Expression) = litOfType(0, e.tpe)
}
object FoldAND extends FoldCommutativeOp {
@@ -131,6 +157,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs
case _ => e
}
+ def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe)
}
object FoldOR extends FoldCommutativeOp {
@@ -141,6 +168,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs
case _ => e
}
+ def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe)
}
object FoldXOR extends FoldCommutativeOp {
@@ -150,6 +178,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
case _ => e
}
+ def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe))
}
object FoldEqual extends FoldCommutativeOp {
@@ -159,6 +188,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
+ def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(1)
}
object FoldNotEqual extends FoldCommutativeOp {
@@ -168,6 +198,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
+ def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0)
}
private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match {
@@ -267,7 +298,16 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case ex => ex
}
}
- foldIfZeroedArg(foldIfOutsideRange(e))
+
+ def foldIfMatchingArgs(x: Expression) = x match {
+ case DoPrim(op, Seq(a, b), _, _) if (a == b) => op match {
+ case (Lt | Gt) => zero
+ case (Leq | Geq) => one
+ case _ => x
+ }
+ case _ => x
+ }
+ foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e)))
}
private def constPropPrim(e: DoPrim): Expression = e.op match {
@@ -277,6 +317,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case Dshr => foldDynamicShiftRight(e)
case Cat => foldConcat(e)
case Add => FoldADD(e)
+ case Sub => SimplifySUB(e)
+ case Div => SimplifyDIV(e)
+ case Rem => SimplifyREM(e)
case And => FoldAND(e)
case Or => FoldOR(e)
case Xor => FoldXOR(e)
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index ef52507f..cc7a5e32 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -1393,6 +1393,50 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
""".stripMargin
execute(input, check, Seq.empty)
}
+
+ private def matchingArgs(op: String, iType: String, oType: String, result: String): Unit = {
+ val input =
+ s"""circuit Top :
+ | module Top :
+ | input i : ${iType}
+ | output o : ${oType}
+ | o <= ${op}(i, i)
+ """.stripMargin
+ val check =
+ s"""circuit Top :
+ | module Top :
+ | input i : ${iType}
+ | output o : ${oType}
+ | o <= ${result}
+ """.stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ it should "optimize some binary operations when arguments match" in {
+ // Signedness matters
+ matchingArgs("sub", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ )
+ matchingArgs("sub", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ )
+ matchingArgs("div", "UInt<8>", "UInt<8>", """ UInt<8>("h1") """ )
+ matchingArgs("div", "SInt<8>", "SInt<8>", """ SInt<8>("h1") """ )
+ matchingArgs("rem", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ )
+ matchingArgs("rem", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ )
+ matchingArgs("and", "UInt<8>", "UInt<8>", """ i """ )
+ matchingArgs("and", "SInt<8>", "UInt<8>", """ asUInt(i) """ )
+ // Signedness doesn't matter
+ matchingArgs("or", "UInt<8>", "UInt<8>", """ i """ )
+ matchingArgs("or", "SInt<8>", "UInt<8>", """ asUInt(i) """ )
+ matchingArgs("xor", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ )
+ matchingArgs("xor", "SInt<8>", "UInt<8>", """ UInt<8>("h0") """ )
+ // Always true
+ matchingArgs("eq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ )
+ matchingArgs("leq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ )
+ matchingArgs("geq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ )
+ // Never true
+ matchingArgs("neq", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ )
+ matchingArgs("lt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ )
+ matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ )
+ }
+
}