aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Magyar2017-02-26 16:11:37 -0800
committerAlbert Magyar2017-03-17 14:26:44 -0700
commit13fc27f35e85026c002e644b61c32268bd258d78 (patch)
tree2c3a36c3d7fa8c7bf44c2a48ac7218c539bafac6 /src
parent3608401852baa18b4deaa22669529830b751901a (diff)
Add utilites for digraphs and netlist analyses
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/analyses/Netlist.scala74
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala306
-rw-r--r--src/test/scala/firrtlTests/graph/DiGraphTests.scala38
3 files changed, 418 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/analyses/Netlist.scala b/src/main/scala/firrtl/analyses/Netlist.scala
new file mode 100644
index 00000000..c83dcc2b
--- /dev/null
+++ b/src/main/scala/firrtl/analyses/Netlist.scala
@@ -0,0 +1,74 @@
+package firrtl.analyses
+
+import scala.collection.mutable
+
+import firrtl._
+import firrtl.ir._
+import firrtl.graph._
+import firrtl.Utils._
+import firrtl.Mappers._
+
+
+/** A class representing the instance hierarchy of a working IR Circuit
+ *
+ * @constructor constructs an instance graph from a Circuit
+ * @param c the Circuit to analyze
+ */
+class InstanceGraph(c: Circuit) {
+
+ private def collectInstances(insts: mutable.Set[WDefInstance])(s: Statement): Statement = s match {
+ case i: WDefInstance =>
+ insts += i
+ i
+ case _ =>
+ s map collectInstances(insts)
+ }
+
+ private val moduleMap = c.modules.map({m => (m.name,m) }).toMap
+ private val childInstances =
+ new mutable.HashMap[String,mutable.Set[WDefInstance]]
+ for (m <- c.modules) {
+ childInstances(m.name) = new mutable.HashSet[WDefInstance]
+ m map collectInstances(childInstances(m.name))
+ }
+ private val instanceGraph = new MutableDiGraph[WDefInstance]
+ private val instanceQueue = new mutable.Queue[WDefInstance]
+ private val topInstance = WDefInstance(c.main,c.main) // top instance
+ instanceQueue.enqueue(topInstance)
+ while (!instanceQueue.isEmpty) {
+ val current = instanceQueue.dequeue
+ instanceGraph.addVertex(current)
+ for (child <- childInstances(current.module)) {
+ if (!instanceGraph.contains(child)) {
+ instanceQueue.enqueue(child)
+ }
+ instanceGraph.addEdge(current,child)
+ }
+ }
+
+ /** A directed graph showing the instance dependencies among modules
+ * in the circuit. Every WDefInstance of a module has an edge to
+ * every WDefInstance arising from every instance statement in
+ * that module.
+ */
+ lazy val graph = DiGraph(instanceGraph)
+
+ /** A list of absolute paths (each represented by a Seq of instances)
+ * of all module instances in the Circuit.
+ */
+ lazy val fullHierarchy = graph.pathsInDAG(topInstance)
+
+ /** Finds the absolute paths (each represented by a Seq of instances
+ * representing the chain of hierarchy) of all instances of a
+ * particular module.
+ *
+ * @param module the name of the selected module
+ * @return a Seq[Seq[WDefInstance]] of absolute instance paths
+ */
+ def findInstancesInHierarchy(module: String): Seq[Seq[WDefInstance]] = {
+ val instances = graph.getVertices.filter(_.module == module).toSeq
+ instances flatMap { i => fullHierarchy(i) }
+ }
+
+}
+
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
new file mode 100644
index 00000000..aa93fd5f
--- /dev/null
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -0,0 +1,306 @@
+package firrtl.graph
+
+import scala.collection.immutable.{Set, Map, HashSet, HashMap}
+import scala.collection.mutable
+import scala.collection.mutable.MultiMap
+
+/** Represents common behavior of all directed graphs */
+trait DiGraphLike[T] {
+ /** Check whether the graph contains vertex v */
+ def contains(v: T): Boolean
+
+ /** Get all vertices in the graph
+ * @return a Set[T] of all vertices in the graph
+ */
+ def getVertices: collection.Set[T]
+
+ /** Get all edges of a node
+ * @param v the specified node
+ * @return a Set[T] of all vertices that v has edges to
+ */
+ def getEdges(v: T): collection.Set[T]
+}
+
+/** A class to represent a mutable directed graph with nodes of type T
+ *
+ * @constructor Create a new graph with the provided edge data
+ * @param edges a mutable.MultiMap[T,T] of edge data
+ *
+ * For the edge data MultiMap, the values associated with each vertex
+ * u in the graph are the vertices with inedges from u
+ */
+class MutableDiGraph[T](
+ private[graph] val edgeData: MultiMap[T,T] =
+ new mutable.HashMap[T, mutable.Set[T]] with MultiMap[T, T]) extends DiGraphLike[T] {
+
+ // Inherited methods from DiGraphLike
+ def contains(v: T) = edgeData.contains(v)
+ def getVertices = edgeData.keySet
+ def getEdges(v: T) = edgeData(v)
+
+ /** Add vertex v to the graph
+ * @return v, the added vertex
+ */
+ def addVertex(v: T): T = {
+ edgeData.getOrElseUpdate(v,new mutable.HashSet[T])
+ v
+ }
+
+ /** Add edge (u,v) to the graph */
+ def addEdge(u: T, v: T) = {
+ // Add v to keys to maintain invariant that all vertices are keys
+ // of edge data
+ edgeData.getOrElseUpdate(v, new mutable.HashSet[T])
+ edgeData.addBinding(u,v)
+ }
+}
+
+/** A companion to create immutable DiGraphs from mutable data */
+object DiGraph {
+ /** Create a DiGraph from a MutableDigraph, representing the same graph */
+ def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] =
+ new DiGraph((mdg.edgeData mapValues { _.toSet }).toMap[T, Set[T]])
+
+ /** Create a DiGraph from a MultiMap[T] of edge data */
+ def apply[T](edgeData: MultiMap[T,T]): DiGraph[T] =
+ new DiGraph((edgeData mapValues { _.toSet }).toMap[T, Set[T]])
+
+ /** Create a DiGraph from a Map[T,Set[T]] of edge data */
+ def apply[T](edgeData: Map[T,Set[T]]) = new DiGraph(edgeData)
+}
+
+/**
+ * A class to represent an immutable directed graph with nodes of
+ * type T
+ *
+ * @constructor Create a new graph with the provided edge data
+ * @param edges a Map[T,Set[T]] of edge data
+ *
+ * For the edge data Map, the value associated with each vertex u in
+ * the graph is a Set[T] of nodes where for each node v in the set,
+ * the directed edge (u,v) exists in the graph.
+ */
+class DiGraph[T] (val edges: Map[T, Set[T]]) extends DiGraphLike[T] {
+
+ /** An exception that is raised when an assumed DAG has a cycle */
+ class CyclicException extends Exception("No valid linearization for cyclic graph")
+ /** An exception that is raised when attempting to find an unreachable node */
+ class PathNotFoundException extends Exception("Unreachable node")
+
+ // Inherited methods from DiGraphLike
+ def contains(v: T) = edges.contains(v)
+ def getVertices = edges.keySet
+ def getEdges(v: T) = edges.getOrElse(v, new HashSet[T])
+
+ /** Find all sources in the graph
+ *
+ * @return a Set[T] of source nodes
+ */
+ def findSources: Set[T] = edges.keySet -- edges.values.flatten.toSet
+
+ /** Find all sinks in the graph
+ *
+ * @return a Set[T] of sink nodes
+ */
+ def findSinks: Set[T] = reverse.findSources
+
+ /** Linearizes (topologically sorts) a DAG
+ *
+ * @param root the start node
+ * @throws CyclicException if the graph is cyclic
+ * @return a Map[T,T] from each visited node to its predecessor in the
+ * traversal
+ */
+ def linearize: Seq[T] = {
+ // permanently marked nodes are implicitly held in order
+ val order = new mutable.ArrayBuffer[T]
+ // invariant: no intersection between unmarked and tempMarked
+ val unmarked = new mutable.HashSet[T]
+ val tempMarked = new mutable.HashSet[T]
+
+ def visit(n: T): Unit = {
+ if (tempMarked.contains(n)) {
+ throw new CyclicException
+ }
+ if (unmarked.contains(n)) {
+ tempMarked += n
+ unmarked -= n
+ for (m <- getEdges(n)) {
+ visit(m)
+ }
+ tempMarked -= n
+ order.append(n)
+ }
+ }
+
+ unmarked ++= getVertices
+ while (!unmarked.isEmpty) {
+ visit(unmarked.head)
+ }
+
+ // visited nodes are in post-traversal order, so must be reversed
+ order.reverse.toSeq
+ }
+
+ /** Performs breadth-first search on the directed graph
+ *
+ * @param root the start node
+ * @return a Map[T,T] from each visited node to its predecessor in the
+ * traversal
+ */
+ def BFS(root: T): Map[T,T] = {
+ val prev = new mutable.HashMap[T,T]
+ val queue = new mutable.Queue[T]
+ queue.enqueue(root)
+ while (!queue.isEmpty) {
+ val u = queue.dequeue
+ for (v <- getEdges(u)) {
+ if (!prev.contains(v)) {
+ prev(v) = u
+ queue.enqueue(v)
+ }
+ }
+ }
+ prev.toMap
+ }
+
+ /** Finds the set of nodes reachable from a particular node
+ *
+ * @param root the start node
+ * @return a Set[T] of nodes reachable from the root
+ */
+ def reachableFrom(root: T): Set[T] = BFS(root).keys.toSet
+
+ /** Finds a path (if one exists) from one node to another
+ *
+ * @param start the start node
+ * @param end the destination node
+ * @throws PathNotFoundException
+ * @return a Seq[T] of nodes defining an arbitrary valid path
+ */
+ def path(start: T, end: T) = {
+ val nodePath = new mutable.ArrayBuffer[T]
+ val prev = BFS(start)
+ nodePath += end
+ while (nodePath.last != start && prev.contains(nodePath.last)) {
+ nodePath += prev(nodePath.last)
+ }
+ if (nodePath.last != start) {
+ throw new PathNotFoundException
+ }
+ nodePath.toSeq.reverse
+ }
+
+ /** Finds the strongly connected components in the graph
+ *
+ * @return a Seq of Seq[T], each containing nodes of an SCC in traversable order
+ */
+ def findSCCs: Seq[Seq[T]] = {
+ var counter: BigInt = 0
+ val stack = new mutable.Stack[T]
+ val onstack = new mutable.HashSet[T]
+ val indices = new mutable.HashMap[T, BigInt]
+ val lowlinks = new mutable.HashMap[T, BigInt]
+ val sccs = new mutable.ArrayBuffer[Seq[T]]
+
+ def strongConnect(v: T): Unit = {
+ indices(v) = counter
+ lowlinks(v) = counter
+ counter = counter + 1
+ stack.push(v)
+ onstack += v
+ for (w <- getEdges(v)) {
+ if (!indices.contains(w)) {
+ strongConnect(w)
+ lowlinks(v) = lowlinks(v).min(lowlinks(w))
+ } else if (onstack.contains(w)) {
+ lowlinks(v) = lowlinks(v).min(indices(w))
+ }
+ }
+ if (lowlinks(v) == indices(v)) {
+ val scc = new mutable.ArrayBuffer[T]
+ do {
+ val w = stack.pop
+ onstack -= w
+ scc += w
+ }
+ while (scc.last != v);
+ sccs.append(scc.toSeq)
+ }
+ }
+
+ for (v <- getVertices) {
+ strongConnect(v)
+ }
+
+ sccs.toSeq
+ }
+
+ /** Finds all paths starting at a particular node in a DAG
+ *
+ * WARNING: This is an exponential time algorithm (as any algorithm
+ * must be for this problem), but is useful for flattening circuit
+ * graph hierarchies. Each path is represented by a Seq[T] of nodes
+ * in a traversable order.
+ *
+ * @param start the node to start at
+ * @return a Map[T,Seq[Seq[T]]] where the value associated with v is the Seq of all paths from start to v
+ */
+ def pathsInDAG(start: T): Map[T,Seq[Seq[T]]] = {
+ // paths(v) holds the set of paths from start to v
+ val paths = new mutable.HashMap[T,mutable.Set[Seq[T]]] with mutable.MultiMap[T,Seq[T]]
+ val queue = new mutable.Queue[T]
+ val visited = new mutable.HashSet[T]
+ paths.addBinding(start,Seq(start))
+ queue.enqueue(start)
+ visited += start
+ while (!queue.isEmpty) {
+ val current = queue.dequeue
+ for (v <- getEdges(current)) {
+ if (!visited.contains(v)) {
+ queue.enqueue(v)
+ visited += v
+ }
+ for (p <- paths(current)) {
+ paths.addBinding(v, p :+ v)
+ }
+ }
+ }
+ (paths map { case (k,v) => (k,v.toSeq) }).toMap
+ }
+
+ /** Returns a graph with all edges reversed */
+ def reverse: DiGraph[T] = {
+ val mdg = new MutableDiGraph[T]
+ edges foreach { case (u,edges) => edges.foreach({ v => mdg.addEdge(v,u) }) }
+ DiGraph(mdg)
+ }
+
+ /** Return a graph with only a subset of the nodes
+ *
+ * Any path between two non-deleted nodes (u,v) that traverses only
+ * deleted nodes will be transformed into an edge (u,v).
+ *
+ * @param vprime the Set[T] of desired vertices
+ * @throws IllegalArgumentException if vprime is not a subset of V
+ * @return the simplified graph
+ */
+ def simplify(vprime: Set[T]): DiGraph[T] = {
+ require(vprime.subsetOf(edges.keySet))
+ val eprime = vprime.map( v => (v,reachableFrom(v) & vprime) ).toMap
+ new DiGraph(eprime)
+ }
+
+ /** Return a graph with all the nodes of the current graph transformed
+ * by a function. Edge connectivity will be the same as the current
+ * graph.
+ *
+ * @param f A function {(T) => Q} that transforms each node
+ * @return a transformed DiGraph[Q]
+ */
+ def transformNodes[Q](f: (T) => Q): DiGraph[Q] = {
+ val eprime = edges.map({ case (k,v) => (f(k),v.map(f(_))) })
+ new DiGraph(eprime)
+ }
+
+}
diff --git a/src/test/scala/firrtlTests/graph/DiGraphTests.scala b/src/test/scala/firrtlTests/graph/DiGraphTests.scala
new file mode 100644
index 00000000..6546e147
--- /dev/null
+++ b/src/test/scala/firrtlTests/graph/DiGraphTests.scala
@@ -0,0 +1,38 @@
+package firrtlTests.graph
+
+import java.io._
+import org.scalatest._
+import org.scalatest.prop._
+import org.scalatest.Matchers._
+import firrtl.graph._
+import firrtlTests._
+
+class DiGraphTests extends FirrtlFlatSpec {
+
+ val acyclicGraph = DiGraph(Map(
+ "a" -> Set("b","c"),
+ "b" -> Set("d"),
+ "c" -> Set("d"),
+ "d" -> Set("e"),
+ "e" -> Set.empty[String]))
+
+ val cyclicGraph = DiGraph(Map(
+ "a" -> Set("b","c"),
+ "b" -> Set("d"),
+ "c" -> Set("d"),
+ "d" -> Set("a")))
+
+
+ acyclicGraph.findSCCs.filter(_.length > 1) shouldBe empty
+
+ cyclicGraph.findSCCs.filter(_.length > 1) should not be empty
+
+ acyclicGraph.path("a","e") should not be empty
+
+ an [acyclicGraph.PathNotFoundException] should be thrownBy acyclicGraph.path("e","a")
+
+ acyclicGraph.linearize.head should equal ("a")
+
+ a [cyclicGraph.CyclicException] should be thrownBy cyclicGraph.linearize
+
+}