diff options
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 24 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 89 |
2 files changed, 111 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index da7f1a46..8a273476 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -67,7 +67,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } object FoldADD extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = (c1, c2) match { + def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match { case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) } @@ -137,6 +137,13 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } } + private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match { + case UIntLiteral(v, IntWidth(w)) => + val shl = DoPrim(Shl, Seq(e.args.head), Seq(v), UnknownType) + pad(PrimOps.set_primop_type(shl), e.tpe) + case _ => e + } + private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head case x => e.args.head match { @@ -148,6 +155,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } } + private def foldDynamicShiftRight(e: DoPrim) = e.args.last match { + case UIntLiteral(v, IntWidth(w)) => + val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType) + pad(PrimOps.set_primop_type(shr), e.tpe) + case _ => e + } + + private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { def isUInt(e: Expression): Boolean = e.tpe match { @@ -221,7 +236,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { private def constPropPrim(e: DoPrim): Expression = e.op match { case Shl => foldShiftLeft(e) + case Dshl => foldDynamicShiftLeft(e) case Shr => foldShiftRight(e) + case Dshr => foldDynamicShiftRight(e) case Cat => foldConcat(e) case Add => FoldADD(e) case And => FoldAND(e) @@ -277,6 +294,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = { val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) val propagated = old match { @@ -290,7 +308,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) case x => x } - propagated + // We're done when the Expression no longer changes + if (propagated eq old) propagated + else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) } /** Constant propagate a Module diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 603ddc25..8a69fcaa 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -734,6 +734,24 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { """.stripMargin (parse(exec(input))) should be(parse(check)) } + + // Optimizing this mux gives: z <= pad(UInt<2>(0), 4) + // Thus this checks that we then optimize that pad + "ConstProp" should "optimize nested Expressions" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<4> + | z <= mux(UInt(1), UInt<2>(0), UInt<4>(0)) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<4> + | z <= UInt<4>("h0") + """.stripMargin + (parse(exec(input))) should be(parse(check)) + } } // More sophisticated tests of the full compiler @@ -1104,6 +1122,77 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { | z <= _T_61""".stripMargin execute(input, check, Seq.empty) } + + behavior of "ConstProp" + + it should "optimize shl of constants" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<7> + | z <= shl(UInt(5), 4) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<7> + | z <= UInt<7>("h50") + """.stripMargin + execute(input, check, Seq.empty) + } + + it should "optimize shr of constants" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<1> + | z <= shr(UInt(5), 2) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<1> + | z <= UInt<1>("h1") + """.stripMargin + execute(input, check, Seq.empty) + } + + // Due to #866, we need dshl optimized away or it'll become a dshlw and error in parsing + // Include cat to verify width is correct + it should "optimize dshl of constant" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<8> + | node n = dshl(UInt<1>(0), UInt<2>(0)) + | z <= cat(UInt<4>("hf"), n) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<8> + | z <= UInt<8>("hf0") + """.stripMargin + execute(input, check, Seq.empty) + } + + // Include cat and constants to verify width is correct + it should "optimize dshr of constant" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<8> + | node n = dshr(UInt<4>(0), UInt<2>(2)) + | z <= cat(UInt<4>("hf"), n) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<8> + | z <= UInt<8>("hf0") + """.stripMargin + execute(input, check, Seq.empty) + } } |
