aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala30
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala22
2 files changed, 28 insertions, 24 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index bb65201b..04bfb19c 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -183,19 +183,23 @@ object ExpandConnects extends Pass {
object Legalize extends Pass {
private def legalizeShiftRight(e: DoPrim): Expression = {
require(e.op == Shr)
- val amount = e.consts.head.toInt
- val width = bitWidth(e.args.head.tpe)
- lazy val msb = width - 1
- if (amount >= width) {
- e.tpe match {
- case UIntType(_) => zero
- case SIntType(_) =>
- val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
- DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
- case t => error(s"Unsupported type $t for Primop Shift Right")
- }
- } else {
- e
+ e.args.head match {
+ case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e)
+ case _ =>
+ val amount = e.consts.head.toInt
+ val width = bitWidth(e.args.head.tpe)
+ lazy val msb = width - 1
+ if (amount >= width) {
+ e.tpe match {
+ case UIntType(_) => zero
+ case SIntType(_) =>
+ val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
+ DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
+ case t => error(s"Unsupported type $t for Primop Shift Right")
+ }
+ } else {
+ e
+ }
}
}
private def legalizeBitExtract(expr: DoPrim): Expression = {
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 54338719..6618312a 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -45,6 +45,17 @@ object ConstantPropagation {
case _ => e
}
}
+
+ def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
+ case 0 => e.args.head
+ case x => e.args.head match {
+ // TODO when amount >= x.width, return a zero-width wire
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1))
+ // take sign bit if shift amount is larger than arg width
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1))
+ case _ => e
+ }
+ }
}
class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
@@ -144,17 +155,6 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case _ => e
}
- private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
- case 0 => e.args.head
- case x => e.args.head match {
- // TODO when amount >= x.width, return a zero-width wire
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1))
- // take sign bit if shift amount is larger than arg width
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1))
- case _ => e
- }
- }
-
private def foldDynamicShiftRight(e: DoPrim) = e.args.last match {
case UIntLiteral(v, IntWidth(w)) =>
val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType)