diff options
| author | Jack Koenig | 2021-03-03 17:01:53 -0800 |
|---|---|---|
| committer | GitHub | 2021-03-04 01:01:53 +0000 |
| commit | e58ba0c12e5d650983c70a61a45542f0cd43fb88 (patch) | |
| tree | fc0689df82dd3b9fcadb7ea8d5fc082b35afb20e | |
| parent | 5be1abb4c654279762a463a861526ce4e0c48035 (diff) | |
CSE SubAccesses (#2099)
Fixes n^2 performance problem when dynamically indexing Vecs of
aggregate types.
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 32 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/stage/Forms.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/CSESubAccesses.scala | 164 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/LowerTypesSpec.scala | 3 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/LoweringCompilersSpec.scala | 1 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala | 187 |
7 files changed, 390 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index d187ea5f..72884e25 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -5,6 +5,7 @@ package firrtl import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.WrappedExpression._ import scala.collection.mutable @@ -210,6 +211,24 @@ object Utils extends LazyLogging { case _ => false } + /** Selects all the elements of this list ignoring the duplicates as determined by == after + * applying the transforming function f + * + * @note In Scala Standard Library starting in 2.13 + */ + def distinctBy[A, B](xs: List[A])(f: A => B): List[A] = { + val buf = new mutable.ListBuffer[A] + val seen = new mutable.HashSet[B] + for (x <- xs) { + val y = f(x) + if (!seen(y)) { + buf += x + seen += y + } + } + buf.toList + } + /** Provide a nice name to create a temporary * */ def niceName(e: Expression): String = niceName(1)(e) def niceName(depth: Int)(e: Expression): String = { @@ -649,6 +668,19 @@ object Utils extends LazyLogging { case _ => NoInfo } + /** Finds all root References in a nested Expression */ + def getAllRefs(expr: Expression): Seq[Reference] = { + val refs = mutable.ListBuffer.empty[Reference] + def rec(e: Expression): Unit = { + e match { + case ref: Reference => refs += ref + case other => other.foreach(rec) + } + } + rec(expr) + refs.toList + } + /** Splits an Expression into root Ref and tail * * @example diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index 90437e56..7449db51 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -9,6 +9,7 @@ import firrtl.Mappers._ import firrtl.Utils._ import firrtl.WrappedExpression._ import firrtl.options.Dependency +import firrtl.transforms.CSESubAccesses import scala.collection.mutable @@ -21,7 +22,8 @@ object RemoveAccesses extends Pass { Dependency(PullMuxes), Dependency(ZeroLengthVecs), Dependency(ReplaceAccesses), - Dependency(ExpandConnects) + Dependency(ExpandConnects), + Dependency[CSESubAccesses] ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform): Boolean = a match { diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala index ab082151..ba27b552 100644 --- a/src/main/scala/firrtl/stage/Forms.scala +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -57,6 +57,7 @@ object Forms { val MidForm: Seq[TransformDependency] = HighForm ++ Seq( Dependency(passes.PullMuxes), + Dependency[firrtl.transforms.CSESubAccesses], Dependency(passes.ReplaceAccesses), Dependency(passes.ExpandConnects), Dependency(passes.RemoveAccesses), diff --git a/src/main/scala/firrtl/transforms/CSESubAccesses.scala b/src/main/scala/firrtl/transforms/CSESubAccesses.scala new file mode 100644 index 00000000..6ed3a5b5 --- /dev/null +++ b/src/main/scala/firrtl/transforms/CSESubAccesses.scala @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.traversals.Foreachers._ +import firrtl.Mappers._ +import firrtl.PrimOps._ +import firrtl.WrappedExpression._ +import firrtl.options.Dependency +import firrtl.passes._ +import firrtl.Utils.{distinctBy, flow, getAllRefs, get_info, niceName} + +import scala.collection.mutable + +object CSESubAccesses { + + // Get all SubAccesses used on the right-hand side along with the info from the outer Statement + private def collectRValueSubAccesses(mod: Module): Seq[(SubAccess, Info)] = { + val acc = new mutable.ListBuffer[(SubAccess, Info)] + def onExpr(outer: Statement)(expr: Expression): Unit = { + // Need postorder because we want to visit inner SubAccesses first + expr.foreach(onExpr(outer)) + expr match { + case e: SubAccess if flow(e) == SourceFlow => acc += e -> get_info(outer) + case _ => // Do nothing + } + } + def onStmt(stmt: Statement): Unit = { + stmt.foreach(onStmt) + stmt match { + // Don't record SubAccesses that are already assigned to a Node, but *do* record any nested + // inside of the SubAccess. This makes the transform idempotent and avoids unnecessary work. + case DefNode(_, _, acc: SubAccess) => acc.foreach(onExpr(stmt)) + case other => other.foreach(onExpr(stmt)) + } + } + onStmt(mod.body) + distinctBy(acc.toList)(_._1) + } + + // Replaces all right-hand side SubAccesses with References + private def replaceOnSourceExpr(replace: SubAccess => Reference)(expr: Expression): Expression = expr match { + // Don't traverse children of SubAccess, just replace it + // Nested SubAccesses are handled during creation of the nodes that the references refer to + case acc: SubAccess if flow(acc) == SourceFlow => replace(acc) + case other => other.map(replaceOnSourceExpr(replace)) + } + + private def hoistSubAccesses( + hoist: String => List[DefNode], + replace: SubAccess => Reference + )(stmt: Statement + ): Statement = { + val onExpr = replaceOnSourceExpr(replace) _ + def onStmt(s: Statement): Statement = s.map(onExpr).map(onStmt) match { + case decl: IsDeclaration => + val nodes = hoist(decl.name) + if (nodes.isEmpty) decl else Block(decl :: nodes) + case other => other + } + onStmt(stmt) + } + + // Given some nodes, determine after which String declaration each node should be inserted + // This function is *mutable*, it keeps track of which declarations each node is sensitive to and + // returns nodes in groups once the last declaration they depend on is seen + private def getSensitivityLookup(nodes: Iterable[DefNode]): String => List[DefNode] = { + case class ReferenceCount(var n: Int, node: DefNode) + // Gather names of declarations each node depends on + val nodeDeps = nodes.map(node => getAllRefs(node.value).view.map(_.name).toSet -> node) + // Map from declaration names to the indices of nodeDeps that depend on it + val lookup = new mutable.HashMap[String, mutable.ArrayBuffer[Int]] + for (((decls, _), idx) <- nodeDeps.zipWithIndex) { + for (d <- decls) { + val indices = lookup.getOrElseUpdate(d, new mutable.ArrayBuffer[Int]) + indices += idx + } + } + // Now we can just associate each List of nodes with how many declarations they need to see + // We use an Array because we're mutating anyway and might as well be quick about it + val nodeLists: Array[ReferenceCount] = + nodeDeps.view.map { case (deps, node) => ReferenceCount(deps.size, node) }.toArray + + // Must be a def because it's recursive + def func(decl: String): List[DefNode] = { + if (lookup.contains(decl)) { + val indices = lookup(decl) + val result = new mutable.ListBuffer[DefNode] + lookup -= decl + for (i <- indices) { + val refCount = nodeLists(i) + refCount.n -= 1 + assert(refCount.n >= 0, "Internal Error!") + if (refCount.n == 0) result += refCount.node + } + // DefNodes can depend on each other, recurse + result.toList.flatMap { node => node :: func(node.name) } + } else { + Nil + } + } + func _ + } + + /** Performs [[CSESubAccesses]] on a single [[ir.Module Module]] */ + def onMod(mod: Module): Module = { + // ***** Pre-Analyze (do we even need to do anything) ***** + val accesses = collectRValueSubAccesses(mod) + if (accesses.isEmpty) mod + else { + // ***** Analyze ***** + val namespace = Namespace(mod) + val replace = new mutable.HashMap[SubAccess, Reference] + val nodes = new mutable.ArrayBuffer[DefNode] + for ((acc, info) <- accesses) { + val name = namespace.newName(niceName(acc)) + // SubAccesses can be nested, so replace any nested ones with prior references + // This is why post-order traversal in collectRValueSubAccesses is important + val accx = acc.map(replaceOnSourceExpr(replace)) + val node = DefNode(info, name, accx) + val ref = Reference(node) + // Record in replace + replace(acc) = ref + // Record node + nodes += node + } + val hoist = getSensitivityLookup(nodes) + + // ***** Transform ***** + val portStmts = mod.ports.flatMap(x => hoist(x.name)) + val bodyx = hoistSubAccesses(hoist, replace)(mod.body) + mod.copy(body = if (portStmts.isEmpty) bodyx else Block(Block(portStmts), bodyx)) + } + } +} + +/** Performs Common Subexpression Elimination (CSE) on right-hand side [[ir.SubAccess SubAccess]]es + * + * This avoids quadratic node creation behavior in [[passes.RemoveAccesses RemoveAccesses]]. For + * simplicity of implementation, all SubAccesses on the right-hand side are also split into + * individual nodes. + */ +class CSESubAccesses extends Transform with DependencyAPIMigration { + + override def prerequisites = Dependency(ResolveFlows) :: Dependency(CheckHighForm) :: Nil + + // Faster to run after these + override def optionalPrerequisites = Dependency(ReplaceAccesses) :: Dependency[DedupModules] :: Nil + + // Running before ExpandConnects is an optimization + override def optionalPrerequisiteOf = Dependency(ExpandConnects) :: Nil + + override def invalidates(a: Transform) = false + + def execute(state: CircuitState): CircuitState = { + val modulesx = state.circuit.modules.map { + case ext: ExtModule => ext + case mod: Module => CSESubAccesses.onMod(mod) + } + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 78d03e68..9425a582 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -486,7 +486,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | out <= in0[in1[in2[0]]][in1[in2[1]]] |""".stripMargin val expected = Seq( - "out <= _in0_in1_in1_in2_1" + "node _in0_in1_in1 = _in0_in1_in1_in2_1", + "out <= _in0_in1_in1" ) executeTest(input, expected) diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index bdc72e7b..ee6077d3 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -180,6 +180,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { it should "replicate the old order" in { val tm = new TransformManager(Forms.MidForm, Forms.Deduped) val patches = Seq( + Add(2, Seq(Dependency[firrtl.transforms.CSESubAccesses])), Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))), Add(5, Seq(Dependency(firrtl.passes.ResolveKinds))), // Uniquify is now part of [[firrtl.passes.LowerTypes]] diff --git a/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala new file mode 100644 index 00000000..55ce07df --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/CSESubAccessesSpec.scala @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests +package transforms + +import firrtl._ +import firrtl.testutils._ +import firrtl.stage.TransformManager +import firrtl.options.Dependency +import firrtl.transforms.CSESubAccesses + +class CSESubAccessesSpec extends FirrtlFlatSpec { + def compile(input: String): String = { + val manager = new TransformManager(Dependency[CSESubAccesses] :: Nil) + val result = manager.execute(CircuitState(parse(input), Nil)) + result.circuit.serialize + } + def circuit(body: String): String = { + """|circuit Test : + | module Test : + |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n") + } + + behavior.of("CSESubAccesses") + + it should "hoist a single RHS subaccess" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |node _in_idx = in[idx] + |out <= _in_idx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "be idempotent" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |node _in_idx = in[idx] + |out <= _in_idx""" + ) + val first = compile(input) + val second = compile(first) + first should be(second) + first should be(parse(expected).serialize) + } + + it should "hoist a redundant RHS subaccess" in { + val input = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> } + |out.foo <= in[idx].foo + |out.bar <= in[idx].bar""" + ) + val expected = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> } + |node _in_idx = in[idx] + |out.foo <= _in_idx.foo + |out.bar <= _in_idx.bar""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "correctly place hosited subaccess after last declaration it depends on" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out is invalid + |when UInt(1) : + | node nidx = not(idx) + | out <= in[nidx] + |""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out is invalid + |when UInt(1) : + | node nidx = not(idx) + | node _in_nidx = in[nidx] + | out <= _in_nidx + |""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support complex expressions" in { + val input = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |out <= in[mux(sel, r, idx)] + |r <= not(idx)""" + ) + val expected = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |node _in_mux = in[mux(sel, r, idx)] + |out <= _in_mux + |r <= not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support nested subaccesses" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |out <= in[idx[jdx]]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |node _idx_jdx = idx[jdx] + |node _in_idx = in[_idx_jdx] + |out <= _in_idx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "avoid name collisions" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx] + |node _in_idx = not(idx)""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |node _in_idx_0 = in[idx] + |out <= _in_idx_0 + |node _in_idx = not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "have no effect on LHS SubAccesses" in { + val input = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |out[idx] <= in""" + ) + val expected = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |out[idx] <= in""" + ) + compile(input) should be(parse(expected).serialize) + } + +} |
