aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/graph/DiGraph.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/graph/DiGraph.scala')
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala34
1 files changed, 25 insertions, 9 deletions
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
index 6dad56d7..450ec4ff 100644
--- a/src/main/scala/firrtl/graph/DiGraph.scala
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -66,7 +66,6 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
/** Linearizes (topologically sorts) a DAG
*
- * @param root the start node
* @throws CyclicException if the graph is cyclic
* @return a Map[T,T] from each visited node to its predecessor in the
* traversal
@@ -75,8 +74,8 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
// permanently marked nodes are implicitly held in order
val order = new mutable.ArrayBuffer[T]
// invariant: no intersection between unmarked and tempMarked
- val unmarked = new LinkedHashSet[T]
- val tempMarked = new LinkedHashSet[T]
+ val unmarked = new mutable.LinkedHashSet[T]
+ val tempMarked = new mutable.LinkedHashSet[T]
def visit(n: T): Unit = {
if (tempMarked.contains(n)) {
@@ -94,7 +93,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
}
unmarked ++= getVertices
- while (!unmarked.isEmpty) {
+ while (unmarked.nonEmpty) {
visit(unmarked.head)
}
@@ -108,14 +107,23 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
* @return a Map[T,T] from each visited node to its predecessor in the
* traversal
*/
- def BFS(root: T): Map[T,T] = {
- val prev = new LinkedHashMap[T,T]
+ def BFS(root: T): Map[T,T] = BFS(root, Set.empty[T])
+
+ /** Performs breadth-first search on the directed graph, with a blacklist of nodes
+ *
+ * @param root the start node
+ * @param blacklist list of nodes to stop searching, if encountered
+ * @return a Map[T,T] from each visited node to its predecessor in the
+ * traversal
+ */
+ def BFS(root: T, blacklist: Set[T]): Map[T,T] = {
+ val prev = new mutable.LinkedHashMap[T,T]
val queue = new mutable.Queue[T]
queue.enqueue(root)
- while (!queue.isEmpty) {
+ while (queue.nonEmpty) {
val u = queue.dequeue
for (v <- getEdges(u)) {
- if (!prev.contains(v)) {
+ if (!prev.contains(v) && !blacklist.contains(v)) {
prev(v) = u
queue.enqueue(v)
}
@@ -129,7 +137,15 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
* @param root the start node
* @return a Set[T] of nodes reachable from the root
*/
- def reachableFrom(root: T): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root).map({ case (k, v) => k })
+ def reachableFrom(root: T): LinkedHashSet[T] = reachableFrom(root, Set.empty[T])
+
+ /** Finds the set of nodes reachable from a particular node, with a blacklist
+ *
+ * @param root the start node
+ * @param blacklist list of nodes to stop searching, if encountered
+ * @return a Set[T] of nodes reachable from the root
+ */
+ def reachableFrom(root: T, blacklist: Set[T]): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root, blacklist).map({ case (k, v) => k })
/** Finds a path (if one exists) from one node to another
*