aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/graph/DiGraph.scala
blob: 99bf84038d4c7d6abd07df668789719beb3ed097 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
// SPDX-License-Identifier: Apache-2.0

package firrtl.graph

import scala.collection.{mutable, Map, Set}
import scala.collection.mutable.{LinkedHashMap, LinkedHashSet}
import firrtl.options.DependencyManagerUtils.{CharSet, PrettyCharSet}

/** An exception that is raised when an assumed DAG has a cycle */
class CyclicException(val node: Any) extends Exception(s"No valid linearization for cyclic graph, found at $node")

/** An exception that is raised when attempting to find an unreachable node */
class PathNotFoundException extends Exception("Unreachable node")

/** A companion to create DiGraphs from mutable data */
object DiGraph {

  /** Create a DiGraph from a MutableDigraph, representing the same graph */
  def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] = mdg

  /** Create a DiGraph from a Map[T,Set[T]] of edge data */
  def apply[T](edgeData: Map[T, Set[T]]): DiGraph[T] = {
    val edgeDataCopy = new LinkedHashMap[T, LinkedHashSet[T]]
    for ((k, v) <- edgeData) {
      edgeDataCopy(k) = new LinkedHashSet[T]
    }
    for ((k, v) <- edgeData) {
      for (n <- v) {
        require(edgeDataCopy.contains(n), s"Does not contain $n")
        edgeDataCopy(k) += n
      }
    }
    new DiGraph(edgeDataCopy)
  }

  /** Create a DiGraph from edges */
  def apply[T](edges: (T, T)*): DiGraph[T] = {
    val edgeMap = new LinkedHashMap[T, LinkedHashSet[T]]
    for ((from, to) <- edges) {
      val set = edgeMap.getOrElseUpdate(from, new LinkedHashSet[T])
      set += to
    }
    new DiGraph(edgeMap)
  }
}

/** Represents common behavior of all directed graphs */
class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) {

  /** Check whether the graph contains vertex v */
  def contains(v: T): Boolean = edges.contains(v)

  /** Get all vertices in the graph
    * @return a Set[T] of all vertices in the graph
    */
  // The pattern of mapping map pairs to keys maintains LinkedHashMap ordering
  def getVertices: Set[T] = new LinkedHashSet ++ edges.map({ case (k, _) => k })

  /** Get all edges of a node
    * @param v the specified node
    * @return a Set[T] of all vertices that v has edges to
    */
  def getEdges(v: T): Set[T] = edges.getOrElse(v, Set.empty)

  def getEdgeMap: Map[T, Set[T]] = edges

  /** Find all sources in the graph
    *
    * @return a Set[T] of source nodes
    */
  def findSources: Set[T] = getVertices -- edges.values.flatten.toSet

  /** Find all sinks in the graph
    *
    * @return a Set[T] of sink nodes
    */
  def findSinks: Set[T] = reverse.findSources

  /**
    * Finds a Seq of Nodes that form a loop
    * @param node Node to start loop path search from.
    * @return     The found Seq, the Seq is empty if there is no loop
    */
  def findLoopAtNode(node: T): Seq[T] = {
    var foundPath = Seq.empty[T]
    getEdges(node).exists { vertex =>
      try {
        foundPath = path(vertex, node, blacklist = Set.empty)
        true
      } catch {
        case _: PathNotFoundException =>
          foundPath = Seq.empty[T]
          false
        case t: Throwable =>
          throw t

      }
    }
    foundPath
  }

  /** Linearizes (topologically sorts) a DAG
    *
    * @throws CyclicException if the graph is cyclic
    * @return a Seq[T] describing the topological order of the DAG
    * traversal
    */
  def linearize: Seq[T] = {
    // permanently marked nodes are implicitly held in order
    val order = new mutable.ArrayBuffer[T]
    // invariant: no intersection between unmarked and tempMarked
    val unmarked = new mutable.LinkedHashSet[T]
    val tempMarked = new mutable.LinkedHashSet[T]

    case class LinearizeFrame[A](v: A, expanded: Boolean)
    val callStack = mutable.Stack[LinearizeFrame[T]]()

    unmarked ++= getVertices
    while (unmarked.nonEmpty) {
      callStack.push(LinearizeFrame(unmarked.head, false))
      while (callStack.nonEmpty) {
        val LinearizeFrame(n, expanded) = callStack.pop()
        if (!expanded) {
          if (tempMarked.contains(n)) {
            throw new CyclicException(n)
          }
          if (unmarked.contains(n)) {
            tempMarked += n
            unmarked -= n
            callStack.push(LinearizeFrame(n, true))
            // We want to visit the first edge first (so push it last)
            for (m <- edges.getOrElse(n, Set.empty).toSeq.reverse) {
              callStack.push(LinearizeFrame(m, false))
            }
          }
        } else {
          tempMarked -= n
          order.append(n)
        }
      }
    }

    // visited nodes are in post-traversal order, so must be reversed
    order.reverse.toSeq
  }

