From c1504e2179e509632fa8d9ab44d87191b46cf851 Mon Sep 17 00:00:00 2001 From: Jack Date: Mon, 9 May 2016 23:55:47 -0700 Subject: API Cleanup - Expression trait Expression -> abstract class Expression Ref -> Reference abbrev. exp -> expr Add abstract class Literal UIntValue -> UIntLiteral extends Literal SIntValue -> SIntLiteral extends Literal --- src/main/scala/firrtl/passes/ConstProp.scala | 54 ++++++++++++++-------------- 1 file changed, 27 insertions(+), 27 deletions(-) (limited to 'src/main/scala/firrtl/passes/ConstProp.scala') diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index 11c14e56..e562f71d 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -37,20 +37,20 @@ object ConstProp extends Pass { def name = "Constant Propagation" trait FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue): UIntValue - def simplify(e: Expression, lhs: UIntValue, rhs: Expression): Expression + def fold(c1: UIntLiteral, c2: UIntLiteral): UIntLiteral + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression): Expression def apply(e: DoPrim): Expression = (e.args(0), e.args(1)) match { - case (lhs: UIntValue, rhs: UIntValue) => fold(lhs, rhs) - case (lhs: UIntValue, rhs) => simplify(e, lhs, rhs) - case (lhs, rhs: UIntValue) => simplify(e, rhs, lhs) + case (lhs: UIntLiteral, rhs: UIntLiteral) => fold(lhs, rhs) + case (lhs: UIntLiteral, rhs) => simplify(e, lhs, rhs) + case (lhs, rhs: UIntLiteral) => simplify(e, rhs, lhs) case _ => e } } object FoldAND extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value & c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if long_BANG(tpe(rhs)) == w => if (lhs.value == 0) lhs // and(x, 0) => 0 else if (lhs.value == (BigInt(1) << w.toInt) - 1) rhs // and(x, 1) => x @@ -60,8 +60,8 @@ object ConstProp extends Pass { } object FoldOR extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value | c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if long_BANG(tpe(rhs)) == w => if (lhs.value == 0) rhs // or(x, 0) => x else if (lhs.value == (BigInt(1) << w.toInt) - 1) lhs // or(x, 1) => 1 @@ -71,8 +71,8 @@ object ConstProp extends Pass { } object FoldXOR extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value ^ c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if long_BANG(tpe(rhs)) == w => if (lhs.value == 0) rhs // xor(x, 0) => x else e @@ -81,8 +81,8 @@ object ConstProp extends Pass { } object FoldEqual extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(if (c1.value == c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == 1 => if (lhs.value == 1) rhs // eq(x, 1) => x else e @@ -91,8 +91,8 @@ object ConstProp extends Pass { } object FoldNotEqual extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(if (c1.value != c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == w => if (lhs.value == 0) rhs // neq(x, 0) => x else e @@ -101,15 +101,15 @@ object ConstProp extends Pass { } private def foldConcat(e: DoPrim) = (e.args(0), e.args(1)) match { - case (UIntValue(xv, IntWidth(xw)), UIntValue(yv, IntWidth(yw))) => UIntValue(xv << yw.toInt | yv, IntWidth(xw + yw)) + case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) case _ => e } private def foldShiftLeft(e: DoPrim) = e.consts(0).toInt match { case 0 => e.args(0) case x => e.args(0) match { - case UIntValue(v, IntWidth(w)) => UIntValue(v << x, IntWidth(w + x)) - case SIntValue(v, IntWidth(w)) => SIntValue(v << x, IntWidth(w + x)) + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) case _ => e } } @@ -118,9 +118,9 @@ object ConstProp extends Pass { case 0 => e.args(0) case x => e.args(0) match { // TODO when amount >= x.width, return a zero-width wire - case UIntValue(v, IntWidth(w)) => UIntValue(v >> x, IntWidth((w - x) max 1)) + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) // take sign bit if shift amount is larger than arg width - case SIntValue(v, IntWidth(w)) => SIntValue(v >> x, IntWidth((w - x) max 1)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) case _ => e } } @@ -208,15 +208,15 @@ object ConstProp extends Pass { case NEQUAL_OP => FoldNotEqual(e) case LESS_OP|LESS_EQ_OP|GREATER_OP|GREATER_EQ_OP => foldComparison(e) case NOT_OP => e.args(0) match { - case UIntValue(v, IntWidth(w)) => UIntValue(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) case _ => e } case BITS_SELECT_OP => e.args(0) match { - case UIntValue(v, _) => { + case UIntLiteral(v, _) => { val hi = e.consts(0).toInt val lo = e.consts(1).toInt require(hi >= lo) - UIntValue((v >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e))) + UIntLiteral((v >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e))) } case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match { case t: UIntType => x @@ -230,28 +230,28 @@ object ConstProp extends Pass { private def constPropMuxCond(m: Mux) = { // Only propagate a value if its width matches the mux width def propagate(e: Expression, muxWidth: BigInt) = e match { - case UIntValue(v, _) => UIntValue(v, IntWidth(muxWidth)) + case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(muxWidth)) case _ => tpe(e) match { case UIntType(IntWidth(w)) if muxWidth == w => e case _ => m } } (m.cond, m.tpe) match { - case (UIntValue(c, _), UIntType(IntWidth(w))) => propagate(if (c == 1) m.tval else m.fval, w) + case (UIntLiteral(c, _), UIntType(IntWidth(w))) => propagate(if (c == 1) m.tval else m.fval, w) case _ => m } } private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { case _ if m.tval == m.fval => m.tval - case (t: UIntValue, f: UIntValue) => + case (t: UIntLiteral, f: UIntLiteral) => if (t.value == 1 && f.value == 0 && long_BANG(m.tpe) == 1) m.cond else constPropMuxCond(m) case _ => constPropMuxCond(m) } private def constPropNodeRef(r: WRef, e: Expression) = e match { - case _: UIntValue | _: SIntValue | _: WRef => e + case _: UIntLiteral | _: SIntLiteral | _: WRef => e case _ => r } -- cgit v1.2.3