aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ConstProp.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/ConstProp.scala')
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala122
1 files changed, 62 insertions, 60 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
index 216e94b0..618a96c0 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/passes/ConstProp.scala
@@ -28,8 +28,10 @@ MODIFICATIONS.
package firrtl.passes
import firrtl._
+import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
+import firrtl.PrimOps._
import annotation.tailrec
@@ -37,20 +39,20 @@ object ConstProp extends Pass {
def name = "Constant Propagation"
trait FoldLogicalOp {
- def fold(c1: UIntValue, c2: UIntValue): UIntValue
- def simplify(e: Expression, lhs: UIntValue, rhs: Expression): Expression
+ def fold(c1: UIntLiteral, c2: UIntLiteral): UIntLiteral
+ def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression): Expression
def apply(e: DoPrim): Expression = (e.args(0), e.args(1)) match {
- case (lhs: UIntValue, rhs: UIntValue) => fold(lhs, rhs)
- case (lhs: UIntValue, rhs) => simplify(e, lhs, rhs)
- case (lhs, rhs: UIntValue) => simplify(e, rhs, lhs)
+ case (lhs: UIntLiteral, rhs: UIntLiteral) => fold(lhs, rhs)
+ case (lhs: UIntLiteral, rhs) => simplify(e, lhs, rhs)
+ case (lhs, rhs: UIntLiteral) => simplify(e, rhs, lhs)
case _ => e
}
}
object FoldAND extends FoldLogicalOp {
- def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value & c2.value, c1.width max c2.width)
- def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value & c2.value, c1.width max c2.width)
+ def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
case IntWidth(w) if long_BANG(tpe(rhs)) == w =>
if (lhs.value == 0) lhs // and(x, 0) => 0
else if (lhs.value == (BigInt(1) << w.toInt) - 1) rhs // and(x, 1) => x
@@ -60,8 +62,8 @@ object ConstProp extends Pass {
}
object FoldOR extends FoldLogicalOp {
- def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value | c2.value, c1.width max c2.width)
- def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value | c2.value, c1.width max c2.width)
+ def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
case IntWidth(w) if long_BANG(tpe(rhs)) == w =>
if (lhs.value == 0) rhs // or(x, 0) => x
else if (lhs.value == (BigInt(1) << w.toInt) - 1) lhs // or(x, 1) => 1
@@ -71,8 +73,8 @@ object ConstProp extends Pass {
}
object FoldXOR extends FoldLogicalOp {
- def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value ^ c2.value, c1.width max c2.width)
- def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width)
+ def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
case IntWidth(w) if long_BANG(tpe(rhs)) == w =>
if (lhs.value == 0) rhs // xor(x, 0) => x
else e
@@ -81,8 +83,8 @@ object ConstProp extends Pass {
}
object FoldEqual extends FoldLogicalOp {
- def fold(c1: UIntValue, c2: UIntValue) = UIntValue(if (c1.value == c2.value) 1 else 0, IntWidth(1))
- def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
+ def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == 1 =>
if (lhs.value == 1) rhs // eq(x, 1) => x
else e
@@ -91,8 +93,8 @@ object ConstProp extends Pass {
}
object FoldNotEqual extends FoldLogicalOp {
- def fold(c1: UIntValue, c2: UIntValue) = UIntValue(if (c1.value != c2.value) 1 else 0, IntWidth(1))
- def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match {
+ def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
+ def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match {
case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == w =>
if (lhs.value == 0) rhs // neq(x, 0) => x
else e
@@ -101,15 +103,15 @@ object ConstProp extends Pass {
}
private def foldConcat(e: DoPrim) = (e.args(0), e.args(1)) match {
- case (UIntValue(xv, IntWidth(xw)), UIntValue(yv, IntWidth(yw))) => UIntValue(xv << yw.toInt | yv, IntWidth(xw + yw))
+ case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw))
case _ => e
}
private def foldShiftLeft(e: DoPrim) = e.consts(0).toInt match {
case 0 => e.args(0)
case x => e.args(0) match {
- case UIntValue(v, IntWidth(w)) => UIntValue(v << x, IntWidth(w + x))
- case SIntValue(v, IntWidth(w)) => SIntValue(v << x, IntWidth(w + x))
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x))
case _ => e
}
}
@@ -118,9 +120,9 @@ object ConstProp extends Pass {
case 0 => e.args(0)
case x => e.args(0) match {
// TODO when amount >= x.width, return a zero-width wire
- case UIntValue(v, IntWidth(w)) => UIntValue(v >> x, IntWidth((w - x) max 1))
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1))
// take sign bit if shift amount is larger than arg width
- case SIntValue(v, IntWidth(w)) => SIntValue(v >> x, IntWidth((w - x) max 1))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1))
case _ => e
}
}
@@ -132,15 +134,15 @@ object ConstProp extends Pass {
case _ => false
}
def isZero(e: Expression) = e match {
- case UIntValue(value,_) => value == BigInt(0)
- case SIntValue(value,_) => value == BigInt(0)
+ case UIntLiteral(value, _) => value == BigInt(0)
+ case SIntLiteral(value, _) => value == BigInt(0)
case _ => false
}
x match {
- case DoPrim(LESS_OP, Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => zero
- case DoPrim(LESS_EQ_OP, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => one
- case DoPrim(GREATER_OP, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => zero
- case DoPrim(GREATER_EQ_OP,Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => one
+ case DoPrim(Lt, Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => zero
+ case DoPrim(Leq, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => one
+ case DoPrim(Gt, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => zero
+ case DoPrim(Geq, Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => one
case e => e
}
}
@@ -159,8 +161,8 @@ object ConstProp extends Pass {
def <= (that: Range) = this.max <= that.min
}
def range(e: Expression): Range = e match {
- case UIntValue(value, _) => Range(value, value)
- case SIntValue(value, _) => Range(value, value)
+ case UIntLiteral(value, _) => Range(value, value)
+ case SIntLiteral(value, _) => Range(value, value)
case _ => tpe(e) match {
case SIntType(IntWidth(width)) => Range(
min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
@@ -179,15 +181,15 @@ object ConstProp extends Pass {
def r1 = range(e.args(1))
e.op match {
// Always true
- case LESS_OP if (r0 < r1) => one
- case LESS_EQ_OP if (r0 <= r1) => one
- case GREATER_OP if (r0 > r1) => one
- case GREATER_EQ_OP if (r0 >= r1) => one
+ case Lt if (r0 < r1) => one
+ case Leq if (r0 <= r1) => one
+ case Gt if (r0 > r1) => one
+ case Geq if (r0 >= r1) => one
// Always false
- case LESS_OP if (r0 >= r1) => zero
- case LESS_EQ_OP if (r0 > r1) => zero
- case GREATER_OP if (r0 <= r1) => zero
- case GREATER_EQ_OP if (r0 < r1) => zero
+ case Lt if (r0 >= r1) => zero
+ case Leq if (r0 > r1) => zero
+ case Gt if (r0 <= r1) => zero
+ case Geq if (r0 < r1) => zero
case _ => e
}
}
@@ -198,29 +200,29 @@ object ConstProp extends Pass {
}
private def constPropPrim(e: DoPrim): Expression = e.op match {
- case SHIFT_LEFT_OP => foldShiftLeft(e)
- case SHIFT_RIGHT_OP => foldShiftRight(e)
- case CONCAT_OP => foldConcat(e)
- case AND_OP => FoldAND(e)
- case OR_OP => FoldOR(e)
- case XOR_OP => FoldXOR(e)
- case EQUAL_OP => FoldEqual(e)
- case NEQUAL_OP => FoldNotEqual(e)
- case LESS_OP|LESS_EQ_OP|GREATER_OP|GREATER_EQ_OP => foldComparison(e)
- case NOT_OP => e.args(0) match {
- case UIntValue(v, IntWidth(w)) => UIntValue(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
+ case Shl => foldShiftLeft(e)
+ case Shr => foldShiftRight(e)
+ case Cat => foldConcat(e)
+ case And => FoldAND(e)
+ case Or => FoldOR(e)
+ case Xor => FoldXOR(e)
+ case Eq => FoldEqual(e)
+ case Neq => FoldNotEqual(e)
+ case (Lt | Leq | Gt | Geq) => foldComparison(e)
+ case Not => e.args(0) match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
case _ => e
}
- case BITS_SELECT_OP => e.args(0) match {
- case UIntValue(v, _) => {
+ case Bits => e.args(0) match {
+ case UIntLiteral(v, _) => {
val hi = e.consts(0).toInt
val lo = e.consts(1).toInt
require(hi >= lo)
- UIntValue((v >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e)))
+ UIntLiteral((v >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e)))
}
case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match {
case t: UIntType => x
- case _ => DoPrim(AS_UINT_OP, Seq(x), Seq(), tpe(e))
+ case _ => DoPrim(AsUInt, Seq(x), Seq(), tpe(e))
}
case _ => e
}
@@ -230,33 +232,33 @@ object ConstProp extends Pass {
private def constPropMuxCond(m: Mux) = {
// Only propagate a value if its width matches the mux width
def propagate(e: Expression, muxWidth: BigInt) = e match {
- case UIntValue(v, _) => UIntValue(v, IntWidth(muxWidth))
+ case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(muxWidth))
case _ => tpe(e) match {
case UIntType(IntWidth(w)) if muxWidth == w => e
case _ => m
}
}
(m.cond, m.tpe) match {
- case (UIntValue(c, _), UIntType(IntWidth(w))) => propagate(if (c == 1) m.tval else m.fval, w)
+ case (UIntLiteral(c, _), UIntType(IntWidth(w))) => propagate(if (c == 1) m.tval else m.fval, w)
case _ => m
}
}
private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match {
case _ if m.tval == m.fval => m.tval
- case (t: UIntValue, f: UIntValue) =>
+ case (t: UIntLiteral, f: UIntLiteral) =>
if (t.value == 1 && f.value == 0 && long_BANG(m.tpe) == 1) m.cond
else constPropMuxCond(m)
case _ => constPropMuxCond(m)
}
private def constPropNodeRef(r: WRef, e: Expression) = e match {
- case _: UIntValue | _: SIntValue | _: WRef => e
+ case _: UIntLiteral | _: SIntLiteral | _: WRef => e
case _ => r
}
@tailrec
- private def constPropModule(m: InModule): InModule = {
+ private def constPropModule(m: Module): Module = {
var nPropagated = 0L
val nodeMap = collection.mutable.HashMap[String, Expression]()
@@ -273,7 +275,7 @@ object ConstProp extends Pass {
propagated
}
- def constPropStmt(s: Stmt): Stmt = {
+ def constPropStmt(s: Statement): Statement = {
s match {
case x: DefNode => nodeMap(x.name) = x.value
case _ =>
@@ -281,14 +283,14 @@ object ConstProp extends Pass {
s map constPropStmt map constPropExpression
}
- val res = InModule(m.info, m.name, m.ports, constPropStmt(m.body))
+ val res = Module(m.info, m.name, m.ports, constPropStmt(m.body))
if (nPropagated > 0) constPropModule(res) else res
}
def run(c: Circuit): Circuit = {
val modulesx = c.modules.map {
- case m: ExModule => m
- case m: InModule => constPropModule(m)
+ case m: ExtModule => m
+ case m: Module => constPropModule(m)
}
Circuit(c.info, modulesx, c.main)
}