1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
|
// SPDX-License-Identifier: Apache-2.0
package firrtl
package transforms
import firrtl.ir._
import firrtl.traversals.Foreachers._
import firrtl.Mappers._
import firrtl.PrimOps._
import firrtl.WrappedExpression._
import firrtl.options.Dependency
import firrtl.passes._
import firrtl.Utils.{distinctBy, flow, getAllRefs, get_info, niceName}
import scala.collection.mutable
object CSESubAccesses {
// Get all SubAccesses used on the right-hand side along with the info from the outer Statement
private def collectRValueSubAccesses(mod: Module): Seq[(SubAccess, Info)] = {
val acc = new mutable.ListBuffer[(SubAccess, Info)]
def onExpr(outer: Statement)(expr: Expression): Unit = {
// Need postorder because we want to visit inner SubAccesses first
// Stop recursing on any non-Source because flips can make the SubAccess a Source despite the
// overall Expression being a Sink
if (flow(expr) == SourceFlow) expr.foreach(onExpr(outer))
expr match {
case e: SubAccess if flow(e) == SourceFlow => acc += e -> get_info(outer)
case _ => // Do nothing
}
}
def onStmt(stmt: Statement): Unit = {
stmt.foreach(onStmt)
stmt match {
// Don't record SubAccesses that are already assigned to a Node, but *do* record any nested
// inside of the SubAccess. This makes the transform idempotent and avoids unnecessary work.
case DefNode(_, _, acc: SubAccess) => acc.foreach(onExpr(stmt))
case other => other.foreach(onExpr(stmt))
}
}
onStmt(mod.body)
distinctBy(acc.toList)(_._1)
}
// Replaces all right-hand side SubAccesses with References
private def replaceOnSourceExpr(replace: SubAccess => Reference)(expr: Expression): Expression = expr match {
// Stop is we ever see a non-SourceFlow
case e if flow(e) != SourceFlow => e
// Don't traverse children of SubAccess, just replace it
// Nested SubAccesses are handled during creation of the nodes that the references refer to
case acc: SubAccess if flow(acc) == SourceFlow => replace(acc)
case other => other.map(replaceOnSourceExpr(replace))
}
private def hoistSubAccesses(
hoist: String => List[DefNode],
replace: SubAccess => Reference
)(stmt: Statement
): Statement = {
val onExpr = replaceOnSourceExpr(replace) _
def onStmt(s: Statement): Statement = s.map(onExpr).map(onStmt) match {
case decl: IsDeclaration =>
val nodes = hoist(decl.name)
if (nodes.isEmpty) decl else Block(decl :: nodes)
case other => other
}
onStmt(stmt)
}
// Given some nodes, determine after which String declaration each node should be inserted
// This function is *mutable*, it keeps track of which declarations each node is sensitive to and
// returns nodes in groups once the last declaration they depend on is seen
private def getSensitivityLookup(nodes: Iterable[DefNode]): String => List[DefNode] = {
case class ReferenceCount(var n: Int, node: DefNode)
// Gather names of declarations each node depends on
val nodeDeps = nodes.map(node => getAllRefs(node.value).view.map(_.name).toSet -> node)
// Map from declaration names to the indices of nodeDeps that depend on it
val lookup = new mutable.HashMap[String, mutable.ArrayBuffer[Int]]
for (((decls, _), idx) <- nodeDeps.zipWithIndex) {
for (d <- decls) {
val indices = lookup.getOrElseUpdate(d, new mutable.ArrayBuffer[Int])
indices += idx
}
}
// Now we can just associate each List of nodes with how many declarations they need to see
// We use an Array because we're mutating anyway and might as well be quick about it
val nodeLists: Array[ReferenceCount] =
nodeDeps.view.map { case (deps, node) => ReferenceCount(deps.size, node) }.toArray
// Must be a def because it's recursive
def func(decl: String): List[DefNode] = {
if (lookup.contains(decl)) {
val indices = lookup(decl)
val result = new mutable.ListBuffer[DefNode]
lookup -= decl
for (i <- indices) {
val refCount = nodeLists(i)
refCount.n -= 1
assert(refCount.n >= 0, "Internal Error!")
if (refCount.n == 0) result += refCount.node
}
// DefNodes can depend on each other, recurse
result.toList.flatMap { node => node :: func(node.name) }
} else {
Nil
}
}
func _
}
/** Performs [[CSESubAccesses]] on a single [[ir.Module Module]] */
def onMod(mod: Module): Module = {
// ***** Pre-Analyze (do we even need to do anything) *****
val accesses = collectRValueSubAccesses(mod)
if (accesses.isEmpty) mod
else {
// ***** Analyze *****
val namespace = Namespace(mod)
val replace = new mutable.HashMap[SubAccess, Reference]
val nodes = new mutable.ArrayBuffer[DefNode]
for ((acc, info) <- accesses) {
val name = namespace.newName(niceName(acc))
// SubAccesses can be nested, so replace any nested ones with prior references
// This is why post-order traversal in collectRValueSubAccesses is important
val accx = acc.map(replaceOnSourceExpr(replace))
val node = DefNode(info, name, accx)
val ref = Reference(node)
// Record in replace
replace(acc) = ref
// Record node
nodes += node
}
val hoist = getSensitivityLookup(nodes)
// ***** Transform *****
val portStmts = mod.ports.flatMap(x => hoist(x.name))
val bodyx = hoistSubAccesses(hoist, replace)(mod.body)
mod.copy(body = if (portStmts.isEmpty) bodyx else Block(Block(portStmts), bodyx))
}
}
}
/** Performs Common Subexpression Elimination (CSE) on right-hand side [[ir.SubAccess SubAccess]]es
*
* This avoids quadratic node creation behavior in [[passes.RemoveAccesses RemoveAccesses]]. For
* simplicity of implementation, all SubAccesses on the right-hand side are also split into
* individual nodes.
*/
class CSESubAccesses extends Transform with DependencyAPIMigration {
override def prerequisites = Dependency(ResolveFlows) :: Dependency(CheckHighForm) :: Nil
// Faster to run after these
override def optionalPrerequisites = Dependency(ReplaceAccesses) :: Dependency[DedupModules] :: Nil
// Running before ExpandConnects is an optimization
override def optionalPrerequisiteOf = Dependency(ExpandConnects) :: Nil
override def invalidates(a: Transform) = false
def execute(state: CircuitState): CircuitState = {
val modulesx = state.circuit.modules.map {
case ext: ExtModule => ext
case mod: Module => CSESubAccesses.onMod(mod)
}
state.copy(circuit = state.circuit.copy(modules = modulesx))
}
}
|