aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/ConstantPropagation.scala
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/transforms/ConstantPropagation.scala
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala463
1 files changed, 247 insertions, 216 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index ce36dd72..dc9b2bbe 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -28,7 +28,7 @@ object ConstantPropagation {
/** Pads e to the width of t */
def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
- case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
+ case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
case (we, wt) if we == wt => e
}
@@ -44,38 +44,40 @@ object ConstantPropagation {
case lit: Literal =>
require(hi >= lo)
UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
- case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
- case t: UIntType => x
- case _ => asUInt(x, e.tpe)
- }
+ case x if bitWidth(e.tpe) == bitWidth(x.tpe) =>
+ x.tpe match {
+ case t: UIntType => x
+ case _ => asUInt(x, e.tpe)
+ }
case _ => e
}
}
def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
- case x => e.args.head match {
- // TODO when amount >= x.width, return a zero-width wire
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1))
- // take sign bit if shift amount is larger than arg width
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1))
- case _ => e
- }
+ case x =>
+ e.args.head match {
+ // TODO when amount >= x.width, return a zero-width wire
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1)))
+ // take sign bit if shift amount is larger than arg width
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1)))
+ case _ => e
+ }
}
-
- /**********************************************
- * REGISTER CONSTANT PROPAGATION HELPER TYPES *
- **********************************************/
+ /** ********************************************
+ * REGISTER CONSTANT PROPAGATION HELPER TYPES *
+ * ********************************************
+ */
// A utility class that is somewhat like an Option but with two variants containing Nothing.
// for register constant propagation (register or literal).
private abstract class ConstPropBinding[+T] {
def resolve[V >: T](that: ConstPropBinding[V]): ConstPropBinding[V] = (this, that) match {
- case (x, y) if (x == y) => x
+ case (x, y) if (x == y) => x
case (x, UnboundConstant) => x
case (UnboundConstant, y) => y
- case _ => NonConstant
+ case _ => NonConstant
}
}
@@ -103,21 +105,23 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
override def prerequisites =
((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)
- + Dependency(firrtl.passes.RemoveValidIf)).toSeq
+ ++ firrtl.stage.Forms.LowForm
+ - Dependency(firrtl.passes.Legalize)
+ + Dependency(firrtl.passes.RemoveValidIf)).toSeq
override def optionalPrerequisites = Seq.empty
override def optionalPrerequisiteOf =
- Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions),
+ Dependency[SystemVerilogEmitter],
+ Dependency[VerilogEmitter]
+ )
override def invalidates(a: Transform): Boolean = a match {
case firrtl.passes.Legalize => true
- case _ => false
+ case _ => false
}
override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation])
@@ -130,7 +134,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
sealed trait FoldCommutativeOp extends SimplifyBinaryOp {
- def fold(c1: Literal, c2: Literal): Expression
+ def fold(c1: Literal, c2: Literal): Expression
def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression
override def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match {
@@ -138,7 +142,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe)
case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe)
case (lhs, rhs) if (lhs == rhs) => matchingArgsValue(e, lhs)
- case _ => e
+ case _ => e
}
}
@@ -177,20 +181,20 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
*/
def apply(prim: DoPrim): Expression = prim.args.head match {
case a: Literal => simplifyLiteral(a)
- case _ => prim
+ case _ => prim
}
}
object FoldADD extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match {
- case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
- case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
+ case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1))
+ case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1))
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, w) if v == BigInt(0) => rhs
case SIntLiteral(v, w) if v == BigInt(0) => rhs
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = e
}
@@ -209,77 +213,81 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
object FoldAND extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = {
- val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth]
UIntLiteral.masked(c1.value & c2.value, width)
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
- case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
+ case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
+ case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe)
}
object FoldOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = {
- val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth]
UIntLiteral.masked((c1.value | c2.value), width)
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, _) if v == BigInt(0) => rhs
- case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
+ case UIntLiteral(v, _) if v == BigInt(0) => rhs
+ case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe)
}
object FoldXOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = {
- val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth]
UIntLiteral.masked((c1.value ^ c2.value), width)
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe))
}
object FoldEqual extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
- case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) =>
+ DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(1)
}
object FoldNotEqual extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
- case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) =>
+ DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0)
}
private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match {
- case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw))
+ case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) =>
+ UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw))
case _ => e
}
private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
- case x => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x))
- case _ => e
- }
+ case x =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x))
+ case _ => e
+ }
}
private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match {
@@ -296,53 +304,55 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case _ => e
}
-
private def foldComparison(e: DoPrim) = {
def foldIfZeroedArg(x: Expression): Expression = {
def isUInt(e: Expression): Boolean = e.tpe match {
case UIntType(_) => true
- case _ => false
+ case _ => false
}
def isZero(e: Expression) = e match {
- case UIntLiteral(value, _) => value == BigInt(0)
- case SIntLiteral(value, _) => value == BigInt(0)
- case _ => false
- }
+ case UIntLiteral(value, _) => value == BigInt(0)
+ case SIntLiteral(value, _) => value == BigInt(0)
+ case _ => false
+ }
x match {
- case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero
- case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one
- case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero
- case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one
- case ex => ex
+ case DoPrim(Lt, Seq(a, b), _, _) if isUInt(a) && isZero(b) => zero
+ case DoPrim(Leq, Seq(a, b), _, _) if isZero(a) && isUInt(b) => one
+ case DoPrim(Gt, Seq(a, b), _, _) if isZero(a) && isUInt(b) => zero
+ case DoPrim(Geq, Seq(a, b), _, _) if isUInt(a) && isZero(b) => one
+ case ex => ex
}
}
def foldIfOutsideRange(x: Expression): Expression = {
//Note, only abides by a partial ordering
case class Range(min: BigInt, max: BigInt) {
- def === (that: Range) =
+ def ===(that: Range) =
Seq(this.min, this.max, that.min, that.max)
- .sliding(2,1)
+ .sliding(2, 1)
.map(x => x.head == x(1))
.reduce(_ && _)
- def > (that: Range) = this.min > that.max
- def >= (that: Range) = this.min >= that.max
- def < (that: Range) = this.max < that.min
- def <= (that: Range) = this.max <= that.min
+ def >(that: Range) = this.min > that.max
+ def >=(that: Range) = this.min >= that.max
+ def <(that: Range) = this.max < that.min
+ def <=(that: Range) = this.max <= that.min
}
def range(e: Expression): Range = e match {
case UIntLiteral(value, _) => Range(value, value)
case SIntLiteral(value, _) => Range(value, value)
- case _ => e.tpe match {
- case SIntType(IntWidth(width)) => Range(
- min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
- max = BigInt(2).pow(width.toInt - 1) - BigInt(1)
- )
- case UIntType(IntWidth(width)) => Range(
- min = BigInt(0),
- max = BigInt(2).pow(width.toInt) - BigInt(1)
- )
- }
+ case _ =>
+ e.tpe match {
+ case SIntType(IntWidth(width)) =>
+ Range(
+ min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
+ max = BigInt(2).pow(width.toInt - 1) - BigInt(1)
+ )
+ case UIntType(IntWidth(width)) =>
+ Range(
+ min = BigInt(0),
+ max = BigInt(2).pow(width.toInt) - BigInt(1)
+ )
+ }
}
// Calculates an expression's range of values
x match {
@@ -351,27 +361,28 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
def r1 = range(ex.args(1))
ex.op match {
// Always true
- case Lt if r0 < r1 => one
+ case Lt if r0 < r1 => one
case Leq if r0 <= r1 => one
- case Gt if r0 > r1 => one
+ case Gt if r0 > r1 => one
case Geq if r0 >= r1 => one
// Always false
- case Lt if r0 >= r1 => zero
+ case Lt if r0 >= r1 => zero
case Leq if r0 > r1 => zero
- case Gt if r0 <= r1 => zero
+ case Gt if r0 <= r1 => zero
case Geq if r0 < r1 => zero
- case _ => ex
+ case _ => ex
}
case ex => ex
}
}
def foldIfMatchingArgs(x: Expression) = x match {
- case DoPrim(op, Seq(a, b), _, _) if (a == b) => op match {
- case (Lt | Gt) => zero
- case (Leq | Geq) => one
- case _ => x
- }
+ case DoPrim(op, Seq(a, b), _, _) if (a == b) =>
+ op match {
+ case (Lt | Gt) => zero
+ case (Leq | Geq) => one
+ case _ => x
+ }
case _ => x
}
foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e)))
@@ -393,43 +404,47 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
private def constPropPrim(e: DoPrim): Expression = e.op match {
- case Shl => foldShiftLeft(e)
- case Dshl => foldDynamicShiftLeft(e)
- case Shr => foldShiftRight(e)
- case Dshr => foldDynamicShiftRight(e)
- case Cat => foldConcat(e)
- case Add => FoldADD(e)
- case Sub => SimplifySUB(e)
- case Div => SimplifyDIV(e)
- case Rem => SimplifyREM(e)
- case And => FoldAND(e)
- case Or => FoldOR(e)
- case Xor => FoldXOR(e)
- case Eq => FoldEqual(e)
- case Neq => FoldNotEqual(e)
- case Andr => FoldANDR(e)
- case Orr => FoldORR(e)
- case Xorr => FoldXORR(e)
+ case Shl => foldShiftLeft(e)
+ case Dshl => foldDynamicShiftLeft(e)
+ case Shr => foldShiftRight(e)
+ case Dshr => foldDynamicShiftRight(e)
+ case Cat => foldConcat(e)
+ case Add => FoldADD(e)
+ case Sub => SimplifySUB(e)
+ case Div => SimplifyDIV(e)
+ case Rem => SimplifyREM(e)
+ case And => FoldAND(e)
+ case Or => FoldOR(e)
+ case Xor => FoldXOR(e)
+ case Eq => FoldEqual(e)
+ case Neq => FoldNotEqual(e)
+ case Andr => FoldANDR(e)
+ case Orr => FoldORR(e)
+ case Xorr => FoldXORR(e)
case (Lt | Leq | Gt | Geq) => foldComparison(e)
- case Not => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
- case _ => e
- }
+ case Not =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
+ case _ => e
+ }
case AsUInt =>
e.args.head match {
case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
- case arg => arg.tpe match {
- case _: UIntType => arg
- case _ => e
- }
+ case arg =>
+ arg.tpe match {
+ case _: UIntType => arg
+ case _ => e
+ }
}
- case AsSInt => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w))
- case arg => arg.tpe match {
- case _: SIntType => arg
- case _ => e
+ case AsSInt =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w))
+ case arg =>
+ arg.tpe match {
+ case _: SIntType => arg
+ case _ => e
+ }
}
- }
case AsClock =>
val arg = e.args.head
arg.tpe match {
@@ -442,25 +457,27 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case AsyncResetType => arg
case _ => e
}
- case Pad => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w))
- case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
- case _ => e
- }
+ case Pad =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
+ case _ => e
+ }
case (Bits | Head | Tail) => constPropBitExtract(e)
- case _ => e
+ case _ => e
}
private def constPropMuxCond(m: Mux) = m.cond match {
case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe)
- case _ => m
+ case _ => m
}
private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match {
case _ if m.tval == m.fval => m.tval
case (t: UIntLiteral, f: UIntLiteral)
- if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => m.cond
+ if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) =>
+ m.cond
case (t: UIntLiteral, _) if t.value == BigInt(1) && bitWidth(m.tpe) == BigInt(1) =>
DoPrim(Or, Seq(m.cond, m.fval), Nil, m.tpe)
case (_, f: UIntLiteral) if f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) =>
@@ -479,15 +496,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Is "a" a "better name" than "b"?
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
- def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
- def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
-
- private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = {
- val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
+ def optimize(e: Expression): Expression =
+ constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ def optimize(e: Expression, nodeMap: NodeMap): Expression =
+ constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+
+ private def constPropExpression(
+ nodeMap: NodeMap,
+ instMap: collection.Map[Instance, OfModule],
+ constSubOutputs: Map[OfModule, Map[String, Literal]]
+ )(e: Expression
+ ): Expression = {
+ val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs))
val propagated = old match {
case p: DoPrim => constPropPrim(p)
- case m: Mux => constPropMux(m)
- case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
+ case m: Mux => constPropMux(m)
+ case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) =>
val module = instMap(inst.Instance)
@@ -506,17 +530,17 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
* @todo generalize source locator propagation across Expressions and delete this method
* @todo is the `orElse` the way we want to do propagation here?
*/
- private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])
- (stmt: Statement): Statement = stmt match {
- // We check rname because inlining it would cause the original declaration to go away
- case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
- val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
- node.copy(info = InfoExpr.orElse(info1, info0))
- case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
- val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
- con.copy(info = InfoExpr.orElse(info1, info0))
- case other => other
- }
+ private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])(stmt: Statement): Statement =
+ stmt match {
+ // We check rname because inlining it would cause the original declaration to go away
+ case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
+ val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
+ node.copy(info = InfoExpr.orElse(info1, info0))
+ case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
+ val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
+ con.copy(info = InfoExpr.orElse(info1, info0))
+ case other => other
+ }
/* Constant propagate a Module
*
@@ -538,12 +562,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
*/
@tailrec
private def constPropModule(
- m: Module,
- dontTouches: Set[String],
- instMap: collection.Map[Instance, OfModule],
- constInputs: Map[String, Literal],
- constSubOutputs: Map[OfModule, Map[String, Literal]]
- ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
+ m: Module,
+ dontTouches: Set[String],
+ instMap: collection.Map[Instance, OfModule],
+ constInputs: Map[String, Literal],
+ constSubOutputs: Map[OfModule, Map[String, Literal]]
+ ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
var nPropagated = 0L
val nodeMap = new NodeMap()
@@ -571,13 +595,13 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// to constant wires, we don't need to worry about propagating primops or muxes since we'll do
// that on the next iteration if necessary
def backPropExpr(expr: Expression): Expression = {
- val old = expr map backPropExpr
+ val old = expr.map(backPropExpr)
val propagated = old match {
// When swapping, we swap both rhs and lhs
- case ref @ WRef(rname, _,_,_) if swapMap.contains(rname) =>
+ case ref @ WRef(rname, _, _, _) if swapMap.contains(rname) =>
ref.copy(name = swapMap(rname))
// Only const prop on the rhs
- case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
+ case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
case x => x
}
@@ -590,27 +614,29 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
def backPropStmt(stmt: Statement): Statement = stmt match {
case reg: DefRegister if (WrappedExpression.weq(reg.init, WRef(reg))) =>
// Self-init reset is an idiom for "no reset," and must be handled separately
- swapMap.get(reg.name)
- .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName)))
- .getOrElse(reg)
- case s => s map backPropExpr match {
- case decl: IsDeclaration if swapMap.contains(decl.name) =>
- val newName = swapMap(decl.name)
- nPropagated += 1
- decl match {
- case node: DefNode => node.copy(name = newName)
- case wire: DefWire => wire.copy(name = newName)
- case reg: DefRegister => reg.copy(name = newName)
- case other => throwInternalError()
- }
- case other => other map backPropStmt
- }
+ swapMap
+ .get(reg.name)
+ .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName)))
+ .getOrElse(reg)
+ case s =>
+ s.map(backPropExpr) match {
+ case decl: IsDeclaration if swapMap.contains(decl.name) =>
+ val newName = swapMap(decl.name)
+ nPropagated += 1
+ decl match {
+ case node: DefNode => node.copy(name = newName)
+ case wire: DefWire => wire.copy(name = newName)
+ case reg: DefRegister => reg.copy(name = newName)
+ case other => throwInternalError()
+ }
+ case other => other.map(backPropStmt)
+ }
}
// When propagating a reference, check if we want to keep the name that would be deleted
def propagateRef(lname: String, value: Expression, info: Info): Unit = {
value match {
- case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind =>
+ case WRef(rname, _, kind, _) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind =>
assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a
// node declaration or the single connection to a wire or register
swapMap += (lname -> rname, rname -> lname)
@@ -639,25 +665,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns
// This requires that reset has been made explicit
case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches(lname) =>
-
- /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment.
- * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of
- * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can
- * be resolved: each side is allowed one candidate register and one candidate literal to
- * appear in their source trees, referring to the potential constant propagation case that
- * they could allow. If the two are compatible (no different bound sources of either of
- * the two types), they can be resolved by combining sources. Otherwise, they propagate
- * NonConstant values. When encountering a node reference, it expands the node by to its
- * RHS assignment and recurses.
- *
- * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This
- * can stifle register constant propagations, which looks at drivers through value-preserving
- * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into
- * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole
- * de-optimization that does not actually appear in the output code.
- *
- * @return a RegCPEntry describing the constant prop-compatible sources driving this expression
- */
+ /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment.
+ * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of
+ * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can
+ * be resolved: each side is allowed one candidate register and one candidate literal to
+ * appear in their source trees, referring to the potential constant propagation case that
+ * they could allow. If the two are compatible (no different bound sources of either of
+ * the two types), they can be resolved by combining sources. Otherwise, they propagate
+ * NonConstant values. When encountering a node reference, it expands the node by to its
+ * RHS assignment and recurses.
+ *
+ * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This
+ * can stifle register constant propagations, which looks at drivers through value-preserving
+ * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into
+ * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole
+ * de-optimization that does not actually appear in the output code.
+ *
+ * @return a RegCPEntry describing the constant prop-compatible sources driving this expression
+ */
val unbound = RegCPEntry(UnboundConstant, UnboundConstant)
val selfBound = RegCPEntry(BoundConstant(lname), UnboundConstant)
@@ -684,11 +709,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Updates nodeMap after analyzing the returned value from regConstant
def updateNodeMapIfConstant(e: Expression): Unit = regConstant(e, selfBound) match {
- case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero)
+ case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero)
case RegCPEntry(BoundConstant(_), UnboundConstant) => nodeMap(lname) = padCPExp(zero)
- case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
+ case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
case RegCPEntry(BoundConstant(_), BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
- case _ =>
+ case _ =>
}
def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe))
@@ -733,11 +758,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Unify two maps using f to combine values of duplicate keys
private def unify[K, V](a: Map[K, V], b: Map[K, V])(f: (V, V) => V): Map[K, V] =
- b.foldLeft(a) { case (acc, (k, v)) =>
- acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
+ b.foldLeft(a) {
+ case (acc, (k, v)) =>
+ acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
}
-
private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = {
val iGraph = InstanceKeyGraph(c)
val moduleDeps = iGraph.getChildInstanceMap
@@ -754,9 +779,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// are driven with the same constant value. Then, if we find a Module input where each instance
// is driven with the same constant (and not seen in a previous iteration), we iterate again
@tailrec
- def iterate(toVisit: Set[OfModule],
- modules: Map[OfModule, Module],
- constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = {
+ def iterate(
+ toVisit: Set[OfModule],
+ modules: Map[OfModule, Module],
+ constInputs: Map[OfModule, Map[String, Literal]]
+ ): Map[OfModule, DefModule] = {
if (toVisit.isEmpty) modules
else {
// Order from leaf modules to root so that any module driving an output
@@ -767,31 +794,36 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Aggreagte Module outputs that are driven constant for use by instaniating Modules
// Aggregate submodule inputs driven constant for checking later
val (modulesx, _, constInputsx) =
- order.foldLeft((modules,
- Map[OfModule, Map[String, Literal]](),
- Map[OfModule, Map[String, Seq[Literal]]]())) {
+ order.foldLeft((modules, Map[OfModule, Map[String, Literal]](), Map[OfModule, Map[String, Seq[Literal]]]())) {
case ((mmap, constOutputs, constInputsAcc), mname) =>
val dontTouches = dontTouchMap.getOrElse(mname, Set.empty)
- val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname),
- constInputs.getOrElse(mname, Map.empty), constOutputs)
+ val (mx, mco, mci) = constPropModule(
+ modules(mname),
+ dontTouches,
+ moduleDeps(mname),
+ constInputs.getOrElse(mname, Map.empty),
+ constOutputs
+ )
// Accumulate all Literals used to drive a particular Module port
val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d))
(mmap + (mname -> mx), constOutputs + (mname -> mco), constInputsx)
}
// Determine which module inputs have all of the same, new constants driving them
- val newProppedInputs = constInputsx.flatMap { case (mname, ports) =>
- val portsx = ports.flatMap { case (pname, lits) =>
- val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false)
- val isModule = modules.contains(mname) // ExtModules are not contained in modules
- val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1
- if (isModule && newPort && allSameConst) Some(pname -> lits.head)
- else None
- }
- if (portsx.nonEmpty) Some(mname -> portsx) else None
+ val newProppedInputs = constInputsx.flatMap {
+ case (mname, ports) =>
+ val portsx = ports.flatMap {
+ case (pname, lits) =>
+ val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false)
+ val isModule = modules.contains(mname) // ExtModules are not contained in modules
+ val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1
+ if (isModule && newPort && allSameConst) Some(pname -> lits.head)
+ else None
+ }
+ if (portsx.nonEmpty) Some(mname -> portsx) else None
}
val modsWithConstInputs = newProppedInputs.keySet
val newToVisit = modsWithConstInputs ++
- modsWithConstInputs.flatMap(parentGraph.reachableFrom)
+ modsWithConstInputs.flatMap(parentGraph.reachableFrom)
// Combine const inputs (there can't be duplicate values in the inner maps)
val nextConstInputs = unify(constInputs, newProppedInputs)((a, b) => a ++ b)
iterate(newToVisit.toSet, modulesx, nextConstInputs)
@@ -805,7 +837,6 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
c.modules.map(m => mmap.getOrElse(m.OfModule, m))
}
-
Circuit(c.info, modulesx, c.main)
}