aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorJack Koenig2018-04-11 10:19:16 -0700
committerGitHub2018-04-11 10:19:16 -0700
commit27ee6fbbdf2b1854503ef51ffc0e2108a939d50c (patch)
treec0a4de1fbacec6ea1b5828f1b50181adc936f3df /src/main
parent0c93a121c109c3e167d80173dbfe8c2e355b30ef (diff)
Make DiGraph.linearize be iterative instead of recursive (#785)
Also make DiGraphTests more ScalaTest-y
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala40
1 files changed, 24 insertions, 16 deletions
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
index 135603ff..22f5fcc5 100644
--- a/src/main/scala/firrtl/graph/DiGraph.scala
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -25,7 +25,7 @@ object DiGraph {
}
for ((k, v) <- edgeData) {
for (n <- v) {
- require(edgeDataCopy.contains(n))
+ require(edgeDataCopy.contains(n), s"Does not contain $n")
edgeDataCopy(k) += n
}
}
@@ -77,24 +77,32 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
val unmarked = new mutable.LinkedHashSet[T]
val tempMarked = new mutable.LinkedHashSet[T]
- def visit(n: T): Unit = {
- if (tempMarked.contains(n)) {
- throw new CyclicException(n)
- }
- if (unmarked.contains(n)) {
- tempMarked += n
- unmarked -= n
- for (m <- getEdges(n)) {
- visit(m)
- }
- tempMarked -= n
- order.append(n)
- }
- }
+ case class LinearizeFrame[T](v: T, expanded: Boolean)
+ val callStack = mutable.Stack[LinearizeFrame[T]]()
unmarked ++= getVertices
while (unmarked.nonEmpty) {
- visit(unmarked.head)
+ callStack.push(LinearizeFrame(unmarked.head, false))
+ while (callStack.nonEmpty) {
+ val LinearizeFrame(n, expanded) = callStack.pop()
+ if (!expanded) {
+ if (tempMarked.contains(n)) {
+ throw new CyclicException(n)
+ }
+ if (unmarked.contains(n)) {
+ tempMarked += n
+ unmarked -= n
+ callStack.push(LinearizeFrame(n, true))
+ // We want to visit the first edge first (so push it last)
+ for (m <- edges.getOrElse(n, Set.empty).toSeq.reverse) {
+ callStack.push(LinearizeFrame(m, false))
+ }
+ }
+ } else {
+ tempMarked -= n
+ order.append(n)
+ }
+ }
}
// visited nodes are in post-traversal order, so must be reversed