aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ConstProp.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/ConstProp.scala')
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala26
1 files changed, 12 insertions, 14 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
index bb8ca549..a95d3de0 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/passes/ConstProp.scala
@@ -60,8 +60,8 @@ object ConstProp extends Pass {
object FoldAND extends FoldLogicalOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, w) if v == 0 => UIntLiteral(0, w)
- case SIntLiteral(v, w) if v == 0 => UIntLiteral(0, w)
+ case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
+ case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs
case _ => e
}
@@ -70,8 +70,8 @@ object ConstProp extends Pass {
object FoldOR extends FoldLogicalOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, _) if v == 0 => rhs
- case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe)
+ case UIntLiteral(v, _) if v == BigInt(0) => rhs
+ case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs
case _ => e
}
@@ -80,8 +80,8 @@ object ConstProp extends Pass {
object FoldXOR extends FoldLogicalOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, _) if v == 0 => rhs
- case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe)
+ case UIntLiteral(v, _) if v == BigInt(0) => rhs
+ case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
case _ => e
}
}
@@ -89,7 +89,7 @@ object ConstProp extends Pass {
object FoldEqual extends FoldLogicalOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, IntWidth(w)) if v == 1 && w == 1 && bitWidth(rhs.tpe) == 1 => rhs
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
case _ => e
}
}
@@ -97,7 +97,7 @@ object ConstProp extends Pass {
object FoldNotEqual extends FoldLogicalOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, IntWidth(w)) if v == 0 && w == 1 && bitWidth(rhs.tpe) == 1 => rhs
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
case _ => e
}
}
@@ -176,7 +176,7 @@ object ConstProp extends Pass {
}
// Calculates an expression's range of values
x match {
- case e: DoPrim => {
+ case e: DoPrim =>
def r0 = range(e.args.head)
def r1 = range(e.args(1))
e.op match {
@@ -192,7 +192,6 @@ object ConstProp extends Pass {
case Geq if (r0 < r1) => zero
case _ => e
}
- }
case e => e
}
}
@@ -230,12 +229,11 @@ object ConstProp extends Pass {
case _ => e
}
case Bits => e.args.head match {
- case lit: Literal => {
+ case lit: Literal =>
val hi = e.consts.head.toInt
val lo = e.consts(1).toInt
require(hi >= lo)
UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
- }
case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
case t: UIntType => x
case _ => asUInt(x, e.tpe)
@@ -246,14 +244,14 @@ object ConstProp extends Pass {
}
private def constPropMuxCond(m: Mux) = m.cond match {
- case UIntLiteral(c, _) => pad(if (c == 1) m.tval else m.fval, m.tpe)
+ case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe)
case _ => m
}
private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match {
case _ if m.tval == m.fval => m.tval
case (t: UIntLiteral, f: UIntLiteral) =>
- if (t.value == 1 && f.value == 0 && bitWidth(m.tpe) == 1) m.cond
+ if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond
else constPropMuxCond(m)
case _ => constPropMuxCond(m)
}