diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 303 |
1 files changed, 303 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala new file mode 100644 index 00000000..930fe45a --- /dev/null +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -0,0 +1,303 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.PrimOps._ + +import annotation.tailrec + +class ConstantPropagation extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { + case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) + case (we, wt) if we == wt => e + } + + private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) + + trait FoldLogicalOp { + def fold(c1: Literal, c2: Literal): Expression + def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression + + def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { + case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) + case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) + case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) + case _ => e + } + } + + object FoldAND extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs + case _ => e + } + } + + object FoldOR extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs + case _ => e + } + } + + object FoldXOR extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case _ => e + } + } + + object FoldEqual extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs + case _ => e + } + } + + object FoldNotEqual extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs + case _ => e + } + } + + private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { + case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) + case _ => e + } + + private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { + case 0 => e.args.head + case x => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) + case _ => e + } + } + + private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { + case 0 => e.args.head + case x => e.args.head match { + // TODO when amount >= x.width, return a zero-width wire + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) + // take sign bit if shift amount is larger than arg width + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) + case _ => e + } + } + + private def foldComparison(e: DoPrim) = { + def foldIfZeroedArg(x: Expression): Expression = { + def isUInt(e: Expression): Boolean = e.tpe match { + case UIntType(_) => true + case _ => false + } + def isZero(e: Expression) = e match { + case UIntLiteral(value, _) => value == BigInt(0) + case SIntLiteral(value, _) => value == BigInt(0) + case _ => false + } + x match { + case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero + case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one + case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero + case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one + case ex => ex + } + } + + def foldIfOutsideRange(x: Expression): Expression = { + //Note, only abides by a partial ordering + case class Range(min: BigInt, max: BigInt) { + def === (that: Range) = + Seq(this.min, this.max, that.min, that.max) + .sliding(2,1) + .map(x => x.head == x(1)) + .reduce(_ && _) + def > (that: Range) = this.min > that.max + def >= (that: Range) = this.min >= that.max + def < (that: Range) = this.max < that.min + def <= (that: Range) = this.max <= that.min + } + def range(e: Expression): Range = e match { + case UIntLiteral(value, _) => Range(value, value) + case SIntLiteral(value, _) => Range(value, value) + case _ => e.tpe match { + case SIntType(IntWidth(width)) => Range( + min = BigInt(0) - BigInt(2).pow(width.toInt - 1), + max = BigInt(2).pow(width.toInt - 1) - BigInt(1) + ) + case UIntType(IntWidth(width)) => Range( + min = BigInt(0), + max = BigInt(2).pow(width.toInt) - BigInt(1) + ) + } + } + // Calculates an expression's range of values + x match { + case ex: DoPrim => + def r0 = range(ex.args.head) + def r1 = range(ex.args(1)) + ex.op match { + // Always true + case Lt if r0 < r1 => one + case Leq if r0 <= r1 => one + case Gt if r0 > r1 => one + case Geq if r0 >= r1 => one + // Always false + case Lt if r0 >= r1 => zero + case Leq if r0 > r1 => zero + case Gt if r0 <= r1 => zero + case Geq if r0 < r1 => zero + case _ => ex + } + case ex => ex + } + } + foldIfZeroedArg(foldIfOutsideRange(e)) + } + + private def constPropPrim(e: DoPrim): Expression = e.op match { + case Shl => foldShiftLeft(e) + case Shr => foldShiftRight(e) + case Cat => foldConcat(e) + case And => FoldAND(e) + case Or => FoldOR(e) + case Xor => FoldXOR(e) + case Eq => FoldEqual(e) + case Neq => FoldNotEqual(e) + case (Lt | Leq | Gt | Geq) => foldComparison(e) + case Not => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) + case _ => e + } + case AsUInt => e.args.head match { + case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) + case u: UIntLiteral => u + case _ => e + } + case AsSInt => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) + case s: SIntLiteral => s + case _ => e + } + case Pad => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) + case _ if bitWidth(e.args.head.tpe) == e.consts.head => e.args.head + case _ => e + } + case Bits => e.args.head match { + case lit: Literal => + val hi = e.consts.head.toInt + val lo = e.consts(1).toInt + require(hi >= lo) + UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) + case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { + case t: UIntType => x + case _ => asUInt(x, e.tpe) + } + case _ => e + } + case _ => e + } + + private def constPropMuxCond(m: Mux) = m.cond match { + case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) + case _ => m + } + + private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { + case _ if m.tval == m.fval => m.tval + case (t: UIntLiteral, f: UIntLiteral) => + if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond + else constPropMuxCond(m) + case _ => constPropMuxCond(m) + } + + private def constPropNodeRef(r: WRef, e: Expression) = e match { + case _: UIntLiteral | _: SIntLiteral | _: WRef => e + case _ => r + } + + // Two pass process + // 1. Propagate constants in expressions and forward propagate references + // 2. Propagate references again for backwards reference (Wires) + // TODO Replacing all wires with nodes makes the second pass unnecessary + @tailrec + private def constPropModule(m: Module): Module = { + var nPropagated = 0L + val nodeMap = collection.mutable.HashMap[String, Expression]() + + def backPropExpr(expr: Expression): Expression = { + val old = expr map backPropExpr + val propagated = old match { + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + if (old ne propagated) { + nPropagated += 1 + } + propagated + } + def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr + + def constPropExpression(e: Expression): Expression = { + val old = e map constPropExpression + val propagated = old match { + case p: DoPrim => constPropPrim(p) + case m: Mux => constPropMux(m) + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + propagated + } + + def constPropStmt(s: Statement): Statement = { + val stmtx = s map constPropStmt map constPropExpression + stmtx match { + case x: DefNode => nodeMap(x.name) = x.value + case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => + val exprx = constPropExpression(pad(expr, wtpe)) + nodeMap(wname) = exprx + case _ => + } + stmtx + } + + val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) + if (nPropagated > 0) constPropModule(res) else res + } + + def run(c: Circuit): Circuit = { + val modulesx = c.modules.map { + case m: ExtModule => m + case m: Module => constPropModule(m) + } + Circuit(c.info, modulesx, c.main) + } + + def execute(state: CircuitState): CircuitState = { + state.copy(circuit = run(state.circuit)) + } +} |
