// See LICENSE for license details. package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ import annotation.tailrec object ConstProp extends Pass { private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e } private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) trait FoldLogicalOp { def fold(c1: Literal, c2: Literal): Expression def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) case _ => e } } object FoldAND extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs case _ => e } } object FoldOR extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs case _ => e } } object FoldXOR extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) case _ => e } } object FoldEqual extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs case _ => e } } object FoldNotEqual extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs case _ => e } } private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) case _ => e } private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head case x => e.args.head match { case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) case _ => e } } private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head case x => e.args.head match { // TODO when amount >= x.width, return a zero-width wire case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) // take sign bit if shift amount is larger than arg width case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) case _ => e } } private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { def isUInt(e: Expression): Boolean = e.tpe match { case UIntType(_) => true case _ => false } def isZero(e: Expression) = e match { case UIntLiteral(value, _) => value == BigInt(0) case SIntLiteral(value, _) => value == BigInt(0) case _ => false } x match { case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one case ex => ex } } def foldIfOutsideRange(x: Expression): Expression = { //Note, only abides by a partial ordering case class Range(min: BigInt, max: BigInt) { def === (that: Range) = Seq(this.min, this.max, that.min, that.max) .sliding(2,1) .map(x => x.head == x(1)) .reduce(_ && _) def > (that: Range) = this.min > that.max def >= (that: Range) = this.min >= that.max def < (that: Range) = this.max < that.min def <= (that: Range) = this.max <= that.min } def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) case _ => e.tpe match { case SIntType(IntWidth(width)) => Range( min = BigInt(0) - BigInt(2).pow(width.toInt - 1), max = BigInt(2).pow(width.toInt - 1) - BigInt(1) ) case UIntType(IntWidth(width)) => Range( min = BigInt(0), max = BigInt(2).pow(width.toInt) - BigInt(1) ) } } // Calculates an expression's range of values x match { case ex: DoPrim => def r0 = range(ex.args.head) def r1 = range(ex.args(1)) ex.op match { // Always true case Lt if r0 < r1 => one case Leq if r0 <= r1 => one case Gt if r0 > r1 => one case Geq if r0 >= r1 => one // Always false case Lt if r0 >= r1 => zero case Leq if r0 > r1 => zero case Gt if r0 <= r1 => zero case Geq if r0 < r1 => zero case _ => ex } case ex => ex } } foldIfZeroedArg(foldIfOutsideRange(e)) } private def constPropPrim(e: DoPrim): Expression = e.op match { case Shl => foldShiftLeft(e) case Shr => foldShiftRight(e) case Cat => foldConcat(e) case And => FoldAND(e) case Or => FoldOR(e) case Xor => FoldXOR(e) case Eq => FoldEqual(e) case Neq => FoldNotEqual(e) case (Lt | Leq | Gt | Geq) => foldComparison(e) case Not => e.args.head match { case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) case _ => e } case AsUInt => e.args.head match { case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) case u: UIntLiteral => u case _ => e } case AsSInt => e.args.head match { case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) case s: SIntLiteral => s case _ => e } case Pad => e.args.head match { case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) case _ if bitWidth(e.args.head.tpe) == e.consts.head => e.args.head case _ => e } case Bits => e.args.head match { case lit: Literal => val hi = e.consts.head.toInt val lo = e.consts(1).toInt require(hi >= lo) UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { case t: UIntType => x case _ => asUInt(x, e.tpe) } case _ => e } case _ => e } private def constPropMuxCond(m: Mux) = m.cond match { case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) case _ => m } private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { case _ if m.tval == m.fval => m.tval case (t: UIntLiteral, f: UIntLiteral) => if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond else constPropMuxCond(m) case _ => constPropMuxCond(m) } private def constPropNodeRef(r: WRef, e: Expression) = e match { case _: UIntLiteral | _: SIntLiteral | _: WRef => e case _ => r } @tailrec private def constPropModule(m: Module): Module = { var nPropagated = 0L val nodeMap = collection.mutable.HashMap[String, Expression]() def constPropExpression(e: Expression): Expression = { val old = e map constPropExpression val propagated = old match { case p: DoPrim => constPropPrim(p) case m: Mux => constPropMux(m) case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name)) case x => x } if (old ne propagated) nPropagated += 1 propagated } def constPropStmt(s: Statement): Statement = { s match { case x: DefNode => nodeMap(x.name) = x.value case _ => } s map constPropStmt map constPropExpression } val res = Module(m.info, m.name, m.ports, constPropStmt(m.body)) if (nPropagated > 0) constPropModule(res) else res } def run(c: Circuit): Circuit = { val modulesx = c.modules.map { case m: ExtModule => m case m: Module => constPropModule(m) } Circuit(c.info, modulesx, c.main) } }