diff options
| author | Jack Koenig | 2018-04-11 10:19:16 -0700 |
|---|---|---|
| committer | GitHub | 2018-04-11 10:19:16 -0700 |
| commit | 27ee6fbbdf2b1854503ef51ffc0e2108a939d50c (patch) | |
| tree | c0a4de1fbacec6ea1b5828f1b50181adc936f3df | |
| parent | 0c93a121c109c3e167d80173dbfe8c2e355b30ef (diff) | |
Make DiGraph.linearize be iterative instead of recursive (#785)
Also make DiGraphTests more ScalaTest-y
| -rw-r--r-- | src/main/scala/firrtl/graph/DiGraph.scala | 40 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/graph/DiGraphTests.scala | 58 |
2 files changed, 66 insertions, 32 deletions
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index 135603ff..22f5fcc5 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -25,7 +25,7 @@ object DiGraph { } for ((k, v) <- edgeData) { for (n <- v) { - require(edgeDataCopy.contains(n)) + require(edgeDataCopy.contains(n), s"Does not contain $n") edgeDataCopy(k) += n } } @@ -77,24 +77,32 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link val unmarked = new mutable.LinkedHashSet[T] val tempMarked = new mutable.LinkedHashSet[T] - def visit(n: T): Unit = { - if (tempMarked.contains(n)) { - throw new CyclicException(n) - } - if (unmarked.contains(n)) { - tempMarked += n - unmarked -= n - for (m <- getEdges(n)) { - visit(m) - } - tempMarked -= n - order.append(n) - } - } + case class LinearizeFrame[T](v: T, expanded: Boolean) + val callStack = mutable.Stack[LinearizeFrame[T]]() unmarked ++= getVertices while (unmarked.nonEmpty) { - visit(unmarked.head) + callStack.push(LinearizeFrame(unmarked.head, false)) + while (callStack.nonEmpty) { + val LinearizeFrame(n, expanded) = callStack.pop() + if (!expanded) { + if (tempMarked.contains(n)) { + throw new CyclicException(n) + } + if (unmarked.contains(n)) { + tempMarked += n + unmarked -= n + callStack.push(LinearizeFrame(n, true)) + // We want to visit the first edge first (so push it last) + for (m <- edges.getOrElse(n, Set.empty).toSeq.reverse) { + callStack.push(LinearizeFrame(m, false)) + } + } + } else { + tempMarked -= n + order.append(n) + } + } } // visited nodes are in post-traversal order, so must be reversed diff --git a/src/test/scala/firrtlTests/graph/DiGraphTests.scala b/src/test/scala/firrtlTests/graph/DiGraphTests.scala index a0f45c80..52ded253 100644 --- a/src/test/scala/firrtlTests/graph/DiGraphTests.scala +++ b/src/test/scala/firrtlTests/graph/DiGraphTests.scala @@ -40,31 +40,45 @@ class DiGraphTests extends FirrtlFlatSpec { val degenerateGraph = DiGraph(Map("a" -> Set.empty[String])) - acyclicGraph.findSCCs.filter(_.length > 1) shouldBe empty - - cyclicGraph.findSCCs.filter(_.length > 1) should not be empty + "A graph without cycles" should "have NOT SCCs" in { + acyclicGraph.findSCCs.filter(_.length > 1) shouldBe empty + } - acyclicGraph.path("a","e") should not be empty + "A graph with cycles" should "have SCCs" in { + cyclicGraph.findSCCs.filter(_.length > 1) should not be empty + } - an [PathNotFoundException] should be thrownBy acyclicGraph.path("e","a") + "Asking a DiGraph for a path that exists" should "work" in { + acyclicGraph.path("a","e") should not be empty + } - acyclicGraph.linearize.head should equal ("a") + "Asking a DiGraph for a path from one node to another with no path" should "error" in { + an [PathNotFoundException] should be thrownBy acyclicGraph.path("e","a") + } - a [CyclicException] should be thrownBy cyclicGraph.linearize + "The first element in a linearized graph with a single root node" should "be the root" in { + acyclicGraph.linearize.head should equal ("a") + } - try { - cyclicGraph.linearize + "A DiGraph with a cycle" should "error when linearized" in { + a [CyclicException] should be thrownBy cyclicGraph.linearize } - catch { - case c: CyclicException => - c.getMessage.contains("found at a") should be (true) - c.node.asInstanceOf[String] should be ("a") - case _: Throwable => + + "CyclicExceptions" should "contain information about the cycle" in { + val c = the [CyclicException] thrownBy { + cyclicGraph.linearize + } + c.getMessage.contains("found at a") should be (true) + c.node.asInstanceOf[String] should be ("a") } - acyclicGraph.reverse.getEdgeMap should equal (reversedAcyclicGraph.getEdgeMap) + "Reversing a graph" should "reverse all of the edges" in { + acyclicGraph.reverse.getEdgeMap should equal (reversedAcyclicGraph.getEdgeMap) + } - degenerateGraph.getEdgeMap should equal (degenerateGraph.reverse.getEdgeMap) + "Reversing a graph with no edges" should "equal the graph itself" in { + degenerateGraph.getEdgeMap should equal (degenerateGraph.reverse.getEdgeMap) + } "transformNodes" should "combine vertices that collide, not drop them" in { tupleGraph.transformNodes(_._1).getEdgeMap should contain ("a" -> Set("b", "c")) @@ -84,4 +98,16 @@ class DiGraphTests extends FirrtlFlatSpec { (first + second + second + second).getEdgeMap should equal (acyclicGraph.getEdgeMap) } + "linearize" should "not cause a stack overflow on very large graphs" in { + // Graph of 0 -> 1, 1 -> 2, etc. + val N = 10000 + val edges = (1 to N).zipWithIndex.map({ case (n, idx) => idx -> Set(n)}).toMap + val bigGraph = DiGraph(edges + (N -> Set.empty[Int])) + bigGraph.linearize should be (0 to N) + } + + it should "work on multi-rooted graphs" in { + val graph = DiGraph(Map("a" -> Set[String](), "b" -> Set[String]())) + graph.linearize.toSet should be (graph.getVertices) + } } |
