From f5a42ce22193a038008a1c4f80618e38f72b40f1 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Sun, 13 Jan 2019 21:05:17 -0800 Subject: Keep constant propagating expressions until done optimizing --- .../scala/firrtl/transforms/ConstantPropagation.scala | 5 ++++- .../scala/firrtlTests/ConstantPropagationTests.scala | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index da7f1a46..ed4ecd96 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -277,6 +277,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = { val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) val propagated = old match { @@ -290,7 +291,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) case x => x } - propagated + // We're done when the Expression no longer changes + if (propagated eq old) propagated + else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) } /** Constant propagate a Module diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 603ddc25..a6df1a3b 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -734,6 +734,24 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { """.stripMargin (parse(exec(input))) should be(parse(check)) } + + // Optimizing this mux gives: z <= pad(UInt<2>(0), 4) + // Thus this checks that we then optimize that pad + "ConstProp" should "optimize nested Expressions" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<4> + | z <= mux(UInt(1), UInt<2>(0), UInt<4>(0)) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<4> + | z <= UInt<4>("h0") + """.stripMargin + (parse(exec(input))) should be(parse(check)) + } } // More sophisticated tests of the full compiler -- cgit v1.2.3 From f961bfca704c9095309e110ff3a546a40b1a2dc5 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Sun, 13 Jan 2019 16:11:22 -0800 Subject: Constant Propagate dshl and dshr with constant amounts Fixes #990 h/t @pentin-as and @abejgonzalez --- .../firrtl/transforms/ConstantPropagation.scala | 17 ++++++ .../firrtlTests/ConstantPropagationTests.scala | 71 ++++++++++++++++++++++ 2 files changed, 88 insertions(+) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index ed4ecd96..16960b34 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -137,6 +137,13 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } } + private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match { + case UIntLiteral(v, IntWidth(w)) => + val shl = DoPrim(Shl, Seq(e.args.head), Seq(v), UnknownType) + pad(PrimOps.set_primop_type(shl), e.tpe) + case _ => e + } + private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head case x => e.args.head match { @@ -148,6 +155,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } } + private def foldDynamicShiftRight(e: DoPrim) = e.args.last match { + case UIntLiteral(v, IntWidth(w)) => + val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType) + pad(PrimOps.set_primop_type(shr), e.tpe) + case _ => e + } + + private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { def isUInt(e: Expression): Boolean = e.tpe match { @@ -221,7 +236,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { private def constPropPrim(e: DoPrim): Expression = e.op match { case Shl => foldShiftLeft(e) + case Dshl => foldDynamicShiftLeft(e) case Shr => foldShiftRight(e) + case Dshr => foldDynamicShiftRight(e) case Cat => foldConcat(e) case Add => FoldADD(e) case And => FoldAND(e) diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index a6df1a3b..8a69fcaa 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -1122,6 +1122,77 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { | z <= _T_61""".stripMargin execute(input, check, Seq.empty) } + + behavior of "ConstProp" + + it should "optimize shl of constants" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<7> + | z <= shl(UInt(5), 4) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<7> + | z <= UInt<7>("h50") + """.stripMargin + execute(input, check, Seq.empty) + } + + it should "optimize shr of constants" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<1> + | z <= shr(UInt(5), 2) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<1> + | z <= UInt<1>("h1") + """.stripMargin + execute(input, check, Seq.empty) + } + + // Due to #866, we need dshl optimized away or it'll become a dshlw and error in parsing + // Include cat to verify width is correct + it should "optimize dshl of constant" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<8> + | node n = dshl(UInt<1>(0), UInt<2>(0)) + | z <= cat(UInt<4>("hf"), n) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<8> + | z <= UInt<8>("hf0") + """.stripMargin + execute(input, check, Seq.empty) + } + + // Include cat and constants to verify width is correct + it should "optimize dshr of constant" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<8> + | node n = dshr(UInt<4>(0), UInt<2>(2)) + | z <= cat(UInt<4>("hf"), n) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<8> + | z <= UInt<8>("hf0") + """.stripMargin + execute(input, check, Seq.empty) + } } -- cgit v1.2.3 From 9636b550505e4803c6d7307af7e01d996d0f0ea8 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Sun, 13 Jan 2019 22:03:53 -0800 Subject: Suppress unchecked warning in Constant Propagation --- src/main/scala/firrtl/transforms/ConstantPropagation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 16960b34..8a273476 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -67,7 +67,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } object FoldADD extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = (c1, c2) match { + def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match { case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) } -- cgit v1.2.3