aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes
diff options
context:
space:
mode:
authorJack Koenig2017-06-28 17:52:56 -0700
committerJack Koenig2017-06-28 17:52:56 -0700
commit39665e1f74cfe8243067442cccf4e7eab66ade68 (patch)
tree8ba403e298c39bc6104f32a93754079dc458752a /src/main/scala/firrtl/passes
parent818cfde4ad42ffa9ee30d0f9ae72533ede80e4ce (diff)
Promote ConstProp to a transform
Diffstat (limited to 'src/main/scala/firrtl/passes')
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala295
1 files changed, 0 insertions, 295 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
deleted file mode 100644
index f2aa1a03..00000000
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ /dev/null
@@ -1,295 +0,0 @@
-// See LICENSE for license details.
-
-package firrtl.passes
-
-import firrtl._
-import firrtl.ir._
-import firrtl.Utils._
-import firrtl.Mappers._
-import firrtl.PrimOps._
-
-import annotation.tailrec
-
-object ConstProp extends Pass {
- private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
- case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
- case (we, wt) if we == wt => e
- }
-
- private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t)
-
- trait FoldLogicalOp {
- def fold(c1: Literal, c2: Literal): Expression
- def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression
-
- def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match {
- case (lhs: Literal, rhs: Literal) => fold(lhs, rhs)
- case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe)
- case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe)
- case _ => e
- }
- }
-
- object FoldAND extends FoldLogicalOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width)
- def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
- case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
- case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs
- case _ => e
- }
- }
-
- object FoldOR extends FoldLogicalOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width)
- def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, _) if v == BigInt(0) => rhs
- case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
- case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs
- case _ => e
- }
- }
-
- object FoldXOR extends FoldLogicalOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width)
- def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, _) if v == BigInt(0) => rhs
- case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
- case _ => e
- }
- }
-
- object FoldEqual extends FoldLogicalOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
- def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
- case _ => e
- }
- }
-
- object FoldNotEqual extends FoldLogicalOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
- def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
- case _ => e
- }
- }
-
- private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match {
- case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw))
- case _ => e
- }
-
- private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match {
- case 0 => e.args.head
- case x => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x))
- case _ => e
- }
- }
-
- private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
- case 0 => e.args.head
- case x => e.args.head match {
- // TODO when amount >= x.width, return a zero-width wire
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1))
- // take sign bit if shift amount is larger than arg width
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1))
- case _ => e
- }
- }
-
- private def foldComparison(e: DoPrim) = {
- def foldIfZeroedArg(x: Expression): Expression = {
- def isUInt(e: Expression): Boolean = e.tpe match {
- case UIntType(_) => true
- case _ => false
- }
- def isZero(e: Expression) = e match {
- case UIntLiteral(value, _) => value == BigInt(0)
- case SIntLiteral(value, _) => value == BigInt(0)
- case _ => false
- }
- x match {
- case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero
- case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one
- case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero
- case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one
- case ex => ex
- }
- }
-
- def foldIfOutsideRange(x: Expression): Expression = {
- //Note, only abides by a partial ordering
- case class Range(min: BigInt, max: BigInt) {
- def === (that: Range) =
- Seq(this.min, this.max, that.min, that.max)
- .sliding(2,1)
- .map(x => x.head == x(1))
- .reduce(_ && _)
- def > (that: Range) = this.min > that.max
- def >= (that: Range) = this.min >= that.max
- def < (that: Range) = this.max < that.min
- def <= (that: Range) = this.max <= that.min
- }
- def range(e: Expression): Range = e match {
- case UIntLiteral(value, _) => Range(value, value)
- case SIntLiteral(value, _) => Range(value, value)
- case _ => e.tpe match {
- case SIntType(IntWidth(width)) => Range(
- min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
- max = BigInt(2).pow(width.toInt - 1) - BigInt(1)
- )
- case UIntType(IntWidth(width)) => Range(
- min = BigInt(0),
- max = BigInt(2).pow(width.toInt) - BigInt(1)
- )
- }
- }
- // Calculates an expression's range of values
- x match {
- case ex: DoPrim =>
- def r0 = range(ex.args.head)
- def r1 = range(ex.args(1))
- ex.op match {
- // Always true
- case Lt if r0 < r1 => one
- case Leq if r0 <= r1 => one
- case Gt if r0 > r1 => one
- case Geq if r0 >= r1 => one
- // Always false
- case Lt if r0 >= r1 => zero
- case Leq if r0 > r1 => zero
- case Gt if r0 <= r1 => zero
- case Geq if r0 < r1 => zero
- case _ => ex
- }
- case ex => ex
- }
- }
- foldIfZeroedArg(foldIfOutsideRange(e))
- }
-
- private def constPropPrim(e: DoPrim): Expression = e.op match {
- case Shl => foldShiftLeft(e)
- case Shr => foldShiftRight(e)
- case Cat => foldConcat(e)
- case And => FoldAND(e)
- case Or => FoldOR(e)
- case Xor => FoldXOR(e)
- case Eq => FoldEqual(e)
- case Neq => FoldNotEqual(e)
- case (Lt | Leq | Gt | Geq) => foldComparison(e)
- case Not => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
- case _ => e
- }
- case AsUInt => e.args.head match {
- case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
- case u: UIntLiteral => u
- case _ => e
- }
- case AsSInt => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w))
- case s: SIntLiteral => s
- case _ => e
- }
- case Pad => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w))
- case _ if bitWidth(e.args.head.tpe) == e.consts.head => e.args.head
- case _ => e
- }
- case Bits => e.args.head match {
- case lit: Literal =>
- val hi = e.consts.head.toInt
- val lo = e.consts(1).toInt
- require(hi >= lo)
- UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
- case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
- case t: UIntType => x
- case _ => asUInt(x, e.tpe)
- }
- case _ => e
- }
- case _ => e
- }
-
- private def constPropMuxCond(m: Mux) = m.cond match {
- case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe)
- case _ => m
- }
-
- private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match {
- case _ if m.tval == m.fval => m.tval
- case (t: UIntLiteral, f: UIntLiteral) =>
- if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond
- else constPropMuxCond(m)
- case _ => constPropMuxCond(m)
- }
-
- private def constPropNodeRef(r: WRef, e: Expression) = e match {
- case _: UIntLiteral | _: SIntLiteral | _: WRef => e
- case _ => r
- }
-
- // Two pass process
- // 1. Propagate constants in expressions and forward propagate references
- // 2. Propagate references again for backwards reference (Wires)
- // TODO Replacing all wires with nodes makes the second pass unnecessary
- @tailrec
- private def constPropModule(m: Module): Module = {
- var nPropagated = 0L
- val nodeMap = collection.mutable.HashMap[String, Expression]()
-
- def backPropExpr(expr: Expression): Expression = {
- val old = expr map backPropExpr
- val propagated = old match {
- case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
- constPropNodeRef(ref, nodeMap(rname))
- case x => x
- }
- if (old ne propagated) {
- nPropagated += 1
- }
- propagated
- }
- def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr
-
- def constPropExpression(e: Expression): Expression = {
- val old = e map constPropExpression
- val propagated = old match {
- case p: DoPrim => constPropPrim(p)
- case m: Mux => constPropMux(m)
- case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
- constPropNodeRef(ref, nodeMap(rname))
- case x => x
- }
- propagated
- }
-
- def constPropStmt(s: Statement): Statement = {
- val stmtx = s map constPropStmt map constPropExpression
- stmtx match {
- case x: DefNode => nodeMap(x.name) = x.value
- case Connect(_, WRef(wname, wtpe, WireKind, _), expr) =>
- val exprx = constPropExpression(pad(expr, wtpe))
- nodeMap(wname) = exprx
- case _ =>
- }
- stmtx
- }
-
- val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body)))
- if (nPropagated > 0) constPropModule(res) else res
- }
-
- def run(c: Circuit): Circuit = {
- val modulesx = c.modules.map {
- case m: ExtModule => m
- case m: Module => constPropModule(m)
- }
- Circuit(c.info, modulesx, c.main)
- }
-}