diff options
| author | Jack Koenig | 2021-03-29 10:21:26 -0700 |
|---|---|---|
| committer | GitHub | 2021-03-29 17:21:26 +0000 |
| commit | a41af6f0a34f9e13866002f19040a40ef55ee9e5 (patch) | |
| tree | 6ef270e7cc4e848940c00cda0b9518d5a7481761 /src/test | |
| parent | abeff01f0714d5474b9d18d78fc13011e5ad6b99 (diff) | |
Fix RemoveAccesses, delete CSESubAccesses (#2157)
CSESubAccesses was intended to be a simple workaround for a quadratic
performance bug in RemoveAccesses but ended up having tricky corner
cases and was hard to get right. The solution to the RemoveAccesses
bug--quadratic expansion of dynamic indexes of vecs of aggreate
type--turned out to be quite simple and makes CSESubAccesses much less
useful and not worth fixing.
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src/test')
5 files changed, 262 insertions, 231 deletions
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 9425a582..78d03e68 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -486,8 +486,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | out <= in0[in1[in2[0]]][in1[in2[1]]] |""".stripMargin val expected = Seq( - "node _in0_in1_in1 = _in0_in1_in1_in2_1", - "out <= _in0_in1_in1" + "out <= _in0_in1_in1_in2_1" ) executeTest(input, expected) diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index ee6077d3..bdc72e7b 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -180,7 +180,6 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { it should "replicate the old order" in { val tm = new TransformManager(Forms.MidForm, Forms.Deduped) val patches = Seq( - Add(2, Seq(Dependency[firrtl.transforms.CSESubAccesses])), Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))), Add(5, Seq(Dependency(firrtl.passes.ResolveKinds))), // Uniquify is now part of [[firrtl.passes.LowerTypes]] diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 0a0df355..061837d7 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -189,14 +189,14 @@ class UnitTests extends FirrtlFlatSpec { //TODO(azidar): I realize this is brittle, but unfortunately there // isn't a better way to test this pass val check = Seq( - """wire _table_1 : { a : UInt<8>}""", - """_table_1.a is invalid""", + """wire _table_1_a : UInt<8>""", + """_table_1_a is invalid""", """when UInt<1>("h1") :""", - """_table_1.a <= table[1].a""", + """_table_1_a <= table[1].a""", """wire _otherTable_table_1_a_a : UInt<8>""", - """when eq(UInt<1>("h0"), _table_1.a) :""", + """when eq(UInt<1>("h0"), _table_1_a) :""", """otherTable[0].a <= _otherTable_table_1_a_a""", - """when eq(UInt<1>("h1"), _table_1.a) :""", + """when eq(UInt<1>("h1"), _table_1_a) :""", """otherTable[1].a <= _otherTable_table_1_a_a""", """_otherTable_table_1_a_a <= UInt<1>("h0")""" ) diff --git a/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala b/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala new file mode 100644 index 00000000..1f1f1968 --- /dev/null +++ b/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests +package passes + +import firrtl._ +import firrtl.testutils._ +import firrtl.stage.TransformManager +import firrtl.options.Dependency +import firrtl.passes._ + +class RemoveAccessesSpec extends FirrtlFlatSpec { + def compile(input: String): String = { + val manager = new TransformManager(Dependency(RemoveAccesses) :: Nil) + val result = manager.execute(CircuitState(parse(input), Nil)) + val checks = List( + CheckHighForm, + CheckTypes, + CheckFlows + ) + for (check <- checks) { check.run(result.circuit) } + result.circuit.serialize + } + def circuit(body: String): String = { + """|circuit Test : + | module Test : + |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n") + } + + behavior.of("RemoveAccesses") + + it should "handle a simple RHS subaccess" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |wire _in_idx : UInt<8> + |_in_idx is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx <= in[0] + |when eq(UInt<1>("h1"), idx) : + | _in_idx <= in[1] + |when eq(UInt<2>("h2"), idx) : + | _in_idx <= in[2] + |when eq(UInt<2>("h3"), idx) : + | _in_idx <= in[3] + |out <= _in_idx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support complex expressions" in { + val input = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |out <= in[mux(sel, r, idx)] + |r <= not(idx)""" + ) + val expected = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |wire _in_mux : UInt<8> + |_in_mux is invalid + |when eq(UInt<1>("h0"), mux(sel, r, idx)) : + | _in_mux <= in[0] + |when eq(UInt<1>("h1"), mux(sel, r, idx)) : + | _in_mux <= in[1] + |when eq(UInt<2>("h2"), mux(sel, r, idx)) : + | _in_mux <= in[2] + |when eq(UInt<2>("h3"), mux(sel, r, idx)) : + | _in_mux <= in[3] + |out <= _in_mux + |r <= not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support nested subaccesses" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |out <= in[idx[jdx]]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |wire _idx_jdx : UInt<2> + |_idx_jdx is invalid + |when eq(UInt<1>("h0"), jdx) : + | _idx_jdx <= idx[0] + |when eq(UInt<1>("h1"), jdx) : + | _idx_jdx <= idx[1] + |when eq(UInt<2>("h2"), jdx) : + | _idx_jdx <= idx[2] + |when eq(UInt<2>("h3"), jdx) : + | _idx_jdx <= idx[3] + |wire _in_idx_jdx : UInt<8> + |_in_idx_jdx is invalid + |when eq(UInt<1>("h0"), _idx_jdx) : + | _in_idx_jdx <= in[0] + |when eq(UInt<1>("h1"), _idx_jdx) : + | _in_idx_jdx <= in[1] + |when eq(UInt<2>("h2"), _idx_jdx) : + | _in_idx_jdx <= in[2] + |when eq(UInt<2>("h3"), _idx_jdx) : + | _in_idx_jdx <= in[3] + |out <= _in_idx_jdx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "avoid name collisions" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx] + |node _in_idx = not(idx)""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |wire _in_idx_0 : UInt<8> + |_in_idx_0 is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx_0 <= in[0] + |when eq(UInt<1>("h1"), idx) : + | _in_idx_0 <= in[1] + |when eq(UInt<2>("h2"), idx) : + | _in_idx_0 <= in[2] + |when eq(UInt<2>("h3"), idx) : + | _in_idx_0 <= in[3] + |out <= _in_idx_0 + |node _in_idx = not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "handle a simple LHS subaccess" in { + val input = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |out[idx] <= in""" + ) + val expected = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |wire _out_idx : UInt<8> + |when eq(UInt<1>("h0"), idx) : + | out[0] <= _out_idx + |when eq(UInt<1>("h1"), idx) : + | out[1] <= _out_idx + |when eq(UInt<2>("h2"), idx) : + | out[2] <= _out_idx + |when eq(UInt<2>("h3"), idx) : + | out[3] <= _out_idx + |_out_idx <= in""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "linearly expand RHS subaccesses of aggregate-typed vecs" in { + val input = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> } + |out.foo <= in[idx].foo + |out.bar <= in[idx].bar""" + ) + val expected = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8>}[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8>} + |wire _in_idx_foo : UInt<8> + |_in_idx_foo is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx_foo <= in[0].foo + |when eq(UInt<1>("h1"), idx) : + | _in_idx_foo <= in[1].foo + |when eq(UInt<2>("h2"), idx) : + | _in_idx_foo <= in[2].foo + |when eq(UInt<2>("h3"), idx) : + | _in_idx_foo <= in[3].foo + |out.foo <= _in_idx_foo + |wire _in_idx_bar : UInt<8> + |_in_idx_bar is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx_bar <= in[0].bar + |when eq(UInt<1>("h1"), idx) : + | _in_idx_bar <= in[1].bar + |when eq(UInt<2>("h2"), idx) : + | _in_idx_bar <= in[2].bar + |when eq(UInt<2>("h3"), idx) : + | _in_idx_bar <= in[3].bar + |out.bar <= _in_idx_bar""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "linearly expand LHS subaccesses of aggregate-typed vecs" in { + val input = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> } + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> }[4] + |out[idx].foo <= in.foo + |out[idx].bar <= in.bar""" + ) + val expected = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> } + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> }[4] + |wire _out_idx_foo : UInt<8> + |when eq(UInt<1>("h0"), idx) : + | out[0].foo <= _out_idx_foo + |when eq(UInt<1>("h1"), idx) : + | out[1].foo <= _out_idx_foo + |when eq(UInt<2>("h2"), idx) : + | out[2].foo <= _out_idx_foo + |when eq(UInt<2>("h3"), idx) : + | out[3].foo <= _out_idx_foo + |_out_idx_foo <= in.foo + |wire _out_idx_bar : UInt<8> + |when eq(UInt<1>("h0"), idx) : + | out[0].bar <= _out_idx_bar + |when eq(UInt<1>("h1"), idx) : + | out[1].bar <= _out_idx_bar + |when eq(UInt<2>("h2"), idx) : + | out[2].bar <= _out_idx_bar + |when eq(UInt<2>("h3"), idx) : + | out[3].bar <= _out_idx_bar + |_out_idx_bar <= in.bar""" + ) + compile(input) should be(parse(expected).serialize) + } +} diff --git a/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala deleted file mode 100644 index f7d67026..00000000 --- a/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala +++ /dev/null @@ -1,223 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtlTests -package transforms - -import firrtl._ -import firrtl.testutils._ -import firrtl.stage.TransformManager -import firrtl.options.Dependency -import firrtl.transforms.CSESubAccesses - -class CSESubAccessesSpec extends FirrtlFlatSpec { - def compile(input: String): String = { - val manager = new TransformManager(Dependency[CSESubAccesses] :: Nil) - val result = manager.execute(CircuitState(parse(input), Nil)) - result.circuit.serialize - } - def circuit(body: String): String = { - """|circuit Test : - | module Test : - |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n") - } - - behavior.of("CSESubAccesses") - - it should "hoist a single RHS subaccess" in { - val input = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2> - |output out : UInt<8> - |out <= in[idx]""" - ) - val expected = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2> - |output out : UInt<8> - |node _in_idx = in[idx] - |out <= _in_idx""" - ) - compile(input) should be(parse(expected).serialize) - } - - it should "be idempotent" in { - val input = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2> - |output out : UInt<8> - |out <= in[idx]""" - ) - val expected = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2> - |output out : UInt<8> - |node _in_idx = in[idx] - |out <= _in_idx""" - ) - val first = compile(input) - val second = compile(first) - first should be(second) - first should be(parse(expected).serialize) - } - - it should "hoist a redundant RHS subaccess" in { - val input = circuit( - s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4] - |input idx : UInt<2> - |output out : { foo : UInt<8>, bar : UInt<8> } - |out.foo <= in[idx].foo - |out.bar <= in[idx].bar""" - ) - val expected = circuit( - s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4] - |input idx : UInt<2> - |output out : { foo : UInt<8>, bar : UInt<8> } - |node _in_idx = in[idx] - |out.foo <= _in_idx.foo - |out.bar <= _in_idx.bar""" - ) - compile(input) should be(parse(expected).serialize) - } - - it should "correctly place hosited subaccess after last declaration it depends on" in { - val input = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2> - |output out : UInt<8> - |out is invalid - |when UInt(1) : - | node nidx = not(idx) - | out <= in[nidx] - |""" - ) - val expected = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2> - |output out : UInt<8> - |out is invalid - |when UInt(1) : - | node nidx = not(idx) - | node _in_nidx = in[nidx] - | out <= _in_nidx - |""" - ) - compile(input) should be(parse(expected).serialize) - } - - it should "support complex expressions" in { - val input = circuit( - s"""|input clock : Clock - |input in : UInt<8>[4] - |input idx : UInt<2> - |input sel : UInt<1> - |output out : UInt<8> - |reg r : UInt<2>, clock - |out <= in[mux(sel, r, idx)] - |r <= not(idx)""" - ) - val expected = circuit( - s"""|input clock : Clock - |input in : UInt<8>[4] - |input idx : UInt<2> - |input sel : UInt<1> - |output out : UInt<8> - |reg r : UInt<2>, clock - |node _in_mux = in[mux(sel, r, idx)] - |out <= _in_mux - |r <= not(idx)""" - ) - compile(input) should be(parse(expected).serialize) - } - - it should "support nested subaccesses" in { - val input = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2>[4] - |input jdx : UInt<2> - |output out : UInt<8> - |out <= in[idx[jdx]]""" - ) - val expected = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2>[4] - |input jdx : UInt<2> - |output out : UInt<8> - |node _idx_jdx = idx[jdx] - |node _in_idx = in[_idx_jdx] - |out <= _in_idx""" - ) - compile(input) should be(parse(expected).serialize) - } - - it should "avoid name collisions" in { - val input = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2> - |output out : UInt<8> - |out <= in[idx] - |node _in_idx = not(idx)""" - ) - val expected = circuit( - s"""|input in : UInt<8>[4] - |input idx : UInt<2> - |output out : UInt<8> - |node _in_idx_0 = in[idx] - |out <= _in_idx_0 - |node _in_idx = not(idx)""" - ) - compile(input) should be(parse(expected).serialize) - } - - it should "have no effect on LHS SubAccesses" in { - val input = circuit( - s"""|input in : UInt<8> - |input idx : UInt<2> - |output out : UInt<8>[4] - |out[idx] <= in""" - ) - val expected = circuit( - s"""|input in : UInt<8> - |input idx : UInt<2> - |output out : UInt<8>[4] - |out[idx] <= in""" - ) - compile(input) should be(parse(expected).serialize) - } - - it should "ignore flipped LHS SubAccesses" in { - val input = circuit( - s"""|input in : { foo : UInt<8> } - |input idx : UInt<1> - |input out : { flip foo : UInt<8> }[2] - |out[0].foo <= UInt(0) - |out[1].foo <= UInt(0) - |out[idx].foo <= in.foo""" - ) - val expected = circuit( - s"""|input in : { foo : UInt<8> } - |input idx : UInt<1> - |input out : { flip foo : UInt<8> }[2] - |out[0].foo <= UInt(0) - |out[1].foo <= UInt(0) - |out[idx].foo <= in.foo""" - ) - compile(input) should be(parse(expected).serialize) - } - - it should "ignore SubAccesses of bidirectional aggregates" in { - val input = circuit( - s"""|input in : { flip foo : UInt<8>, bar : UInt<8> } - |input idx : UInt<2> - |output out : { flip foo : UInt<8>, bar : UInt<8> }[4] - |out[idx] <= in""" - ) - val expected = circuit( - s"""|input in : { flip foo : UInt<8>, bar : UInt<8> } - |input idx : UInt<2> - |output out : { flip foo : UInt<8>, bar : UInt<8> }[4] - |out[idx] <= in""" - ) - compile(input) should be(parse(expected).serialize) - } - -} |
