diff options
| author | Jack Koenig | 2021-03-29 10:21:26 -0700 |
|---|---|---|
| committer | GitHub | 2021-03-29 17:21:26 +0000 |
| commit | a41af6f0a34f9e13866002f19040a40ef55ee9e5 (patch) | |
| tree | 6ef270e7cc4e848940c00cda0b9518d5a7481761 /src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala | |
| parent | abeff01f0714d5474b9d18d78fc13011e5ad6b99 (diff) | |
Fix RemoveAccesses, delete CSESubAccesses (#2157)
CSESubAccesses was intended to be a simple workaround for a quadratic
performance bug in RemoveAccesses but ended up having tricky corner
cases and was hard to get right. The solution to the RemoveAccesses
bug--quadratic expansion of dynamic indexes of vecs of aggreate
type--turned out to be quite simple and makes CSESubAccesses much less
useful and not worth fixing.
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala')
| -rw-r--r-- | src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala | 256 |
1 files changed, 256 insertions, 0 deletions
diff --git a/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala b/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala new file mode 100644 index 00000000..1f1f1968 --- /dev/null +++ b/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests +package passes + +import firrtl._ +import firrtl.testutils._ +import firrtl.stage.TransformManager +import firrtl.options.Dependency +import firrtl.passes._ + +class RemoveAccessesSpec extends FirrtlFlatSpec { + def compile(input: String): String = { + val manager = new TransformManager(Dependency(RemoveAccesses) :: Nil) + val result = manager.execute(CircuitState(parse(input), Nil)) + val checks = List( + CheckHighForm, + CheckTypes, + CheckFlows + ) + for (check <- checks) { check.run(result.circuit) } + result.circuit.serialize + } + def circuit(body: String): String = { + """|circuit Test : + | module Test : + |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n") + } + + behavior.of("RemoveAccesses") + + it should "handle a simple RHS subaccess" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |wire _in_idx : UInt<8> + |_in_idx is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx <= in[0] + |when eq(UInt<1>("h1"), idx) : + | _in_idx <= in[1] + |when eq(UInt<2>("h2"), idx) : + | _in_idx <= in[2] + |when eq(UInt<2>("h3"), idx) : + | _in_idx <= in[3] + |out <= _in_idx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support complex expressions" in { + val input = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |out <= in[mux(sel, r, idx)] + |r <= not(idx)""" + ) + val expected = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |wire _in_mux : UInt<8> + |_in_mux is invalid + |when eq(UInt<1>("h0"), mux(sel, r, idx)) : + | _in_mux <= in[0] + |when eq(UInt<1>("h1"), mux(sel, r, idx)) : + | _in_mux <= in[1] + |when eq(UInt<2>("h2"), mux(sel, r, idx)) : + | _in_mux <= in[2] + |when eq(UInt<2>("h3"), mux(sel, r, idx)) : + | _in_mux <= in[3] + |out <= _in_mux + |r <= not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support nested subaccesses" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |out <= in[idx[jdx]]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |wire _idx_jdx : UInt<2> + |_idx_jdx is invalid + |when eq(UInt<1>("h0"), jdx) : + | _idx_jdx <= idx[0] + |when eq(UInt<1>("h1"), jdx) : + | _idx_jdx <= idx[1] + |when eq(UInt<2>("h2"), jdx) : + | _idx_jdx <= idx[2] + |when eq(UInt<2>("h3"), jdx) : + | _idx_jdx <= idx[3] + |wire _in_idx_jdx : UInt<8> + |_in_idx_jdx is invalid + |when eq(UInt<1>("h0"), _idx_jdx) : + | _in_idx_jdx <= in[0] + |when eq(UInt<1>("h1"), _idx_jdx) : + | _in_idx_jdx <= in[1] + |when eq(UInt<2>("h2"), _idx_jdx) : + | _in_idx_jdx <= in[2] + |when eq(UInt<2>("h3"), _idx_jdx) : + | _in_idx_jdx <= in[3] + |out <= _in_idx_jdx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "avoid name collisions" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx] + |node _in_idx = not(idx)""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |wire _in_idx_0 : UInt<8> + |_in_idx_0 is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx_0 <= in[0] + |when eq(UInt<1>("h1"), idx) : + | _in_idx_0 <= in[1] + |when eq(UInt<2>("h2"), idx) : + | _in_idx_0 <= in[2] + |when eq(UInt<2>("h3"), idx) : + | _in_idx_0 <= in[3] + |out <= _in_idx_0 + |node _in_idx = not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "handle a simple LHS subaccess" in { + val input = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |out[idx] <= in""" + ) + val expected = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |wire _out_idx : UInt<8> + |when eq(UInt<1>("h0"), idx) : + | out[0] <= _out_idx + |when eq(UInt<1>("h1"), idx) : + | out[1] <= _out_idx + |when eq(UInt<2>("h2"), idx) : + | out[2] <= _out_idx + |when eq(UInt<2>("h3"), idx) : + | out[3] <= _out_idx + |_out_idx <= in""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "linearly expand RHS subaccesses of aggregate-typed vecs" in { + val input = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> } + |out.foo <= in[idx].foo + |out.bar <= in[idx].bar""" + ) + val expected = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8>}[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8>} + |wire _in_idx_foo : UInt<8> + |_in_idx_foo is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx_foo <= in[0].foo + |when eq(UInt<1>("h1"), idx) : + | _in_idx_foo <= in[1].foo + |when eq(UInt<2>("h2"), idx) : + | _in_idx_foo <= in[2].foo + |when eq(UInt<2>("h3"), idx) : + | _in_idx_foo <= in[3].foo + |out.foo <= _in_idx_foo + |wire _in_idx_bar : UInt<8> + |_in_idx_bar is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx_bar <= in[0].bar + |when eq(UInt<1>("h1"), idx) : + | _in_idx_bar <= in[1].bar + |when eq(UInt<2>("h2"), idx) : + | _in_idx_bar <= in[2].bar + |when eq(UInt<2>("h3"), idx) : + | _in_idx_bar <= in[3].bar + |out.bar <= _in_idx_bar""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "linearly expand LHS subaccesses of aggregate-typed vecs" in { + val input = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> } + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> }[4] + |out[idx].foo <= in.foo + |out[idx].bar <= in.bar""" + ) + val expected = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> } + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> }[4] + |wire _out_idx_foo : UInt<8> + |when eq(UInt<1>("h0"), idx) : + | out[0].foo <= _out_idx_foo + |when eq(UInt<1>("h1"), idx) : + | out[1].foo <= _out_idx_foo + |when eq(UInt<2>("h2"), idx) : + | out[2].foo <= _out_idx_foo + |when eq(UInt<2>("h3"), idx) : + | out[3].foo <= _out_idx_foo + |_out_idx_foo <= in.foo + |wire _out_idx_bar : UInt<8> + |when eq(UInt<1>("h0"), idx) : + | out[0].bar <= _out_idx_bar + |when eq(UInt<1>("h1"), idx) : + | out[1].bar <= _out_idx_bar + |when eq(UInt<2>("h2"), idx) : + | out[2].bar <= _out_idx_bar + |when eq(UInt<2>("h3"), idx) : + | out[3].bar <= _out_idx_bar + |_out_idx_bar <= in.bar""" + ) + compile(input) should be(parse(expected).serialize) + } +} |
