diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/ConstProp.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/ConstProp.scala | 26 |
1 files changed, 12 insertions, 14 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index bb8ca549..a95d3de0 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -60,8 +60,8 @@ object ConstProp extends Pass { object FoldAND extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, w) if v == 0 => UIntLiteral(0, w) - case SIntLiteral(v, w) if v == 0 => UIntLiteral(0, w) + case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs case _ => e } @@ -70,8 +70,8 @@ object ConstProp extends Pass { object FoldOR extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == 0 => rhs - case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe) + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs case _ => e } @@ -80,8 +80,8 @@ object ConstProp extends Pass { object FoldXOR extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == 0 => rhs - case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe) + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) case _ => e } } @@ -89,7 +89,7 @@ object ConstProp extends Pass { object FoldEqual extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == 1 && w == 1 && bitWidth(rhs.tpe) == 1 => rhs + case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs case _ => e } } @@ -97,7 +97,7 @@ object ConstProp extends Pass { object FoldNotEqual extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == 0 && w == 1 && bitWidth(rhs.tpe) == 1 => rhs + case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs case _ => e } } @@ -176,7 +176,7 @@ object ConstProp extends Pass { } // Calculates an expression's range of values x match { - case e: DoPrim => { + case e: DoPrim => def r0 = range(e.args.head) def r1 = range(e.args(1)) e.op match { @@ -192,7 +192,6 @@ object ConstProp extends Pass { case Geq if (r0 < r1) => zero case _ => e } - } case e => e } } @@ -230,12 +229,11 @@ object ConstProp extends Pass { case _ => e } case Bits => e.args.head match { - case lit: Literal => { + case lit: Literal => val hi = e.consts.head.toInt val lo = e.consts(1).toInt require(hi >= lo) UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) - } case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { case t: UIntType => x case _ => asUInt(x, e.tpe) @@ -246,14 +244,14 @@ object ConstProp extends Pass { } private def constPropMuxCond(m: Mux) = m.cond match { - case UIntLiteral(c, _) => pad(if (c == 1) m.tval else m.fval, m.tpe) + case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) case _ => m } private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { case _ if m.tval == m.fval => m.tval case (t: UIntLiteral, f: UIntLiteral) => - if (t.value == 1 && f.value == 0 && bitWidth(m.tpe) == 1) m.cond + if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond else constPropMuxCond(m) case _ => constPropMuxCond(m) } |
