aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala40
-rw-r--r--src/test/scala/firrtlTests/graph/DiGraphTests.scala58
2 files changed, 66 insertions, 32 deletions
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
index 135603ff..22f5fcc5 100644
--- a/src/main/scala/firrtl/graph/DiGraph.scala
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -25,7 +25,7 @@ object DiGraph {
}
for ((k, v) <- edgeData) {
for (n <- v) {
- require(edgeDataCopy.contains(n))
+ require(edgeDataCopy.contains(n), s"Does not contain $n")
edgeDataCopy(k) += n
}
}
@@ -77,24 +77,32 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
val unmarked = new mutable.LinkedHashSet[T]
val tempMarked = new mutable.LinkedHashSet[T]
- def visit(n: T): Unit = {
- if (tempMarked.contains(n)) {
- throw new CyclicException(n)
- }
- if (unmarked.contains(n)) {
- tempMarked += n
- unmarked -= n
- for (m <- getEdges(n)) {
- visit(m)
- }
- tempMarked -= n
- order.append(n)
- }
- }
+ case class LinearizeFrame[T](v: T, expanded: Boolean)
+ val callStack = mutable.Stack[LinearizeFrame[T]]()
unmarked ++= getVertices
while (unmarked.nonEmpty) {
- visit(unmarked.head)
+ callStack.push(LinearizeFrame(unmarked.head, false))
+ while (callStack.nonEmpty) {
+ val LinearizeFrame(n, expanded) = callStack.pop()
+ if (!expanded) {
+ if (tempMarked.contains(n)) {
+ throw new CyclicException(n)
+ }
+ if (unmarked.contains(n)) {
+ tempMarked += n
+ unmarked -= n
+ callStack.push(LinearizeFrame(n, true))
+ // We want to visit the first edge first (so push it last)
+ for (m <- edges.getOrElse(n, Set.empty).toSeq.reverse) {
+ callStack.push(LinearizeFrame(m, false))
+ }
+ }
+ } else {
+ tempMarked -= n
+ order.append(n)
+ }
+ }
}
// visited nodes are in post-traversal order, so must be reversed
diff --git a/src/test/scala/firrtlTests/graph/DiGraphTests.scala b/src/test/scala/firrtlTests/graph/DiGraphTests.scala
index a0f45c80..52ded253 100644
--- a/src/test/scala/firrtlTests/graph/DiGraphTests.scala
+++ b/src/test/scala/firrtlTests/graph/DiGraphTests.scala
@@ -40,31 +40,45 @@ class DiGraphTests extends FirrtlFlatSpec {
val degenerateGraph = DiGraph(Map("a" -> Set.empty[String]))
- acyclicGraph.findSCCs.filter(_.length > 1) shouldBe empty
-
- cyclicGraph.findSCCs.filter(_.length > 1) should not be empty
+ "A graph without cycles" should "have NOT SCCs" in {
+ acyclicGraph.findSCCs.filter(_.length > 1) shouldBe empty
+ }
- acyclicGraph.path("a","e") should not be empty
+ "A graph with cycles" should "have SCCs" in {
+ cyclicGraph.findSCCs.filter(_.length > 1) should not be empty
+ }
- an [PathNotFoundException] should be thrownBy acyclicGraph.path("e","a")
+ "Asking a DiGraph for a path that exists" should "work" in {
+ acyclicGraph.path("a","e") should not be empty
+ }
- acyclicGraph.linearize.head should equal ("a")
+ "Asking a DiGraph for a path from one node to another with no path" should "error" in {
+ an [PathNotFoundException] should be thrownBy acyclicGraph.path("e","a")
+ }
- a [CyclicException] should be thrownBy cyclicGraph.linearize
+ "The first element in a linearized graph with a single root node" should "be the root" in {
+ acyclicGraph.linearize.head should equal ("a")
+ }
- try {
- cyclicGraph.linearize
+ "A DiGraph with a cycle" should "error when linearized" in {
+ a [CyclicException] should be thrownBy cyclicGraph.linearize
}
- catch {
- case c: CyclicException =>
- c.getMessage.contains("found at a") should be (true)
- c.node.asInstanceOf[String] should be ("a")
- case _: Throwable =>
+
+ "CyclicExceptions" should "contain information about the cycle" in {
+ val c = the [CyclicException] thrownBy {
+ cyclicGraph.linearize
+ }
+ c.getMessage.contains("found at a") should be (true)
+ c.node.asInstanceOf[String] should be ("a")
}
- acyclicGraph.reverse.getEdgeMap should equal (reversedAcyclicGraph.getEdgeMap)
+ "Reversing a graph" should "reverse all of the edges" in {
+ acyclicGraph.reverse.getEdgeMap should equal (reversedAcyclicGraph.getEdgeMap)
+ }
- degenerateGraph.getEdgeMap should equal (degenerateGraph.reverse.getEdgeMap)
+ "Reversing a graph with no edges" should "equal the graph itself" in {
+ degenerateGraph.getEdgeMap should equal (degenerateGraph.reverse.getEdgeMap)
+ }
"transformNodes" should "combine vertices that collide, not drop them" in {
tupleGraph.transformNodes(_._1).getEdgeMap should contain ("a" -> Set("b", "c"))
@@ -84,4 +98,16 @@ class DiGraphTests extends FirrtlFlatSpec {
(first + second + second + second).getEdgeMap should equal (acyclicGraph.getEdgeMap)
}
+ "linearize" should "not cause a stack overflow on very large graphs" in {
+ // Graph of 0 -> 1, 1 -> 2, etc.
+ val N = 10000
+ val edges = (1 to N).zipWithIndex.map({ case (n, idx) => idx -> Set(n)}).toMap
+ val bigGraph = DiGraph(edges + (N -> Set.empty[Int]))
+ bigGraph.linearize should be (0 to N)
+ }
+
+ it should "work on multi-rooted graphs" in {
+ val graph = DiGraph(Map("a" -> Set[String](), "b" -> Set[String]()))
+ graph.linearize.toSet should be (graph.getVertices)
+ }
}