aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAndrew Waterman2016-04-07 14:07:39 -0700
committerAndrew Waterman2016-04-07 14:07:39 -0700
commit39974611ab6b70d9e86b4e6030bf31b6b80c4582 (patch)
treea05ce9ce0a22ed1693344ed39b1ab57cab9e903f /src
parentecc5c3d0934b11a9b727390853f84996c13dbb42 (diff)
Add basic constant propagation for logical operators
This is deliberately incomplete because I wanted to get feedback before plowing ahead. These passes handle constant propagation for bitwise and equality operators on UInt only, usually only when the widths match.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala73
1 files changed, 73 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 72b96fb5..5c4fbaf6 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -1166,6 +1166,70 @@ object Legalize extends Pass {
object ConstProp extends Pass {
def name = "Constant Propogation"
+ trait FoldLogicalOp {
+ def fold(c1: UIntValue, c2: UIntValue): UIntValue
+ def simplify(e: Expression, lhs: UIntValue, rhs: Expression): Expression
+
+ def apply(e: DoPrim): Expression = (e.args(0), e.args(1)) match {
+ case (lhs: UIntValue, rhs: UIntValue) => fold(lhs, rhs)
+ case (lhs: UIntValue, rhs) => simplify(e, lhs, rhs)
+ case (lhs, rhs: UIntValue) => simplify(e, rhs, lhs)
+ case _ => e
+ }
+ }
+
+ object FoldAND extends FoldLogicalOp {
+ def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value & c2.value, c1.width max c2.width)
+ def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ case IntWidth(w) if long_BANG(tpe(rhs)) == w =>
+ if (lhs.value == 0) lhs // and(x, 0) => 0
+ else if (lhs.value == (BigInt(1) << w.toInt) - 1) rhs // and(x, 1) => x
+ else e
+ case _ => e
+ }
+ }
+
+ object FoldOR extends FoldLogicalOp {
+ def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value | c2.value, c1.width max c2.width)
+ def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ case IntWidth(w) if long_BANG(tpe(rhs)) == w =>
+ if (lhs.value == 0) rhs // or(x, 0) => x
+ else if (lhs.value == (BigInt(1) << w.toInt) - 1) lhs // or(x, 1) => 1
+ else e
+ case _ => e
+ }
+ }
+
+ object FoldXOR extends FoldLogicalOp {
+ def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value ^ c2.value, c1.width max c2.width)
+ def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ case IntWidth(w) if long_BANG(tpe(rhs)) == w =>
+ if (lhs.value == 0) rhs // xor(x, 0) => x
+ else e
+ case _ => e
+ }
+ }
+
+ object FoldEqual extends FoldLogicalOp {
+ def fold(c1: UIntValue, c2: UIntValue) = UIntValue(if (c1.value == c2.value) 1 else 0, IntWidth(1))
+ def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == 1 =>
+ if (lhs.value == 1) rhs // eq(x, 1) => x
+ else e
+ case _ => e
+ }
+ }
+
+ object FoldNotEqual extends FoldLogicalOp {
+ def fold(c1: UIntValue, c2: UIntValue) = UIntValue(if (c1.value != c2.value) 1 else 0, IntWidth(1))
+ def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == w =>
+ if (lhs.value == 0) rhs // neq(x, 0) => x
+ else e
+ case _ => e
+ }
+ }
+
private def constPropPrim(e: DoPrim): Expression = e.op match {
case SHIFT_RIGHT_OP => {
val amount = e.consts(0).toInt
@@ -1178,6 +1242,15 @@ object ConstProp extends Pass {
case _ => e
}
}
+ case AND_OP => FoldAND(e)
+ case OR_OP => FoldOR(e)
+ case XOR_OP => FoldXOR(e)
+ case EQUAL_OP => FoldEqual(e)
+ case NEQUAL_OP => FoldNotEqual(e)
+ case NOT_OP => e.args(0) match {
+ case UIntValue(v, IntWidth(w)) => UIntValue(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
+ case _ => e
+ }
case BITS_SELECT_OP => e.args(0) match {
case UIntValue(v, w) => {
val hi = e.consts(0).toInt