diff options
Diffstat (limited to 'src/main/scala/firrtl/graph/DiGraph.scala')
| -rw-r--r-- | src/main/scala/firrtl/graph/DiGraph.scala | 34 |
1 files changed, 25 insertions, 9 deletions
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index 6dad56d7..450ec4ff 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -66,7 +66,6 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link /** Linearizes (topologically sorts) a DAG * - * @param root the start node * @throws CyclicException if the graph is cyclic * @return a Map[T,T] from each visited node to its predecessor in the * traversal @@ -75,8 +74,8 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link // permanently marked nodes are implicitly held in order val order = new mutable.ArrayBuffer[T] // invariant: no intersection between unmarked and tempMarked - val unmarked = new LinkedHashSet[T] - val tempMarked = new LinkedHashSet[T] + val unmarked = new mutable.LinkedHashSet[T] + val tempMarked = new mutable.LinkedHashSet[T] def visit(n: T): Unit = { if (tempMarked.contains(n)) { @@ -94,7 +93,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link } unmarked ++= getVertices - while (!unmarked.isEmpty) { + while (unmarked.nonEmpty) { visit(unmarked.head) } @@ -108,14 +107,23 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * @return a Map[T,T] from each visited node to its predecessor in the * traversal */ - def BFS(root: T): Map[T,T] = { - val prev = new LinkedHashMap[T,T] + def BFS(root: T): Map[T,T] = BFS(root, Set.empty[T]) + + /** Performs breadth-first search on the directed graph, with a blacklist of nodes + * + * @param root the start node + * @param blacklist list of nodes to stop searching, if encountered + * @return a Map[T,T] from each visited node to its predecessor in the + * traversal + */ + def BFS(root: T, blacklist: Set[T]): Map[T,T] = { + val prev = new mutable.LinkedHashMap[T,T] val queue = new mutable.Queue[T] queue.enqueue(root) - while (!queue.isEmpty) { + while (queue.nonEmpty) { val u = queue.dequeue for (v <- getEdges(u)) { - if (!prev.contains(v)) { + if (!prev.contains(v) && !blacklist.contains(v)) { prev(v) = u queue.enqueue(v) } @@ -129,7 +137,15 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * @param root the start node * @return a Set[T] of nodes reachable from the root */ - def reachableFrom(root: T): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root).map({ case (k, v) => k }) + def reachableFrom(root: T): LinkedHashSet[T] = reachableFrom(root, Set.empty[T]) + + /** Finds the set of nodes reachable from a particular node, with a blacklist + * + * @param root the start node + * @param blacklist list of nodes to stop searching, if encountered + * @return a Set[T] of nodes reachable from the root + */ + def reachableFrom(root: T, blacklist: Set[T]): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root, blacklist).map({ case (k, v) => k }) /** Finds a path (if one exists) from one node to another * |
