From 39665e1f74cfe8243067442cccf4e7eab66ade68 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Wed, 28 Jun 2017 17:52:56 -0700 Subject: Promote ConstProp to a transform --- src/main/scala/firrtl/LoweringCompilers.scala | 6 +- src/main/scala/firrtl/passes/ConstProp.scala | 295 -------------------- .../firrtl/transforms/ConstantPropagation.scala | 303 +++++++++++++++++++++ src/test/scala/firrtlTests/CInferMDirSpec.scala | 3 +- src/test/scala/firrtlTests/ChirrtlMemSpec.scala | 3 +- .../firrtlTests/ConstantPropagationTests.scala | 39 ++- src/test/scala/firrtlTests/LowerTypesSpec.scala | 3 +- src/test/scala/firrtlTests/ReplSeqMemTests.scala | 2 +- src/test/scala/firrtlTests/UnitTests.scala | 12 +- 9 files changed, 351 insertions(+), 315 deletions(-) delete mode 100644 src/main/scala/firrtl/passes/ConstProp.scala create mode 100644 src/main/scala/firrtl/transforms/ConstantPropagation.scala (limited to 'src') diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 66ae1673..8dd9b180 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -98,12 +98,12 @@ class LowFirrtlOptimization extends CoreTransform { def outputForm = LowForm def transforms = Seq( passes.RemoveValidIf, - passes.ConstProp, + new firrtl.transforms.ConstantPropagation, passes.PadWidths, - passes.ConstProp, + new firrtl.transforms.ConstantPropagation, passes.Legalize, passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - passes.ConstProp, + new firrtl.transforms.ConstantPropagation, passes.SplitExpressions, passes.CommonSubexpressionElimination, new firrtl.transforms.DeadCodeElimination) diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala deleted file mode 100644 index f2aa1a03..00000000 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ /dev/null @@ -1,295 +0,0 @@ -// See LICENSE for license details. - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Utils._ -import firrtl.Mappers._ -import firrtl.PrimOps._ - -import annotation.tailrec - -object ConstProp extends Pass { - private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { - case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) - case (we, wt) if we == wt => e - } - - private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) - - trait FoldLogicalOp { - def fold(c1: Literal, c2: Literal): Expression - def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression - - def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { - case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) - case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) - case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) - case _ => e - } - } - - object FoldAND extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) - case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) - case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs - case _ => e - } - } - - object FoldOR extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) - case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs - case _ => e - } - } - - object FoldXOR extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) - case _ => e - } - } - - object FoldEqual extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case _ => e - } - } - - object FoldNotEqual extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case _ => e - } - } - - private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { - case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) - case _ => e - } - - private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { - case 0 => e.args.head - case x => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) - case _ => e - } - } - - private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { - case 0 => e.args.head - case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) - // take sign bit if shift amount is larger than arg width - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) - case _ => e - } - } - - private def foldComparison(e: DoPrim) = { - def foldIfZeroedArg(x: Expression): Expression = { - def isUInt(e: Expression): Boolean = e.tpe match { - case UIntType(_) => true - case _ => false - } - def isZero(e: Expression) = e match { - case UIntLiteral(value, _) => value == BigInt(0) - case SIntLiteral(value, _) => value == BigInt(0) - case _ => false - } - x match { - case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero - case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one - case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero - case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one - case ex => ex - } - } - - def foldIfOutsideRange(x: Expression): Expression = { - //Note, only abides by a partial ordering - case class Range(min: BigInt, max: BigInt) { - def === (that: Range) = - Seq(this.min, this.max, that.min, that.max) - .sliding(2,1) - .map(x => x.head == x(1)) - .reduce(_ && _) - def > (that: Range) = this.min > that.max - def >= (that: Range) = this.min >= that.max - def < (that: Range) = this.max < that.min - def <= (that: Range) = this.max <= that.min - } - def range(e: Expression): Range = e match { - case UIntLiteral(value, _) => Range(value, value) - case SIntLiteral(value, _) => Range(value, value) - case _ => e.tpe match { - case SIntType(IntWidth(width)) => Range( - min = BigInt(0) - BigInt(2).pow(width.toInt - 1), - max = BigInt(2).pow(width.toInt - 1) - BigInt(1) - ) - case UIntType(IntWidth(width)) => Range( - min = BigInt(0), - max = BigInt(2).pow(width.toInt) - BigInt(1) - ) - } - } - // Calculates an expression's range of values - x match { - case ex: DoPrim => - def r0 = range(ex.args.head) - def r1 = range(ex.args(1)) - ex.op match { - // Always true - case Lt if r0 < r1 => one - case Leq if r0 <= r1 => one - case Gt if r0 > r1 => one - case Geq if r0 >= r1 => one - // Always false - case Lt if r0 >= r1 => zero - case Leq if r0 > r1 => zero - case Gt if r0 <= r1 => zero - case Geq if r0 < r1 => zero - case _ => ex - } - case ex => ex - } - } - foldIfZeroedArg(foldIfOutsideRange(e)) - } - - private def constPropPrim(e: DoPrim): Expression = e.op match { - case Shl => foldShiftLeft(e) - case Shr => foldShiftRight(e) - case Cat => foldConcat(e) - case And => FoldAND(e) - case Or => FoldOR(e) - case Xor => FoldXOR(e) - case Eq => FoldEqual(e) - case Neq => FoldNotEqual(e) - case (Lt | Leq | Gt | Geq) => foldComparison(e) - case Not => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) - case _ => e - } - case AsUInt => e.args.head match { - case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case u: UIntLiteral => u - case _ => e - } - case AsSInt => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) - case s: SIntLiteral => s - case _ => e - } - case Pad => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) - case _ if bitWidth(e.args.head.tpe) == e.consts.head => e.args.head - case _ => e - } - case Bits => e.args.head match { - case lit: Literal => - val hi = e.consts.head.toInt - val lo = e.consts(1).toInt - require(hi >= lo) - UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) - case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { - case t: UIntType => x - case _ => asUInt(x, e.tpe) - } - case _ => e - } - case _ => e - } - - private def constPropMuxCond(m: Mux) = m.cond match { - case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) - case _ => m - } - - private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { - case _ if m.tval == m.fval => m.tval - case (t: UIntLiteral, f: UIntLiteral) => - if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond - else constPropMuxCond(m) - case _ => constPropMuxCond(m) - } - - private def constPropNodeRef(r: WRef, e: Expression) = e match { - case _: UIntLiteral | _: SIntLiteral | _: WRef => e - case _ => r - } - - // Two pass process - // 1. Propagate constants in expressions and forward propagate references - // 2. Propagate references again for backwards reference (Wires) - // TODO Replacing all wires with nodes makes the second pass unnecessary - @tailrec - private def constPropModule(m: Module): Module = { - var nPropagated = 0L - val nodeMap = collection.mutable.HashMap[String, Expression]() - - def backPropExpr(expr: Expression): Expression = { - val old = expr map backPropExpr - val propagated = old match { - case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) - case x => x - } - if (old ne propagated) { - nPropagated += 1 - } - propagated - } - def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr - - def constPropExpression(e: Expression): Expression = { - val old = e map constPropExpression - val propagated = old match { - case p: DoPrim => constPropPrim(p) - case m: Mux => constPropMux(m) - case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) - case x => x - } - propagated - } - - def constPropStmt(s: Statement): Statement = { - val stmtx = s map constPropStmt map constPropExpression - stmtx match { - case x: DefNode => nodeMap(x.name) = x.value - case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => - val exprx = constPropExpression(pad(expr, wtpe)) - nodeMap(wname) = exprx - case _ => - } - stmtx - } - - val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) - if (nPropagated > 0) constPropModule(res) else res - } - - def run(c: Circuit): Circuit = { - val modulesx = c.modules.map { - case m: ExtModule => m - case m: Module => constPropModule(m) - } - Circuit(c.info, modulesx, c.main) - } -} diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala new file mode 100644 index 00000000..930fe45a --- /dev/null +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -0,0 +1,303 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.PrimOps._ + +import annotation.tailrec + +class ConstantPropagation extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { + case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) + case (we, wt) if we == wt => e + } + + private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) + + trait FoldLogicalOp { + def fold(c1: Literal, c2: Literal): Expression + def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression + + def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { + case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) + case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) + case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) + case _ => e + } + } + + object FoldAND extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs + case _ => e + } + } + + object FoldOR extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs + case _ => e + } + } + + object FoldXOR extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case _ => e + } + } + + object FoldEqual extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs + case _ => e + } + } + + object FoldNotEqual extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs + case _ => e + } + } + + private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { + case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) + case _ => e + } + + private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { + case 0 => e.args.head + case x => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) + case _ => e + } + } + + private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { + case 0 => e.args.head + case x => e.args.head match { + // TODO when amount >= x.width, return a zero-width wire + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) + // take sign bit if shift amount is larger than arg width + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) + case _ => e + } + } + + private def foldComparison(e: DoPrim) = { + def foldIfZeroedArg(x: Expression): Expression = { + def isUInt(e: Expression): Boolean = e.tpe match { + case UIntType(_) => true + case _ => false + } + def isZero(e: Expression) = e match { + case UIntLiteral(value, _) => value == BigInt(0) + case SIntLiteral(value, _) => value == BigInt(0) + case _ => false + } + x match { + case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero + case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one + case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero + case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one + case ex => ex + } + } + + def foldIfOutsideRange(x: Expression): Expression = { + //Note, only abides by a partial ordering + case class Range(min: BigInt, max: BigInt) { + def === (that: Range) = + Seq(this.min, this.max, that.min, that.max) + .sliding(2,1) + .map(x => x.head == x(1)) + .reduce(_ && _) + def > (that: Range) = this.min > that.max + def >= (that: Range) = this.min >= that.max + def < (that: Range) = this.max < that.min + def <= (that: Range) = this.max <= that.min + } + def range(e: Expression): Range = e match { + case UIntLiteral(value, _) => Range(value, value) + case SIntLiteral(value, _) => Range(value, value) + case _ => e.tpe match { + case SIntType(IntWidth(width)) => Range( + min = BigInt(0) - BigInt(2).pow(width.toInt - 1), + max = BigInt(2).pow(width.toInt - 1) - BigInt(1) + ) + case UIntType(IntWidth(width)) => Range( + min = BigInt(0), + max = BigInt(2).pow(width.toInt) - BigInt(1) + ) + } + } + // Calculates an expression's range of values + x match { + case ex: DoPrim => + def r0 = range(ex.args.head) + def r1 = range(ex.args(1)) + ex.op match { + // Always true + case Lt if r0 < r1 => one + case Leq if r0 <= r1 => one + case Gt if r0 > r1 => one + case Geq if r0 >= r1 => one + // Always false + case Lt if r0 >= r1 => zero + case Leq if r0 > r1 => zero + case Gt if r0 <= r1 => zero + case Geq if r0 < r1 => zero + case _ => ex + } + case ex => ex + } + } + foldIfZeroedArg(foldIfOutsideRange(e)) + } + + private def constPropPrim(e: DoPrim): Expression = e.op match { + case Shl => foldShiftLeft(e) + case Shr => foldShiftRight(e) + case Cat => foldConcat(e) + case And => FoldAND(e) + case Or => FoldOR(e) + case Xor => FoldXOR(e) + case Eq => FoldEqual(e) + case Neq => FoldNotEqual(e) + case (Lt | Leq | Gt | Geq) => foldComparison(e) + case Not => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) + case _ => e + } + case AsUInt => e.args.head match { + case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) + case u: UIntLiteral => u + case _ => e + } + case AsSInt => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) + case s: SIntLiteral => s + case _ => e + } + case Pad => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) + case _ if bitWidth(e.args.head.tpe) == e.consts.head => e.args.head + case _ => e + } + case Bits => e.args.head match { + case lit: Literal => + val hi = e.consts.head.toInt + val lo = e.consts(1).toInt + require(hi >= lo) + UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) + case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { + case t: UIntType => x + case _ => asUInt(x, e.tpe) + } + case _ => e + } + case _ => e + } + + private def constPropMuxCond(m: Mux) = m.cond match { + case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) + case _ => m + } + + private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { + case _ if m.tval == m.fval => m.tval + case (t: UIntLiteral, f: UIntLiteral) => + if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond + else constPropMuxCond(m) + case _ => constPropMuxCond(m) + } + + private def constPropNodeRef(r: WRef, e: Expression) = e match { + case _: UIntLiteral | _: SIntLiteral | _: WRef => e + case _ => r + } + + // Two pass process + // 1. Propagate constants in expressions and forward propagate references + // 2. Propagate references again for backwards reference (Wires) + // TODO Replacing all wires with nodes makes the second pass unnecessary + @tailrec + private def constPropModule(m: Module): Module = { + var nPropagated = 0L + val nodeMap = collection.mutable.HashMap[String, Expression]() + + def backPropExpr(expr: Expression): Expression = { + val old = expr map backPropExpr + val propagated = old match { + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + if (old ne propagated) { + nPropagated += 1 + } + propagated + } + def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr + + def constPropExpression(e: Expression): Expression = { + val old = e map constPropExpression + val propagated = old match { + case p: DoPrim => constPropPrim(p) + case m: Mux => constPropMux(m) + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + propagated + } + + def constPropStmt(s: Statement): Statement = { + val stmtx = s map constPropStmt map constPropExpression + stmtx match { + case x: DefNode => nodeMap(x.name) = x.value + case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => + val exprx = constPropExpression(pad(expr, wtpe)) + nodeMap(wname) = exprx + case _ => + } + stmtx + } + + val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) + if (nPropagated > 0) constPropModule(res) else res + } + + def run(c: Circuit): Circuit = { + val modulesx = c.modules.map { + case m: ExtModule => m + case m: Module => constPropModule(m) + } + Circuit(c.info, modulesx, c.main) + } + + def execute(state: CircuitState): CircuitState = { + state.copy(circuit = run(state.circuit)) + } +} diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala index 0d31038a..299142d9 100644 --- a/src/test/scala/firrtlTests/CInferMDirSpec.scala +++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala @@ -5,6 +5,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ import firrtl.passes._ +import firrtl.transforms._ import firrtl.Mappers._ import annotations._ @@ -39,7 +40,7 @@ class CInferMDir extends LowTransformSpec { def transform = new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, CInferMDirCheckPass) + def transforms = Seq(new ConstantPropagation, CInferMDirCheckPass) } "Memory" should "have correct mem port directions" in { diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index c963c8ae..6fac5047 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -5,6 +5,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ import firrtl.passes._ +import firrtl.transforms._ import firrtl.Mappers._ import annotations._ @@ -53,7 +54,7 @@ class ChirrtlMemSpec extends LowTransformSpec { def transform = new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, MemEnableCheckPass) + def transforms = Seq(new ConstantPropagation, MemEnableCheckPass) } "Sequential Memory" should "have correct enable signals" in { diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 95785717..c94adbf6 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -2,11 +2,11 @@ package firrtlTests -import org.scalatest.Matchers +import firrtl._ import firrtl.ir.Circuit import firrtl.Parser.IgnoreInfo -import firrtl.Parser import firrtl.passes._ +import firrtl.transforms._ // Tests the following cases for constant propagation: // 1) Unsigned integers are always greater than or @@ -16,17 +16,17 @@ import firrtl.passes._ // 3) Values are always greater than a number smaller // than their minimum value class ConstantPropagationSpec extends FirrtlFlatSpec { - val passes = Seq( + val transforms = Seq( ToWorkingIR, ResolveKinds, InferTypes, ResolveGenders, InferWidths, - ConstProp) - private def exec (input: String) = { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - }.serialize + new ConstantPropagation) + private def exec(input: String) = { + transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit.serialize } // ============================= "The rule x >= 0 " should " always be true if x is a UInt" in { @@ -346,6 +346,29 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "ConstProp" should "work across wires" in { + val input = +"""circuit Top : + module Top : + input x : UInt<1> + output y : UInt<1> + wire z : UInt<1> + y <= z + z <= mux(x, UInt<1>(0), UInt<1>(0)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<1> + output y : UInt<1> + wire z : UInt<1> + y <= UInt<1>(0) + z <= UInt<1>(0) """ (parse(exec(input))) should be (parse(check)) } diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index b43df713..ab367554 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -8,6 +8,7 @@ import org.scalatest.prop._ import firrtl.Parser import firrtl.ir.Circuit import firrtl.passes._ +import firrtl.transforms._ import firrtl._ class LowerTypesSpec extends FirrtlFlatSpec { @@ -27,7 +28,7 @@ class LowerTypesSpec extends FirrtlFlatSpec { ExpandWhens, CheckInitialization, Legalize, - ConstProp, + new ConstantPropagation, ResolveKinds, InferTypes, ResolveGenders, diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 25f845bc..7cbfeafe 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -22,7 +22,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) + def transforms = Seq(new ConstantPropagation, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) } ) diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 0d5d098c..f717fc18 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -8,13 +8,15 @@ import org.scalatest.prop._ import firrtl._ import firrtl.ir.Circuit import firrtl.passes._ +import firrtl.transforms._ import firrtl.Parser.IgnoreInfo class UnitTests extends FirrtlFlatSpec { - private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + private def executeTest(input: String, expected: Seq[String], transforms: Seq[Transform]) = { + val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit + val lines = c.serialize.split("\n") map normalized expected foreach { e => @@ -199,7 +201,7 @@ class UnitTests extends FirrtlFlatSpec { PullMuxes, ExpandConnects, RemoveAccesses, - ConstProp + new ConstantPropagation ) val input = """circuit AssignViaDeref : -- cgit v1.2.3