aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorAndrew Waterman2016-04-08 13:19:53 -0700
committerAndrew Waterman2016-04-13 15:11:22 -0700
commit6fe81557f893081a7365d66c1711c6f21ef16284 (patch)
treea7f5d8f41281f665cf39695396c62fe34a8cdd4a /src/main
parent168392889d42506fae5a1aa637ebe1e61d799e62 (diff)
Add shift/concat constant propagation
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala39
1 files changed, 28 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
index 96dab45b..848e50d0 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/passes/ConstProp.scala
@@ -100,18 +100,35 @@ object ConstProp extends Pass {
}
}
- private def constPropPrim(e: DoPrim): Expression = e.op match {
- case SHIFT_RIGHT_OP => {
- val amount = e.consts(0).toInt
- def shiftWidth(w: Width) = (w - IntWidth(amount)) max IntWidth(1)
- e.args(0) match {
- // TODO when amount >= x.width, return a zero-width wire
- case UIntValue(v, w) => UIntValue(v >> amount, shiftWidth(w))
- // take sign bit if shift amount is larger than arg width
- case SIntValue(v, w) => SIntValue(v >> amount, shiftWidth(w))
- case _ => e
- }
+ private def foldConcat(e: DoPrim) = (e.args(0), e.args(1)) match {
+ case (UIntValue(xv, IntWidth(xw)), UIntValue(yv, IntWidth(yw))) => UIntValue(xv << yw.toInt | yv, IntWidth(xw + yw))
+ case _ => e
+ }
+
+ private def foldShiftLeft(e: DoPrim) = e.consts(0).toInt match {
+ case 0 => e.args(0)
+ case x => e.args(0) match {
+ case UIntValue(v, IntWidth(w)) => UIntValue(v << x, IntWidth(w + x))
+ case SIntValue(v, IntWidth(w)) => SIntValue(v << x, IntWidth(w + x))
+ case _ => e
}
+ }
+
+ private def foldShiftRight(e: DoPrim) = e.consts(0).toInt match {
+ case 0 => e.args(0)
+ case x => e.args(0) match {
+ // TODO when amount >= x.width, return a zero-width wire
+ case UIntValue(v, IntWidth(w)) => UIntValue(v >> x, IntWidth((w - x) max 1))
+ // take sign bit if shift amount is larger than arg width
+ case SIntValue(v, IntWidth(w)) => SIntValue(v >> x, IntWidth((w - x) max 1))
+ case _ => e
+ }
+ }
+
+ private def constPropPrim(e: DoPrim): Expression = e.op match {
+ case SHIFT_LEFT_OP => foldShiftLeft(e)
+ case SHIFT_RIGHT_OP => foldShiftRight(e)
+ case CONCAT_OP => foldConcat(e)
case AND_OP => FoldAND(e)
case OR_OP => FoldOR(e)
case XOR_OP => FoldXOR(e)