diff options
| author | Albert Magyar | 2017-11-10 14:02:39 -0800 |
|---|---|---|
| committer | Jack Koenig | 2017-11-10 14:02:39 -0800 |
| commit | 7d86a35e19519d92dc436c07359d5120b44b5a85 (patch) | |
| tree | 230cf0220d373f0840e337586c983920bf3e71b6 | |
| parent | 8ed378dfc9be7e5ebaff1e6b7393b5b991ea691d (diff) | |
Make digraph methods deterministic (#653)
8 files changed, 183 insertions, 152 deletions
diff --git a/src/main/scala/firrtl/analyses/Netlist.scala b/src/main/scala/firrtl/analyses/Netlist.scala index f00e96bb..4e211d8b 100644 --- a/src/main/scala/firrtl/analyses/Netlist.scala +++ b/src/main/scala/firrtl/analyses/Netlist.scala @@ -47,6 +47,7 @@ class InstanceGraph(c: Circuit) { for (child <- childInstances(current.module)) { if (!instanceGraph.contains(child)) { instanceQueue.enqueue(child) + instanceGraph.addVertex(child) } instanceGraph.addEdge(current,child) } diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index 39016732..3a657cc0 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -1,111 +1,70 @@ package firrtl.graph -import scala.collection.immutable.{Set, Map, HashSet, HashMap} +import scala.collection.{Set, Map} import scala.collection.mutable -import scala.collection.mutable.MultiMap - -/** Represents common behavior of all directed graphs */ -trait DiGraphLike[T] { - /** Check whether the graph contains vertex v */ - def contains(v: T): Boolean - - /** Get all vertices in the graph - * @return a Set[T] of all vertices in the graph - */ - def getVertices: collection.Set[T] - - /** Get all edges of a node - * @param v the specified node - * @return a Set[T] of all vertices that v has edges to - */ - def getEdges(v: T): collection.Set[T] -} - -/** A class to represent a mutable directed graph with nodes of type T - * - * @constructor Create a new graph with the provided edge data - * @param edges a mutable.MultiMap[T,T] of edge data - * - * For the edge data MultiMap, the values associated with each vertex - * u in the graph are the vertices with inedges from u - */ -class MutableDiGraph[T]( - private[graph] val edgeData: MultiMap[T,T] = - new mutable.HashMap[T, mutable.Set[T]] with MultiMap[T, T]) extends DiGraphLike[T] { - - // Inherited methods from DiGraphLike - def contains(v: T) = edgeData.contains(v) - def getVertices = edgeData.keySet - def getEdges(v: T) = edgeData(v) - - /** Add vertex v to the graph - * @return v, the added vertex - */ - def addVertex(v: T): T = { - edgeData.getOrElseUpdate(v,new mutable.HashSet[T]) - v - } - - /** Add edge (u,v) to the graph */ - def addEdge(u: T, v: T) = { - // Add v to keys to maintain invariant that all vertices are keys - // of edge data - edgeData.getOrElseUpdate(v, new mutable.HashSet[T]) - edgeData.addBinding(u,v) - } -} +import scala.collection.mutable.{LinkedHashSet, LinkedHashMap} /** A companion to create immutable DiGraphs from mutable data */ object DiGraph { /** Create a DiGraph from a MutableDigraph, representing the same graph */ - def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] = - new DiGraph((mdg.edgeData mapValues { _.toSet }).toMap[T, Set[T]]) - - /** Create a DiGraph from a MultiMap[T] of edge data */ - def apply[T](edgeData: MultiMap[T,T]): DiGraph[T] = - new DiGraph((edgeData mapValues { _.toSet }).toMap[T, Set[T]]) + def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] = mdg /** Create a DiGraph from a Map[T,Set[T]] of edge data */ - def apply[T](edgeData: Map[T,Set[T]]) = new DiGraph(edgeData) + def apply[T](edgeData: Map[T,Set[T]]): DiGraph[T] = { + val edgeDataCopy = new LinkedHashMap[T, LinkedHashSet[T]] + for ((k, v) <- edgeData) { + edgeDataCopy(k) = new LinkedHashSet[T] + } + for ((k, v) <- edgeData) { + for (n <- v) { + require(edgeDataCopy.contains(n)) + edgeDataCopy(k) += n + } + } + new DiGraph(edgeDataCopy) + } } -/** - * A class to represent an immutable directed graph with nodes of - * type T - * - * @constructor Create a new graph with the provided edge data - * @param edges a Map[T,Set[T]] of edge data - * - * For the edge data Map, the value associated with each vertex u in - * the graph is a Set[T] of nodes where for each node v in the set, - * the directed edge (u,v) exists in the graph. - */ -class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] { +/** Represents common behavior of all directed graphs */ +class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { /** An exception that is raised when an assumed DAG has a cycle */ class CyclicException extends Exception("No valid linearization for cyclic graph") /** An exception that is raised when attempting to find an unreachable node */ class PathNotFoundException extends Exception("Unreachable node") - // Inherited methods from DiGraphLike - def contains(v: T) = edges.contains(v) - def getVertices = edges.keySet - def getEdges(v: T) = edges.getOrElse(v, new HashSet[T]) + + /** Check whether the graph contains vertex v */ + def contains(v: T): Boolean = edges.contains(v) + + /** Get all vertices in the graph + * @return a Set[T] of all vertices in the graph + */ + // The pattern of mapping map pairs to keys maintains LinkedHashMap ordering + def getVertices: Set[T] = new LinkedHashSet ++ edges.map({ case (k, _) => k }) + + /** Get all edges of a node + * @param v the specified node + * @return a Set[T] of all vertices that v has edges to + */ + def getEdges(v: T): Set[T] = edges.getOrElse(v, Set.empty) + + def getEdgeMap: Map[T, Set[T]] = edges /** Find all sources in the graph - * + * * @return a Set[T] of source nodes */ - def findSources: Set[T] = edges.keySet -- edges.values.flatten.toSet + def findSources: Set[T] = getVertices -- edges.values.flatten.toSet /** Find all sinks in the graph - * + * * @return a Set[T] of sink nodes */ def findSinks: Set[T] = reverse.findSources /** Linearizes (topologically sorts) a DAG - * + * * @param root the start node * @throws CyclicException if the graph is cyclic * @return a Map[T,T] from each visited node to its predecessor in the @@ -115,8 +74,8 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] { // permanently marked nodes are implicitly held in order val order = new mutable.ArrayBuffer[T] // invariant: no intersection between unmarked and tempMarked - val unmarked = new mutable.HashSet[T] - val tempMarked = new mutable.HashSet[T] + val unmarked = new LinkedHashSet[T] + val tempMarked = new LinkedHashSet[T] def visit(n: T): Unit = { if (tempMarked.contains(n)) { @@ -143,13 +102,13 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] { } /** Performs breadth-first search on the directed graph - * + * * @param root the start node * @return a Map[T,T] from each visited node to its predecessor in the * traversal */ def BFS(root: T): Map[T,T] = { - val prev = new mutable.HashMap[T,T] + val prev = new LinkedHashMap[T,T] val queue = new mutable.Queue[T] queue.enqueue(root) while (!queue.isEmpty) { @@ -161,24 +120,24 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] { } } } - prev.toMap + prev } /** Finds the set of nodes reachable from a particular node - * + * * @param root the start node * @return a Set[T] of nodes reachable from the root */ - def reachableFrom(root: T): Set[T] = BFS(root).keys.toSet + def reachableFrom(root: T): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root).map({ case (k, v) => k }) /** Finds a path (if one exists) from one node to another - * + * * @param start the start node * @param end the destination node * @throws PathNotFoundException * @return a Seq[T] of nodes defining an arbitrary valid path */ - def path(start: T, end: T) = { + def path(start: T, end: T): Seq[T] = { val nodePath = new mutable.ArrayBuffer[T] val prev = BFS(start) nodePath += end @@ -192,15 +151,15 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] { } /** Finds the strongly connected components in the graph - * + * * @return a Seq of Seq[T], each containing nodes of an SCC in traversable order */ def findSCCs: Seq[Seq[T]] = { var counter: BigInt = 0 val stack = new mutable.Stack[T] - val onstack = new mutable.HashSet[T] - val indices = new mutable.HashMap[T, BigInt] - val lowlinks = new mutable.HashMap[T, BigInt] + val onstack = new LinkedHashSet[T] + val indices = new LinkedHashMap[T, BigInt] + val lowlinks = new LinkedHashMap[T, BigInt] val sccs = new mutable.ArrayBuffer[Seq[T]] /* @@ -260,83 +219,130 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] { } /** Finds all paths starting at a particular node in a DAG - * + * * WARNING: This is an exponential time algorithm (as any algorithm * must be for this problem), but is useful for flattening circuit * graph hierarchies. Each path is represented by a Seq[T] of nodes * in a traversable order. - * + * * @param start the node to start at * @return a Map[T,Seq[Seq[T]]] where the value associated with v is the Seq of all paths from start to v */ def pathsInDAG(start: T): Map[T,Seq[Seq[T]]] = { // paths(v) holds the set of paths from start to v - val paths = new mutable.HashMap[T,mutable.Set[Seq[T]]] with mutable.MultiMap[T,Seq[T]] + val paths = new LinkedHashMap[T, mutable.Set[Seq[T]]] val queue = new mutable.Queue[T] val reachable = reachableFrom(start) - paths.addBinding(start,Seq(start)) + def addBinding(n: T, p: Seq[T]): Unit = { + paths.getOrElseUpdate(n, new LinkedHashSet[Seq[T]]) += p + } + addBinding(start,Seq(start)) queue += start queue ++= linearize.filter(reachable.contains(_)) while (!queue.isEmpty) { val current = queue.dequeue for (v <- getEdges(current)) { for (p <- paths(current)) { - paths.addBinding(v, p :+ v) + addBinding(v, p :+ v) } } } - (paths map { case (k,v) => (k,v.toSeq) }).toMap + paths.map({ case (k,v) => (k,v.toSeq) }) } /** Returns a graph with all edges reversed */ def reverse: DiGraph[T] = { val mdg = new MutableDiGraph[T] - edges.foreach { case (u, edges) => - mdg.addVertex(u) + edges.foreach({ case (u, edges) => mdg.addVertex(u) }) + edges.foreach({ case (u, edges) => edges.foreach(v => mdg.addEdge(v,u)) - } + }) DiGraph(mdg) } + private def filterEdges(vprime: Set[T]): LinkedHashMap[T, LinkedHashSet[T]] = { + def filterNodeSet(s: LinkedHashSet[T]): LinkedHashSet[T] = s.filter({ case (k) => vprime.contains(k) }) + def filterAdjacencyLists(m: LinkedHashMap[T, LinkedHashSet[T]]): LinkedHashMap[T, LinkedHashSet[T]] = m.map({ case (k, v) => (k, filterNodeSet(v)) }) + var eprime: LinkedHashMap[T, LinkedHashSet[T]] = edges.filter({ case (k, v) => vprime.contains(k) }) + filterAdjacencyLists(eprime) + } + /** Return a graph with only a subset of the nodes * * Any edge including a deleted node will be deleted - * + * * @param vprime the Set[T] of desired vertices * @throws IllegalArgumentException if vprime is not a subset of V * @return the subgraph */ def subgraph(vprime: Set[T]): DiGraph[T] = { require(vprime.subsetOf(edges.keySet)) - val eprime = vprime.map(v => (v,getEdges(v) & vprime)).toMap - new DiGraph(eprime) + new DiGraph(filterEdges(vprime)) } - /** Return a graph with only a subset of the nodes + /** Return a simplified connectivity graph with only a subset of the nodes + * + * Any path between two non-deleted nodes (u,v) in the original graph will be + * transformed into an edge (u,v). * - * Any path between two non-deleted nodes (u,v) that traverses only - * deleted nodes will be transformed into an edge (u,v). - * * @param vprime the Set[T] of desired vertices * @throws IllegalArgumentException if vprime is not a subset of V * @return the simplified graph */ def simplify(vprime: Set[T]): DiGraph[T] = { require(vprime.subsetOf(edges.keySet)) - val eprime = vprime.map( v => (v,reachableFrom(v) & (vprime-v)) ).toMap - new DiGraph(eprime) + val pathEdges = vprime.map( v => (v, reachableFrom(v) & (vprime-v)) ) + new DiGraph(new LinkedHashMap[T, LinkedHashSet[T]] ++ pathEdges) } /** Return a graph with all the nodes of the current graph transformed * by a function. Edge connectivity will be the same as the current * graph. - * + * * @param f A function {(T) => Q} that transforms each node * @return a transformed DiGraph[Q] */ def transformNodes[Q](f: (T) => Q): DiGraph[Q] = { - val eprime = edges.map({ case (k,v) => (f(k),v.map(f(_))) }) + val eprime = edges.map({ case (k, v) => (f(k), v.map(f(_))) }) new DiGraph(eprime) } } + +class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]]) { + /** Add vertex v to the graph + * @return v, the added vertex + */ + def addVertex(v: T): T = { + edges.getOrElseUpdate(v, new LinkedHashSet[T]) + v + } + + /** Add edge (u,v) to the graph. + * @throws IllegalArgumentException if u and/or v is not in the graph + */ + def addEdge(u: T, v: T): Unit = { + require(contains(u)) + require(contains(v)) + edges(u) += v + } + + /** Add edge (u,v) to the graph, adding u and/or v if they are not + * already in the graph. + */ + def addPairWithEdge(u: T, v: T): Unit = { + edges.getOrElseUpdate(v, new LinkedHashSet[T]) + edges.getOrElseUpdate(u, new LinkedHashSet[T]) += v + } + + /** Add edge (u,v) to the graph if and only if both u and v are in + * the graph prior to calling addEdgeIfValid. + */ + def addEdgeIfValid(u: T, v: T): Boolean = { + val valid = contains(u) && contains(v) + if (contains(u) && contains(v)) { + edges(u) += v + } + valid + } +} diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index d50a027c..bb2ffea9 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -77,15 +77,16 @@ class CheckCombLoops extends Transform { } } - private def getExprDeps(deps: mutable.Set[LogicNode])(e: Expression): Expression = e match { + + private def getExprDeps(deps: MutableDiGraph[LogicNode], v: LogicNode)(e: Expression): Expression = e match { case r: WRef => - deps += toLogicNode(r) + deps.addEdgeIfValid(v, toLogicNode(r)) r case s: WSubField => - deps += toLogicNode(s) + deps.addEdgeIfValid(v, toLogicNode(s)) s case _ => - e map getExprDeps(deps) + e map getExprDeps(deps, v) } private def getStmtDeps( @@ -95,14 +96,14 @@ class CheckCombLoops extends Transform { case Connect(_,loc,expr) => val lhs = toLogicNode(loc) if (deps.contains(lhs)) { - getExprDeps(deps.getEdges(lhs))(expr) + getExprDeps(deps, lhs)(expr) } case w: DefWire => deps.addVertex(LogicNode(w.name)) case n: DefNode => val lhs = LogicNode(n.name) deps.addVertex(lhs) - getExprDeps(deps.getEdges(lhs))(n.value) + getExprDeps(deps, lhs)(n.value) case m: DefMemory if (m.readLatency == 0) => for (rp <- m.readers) { val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp))) @@ -111,10 +112,8 @@ class CheckCombLoops extends Transform { } case i: WDefInstance => val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) - for (v <- iGraph.getVertices) { - deps.addVertex(v) - iGraph.getEdges(v).foreach { deps.addEdge(v,_) } - } + iGraph.getVertices.foreach(deps.addVertex(_)) + iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } }) case _ => s map getStmtDeps(simplifiedModules,deps) } @@ -196,9 +195,9 @@ class CheckCombLoops extends Transform { * exist. Maybe warn when iterating through modules. */ val moduleMap = c.modules.map({m => (m.name,m) }).toMap - val iGraph = new InstanceGraph(c) - val moduleDeps = iGraph.graph.edges.map{ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) } - val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } + val iGraph = new InstanceGraph(c).graph + val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap + val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] for (m <- topoSortedModules) { @@ -212,7 +211,7 @@ class CheckCombLoops extends Transform { val sccSubgraph = moduleGraphs(m.name).subgraph(scc.toSet) val cycle = findCycleInSCC(sccSubgraph) (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) }) - val expandedCycle = expandInstancePaths(m.name,moduleGraphs,moduleDeps,Seq(m.name),cycle.reverse) + val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse) errors.append(new CombLoopException(m.info, m.name, expandedCycle)) } } diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 69502911..84b63e3d 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -398,9 +398,9 @@ class ConstantPropagation extends Transform { private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = { val iGraph = (new InstanceGraph(c)).graph - val moduleDeps = iGraph.edges.map { case (mod, children) => + val moduleDeps = iGraph.getEdgeMap.map({ case (mod, children) => mod.module -> children.map(i => i.name -> i.module).toMap - } + }) // Module name to number of instances val instCount: Map[String, Int] = iGraph.getVertices.groupBy(_.module).mapValues(_.size) diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index 24c1c51c..578de264 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -95,11 +95,11 @@ class DeadCodeElimination extends Transform { case DefRegister(_, name, _, clock, reset, init) => val node = LogicNode(mod.name, name) depGraph.addVertex(node) - Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(node, ref)) + Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(node, ref)) case DefNode(_, name, value) => val node = LogicNode(mod.name, name) depGraph.addVertex(node) - getDeps(value).foreach(ref => depGraph.addEdge(node, ref)) + getDeps(value).foreach(ref => depGraph.addPairWithEdge(node, ref)) case DefWire(_, name, _) => depGraph.addVertex(LogicNode(mod.name, name)) case mem: DefMemory => @@ -111,23 +111,23 @@ class DeadCodeElimination extends Transform { val sinks = exprs.getOrElse(FEMALE, List.empty).flatMap(getDeps(_)) val memNode = getDeps(memRef) match { case Seq(node) => node } depGraph.addVertex(memNode) - sinks.foreach(sink => depGraph.addEdge(sink, memNode)) - sources.foreach(source => depGraph.addEdge(memNode, source)) + sinks.foreach(sink => depGraph.addPairWithEdge(sink, memNode)) + sources.foreach(source => depGraph.addPairWithEdge(memNode, source)) case Attach(_, exprs) => // Add edge between each expression exprs.flatMap(getDeps(_)).toSet.subsets(2).map(_.toList).foreach { case Seq(a, b) => - depGraph.addEdge(a, b) - depGraph.addEdge(b, a) + depGraph.addPairWithEdge(a, b) + depGraph.addPairWithEdge(b, a) } case Connect(_, loc, expr) => // This match enforces the low Firrtl requirement of expanded connections val node = getDeps(loc) match { case Seq(elt) => elt } - getDeps(expr).foreach(ref => depGraph.addEdge(node, ref)) + getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref)) // Simulation constructs are treated as top-level outputs case Stop(_,_, clk, en) => - Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref)) + Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) case Print(_, _, args, clk, en) => - (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref)) + (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) case Block(stmts) => stmts.foreach(onStmt(_)) case ignore @ (_: IsInvalid | _: WDefInstance | EmptyStmt) => // do nothing case other => throw new Exception(s"Unexpected Statement $other") @@ -153,30 +153,30 @@ class DeadCodeElimination extends Transform { val node = LogicNode(ext) // Don't touch external modules *unless* they are specifically marked as doTouch // Simply marking the extmodule itself is sufficient to prevent inputs from being removed - if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, node) + if (!doTouchExtMods.contains(ext.name)) depGraph.addPairWithEdge(circuitSink, node) ext.ports.foreach { case Port(_, pname, _, AnalogType(_)) => - depGraph.addEdge(LogicNode(ext.name, pname), node) - depGraph.addEdge(node, LogicNode(ext.name, pname)) + depGraph.addPairWithEdge(LogicNode(ext.name, pname), node) + depGraph.addPairWithEdge(node, LogicNode(ext.name, pname)) case Port(_, pname, Output, _) => val portNode = LogicNode(ext.name, pname) - depGraph.addEdge(portNode, node) + depGraph.addPairWithEdge(portNode, node) // Also mark all outputs as circuit sinks (unless marked doTouch obviously) - if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, portNode) - case Port(_, pname, Input, _) => depGraph.addEdge(node, LogicNode(ext.name, pname)) + if (!doTouchExtMods.contains(ext.name)) depGraph.addPairWithEdge(circuitSink, portNode) + case Port(_, pname, Input, _) => depGraph.addPairWithEdge(node, LogicNode(ext.name, pname)) } } // Connect circuitSink to ALL top-level ports (we don't want to change the top-level interface) val topModule = c.modules.find(_.name == c.main).get val topOutputs = topModule.ports.foreach { port => - depGraph.addEdge(circuitSink, LogicNode(c.main, port.name)) + depGraph.addPairWithEdge(circuitSink, LogicNode(c.main, port.name)) } depGraph } private def deleteDeadCode(instMap: collection.Map[String, String], - deadNodes: Set[LogicNode], + deadNodes: collection.Set[LogicNode], moduleMap: collection.Map[String, DefModule], renames: RenameMap) (mod: DefModule): Option[DefModule] = { @@ -243,9 +243,9 @@ class DeadCodeElimination extends Transform { val c = state.circuit val moduleMap = c.modules.map(m => m.name -> m).toMap val iGraph = new InstanceGraph(c) - val moduleDeps = iGraph.graph.edges.map { case (k,v) => + val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) => k.module -> v.map(i => i.name -> i.module).toMap - } + }) val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_)) val depGraph = { @@ -255,7 +255,7 @@ class DeadCodeElimination extends Transform { dontTouches.foreach { dontTouch => // Ensure that they are actually found if (vertices.contains(dontTouch)) { - dGraph.addEdge(circuitSink, dontTouch) + dGraph.addPairWithEdge(circuitSink, dontTouch) } else { val (root, tail) = Utils.splitRef(dontTouch.e1) DontTouchAnnotation.errorNotFound(root.serialize, tail.serialize) diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala index 2c12d4ca..6c8a2f20 100644 --- a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala +++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala @@ -23,6 +23,31 @@ class CheckCombLoopsSpec extends SimpleTransformSpec { new MiddleFirrtlToLowFirrtl ) + "Loop-free circuit" should "not throw an exception" in { + val input = """circuit hasnoloops : + | module thru : + | input in1 : UInt<1> + | input in2 : UInt<1> + | output out1 : UInt<1> + | output out2 : UInt<1> + | out1 <= in1 + | out2 <= in2 + | module hasnoloops : + | input clk : Clock + | input a : UInt<1> + | output b : UInt<1> + | wire x : UInt<1> + | inst inner of thru + | inner.in1 <= a + | x <= inner.out1 + | inner.in2 <= x + | b <= inner.out2 + |""".stripMargin + + val writer = new java.io.StringWriter + compile(CircuitState(parse(input), ChirrtlForm, None), writer) + } + "Simple combinational loop" should "throw an exception" in { val input = """circuit hasloops : | module hasloops : diff --git a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala index 300fa8c4..3e517079 100644 --- a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala +++ b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala @@ -11,7 +11,7 @@ import firrtl.passes._ import firrtlTests._ class InstanceGraphTests extends FirrtlFlatSpec { - private def getEdgeSet(graph: DiGraph[String]): Map[String, Set[String]] = { + private def getEdgeSet(graph: DiGraph[String]): collection.Map[String, collection.Set[String]] = { (graph.getVertices map {v => (v, graph.getEdges(v))}).toMap } diff --git a/src/test/scala/firrtlTests/graph/DiGraphTests.scala b/src/test/scala/firrtlTests/graph/DiGraphTests.scala index 9eb1c7f8..da268e4f 100644 --- a/src/test/scala/firrtlTests/graph/DiGraphTests.scala +++ b/src/test/scala/firrtlTests/graph/DiGraphTests.scala @@ -43,8 +43,8 @@ class DiGraphTests extends FirrtlFlatSpec { a [cyclicGraph.CyclicException] should be thrownBy cyclicGraph.linearize - acyclicGraph.reverse.edges should equal (reversedAcyclicGraph.edges) + acyclicGraph.reverse.getEdgeMap should equal (reversedAcyclicGraph.getEdgeMap) - degenerateGraph.edges should equal (degenerateGraph.reverse.edges) + degenerateGraph.getEdgeMap should equal (degenerateGraph.reverse.getEdgeMap) } |
