diff options
| author | Jack Koenig | 2021-03-03 17:01:53 -0800 |
|---|---|---|
| committer | GitHub | 2021-03-04 01:01:53 +0000 |
| commit | e58ba0c12e5d650983c70a61a45542f0cd43fb88 (patch) | |
| tree | fc0689df82dd3b9fcadb7ea8d5fc082b35afb20e /src/test | |
| parent | 5be1abb4c654279762a463a861526ce4e0c48035 (diff) | |
CSE SubAccesses (#2099)
Fixes n^2 performance problem when dynamically indexing Vecs of
aggregate types.
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src/test')
3 files changed, 190 insertions, 1 deletions
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 78d03e68..9425a582 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -486,7 +486,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | out <= in0[in1[in2[0]]][in1[in2[1]]] |""".stripMargin val expected = Seq( - "out <= _in0_in1_in1_in2_1" + "node _in0_in1_in1 = _in0_in1_in1_in2_1", + "out <= _in0_in1_in1" ) executeTest(input, expected) diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index bdc72e7b..ee6077d3 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -180,6 +180,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { it should "replicate the old order" in { val tm = new TransformManager(Forms.MidForm, Forms.Deduped) val patches = Seq( + Add(2, Seq(Dependency[firrtl.transforms.CSESubAccesses])), Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))), Add(5, Seq(Dependency(firrtl.passes.ResolveKinds))), // Uniquify is now part of [[firrtl.passes.LowerTypes]] diff --git a/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala new file mode 100644 index 00000000..55ce07df --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests +package transforms + +import firrtl._ +import firrtl.testutils._ +import firrtl.stage.TransformManager +import firrtl.options.Dependency +import firrtl.transforms.CSESubAccesses + +class CSESubAccessesSpec extends FirrtlFlatSpec { + def compile(input: String): String = { + val manager = new TransformManager(Dependency[CSESubAccesses] :: Nil) + val result = manager.execute(CircuitState(parse(input), Nil)) + result.circuit.serialize + } + def circuit(body: String): String = { + """|circuit Test : + | module Test : + |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n") + } + + behavior.of("CSESubAccesses") + + it should "hoist a single RHS subaccess" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |node _in_idx = in[idx] + |out <= _in_idx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "be idempotent" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |node _in_idx = in[idx] + |out <= _in_idx""" + ) + val first = compile(input) + val second = compile(first) + first should be(second) + first should be(parse(expected).serialize) + } + + it should "hoist a redundant RHS subaccess" in { + val input = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> } + |out.foo <= in[idx].foo + |out.bar <= in[idx].bar""" + ) + val expected = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> } + |node _in_idx = in[idx] + |out.foo <= _in_idx.foo + |out.bar <= _in_idx.bar""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "correctly place hosited subaccess after last declaration it depends on" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out is invalid + |when UInt(1) : + | node nidx = not(idx) + | out <= in[nidx] + |""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out is invalid + |when UInt(1) : + | node nidx = not(idx) + | node _in_nidx = in[nidx] + | out <= _in_nidx + |""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support complex expressions" in { + val input = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |out <= in[mux(sel, r, idx)] + |r <= not(idx)""" + ) + val expected = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |node _in_mux = in[mux(sel, r, idx)] + |out <= _in_mux + |r <= not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support nested subaccesses" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |out <= in[idx[jdx]]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |node _idx_jdx = idx[jdx] + |node _in_idx = in[_idx_jdx] + |out <= _in_idx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "avoid name collisions" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx] + |node _in_idx = not(idx)""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |node _in_idx_0 = in[idx] + |out <= _in_idx_0 + |node _in_idx = not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "have no effect on LHS SubAccesses" in { + val input = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |out[idx] <= in""" + ) + val expected = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |out[idx] <= in""" + ) + compile(input) should be(parse(expected).serialize) + } + +} |
