aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2021-03-03 17:01:53 -0800
committerGitHub2021-03-04 01:01:53 +0000
commite58ba0c12e5d650983c70a61a45542f0cd43fb88 (patch)
treefc0689df82dd3b9fcadb7ea8d5fc082b35afb20e /src
parent5be1abb4c654279762a463a861526ce4e0c48035 (diff)
CSE SubAccesses (#2099)
Fixes n^2 performance problem when dynamically indexing Vecs of aggregate types. Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Utils.scala32
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala4
-rw-r--r--src/main/scala/firrtl/stage/Forms.scala1
-rw-r--r--src/main/scala/firrtl/transforms/CSESubAccesses.scala164
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/LoweringCompilersSpec.scala1
-rw-r--r--src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala187
7 files changed, 390 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index d187ea5f..72884e25 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -5,6 +5,7 @@ package firrtl
import firrtl.ir._
import firrtl.PrimOps._
import firrtl.Mappers._
+import firrtl.traversals.Foreachers._
import firrtl.WrappedExpression._
import scala.collection.mutable
@@ -210,6 +211,24 @@ object Utils extends LazyLogging {
case _ => false
}
+ /** Selects all the elements of this list ignoring the duplicates as determined by == after
+ * applying the transforming function f
+ *
+ * @note In Scala Standard Library starting in 2.13
+ */
+ def distinctBy[A, B](xs: List[A])(f: A => B): List[A] = {
+ val buf = new mutable.ListBuffer[A]
+ val seen = new mutable.HashSet[B]
+ for (x <- xs) {
+ val y = f(x)
+ if (!seen(y)) {
+ buf += x
+ seen += y
+ }
+ }
+ buf.toList
+ }
+
/** Provide a nice name to create a temporary * */
def niceName(e: Expression): String = niceName(1)(e)
def niceName(depth: Int)(e: Expression): String = {
@@ -649,6 +668,19 @@ object Utils extends LazyLogging {
case _ => NoInfo
}
+ /** Finds all root References in a nested Expression */
+ def getAllRefs(expr: Expression): Seq[Reference] = {
+ val refs = mutable.ListBuffer.empty[Reference]
+ def rec(e: Expression): Unit = {
+ e match {
+ case ref: Reference => refs += ref
+ case other => other.foreach(rec)
+ }
+ }
+ rec(expr)
+ refs.toList
+ }
+
/** Splits an Expression into root Ref and tail
*
* @example
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index 90437e56..7449db51 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -9,6 +9,7 @@ import firrtl.Mappers._
import firrtl.Utils._
import firrtl.WrappedExpression._
import firrtl.options.Dependency
+import firrtl.transforms.CSESubAccesses
import scala.collection.mutable
@@ -21,7 +22,8 @@ object RemoveAccesses extends Pass {
Dependency(PullMuxes),
Dependency(ZeroLengthVecs),
Dependency(ReplaceAccesses),
- Dependency(ExpandConnects)
+ Dependency(ExpandConnects),
+ Dependency[CSESubAccesses]
) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala
index ab082151..ba27b552 100644
--- a/src/main/scala/firrtl/stage/Forms.scala
+++ b/src/main/scala/firrtl/stage/Forms.scala
@@ -57,6 +57,7 @@ object Forms {
val MidForm: Seq[TransformDependency] = HighForm ++
Seq(
Dependency(passes.PullMuxes),
+ Dependency[firrtl.transforms.CSESubAccesses],
Dependency(passes.ReplaceAccesses),
Dependency(passes.ExpandConnects),
Dependency(passes.RemoveAccesses),
diff --git a/src/main/scala/firrtl/transforms/CSESubAccesses.scala b/src/main/scala/firrtl/transforms/CSESubAccesses.scala
new file mode 100644
index 00000000..6ed3a5b5
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/CSESubAccesses.scala
@@ -0,0 +1,164 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.traversals.Foreachers._
+import firrtl.Mappers._
+import firrtl.PrimOps._
+import firrtl.WrappedExpression._
+import firrtl.options.Dependency
+import firrtl.passes._
+import firrtl.Utils.{distinctBy, flow, getAllRefs, get_info, niceName}
+
+import scala.collection.mutable
+
+object CSESubAccesses {
+
+ // Get all SubAccesses used on the right-hand side along with the info from the outer Statement
+ private def collectRValueSubAccesses(mod: Module): Seq[(SubAccess, Info)] = {
+ val acc = new mutable.ListBuffer[(SubAccess, Info)]
+ def onExpr(outer: Statement)(expr: Expression): Unit = {
+ // Need postorder because we want to visit inner SubAccesses first
+ expr.foreach(onExpr(outer))
+ expr match {
+ case e: SubAccess if flow(e) == SourceFlow => acc += e -> get_info(outer)
+ case _ => // Do nothing
+ }
+ }
+ def onStmt(stmt: Statement): Unit = {
+ stmt.foreach(onStmt)
+ stmt match {
+ // Don't record SubAccesses that are already assigned to a Node, but *do* record any nested
+ // inside of the SubAccess. This makes the transform idempotent and avoids unnecessary work.
+ case DefNode(_, _, acc: SubAccess) => acc.foreach(onExpr(stmt))
+ case other => other.foreach(onExpr(stmt))
+ }
+ }
+ onStmt(mod.body)
+ distinctBy(acc.toList)(_._1)
+ }
+
+ // Replaces all right-hand side SubAccesses with References
+ private def replaceOnSourceExpr(replace: SubAccess => Reference)(expr: Expression): Expression = expr match {
+ // Don't traverse children of SubAccess, just replace it
+ // Nested SubAccesses are handled during creation of the nodes that the references refer to
+ case acc: SubAccess if flow(acc) == SourceFlow => replace(acc)
+ case other => other.map(replaceOnSourceExpr(replace))
+ }
+
+ private def hoistSubAccesses(
+ hoist: String => List[DefNode],
+ replace: SubAccess => Reference
+ )(stmt: Statement
+ ): Statement = {
+ val onExpr = replaceOnSourceExpr(replace) _
+ def onStmt(s: Statement): Statement = s.map(onExpr).map(onStmt) match {
+ case decl: IsDeclaration =>
+ val nodes = hoist(decl.name)
+ if (nodes.isEmpty) decl else Block(decl :: nodes)
+ case other => other
+ }
+ onStmt(stmt)
+ }
+
+ // Given some nodes, determine after which String declaration each node should be inserted
+ // This function is *mutable*, it keeps track of which declarations each node is sensitive to and
+ // returns nodes in groups once the last declaration they depend on is seen
+ private def getSensitivityLookup(nodes: Iterable[DefNode]): String => List[DefNode] = {
+ case class ReferenceCount(var n: Int, node: DefNode)
+ // Gather names of declarations each node depends on
+ val nodeDeps = nodes.map(node => getAllRefs(node.value).view.map(_.name).toSet -> node)
+ // Map from declaration names to the indices of nodeDeps that depend on it
+ val lookup = new mutable.HashMap[String, mutable.ArrayBuffer[Int]]
+ for (((decls, _), idx) <- nodeDeps.zipWithIndex) {
+ for (d <- decls) {
+ val indices = lookup.getOrElseUpdate(d, new mutable.ArrayBuffer[Int])
+ indices += idx
+ }
+ }
+ // Now we can just associate each List of nodes with how many declarations they need to see
+ // We use an Array because we're mutating anyway and might as well be quick about it
+ val nodeLists: Array[ReferenceCount] =
+ nodeDeps.view.map { case (deps, node) => ReferenceCount(deps.size, node) }.toArray
+
+ // Must be a def because it's recursive
+ def func(decl: String): List[DefNode] = {
+ if (lookup.contains(decl)) {
+ val indices = lookup(decl)
+ val result = new mutable.ListBuffer[DefNode]
+ lookup -= decl
+ for (i <- indices) {
+ val refCount = nodeLists(i)
+ refCount.n -= 1
+ assert(refCount.n >= 0, "Internal Error!")
+ if (refCount.n == 0) result += refCount.node
+ }
+ // DefNodes can depend on each other, recurse
+ result.toList.flatMap { node => node :: func(node.name) }
+ } else {
+ Nil
+ }
+ }
+ func _
+ }
+
+ /** Performs [[CSESubAccesses]] on a single [[ir.Module Module]] */
+ def onMod(mod: Module): Module = {
+ // ***** Pre-Analyze (do we even need to do anything) *****
+ val accesses = collectRValueSubAccesses(mod)
+ if (accesses.isEmpty) mod
+ else {
+ // ***** Analyze *****
+ val namespace = Namespace(mod)
+ val replace = new mutable.HashMap[SubAccess, Reference]
+ val nodes = new mutable.ArrayBuffer[DefNode]
+ for ((acc, info) <- accesses) {
+ val name = namespace.newName(niceName(acc))
+ // SubAccesses can be nested, so replace any nested ones with prior references
+ // This is why post-order traversal in collectRValueSubAccesses is important
+ val accx = acc.map(replaceOnSourceExpr(replace))
+ val node = DefNode(info, name, accx)
+ val ref = Reference(node)
+ // Record in replace
+ replace(acc) = ref
+ // Record node
+ nodes += node
+ }
+ val hoist = getSensitivityLookup(nodes)
+
+ // ***** Transform *****
+ val portStmts = mod.ports.flatMap(x => hoist(x.name))
+ val bodyx = hoistSubAccesses(hoist, replace)(mod.body)
+ mod.copy(body = if (portStmts.isEmpty) bodyx else Block(Block(portStmts), bodyx))
+ }
+ }
+}
+
+/** Performs Common Subexpression Elimination (CSE) on right-hand side [[ir.SubAccess SubAccess]]es
+ *
+ * This avoids quadratic node creation behavior in [[passes.RemoveAccesses RemoveAccesses]]. For
+ * simplicity of implementation, all SubAccesses on the right-hand side are also split into
+ * individual nodes.
+ */
+class CSESubAccesses extends Transform with DependencyAPIMigration {
+
+ override def prerequisites = Dependency(ResolveFlows) :: Dependency(CheckHighForm) :: Nil
+
+ // Faster to run after these
+ override def optionalPrerequisites = Dependency(ReplaceAccesses) :: Dependency[DedupModules] :: Nil
+
+ // Running before ExpandConnects is an optimization
+ override def optionalPrerequisiteOf = Dependency(ExpandConnects) :: Nil
+
+ override def invalidates(a: Transform) = false
+
+ def execute(state: CircuitState): CircuitState = {
+ val modulesx = state.circuit.modules.map {
+ case ext: ExtModule => ext
+ case mod: Module => CSESubAccesses.onMod(mod)
+ }
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index 78d03e68..9425a582 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -486,7 +486,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec {
| out <= in0[in1[in2[0]]][in1[in2[1]]]
|""".stripMargin
val expected = Seq(
- "out <= _in0_in1_in1_in2_1"
+ "node _in0_in1_in1 = _in0_in1_in1_in2_1",
+ "out <= _in0_in1_in1"
)
executeTest(input, expected)
diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
index bdc72e7b..ee6077d3 100644
--- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
+++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
@@ -180,6 +180,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.MidForm, Forms.Deduped)
val patches = Seq(
+ Add(2, Seq(Dependency[firrtl.transforms.CSESubAccesses])),
Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))),
Add(5, Seq(Dependency(firrtl.passes.ResolveKinds))),
// Uniquify is now part of [[firrtl.passes.LowerTypes]]
diff --git a/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala
new file mode 100644
index 00000000..55ce07df
--- /dev/null
+++ b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala
@@ -0,0 +1,187 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtlTests
+package transforms
+
+import firrtl._
+import firrtl.testutils._
+import firrtl.stage.TransformManager
+import firrtl.options.Dependency
+import firrtl.transforms.CSESubAccesses
+
+class CSESubAccessesSpec extends FirrtlFlatSpec {
+ def compile(input: String): String = {
+ val manager = new TransformManager(Dependency[CSESubAccesses] :: Nil)
+ val result = manager.execute(CircuitState(parse(input), Nil))
+ result.circuit.serialize
+ }
+ def circuit(body: String): String = {
+ """|circuit Test :
+ | module Test :
+ |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n")
+ }
+
+ behavior.of("CSESubAccesses")
+
+ it should "hoist a single RHS subaccess" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx]"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |node _in_idx = in[idx]
+ |out <= _in_idx"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "be idempotent" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx]"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |node _in_idx = in[idx]
+ |out <= _in_idx"""
+ )
+ val first = compile(input)
+ val second = compile(first)
+ first should be(second)
+ first should be(parse(expected).serialize)
+ }
+
+ it should "hoist a redundant RHS subaccess" in {
+ val input = circuit(
+ s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4]
+ |input idx : UInt<2>
+ |output out : { foo : UInt<8>, bar : UInt<8> }
+ |out.foo <= in[idx].foo
+ |out.bar <= in[idx].bar"""
+ )
+ val expected = circuit(
+ s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4]
+ |input idx : UInt<2>
+ |output out : { foo : UInt<8>, bar : UInt<8> }
+ |node _in_idx = in[idx]
+ |out.foo <= _in_idx.foo
+ |out.bar <= _in_idx.bar"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "correctly place hosited subaccess after last declaration it depends on" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out is invalid
+ |when UInt(1) :
+ | node nidx = not(idx)
+ | out <= in[nidx]
+ |"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out is invalid
+ |when UInt(1) :
+ | node nidx = not(idx)
+ | node _in_nidx = in[nidx]
+ | out <= _in_nidx
+ |"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "support complex expressions" in {
+ val input = circuit(
+ s"""|input clock : Clock
+ |input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |input sel : UInt<1>
+ |output out : UInt<8>
+ |reg r : UInt<2>, clock
+ |out <= in[mux(sel, r, idx)]
+ |r <= not(idx)"""
+ )
+ val expected = circuit(
+ s"""|input clock : Clock
+ |input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |input sel : UInt<1>
+ |output out : UInt<8>
+ |reg r : UInt<2>, clock
+ |node _in_mux = in[mux(sel, r, idx)]
+ |out <= _in_mux
+ |r <= not(idx)"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "support nested subaccesses" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>[4]
+ |input jdx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx[jdx]]"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>[4]
+ |input jdx : UInt<2>
+ |output out : UInt<8>
+ |node _idx_jdx = idx[jdx]
+ |node _in_idx = in[_idx_jdx]
+ |out <= _in_idx"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "avoid name collisions" in {
+ val input = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |out <= in[idx]
+ |node _in_idx = not(idx)"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>[4]
+ |input idx : UInt<2>
+ |output out : UInt<8>
+ |node _in_idx_0 = in[idx]
+ |out <= _in_idx_0
+ |node _in_idx = not(idx)"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+ it should "have no effect on LHS SubAccesses" in {
+ val input = circuit(
+ s"""|input in : UInt<8>
+ |input idx : UInt<2>
+ |output out : UInt<8>[4]
+ |out[idx] <= in"""
+ )
+ val expected = circuit(
+ s"""|input in : UInt<8>
+ |input idx : UInt<2>
+ |output out : UInt<8>[4]
+ |out[idx] <= in"""
+ )
+ compile(input) should be(parse(expected).serialize)
+ }
+
+}