aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala5
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala18
2 files changed, 22 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index da7f1a46..ed4ecd96 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -277,6 +277,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
+
private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = {
val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
val propagated = old match {
@@ -290,7 +291,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
case x => x
}
- propagated
+ // We're done when the Expression no longer changes
+ if (propagated eq old) propagated
+ else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated)
}
/** Constant propagate a Module
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 603ddc25..a6df1a3b 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -734,6 +734,24 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
""".stripMargin
(parse(exec(input))) should be(parse(check))
}
+
+ // Optimizing this mux gives: z <= pad(UInt<2>(0), 4)
+ // Thus this checks that we then optimize that pad
+ "ConstProp" should "optimize nested Expressions" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<4>
+ | z <= mux(UInt(1), UInt<2>(0), UInt<4>(0))
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<4>
+ | z <= UInt<4>("h0")
+ """.stripMargin
+ (parse(exec(input))) should be(parse(check))
+ }
}
// More sophisticated tests of the full compiler