/* Copyright (c) 2014 - 2016 The Regents of the University of California (Regents). All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following two paragraphs of disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following two paragraphs of disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the Regents nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. */ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ import annotation.tailrec object ConstProp extends Pass { def name = "Constant Propagation" private def pad(e: Expression, t: Type) = (long_BANG(e.tpe), long_BANG(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e } private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) trait FoldLogicalOp { def fold(c1: Literal, c2: Literal): Expression def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression def apply(e: DoPrim): Expression = (e.args(0), e.args(1)) match { case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) case _ => e } } object FoldAND extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, w) if v == 0 => UIntLiteral(0, w) case SIntLiteral(v, w) if v == 0 => UIntLiteral(0, w) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << long_BANG(rhs.tpe).toInt) - 1 => rhs case _ => e } } object FoldOR extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == 0 => rhs case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << long_BANG(rhs.tpe).toInt) - 1 => lhs case _ => e } } object FoldXOR extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == 0 => rhs case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe) case _ => e } } object FoldEqual extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == 1 && w == 1 && long_BANG(rhs.tpe) == 1 => rhs case _ => e } } object FoldNotEqual extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == 0 && w == 1 && long_BANG(rhs.tpe) == 1 => rhs case _ => e } } private def foldConcat(e: DoPrim) = (e.args(0), e.args(1)) match { case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) case _ => e } private def foldShiftLeft(e: DoPrim) = e.consts(0).toInt match { case 0 => e.args(0) case x => e.args(0) match { case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) case _ => e } } private def foldShiftRight(e: DoPrim) = e.consts(0).toInt match { case 0 => e.args(0) case x => e.args(0) match { // TODO when amount >= x.width, return a zero-width wire case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) // take sign bit if shift amount is larger than arg width case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) case _ => e } } private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { def isUInt(e: Expression): Boolean = tpe(e) match { case UIntType(_) => true case _ => false } def isZero(e: Expression) = e match { case UIntLiteral(value, _) => value == BigInt(0) case SIntLiteral(value, _) => value == BigInt(0) case _ => false } x match { case DoPrim(Lt, Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => zero case DoPrim(Leq, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => one case DoPrim(Gt, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => zero case DoPrim(Geq, Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => one case e => e } } def foldIfOutsideRange(x: Expression): Expression = { //Note, only abides by a partial ordering case class Range(min: BigInt, max: BigInt) { def === (that: Range) = Seq(this.min, this.max, that.min, that.max) .sliding(2,1) .map(x => x(0) == x(1)) .reduce(_ && _) def > (that: Range) = this.min > that.max def >= (that: Range) = this.min >= that.max def < (that: Range) = this.max < that.min def <= (that: Range) = this.max <= that.min } def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) case _ => tpe(e) match { case SIntType(IntWidth(width)) => Range( min = BigInt(0) - BigInt(2).pow(width.toInt - 1), max = BigInt(2).pow(width.toInt - 1) - BigInt(1) ) case UIntType(IntWidth(width)) => Range( min = BigInt(0), max = BigInt(2).pow(width.toInt) - BigInt(1) ) } } // Calculates an expression's range of values x match { case e: DoPrim => { def r0 = range(e.args(0)) def r1 = range(e.args(1)) e.op match { // Always true case Lt if (r0 < r1) => one case Leq if (r0 <= r1) => one case Gt if (r0 > r1) => one case Geq if (r0 >= r1) => one // Always false case Lt if (r0 >= r1) => zero case Leq if (r0 > r1) => zero case Gt if (r0 <= r1) => zero case Geq if (r0 < r1) => zero case _ => e } } case e => e } } foldIfZeroedArg(foldIfOutsideRange(e)) } private def constPropPrim(e: DoPrim): Expression = e.op match { case Shl => foldShiftLeft(e) case Shr => foldShiftRight(e) case Cat => foldConcat(e) case And => FoldAND(e) case Or => FoldOR(e) case Xor => FoldXOR(e) case Eq => FoldEqual(e) case Neq => FoldNotEqual(e) case (Lt | Leq | Gt | Geq) => foldComparison(e) case Not => e.args(0) match { case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) case _ => e } case AsUInt => e.args(0) match { case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) case u: UIntLiteral => u case _ => e } case AsSInt => e.args(0) match { case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) case s: SIntLiteral => s case _ => e } case Pad => e.args(0) match { case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(e.consts(0))) case SIntLiteral(v, _) => SIntLiteral(v, IntWidth(e.consts(0))) case _ if long_BANG(tpe(e.args(0))) == e.consts(0) => e.args(0) case _ => e } case Bits => e.args(0) match { case lit: Literal => { val hi = e.consts(0).toInt val lo = e.consts(1).toInt require(hi >= lo) UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e))) } case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match { case t: UIntType => x case _ => asUInt(x, e.tpe) } case _ => e } case _ => e } private def constPropMuxCond(m: Mux) = m.cond match { case UIntLiteral(c, _) => pad(if (c == 1) m.tval else m.fval, m.tpe) case _ => m } private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { case _ if m.tval == m.fval => m.tval case (t: UIntLiteral, f: UIntLiteral) => if (t.value == 1 && f.value == 0 && long_BANG(m.tpe) == 1) m.cond else constPropMuxCond(m) case _ => constPropMuxCond(m) } private def constPropNodeRef(r: WRef, e: Expression) = e match { case _: UIntLiteral | _: SIntLiteral | _: WRef => e case _ => r } @tailrec private def constPropModule(m: Module): Module = { var nPropagated = 0L val nodeMap = collection.mutable.HashMap[String, Expression]() def constPropExpression(e: Expression): Expression = { val old = e map constPropExpression val propagated = old match { case p: DoPrim => constPropPrim(p) case m: Mux => constPropMux(m) case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name)) case x => x } if (old ne propagated) nPropagated += 1 propagated } def constPropStmt(s: Statement): Statement = { s match { case x: DefNode => nodeMap(x.name) = x.value case _ => } s map constPropStmt map constPropExpression } val res = Module(m.info, m.name, m.ports, constPropStmt(m.body)) if (nPropagated > 0) constPropModule(res) else res } def run(c: Circuit): Circuit = { val modulesx = c.modules.map { case m: ExtModule => m case m: Module => constPropModule(m) } Circuit(c.info, modulesx, c.main) } }