aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala24
1 files changed, 22 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index da7f1a46..8a273476 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -67,7 +67,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
object FoldADD extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = (c1, c2) match {
+ def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match {
case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
}
@@ -137,6 +137,13 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
}
+ private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match {
+ case UIntLiteral(v, IntWidth(w)) =>
+ val shl = DoPrim(Shl, Seq(e.args.head), Seq(v), UnknownType)
+ pad(PrimOps.set_primop_type(shl), e.tpe)
+ case _ => e
+ }
+
private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
case x => e.args.head match {
@@ -148,6 +155,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
}
+ private def foldDynamicShiftRight(e: DoPrim) = e.args.last match {
+ case UIntLiteral(v, IntWidth(w)) =>
+ val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType)
+ pad(PrimOps.set_primop_type(shr), e.tpe)
+ case _ => e
+ }
+
+
private def foldComparison(e: DoPrim) = {
def foldIfZeroedArg(x: Expression): Expression = {
def isUInt(e: Expression): Boolean = e.tpe match {
@@ -221,7 +236,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
private def constPropPrim(e: DoPrim): Expression = e.op match {
case Shl => foldShiftLeft(e)
+ case Dshl => foldDynamicShiftLeft(e)
case Shr => foldShiftRight(e)
+ case Dshr => foldDynamicShiftRight(e)
case Cat => foldConcat(e)
case Add => FoldADD(e)
case And => FoldAND(e)
@@ -277,6 +294,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
+
private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = {
val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
val propagated = old match {
@@ -290,7 +308,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
case x => x
}
- propagated
+ // We're done when the Expression no longer changes
+ if (propagated eq old) propagated
+ else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated)
}
/** Constant propagate a Module