aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Magyar2017-11-10 14:02:39 -0800
committerJack Koenig2017-11-10 14:02:39 -0800
commit7d86a35e19519d92dc436c07359d5120b44b5a85 (patch)
tree230cf0220d373f0840e337586c983920bf3e71b6 /src
parent8ed378dfc9be7e5ebaff1e6b7393b5b991ea691d (diff)
Make digraph methods deterministic (#653)
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/analyses/Netlist.scala1
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala232
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala27
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala4
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala40
-rw-r--r--src/test/scala/firrtlTests/CheckCombLoopsSpec.scala25
-rw-r--r--src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala2
-rw-r--r--src/test/scala/firrtlTests/graph/DiGraphTests.scala4
8 files changed, 183 insertions, 152 deletions
diff --git a/src/main/scala/firrtl/analyses/Netlist.scala b/src/main/scala/firrtl/analyses/Netlist.scala
index f00e96bb..4e211d8b 100644
--- a/src/main/scala/firrtl/analyses/Netlist.scala
+++ b/src/main/scala/firrtl/analyses/Netlist.scala
@@ -47,6 +47,7 @@ class InstanceGraph(c: Circuit) {
for (child <- childInstances(current.module)) {
if (!instanceGraph.contains(child)) {
instanceQueue.enqueue(child)
+ instanceGraph.addVertex(child)
}
instanceGraph.addEdge(current,child)
}
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
index 39016732..3a657cc0 100644
--- a/src/main/scala/firrtl/graph/DiGraph.scala
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -1,111 +1,70 @@
package firrtl.graph
-import scala.collection.immutable.{Set, Map, HashSet, HashMap}
+import scala.collection.{Set, Map}
import scala.collection.mutable
-import scala.collection.mutable.MultiMap
-
-/** Represents common behavior of all directed graphs */
-trait DiGraphLike[T] {
- /** Check whether the graph contains vertex v */
- def contains(v: T): Boolean
-
- /** Get all vertices in the graph
- * @return a Set[T] of all vertices in the graph
- */
- def getVertices: collection.Set[T]
-
- /** Get all edges of a node
- * @param v the specified node
- * @return a Set[T] of all vertices that v has edges to
- */
- def getEdges(v: T): collection.Set[T]
-}
-
-/** A class to represent a mutable directed graph with nodes of type T
- *
- * @constructor Create a new graph with the provided edge data
- * @param edges a mutable.MultiMap[T,T] of edge data
- *
- * For the edge data MultiMap, the values associated with each vertex
- * u in the graph are the vertices with inedges from u
- */
-class MutableDiGraph[T](
- private[graph] val edgeData: MultiMap[T,T] =
- new mutable.HashMap[T, mutable.Set[T]] with MultiMap[T, T]) extends DiGraphLike[T] {
-
- // Inherited methods from DiGraphLike
- def contains(v: T) = edgeData.contains(v)
- def getVertices = edgeData.keySet
- def getEdges(v: T) = edgeData(v)
-
- /** Add vertex v to the graph
- * @return v, the added vertex
- */
- def addVertex(v: T): T = {
- edgeData.getOrElseUpdate(v,new mutable.HashSet[T])
- v
- }
-
- /** Add edge (u,v) to the graph */
- def addEdge(u: T, v: T) = {
- // Add v to keys to maintain invariant that all vertices are keys
- // of edge data
- edgeData.getOrElseUpdate(v, new mutable.HashSet[T])
- edgeData.addBinding(u,v)
- }
-}
+import scala.collection.mutable.{LinkedHashSet, LinkedHashMap}
/** A companion to create immutable DiGraphs from mutable data */
object DiGraph {
/** Create a DiGraph from a MutableDigraph, representing the same graph */
- def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] =
- new DiGraph((mdg.edgeData mapValues { _.toSet }).toMap[T, Set[T]])
-
- /** Create a DiGraph from a MultiMap[T] of edge data */
- def apply[T](edgeData: MultiMap[T,T]): DiGraph[T] =
- new DiGraph((edgeData mapValues { _.toSet }).toMap[T, Set[T]])
+ def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] = mdg
/** Create a DiGraph from a Map[T,Set[T]] of edge data */
- def apply[T](edgeData: Map[T,Set[T]]) = new DiGraph(edgeData)
+ def apply[T](edgeData: Map[T,Set[T]]): DiGraph[T] = {
+ val edgeDataCopy = new LinkedHashMap[T, LinkedHashSet[T]]
+ for ((k, v) <- edgeData) {
+ edgeDataCopy(k) = new LinkedHashSet[T]
+ }
+ for ((k, v) <- edgeData) {
+ for (n <- v) {
+ require(edgeDataCopy.contains(n))
+ edgeDataCopy(k) += n
+ }
+ }
+ new DiGraph(edgeDataCopy)
+ }
}
-/**
- * A class to represent an immutable directed graph with nodes of
- * type T
- *
- * @constructor Create a new graph with the provided edge data
- * @param edges a Map[T,Set[T]] of edge data
- *
- * For the edge data Map, the value associated with each vertex u in
- * the graph is a Set[T] of nodes where for each node v in the set,
- * the directed edge (u,v) exists in the graph.
- */
-class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] {
+/** Represents common behavior of all directed graphs */
+class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) {
/** An exception that is raised when an assumed DAG has a cycle */
class CyclicException extends Exception("No valid linearization for cyclic graph")
/** An exception that is raised when attempting to find an unreachable node */
class PathNotFoundException extends Exception("Unreachable node")
- // Inherited methods from DiGraphLike
- def contains(v: T) = edges.contains(v)
- def getVertices = edges.keySet
- def getEdges(v: T) = edges.getOrElse(v, new HashSet[T])
+
+ /** Check whether the graph contains vertex v */
+ def contains(v: T): Boolean = edges.contains(v)
+
+ /** Get all vertices in the graph
+ * @return a Set[T] of all vertices in the graph
+ */
+ // The pattern of mapping map pairs to keys maintains LinkedHashMap ordering
+ def getVertices: Set[T] = new LinkedHashSet ++ edges.map({ case (k, _) => k })
+
+ /** Get all edges of a node
+ * @param v the specified node
+ * @return a Set[T] of all vertices that v has edges to
+ */
+ def getEdges(v: T): Set[T] = edges.getOrElse(v, Set.empty)
+
+ def getEdgeMap: Map[T, Set[T]] = edges
/** Find all sources in the graph
- *
+ *
* @return a Set[T] of source nodes
*/
- def findSources: Set[T] = edges.keySet -- edges.values.flatten.toSet
+ def findSources: Set[T] = getVertices -- edges.values.flatten.toSet
/** Find all sinks in the graph
- *
+ *
* @return a Set[T] of sink nodes
*/
def findSinks: Set[T] = reverse.findSources
/** Linearizes (topologically sorts) a DAG
- *
+ *
* @param root the start node
* @throws CyclicException if the graph is cyclic
* @return a Map[T,T] from each visited node to its predecessor in the
@@ -115,8 +74,8 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] {
// permanently marked nodes are implicitly held in order
val order = new mutable.ArrayBuffer[T]
// invariant: no intersection between unmarked and tempMarked
- val unmarked = new mutable.HashSet[T]
- val tempMarked = new mutable.HashSet[T]
+ val unmarked = new LinkedHashSet[T]
+ val tempMarked = new LinkedHashSet[T]
def visit(n: T): Unit = {
if (tempMarked.contains(n)) {
@@ -143,13 +102,13 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] {
}
/** Performs breadth-first search on the directed graph
- *
+ *
* @param root the start node
* @return a Map[T,T] from each visited node to its predecessor in the
* traversal
*/
def BFS(root: T): Map[T,T] = {
- val prev = new mutable.HashMap[T,T]
+ val prev = new LinkedHashMap[T,T]
val queue = new mutable.Queue[T]
queue.enqueue(root)
while (!queue.isEmpty) {
@@ -161,24 +120,24 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] {
}
}
}
- prev.toMap
+ prev
}
/** Finds the set of nodes reachable from a particular node
- *
+ *
* @param root the start node
* @return a Set[T] of nodes reachable from the root
*/
- def reachableFrom(root: T): Set[T] = BFS(root).keys.toSet
+ def reachableFrom(root: T): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root).map({ case (k, v) => k })
/** Finds a path (if one exists) from one node to another
- *
+ *
* @param start the start node
* @param end the destination node
* @throws PathNotFoundException
* @return a Seq[T] of nodes defining an arbitrary valid path
*/
- def path(start: T, end: T) = {
+ def path(start: T, end: T): Seq[T] = {
val nodePath = new mutable.ArrayBuffer[T]
val prev = BFS(start)
nodePath += end
@@ -192,15 +151,15 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] {
}
/** Finds the strongly connected components in the graph
- *
+ *
* @return a Seq of Seq[T], each containing nodes of an SCC in traversable order
*/
def findSCCs: Seq[Seq[T]] = {
var counter: BigInt = 0
val stack = new mutable.Stack[T]
- val onstack = new mutable.HashSet[T]
- val indices = new mutable.HashMap[T, BigInt]
- val lowlinks = new mutable.HashMap[T, BigInt]
+ val onstack = new LinkedHashSet[T]
+ val indices = new LinkedHashMap[T, BigInt]
+ val lowlinks = new LinkedHashMap[T, BigInt]
val sccs = new mutable.ArrayBuffer[Seq[T]]
/*
@@ -260,83 +219,130 @@ class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] {
}
/** Finds all paths starting at a particular node in a DAG
- *
+ *
* WARNING: This is an exponential time algorithm (as any algorithm
* must be for this problem), but is useful for flattening circuit
* graph hierarchies. Each path is represented by a Seq[T] of nodes
* in a traversable order.
- *
+ *
* @param start the node to start at
* @return a Map[T,Seq[Seq[T]]] where the value associated with v is the Seq of all paths from start to v
*/
def pathsInDAG(start: T): Map[T,Seq[Seq[T]]] = {
// paths(v) holds the set of paths from start to v
- val paths = new mutable.HashMap[T,mutable.Set[Seq[T]]] with mutable.MultiMap[T,Seq[T]]
+ val paths = new LinkedHashMap[T, mutable.Set[Seq[T]]]
val queue = new mutable.Queue[T]
val reachable = reachableFrom(start)
- paths.addBinding(start,Seq(start))
+ def addBinding(n: T, p: Seq[T]): Unit = {
+ paths.getOrElseUpdate(n, new LinkedHashSet[Seq[T]]) += p
+ }
+ addBinding(start,Seq(start))
queue += start
queue ++= linearize.filter(reachable.contains(_))
while (!queue.isEmpty) {
val current = queue.dequeue
for (v <- getEdges(current)) {
for (p <- paths(current)) {
- paths.addBinding(v, p :+ v)
+ addBinding(v, p :+ v)
}
}
}
- (paths map { case (k,v) => (k,v.toSeq) }).toMap
+ paths.map({ case (k,v) => (k,v.toSeq) })
}
/** Returns a graph with all edges reversed */
def reverse: DiGraph[T] = {
val mdg = new MutableDiGraph[T]
- edges.foreach { case (u, edges) =>
- mdg.addVertex(u)
+ edges.foreach({ case (u, edges) => mdg.addVertex(u) })
+ edges.foreach({ case (u, edges) =>
edges.foreach(v => mdg.addEdge(v,u))
- }
+ })
DiGraph(mdg)
}
+ private def filterEdges(vprime: Set[T]): LinkedHashMap[T, LinkedHashSet[T]] = {
+ def filterNodeSet(s: LinkedHashSet[T]): LinkedHashSet[T] = s.filter({ case (k) => vprime.contains(k) })
+ def filterAdjacencyLists(m: LinkedHashMap[T, LinkedHashSet[T]]): LinkedHashMap[T, LinkedHashSet[T]] = m.map({ case (k, v) => (k, filterNodeSet(v)) })
+ var eprime: LinkedHashMap[T, LinkedHashSet[T]] = edges.filter({ case (k, v) => vprime.contains(k) })
+ filterAdjacencyLists(eprime)
+ }
+
/** Return a graph with only a subset of the nodes
*
* Any edge including a deleted node will be deleted
- *
+ *
* @param vprime the Set[T] of desired vertices
* @throws IllegalArgumentException if vprime is not a subset of V
* @return the subgraph
*/
def subgraph(vprime: Set[T]): DiGraph[T] = {
require(vprime.subsetOf(edges.keySet))
- val eprime = vprime.map(v => (v,getEdges(v) & vprime)).toMap
- new DiGraph(eprime)
+ new DiGraph(filterEdges(vprime))
}
- /** Return a graph with only a subset of the nodes
+ /** Return a simplified connectivity graph with only a subset of the nodes
+ *
+ * Any path between two non-deleted nodes (u,v) in the original graph will be
+ * transformed into an edge (u,v).
*
- * Any path between two non-deleted nodes (u,v) that traverses only
- * deleted nodes will be transformed into an edge (u,v).
- *
* @param vprime the Set[T] of desired vertices
* @throws IllegalArgumentException if vprime is not a subset of V
* @return the simplified graph
*/
def simplify(vprime: Set[T]): DiGraph[T] = {
require(vprime.subsetOf(edges.keySet))
- val eprime = vprime.map( v => (v,reachableFrom(v) & (vprime-v)) ).toMap
- new DiGraph(eprime)
+ val pathEdges = vprime.map( v => (v, reachableFrom(v) & (vprime-v)) )
+ new DiGraph(new LinkedHashMap[T, LinkedHashSet[T]] ++ pathEdges)
}
/** Return a graph with all the nodes of the current graph transformed
* by a function. Edge connectivity will be the same as the current
* graph.
- *
+ *
* @param f A function {(T) => Q} that transforms each node
* @return a transformed DiGraph[Q]
*/
def transformNodes[Q](f: (T) => Q): DiGraph[Q] = {
- val eprime = edges.map({ case (k,v) => (f(k),v.map(f(_))) })
+ val eprime = edges.map({ case (k, v) => (f(k), v.map(f(_))) })
new DiGraph(eprime)
}
}
+
+class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]]) {
+ /** Add vertex v to the graph
+ * @return v, the added vertex
+ */
+ def addVertex(v: T): T = {
+ edges.getOrElseUpdate(v, new LinkedHashSet[T])
+ v
+ }
+
+ /** Add edge (u,v) to the graph.
+ * @throws IllegalArgumentException if u and/or v is not in the graph
+ */
+ def addEdge(u: T, v: T): Unit = {
+ require(contains(u))
+ require(contains(v))
+ edges(u) += v
+ }
+
+ /** Add edge (u,v) to the graph, adding u and/or v if they are not
+ * already in the graph.
+ */
+ def addPairWithEdge(u: T, v: T): Unit = {
+ edges.getOrElseUpdate(v, new LinkedHashSet[T])
+ edges.getOrElseUpdate(u, new LinkedHashSet[T]) += v
+ }
+
+ /** Add edge (u,v) to the graph if and only if both u and v are in
+ * the graph prior to calling addEdgeIfValid.
+ */
+ def addEdgeIfValid(u: T, v: T): Boolean = {
+ val valid = contains(u) && contains(v)
+ if (contains(u) && contains(v)) {
+ edges(u) += v
+ }
+ valid
+ }
+}
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index d50a027c..bb2ffea9 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -77,15 +77,16 @@ class CheckCombLoops extends Transform {
}
}
- private def getExprDeps(deps: mutable.Set[LogicNode])(e: Expression): Expression = e match {
+
+ private def getExprDeps(deps: MutableDiGraph[LogicNode], v: LogicNode)(e: Expression): Expression = e match {
case r: WRef =>
- deps += toLogicNode(r)
+ deps.addEdgeIfValid(v, toLogicNode(r))
r
case s: WSubField =>
- deps += toLogicNode(s)
+ deps.addEdgeIfValid(v, toLogicNode(s))
s
case _ =>
- e map getExprDeps(deps)
+ e map getExprDeps(deps, v)
}
private def getStmtDeps(
@@ -95,14 +96,14 @@ class CheckCombLoops extends Transform {
case Connect(_,loc,expr) =>
val lhs = toLogicNode(loc)
if (deps.contains(lhs)) {
- getExprDeps(deps.getEdges(lhs))(expr)
+ getExprDeps(deps, lhs)(expr)
}
case w: DefWire =>
deps.addVertex(LogicNode(w.name))
case n: DefNode =>
val lhs = LogicNode(n.name)
deps.addVertex(lhs)
- getExprDeps(deps.getEdges(lhs))(n.value)
+ getExprDeps(deps, lhs)(n.value)
case m: DefMemory if (m.readLatency == 0) =>
for (rp <- m.readers) {
val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
@@ -111,10 +112,8 @@ class CheckCombLoops extends Transform {
}
case i: WDefInstance =>
val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
- for (v <- iGraph.getVertices) {
- deps.addVertex(v)
- iGraph.getEdges(v).foreach { deps.addEdge(v,_) }
- }
+ iGraph.getVertices.foreach(deps.addVertex(_))
+ iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
case _ =>
s map getStmtDeps(simplifiedModules,deps)
}
@@ -196,9 +195,9 @@ class CheckCombLoops extends Transform {
* exist. Maybe warn when iterating through modules.
*/
val moduleMap = c.modules.map({m => (m.name,m) }).toMap
- val iGraph = new InstanceGraph(c)
- val moduleDeps = iGraph.graph.edges.map{ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }
- val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
+ val iGraph = new InstanceGraph(c).graph
+ val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap
+ val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
for (m <- topoSortedModules) {
@@ -212,7 +211,7 @@ class CheckCombLoops extends Transform {
val sccSubgraph = moduleGraphs(m.name).subgraph(scc.toSet)
val cycle = findCycleInSCC(sccSubgraph)
(cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) })
- val expandedCycle = expandInstancePaths(m.name,moduleGraphs,moduleDeps,Seq(m.name),cycle.reverse)
+ val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse)
errors.append(new CombLoopException(m.info, m.name, expandedCycle))
}
}
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 69502911..84b63e3d 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -398,9 +398,9 @@ class ConstantPropagation extends Transform {
private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = {
val iGraph = (new InstanceGraph(c)).graph
- val moduleDeps = iGraph.edges.map { case (mod, children) =>
+ val moduleDeps = iGraph.getEdgeMap.map({ case (mod, children) =>
mod.module -> children.map(i => i.name -> i.module).toMap
- }
+ })
// Module name to number of instances
val instCount: Map[String, Int] = iGraph.getVertices.groupBy(_.module).mapValues(_.size)
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index 24c1c51c..578de264 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -95,11 +95,11 @@ class DeadCodeElimination extends Transform {
case DefRegister(_, name, _, clock, reset, init) =>
val node = LogicNode(mod.name, name)
depGraph.addVertex(node)
- Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(node, ref))
+ Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(node, ref))
case DefNode(_, name, value) =>
val node = LogicNode(mod.name, name)
depGraph.addVertex(node)
- getDeps(value).foreach(ref => depGraph.addEdge(node, ref))
+ getDeps(value).foreach(ref => depGraph.addPairWithEdge(node, ref))
case DefWire(_, name, _) =>
depGraph.addVertex(LogicNode(mod.name, name))
case mem: DefMemory =>
@@ -111,23 +111,23 @@ class DeadCodeElimination extends Transform {
val sinks = exprs.getOrElse(FEMALE, List.empty).flatMap(getDeps(_))
val memNode = getDeps(memRef) match { case Seq(node) => node }
depGraph.addVertex(memNode)
- sinks.foreach(sink => depGraph.addEdge(sink, memNode))
- sources.foreach(source => depGraph.addEdge(memNode, source))
+ sinks.foreach(sink => depGraph.addPairWithEdge(sink, memNode))
+ sources.foreach(source => depGraph.addPairWithEdge(memNode, source))
case Attach(_, exprs) => // Add edge between each expression
exprs.flatMap(getDeps(_)).toSet.subsets(2).map(_.toList).foreach {
case Seq(a, b) =>
- depGraph.addEdge(a, b)
- depGraph.addEdge(b, a)
+ depGraph.addPairWithEdge(a, b)
+ depGraph.addPairWithEdge(b, a)
}
case Connect(_, loc, expr) =>
// This match enforces the low Firrtl requirement of expanded connections
val node = getDeps(loc) match { case Seq(elt) => elt }
- getDeps(expr).foreach(ref => depGraph.addEdge(node, ref))
+ getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref))
// Simulation constructs are treated as top-level outputs
case Stop(_,_, clk, en) =>
- Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
case Print(_, _, args, clk, en) =>
- (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
case Block(stmts) => stmts.foreach(onStmt(_))
case ignore @ (_: IsInvalid | _: WDefInstance | EmptyStmt) => // do nothing
case other => throw new Exception(s"Unexpected Statement $other")
@@ -153,30 +153,30 @@ class DeadCodeElimination extends Transform {
val node = LogicNode(ext)
// Don't touch external modules *unless* they are specifically marked as doTouch
// Simply marking the extmodule itself is sufficient to prevent inputs from being removed
- if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, node)
+ if (!doTouchExtMods.contains(ext.name)) depGraph.addPairWithEdge(circuitSink, node)
ext.ports.foreach {
case Port(_, pname, _, AnalogType(_)) =>
- depGraph.addEdge(LogicNode(ext.name, pname), node)
- depGraph.addEdge(node, LogicNode(ext.name, pname))
+ depGraph.addPairWithEdge(LogicNode(ext.name, pname), node)
+ depGraph.addPairWithEdge(node, LogicNode(ext.name, pname))
case Port(_, pname, Output, _) =>
val portNode = LogicNode(ext.name, pname)
- depGraph.addEdge(portNode, node)
+ depGraph.addPairWithEdge(portNode, node)
// Also mark all outputs as circuit sinks (unless marked doTouch obviously)
- if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, portNode)
- case Port(_, pname, Input, _) => depGraph.addEdge(node, LogicNode(ext.name, pname))
+ if (!doTouchExtMods.contains(ext.name)) depGraph.addPairWithEdge(circuitSink, portNode)
+ case Port(_, pname, Input, _) => depGraph.addPairWithEdge(node, LogicNode(ext.name, pname))
}
}
// Connect circuitSink to ALL top-level ports (we don't want to change the top-level interface)
val topModule = c.modules.find(_.name == c.main).get
val topOutputs = topModule.ports.foreach { port =>
- depGraph.addEdge(circuitSink, LogicNode(c.main, port.name))
+ depGraph.addPairWithEdge(circuitSink, LogicNode(c.main, port.name))
}
depGraph
}
private def deleteDeadCode(instMap: collection.Map[String, String],
- deadNodes: Set[LogicNode],
+ deadNodes: collection.Set[LogicNode],
moduleMap: collection.Map[String, DefModule],
renames: RenameMap)
(mod: DefModule): Option[DefModule] = {
@@ -243,9 +243,9 @@ class DeadCodeElimination extends Transform {
val c = state.circuit
val moduleMap = c.modules.map(m => m.name -> m).toMap
val iGraph = new InstanceGraph(c)
- val moduleDeps = iGraph.graph.edges.map { case (k,v) =>
+ val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) =>
k.module -> v.map(i => i.name -> i.module).toMap
- }
+ })
val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_))
val depGraph = {
@@ -255,7 +255,7 @@ class DeadCodeElimination extends Transform {
dontTouches.foreach { dontTouch =>
// Ensure that they are actually found
if (vertices.contains(dontTouch)) {
- dGraph.addEdge(circuitSink, dontTouch)
+ dGraph.addPairWithEdge(circuitSink, dontTouch)
} else {
val (root, tail) = Utils.splitRef(dontTouch.e1)
DontTouchAnnotation.errorNotFound(root.serialize, tail.serialize)
diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
index 2c12d4ca..6c8a2f20 100644
--- a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
+++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
@@ -23,6 +23,31 @@ class CheckCombLoopsSpec extends SimpleTransformSpec {
new MiddleFirrtlToLowFirrtl
)
+ "Loop-free circuit" should "not throw an exception" in {
+ val input = """circuit hasnoloops :
+ | module thru :
+ | input in1 : UInt<1>
+ | input in2 : UInt<1>
+ | output out1 : UInt<1>
+ | output out2 : UInt<1>
+ | out1 <= in1
+ | out2 <= in2
+ | module hasnoloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | output b : UInt<1>
+ | wire x : UInt<1>
+ | inst inner of thru
+ | inner.in1 <= a
+ | x <= inner.out1
+ | inner.in2 <= x
+ | b <= inner.out2
+ |""".stripMargin
+
+ val writer = new java.io.StringWriter
+ compile(CircuitState(parse(input), ChirrtlForm, None), writer)
+ }
+
"Simple combinational loop" should "throw an exception" in {
val input = """circuit hasloops :
| module hasloops :
diff --git a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala
index 300fa8c4..3e517079 100644
--- a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala
+++ b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala
@@ -11,7 +11,7 @@ import firrtl.passes._
import firrtlTests._
class InstanceGraphTests extends FirrtlFlatSpec {
- private def getEdgeSet(graph: DiGraph[String]): Map[String, Set[String]] = {
+ private def getEdgeSet(graph: DiGraph[String]): collection.Map[String, collection.Set[String]] = {
(graph.getVertices map {v => (v, graph.getEdges(v))}).toMap
}
diff --git a/src/test/scala/firrtlTests/graph/DiGraphTests.scala b/src/test/scala/firrtlTests/graph/DiGraphTests.scala
index 9eb1c7f8..da268e4f 100644
--- a/src/test/scala/firrtlTests/graph/DiGraphTests.scala
+++ b/src/test/scala/firrtlTests/graph/DiGraphTests.scala
@@ -43,8 +43,8 @@ class DiGraphTests extends FirrtlFlatSpec {
a [cyclicGraph.CyclicException] should be thrownBy cyclicGraph.linearize
- acyclicGraph.reverse.edges should equal (reversedAcyclicGraph.edges)
+ acyclicGraph.reverse.getEdgeMap should equal (reversedAcyclicGraph.getEdgeMap)
- degenerateGraph.edges should equal (degenerateGraph.reverse.edges)
+ degenerateGraph.getEdgeMap should equal (degenerateGraph.reverse.getEdgeMap)
}