diff options
| author | Jack Koenig | 2017-06-28 17:52:56 -0700 |
|---|---|---|
| committer | Jack Koenig | 2017-06-28 17:52:56 -0700 |
| commit | 39665e1f74cfe8243067442cccf4e7eab66ade68 (patch) | |
| tree | 8ba403e298c39bc6104f32a93754079dc458752a /src/main/scala/firrtl/passes | |
| parent | 818cfde4ad42ffa9ee30d0f9ae72533ede80e4ce (diff) | |
Promote ConstProp to a transform
Diffstat (limited to 'src/main/scala/firrtl/passes')
| -rw-r--r-- | src/main/scala/firrtl/passes/ConstProp.scala | 295 |
1 files changed, 0 insertions, 295 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala deleted file mode 100644 index f2aa1a03..00000000 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ /dev/null @@ -1,295 +0,0 @@ -// See LICENSE for license details. - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Utils._ -import firrtl.Mappers._ -import firrtl.PrimOps._ - -import annotation.tailrec - -object ConstProp extends Pass { - private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { - case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) - case (we, wt) if we == wt => e - } - - private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) - - trait FoldLogicalOp { - def fold(c1: Literal, c2: Literal): Expression - def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression - - def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { - case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) - case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) - case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) - case _ => e - } - } - - object FoldAND extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) - case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) - case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs - case _ => e - } - } - - object FoldOR extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) - case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs - case _ => e - } - } - - object FoldXOR extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) - case _ => e - } - } - - object FoldEqual extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case _ => e - } - } - - object FoldNotEqual extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case _ => e - } - } - - private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { - case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) - case _ => e - } - - private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { - case 0 => e.args.head - case x => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) - case _ => e - } - } - - private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { - case 0 => e.args.head - case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) - // take sign bit if shift amount is larger than arg width - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) - case _ => e - } - } - - private def foldComparison(e: DoPrim) = { - def foldIfZeroedArg(x: Expression): Expression = { - def isUInt(e: Expression): Boolean = e.tpe match { - case UIntType(_) => true - case _ => false - } - def isZero(e: Expression) = e match { - case UIntLiteral(value, _) => value == BigInt(0) - case SIntLiteral(value, _) => value == BigInt(0) - case _ => false - } - x match { - case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero - case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one - case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero - case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one - case ex => ex - } - } - - def foldIfOutsideRange(x: Expression): Expression = { - //Note, only abides by a partial ordering - case class Range(min: BigInt, max: BigInt) { - def === (that: Range) = - Seq(this.min, this.max, that.min, that.max) - .sliding(2,1) - .map(x => x.head == x(1)) - .reduce(_ && _) - def > (that: Range) = this.min > that.max - def >= (that: Range) = this.min >= that.max - def < (that: Range) = this.max < that.min - def <= (that: Range) = this.max <= that.min - } - def range(e: Expression): Range = e match { - case UIntLiteral(value, _) => Range(value, value) - case SIntLiteral(value, _) => Range(value, value) - case _ => e.tpe match { - case SIntType(IntWidth(width)) => Range( - min = BigInt(0) - BigInt(2).pow(width.toInt - 1), - max = BigInt(2).pow(width.toInt - 1) - BigInt(1) - ) - case UIntType(IntWidth(width)) => Range( - min = BigInt(0), - max = BigInt(2).pow(width.toInt) - BigInt(1) - ) - } - } - // Calculates an expression's range of values - x match { - case ex: DoPrim => - def r0 = range(ex.args.head) - def r1 = range(ex.args(1)) - ex.op match { - // Always true - case Lt if r0 < r1 => one - case Leq if r0 <= r1 => one - case Gt if r0 > r1 => one - case Geq if r0 >= r1 => one - // Always false - case Lt if r0 >= r1 => zero - case Leq if r0 > r1 => zero - case Gt if r0 <= r1 => zero - case Geq if r0 < r1 => zero - case _ => ex - } - case ex => ex - } - } - foldIfZeroedArg(foldIfOutsideRange(e)) - } - - private def constPropPrim(e: DoPrim): Expression = e.op match { - case Shl => foldShiftLeft(e) - case Shr => foldShiftRight(e) - case Cat => foldConcat(e) - case And => FoldAND(e) - case Or => FoldOR(e) - case Xor => FoldXOR(e) - case Eq => FoldEqual(e) - case Neq => FoldNotEqual(e) - case (Lt | Leq | Gt | Geq) => foldComparison(e) - case Not => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) - case _ => e - } - case AsUInt => e.args.head match { - case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case u: UIntLiteral => u - case _ => e - } - case AsSInt => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) - case s: SIntLiteral => s - case _ => e - } - case Pad => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) - case _ if bitWidth(e.args.head.tpe) == e.consts.head => e.args.head - case _ => e - } - case Bits => e.args.head match { - case lit: Literal => - val hi = e.consts.head.toInt - val lo = e.consts(1).toInt - require(hi >= lo) - UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) - case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { - case t: UIntType => x - case _ => asUInt(x, e.tpe) - } - case _ => e - } - case _ => e - } - - private def constPropMuxCond(m: Mux) = m.cond match { - case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) - case _ => m - } - - private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { - case _ if m.tval == m.fval => m.tval - case (t: UIntLiteral, f: UIntLiteral) => - if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond - else constPropMuxCond(m) - case _ => constPropMuxCond(m) - } - - private def constPropNodeRef(r: WRef, e: Expression) = e match { - case _: UIntLiteral | _: SIntLiteral | _: WRef => e - case _ => r - } - - // Two pass process - // 1. Propagate constants in expressions and forward propagate references - // 2. Propagate references again for backwards reference (Wires) - // TODO Replacing all wires with nodes makes the second pass unnecessary - @tailrec - private def constPropModule(m: Module): Module = { - var nPropagated = 0L - val nodeMap = collection.mutable.HashMap[String, Expression]() - - def backPropExpr(expr: Expression): Expression = { - val old = expr map backPropExpr - val propagated = old match { - case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) - case x => x - } - if (old ne propagated) { - nPropagated += 1 - } - propagated - } - def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr - - def constPropExpression(e: Expression): Expression = { - val old = e map constPropExpression - val propagated = old match { - case p: DoPrim => constPropPrim(p) - case m: Mux => constPropMux(m) - case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) - case x => x - } - propagated - } - - def constPropStmt(s: Statement): Statement = { - val stmtx = s map constPropStmt map constPropExpression - stmtx match { - case x: DefNode => nodeMap(x.name) = x.value - case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => - val exprx = constPropExpression(pad(expr, wtpe)) - nodeMap(wname) = exprx - case _ => - } - stmtx - } - - val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) - if (nPropagated > 0) constPropModule(res) else res - } - - def run(c: Circuit): Circuit = { - val modulesx = c.modules.map { - case m: ExtModule => m - case m: Module => constPropModule(m) - } - Circuit(c.info, modulesx, c.main) - } -} |
