diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/transforms/ConstantPropagation.scala | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 463 |
1 files changed, 247 insertions, 216 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index ce36dd72..dc9b2bbe 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -28,7 +28,7 @@ object ConstantPropagation { /** Pads e to the width of t */ def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { - case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) + case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e } @@ -44,38 +44,40 @@ object ConstantPropagation { case lit: Literal => require(hi >= lo) UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) - case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { - case t: UIntType => x - case _ => asUInt(x, e.tpe) - } + case x if bitWidth(e.tpe) == bitWidth(x.tpe) => + x.tpe match { + case t: UIntType => x + case _ => asUInt(x, e.tpe) + } case _ => e } } def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head - case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) - // take sign bit if shift amount is larger than arg width - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) - case _ => e - } + case x => + e.args.head match { + // TODO when amount >= x.width, return a zero-width wire + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1))) + // take sign bit if shift amount is larger than arg width + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1))) + case _ => e + } } - - /********************************************** - * REGISTER CONSTANT PROPAGATION HELPER TYPES * - **********************************************/ + /** ******************************************** + * REGISTER CONSTANT PROPAGATION HELPER TYPES * + * ******************************************** + */ // A utility class that is somewhat like an Option but with two variants containing Nothing. // for register constant propagation (register or literal). private abstract class ConstPropBinding[+T] { def resolve[V >: T](that: ConstPropBinding[V]): ConstPropBinding[V] = (this, that) match { - case (x, y) if (x == y) => x + case (x, y) if (x == y) => x case (x, UnboundConstant) => x case (UnboundConstant, y) => y - case _ => NonConstant + case _ => NonConstant } } @@ -103,21 +105,23 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res override def prerequisites = ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq + ++ firrtl.stage.Forms.LowForm + - Dependency(firrtl.passes.Legalize) + + Dependency(firrtl.passes.RemoveValidIf)).toSeq override def optionalPrerequisites = Seq.empty override def optionalPrerequisiteOf = - Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq( + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] + ) override def invalidates(a: Transform): Boolean = a match { case firrtl.passes.Legalize => true - case _ => false + case _ => false } override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation]) @@ -130,7 +134,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } sealed trait FoldCommutativeOp extends SimplifyBinaryOp { - def fold(c1: Literal, c2: Literal): Expression + def fold(c1: Literal, c2: Literal): Expression def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression override def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { @@ -138,7 +142,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) case (lhs, rhs) if (lhs == rhs) => matchingArgsValue(e, lhs) - case _ => e + case _ => e } } @@ -177,20 +181,20 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res */ def apply(prim: DoPrim): Expression = prim.args.head match { case a: Literal => simplifyLiteral(a) - case _ => prim + case _ => prim } } object FoldADD extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match { - case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) - case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) + case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1)) + case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1)) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, w) if v == BigInt(0) => rhs case SIntLiteral(v, w) if v == BigInt(0) => rhs - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = e } @@ -209,77 +213,81 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res object FoldAND extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = { - val width = (c1.width max c2.width).asInstanceOf[IntWidth] + val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth] UIntLiteral.masked(c1.value & c2.value, width) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) - case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe) } object FoldOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = { - val width = (c1.width max c2.width).asInstanceOf[IntWidth] + val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth] UIntLiteral.masked((c1.value | c2.value), width) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe) } object FoldXOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = { - val width = (c1.width max c2.width).asInstanceOf[IntWidth] + val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth] UIntLiteral.masked((c1.value ^ c2.value), width) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe)) } object FoldEqual extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => + DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(1) } object FoldNotEqual extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => + DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0) } private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { - case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) + case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => + UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) case _ => e } private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head - case x => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) - case _ => e - } + case x => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) + case _ => e + } } private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match { @@ -296,53 +304,55 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case _ => e } - private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { def isUInt(e: Expression): Boolean = e.tpe match { case UIntType(_) => true - case _ => false + case _ => false } def isZero(e: Expression) = e match { - case UIntLiteral(value, _) => value == BigInt(0) - case SIntLiteral(value, _) => value == BigInt(0) - case _ => false - } + case UIntLiteral(value, _) => value == BigInt(0) + case SIntLiteral(value, _) => value == BigInt(0) + case _ => false + } x match { - case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero - case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one - case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero - case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one - case ex => ex + case DoPrim(Lt, Seq(a, b), _, _) if isUInt(a) && isZero(b) => zero + case DoPrim(Leq, Seq(a, b), _, _) if isZero(a) && isUInt(b) => one + case DoPrim(Gt, Seq(a, b), _, _) if isZero(a) && isUInt(b) => zero + case DoPrim(Geq, Seq(a, b), _, _) if isUInt(a) && isZero(b) => one + case ex => ex } } def foldIfOutsideRange(x: Expression): Expression = { //Note, only abides by a partial ordering case class Range(min: BigInt, max: BigInt) { - def === (that: Range) = + def ===(that: Range) = Seq(this.min, this.max, that.min, that.max) - .sliding(2,1) + .sliding(2, 1) .map(x => x.head == x(1)) .reduce(_ && _) - def > (that: Range) = this.min > that.max - def >= (that: Range) = this.min >= that.max - def < (that: Range) = this.max < that.min - def <= (that: Range) = this.max <= that.min + def >(that: Range) = this.min > that.max + def >=(that: Range) = this.min >= that.max + def <(that: Range) = this.max < that.min + def <=(that: Range) = this.max <= that.min } def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) - case _ => e.tpe match { - case SIntType(IntWidth(width)) => Range( - min = BigInt(0) - BigInt(2).pow(width.toInt - 1), - max = BigInt(2).pow(width.toInt - 1) - BigInt(1) - ) - case UIntType(IntWidth(width)) => Range( - min = BigInt(0), - max = BigInt(2).pow(width.toInt) - BigInt(1) - ) - } + case _ => + e.tpe match { + case SIntType(IntWidth(width)) => + Range( + min = BigInt(0) - BigInt(2).pow(width.toInt - 1), + max = BigInt(2).pow(width.toInt - 1) - BigInt(1) + ) + case UIntType(IntWidth(width)) => + Range( + min = BigInt(0), + max = BigInt(2).pow(width.toInt) - BigInt(1) + ) + } } // Calculates an expression's range of values x match { @@ -351,27 +361,28 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res def r1 = range(ex.args(1)) ex.op match { // Always true - case Lt if r0 < r1 => one + case Lt if r0 < r1 => one case Leq if r0 <= r1 => one - case Gt if r0 > r1 => one + case Gt if r0 > r1 => one case Geq if r0 >= r1 => one // Always false - case Lt if r0 >= r1 => zero + case Lt if r0 >= r1 => zero case Leq if r0 > r1 => zero - case Gt if r0 <= r1 => zero + case Gt if r0 <= r1 => zero case Geq if r0 < r1 => zero - case _ => ex + case _ => ex } case ex => ex } } def foldIfMatchingArgs(x: Expression) = x match { - case DoPrim(op, Seq(a, b), _, _) if (a == b) => op match { - case (Lt | Gt) => zero - case (Leq | Geq) => one - case _ => x - } + case DoPrim(op, Seq(a, b), _, _) if (a == b) => + op match { + case (Lt | Gt) => zero + case (Leq | Geq) => one + case _ => x + } case _ => x } foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e))) @@ -393,43 +404,47 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } private def constPropPrim(e: DoPrim): Expression = e.op match { - case Shl => foldShiftLeft(e) - case Dshl => foldDynamicShiftLeft(e) - case Shr => foldShiftRight(e) - case Dshr => foldDynamicShiftRight(e) - case Cat => foldConcat(e) - case Add => FoldADD(e) - case Sub => SimplifySUB(e) - case Div => SimplifyDIV(e) - case Rem => SimplifyREM(e) - case And => FoldAND(e) - case Or => FoldOR(e) - case Xor => FoldXOR(e) - case Eq => FoldEqual(e) - case Neq => FoldNotEqual(e) - case Andr => FoldANDR(e) - case Orr => FoldORR(e) - case Xorr => FoldXORR(e) + case Shl => foldShiftLeft(e) + case Dshl => foldDynamicShiftLeft(e) + case Shr => foldShiftRight(e) + case Dshr => foldDynamicShiftRight(e) + case Cat => foldConcat(e) + case Add => FoldADD(e) + case Sub => SimplifySUB(e) + case Div => SimplifyDIV(e) + case Rem => SimplifyREM(e) + case And => FoldAND(e) + case Or => FoldOR(e) + case Xor => FoldXOR(e) + case Eq => FoldEqual(e) + case Neq => FoldNotEqual(e) + case Andr => FoldANDR(e) + case Orr => FoldORR(e) + case Xorr => FoldXORR(e) case (Lt | Leq | Gt | Geq) => foldComparison(e) - case Not => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) - case _ => e - } + case Not => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) + case _ => e + } case AsUInt => e.args.head match { case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case arg => arg.tpe match { - case _: UIntType => arg - case _ => e - } + case arg => + arg.tpe match { + case _: UIntType => arg + case _ => e + } } - case AsSInt => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) - case arg => arg.tpe match { - case _: SIntType => arg - case _ => e + case AsSInt => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w)) + case arg => + arg.tpe match { + case _: SIntType => arg + case _ => e + } } - } case AsClock => val arg = e.args.head arg.tpe match { @@ -442,25 +457,27 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case AsyncResetType => arg case _ => e } - case Pad => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) - case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head - case _ => e - } + case Pad => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w))) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w))) + case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head + case _ => e + } case (Bits | Head | Tail) => constPropBitExtract(e) - case _ => e + case _ => e } private def constPropMuxCond(m: Mux) = m.cond match { case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) - case _ => m + case _ => m } private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { case _ if m.tval == m.fval => m.tval case (t: UIntLiteral, f: UIntLiteral) - if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => m.cond + if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => + m.cond case (t: UIntLiteral, _) if t.value == BigInt(1) && bitWidth(m.tpe) == BigInt(1) => DoPrim(Or, Seq(m.cond, m.fval), Nil, m.tpe) case (_, f: UIntLiteral) if f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => @@ -479,15 +496,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Is "a" a "better name" than "b"? private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') - def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) - def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) - - private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = { - val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) + def optimize(e: Expression): Expression = + constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + def optimize(e: Expression, nodeMap: NodeMap): Expression = + constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + + private def constPropExpression( + nodeMap: NodeMap, + instMap: collection.Map[Instance, OfModule], + constSubOutputs: Map[OfModule, Map[String, Literal]] + )(e: Expression + ): Expression = { + val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs)) val propagated = old match { case p: DoPrim => constPropPrim(p) - case m: Mux => constPropMux(m) - case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => + case m: Mux => constPropMux(m) + case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) => val module = instMap(inst.Instance) @@ -506,17 +530,17 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res * @todo generalize source locator propagation across Expressions and delete this method * @todo is the `orElse` the way we want to do propagation here? */ - private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String]) - (stmt: Statement): Statement = stmt match { - // We check rname because inlining it would cause the original declaration to go away - case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => - val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) - node.copy(info = InfoExpr.orElse(info1, info0)) - case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => - val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) - con.copy(info = InfoExpr.orElse(info1, info0)) - case other => other - } + private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])(stmt: Statement): Statement = + stmt match { + // We check rname because inlining it would cause the original declaration to go away + case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + node.copy(info = InfoExpr.orElse(info1, info0)) + case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + con.copy(info = InfoExpr.orElse(info1, info0)) + case other => other + } /* Constant propagate a Module * @@ -538,12 +562,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res */ @tailrec private def constPropModule( - m: Module, - dontTouches: Set[String], - instMap: collection.Map[Instance, OfModule], - constInputs: Map[String, Literal], - constSubOutputs: Map[OfModule, Map[String, Literal]] - ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { + m: Module, + dontTouches: Set[String], + instMap: collection.Map[Instance, OfModule], + constInputs: Map[String, Literal], + constSubOutputs: Map[OfModule, Map[String, Literal]] + ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { var nPropagated = 0L val nodeMap = new NodeMap() @@ -571,13 +595,13 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // to constant wires, we don't need to worry about propagating primops or muxes since we'll do // that on the next iteration if necessary def backPropExpr(expr: Expression): Expression = { - val old = expr map backPropExpr + val old = expr.map(backPropExpr) val propagated = old match { // When swapping, we swap both rhs and lhs - case ref @ WRef(rname, _,_,_) if swapMap.contains(rname) => + case ref @ WRef(rname, _, _, _) if swapMap.contains(rname) => ref.copy(name = swapMap(rname)) // Only const prop on the rhs - case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => + case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case x => x } @@ -590,27 +614,29 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res def backPropStmt(stmt: Statement): Statement = stmt match { case reg: DefRegister if (WrappedExpression.weq(reg.init, WRef(reg))) => // Self-init reset is an idiom for "no reset," and must be handled separately - swapMap.get(reg.name) - .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName))) - .getOrElse(reg) - case s => s map backPropExpr match { - case decl: IsDeclaration if swapMap.contains(decl.name) => - val newName = swapMap(decl.name) - nPropagated += 1 - decl match { - case node: DefNode => node.copy(name = newName) - case wire: DefWire => wire.copy(name = newName) - case reg: DefRegister => reg.copy(name = newName) - case other => throwInternalError() - } - case other => other map backPropStmt - } + swapMap + .get(reg.name) + .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName))) + .getOrElse(reg) + case s => + s.map(backPropExpr) match { + case decl: IsDeclaration if swapMap.contains(decl.name) => + val newName = swapMap(decl.name) + nPropagated += 1 + decl match { + case node: DefNode => node.copy(name = newName) + case wire: DefWire => wire.copy(name = newName) + case reg: DefRegister => reg.copy(name = newName) + case other => throwInternalError() + } + case other => other.map(backPropStmt) + } } // When propagating a reference, check if we want to keep the name that would be deleted def propagateRef(lname: String, value: Expression, info: Info): Unit = { value match { - case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => + case WRef(rname, _, kind, _) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a // node declaration or the single connection to a wire or register swapMap += (lname -> rname, rname -> lname) @@ -639,25 +665,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns // This requires that reset has been made explicit case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches(lname) => - - /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment. - * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of - * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can - * be resolved: each side is allowed one candidate register and one candidate literal to - * appear in their source trees, referring to the potential constant propagation case that - * they could allow. If the two are compatible (no different bound sources of either of - * the two types), they can be resolved by combining sources. Otherwise, they propagate - * NonConstant values. When encountering a node reference, it expands the node by to its - * RHS assignment and recurses. - * - * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This - * can stifle register constant propagations, which looks at drivers through value-preserving - * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into - * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole - * de-optimization that does not actually appear in the output code. - * - * @return a RegCPEntry describing the constant prop-compatible sources driving this expression - */ + /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment. + * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of + * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can + * be resolved: each side is allowed one candidate register and one candidate literal to + * appear in their source trees, referring to the potential constant propagation case that + * they could allow. If the two are compatible (no different bound sources of either of + * the two types), they can be resolved by combining sources. Otherwise, they propagate + * NonConstant values. When encountering a node reference, it expands the node by to its + * RHS assignment and recurses. + * + * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This + * can stifle register constant propagations, which looks at drivers through value-preserving + * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into + * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole + * de-optimization that does not actually appear in the output code. + * + * @return a RegCPEntry describing the constant prop-compatible sources driving this expression + */ val unbound = RegCPEntry(UnboundConstant, UnboundConstant) val selfBound = RegCPEntry(BoundConstant(lname), UnboundConstant) @@ -684,11 +709,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Updates nodeMap after analyzing the returned value from regConstant def updateNodeMapIfConstant(e: Expression): Unit = regConstant(e, selfBound) match { - case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero) + case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero) case RegCPEntry(BoundConstant(_), UnboundConstant) => nodeMap(lname) = padCPExp(zero) - case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) + case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) case RegCPEntry(BoundConstant(_), BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) - case _ => + case _ => } def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe)) @@ -733,11 +758,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Unify two maps using f to combine values of duplicate keys private def unify[K, V](a: Map[K, V], b: Map[K, V])(f: (V, V) => V): Map[K, V] = - b.foldLeft(a) { case (acc, (k, v)) => - acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) + b.foldLeft(a) { + case (acc, (k, v)) => + acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) } - private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = { val iGraph = InstanceKeyGraph(c) val moduleDeps = iGraph.getChildInstanceMap @@ -754,9 +779,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // are driven with the same constant value. Then, if we find a Module input where each instance // is driven with the same constant (and not seen in a previous iteration), we iterate again @tailrec - def iterate(toVisit: Set[OfModule], - modules: Map[OfModule, Module], - constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = { + def iterate( + toVisit: Set[OfModule], + modules: Map[OfModule, Module], + constInputs: Map[OfModule, Map[String, Literal]] + ): Map[OfModule, DefModule] = { if (toVisit.isEmpty) modules else { // Order from leaf modules to root so that any module driving an output @@ -767,31 +794,36 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Aggreagte Module outputs that are driven constant for use by instaniating Modules // Aggregate submodule inputs driven constant for checking later val (modulesx, _, constInputsx) = - order.foldLeft((modules, - Map[OfModule, Map[String, Literal]](), - Map[OfModule, Map[String, Seq[Literal]]]())) { + order.foldLeft((modules, Map[OfModule, Map[String, Literal]](), Map[OfModule, Map[String, Seq[Literal]]]())) { case ((mmap, constOutputs, constInputsAcc), mname) => val dontTouches = dontTouchMap.getOrElse(mname, Set.empty) - val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname), - constInputs.getOrElse(mname, Map.empty), constOutputs) + val (mx, mco, mci) = constPropModule( + modules(mname), + dontTouches, + moduleDeps(mname), + constInputs.getOrElse(mname, Map.empty), + constOutputs + ) // Accumulate all Literals used to drive a particular Module port val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d)) (mmap + (mname -> mx), constOutputs + (mname -> mco), constInputsx) } // Determine which module inputs have all of the same, new constants driving them - val newProppedInputs = constInputsx.flatMap { case (mname, ports) => - val portsx = ports.flatMap { case (pname, lits) => - val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false) - val isModule = modules.contains(mname) // ExtModules are not contained in modules - val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1 - if (isModule && newPort && allSameConst) Some(pname -> lits.head) - else None - } - if (portsx.nonEmpty) Some(mname -> portsx) else None + val newProppedInputs = constInputsx.flatMap { + case (mname, ports) => + val portsx = ports.flatMap { + case (pname, lits) => + val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false) + val isModule = modules.contains(mname) // ExtModules are not contained in modules + val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1 + if (isModule && newPort && allSameConst) Some(pname -> lits.head) + else None + } + if (portsx.nonEmpty) Some(mname -> portsx) else None } val modsWithConstInputs = newProppedInputs.keySet val newToVisit = modsWithConstInputs ++ - modsWithConstInputs.flatMap(parentGraph.reachableFrom) + modsWithConstInputs.flatMap(parentGraph.reachableFrom) // Combine const inputs (there can't be duplicate values in the inner maps) val nextConstInputs = unify(constInputs, newProppedInputs)((a, b) => a ++ b) iterate(newToVisit.toSet, modulesx, nextConstInputs) @@ -805,7 +837,6 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res c.modules.map(m => mmap.getOrElse(m.OfModule, m)) } - Circuit(c.info, modulesx, c.main) } |
