diff options
| author | Andrew Waterman | 2016-07-07 01:17:35 -0700 |
|---|---|---|
| committer | Andrew Waterman | 2016-07-07 02:22:22 -0700 |
| commit | 359021cd1753cd8ad6e5315da6ef2638c5000323 (patch) | |
| tree | c9c2d99543de6bf96babb2b8d6b2027d2d36d497 /src | |
| parent | df989658f707d28916d03f37ad2bf65ec327a053 (diff) | |
Generalize and clean up constant propagation pass
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/ConstProp.scala | 102 |
1 files changed, 54 insertions, 48 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index 618a96c0..57782a3c 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -38,66 +38,66 @@ import annotation.tailrec object ConstProp extends Pass { def name = "Constant Propagation" + private def pad(e: Expression, t: Type) = (long_BANG(e.tpe), long_BANG(t)) match { + case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) + case (we, wt) if we == wt => e + } + + private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) + trait FoldLogicalOp { - def fold(c1: UIntLiteral, c2: UIntLiteral): UIntLiteral - def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression): Expression + def fold(c1: Literal, c2: Literal): Expression + def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression def apply(e: DoPrim): Expression = (e.args(0), e.args(1)) match { - case (lhs: UIntLiteral, rhs: UIntLiteral) => fold(lhs, rhs) - case (lhs: UIntLiteral, rhs) => simplify(e, lhs, rhs) - case (lhs, rhs: UIntLiteral) => simplify(e, rhs, lhs) + case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) + case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) + case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) case _ => e } } object FoldAND extends FoldLogicalOp { - def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { - case IntWidth(w) if long_BANG(tpe(rhs)) == w => - if (lhs.value == 0) lhs // and(x, 0) => 0 - else if (lhs.value == (BigInt(1) << w.toInt) - 1) rhs // and(x, 1) => x - else e + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, w) if v == 0 => UIntLiteral(0, w) + case SIntLiteral(v, w) if v == 0 => UIntLiteral(0, w) + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << long_BANG(rhs.tpe).toInt) - 1 => rhs case _ => e } } object FoldOR extends FoldLogicalOp { - def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { - case IntWidth(w) if long_BANG(tpe(rhs)) == w => - if (lhs.value == 0) rhs // or(x, 0) => x - else if (lhs.value == (BigInt(1) << w.toInt) - 1) lhs // or(x, 1) => 1 - else e + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, _) if v == 0 => rhs + case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << long_BANG(rhs.tpe).toInt) - 1 => lhs case _ => e } } object FoldXOR extends FoldLogicalOp { - def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { - case IntWidth(w) if long_BANG(tpe(rhs)) == w => - if (lhs.value == 0) rhs // xor(x, 0) => x - else e + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, _) if v == 0 => rhs + case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe) case _ => e } } object FoldEqual extends FoldLogicalOp { - def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { - case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == 1 => - if (lhs.value == 1) rhs // eq(x, 1) => x - else e + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, IntWidth(w)) if v == 1 && w == 1 && long_BANG(rhs.tpe) == 1 => rhs case _ => e } } object FoldNotEqual extends FoldLogicalOp { - def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { - case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == w => - if (lhs.value == 0) rhs // neq(x, 0) => x - else e + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, IntWidth(w)) if v == 0 && w == 1 && long_BANG(rhs.tpe) == 1 => rhs case _ => e } } @@ -213,35 +213,41 @@ object ConstProp extends Pass { case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) case _ => e } + case AsUInt => e.args(0) match { + case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) + case u: UIntLiteral => u + case _ => e + } + case AsSInt => e.args(0) match { + case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) + case s: SIntLiteral => s + case _ => e + } + case Pad => e.args(0) match { + case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(e.consts(0))) + case SIntLiteral(v, _) => SIntLiteral(v, IntWidth(e.consts(0))) + case _ if long_BANG(tpe(e.args(0))) == e.consts(0) => e.args(0) + case _ => e + } case Bits => e.args(0) match { - case UIntLiteral(v, _) => { + case lit: Literal => { val hi = e.consts(0).toInt val lo = e.consts(1).toInt require(hi >= lo) - UIntLiteral((v >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e))) + UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e))) } case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match { case t: UIntType => x - case _ => DoPrim(AsUInt, Seq(x), Seq(), tpe(e)) + case _ => asUInt(x, e.tpe) } case _ => e } case _ => e } - private def constPropMuxCond(m: Mux) = { - // Only propagate a value if its width matches the mux width - def propagate(e: Expression, muxWidth: BigInt) = e match { - case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(muxWidth)) - case _ => tpe(e) match { - case UIntType(IntWidth(w)) if muxWidth == w => e - case _ => m - } - } - (m.cond, m.tpe) match { - case (UIntLiteral(c, _), UIntType(IntWidth(w))) => propagate(if (c == 1) m.tval else m.fval, w) - case _ => m - } + private def constPropMuxCond(m: Mux) = m.cond match { + case UIntLiteral(c, _) => pad(if (c == 1) m.tval else m.fval, m.tpe) + case _ => m } private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { |
