aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/ConstantPropagation.scala
diff options
context:
space:
mode:
authoralbertchen-sifive2018-07-20 14:36:30 -0700
committerAdam Izraelevitz2018-07-20 14:36:30 -0700
commit7dff927840a30893facae957595a8e88ea62509a (patch)
tree08210d9b2936fc4606ae8a0fe1c9f12a8c7c673e /src/main/scala/firrtl/transforms/ConstantPropagation.scala
parent897dad039a12a49b3c4ae833fbf0d02087b26ed5 (diff)
Constant prop add (#849)
* add FoldADD to const prop, add yosys miter tests * add option for verilog compiler without optimizations * rename FoldLogicalOp to FoldCommutativeOp * add GetNamespace and RenameModules, GetNamespace stores namespace as a ModuleNamespaceAnnotation * add constant propagation for Tail DoPrims * add scaladocs for MinimumLowFirrtlOptimization and yosysExpectFalure/Success, add constant propagation for Head DoPrim * add legalize pass to MinimumLowFirrtlOptimizations, use constPropBitExtract in legalize pass
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala61
1 files changed, 41 insertions, 20 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 4a4f41d1..0d30446c 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -17,12 +17,33 @@ import annotation.tailrec
import collection.mutable
object ConstantPropagation {
+ private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t)
+
/** Pads e to the width of t */
def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
case (we, wt) if we == wt => e
}
+ def constPropBitExtract(e: DoPrim) = {
+ val arg = e.args.head
+ val (hi, lo) = e.op match {
+ case Bits => (e.consts.head.toInt, e.consts(1).toInt)
+ case Tail => ((bitWidth(arg.tpe) - 1 - e.consts.head).toInt, 0)
+ case Head => ((bitWidth(arg.tpe) - 1).toInt, (bitWidth(arg.tpe) - e.consts.head).toInt)
+ }
+
+ arg match {
+ case lit: Literal =>
+ require(hi >= lo)
+ UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
+ case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
+ case t: UIntType => x
+ case _ => asUInt(x, e.tpe)
+ }
+ case _ => e
+ }
+ }
}
class ConstantPropagation extends Transform {
@@ -30,9 +51,7 @@ class ConstantPropagation extends Transform {
def inputForm = LowForm
def outputForm = LowForm
- private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t)
-
- trait FoldLogicalOp {
+ trait FoldCommutativeOp {
def fold(c1: Literal, c2: Literal): Expression
def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression
@@ -44,7 +63,19 @@ class ConstantPropagation extends Transform {
}
}
- object FoldAND extends FoldLogicalOp {
+ object FoldADD extends FoldCommutativeOp {
+ def fold(c1: Literal, c2: Literal) = (c1, c2) match {
+ case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
+ case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
+ }
+ def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
+ case UIntLiteral(v, w) if v == BigInt(0) => rhs
+ case SIntLiteral(v, w) if v == BigInt(0) => rhs
+ case _ => e
+ }
+ }
+
+ object FoldAND extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
@@ -54,7 +85,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldOR extends FoldLogicalOp {
+ object FoldOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
@@ -64,7 +95,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldXOR extends FoldLogicalOp {
+ object FoldXOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
@@ -73,7 +104,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldEqual extends FoldLogicalOp {
+ object FoldEqual extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
@@ -81,7 +112,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldNotEqual extends FoldLogicalOp {
+ object FoldNotEqual extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
@@ -189,6 +220,7 @@ class ConstantPropagation extends Transform {
case Shl => foldShiftLeft(e)
case Shr => foldShiftRight(e)
case Cat => foldConcat(e)
+ case Add => FoldADD(e)
case And => FoldAND(e)
case Or => FoldOR(e)
case Xor => FoldXOR(e)
@@ -215,18 +247,7 @@ class ConstantPropagation extends Transform {
case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
case _ => e
}
- case Bits => e.args.head match {
- case lit: Literal =>
- val hi = e.consts.head.toInt
- val lo = e.consts(1).toInt
- require(hi >= lo)
- UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
- case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
- case t: UIntType => x
- case _ => asUInt(x, e.tpe)
- }
- case _ => e
- }
+ case (Bits | Head | Tail) => constPropBitExtract(e)
case _ => e
}