aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala63
-rw-r--r--src/main/scala/firrtl/stage/Forms.scala1
-rw-r--r--src/main/scala/firrtl/transforms/CSESubAccesses.scala168
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/LoweringCompilersSpec.scala1
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala10
-rw-r--r--src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala256
-rw-r--r--src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala223
8 files changed, 297 insertions, 428 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index 7449db51..073bf49d 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -9,7 +9,6 @@ import firrtl.Mappers._
import firrtl.Utils._
import firrtl.WrappedExpression._
import firrtl.options.Dependency
-import firrtl.transforms.CSESubAccesses
import scala.collection.mutable
@@ -22,8 +21,7 @@ object RemoveAccesses extends Pass {
Dependency(PullMuxes),
Dependency(ZeroLengthVecs),
Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency[CSESubAccesses]
+ Dependency(ExpandConnects)
) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
@@ -122,26 +120,26 @@ object RemoveAccesses extends Pass {
/** Replaces a subaccess in a given source expression
*/
val stmts = mutable.ArrayBuffer[Statement]()
- def removeSource(e: Expression): Expression = e match {
- case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) =>
- val rs = getLocations(e)
- rs.find(x => x.guard != one) match {
- case None => throwInternalError(s"removeSource: shouldn't be here - $e")
- case Some(_) =>
- val (wire, temp) = create_temp(e)
- val temps = create_exps(temp)
- def getTemp(i: Int) = temps(i % temps.size)
- stmts += wire
- rs.zipWithIndex.foreach {
- case (x, i) if i < temps.size =>
- stmts += IsInvalid(get_info(s), getTemp(i))
- stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
- case (x, i) =>
- stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
- }
- temp
- }
- case _ => e
+ // Only called on RefLikes that definitely have a SubAccess
+ // Must accept Expression because that's the output type of fixIndices
+ def removeSource(e: Expression): Expression = {
+ val rs = getLocations(e)
+ rs.find(x => x.guard != one) match {
+ case None => throwInternalError(s"removeSource: shouldn't be here - $e")
+ case Some(_) =>
+ val (wire, temp) = create_temp(e)
+ val temps = create_exps(temp)
+ def getTemp(i: Int) = temps(i % temps.size)
+ stmts += wire
+ rs.zipWithIndex.foreach {
+ case (x, i) if i < temps.size =>
+ stmts += IsInvalid(get_info(s), getTemp(i))
+ stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
+ case (x, i) =>
+ stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
+ }
+ temp
+ }
}
/** Replaces a subaccess in a given sink expression
@@ -162,14 +160,23 @@ object RemoveAccesses extends Pass {
case _ => loc
}
+ /** Recurse until find SubAccess and call fixSource on its index
+ * @note this only accepts [[RefLikeExpression]]s but we can't enforce it because map
+ * requires Expression => Expression
+ */
+ def fixIndices(e: Expression): Expression = e match {
+ case e: SubAccess => e.copy(index = fixSource(e.index))
+ case other => other.map(fixIndices)
+ }
+
/** Recursively walks a source expression and fixes all subaccesses
- * If we see a sub-access, replace it.
- * Otherwise, map to children.
+ *
+ * If we see a RefLikeExpression that contains a SubAccess, we recursively remove
+ * subaccesses from the indices of any SubAccesses, then process modified RefLikeExpression
*/
def fixSource(e: Expression): Expression = e match {
- case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow))
- //case w: WSubIndex => removeSource(w)
- //case w: WSubField => removeSource(w)
+ case ref: RefLikeExpression =>
+ if (hasAccess(ref)) removeSource(fixIndices(ref)) else ref
case x => x.map(fixSource)
}
diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala
index ba27b552..ab082151 100644
--- a/src/main/scala/firrtl/stage/Forms.scala
+++ b/src/main/scala/firrtl/stage/Forms.scala
@@ -57,7 +57,6 @@ object Forms {
val MidForm: Seq[TransformDependency] = HighForm ++
Seq(
Dependency(passes.PullMuxes),
- Dependency[firrtl.transforms.CSESubAccesses],
Dependency(passes.ReplaceAccesses),
Dependency(passes.ExpandConnects),
Dependency(passes.RemoveAccesses),
diff --git a/src/main/scala/firrtl/transforms/CSESubAccesses.scala b/src/main/scala/firrtl/transforms/CSESubAccesses.scala
deleted file mode 100644
index 6ca27a83..00000000
--- a/src/main/scala/firrtl/transforms/CSESubAccesses.scala
+++ /dev/null
@@ -1,168 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package firrtl
-package transforms
-
-import firrtl.ir._
-import firrtl.traversals.Foreachers._
-import firrtl.Mappers._
-import firrtl.PrimOps._
-import firrtl.WrappedExpression._
-import firrtl.options.Dependency
-import firrtl.passes._
-import firrtl.Utils.{distinctBy, flow, getAllRefs, get_info, niceName}
-
-import scala.collection.mutable
-
-object CSESubAccesses {
-
- // Get all SubAccesses used on the right-hand side along with the info from the outer Statement
- private def collectRValueSubAccesses(mod: Module): Seq[(SubAccess, Info)] = {
- val acc = new mutable.ListBuffer[(SubAccess, Info)]
- def onExpr(outer: Statement)(expr: Expression): Unit = {
- // Need postorder because we want to visit inner SubAccesses first
- // Stop recursing on any non-Source because flips can make the SubAccess a Source despite the
- // overall Expression being a Sink
- if (flow(expr) == SourceFlow) expr.foreach(onExpr(outer))
- expr match {
- case e: SubAccess if flow(e) == SourceFlow => acc += e -> get_info(outer)
- case _ => // Do nothing
- }
- }
- def onStmt(stmt: Statement): Unit = {
- stmt.foreach(onStmt)
- stmt match {
- // Don't record SubAccesses that are already assigned to a Node, but *do* record any nested
- // inside of the SubAccess. This makes the transform idempotent and avoids unnecessary work.
- case DefNode(_, _, acc: SubAccess) => acc.foreach(onExpr(stmt))
- case other => other.foreach(onExpr(stmt))
- }
- }
- onStmt(mod.body)
- distinctBy(acc.toList)(_._1)
- }
-
- // Replaces all right-hand side SubAccesses with References
- private def replaceOnSourceExpr(replace: SubAccess => Reference)(expr: Expression): Expression = expr match {
- // Stop is we ever see a non-SourceFlow
- case e if flow(e) != SourceFlow => e
- // Don't traverse children of SubAccess, just replace it
- // Nested SubAccesses are handled during creation of the nodes that the references refer to
- case acc: SubAccess if flow(acc) == SourceFlow => replace(acc)
- case other => other.map(replaceOnSourceExpr(replace))
- }
-
- private def hoistSubAccesses(
- hoist: String => List[DefNode],
- replace: SubAccess => Reference
- )(stmt: Statement
- ): Statement = {
- val onExpr = replaceOnSourceExpr(replace) _
- def onStmt(s: Statement): Statement = s.map(onExpr).map(onStmt) match {
- case decl: IsDeclaration =>
- val nodes = hoist(decl.name)
- if (nodes.isEmpty) decl else Block(decl :: nodes)
- case other => other
- }
- onStmt(stmt)
- }
-
- // Given some nodes, determine after which String declaration each node should be inserted
- // This function is *mutable*, it keeps track of which declarations each node is sensitive to and
- // returns nodes in groups once the last declaration they depend on is seen
- private def getSensitivityLookup(nodes: Iterable[DefNode]): String => List[DefNode] = {
- case class ReferenceCount(var n: Int, node: DefNode)
- // Gather names of declarations each node depends on
- val nodeDeps = nodes.map(node => getAllRefs(node.value).view.map(_.name).toSet -> node)
- // Map from declaration names to the indices of nodeDeps that depend on it
- val lookup = new mutable.HashMap[String, mutable.ArrayBuffer[Int]]
- for (((decls, _), idx) <- nodeDeps.zipWithIndex) {
- for (d <- decls) {
- val indices = lookup.getOrElseUpdate(d, new mutable.ArrayBuffer[Int])
- indices += idx
- }
- }
- // Now we can just associate each List of nodes with how many declarations they need to see
- // We use an Array because we're mutating anyway and might as well be quick about it
- val nodeLists: Array[ReferenceCount] =
- nodeDeps.view.map { case (deps, node) => ReferenceCount(deps.size, node) }.toArray
-
- // Must be a def because it's recursive
- def func(decl: String): List[DefNode] = {
- if (lookup.contains(decl)) {
- val indices = lookup(decl)
- val result = new mutable.ListBuffer[DefNode]
- lookup -= decl
- for (i <- indices) {
- val refCount = nodeLists(i)
- refCount.n -= 1
- assert(refCount.n >= 0, "Internal Error!")
- if (refCount.n == 0) result += refCount.node
- }
- // DefNodes can depend on each other, recurse
- result.toList.flatMap { node => node :: func(node.name) }
- } else {
- Nil
- }
- }
- func _
- }
-
- /** Performs [[CSESubAccesses]] on a single [[ir.Module Module]] */
- def onMod(mod: Module): Module = {
- // ***** Pre-Analyze (do we even need to do anything) *****
- val accesses = collectRValueSubAccesses(mod)
- if (accesses.isEmpty) mod
- else {
- // ***** Analyze *****
- val namespace = Namespace(mod)
- val replace = new mutable.HashMap[SubAccess, Reference]
- val nodes = new mutable.ArrayBuffer[DefNode]
- for ((acc, info) <- accesses) {
- val name = namespace.newName(niceName(acc))
- // SubAccesses can be nested, so replace any nested ones with prior references
- // This is why post-order traversal in collectRValueSubAccesses is important
- val accx = acc.map(replaceOnSourceExpr(replace))
- val node = DefNode(info, name, accx)
- val ref = Reference(node)
- // Record in replace
- replace(acc) = ref
- // Record node
- nodes += node
- }
- val hoist = getSensitivityLookup(nodes)
-
- // ***** Transform *****
- val portStmts = mod.ports.flatMap(x => hoist(x.name))
- val bodyx = hoistSubAccesses(hoist, replace)(mod.body)
- mod.copy(body = if (portStmts.isEmpty) bodyx else Block(Block(portStmts), bodyx))
- }
- }
-}
-
-/** Performs Common Subexpression Elimination (CSE) on right-hand side [[ir.SubAccess SubAccess]]es
- *
- * This avoids quadratic node creation behavior in [[passes.RemoveAccesses RemoveAccesses]]. For
- * simplicity of implementation, all SubAccesses on the right-hand side are also split into
- * individual nodes.
- */
-class CSESubAccesses extends Transform with DependencyAPIMigration {
-
- override def prerequisites = Dependency(ResolveFlows) :: Dependency(CheckHighForm) :: Nil
-
- // Faster to run after these
- override def optionalPrerequisites = Dependency(ReplaceAccesses) :: Dependency[DedupModules] :: Nil
-
- // Running before ExpandConnects is an optimization
- override def optionalPrerequisiteOf = Dependency(ExpandConnects) :: Nil
-
- override def invalidates(a: Transform) = false
-
- def execute(state: CircuitState): CircuitState = {
- val modulesx = state.circuit.modules.map {
- case ext: ExtModule => ext
- case mod: Module => CSESubAccesses.onMod(mod)
- }
- state.copy(circuit = state.circuit.copy(modules = modulesx))
- }
-}
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index 9425a582..78d03e68 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -486,8 +486,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec {
| out <= in0[in1[in2[0]]][in1[in2[1]]]
|""".stripMargin
val expected = Seq(
- "node _in0_in1_in1 = _in0_in1_in1_in2_1",
- "out <= _in0_in1_in1"
+ "out <= _in0_in1_in1_in2_1"
)
executeTest(input, expected)
diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
index ee6077d3..bdc72e7b 100644
--- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
+++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
@@ -180,7 +180,6 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.MidForm, Forms.Deduped)
val patches = Seq(
- Add(2, Seq(Dependency[firrtl.transforms.CSESubAccesses])),
Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))),
Add(5, Seq(Dependency(firrtl.passes.ResolveKinds))),
// Uniquify is now part of [[firrtl.passes.LowerTypes]]
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index 0a0df355..061837d7 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -189,14 +189,14 @@ class UnitTests extends FirrtlFlatSpec {
//TODO(azidar): I realize this is brittle, but unfortunately there
// isn't a better way to test this pass
val check = Seq(
- """wire _table_1 : { a : UInt<8>}""",
- """_table_1.a is invalid""",
+ """wire _table_1_a : UInt<8>""",
+ """_table_1_a is invalid""",
"""when UInt<1>("h1") :""",
- """_table_1.a <= table[1].a""",
+ """_table_1_a <= table[1].a""",
"""wire _otherTable_table_1_a_a : UInt<8>""",
- """when eq(UInt<1>("h0"), _table_1.a) :""",
+ """when eq(UInt<1>("h0"), _table_1_a) :""",
"""otherTable[0].a <= _otherTable_table_1_a_a""",
- """when eq(UInt<1>("h1"), _table_1.a) :""",
+ """when eq(UInt<1>("h1"), _table_1_a) :""",
"""otherTable[1].a <= _otherTable_table_1_a_a""",
"""_otherTable_table_1_a_a <= UInt<1>("h0")"""
)
diff --git a/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala b/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala
new file mode 100644
index 00000000..1f1f1968
--- /dev/null
+++ b/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala
@@ -0,0 +1,256 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtlTests
+package passes
+
+import firrtl._
+import firrtl.testutils._
+import firrtl.stage.TransformManager
+import firrtl.options.Dependency
+import firrtl.passes._
+
+class RemoveAccessesSpec extends FirrtlFlatSpec {
+ def compile(input: String): String = {
+ val manager = new TransformManager(Dependency(RemoveAccesses) :: Nil)
+ val result = manager.execute(CircuitState(parse(input), Nil))
+ val checks = List(
+ CheckHighForm,
+ CheckTypes,
+ CheckFlows
+ )
+ for (check <- checks) { check.run(result.circuit) }
+ result.circuit.serialize
+ }
+ def circuit(body: String): String = {
+ """|circuit Test :
+ | module Test :
+ |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n")
+ }
+
+ behavior.of("RemoveAccesses")
+
+ it should "handle a simple RHS subaccess" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx]"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |wire _in_idx : UInt<8>
+ |_in_idx is invalid
+ |when eq(UInt<1>("h0"), idx) :
+ | _in_idx <= in[0]
+ |when eq(UInt<1>("h1"), idx) :
+ | _in_idx <= in[1]
+ |when eq(UInt<2>("h2"), idx) :
+ | _in_idx <= in[2]
+ |when eq(UInt<2>("h3"), idx) :
+ | _in_idx <= in[3]
+ |out <= _in_idx"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "support complex expressions" in {
+ val input = circuit(
+ s"""|input clock : Clock
+ |input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |input sel : UInt<1>
+ |output out : UInt<8>
+ |reg r : UInt<2>, clock
+ |out <= in[mux(sel, r, idx)]
+ |r <= not(idx)"""
+ )
+ val expected = circuit(
+ s"""|input clock : Clock
+ |input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |input sel : UInt<1>
+ |output out : UInt<8>
+ |reg r : UInt<2>, clock
+ |wire _in_mux : UInt<8>
+ |_in_mux is invalid
+ |when eq(UInt<1>("h0"), mux(sel, r, idx)) :
+ | _in_mux <= in[0]
+ |when eq(UInt<1>("h1"), mux(sel, r, idx)) :
+ | _in_mux <= in[1]
+ |when eq(UInt<2>("h2"), mux(sel, r, idx)) :
+ | _in_mux <= in[2]
+ |when eq(UInt<2>("h3"), mux(sel, r, idx)) :
+ | _in_mux <= in[3]
+ |out <= _in_mux
+ |r <= not(idx)"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "support nested subaccesses" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>[4]
+ |input jdx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx[jdx]]"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>[4]
+ |input jdx : UInt<2>
+ |output out : UInt<8>
+ |wire _idx_jdx : UInt<2>
+ |_idx_jdx is invalid
+ |when eq(UInt<1>("h0"), jdx) :
+ | _idx_jdx <= idx[0]
+ |when eq(UInt<1>("h1"), jdx) :
+ | _idx_jdx <= idx[1]
+ |when eq(UInt<2>("h2"), jdx) :
+ | _idx_jdx <= idx[2]
+ |when eq(UInt<2>("h3"), jdx) :
+ | _idx_jdx <= idx[3]
+ |wire _in_idx_jdx : UInt<8>
+ |_in_idx_jdx is invalid
+ |when eq(UInt<1>("h0"), _idx_jdx) :
+ | _in_idx_jdx <= in[0]
+ |when eq(UInt<1>("h1"), _idx_jdx) :
+ | _in_idx_jdx <= in[1]
+ |when eq(UInt<2>("h2"), _idx_jdx) :
+ | _in_idx_jdx <= in[2]
+ |when eq(UInt<2>("h3"), _idx_jdx) :
+ | _in_idx_jdx <= in[3]
+ |out <= _in_idx_jdx"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "avoid name collisions" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx]
+ |node _in_idx = not(idx)"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |wire _in_idx_0 : UInt<8>
+ |_in_idx_0 is invalid
+ |when eq(UInt<1>("h0"), idx) :
+ | _in_idx_0 <= in[0]
+ |when eq(UInt<1>("h1"), idx) :
+ | _in_idx_0 <= in[1]
+ |when eq(UInt<2>("h2"), idx) :
+ | _in_idx_0 <= in[2]
+ |when eq(UInt<2>("h3"), idx) :
+ | _in_idx_0 <= in[3]
+ |out <= _in_idx_0
+ |node _in_idx = not(idx)"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "handle a simple LHS subaccess" in {
+ val input = circuit(
+ s"""|input in : UInt<8>
+ |input idx : UInt<2>
+ |output out : UInt<8>[4]
+ |out[idx] <= in"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>
+ |input idx : UInt<2>
+ |output out : UInt<8>[4]
+ |wire _out_idx : UInt<8>
+ |when eq(UInt<1>("h0"), idx) :
+ | out[0] <= _out_idx
+ |when eq(UInt<1>("h1"), idx) :
+ | out[1] <= _out_idx
+ |when eq(UInt<2>("h2"), idx) :
+ | out[2] <= _out_idx
+ |when eq(UInt<2>("h3"), idx) :
+ | out[3] <= _out_idx
+ |_out_idx <= in"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "linearly expand RHS subaccesses of aggregate-typed vecs" in {
+ val input = circuit(
+ s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4]
+ |input idx : UInt<2>
+ |output out : { foo : UInt<8>, bar : UInt<8> }
+ |out.foo <= in[idx].foo
+ |out.bar <= in[idx].bar"""
+ )
+ val expected = circuit(
+ s"""|input in : { foo : UInt<8>, bar : UInt<8>}[4]
+ |input idx : UInt<2>
+ |output out : { foo : UInt<8>, bar : UInt<8>}
+ |wire _in_idx_foo : UInt<8>
+ |_in_idx_foo is invalid
+ |when eq(UInt<1>("h0"), idx) :
+ | _in_idx_foo <= in[0].foo
+ |when eq(UInt<1>("h1"), idx) :
+ | _in_idx_foo <= in[1].foo
+ |when eq(UInt<2>("h2"), idx) :
+ | _in_idx_foo <= in[2].foo
+ |when eq(UInt<2>("h3"), idx) :
+ | _in_idx_foo <= in[3].foo
+ |out.foo <= _in_idx_foo
+ |wire _in_idx_bar : UInt<8>
+ |_in_idx_bar is invalid
+ |when eq(UInt<1>("h0"), idx) :
+ | _in_idx_bar <= in[0].bar
+ |when eq(UInt<1>("h1"), idx) :
+ | _in_idx_bar <= in[1].bar
+ |when eq(UInt<2>("h2"), idx) :
+ | _in_idx_bar <= in[2].bar
+ |when eq(UInt<2>("h3"), idx) :
+ | _in_idx_bar <= in[3].bar
+ |out.bar <= _in_idx_bar"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "linearly expand LHS subaccesses of aggregate-typed vecs" in {
+ val input = circuit(
+ s"""|input in : { foo : UInt<8>, bar : UInt<8> }
+ |input idx : UInt<2>
+ |output out : { foo : UInt<8>, bar : UInt<8> }[4]
+ |out[idx].foo <= in.foo
+ |out[idx].bar <= in.bar"""
+ )
+ val expected = circuit(
+ s"""|input in : { foo : UInt<8>, bar : UInt<8> }
+ |input idx : UInt<2>
+ |output out : { foo : UInt<8>, bar : UInt<8> }[4]
+ |wire _out_idx_foo : UInt<8>
+ |when eq(UInt<1>("h0"), idx) :
+ | out[0].foo <= _out_idx_foo
+ |when eq(UInt<1>("h1"), idx) :
+ | out[1].foo <= _out_idx_foo
+ |when eq(UInt<2>("h2"), idx) :
+ | out[2].foo <= _out_idx_foo
+ |when eq(UInt<2>("h3"), idx) :
+ | out[3].foo <= _out_idx_foo
+ |_out_idx_foo <= in.foo
+ |wire _out_idx_bar : UInt<8>
+ |when eq(UInt<1>("h0"), idx) :
+ | out[0].bar <= _out_idx_bar
+ |when eq(UInt<1>("h1"), idx) :
+ | out[1].bar <= _out_idx_bar
+ |when eq(UInt<2>("h2"), idx) :
+ | out[2].bar <= _out_idx_bar
+ |when eq(UInt<2>("h3"), idx) :
+ | out[3].bar <= _out_idx_bar
+ |_out_idx_bar <= in.bar"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+}
diff --git a/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala
deleted file mode 100644
index f7d67026..00000000
--- a/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala
+++ /dev/null
@@ -1,223 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package firrtlTests
-package transforms
-
-import firrtl._
-import firrtl.testutils._
-import firrtl.stage.TransformManager
-import firrtl.options.Dependency
-import firrtl.transforms.CSESubAccesses
-
-class CSESubAccessesSpec extends FirrtlFlatSpec {
- def compile(input: String): String = {
- val manager = new TransformManager(Dependency[CSESubAccesses] :: Nil)
- val result = manager.execute(CircuitState(parse(input), Nil))
- result.circuit.serialize
- }
- def circuit(body: String): String = {
- """|circuit Test :
- | module Test :
- |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n")
- }
-
- behavior.of("CSESubAccesses")
-
- it should "hoist a single RHS subaccess" in {
- val input = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>
- |output out : UInt<8>
- |out <= in[idx]"""
- )
- val expected = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>
- |output out : UInt<8>
- |node _in_idx = in[idx]
- |out <= _in_idx"""
- )
- compile(input) should be(parse(expected).serialize)
- }
-
- it should "be idempotent" in {
- val input = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>
- |output out : UInt<8>
- |out <= in[idx]"""
- )
- val expected = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>
- |output out : UInt<8>
- |node _in_idx = in[idx]
- |out <= _in_idx"""
- )
- val first = compile(input)
- val second = compile(first)
- first should be(second)
- first should be(parse(expected).serialize)
- }
-
- it should "hoist a redundant RHS subaccess" in {
- val input = circuit(
- s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4]
- |input idx : UInt<2>
- |output out : { foo : UInt<8>, bar : UInt<8> }
- |out.foo <= in[idx].foo
- |out.bar <= in[idx].bar"""
- )
- val expected = circuit(
- s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4]
- |input idx : UInt<2>
- |output out : { foo : UInt<8>, bar : UInt<8> }
- |node _in_idx = in[idx]
- |out.foo <= _in_idx.foo
- |out.bar <= _in_idx.bar"""
- )
- compile(input) should be(parse(expected).serialize)
- }
-
- it should "correctly place hosited subaccess after last declaration it depends on" in {
- val input = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>
- |output out : UInt<8>
- |out is invalid
- |when UInt(1) :
- | node nidx = not(idx)
- | out <= in[nidx]
- |"""
- )
- val expected = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>
- |output out : UInt<8>
- |out is invalid
- |when UInt(1) :
- | node nidx = not(idx)
- | node _in_nidx = in[nidx]
- | out <= _in_nidx
- |"""
- )
- compile(input) should be(parse(expected).serialize)
- }
-
- it should "support complex expressions" in {
- val input = circuit(
- s"""|input clock : Clock
- |input in : UInt<8>[4]
- |input idx : UInt<2>
- |input sel : UInt<1>
- |output out : UInt<8>
- |reg r : UInt<2>, clock
- |out <= in[mux(sel, r, idx)]
- |r <= not(idx)"""
- )
- val expected = circuit(
- s"""|input clock : Clock
- |input in : UInt<8>[4]
- |input idx : UInt<2>
- |input sel : UInt<1>
- |output out : UInt<8>
- |reg r : UInt<2>, clock
- |node _in_mux = in[mux(sel, r, idx)]
- |out <= _in_mux
- |r <= not(idx)"""
- )
- compile(input) should be(parse(expected).serialize)
- }
-
- it should "support nested subaccesses" in {
- val input = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>[4]
- |input jdx : UInt<2>
- |output out : UInt<8>
- |out <= in[idx[jdx]]"""
- )
- val expected = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>[4]
- |input jdx : UInt<2>
- |output out : UInt<8>
- |node _idx_jdx = idx[jdx]
- |node _in_idx = in[_idx_jdx]
- |out <= _in_idx"""
- )
- compile(input) should be(parse(expected).serialize)
- }
-
- it should "avoid name collisions" in {
- val input = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>
- |output out : UInt<8>
- |out <= in[idx]
- |node _in_idx = not(idx)"""
- )
- val expected = circuit(
- s"""|input in : UInt<8>[4]
- |input idx : UInt<2>
- |output out : UInt<8>
- |node _in_idx_0 = in[idx]
- |out <= _in_idx_0
- |node _in_idx = not(idx)"""
- )
- compile(input) should be(parse(expected).serialize)
- }
-
- it should "have no effect on LHS SubAccesses" in {
- val input = circuit(
- s"""|input in : UInt<8>
- |input idx : UInt<2>
- |output out : UInt<8>[4]
- |out[idx] <= in"""
- )
- val expected = circuit(
- s"""|input in : UInt<8>
- |input idx : UInt<2>
- |output out : UInt<8>[4]
- |out[idx] <= in"""
- )
- compile(input) should be(parse(expected).serialize)
- }
-
- it should "ignore flipped LHS SubAccesses" in {
- val input = circuit(
- s"""|input in : { foo : UInt<8> }
- |input idx : UInt<1>
- |input out : { flip foo : UInt<8> }[2]
- |out[0].foo <= UInt(0)
- |out[1].foo <= UInt(0)
- |out[idx].foo <= in.foo"""
- )
- val expected = circuit(
- s"""|input in : { foo : UInt<8> }
- |input idx : UInt<1>
- |input out : { flip foo : UInt<8> }[2]
- |out[0].foo <= UInt(0)
- |out[1].foo <= UInt(0)
- |out[idx].foo <= in.foo"""
- )
- compile(input) should be(parse(expected).serialize)
- }
-
- it should "ignore SubAccesses of bidirectional aggregates" in {
- val input = circuit(
- s"""|input in : { flip foo : UInt<8>, bar : UInt<8> }
- |input idx : UInt<2>
- |output out : { flip foo : UInt<8>, bar : UInt<8> }[4]
- |out[idx] <= in"""
- )
- val expected = circuit(
- s"""|input in : { flip foo : UInt<8>, bar : UInt<8> }
- |input idx : UInt<2>
- |output out : { flip foo : UInt<8>, bar : UInt<8> }[4]
- |out[idx] <= in"""
- )
- compile(input) should be(parse(expected).serialize)
- }
-
-}