1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
|
// SPDX-License-Identifier: Apache-2.0
package firrtl.transforms
import scala.collection.mutable
import firrtl._
import firrtl.ir._
import firrtl.passes.{Errors, PassException}
import firrtl.traversals.Foreachers._
import firrtl.annotations._
import firrtl.Utils.throwInternalError
import firrtl.graph._
import firrtl.analyses.InstanceKeyGraph
import firrtl.options.{Dependency, RegisteredTransform, ShellOption}
/**
* A case class that represents a net in the circuit. This is necessary since combinational loop
* checking is an analysis on the netlist of the circuit; the fields are specialized for low FIRRTL.
* Since all wires are ground types, a given ground type net may only be a subfield of an instance
* or a memory port. Therefore, it is uniquely specified within its module context by its name, its
* optional parent instance (a WDefInstance or WDefMemory), and its optional memory port name.
*/
case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None)
object LogicNode {
/**
* Construct a LogicNode from a *Low FIRRTL* reference or subfield that refers to a component.
* Since aggregate types appear in Low FIRRTL only as the full types of instances or memories,
* the only reference-like expressions that may appear are WRefs, or trees of up to two levels of
* WSubField field selection.
*
* @param e The reference-like expression to describe with a LogicNode
* @return a LogicNode referring to e
*/
def apply(e: Expression): LogicNode = e match {
case r: WRef =>
LogicNode(r.name)
case s: WSubField =>
s.expr match {
case modref: WRef =>
LogicNode(s.name, Some(modref.name))
case memport: WSubField =>
memport.expr match {
case memref: WRef =>
LogicNode(s.name, Some(memref.name), Some(memport.name))
case _ => throwInternalError(s"LogicNode: unrecognized subsubfield expression - $memport")
}
case _ => throwInternalError(s"LogicNode: unrecognized subfield expression - $s")
}
}
}
object CheckCombLoops {
type AbstractConnMap = DiGraph[LogicNode]
type ConnMap = DiGraph[LogicNode] with EdgeData[LogicNode, Info]
type MutableConnMap = MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info]
class CombLoopException(info: Info, mname: String, cycle: Seq[String])
extends PassException(s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n"))
}
case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation
case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarget) extends Annotation {
if (!source.isLocal || !sink.isLocal || source.module != sink.module) {
throwInternalError(s"ExtModulePathAnnotation must connect two local targets from the same module")
}
override def getTargets: Seq[ReferenceTarget] = Seq(source, sink)
override def update(renames: RenameMap): Seq[Annotation] = {
val sources = renames.get(source).getOrElse(Seq(source))
val sinks = renames.get(sink).getOrElse(Seq(sink))
val paths = sources.flatMap { s => sinks.map((s, _)) }
paths.collect {
case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink)
}
}
}
case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget]) extends Annotation {
override def update(renames: RenameMap): Seq[Annotation] = {
val newSources = sources.flatMap { s => renames(s) }.collect { case x: ReferenceTarget if x.isLocal => x }
val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x }
newSinks.map(snk => CombinationalPath(snk, newSources))
}
}
/** Finds and detects combinational logic loops in a circuit, if any exist. Returns the input circuit with no
* modifications.
*
* @throws firrtl.transforms.CheckCombLoops.CombLoopException if a loop is found
* @note Input form: Low FIRRTL
* @note Output form: Low FIRRTL (identity transform)
* @note The pass looks for loops through combinational-read memories
* @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules
* @note The pass will throw exceptions on "false paths"
*/
class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
Seq(Dependency(passes.LowerTypes), Dependency(firrtl.transforms.RemoveReset))
override def optionalPrerequisites = Seq.empty
override def optionalPrerequisiteOf = Seq.empty
override def invalidates(a: Transform) = false
import CheckCombLoops._
val options = Seq(
new ShellOption[Unit](
longOption = "no-check-comb-loops",
toAnnotationSeq = (_: Unit) => Seq(DontCheckCombLoopsAnnotation),
helpText = "Disable combinational loop checking"
)
)
private def getExprDeps(deps: MutableConnMap, v: LogicNode, info: Info)(e: Expression): Unit = e match {
case r: WRef => deps.addEdgeIfValid(v, LogicNode(r), info)
case s: WSubField => deps.addEdgeIfValid(v, LogicNode(s), info)
case _ => e.foreach(getExprDeps(deps, v, info))
}
private def getStmtDeps(
simplifiedModules: mutable.Map[String, AbstractConnMap],
deps: MutableConnMap
)(s: Statement
): Unit = s match {
case Connect(info, loc, expr) =>
val lhs = LogicNode(loc)
if (deps.contains(lhs)) {
getExprDeps(deps, lhs, info)(expr)
}
case w: DefWire =>
deps.addVertex(LogicNode(w.name))
case DefNode(info, name, value) =>
val lhs = LogicNode(name)
deps.addVertex(lhs)
getExprDeps(deps, lhs, info)(value)
case m: DefMemory if (m.readLatency == 0) =>
for (rp <- m.readers) {
val dataNode = deps.addVertex(LogicNode("data", Some(m.name), Some(rp)))
val addr = LogicNode("addr", Some(m.name), Some(rp))
val en = LogicNode("en", Some(m.name), Some(rp))
deps.addEdge(dataNode, deps.addVertex(addr), m.info)
deps.addEdge(dataNode, deps.addVertex(en), m.info)
}
case i: WDefInstance =>
val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
iGraph.getVertices.foreach(deps.addVertex(_))
iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v, _) } })
case _ =>
s.foreach(getStmtDeps(simplifiedModules, deps))
}
// Pretty-print a LogicNode with a prepended hierarchical path
private def prettyPrintAbsoluteRef(hierPrefix: Seq[String], node: LogicNode): String = {
(hierPrefix ++ node.inst ++ node.memport :+ node.name).mkString(".")
}
/*
* Recover the full path from a path passing through simplified
* instances. Since edges may pass through simplified instances, the
* hierarchy that the path passes through must be recursively
* recovered.
*/
private def expandInstancePaths(
m: String,
moduleGraphs: mutable.Map[String, ConnMap],
moduleDeps: Map[String, Map[String, String]],
hierPrefix: Seq[String],
path: Seq[LogicNode]
): Seq[String] = {
// Recover info from edge data, add to error string
def info(u: LogicNode, v: LogicNode): String =
moduleGraphs(m).getEdgeData(u, v).map(_.toString).mkString("\t", "", "")
// lhs comes after rhs
val pathNodes = (path.zip(path.tail)).map {
case (rhs, lhs) =>
if (lhs.inst.isDefined && !lhs.memport.isDefined && lhs.inst == rhs.inst) {
val child = moduleDeps(m)(lhs.inst.get)
val newHierPrefix = hierPrefix :+ lhs.inst.get
val subpath = moduleGraphs(child).path(lhs.copy(inst = None), rhs.copy(inst = None)).reverse
expandInstancePaths(child, moduleGraphs, moduleDeps, newHierPrefix, subpath)
} else {
Seq(prettyPrintAbsoluteRef(hierPrefix, lhs) ++ info(lhs, rhs))
}
}
pathNodes.flatten
}
/*
* An SCC may contain more than one loop. In this case, the sequence
* of nodes forming the SCC cannot be interpreted as a simple
* cycle. However, it is desirable to print an error consisting of a
* loop rather than an arbitrary ordering of the SCC. This function
* operates on a pruned subgraph composed only of the SCC and finds
* a simple cycle by performing an arbitrary walk.
*/
private def findCycleInSCC[T](sccGraph: DiGraph[T]): Seq[T] = {
val walk = new mutable.ArrayBuffer[T]
val visited = new mutable.HashSet[T]
var current = sccGraph.getVertices.head
while (!visited.contains(current)) {
walk += current
visited += current
current = sccGraph.getEdges(current).head
}
walk.drop(walk.indexOf(current)).toSeq :+ current
}
/*
* This implementation of combinational loop detection avoids ever
* generating a full netlist from the FIRRTL circuit. Instead, each
* module is converted to a netlist and analyzed locally, with its
* subinstances represented by trivial, simplified subgraphs. The
* overall outline of the process is:
*
* 1. Create a graph of module instance dependances
* 2. Linearize this acyclic graph
*
* 3. Generate a local netlist; replace any instances with
* simplified subgraphs representing connectivity of their IOs
*
* 4. Check for nontrivial strongly connected components
*
* 5. Create a reduced representation of the netlist with only the
* module IOs as nodes, where output X (which must be a ground type,
* as only low FIRRTL is supported) will have an edge to input Y if
* and only if it combinationally depends on input Y. Associate this
* reduced graph with the module for future use.
*/
private def run(state: CircuitState) = {
val c = state.circuit
val errors = new Errors()
val extModulePaths = state.annotations.groupBy {
case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module)
case ann: Annotation => CircuitTarget(c.main)
}
val moduleMap = c.modules.map({ m => (m.name, m) }).toMap
val iGraph = InstanceKeyGraph(c).graph
val moduleDeps =
iGraph.getEdgeMap.map({ case (k, v) => (k.module, (v.map { i => (i.name, i.module) }).toMap) }).toMap
val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse.map { moduleMap(_) }
val moduleGraphs = new mutable.HashMap[String, ConnMap]
val simplifiedModuleGraphs = new mutable.HashMap[String, AbstractConnMap]
topoSortedModules.foreach {
case em: ExtModule =>
val portSet = em.ports.map(p => LogicNode(p.name)).toSet
val extModuleDeps = new MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info]
portSet.foreach(extModuleDeps.addVertex(_))
extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect {
case a: ExtModulePathAnnotation =>
extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref))
}
moduleGraphs(em.name) = extModuleDeps
simplifiedModuleGraphs(em.name) = extModuleDeps.simplify(portSet)
case m: Module =>
val portSet = m.ports.map(p => LogicNode(p.name)).toSet
val internalDeps = new MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info]
portSet.foreach(internalDeps.addVertex(_))
m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps))
moduleGraphs(m.name) = internalDeps
simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify(portSet)
// Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs!
for (unitLoopNode <- internalDeps.getVertices.filter(v => internalDeps.getEdges(v).contains(v))) {
errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name)))
}
for (scc <- internalDeps.findSCCs.filter(_.length > 1)) {
val sccSubgraph = internalDeps.subgraph(scc.toSet)
val cycle = findCycleInSCC(sccSubgraph)
(cycle.zip(cycle.tail)).foreach({ case (a, b) => require(internalDeps.getEdges(a).contains(b)) })
// Reverse to make sure LHS comes after RHS, print repeated vertex at start for legibility
val intuitiveCycle = cycle.reverse
val repeatedInitial = prettyPrintAbsoluteRef(Seq(m.name), intuitiveCycle.head)
val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), intuitiveCycle)
errors.append(new CombLoopException(m.info, m.name, repeatedInitial +: expandedCycle))
}
case m => throwInternalError(s"Module ${m.name} has unrecognized type")
}
val mt = ModuleTarget(c.main, c.main)
val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect {
case (from, tos) if tos.nonEmpty =>
val sink = mt.ref(from.name)
val sources = tos.map(to => mt.ref(to.name))
CombinationalPath(sink, sources.toSeq)
}
(state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs, moduleGraphs)
}
/**
* Returns a Map from Module name to port connectivity
*/
def analyze(state: CircuitState): collection.Map[String, DiGraph[String]] = {
val (result, errors, connectivity, _) = run(state)
connectivity.map {
case (k, v) => (k, v.transformNodes(ln => ln.name))
}
}
/**
* Returns a Map from Module name to complete netlist connectivity
*/
def analyzeFull(state: CircuitState): collection.Map[String, DiGraph[LogicNode]] = {
run(state)._4
}
def execute(state: CircuitState): CircuitState = {
val dontRun = state.annotations.contains(DontCheckCombLoopsAnnotation)
if (dontRun) {
logger.warn("Skipping Combinational Loop Detection")
state
} else {
val (result, errors, connectivity, _) = run(state)
errors.trigger()
result
}
}
}
|