diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 61 |
1 files changed, 41 insertions, 20 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 4a4f41d1..0d30446c 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -17,12 +17,33 @@ import annotation.tailrec import collection.mutable object ConstantPropagation { + private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) + /** Pads e to the width of t */ def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e } + def constPropBitExtract(e: DoPrim) = { + val arg = e.args.head + val (hi, lo) = e.op match { + case Bits => (e.consts.head.toInt, e.consts(1).toInt) + case Tail => ((bitWidth(arg.tpe) - 1 - e.consts.head).toInt, 0) + case Head => ((bitWidth(arg.tpe) - 1).toInt, (bitWidth(arg.tpe) - e.consts.head).toInt) + } + + arg match { + case lit: Literal => + require(hi >= lo) + UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) + case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { + case t: UIntType => x + case _ => asUInt(x, e.tpe) + } + case _ => e + } + } } class ConstantPropagation extends Transform { @@ -30,9 +51,7 @@ class ConstantPropagation extends Transform { def inputForm = LowForm def outputForm = LowForm - private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) - - trait FoldLogicalOp { + trait FoldCommutativeOp { def fold(c1: Literal, c2: Literal): Expression def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression @@ -44,7 +63,19 @@ class ConstantPropagation extends Transform { } } - object FoldAND extends FoldLogicalOp { + object FoldADD extends FoldCommutativeOp { + def fold(c1: Literal, c2: Literal) = (c1, c2) match { + case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) + case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) + } + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, w) if v == BigInt(0) => rhs + case SIntLiteral(v, w) if v == BigInt(0) => rhs + case _ => e + } + } + + object FoldAND extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) @@ -54,7 +85,7 @@ class ConstantPropagation extends Transform { } } - object FoldOR extends FoldLogicalOp { + object FoldOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs @@ -64,7 +95,7 @@ class ConstantPropagation extends Transform { } } - object FoldXOR extends FoldLogicalOp { + object FoldXOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs @@ -73,7 +104,7 @@ class ConstantPropagation extends Transform { } } - object FoldEqual extends FoldLogicalOp { + object FoldEqual extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs @@ -81,7 +112,7 @@ class ConstantPropagation extends Transform { } } - object FoldNotEqual extends FoldLogicalOp { + object FoldNotEqual extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs @@ -189,6 +220,7 @@ class ConstantPropagation extends Transform { case Shl => foldShiftLeft(e) case Shr => foldShiftRight(e) case Cat => foldConcat(e) + case Add => FoldADD(e) case And => FoldAND(e) case Or => FoldOR(e) case Xor => FoldXOR(e) @@ -215,18 +247,7 @@ class ConstantPropagation extends Transform { case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head case _ => e } - case Bits => e.args.head match { - case lit: Literal => - val hi = e.consts.head.toInt - val lo = e.consts(1).toInt - require(hi >= lo) - UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) - case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { - case t: UIntType => x - case _ => asUInt(x, e.tpe) - } - case _ => e - } + case (Bits | Head | Tail) => constPropBitExtract(e) case _ => e } |