  /** Performs breadth-first search on the directed graph
    *
    * @param root the start node
    * @return a Map[T,T] from each visited node to its predecessor in the
    * traversal
    */
  def BFS(root: T): Map[T, T] = BFS(root, Set.empty[T])

  /** Performs breadth-first search on the directed graph, with a blacklist of nodes
    *
    * @param root the start node
    * @param blacklist list of nodes to avoid visiting, if encountered
    * @return a Map[T,T] from each visited node to its predecessor in the
    * traversal
    */
  def BFS(root: T, blacklist: Set[T]): Map[T, T] = {
    val prev = new mutable.LinkedHashMap[T, T]
    val queue = new mutable.Queue[T]
    queue.enqueue(root)
    while (queue.nonEmpty) {
      val u = queue.dequeue()
      for (v <- getEdges(u)) {
        if (!prev.contains(v) && !blacklist.contains(v)) {
          prev(v) = u
          queue.enqueue(v)
        }
      }
    }
    prev
  }

  /** Finds the set of nodes reachable from a particular node. The `root` node is *not* included in the
    * returned set unless it is possible to reach `root` along a non-trivial path beginning at
    * `root`; i.e., if the graph has a cycle that contains `root`.
    *
    * @param root the start node
    * @return a Set[T] of nodes reachable from `root`
    */
  def reachableFrom(root: T): LinkedHashSet[T] = reachableFrom(root, Set.empty[T])

