aboutsummaryrefslogtreecommitdiff
path: root/src/test
diff options
context:
space:
mode:
authorJack Koenig2021-03-03 17:01:53 -0800
committerGitHub2021-03-04 01:01:53 +0000
commite58ba0c12e5d650983c70a61a45542f0cd43fb88 (patch)
treefc0689df82dd3b9fcadb7ea8d5fc082b35afb20e /src/test
parent5be1abb4c654279762a463a861526ce4e0c48035 (diff)
CSE SubAccesses (#2099)
Fixes n^2 performance problem when dynamically indexing Vecs of aggregate types. Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src/test')
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/LoweringCompilersSpec.scala1
-rw-r--r--src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala187
3 files changed, 190 insertions, 1 deletions
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index 78d03e68..9425a582 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -486,7 +486,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec {
| out <= in0[in1[in2[0]]][in1[in2[1]]]
|""".stripMargin
val expected = Seq(
- "out <= _in0_in1_in1_in2_1"
+ "node _in0_in1_in1 = _in0_in1_in1_in2_1",
+ "out <= _in0_in1_in1"
)
executeTest(input, expected)
diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
index bdc72e7b..ee6077d3 100644
--- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
+++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
@@ -180,6 +180,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.MidForm, Forms.Deduped)
val patches = Seq(
+ Add(2, Seq(Dependency[firrtl.transforms.CSESubAccesses])),
Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))),
Add(5, Seq(Dependency(firrtl.passes.ResolveKinds))),
// Uniquify is now part of [[firrtl.passes.LowerTypes]]
diff --git a/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala
new file mode 100644
index 00000000..55ce07df
--- /dev/null
+++ b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala
@@ -0,0 +1,187 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtlTests
+package transforms
+
+import firrtl._
+import firrtl.testutils._
+import firrtl.stage.TransformManager
+import firrtl.options.Dependency
+import firrtl.transforms.CSESubAccesses
+
+class CSESubAccessesSpec extends FirrtlFlatSpec {
+ def compile(input: String): String = {
+ val manager = new TransformManager(Dependency[CSESubAccesses] :: Nil)
+ val result = manager.execute(CircuitState(parse(input), Nil))
+ result.circuit.serialize
+ }
+ def circuit(body: String): String = {
+ """|circuit Test :
+ | module Test :
+ |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n")
+ }
+
+ behavior.of("CSESubAccesses")
+
+ it should "hoist a single RHS subaccess" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx]"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |node _in_idx = in[idx]
+ |out <= _in_idx"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "be idempotent" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx]"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |node _in_idx = in[idx]
+ |out <= _in_idx"""
+ )
+ val first = compile(input)
+ val second = compile(first)
+ first should be(second)
+ first should be(parse(expected).serialize)
+ }
+
+ it should "hoist a redundant RHS subaccess" in {
+ val input = circuit(
+ s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4]
+ |input idx : UInt<2>
+ |output out : { foo : UInt<8>, bar : UInt<8> }
+ |out.foo <= in[idx].foo
+ |out.bar <= in[idx].bar"""
+ )
+ val expected = circuit(
+ s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4]
+ |input idx : UInt<2>
+ |output out : { foo : UInt<8>, bar : UInt<8> }
+ |node _in_idx = in[idx]
+ |out.foo <= _in_idx.foo
+ |out.bar <= _in_idx.bar"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "correctly place hosited subaccess after last declaration it depends on" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out is invalid
+ |when UInt(1) :
+ | node nidx = not(idx)
+ | out <= in[nidx]
+ |"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out is invalid
+ |when UInt(1) :
+ | node nidx = not(idx)
+ | node _in_nidx = in[nidx]
+ | out <= _in_nidx
+ |"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "support complex expressions" in {
+ val input = circuit(
+ s"""|input clock : Clock
+ |input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |input sel : UInt<1>
+ |output out : UInt<8>
+ |reg r : UInt<2>, clock
+ |out <= in[mux(sel, r, idx)]
+ |r <= not(idx)"""
+ )
+ val expected = circuit(
+ s"""|input clock : Clock
+ |input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |input sel : UInt<1>
+ |output out : UInt<8>
+ |reg r : UInt<2>, clock
+ |node _in_mux = in[mux(sel, r, idx)]
+ |out <= _in_mux
+ |r <= not(idx)"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "support nested subaccesses" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>[4]
+ |input jdx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx[jdx]]"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>[4]
+ |input jdx : UInt<2>
+ |output out : UInt<8>
+ |node _idx_jdx = idx[jdx]
+ |node _in_idx = in[_idx_jdx]
+ |out <= _in_idx"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "avoid name collisions" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx]
+ |node _in_idx = not(idx)"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |node _in_idx_0 = in[idx]
+ |out <= _in_idx_0
+ |node _in_idx = not(idx)"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "have no effect on LHS SubAccesses" in {
+ val input = circuit(
+ s"""|input in : UInt<8>
+ |input idx : UInt<2>
+ |output out : UInt<8>[4]
+ |out[idx] <= in"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>
+ |input idx : UInt<2>
+ |output out : UInt<8>[4]
+ |out[idx] <= in"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+}