diff options
| author | Jack Koenig | 2021-03-29 10:21:26 -0700 |
|---|---|---|
| committer | GitHub | 2021-03-29 17:21:26 +0000 |
| commit | a41af6f0a34f9e13866002f19040a40ef55ee9e5 (patch) | |
| tree | 6ef270e7cc4e848940c00cda0b9518d5a7481761 /src/main | |
| parent | abeff01f0714d5474b9d18d78fc13011e5ad6b99 (diff) | |
Fix RemoveAccesses, delete CSESubAccesses (#2157)
CSESubAccesses was intended to be a simple workaround for a quadratic
performance bug in RemoveAccesses but ended up having tricky corner
cases and was hard to get right. The solution to the RemoveAccesses
bug--quadratic expansion of dynamic indexes of vecs of aggreate
type--turned out to be quite simple and makes CSESubAccesses much less
useful and not worth fixing.
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 63 | ||||
| -rw-r--r-- | src/main/scala/firrtl/stage/Forms.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/CSESubAccesses.scala | 168 |
3 files changed, 35 insertions, 197 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index 7449db51..073bf49d 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -9,7 +9,6 @@ import firrtl.Mappers._ import firrtl.Utils._ import firrtl.WrappedExpression._ import firrtl.options.Dependency -import firrtl.transforms.CSESubAccesses import scala.collection.mutable @@ -22,8 +21,7 @@ object RemoveAccesses extends Pass { Dependency(PullMuxes), Dependency(ZeroLengthVecs), Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency[CSESubAccesses] + Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform): Boolean = a match { @@ -122,26 +120,26 @@ object RemoveAccesses extends Pass { /** Replaces a subaccess in a given source expression */ val stmts = mutable.ArrayBuffer[Statement]() - def removeSource(e: Expression): Expression = e match { - case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) => - val rs = getLocations(e) - rs.find(x => x.guard != one) match { - case None => throwInternalError(s"removeSource: shouldn't be here - $e") - case Some(_) => - val (wire, temp) = create_temp(e) - val temps = create_exps(temp) - def getTemp(i: Int) = temps(i % temps.size) - stmts += wire - rs.zipWithIndex.foreach { - case (x, i) if i < temps.size => - stmts += IsInvalid(get_info(s), getTemp(i)) - stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) - case (x, i) => - stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) - } - temp - } - case _ => e + // Only called on RefLikes that definitely have a SubAccess + // Must accept Expression because that's the output type of fixIndices + def removeSource(e: Expression): Expression = { + val rs = getLocations(e) + rs.find(x => x.guard != one) match { + case None => throwInternalError(s"removeSource: shouldn't be here - $e") + case Some(_) => + val (wire, temp) = create_temp(e) + val temps = create_exps(temp) + def getTemp(i: Int) = temps(i % temps.size) + stmts += wire + rs.zipWithIndex.foreach { + case (x, i) if i < temps.size => + stmts += IsInvalid(get_info(s), getTemp(i)) + stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) + case (x, i) => + stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) + } + temp + } } /** Replaces a subaccess in a given sink expression @@ -162,14 +160,23 @@ object RemoveAccesses extends Pass { case _ => loc } + /** Recurse until find SubAccess and call fixSource on its index + * @note this only accepts [[RefLikeExpression]]s but we can't enforce it because map + * requires Expression => Expression + */ + def fixIndices(e: Expression): Expression = e match { + case e: SubAccess => e.copy(index = fixSource(e.index)) + case other => other.map(fixIndices) + } + /** Recursively walks a source expression and fixes all subaccesses - * If we see a sub-access, replace it. - * Otherwise, map to children. + * + * If we see a RefLikeExpression that contains a SubAccess, we recursively remove + * subaccesses from the indices of any SubAccesses, then process modified RefLikeExpression */ def fixSource(e: Expression): Expression = e match { - case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow)) - //case w: WSubIndex => removeSource(w) - //case w: WSubField => removeSource(w) + case ref: RefLikeExpression => + if (hasAccess(ref)) removeSource(fixIndices(ref)) else ref case x => x.map(fixSource) } diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala index ba27b552..ab082151 100644 --- a/src/main/scala/firrtl/stage/Forms.scala +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -57,7 +57,6 @@ object Forms { val MidForm: Seq[TransformDependency] = HighForm ++ Seq( Dependency(passes.PullMuxes), - Dependency[firrtl.transforms.CSESubAccesses], Dependency(passes.ReplaceAccesses), Dependency(passes.ExpandConnects), Dependency(passes.RemoveAccesses), diff --git a/src/main/scala/firrtl/transforms/CSESubAccesses.scala b/src/main/scala/firrtl/transforms/CSESubAccesses.scala deleted file mode 100644 index 6ca27a83..00000000 --- a/src/main/scala/firrtl/transforms/CSESubAccesses.scala +++ /dev/null @@ -1,168 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtl -package transforms - -import firrtl.ir._ -import firrtl.traversals.Foreachers._ -import firrtl.Mappers._ -import firrtl.PrimOps._ -import firrtl.WrappedExpression._ -import firrtl.options.Dependency -import firrtl.passes._ -import firrtl.Utils.{distinctBy, flow, getAllRefs, get_info, niceName} - -import scala.collection.mutable - -object CSESubAccesses { - - // Get all SubAccesses used on the right-hand side along with the info from the outer Statement - private def collectRValueSubAccesses(mod: Module): Seq[(SubAccess, Info)] = { - val acc = new mutable.ListBuffer[(SubAccess, Info)] - def onExpr(outer: Statement)(expr: Expression): Unit = { - // Need postorder because we want to visit inner SubAccesses first - // Stop recursing on any non-Source because flips can make the SubAccess a Source despite the - // overall Expression being a Sink - if (flow(expr) == SourceFlow) expr.foreach(onExpr(outer)) - expr match { - case e: SubAccess if flow(e) == SourceFlow => acc += e -> get_info(outer) - case _ => // Do nothing - } - } - def onStmt(stmt: Statement): Unit = { - stmt.foreach(onStmt) - stmt match { - // Don't record SubAccesses that are already assigned to a Node, but *do* record any nested - // inside of the SubAccess. This makes the transform idempotent and avoids unnecessary work. - case DefNode(_, _, acc: SubAccess) => acc.foreach(onExpr(stmt)) - case other => other.foreach(onExpr(stmt)) - } - } - onStmt(mod.body) - distinctBy(acc.toList)(_._1) - } - - // Replaces all right-hand side SubAccesses with References - private def replaceOnSourceExpr(replace: SubAccess => Reference)(expr: Expression): Expression = expr match { - // Stop is we ever see a non-SourceFlow - case e if flow(e) != SourceFlow => e - // Don't traverse children of SubAccess, just replace it - // Nested SubAccesses are handled during creation of the nodes that the references refer to - case acc: SubAccess if flow(acc) == SourceFlow => replace(acc) - case other => other.map(replaceOnSourceExpr(replace)) - } - - private def hoistSubAccesses( - hoist: String => List[DefNode], - replace: SubAccess => Reference - )(stmt: Statement - ): Statement = { - val onExpr = replaceOnSourceExpr(replace) _ - def onStmt(s: Statement): Statement = s.map(onExpr).map(onStmt) match { - case decl: IsDeclaration => - val nodes = hoist(decl.name) - if (nodes.isEmpty) decl else Block(decl :: nodes) - case other => other - } - onStmt(stmt) - } - - // Given some nodes, determine after which String declaration each node should be inserted - // This function is *mutable*, it keeps track of which declarations each node is sensitive to and - // returns nodes in groups once the last declaration they depend on is seen - private def getSensitivityLookup(nodes: Iterable[DefNode]): String => List[DefNode] = { - case class ReferenceCount(var n: Int, node: DefNode) - // Gather names of declarations each node depends on - val nodeDeps = nodes.map(node => getAllRefs(node.value).view.map(_.name).toSet -> node) - // Map from declaration names to the indices of nodeDeps that depend on it - val lookup = new mutable.HashMap[String, mutable.ArrayBuffer[Int]] - for (((decls, _), idx) <- nodeDeps.zipWithIndex) { - for (d <- decls) { - val indices = lookup.getOrElseUpdate(d, new mutable.ArrayBuffer[Int]) - indices += idx - } - } - // Now we can just associate each List of nodes with how many declarations they need to see - // We use an Array because we're mutating anyway and might as well be quick about it - val nodeLists: Array[ReferenceCount] = - nodeDeps.view.map { case (deps, node) => ReferenceCount(deps.size, node) }.toArray - - // Must be a def because it's recursive - def func(decl: String): List[DefNode] = { - if (lookup.contains(decl)) { - val indices = lookup(decl) - val result = new mutable.ListBuffer[DefNode] - lookup -= decl - for (i <- indices) { - val refCount = nodeLists(i) - refCount.n -= 1 - assert(refCount.n >= 0, "Internal Error!") - if (refCount.n == 0) result += refCount.node - } - // DefNodes can depend on each other, recurse - result.toList.flatMap { node => node :: func(node.name) } - } else { - Nil - } - } - func _ - } - - /** Performs [[CSESubAccesses]] on a single [[ir.Module Module]] */ - def onMod(mod: Module): Module = { - // ***** Pre-Analyze (do we even need to do anything) ***** - val accesses = collectRValueSubAccesses(mod) - if (accesses.isEmpty) mod - else { - // ***** Analyze ***** - val namespace = Namespace(mod) - val replace = new mutable.HashMap[SubAccess, Reference] - val nodes = new mutable.ArrayBuffer[DefNode] - for ((acc, info) <- accesses) { - val name = namespace.newName(niceName(acc)) - // SubAccesses can be nested, so replace any nested ones with prior references - // This is why post-order traversal in collectRValueSubAccesses is important - val accx = acc.map(replaceOnSourceExpr(replace)) - val node = DefNode(info, name, accx) - val ref = Reference(node) - // Record in replace - replace(acc) = ref - // Record node - nodes += node - } - val hoist = getSensitivityLookup(nodes) - - // ***** Transform ***** - val portStmts = mod.ports.flatMap(x => hoist(x.name)) - val bodyx = hoistSubAccesses(hoist, replace)(mod.body) - mod.copy(body = if (portStmts.isEmpty) bodyx else Block(Block(portStmts), bodyx)) - } - } -} - -/** Performs Common Subexpression Elimination (CSE) on right-hand side [[ir.SubAccess SubAccess]]es - * - * This avoids quadratic node creation behavior in [[passes.RemoveAccesses RemoveAccesses]]. For - * simplicity of implementation, all SubAccesses on the right-hand side are also split into - * individual nodes. - */ -class CSESubAccesses extends Transform with DependencyAPIMigration { - - override def prerequisites = Dependency(ResolveFlows) :: Dependency(CheckHighForm) :: Nil - - // Faster to run after these - override def optionalPrerequisites = Dependency(ReplaceAccesses) :: Dependency[DedupModules] :: Nil - - // Running before ExpandConnects is an optimization - override def optionalPrerequisiteOf = Dependency(ExpandConnects) :: Nil - - override def invalidates(a: Transform) = false - - def execute(state: CircuitState): CircuitState = { - val modulesx = state.circuit.modules.map { - case ext: ExtModule => ext - case mod: Module => CSESubAccesses.onMod(mod) - } - state.copy(circuit = state.circuit.copy(modules = modulesx)) - } -} |