  /** Finds the set of nodes reachable from a particular node, with a blacklist. The semantics of
    * adding a node to the blacklist is that any of its inedges will be ignored in the traversal.
    * The `root` node is *not* included in the returned set unless it is possible to reach `root` along
    * a non-trivial path beginning at `root`; i.e., if the graph has a cycle that contains `root`.
    *
    * @param root the start node
    * @param blacklist list of nodes to stop searching, if encountered
    * @return a Set[T] of nodes reachable from `root`
    */
  def reachableFrom(root: T, blacklist: Set[T]): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root, blacklist).map({
    case (k, v) => k
  })

  /** Finds a path (if one exists) from one node to another
    *
    * @param start the start node
    * @param end the destination node
    * @throws firrtl.graph.PathNotFoundException
    * @return a Seq[T] of nodes defining an arbitrary valid path
    */
  def path(start: T, end: T): Seq[T] = path(start, end, Set.empty[T])

  /** Finds a path (if one exists) from one node to another, with a blacklist
    *
    * @param start the start node
    * @param end the destination node
    * @param blacklist list of nodes which break path, if encountered
    * @throws firrtl.graph.PathNotFoundException
    * @return a Seq[T] of nodes defining an arbitrary valid path
    */
  def path(start: T, end: T, blacklist: Set[T]): Seq[T] = {
    val nodePath = new mutable.ArrayBuffer[T]
    val prev = BFS(start, blacklist)
    nodePath += end
    while (nodePath.last != start && prev.contains(nodePath.last)) {
      nodePath += prev(nodePath.last)
    }
    if (nodePath.last != start) {
      throw new PathNotFoundException
    }
    nodePath.toSeq.reverse
  }

  /** Finds the strongly connected components in the graph
    *
    * @return a Seq of Seq[T], each containing nodes of an SCC in traversable order
    */
  def findSCCs: Seq[Seq[T]] = {
    var counter: BigInt = 0
    val stack = new mutable.Stack[T]
    val onstack = new LinkedHashSet[T]
    val indices = new LinkedHashMap[T, BigInt]
    val lowlinks = new LinkedHashMap[T, BigInt]
    val sccs = new mutable.ArrayBuffer[Seq[T]]

    /*
     * Recursive code is transformed to iterative code by representing
     * call stack info in an explicit structure. Here, the stack data
     * consists of the current vertex, its currently active edge, and
     * the position in the function. Because there is only one
     * recursive call site, remembering whether a child call was
     * created on the last iteration where the current frame was
     * active is sufficient to track the position.
     */
    class StrongConnectFrame[A](val v: A, val edgeIter: Iterator[A], var childCall: Option[A] = None)
    val callStack = new mutable.Stack[StrongConnectFrame[T]]

    for (node <- getVertices) {
      callStack.push(new StrongConnectFrame(node, getEdges(node).iterator))
      while (!callStack.isEmpty) {
        val frame = callStack.top
        val v = frame.v
        frame.childCall match {
          case None =>
            indices(v) = counter
            lowlinks(v) = counter
            counter = counter + 1
            stack.push(v)
            onstack += v
          case Some(w) =>
            lowlinks(v) = lowlinks(v).min(lowlinks(w))
        }
        frame.childCall = None
        while (frame.edgeIter.hasNext && frame.childCall.isEmpty) {
          val w = frame.edgeIter.next()
          if (!indices.contains(w)) {
            frame.childCall = Some(w)
            callStack.push(new StrongConnectFrame(w, getEdges(w).iterator))
          } else if (onstack.contains(w)) {
            lowlinks(v) = lowlinks(v).min(indices(w))
          }
        }
        if (frame.childCall.isEmpty) {
          if (lowlinks(v) == indices(v)) {
            val scc = new mutable.ArrayBuffer[T]
            do {
              val w = stack.pop()
              onstack -= w
              scc += w
            } while (scc.last != v);
            sccs.append(scc.toSeq)
          }
          callStack.pop()
        }
      }
    }

    sccs.toSeq
  }

  /** Finds all paths starting at a particular node in a DAG
    *
    * WARNING: This is an exponential time algorithm (as any algorithm
    * must be for this problem), but is useful for flattening circuit
    * graph hierarchies. Each path is represented by a Seq[T] of nodes
    * in a traversable order.
    *
    * @param start the node to start at
    * @return a Map[T,Seq[Seq[T]]] where the value associated with v is the Seq of all paths from start to v
    */
  def pathsInDAG(start: T): LinkedHashMap[T, Seq[Seq[T]]] = {
    // paths(v) holds the set of paths from start to v
    val paths = new LinkedHashMap[T, mutable.Set[Seq[T]]]
    val queue = new mutable.Queue[T]
    val reachable = reachableFrom(start)
    def addBinding(n: T, p: Seq[T]): Unit = {
      paths.getOrElseUpdate(n, new LinkedHashSet[Seq[T]]) += p
    }
    addBinding(start, Seq(start))
    queue += start
    queue ++= linearize.filter(reachable.contains(_))
    while (!queue.isEmpty) {
      val current = queue.dequeue()
      for (v <- getEdges(current)) {
        for (p <- paths(current)) {
          addBinding(v, p :+ v)
        }
      }
    }
    paths.map({ case (k, v) => (k, v.toSeq) })
  }

  /** Returns a graph with all edges reversed */
  def reverse: DiGraph[T] = {
    val mdg = new MutableDiGraph[T]
    edges.foreach({ case (u, edges) => mdg.addVertex(u) })
    edges.foreach({
      case (u, edges) =>
        edges.foreach(v => mdg.addEdge(v, u))
    })
    DiGraph(mdg)
  }

  private def filterEdges(vprime: Set[T]): LinkedHashMap[T, LinkedHashSet[T]] = {
    def filterNodeSet(s:        LinkedHashSet[T]): LinkedHashSet[T] = s.filter({ case (k) => vprime.contains(k) })
    def filterAdjacencyLists(m: LinkedHashMap[T, LinkedHashSet[T]]): LinkedHashMap[T, LinkedHashSet[T]] = m.map({
      case (k, v) => (k, filterNodeSet(v))
    })
    val eprime: LinkedHashMap[T, LinkedHashSet[T]] = edges.filter({ case (k, v) => vprime.contains(k) })
    filterAdjacencyLists(eprime)
  }

  /** Return a graph with only a subset of the nodes
    *
    * Any edge including a deleted node will be deleted
    *
    * @param vprime the Set[T] of desired vertices
    * @throws java.lang.IllegalArgumentException if vprime is not a subset of V
    * @return the subgraph
    */
  def subgraph(vprime: Set[T]): DiGraph[T] = {
    require(vprime.subsetOf(edges.keySet))
    new DiGraph(filterEdges(vprime))
  }

  /** Return a simplified connectivity graph with only a subset of the nodes
    *
    * Any path between two non-deleted nodes (u,v) in the original graph will be
    * transformed into an edge (u,v).
    *
    * @param vprime the Set[T] of desired vertices
    * @throws java.lang.IllegalArgumentException if vprime is not a subset of V
    * @return the simplified graph
    */
  def simplify(vprime: Set[T]): DiGraph[T] = {
    require(vprime.subsetOf(edges.keySet))
    val pathEdges = vprime.map(v => (v, reachableFrom(v) & (vprime - v)))
    new DiGraph(new LinkedHashMap[T, LinkedHashSet[T]] ++ pathEdges)
  }

  /** Return a graph with all the nodes of the current graph transformed
    * by a function. Edge connectivity will be the same as the current
    * graph.
    *
    * @param f A function {(T) => Q} that transforms each node
    * @return a transformed DiGraph[Q]
    */
  def transformNodes[Q](f: (T) => Q): DiGraph[Q] = {
    val eprime = edges.map({ case (k, _) => (f(k), new LinkedHashSet[Q]) })
    edges.foreach({ case (k, v) => eprime(f(k)) ++= v.map(f(_)) })
    new DiGraph(eprime)
  }

  /** Graph sum of `this` and `that`
    *
    * @param that a second DiGraph[T]
    * @return a DiGraph[T] containing all vertices and edges of each graph
    */
  def +(that: DiGraph[T]): DiGraph[T] = {
    val eprime = edges.map({ case (k, v) => (k, v.clone) })
    that.edges.foreach({ case (k, v) => eprime.getOrElseUpdate(k, new LinkedHashSet[T]) ++= v })
    new DiGraph(eprime)
  }

  /** Serializes a `DiGraph[String]` as a pretty tree
    *
    * Multiple roots are supported, but cycles are not.
    */
  def prettyTree(charSet: CharSet = PrettyCharSet)(implicit ev: T =:= String): String = {
    // Set up characters for building the tree
    val (l, n, c) = (charSet.lastNode, charSet.notLastNode, charSet.continuation)
    val ctab = " " * c.size + " "

    // Recursively adds each node of the DiGraph to accumulating List[String]
    // Uses List because prepend is cheap and this prevents quadratic behavior of String
    //   concatenations or even flatMapping on Seqs
    def rec(tab: String, node: T, mark: String, prev: List[String]): List[String] = {
      val here = s"$mark$node"
      val children = this.getEdges(node)
      val last = children.size - 1
      children.toList // Convert LinkedHashSet to List to avoid determinism issues
        .zipWithIndex // Find last
        .foldLeft(here :: prev) {
          case (acc, (nodex, idx)) =>
            val nextTab = if (idx == last) tab + ctab else tab + c + " "
            val nextMark = if (idx == last) tab + l else tab + n
            rec(nextTab, nodex, nextMark + " ", acc)
        }
    }
    this.findSources.toList // Convert LinkedHashSet to List to avoid determinism issues
      .sortBy(_.toString) // Make order deterministic
      .foldLeft(Nil: List[String]) {
        case (acc, root) => rec("", root, "", acc)
      }
      .reverse
      .mkString("\n")
  }

}

class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]]) {

  /** Add vertex v to the graph
    * @return v, the added vertex
    */
  def addVertex(v: T): T = {
    edges.getOrElseUpdate(v, new LinkedHashSet[T])
    v
  }

  /** Add edge (u,v) to the graph.
    * @throws java.lang.IllegalArgumentException if u and/or v is not in the graph
    */
  def addEdge(u: T, v: T): Unit = {
    require(contains(u))
    require(contains(v))
    edges(u) += v
  }

  /** Add edge (u,v) to the graph, adding u and/or v if they are not
    * already in the graph.
    */
  def addPairWithEdge(u: T, v: T): Unit = {
    edges.getOrElseUpdate(v, new LinkedHashSet[T])
    edges.getOrElseUpdate(u, new LinkedHashSet[T]) += v
  }

  /** Add edge (u,v) to the graph if and only if both u and v are in
    * the graph prior to calling addEdgeIfValid.
    */
  def addEdgeIfValid(u: T, v: T): Boolean = {
    val valid = contains(u) && contains(v)
    if (contains(u) && contains(v)) {
      edges(u) += v
    }
    valid
  }
}