diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/test/scala/firrtlTests/ConstantPropagationTests.scala | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/test/scala/firrtlTests/ConstantPropagationTests.scala')
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 1263 |
1 files changed, 629 insertions, 634 deletions
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index efe85e48..6ab54159 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -9,24 +9,22 @@ import firrtl.testutils._ import firrtl.annotations.Annotation class ConstantPropagationSpec extends FirrtlFlatSpec { - val transforms: Seq[Transform] = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - new ConstantPropagation) + val transforms: Seq[Transform] = + Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, new ConstantPropagation) protected def exec(input: String, annos: Seq[Annotation] = Nil) = { - transforms.foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit.serialize + transforms + .foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit + .serialize } } class ConstantPropagationMultiModule extends ConstantPropagationSpec { - "ConstProp" should "propagate constant inputs" in { - val input = -"""circuit Top : + "ConstProp" should "propagate constant inputs" in { + val input = + """circuit Top : module Child : input in0 : UInt<1> input in1 : UInt<1> @@ -40,8 +38,8 @@ class ConstantPropagationMultiModule extends ConstantPropagationSpec { c.in1 <= UInt<1>(1) z <= c.out """ - val check = -"""circuit Top : + val check = + """circuit Top : module Child : input in0 : UInt<1> input in1 : UInt<1> @@ -55,12 +53,12 @@ class ConstantPropagationMultiModule extends ConstantPropagationSpec { c.in1 <= UInt<1>(1) z <= c.out """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "propagate constant inputs ONLY if ALL instance inputs get the same value" in { - def circuit(allSame: Boolean) = -s"""circuit Top : + "ConstProp" should "propagate constant inputs ONLY if ALL instance inputs get the same value" in { + def circuit(allSame: Boolean) = + s"""circuit Top : module Bottom : input in : UInt<1> output out : UInt<1> @@ -83,8 +81,8 @@ s"""circuit Top : z <= and(and(b0.out, b1.out), c.out) """ - val resultFromAllSame = -"""circuit Top : + val resultFromAllSame = + """circuit Top : module Bottom : input in : UInt<1> output out : UInt<1> @@ -104,14 +102,14 @@ s"""circuit Top : b1.in <= UInt(1) z <= UInt(1) """ - (parse(exec(circuit(false)))) should be (parse(circuit(false))) - (parse(exec(circuit(true)))) should be (parse(resultFromAllSame)) - } - - // ============================= - "ConstProp" should "do nothing on unrelated modules" in { - val input = -"""circuit foo : + (parse(exec(circuit(false)))) should be(parse(circuit(false))) + (parse(exec(circuit(true)))) should be(parse(resultFromAllSame)) + } + + // ============================= + "ConstProp" should "do nothing on unrelated modules" in { + val input = + """circuit foo : module foo : input dummy : UInt<1> skip @@ -120,14 +118,14 @@ s"""circuit Top : input dummy : UInt<1> skip """ - val check = input - (parse(exec(input))) should be (parse(check)) - } - - // ============================= - "ConstProp" should "propagate module chains not connected to the top" in { - val input = -"""circuit foo : + val check = input + (parse(exec(input))) should be(parse(check)) + } + + // ============================= + "ConstProp" should "propagate module chains not connected to the top" in { + val input = + """circuit foo : module foo : input dummy : UInt<1> skip @@ -151,8 +149,8 @@ s"""circuit Top : output test : UInt<1> test <= UInt<1>(0) """ - val check = -"""circuit foo : + val check = + """circuit foo : module foo : input dummy : UInt<1> skip @@ -176,8 +174,8 @@ s"""circuit Top : output test : UInt<1> test <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } } // Tests the following cases for constant propagation: @@ -188,332 +186,332 @@ s"""circuit Top : // 3) Values are always greater than a number smaller // than their minimum value class ConstantPropagationSingleModule extends ConstantPropagationSpec { - // ============================= - "The rule x >= 0 " should " always be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule x >= 0 " should " always be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= geq(x, UInt(0)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>("h1") """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x < 0 " should " never be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule x < 0 " should " never be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= lt(x, UInt(0)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 0 <= x " should " always be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule 0 <= x " should " always be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= leq(UInt(0),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 0 > x " should " never be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule 0 > x " should " never be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= gt(UInt(0),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 1 < 3 " should " always be true" in { - val input = -"""circuit Top : + // ============================= + "The rule 1 < 3 " should " always be true" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= lt(UInt(0),UInt(3)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x < 8 " should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x < 8 " should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= lt(x,UInt(8)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x <= 7 " should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x <= 7 " should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= leq(x,UInt(7)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 8 > x" should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 8 > x" should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= gt(UInt(8),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 7 >= x" should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 7 >= x" should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= geq(UInt(7),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 10 == 10" should " always be true" in { - val input = -"""circuit Top : + // ============================= + "The rule 10 == 10" should " always be true" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= eq(UInt(10),UInt(10)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x == z " should " not be true even if they have the same number of bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x == z " should " not be true even if they have the same number of bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> input z : UInt<3> output y : UInt<1> y <= eq(x,z) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> input z : UInt<3> output y : UInt<1> y <= eq(x,z) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 10 != 10 " should " always be false" in { - val input = -"""circuit Top : + // ============================= + "The rule 10 != 10 " should " always be false" in { + val input = + """circuit Top : module Top : output y : UInt<1> y <= neq(UInt(10),UInt(10)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : output y : UInt<1> y <= UInt(0) """ - (parse(exec(input))) should be (parse(check)) - } - // ============================= - "The rule 1 >= 3 " should " always be false" in { - val input = -"""circuit Top : + (parse(exec(input))) should be(parse(check)) + } + // ============================= + "The rule 1 >= 3 " should " always be false" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= geq(UInt(1),UInt(3)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x >= 8 " should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x >= 8 " should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= geq(x,UInt(8)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x > 7 " should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x > 7 " should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= gt(x,UInt(7)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 8 <= x" should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 8 <= x" should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= leq(UInt(8),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 7 < x" should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 7 < x" should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= lt(UInt(7),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "work across wires" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "work across wires" in { + val input = + """circuit Top : module Top : input x : UInt<1> output y : UInt<1> @@ -521,8 +519,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { y <= z z <= mux(x, UInt<1>(0), UInt<1>(0)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> output y : UInt<1> @@ -530,13 +528,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { y <= UInt<1>(0) z <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "swap named nodes with temporary nodes that drive them" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "swap named nodes with temporary nodes that drive them" in { + val input = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -545,8 +543,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node n = _T_1 z <= and(n, x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -555,13 +553,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node _T_1 = n z <= and(n, x) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "swap named nodes with temporary wires that drive them" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "swap named nodes with temporary wires that drive them" in { + val input = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -571,8 +569,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n _T_1 <= and(x, y) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -582,13 +580,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n n <= and(x, y) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "swap named nodes with temporary registers that drive them" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "swap named nodes with temporary registers that drive them" in { + val input = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -598,8 +596,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n _T_1 <= x """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -609,13 +607,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n n <= x """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "only swap a given name with one other name" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "only swap a given name with one other name" in { + val input = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -625,8 +623,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node m = _T_1 z <= add(n, m) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -636,12 +634,12 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node m = n z <= add(n, n) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "NOT swap wire names with node names" in { - val input = -"""circuit Top : + "ConstProp" should "NOT swap wire names with node names" in { + val input = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -653,8 +651,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { hit <= _T_2 z <= hit """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -666,12 +664,12 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { hit <= or(x, y) z <= hit """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "propagate constant outputs" in { - val input = -"""circuit Top : + "ConstProp" should "propagate constant outputs" in { + val input = + """circuit Top : module Child : output out : UInt<1> out <= UInt<1>(0) @@ -681,8 +679,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { inst c of Child z <= and(x, c.out) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Child : output out : UInt<1> out <= UInt<1>(0) @@ -692,10 +690,10 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { inst c of Child z <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "propagate constant addition" in { + "ConstProp" should "propagate constant addition" in { val input = """circuit Top : | module Top : @@ -717,7 +715,7 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { (parse(exec(input))) should be(parse(check)) } - "ConstProp" should "propagate addition with zero" in { + "ConstProp" should "propagate addition with zero" in { val input = """circuit Top : | module Top : @@ -779,20 +777,20 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { def castCheck(tpe: String, cast: String): Unit = { val input = - s"""circuit Top : - | module Top : - | input x : $tpe - | output z : $tpe - | z <= $cast(x) + s"""circuit Top : + | module Top : + | input x : $tpe + | output z : $tpe + | z <= $cast(x) """.stripMargin val check = - s"""circuit Top : - | module Top : - | input x : $tpe - | output z : $tpe - | z <= x + s"""circuit Top : + | module Top : + | input x : $tpe + | output z : $tpe + | z <= x """.stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } it should "optimize unnecessary casts" in { castCheck("UInt<4>", "asUInt") @@ -807,218 +805,217 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { def transform = new LowFirrtlOptimization "ConstProp" should "NOT optimize across dontTouch on nodes" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin - val check = input + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Top.z"))) } it should "NOT optimize across nodes marked dontTouch by other annotations" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin - val check = input - val dontTouchRT = annotations.ModuleTarget("Top", "Top").ref("z") + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + val check = input + val dontTouchRT = annotations.ModuleTarget("Top", "Top").ref("z") execute(input, check, Seq(AnnotationWithDontTouches(dontTouchRT))) } it should "NOT optimize across dontTouch on registers" in { - val input = - """circuit Top : - | module Top : - | input clk : Clock - | input reset : UInt<1> - | output y : UInt<1> - | reg z : UInt<1>, clk - | y <= z - | z <= mux(reset, UInt<1>("h0"), z)""".stripMargin - val check = input + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input reset : UInt<1> + | output y : UInt<1> + | reg z : UInt<1>, clk + | y <= z + | z <= mux(reset, UInt<1>("h0"), z)""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Top.z"))) } - it should "NOT optimize across dontTouch on wires" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | wire z : UInt<1> - | y <= z - | z <= x""".stripMargin - val check = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | wire z : UInt<1> + | y <= z + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin execute(input, check, Seq(dontTouch("Top.z"))) } it should "NOT optimize across dontTouch on output ports" in { val input = """circuit Top : - | module Child : - | output out : UInt<1> - | out <= UInt<1>(0) - | module Top : - | input x : UInt<1> - | output z : UInt<1> - | inst c of Child - | z <= and(x, c.out)""".stripMargin - val check = input + | module Child : + | output out : UInt<1> + | out <= UInt<1>(0) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of Child + | z <= and(x, c.out)""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Child.out"))) } it should "NOT optimize across dontTouch on input ports" in { val input = """circuit Top : - | module Child : - | input in0 : UInt<1> - | input in1 : UInt<1> - | output out : UInt<1> - | out <= and(in0, in1) - | module Top : - | input x : UInt<1> - | output z : UInt<1> - | inst c of Child - | z <= c.out - | c.in0 <= x - | c.in1 <= UInt<1>(1)""".stripMargin - val check = input + | module Child : + | input in0 : UInt<1> + | input in1 : UInt<1> + | output out : UInt<1> + | out <= and(in0, in1) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of Child + | z <= c.out + | c.in0 <= x + | c.in1 <= UInt<1>(1)""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Child.in1"))) } it should "still propagate constants even when there is name swapping" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | input y : UInt<1> - | output z : UInt<1> - | node _T_1 = and(and(x, y), UInt<1>(0)) - | node n = _T_1 - | z <= n""".stripMargin - val check = - """circuit Top : - | module Top : - | input x : UInt<1> - | input y : UInt<1> - | output z : UInt<1> - | z <= UInt<1>(0)""".stripMargin + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | node _T_1 = and(and(x, y), UInt<1>(0)) + | node n = _T_1 + | z <= n""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | z <= UInt<1>(0)""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to wires when propagating" in { - val input = - """circuit Top : - | module Top : - | output z : UInt<16> - | wire w : { a : UInt<8>, b : UInt<8> } - | w.a <= UInt<2>("h3") - | w.b <= UInt<2>("h3") - | z <= cat(w.a, w.b)""".stripMargin - val check = - """circuit Top : - | module Top : - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Top : + | output z : UInt<16> + | wire w : { a : UInt<8>, b : UInt<8> } + | w.a <= UInt<2>("h3") + | w.b <= UInt<2>("h3") + | z <= cat(w.a, w.b)""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to registers when propagating" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | reg r : { a : UInt<8>, b : UInt<8> }, clock - | r.a <= UInt<2>("h3") - | r.b <= UInt<2>("h3") - | z <= cat(r.a, r.b)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | reg r : { a : UInt<8>, b : UInt<8> }, clock + | r.a <= UInt<2>("h3") + | r.b <= UInt<2>("h3") + | z <= cat(r.a, r.b)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "pad zero when constant propping a register replaced with zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | reg r : UInt<8>, clock - | r <= or(r, UInt(0)) - | node n = UInt("hab") - | z <= cat(n, r)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | z <= UInt<16>("hab00")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | reg r : UInt<8>, clock + | r <= or(r, UInt(0)) + | node n = UInt("hab") + | z <= cat(n, r)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | z <= UInt<16>("hab00")""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to outputs when propagating" in { - val input = - """circuit Top : - | module Child : - | output x : UInt<8> - | x <= UInt<2>("h3") - | module Top : - | output z : UInt<16> - | inst c of Child - | z <= cat(UInt<2>("h3"), c.x)""".stripMargin - val check = - """circuit Top : - | module Top : - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Child : + | output x : UInt<8> + | x <= UInt<2>("h3") + | module Top : + | output z : UInt<16> + | inst c of Child + | z <= cat(UInt<2>("h3"), c.x)""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to submodule inputs when propagating" in { - val input = - """circuit Top : - | module Child : - | input x : UInt<8> - | output y : UInt<16> - | y <= cat(UInt<2>("h3"), x) - | module Top : - | output z : UInt<16> - | inst c of Child - | c.x <= UInt<2>("h3") - | z <= c.y""".stripMargin - val check = - """circuit Top : - | module Top : - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Child : + | input x : UInt<8> + | output y : UInt<16> + | y <= cat(UInt<2>("h3"), x) + | module Top : + | output z : UInt<16> + | inst c of Child + | c.x <= UInt<2>("h3") + | z <= c.y""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "remove pads if the width is <= the width of the argument" in { def input(w: Int) = - s"""circuit Top : - | module Top : - | input x : UInt<8> - | output z : UInt<8> - | z <= pad(x, $w)""".stripMargin + s"""circuit Top : + | module Top : + | input x : UInt<8> + | output z : UInt<8> + | z <= pad(x, $w)""".stripMargin val check = """circuit Top : | module Top : @@ -1029,247 +1026,246 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input(8), check, Seq.empty) } - "Registers with no reset or connections" should "be replaced with constant zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<8> - | reg r : UInt<8>, clock - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<8> - | z <= UInt<8>(0)""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<8> + | reg r : UInt<8>, clock + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<8> + | z <= UInt<8>(0)""".stripMargin execute(input, check, Seq.empty) } "Registers with ONLY constant reset" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | z <= UInt<8>("hb")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } "Registers async reset and a constant connection" should "NOT be removed" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : AsyncReset - | input en : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | when en : - | r <= UInt<4>("h0") - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : AsyncReset - | input en : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : - | reset => (reset, UInt<8>("hb")) - | z <= r - | r <= mux(en, UInt<8>("h0"), r)""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : AsyncReset + | input en : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | when en : + | r <= UInt<4>("h0") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : AsyncReset + | input en : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : + | reset => (reset, UInt<8>("hb")) + | z <= r + | r <= mux(en, UInt<8>("h0"), r)""".stripMargin execute(input, check, Seq.empty) } "Registers with constant reset and connection to the same constant" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cond : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | when cond : - | r <= UInt<4>("hb") - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cond : UInt<1> - | output z : UInt<8> - | z <= UInt<8>("hb")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cond : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | when cond : + | r <= UInt<4>("hb") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cond : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } "Const prop of registers" should "do limited speculative expansion of optimized muxes to absorb bigger cones" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input en : UInt<1> - | output out : UInt<1> - | reg r1 : UInt<1>, clock - | reg r2 : UInt<1>, clock - | when en : - | r1 <= UInt<1>(1) - | r2 <= UInt<1>(0) - | when en : - | r2 <= r2 - | out <= xor(r1, r2)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input en : UInt<1> - | output out : UInt<1> - | out <= UInt<1>("h1")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input en : UInt<1> + | output out : UInt<1> + | reg r1 : UInt<1>, clock + | reg r2 : UInt<1>, clock + | when en : + | r1 <= UInt<1>(1) + | r2 <= UInt<1>(0) + | when en : + | r2 <= r2 + | out <= xor(r1, r2)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input en : UInt<1> + | output out : UInt<1> + | out <= UInt<1>("h1")""".stripMargin execute(input, check, Seq.empty) } "A register with constant reset and all connection to either itself or the same constant" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cmd : UInt<3> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("h7"))) - | r <= r - | when eq(cmd, UInt<3>("h0")) : - | r <= UInt<3>("h7") - | else : - | when eq(cmd, UInt<3>("h1")) : - | r <= r - | else : - | when eq(cmd, UInt<3>("h2")) : - | r <= UInt<4>("h7") - | else : - | r <= r - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cmd : UInt<3> - | output z : UInt<8> - | z <= UInt<8>("h7")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cmd : UInt<3> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("h7"))) + | r <= r + | when eq(cmd, UInt<3>("h0")) : + | r <= UInt<3>("h7") + | else : + | when eq(cmd, UInt<3>("h1")) : + | r <= r + | else : + | when eq(cmd, UInt<3>("h2")) : + | r <= UInt<4>("h7") + | else : + | r <= r + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cmd : UInt<3> + | output z : UInt<8> + | z <= UInt<8>("h7")""".stripMargin execute(input, check, Seq.empty) } "Registers with ONLY constant connection" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : SInt<8> - | reg r : SInt<8>, clock - | r <= SInt<4>(-5) - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : SInt<8> - | z <= SInt<8>(-5)""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : SInt<8> + | reg r : SInt<8>, clock + | r <= SInt<4>(-5) + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : SInt<8> + | z <= SInt<8>(-5)""".stripMargin execute(input, check, Seq.empty) } "Registers with identical constant reset and connection" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | r <= UInt<4>("hb") - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | z <= UInt<8>("hb")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | r <= UInt<4>("hb") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } "Connections to a node reference" should "be replaced with the rhs of that node" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<8> - | input b : UInt<8> - | input c : UInt<1> - | output z : UInt<8> - | node x = mux(c, a, b) - | z <= x""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<8> - | input b : UInt<8> - | input c : UInt<1> - | output z : UInt<8> - | z <= mux(c, a, b)""".stripMargin + val input = + """circuit Top : + | module Top : + | input a : UInt<8> + | input b : UInt<8> + | input c : UInt<1> + | output z : UInt<8> + | node x = mux(c, a, b) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<8> + | input b : UInt<8> + | input c : UInt<1> + | output z : UInt<8> + | z <= mux(c, a, b)""".stripMargin execute(input, check, Seq.empty) } "Registers connected only to themselves" should "be replaced with zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | reg ra : UInt<8>, clock - | ra <= ra - | a <= ra - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | a <= UInt<8>(0) - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | reg ra : UInt<8>, clock + | ra <= ra + | a <= ra + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | a <= UInt<8>(0) + |""".stripMargin execute(input, check, Seq.empty) } "Registers connected only to themselves from constant propagation" should "be replaced with zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | reg ra : UInt<8>, clock - | ra <= or(ra, UInt(0)) - | a <= ra - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | a <= UInt<8>(0) - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | reg ra : UInt<8>, clock + | ra <= or(ra, UInt(0)) + | a <= ra + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | a <= UInt<8>(0) + |""".stripMargin execute(input, check, Seq.empty) } @@ -1290,7 +1286,7 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input, check, Seq.empty) } - behavior of "ConstProp" + behavior.of("ConstProp") it should "optimize shl of constants" in { val input = @@ -1381,30 +1377,30 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { it should "optimize some binary operations when arguments match" in { // Signedness matters - matchingArgs("sub", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) - matchingArgs("sub", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ ) - matchingArgs("div", "UInt<8>", "UInt<8>", """ UInt<8>("h1") """ ) - matchingArgs("div", "SInt<8>", "SInt<8>", """ SInt<8>("h1") """ ) - matchingArgs("rem", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) - matchingArgs("rem", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ ) - matchingArgs("and", "UInt<8>", "UInt<8>", """ i """ ) - matchingArgs("and", "SInt<8>", "UInt<8>", """ asUInt(i) """ ) + matchingArgs("sub", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """) + matchingArgs("sub", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """) + matchingArgs("div", "UInt<8>", "UInt<8>", """ UInt<8>("h1") """) + matchingArgs("div", "SInt<8>", "SInt<8>", """ SInt<8>("h1") """) + matchingArgs("rem", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """) + matchingArgs("rem", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """) + matchingArgs("and", "UInt<8>", "UInt<8>", """ i """) + matchingArgs("and", "SInt<8>", "UInt<8>", """ asUInt(i) """) // Signedness doesn't matter - matchingArgs("or", "UInt<8>", "UInt<8>", """ i """ ) - matchingArgs("or", "SInt<8>", "UInt<8>", """ asUInt(i) """ ) - matchingArgs("xor", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) - matchingArgs("xor", "SInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) + matchingArgs("or", "UInt<8>", "UInt<8>", """ i """) + matchingArgs("or", "SInt<8>", "UInt<8>", """ asUInt(i) """) + matchingArgs("xor", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """) + matchingArgs("xor", "SInt<8>", "UInt<8>", """ UInt<8>("h0") """) // Always true - matchingArgs("eq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) - matchingArgs("leq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) - matchingArgs("geq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) + matchingArgs("eq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """) + matchingArgs("leq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """) + matchingArgs("geq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """) // Never true - matchingArgs("neq", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) - matchingArgs("lt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) - matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) + matchingArgs("neq", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """) + matchingArgs("lt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """) + matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """) } - behavior of "Reduction operators" + behavior.of("Reduction operators") it should "optimize andr of a literal" in { val input = @@ -1534,7 +1530,6 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { } - class ConstantPropagationEquivalenceSpec extends FirrtlFlatSpec { private val srcDir = "/constant_propagation_tests" private val transforms = Seq(new ConstantPropagation) @@ -1642,15 +1637,15 @@ class ConstantPropagationEquivalenceSpec extends FirrtlFlatSpec { firrtlEquivalenceTest(input, transforms) } - "addition of negative literals" should "be propagated" in { - val input = - s"""circuit AddTester : - | module AddTester : - | output ref : SInt<2> - | ref <= add(SInt<1>("h-1"), SInt<1>("h-1")) - |""".stripMargin - firrtlEquivalenceTest(input, transforms) - } + "addition of negative literals" should "be propagated" in { + val input = + s"""circuit AddTester : + | module AddTester : + | output ref : SInt<2> + | ref <= add(SInt<1>("h-1"), SInt<1>("h-1")) + |""".stripMargin + firrtlEquivalenceTest(input, transforms) + } "propagation of signed expressions" should "have the correct signs" in { val input = |
