aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAndrew Waterman2016-07-07 01:17:35 -0700
committerAndrew Waterman2016-07-07 02:22:22 -0700
commit359021cd1753cd8ad6e5315da6ef2638c5000323 (patch)
treec9c2d99543de6bf96babb2b8d6b2027d2d36d497 /src
parentdf989658f707d28916d03f37ad2bf65ec327a053 (diff)
Generalize and clean up constant propagation pass
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala102
1 files changed, 54 insertions, 48 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
index 618a96c0..57782a3c 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/passes/ConstProp.scala
@@ -38,66 +38,66 @@ import annotation.tailrec
object ConstProp extends Pass {
def name = "Constant Propagation"
+ private def pad(e: Expression, t: Type) = (long_BANG(e.tpe), long_BANG(t)) match {
+ case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
+ case (we, wt) if we == wt => e
+ }
+
+ private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t)
+
trait FoldLogicalOp {
- def fold(c1: UIntLiteral, c2: UIntLiteral): UIntLiteral
- def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression): Expression
+ def fold(c1: Literal, c2: Literal): Expression
+ def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression
def apply(e: DoPrim): Expression = (e.args(0), e.args(1)) match {
- case (lhs: UIntLiteral, rhs: UIntLiteral) => fold(lhs, rhs)
- case (lhs: UIntLiteral, rhs) => simplify(e, lhs, rhs)
- case (lhs, rhs: UIntLiteral) => simplify(e, rhs, lhs)
+ case (lhs: Literal, rhs: Literal) => fold(lhs, rhs)
+ case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe)
+ case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe)
case _ => e
}
}
object FoldAND extends FoldLogicalOp {
- def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value & c2.value, c1.width max c2.width)
- def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
- case IntWidth(w) if long_BANG(tpe(rhs)) == w =>
- if (lhs.value == 0) lhs // and(x, 0) => 0
- else if (lhs.value == (BigInt(1) << w.toInt) - 1) rhs // and(x, 1) => x
- else e
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width)
+ def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
+ case UIntLiteral(v, w) if v == 0 => UIntLiteral(0, w)
+ case SIntLiteral(v, w) if v == 0 => UIntLiteral(0, w)
+ case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << long_BANG(rhs.tpe).toInt) - 1 => rhs
case _ => e
}
}
object FoldOR extends FoldLogicalOp {
- def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value | c2.value, c1.width max c2.width)
- def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
- case IntWidth(w) if long_BANG(tpe(rhs)) == w =>
- if (lhs.value == 0) rhs // or(x, 0) => x
- else if (lhs.value == (BigInt(1) << w.toInt) - 1) lhs // or(x, 1) => 1
- else e
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width)
+ def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
+ case UIntLiteral(v, _) if v == 0 => rhs
+ case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe)
+ case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << long_BANG(rhs.tpe).toInt) - 1 => lhs
case _ => e
}
}
object FoldXOR extends FoldLogicalOp {
- def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width)
- def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
- case IntWidth(w) if long_BANG(tpe(rhs)) == w =>
- if (lhs.value == 0) rhs // xor(x, 0) => x
- else e
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width)
+ def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
+ case UIntLiteral(v, _) if v == 0 => rhs
+ case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe)
case _ => e
}
}
object FoldEqual extends FoldLogicalOp {
- def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
- def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
- case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == 1 =>
- if (lhs.value == 1) rhs // eq(x, 1) => x
- else e
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
+ def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
+ case UIntLiteral(v, IntWidth(w)) if v == 1 && w == 1 && long_BANG(rhs.tpe) == 1 => rhs
case _ => e
}
}
object FoldNotEqual extends FoldLogicalOp {
- def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
- def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
- case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == w =>
- if (lhs.value == 0) rhs // neq(x, 0) => x
- else e
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
+ def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
+ case UIntLiteral(v, IntWidth(w)) if v == 0 && w == 1 && long_BANG(rhs.tpe) == 1 => rhs
case _ => e
}
}
@@ -213,35 +213,41 @@ object ConstProp extends Pass {
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
case _ => e
}
+ case AsUInt => e.args(0) match {
+ case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
+ case u: UIntLiteral => u
+ case _ => e
+ }
+ case AsSInt => e.args(0) match {
+ case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w))
+ case s: SIntLiteral => s
+ case _ => e
+ }
+ case Pad => e.args(0) match {
+ case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(e.consts(0)))
+ case SIntLiteral(v, _) => SIntLiteral(v, IntWidth(e.consts(0)))
+ case _ if long_BANG(tpe(e.args(0))) == e.consts(0) => e.args(0)
+ case _ => e
+ }
case Bits => e.args(0) match {
- case UIntLiteral(v, _) => {
+ case lit: Literal => {
val hi = e.consts(0).toInt
val lo = e.consts(1).toInt
require(hi >= lo)
- UIntLiteral((v >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e)))
+ UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e)))
}
case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match {
case t: UIntType => x
- case _ => DoPrim(AsUInt, Seq(x), Seq(), tpe(e))
+ case _ => asUInt(x, e.tpe)
}
case _ => e
}
case _ => e
}
- private def constPropMuxCond(m: Mux) = {
- // Only propagate a value if its width matches the mux width
- def propagate(e: Expression, muxWidth: BigInt) = e match {
- case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(muxWidth))
- case _ => tpe(e) match {
- case UIntType(IntWidth(w)) if muxWidth == w => e
- case _ => m
- }
- }
- (m.cond, m.tpe) match {
- case (UIntLiteral(c, _), UIntType(IntWidth(w))) => propagate(if (c == 1) m.tval else m.fval, w)
- case _ => m
- }
+ private def constPropMuxCond(m: Mux) = m.cond match {
+ case UIntLiteral(c, _) => pad(if (c == 1) m.tval else m.fval, m.tpe)
+ case _ => m
}
private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match {