aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJiuyang Liu2020-08-01 00:25:13 +0800
committerGitHub2020-07-31 16:25:13 +0000
commitf22652a330afe1daa77be2aadb525d65ab05e9fe (patch)
tree59424ccbe5634993b62a3040f74d077e66ed7c1d /src
parentba2be50f42c1ec760decc22cfda73fbd39113b53 (diff)
[WIP] Implement CircuitGraph and IRLookup to firrtl.analyses (#1603)
* WIP Commit * Add EdgeDataDiGraph with views to amortize graph construction * WIP, got basic structure, need tests to pipeclean * First tests pass. Need more. * Tests pass, more need to be written * More tests pass! Things should work, except for memories * Added clearPrev to fix digraph uses where caching prev breaks * Removed old Component. Documented IRLookup * Added comments. Make prev arg to getEdges * WIP: Refactoring for CircuitGraph * Refactored into CircuitGraph. Can do topological module analysis * Removed old versions * Added support for memories * Added cached test * More stufffff * Added implicit caching of connectivity * Added tests for IRLookup, and others * Many major changes. Replaced CircuitGraph as ConnectionGraph Added CircuitGraph to be top-level user-facing object ConnectionGraph now automatically shortcuts getEdges ConnectionGraph overwrites BFS as PriorityBFS Added leafModule to Target Added lookup by kind to IRLookup Added more tests * Reordered stuff in ConnectionGraph * Made path work with deep hierarchies. Added PML for IllegalClockCrossings * Made pathsInDAG work with current shortcut semantics * Bugfix: check pathless targets when shortcutting paths * Added documentation/licenses * Removed UnnamedToken and related functionality * Added documentation of ConnectionGraph * Added back topo, needed for correct solving of intermediate modules * Bugfix. Cache intermediate clockSources from same BFS with same root, but not BFS with different root * Added literal/invalid clock source, and unknown top for getclocksource * Bugfix for clocks in bundles * Add CompleteTargetSerializer and test * remove ClockFinder, be able to compile. * test is able to compile, but need to fix. * public and abstract DiGraph, remove DiGraphLike. * revert some DiGraph code, ConnectionGraphSpec passed. * CircuitGraphSpec passed. * minimize diff between master * codes clean up * override linearize and revert DiGraph * keep DiGraph unchanged. * make ci happy again. * codes clean up. * bug fix for rebase * remove wir * make scaladoc happy again. * update for review. * add some documentation. * remove tag * wip IRLookup * code clean up and add some doucmentations. * IRLookup cache with ModuleTarget guarded. * make unidoc and 2.13 happy Co-authored-by: Adam Izraelevitz <azidar@gmail.com> Co-authored-by: Albert Magyar <albert.magyar@gmail.com> Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/analyses/CircuitGraph.scala137
-rw-r--r--src/main/scala/firrtl/analyses/ConnectionGraph.scala573
-rw-r--r--src/main/scala/firrtl/analyses/IRLookup.scala265
-rw-r--r--src/main/scala/firrtl/analyses/InstanceGraph.scala40
-rw-r--r--src/main/scala/firrtl/annotations/Target.scala99
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala21
-rw-r--r--src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala50
-rw-r--r--src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala107
-rw-r--r--src/test/scala/firrtlTests/analyses/IRLookupSpec.scala296
-rw-r--r--src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala2
10 files changed, 1554 insertions, 36 deletions
diff --git a/src/main/scala/firrtl/analyses/CircuitGraph.scala b/src/main/scala/firrtl/analyses/CircuitGraph.scala
new file mode 100644
index 00000000..a5c811c6
--- /dev/null
+++ b/src/main/scala/firrtl/analyses/CircuitGraph.scala
@@ -0,0 +1,137 @@
+// See LICENSE for license details.
+
+package firrtl.analyses
+
+import firrtl.Kind
+import firrtl.annotations.TargetToken.{Instance, OfModule}
+import firrtl.annotations._
+import firrtl.ir.{Circuit, DefInstance}
+
+/** Use to construct [[CircuitGraph]]
+ * Also contains useful related functions
+ */
+object CircuitGraph {
+
+ /** Build a CircuitGraph
+ * [[firrtl.ir.Circuit]] must be of MiddleForm or lower
+ * @param circuit
+ * @return
+ */
+ def apply(circuit: Circuit): CircuitGraph = new CircuitGraph(ConnectionGraph(circuit))
+
+ /** Return a nicely-formatted string of a path of [[firrtl.annotations.ReferenceTarget]]
+ * @param connectionPath
+ * @param tab
+ * @return
+ */
+ def prettyToString(connectionPath: Seq[ReferenceTarget], tab: String = ""): String = {
+ tab + connectionPath.mkString(s"\n$tab")
+ }
+}
+
+/** Graph-representation of a FIRRTL Circuit
+ *
+ * Requires Middle FIRRTL
+ * Useful for writing design-specific custom-transforms that require connectivity information
+ *
+ * @param connectionGraph Source-to-sink connectivity graph
+ */
+class CircuitGraph private[analyses] (val connectionGraph: ConnectionGraph) {
+
+ // Reverse (sink-to-source) connectivity graph
+ lazy val reverseConnectionGraph = connectionGraph.reverseConnectionGraph
+
+ // AST Circuit
+ val circuit = connectionGraph.circuit
+
+ // AST Information
+ val irLookup = connectionGraph.irLookup
+
+ // Module/Instance Hierarchy information
+ lazy val instanceGraph = new InstanceGraph(circuit)
+
+ // Per module, which modules does it instantiate
+ lazy val moduleChildren = instanceGraph.getChildrenInstanceOfModule
+
+ // Top-level module target
+ val main = ModuleTarget(circuit.main, circuit.main)
+
+ /** Given a signal, return the signals that it drives
+ * @param source
+ * @return
+ */
+ def fanOutSignals(source: ReferenceTarget): Set[ReferenceTarget] = connectionGraph.getEdges(source).toSet
+
+ /** Given a signal, return the signals that drive it
+ * @param sink
+ * @return
+ */
+ def fanInSignals(sink: ReferenceTarget): Set[ReferenceTarget] = reverseConnectionGraph.getEdges(sink).toSet
+
+ /** Return the absolute paths of all instances of this module.
+ *
+ * For example:
+ * - Top instantiates a1 of A and a2 of A
+ * - A instantiates b1 of B and b2 of B
+ * Then, absolutePaths of B will return:
+ * - Seq(~Top|Top/a1:A/b1:B, ~Top|Top/a1:A/b2:B, ~Top|Top/a2:A/b1:B, ~Top|Top/a2:A/b2:B)
+ * @param mt
+ * @return
+ */
+ def absolutePaths(mt: ModuleTarget): Seq[IsModule] = instanceGraph.findInstancesInHierarchy(mt.module).map {
+ case seq if seq.nonEmpty => seq.foldLeft(CircuitTarget(circuit.main).module(circuit.main): IsModule) {
+ case (it, DefInstance(_, instance, ofModule, _)) => it.instOf(instance, ofModule)
+ }
+ }
+
+ /** Return the sequence of nodes from source to sink, inclusive
+ * @param source
+ * @param sink
+ * @return
+ */
+ def connectionPath(source: ReferenceTarget, sink: ReferenceTarget): Seq[ReferenceTarget] =
+ connectionGraph.path(source, sink)
+
+ /** Return a reference to all nodes of given kind, directly contained in the referenced module/instance
+ * Path can be either a module, or an instance
+ * @param path
+ * @param kind
+ * @return
+ */
+ def localReferences(path: IsModule, kind: Kind): Seq[ReferenceTarget] = {
+ val leafModule = path.leafModule
+ irLookup.kindFinder(ModuleTarget(circuit.main, leafModule), kind).map(_.setPathTarget(path))
+ }
+
+ /** Return a reference to all nodes of given kind, contained in the referenced module/instance or any child instance
+ * Path can be either a module, or an instance
+ * @param kind
+ * @param path
+ * @return
+ */
+ def deepReferences(kind: Kind, path: IsModule = ModuleTarget(circuit.main, circuit.main)): Seq[ReferenceTarget] = {
+ val leafModule = path.leafModule
+ val children = moduleChildren(leafModule)
+ val localRefs = localReferences(path, kind)
+ localRefs ++ children.flatMap {
+ case (Instance(inst), OfModule(ofModule)) => deepReferences(kind, path.instOf(inst, ofModule))
+ }
+ }
+
+ /** Return all absolute references to signals of the given kind directly contained in the module
+ * @param moduleTarget
+ * @param kind
+ * @return
+ */
+ def absoluteReferences(moduleTarget: ModuleTarget, kind: Kind): Seq[ReferenceTarget] = {
+ localReferences(moduleTarget, kind).flatMap(makeAbsolute)
+ }
+
+ /** Given a reference, return all instances of that reference (i.e. with absolute paths)
+ * @param reference
+ * @return
+ */
+ def makeAbsolute(reference: ReferenceTarget): Seq[ReferenceTarget] = {
+ absolutePaths(reference.moduleTarget).map(abs => reference.setPathTarget(abs))
+ }
+}
diff --git a/src/main/scala/firrtl/analyses/ConnectionGraph.scala b/src/main/scala/firrtl/analyses/ConnectionGraph.scala
new file mode 100644
index 00000000..0e13711a
--- /dev/null
+++ b/src/main/scala/firrtl/analyses/ConnectionGraph.scala
@@ -0,0 +1,573 @@
+// See LICENSE for license details.
+
+package firrtl.analyses
+
+import firrtl.Mappers._
+import firrtl.annotations.{TargetToken, _}
+import firrtl.graph.{CyclicException, DiGraph, MutableDiGraph}
+import firrtl.ir._
+import firrtl.passes.MemPortUtils
+import firrtl.{InstanceKind, PortKind, SinkFlow, SourceFlow, Utils, WInvalid}
+
+import scala.collection.mutable
+
+/** Class to represent circuit connection.
+ *
+ * @param circuit firrtl AST of this graph.
+ * @param digraph Directed graph of ReferenceTarget in the AST.
+ * @param irLookup [[IRLookup]] instance of circuit graph.
+ * */
+class ConnectionGraph protected(val circuit: Circuit,
+ val digraph: DiGraph[ReferenceTarget],
+ val irLookup: IRLookup)
+ extends DiGraph[ReferenceTarget](digraph.getEdgeMap.asInstanceOf[mutable.LinkedHashMap[ReferenceTarget, mutable.LinkedHashSet[ReferenceTarget]]]) {
+
+ lazy val serialize: String = s"""{
+ |${getEdgeMap.map { case (k, vs) =>
+ s""" "$k": {
+ | "kind": "${irLookup.kind(k)}",
+ | "type": "${irLookup.tpe(k)}",
+ | "expr": "${irLookup.expr(k, irLookup.flow(k))}",
+ | "sinks": [${vs.map { v => s""""$v"""" }.mkString(", ")}],
+ | "declaration": "${irLookup.declaration(k)}"
+ | }""".stripMargin }.mkString(",\n")}
+ |}""".stripMargin
+
+ /** Used by BFS to map each visited node to the list of instance inputs visited thus far
+ *
+ * When BFS descends into a child instance, the child instance port is prepended to the list
+ * When BFS ascends into a parent instance, the head of the list is removed
+ * In essence, the list is a stack that you push when descending, and pop when ascending
+ *
+ * Because the search is BFS not DFS, we must record the state of the stack for each edge node, so
+ * when that edge node is finally visited, we know the state of the stack
+ *
+ * For example:
+ * circuit Top:
+ * module Top:
+ * input in: UInt
+ * output out: UInt
+ * inst a of A
+ * a.in <= in
+ * out <= a.out
+ * module A:
+ * input in: UInt
+ * output out: UInt
+ * inst b of B
+ * b.in <= in
+ * out <= b.out
+ * module B:
+ * input in: UInt
+ * output out: UInt
+ * out <= in
+ *
+ * We perform BFS starting at `Top>in`,
+ * Node [[portConnectivityStack]]
+ *
+ * Top>in List()
+ * Top>a.in List()
+ * Top/a:A>in List(Top>a.in)
+ * Top/a:A>b.in List(Top>a.in)
+ * Top/a:A/b:B/in List(Top/a:A>b.in, Top>a.in)
+ * Top/a:A/b:B/out List(Top/a:A>b.in, Top>a.in)
+ * Top/a:A>b.out List(Top>a.in)
+ * Top/a:A>out List(Top>a.in)
+ * Top>a.out List()
+ * Top>out List()
+ * when we reach `Top/a:A>`, `Top>a.in` will be pushed into [[portConnectivityStack]];
+ * when we reach `Top/a:A/b:B>`, `Top/a:A>b.in` will be pushed into [[portConnectivityStack]];
+ * when we leave `Top/a:A/b:B>`, `Top/a:A>b.in` will be popped from [[portConnectivityStack]];
+ * when we leave `Top/a:A>`, `Top/a:A>b.in` will be popped from [[portConnectivityStack]].
+ */
+ private val portConnectivityStack: mutable.HashMap[ReferenceTarget, List[ReferenceTarget]] =
+ mutable.HashMap.empty[ReferenceTarget, List[ReferenceTarget]]
+
+ /** Records connectivities found while BFS is executing, from a module's source port to sink ports of a module
+ *
+ * All keys and values are local references.
+ *
+ * A BFS search will first query this map. If the query fails, then it continues and populates the map. If the query
+ * succeeds, then the BFS shortcuts with the values provided by the query.
+ *
+ * Because this BFS implementation uses a priority queue which prioritizes exploring deeper instances first, a
+ * successful query during BFS will only occur after all paths which leave the module from that reference have
+ * already been searched.
+ */
+ private val bfsShortCuts: mutable.HashMap[ReferenceTarget, mutable.HashSet[ReferenceTarget]] =
+ mutable.HashMap.empty[ReferenceTarget, mutable.HashSet[ReferenceTarget]]
+
+ /** Records connectivities found after BFS is completed, from a module's source port to sink ports of a module
+ *
+ * All keys and values are local references.
+ *
+ * If its keys contain a reference, then the value will be complete, in that all paths from the reference out of
+ * the module will have been explored
+ *
+ * For example, if Top>in connects to Top>out1 and Top>out2, then foundShortCuts(Top>in) will contain
+ * Set(Top>out1, Top>out2), not Set(Top>out1) or Set(Top>out2)
+ */
+ private val foundShortCuts: mutable.HashMap[ReferenceTarget, mutable.HashSet[ReferenceTarget]] =
+ mutable.HashMap.empty[ReferenceTarget, mutable.HashSet[ReferenceTarget]]
+
+ /** Returns whether a previous BFS search has found a shortcut out of a module, starting from target
+ *
+ * @param target first target to find shortcut.
+ * @return true if find a shortcut.
+ */
+ def hasShortCut(target: ReferenceTarget): Boolean = getShortCut(target).nonEmpty
+
+ /** Optionally returns the shortcut a previous BFS search may have found out of a module, starting from target
+ *
+ * @param target first target to find shortcut.
+ * @return [[firrtl.annotations.ReferenceTarget]] of short cut.
+ */
+ def getShortCut(target: ReferenceTarget): Option[Set[ReferenceTarget]] =
+ foundShortCuts.get(target.pathlessTarget).map(set => set.map(_.setPathTarget(target.pathTarget)).toSet)
+
+ /** Returns the shortcut a previous BFS search may have found out of a module, starting from target
+ *
+ * @param target first target to find shortcut.
+ * @return [[firrtl.annotations.ReferenceTarget]] of short cut.
+ */
+ def shortCut(target: ReferenceTarget): Set[ReferenceTarget] = getShortCut(target).get
+
+ /** @return a new, reversed connection graph where edges point from sinks to sources. */
+ def reverseConnectionGraph: ConnectionGraph = new ConnectionGraph(circuit, digraph.reverse, irLookup)
+
+ override def BFS(root: ReferenceTarget, blacklist: collection.Set[ReferenceTarget]): collection.Map[ReferenceTarget, ReferenceTarget] = {
+ val prev = new mutable.LinkedHashMap[ReferenceTarget, ReferenceTarget]()
+ val ordering = new Ordering[ReferenceTarget] {
+ override def compare(x: ReferenceTarget, y: ReferenceTarget): Int = x.path.size - y.path.size
+ }
+ val bfsQueue = new mutable.PriorityQueue[ReferenceTarget]()(ordering)
+ bfsQueue.enqueue(root)
+ while (bfsQueue.nonEmpty) {
+ val u = bfsQueue.dequeue
+ for (v <- getEdges(u)) {
+ if (!prev.contains(v) && !blacklist.contains(v)) {
+ prev(v) = u
+ bfsQueue.enqueue(v)
+ }
+ }
+ }
+
+ foundShortCuts ++= bfsShortCuts
+ bfsShortCuts.clear()
+ portConnectivityStack.clear()
+
+ prev
+ }
+
+ /** Linearizes (topologically sorts) a DAG
+ *
+ * @throws firrtl.graph.CyclicException if the graph is cyclic
+ * @return a Seq[T] describing the topological order of the DAG
+ * traversal
+ */
+ override def linearize: Seq[ReferenceTarget] = {
+ // permanently marked nodes are implicitly held in order
+ val order = new mutable.ArrayBuffer[ReferenceTarget]
+ // invariant: no intersection between unmarked and tempMarked
+ val unmarked = new mutable.LinkedHashSet[ReferenceTarget]
+ val tempMarked = new mutable.LinkedHashSet[ReferenceTarget]
+ val finished = new mutable.LinkedHashSet[ReferenceTarget]
+
+ case class LinearizeFrame[A](v: A, expanded: Boolean)
+ val callStack = mutable.Stack[LinearizeFrame[ReferenceTarget]]()
+
+ unmarked ++= getVertices
+ while (unmarked.nonEmpty) {
+ callStack.push(LinearizeFrame(unmarked.head, false))
+ while (callStack.nonEmpty) {
+ val LinearizeFrame(n, expanded) = callStack.pop()
+ if (!expanded) {
+ if (tempMarked.contains(n)) {
+ throw new CyclicException(n)
+ }
+ if (unmarked.contains(n)) {
+ tempMarked += n
+ unmarked -= n
+ callStack.push(LinearizeFrame(n, true))
+ // We want to visit the first edge first (so push it last)
+ for (m <- getEdges(n).toSeq.reverse) {
+ if (!unmarked.contains(m) && !tempMarked.contains(m) && !finished.contains(m)) {
+ unmarked += m
+ }
+ callStack.push(LinearizeFrame(m, false))
+ }
+ }
+ } else {
+ tempMarked -= n
+ finished += n
+ order.append(n)
+ }
+ }
+ }
+
+ // visited nodes are in post-traversal order, so must be reversed
+ order.toSeq.reverse
+ }
+
+ override def getEdges(source: ReferenceTarget): collection.Set[ReferenceTarget] = {
+ import ConnectionGraph._
+
+ val localSource = source.pathlessTarget
+
+ bfsShortCuts.get(localSource) match {
+ case Some(set) => set.map { x => x.setPathTarget(source.pathTarget) }
+ case None =>
+
+ val pathlessEdges = super.getEdges(localSource)
+
+ val ret = pathlessEdges.flatMap {
+
+ case localSink if withinSameInstance(source)(localSink) =>
+ portConnectivityStack(localSink) = portConnectivityStack.getOrElse(localSource, Nil)
+ Set[ReferenceTarget](localSink.setPathTarget(source.pathTarget))
+
+ case localSink if enteringParentInstance(source)(localSink) =>
+ val currentStack = portConnectivityStack.getOrElse(localSource, Nil)
+ if (currentStack.nonEmpty && currentStack.head.module == localSink.module) {
+ // Exiting back to parent module
+ // Update shortcut path from entrance from parent to new exit to parent
+ val instancePort = currentStack.head
+ val modulePort = ReferenceTarget(
+ localSource.circuit,
+ localSource.module,
+ Nil,
+ instancePort.component.head.value.toString,
+ instancePort.component.tail
+ )
+ val destinations = bfsShortCuts.getOrElse(modulePort, mutable.HashSet.empty[ReferenceTarget])
+ bfsShortCuts(modulePort) = destinations + localSource
+ // Remove entrance from parent from stack
+ portConnectivityStack(localSink) = currentStack.tail
+ } else {
+ // Exiting to parent, but had unresolved trip through child, so don't update shortcut
+ portConnectivityStack(localSink) = localSource +: currentStack
+ }
+ Set[ReferenceTarget](localSink.setPathTarget(source.noComponents.targetParent.asInstanceOf[IsComponent].pathTarget))
+
+ case localSink if enteringChildInstance(source)(localSink) =>
+ portConnectivityStack(localSink) = localSource +: portConnectivityStack.getOrElse(localSource, Nil)
+ val x = localSink.setPathTarget(source.pathTarget.instOf(source.ref, localSink.module))
+ Set[ReferenceTarget](x)
+
+ case localSink if leavingRootInstance(source)(localSink) => Set[ReferenceTarget]()
+
+ case localSink if enteringNonParentInstance(source)(localSink) => Set[ReferenceTarget]()
+
+ case other => Utils.throwInternalError(s"BAD? $source -> $other")
+
+ }
+ ret
+ }
+
+ }
+
+ override def path(start: ReferenceTarget, end: ReferenceTarget, blacklist: collection.Set[ReferenceTarget]): Seq[ReferenceTarget] = {
+ insertShortCuts(super.path(start, end, blacklist))
+ }
+
+ private def insertShortCuts(path: Seq[ReferenceTarget]): Seq[ReferenceTarget] = {
+ val soFar = mutable.HashSet[ReferenceTarget]()
+ if (path.size > 1) {
+ path.head +: path.sliding(2).flatMap {
+ case Seq(from, to) =>
+ getShortCut(from) match {
+ case Some(set) if set.contains(to) && soFar.contains(from.pathlessTarget) =>
+ soFar += from.pathlessTarget
+ Seq(from.pathTarget.ref("..."), to)
+ case _ =>
+ soFar += from.pathlessTarget
+ Seq(to)
+ }
+ }.toSeq
+ } else path
+ }
+
+ /** Finds all paths starting at a particular node in a DAG
+ *
+ * WARNING: This is an exponential time algorithm (as any algorithm
+ * must be for this problem), but is useful for flattening circuit
+ * graph hierarchies. Each path is represented by a Seq[T] of nodes
+ * in a traversable order.
+ *
+ * @param start the node to start at
+ * @return a `Map[T,Seq[Seq[T]]]` where the value associated with v is the Seq of all paths from start to v
+ */
+ override def pathsInDAG(start: ReferenceTarget): mutable.LinkedHashMap[ReferenceTarget, Seq[Seq[ReferenceTarget]]] = {
+ val linkedMap = super.pathsInDAG(start)
+ linkedMap.keysIterator.foreach { key =>
+ linkedMap(key) = linkedMap(key).map(insertShortCuts)
+ }
+ linkedMap
+ }
+
+ override def findSCCs: Seq[Seq[ReferenceTarget]] = Utils.throwInternalError("Cannot call findSCCs on ConnectionGraph")
+}
+
+object ConnectionGraph {
+
+ /** Returns a [[firrtl.graph.DiGraph]] of [[firrtl.annotations.Target]] and corresponding [[IRLookup]]
+ * Represents the directed connectivity of a FIRRTL circuit
+ *
+ * @param circuit firrtl AST of graph to be constructed.
+ * @return [[ConnectionGraph]] of this `circuit`.
+ */
+ def apply(circuit: Circuit): ConnectionGraph = buildCircuitGraph(circuit)
+
+ /** Within a module, given an [[firrtl.ir.Expression]] inside a module, return a corresponding [[firrtl.annotations.ReferenceTarget]]
+ * @todo why no subaccess.
+ *
+ * @param m Target of module containing the expression
+ * @param e
+ * @return
+ */
+ def asTarget(m: ModuleTarget, tagger: TokenTagger)(e: FirrtlNode): ReferenceTarget = e match {
+ case l: Literal => m.ref(tagger.getRef(l.value.toString))
+ case r: Reference => m.ref(r.name)
+ case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value)
+ case s: SubField => asTarget(m, tagger)(s.expr).field(s.name)
+ case d: DoPrim => m.ref(tagger.getRef(d.op.serialize))
+ case _: Mux => m.ref(tagger.getRef("mux"))
+ case _: ValidIf => m.ref(tagger.getRef("validif"))
+ case WInvalid => m.ref(tagger.getRef("invalid"))
+ case _: Print => m.ref(tagger.getRef("print"))
+ case _: Stop => m.ref(tagger.getRef("print"))
+ case other => sys.error(s"Unsupported: $other")
+ }
+
+ def withinSameInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = {
+ source.encapsulatingModule == localSink.encapsulatingModule
+ }
+
+ def enteringParentInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = {
+ def b1 = source.path.nonEmpty
+
+ def b2 = source.noComponents.targetParent.asInstanceOf[InstanceTarget].encapsulatingModule == localSink.module
+
+ def b3 = localSink.ref == source.path.last._1.value
+
+ b1 && b2 && b3
+ }
+
+ def enteringNonParentInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = {
+ source.path.nonEmpty &&
+ (source.noComponents.targetParent.asInstanceOf[InstanceTarget].encapsulatingModule != localSink.module ||
+ localSink.ref != source.path.last._1.value)
+ }
+
+ def enteringChildInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = source match {
+ case ReferenceTarget(_, _, _, _, TargetToken.Field(port) +: comps)
+ if port == localSink.ref && comps == localSink.component => true
+ case _ => false
+ }
+
+ def leavingRootInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = source match {
+ case ReferenceTarget(_, _, Seq(), port, comps)
+ if port == localSink.component.head.value && comps == localSink.component.tail => true
+ case _ => false
+ }
+
+
+ private def buildCircuitGraph(circuit: Circuit): ConnectionGraph = {
+ val mdg = new MutableDiGraph[ReferenceTarget]()
+ val declarations = mutable.LinkedHashMap[ModuleTarget, mutable.LinkedHashMap[ReferenceTarget, FirrtlNode]]()
+ val circuitTarget = CircuitTarget(circuit.main)
+ val moduleMap = circuit.modules.map { m => circuitTarget.module(m.name) -> m }.toMap
+
+ circuit map buildModule(circuitTarget)
+
+ def addLabeledVertex(v: ReferenceTarget, f: FirrtlNode): Unit = {
+ mdg.addVertex(v)
+ declarations.getOrElseUpdate(v.moduleTarget, mutable.LinkedHashMap.empty[ReferenceTarget, FirrtlNode])(v) = f
+ }
+
+ def buildModule(c: CircuitTarget)(module: DefModule): DefModule = {
+ val m = c.module(module.name)
+ module map buildPort(m) map buildStatement(m, new TokenTagger())
+ }
+
+ def buildPort(m: ModuleTarget)(port: Port): Port = {
+ val p = m.ref(port.name)
+ addLabeledVertex(p, port)
+ port
+ }
+
+ def buildInstance(m: ModuleTarget, tagger: TokenTagger, name: String, ofModule: String, tpe: Type): Unit = {
+ val instPorts = Utils.create_exps(Reference(name, tpe, InstanceKind, SinkFlow))
+ val modulePorts = tpe.asInstanceOf[BundleType].fields.flatMap {
+ // Module output
+ case firrtl.ir.Field(name, Default, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SourceFlow))
+ // Module input
+ case firrtl.ir.Field(name, Flip, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SinkFlow))
+ }
+ assert(instPorts.size == modulePorts.size)
+ val o = m.circuitTarget.module(ofModule)
+ instPorts.zip(modulePorts).foreach { x =>
+ val (instExp, modExp) = x
+ val it = asTarget(m, tagger)(instExp)
+ val mt = asTarget(o, tagger)(modExp)
+ (Utils.flow(instExp), Utils.flow(modExp)) match {
+ case (SourceFlow, SinkFlow) => mdg.addPairWithEdge(it, mt)
+ case (SinkFlow, SourceFlow) => mdg.addPairWithEdge(mt, it)
+ case _ => sys.error("Something went wrong...")
+ }
+ }
+ }
+
+ def buildMemory(mt: ModuleTarget, d: DefMemory): Unit = {
+ val readers = d.readers.toSet
+ val readwriters = d.readwriters.toSet
+ val mem = mt.ref(d.name)
+ MemPortUtils.memType(d).fields.foreach {
+ case Field(name, _, _: BundleType) if readers.contains(name) || readwriters.contains(name) =>
+ val port = mem.field(name)
+ val sources = Seq(
+ port.field("clk"),
+ port.field("en"),
+ port.field("addr")
+ ) ++ (if (readwriters.contains(name)) Seq(port.field("wmode")) else Nil)
+
+ val data = if (readers.contains(name)) port.field("data") else port.field("rdata")
+ val sinks = data.leafSubTargets(d.dataType)
+
+ sources.foreach {
+ mdg.addVertex
+ }
+ sinks.foreach { sink =>
+ mdg.addVertex(sink)
+ sources.foreach { source => mdg.addEdge(source, sink) }
+ }
+ case _ =>
+ }
+ }
+
+ def buildRegister(m: ModuleTarget, tagger: TokenTagger, d: DefRegister): Unit = {
+ val regTarget = m.ref(d.name)
+ val clockTarget = regTarget.clock
+ val resetTarget = regTarget.reset
+ val initTarget = regTarget.init
+
+ // Build clock expression
+ mdg.addVertex(clockTarget)
+ buildExpression(m, tagger, clockTarget)(d.clock)
+
+ // Build reset expression
+ mdg.addVertex(resetTarget)
+ buildExpression(m, tagger, resetTarget)(d.reset)
+
+ // Connect each subTarget to the corresponding init subTarget
+ val allRegTargets = regTarget.leafSubTargets(d.tpe)
+ val allInitTargets = initTarget.leafSubTargets(d.tpe).zip(Utils.create_exps(d.init))
+ allRegTargets.zip(allInitTargets).foreach { case (r, (i, e)) =>
+ mdg.addVertex(i)
+ mdg.addVertex(r)
+ mdg.addEdge(clockTarget, r)
+ mdg.addEdge(resetTarget, r)
+ mdg.addEdge(i, r)
+ buildExpression(m, tagger, i)(e)
+ }
+ }
+
+ def buildStatement(m: ModuleTarget, tagger: TokenTagger)(stmt: Statement): Statement = {
+ stmt match {
+ case d: DefWire =>
+ addLabeledVertex(m.ref(d.name), stmt)
+
+ case d: DefNode =>
+ val sinkTarget = m.ref(d.name)
+ addLabeledVertex(sinkTarget, stmt)
+ val nodeTargets = sinkTarget.leafSubTargets(d.value.tpe)
+ nodeTargets.zip(Utils.create_exps(d.value)).foreach { case (n, e) =>
+ mdg.addVertex(n)
+ buildExpression(m, tagger, n)(e)
+ }
+
+ case c: Connect =>
+ val sinkTarget = asTarget(m, tagger)(c.loc)
+ mdg.addVertex(sinkTarget)
+ buildExpression(m, tagger, sinkTarget)(c.expr)
+
+ case i: IsInvalid =>
+ val sourceTarget = asTarget(m, tagger)(WInvalid)
+ addLabeledVertex(sourceTarget, stmt)
+ mdg.addVertex(sourceTarget)
+ val sinkTarget = asTarget(m, tagger)(i.expr)
+ sinkTarget.allSubTargets(i.expr.tpe).foreach { st =>
+ mdg.addVertex(st)
+ mdg.addEdge(sourceTarget, st)
+ }
+
+ case DefInstance(_, name, ofModule, tpe) =>
+ addLabeledVertex(m.ref(name), stmt)
+ buildInstance(m, tagger, name, ofModule, tpe)
+
+ case d: DefRegister =>
+ addLabeledVertex(m.ref(d.name), d)
+ buildRegister(m, tagger, d)
+
+ case d: DefMemory =>
+ addLabeledVertex(m.ref(d.name), d)
+ buildMemory(m, d)
+
+ /** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.passes.ExpandWhensAndCheck]]*/
+ case _: Conditionally => sys.error("Unsupported! Only works on Middle Firrtl")
+
+ case s: Block => s map buildStatement(m, tagger)
+
+ case a: Attach =>
+ val attachTargets = a.exprs.map { r =>
+ val at = asTarget(m, tagger)(r)
+ mdg.addVertex(at)
+ at
+ }
+ attachTargets.combinations(2).foreach { case Seq(l, r) =>
+ mdg.addEdge(l, r)
+ mdg.addEdge(r, l)
+ }
+ case p: Print => addLabeledVertex(asTarget(m, tagger)(p), p)
+ case s: Stop => addLabeledVertex(asTarget(m, tagger)(s), s)
+ case EmptyStmt =>
+ }
+ stmt
+ }
+
+ def buildExpression(m: ModuleTarget, tagger: TokenTagger, sinkTarget: ReferenceTarget)(expr: Expression): Expression = {
+ /** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.stage.Forms.Resolved]]. */
+ val sourceTarget = asTarget(m, tagger)(expr)
+ mdg.addVertex(sourceTarget)
+ mdg.addEdge(sourceTarget, sinkTarget)
+ expr match {
+ case _: DoPrim | _: Mux | _: ValidIf | _: Literal =>
+ addLabeledVertex(sourceTarget, expr)
+ expr map buildExpression(m, tagger, sourceTarget)
+ case _ =>
+ }
+ expr
+ }
+
+ new ConnectionGraph(circuit, DiGraph(mdg), new IRLookup(declarations.mapValues(_.toMap).toMap, moduleMap))
+ }
+}
+
+
+/** Used for obtaining a tag for a given label unnamed Target. */
+class TokenTagger {
+ private val counterMap = mutable.HashMap[String, Int]()
+
+ def getTag(label: String): Int = {
+ val tag = counterMap.getOrElse(label, 0)
+ counterMap(label) = tag + 1
+ tag
+ }
+
+ def getRef(label: String): String = {
+ "@" + label + "#" + getTag(label)
+ }
+}
+
+object TokenTagger {
+ val literalRegex = "@([-]?[0-9]+)#[0-9]+".r
+}
diff --git a/src/main/scala/firrtl/analyses/IRLookup.scala b/src/main/scala/firrtl/analyses/IRLookup.scala
new file mode 100644
index 00000000..f9819ebd
--- /dev/null
+++ b/src/main/scala/firrtl/analyses/IRLookup.scala
@@ -0,0 +1,265 @@
+// See LICENSE for license details.
+
+package firrtl.analyses
+
+import firrtl.annotations.TargetToken._
+import firrtl.annotations._
+import firrtl.ir._
+import firrtl.passes.MemPortUtils
+import firrtl.{DuplexFlow, ExpKind, Flow, InstanceKind, Kind, MemKind, PortKind, RegKind, SinkFlow, SourceFlow, UnknownFlow, Utils, WInvalid, WireKind}
+
+import scala.collection.mutable
+
+object IRLookup {
+ def apply(circuit: Circuit): IRLookup = ConnectionGraph(circuit).irLookup
+}
+
+/** Handy lookup for obtaining AST information about a given Target
+ *
+ * @param declarations Maps references (not subreferences) to declarations
+ * @param modules Maps module targets to modules
+ */
+class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map[ReferenceTarget, FirrtlNode]],
+ private val modules: Map[ModuleTarget, DefModule]) {
+
+ private val flowCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Flow]]()
+ private val kindCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Kind]]()
+ private val tpeCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Type]]()
+ private val exprCache = mutable.HashMap[ModuleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]]()
+ private val refCache = mutable.HashMap[ModuleTarget, mutable.LinkedHashMap[Kind, mutable.ArrayBuffer[ReferenceTarget]]]()
+
+
+ /** @example Given ~Top|MyModule/inst:Other>foo.bar, returns ~Top|Other>foo
+ * @return the target converted to its local reference
+ */
+ def asLocalRef(t: ReferenceTarget): ReferenceTarget = t.pathlessTarget.copy(component = Nil)
+
+ def flow(t: ReferenceTarget): Flow = flowCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Flow]()).getOrElseUpdate(t.pathlessTarget, Utils.flow(expr(t.pathlessTarget)))
+
+ def kind(t: ReferenceTarget): Kind = kindCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Kind]()).getOrElseUpdate(t.pathlessTarget, Utils.kind(expr(t.pathlessTarget)))
+
+ def tpe(t: ReferenceTarget): Type = tpeCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Type]()).getOrElseUpdate(t.pathlessTarget, expr(t.pathlessTarget).tpe)
+
+ /** get expression of the target.
+ * It can return None for many reasons, including
+ * - declaration is missing
+ * - flow is wrong
+ * - component is wrong
+ *
+ * @param t [[firrtl.annotations.ReferenceTarget]] to be queried.
+ * @param flow flow of the target
+ * @return Some(e) if expression exists, None if it does not
+ */
+ def getExpr(t: ReferenceTarget, flow: Flow): Option[Expression] = {
+ val pathless = t.pathlessTarget
+
+ inCache(pathless, flow) match {
+ case e@Some(_) => return e
+ case None =>
+ val mt = pathless.moduleTarget
+ val emt = t.encapsulatingModuleTarget
+ if (declarations.contains(emt) && declarations(emt).contains(asLocalRef(t))) {
+ declarations(emt)(asLocalRef(t)) match {
+ case e: Expression =>
+ require(e.tpe.isInstanceOf[GroundType])
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).getOrElseUpdate((pathless, Utils.flow(e)), e)
+ case d: IsDeclaration => d match {
+ case n: DefNode =>
+ updateExpr(mt, Reference(n.name, n.value.tpe, ExpKind, SourceFlow))
+ case p: Port =>
+ updateExpr(mt, Reference(p.name, p.tpe, PortKind, Utils.get_flow(p)))
+ case w: DefInstance =>
+ updateExpr(mt, Reference(w.name, w.tpe, InstanceKind, SourceFlow))
+ case w: DefWire =>
+ updateExpr(mt, Reference(w.name, w.tpe, WireKind, SourceFlow))
+ updateExpr(mt, Reference(w.name, w.tpe, WireKind, SinkFlow))
+ updateExpr(mt, Reference(w.name, w.tpe, WireKind, DuplexFlow))
+ case r: DefRegister if pathless.tokens.last == Clock =>
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.clock
+ case r: DefRegister if pathless.tokens.isDefinedAt(1) && pathless.tokens(1) == Init =>
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.init
+ updateExpr(pathless, r.init)
+ case r: DefRegister if pathless.tokens.last == Reset =>
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.reset
+ case r: DefRegister =>
+ updateExpr(mt, Reference(r.name, r.tpe, RegKind, SourceFlow))
+ updateExpr(mt, Reference(r.name, r.tpe, RegKind, SinkFlow))
+ updateExpr(mt, Reference(r.name, r.tpe, RegKind, DuplexFlow))
+ case m: DefMemory =>
+ updateExpr(mt, Reference(m.name, MemPortUtils.memType(m), MemKind, SourceFlow))
+ case other =>
+ sys.error(s"Cannot call expr with: $t, given declaration $other")
+ }
+ case _: IsInvalid =>
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = WInvalid
+ }
+ }
+ }
+
+ inCache(pathless, flow)
+ }
+
+ /**
+ * @param t [[firrtl.annotations.ReferenceTarget]] to be queried.
+ * @param flow flow of the target
+ * @return expression of `t`
+ */
+ def expr(t: ReferenceTarget, flow: Flow = UnknownFlow): Expression = {
+ require(contains(t), s"Cannot find\n${t.prettyPrint()}\nin circuit!")
+ getExpr(t, flow) match {
+ case Some(e) => e
+ case None =>
+ require(getExpr(t.pathlessTarget, UnknownFlow).isEmpty, s"Illegal flow $flow with target $t")
+ sys.error("")
+ }
+ }
+
+ /** Find [[firrtl.annotations.ReferenceTarget]] with a specific [[firrtl.Kind]] in a [[firrtl.annotations.ModuleTarget]]
+ *
+ * @param moduleTarget [[firrtl.annotations.ModuleTarget]] to be queried.
+ * @param kind [[firrtl.Kind]] to be find.
+ * @return all [[firrtl.annotations.ReferenceTarget]] in this node. */
+ def kindFinder(moduleTarget: ModuleTarget, kind: Kind): Seq[ReferenceTarget] = {
+ def updateRefs(kind: Kind, rt: ReferenceTarget): Unit = refCache
+ .getOrElseUpdate(rt.moduleTarget, mutable.LinkedHashMap.empty[Kind, mutable.ArrayBuffer[ReferenceTarget]])
+ .getOrElseUpdate(kind, mutable.ArrayBuffer.empty[ReferenceTarget]) += rt
+
+ require(contains(moduleTarget), s"Cannot find\n${moduleTarget.prettyPrint()}\nin circuit!")
+ if (refCache.contains(moduleTarget) && refCache(moduleTarget).contains(kind)) refCache(moduleTarget)(kind).toSeq
+ else {
+ declarations(moduleTarget).foreach {
+ case (rt, _: DefRegister) => updateRefs(RegKind, rt)
+ case (rt, _: DefWire) => updateRefs(WireKind, rt)
+ case (rt, _: DefNode) => updateRefs(ExpKind, rt)
+ case (rt, _: DefMemory) => updateRefs(MemKind, rt)
+ case (rt, _: DefInstance) => updateRefs(InstanceKind, rt)
+ case (rt, _: Port) => updateRefs(PortKind, rt)
+ case _ =>
+ }
+ refCache.get(moduleTarget).map(_.getOrElse(kind, Seq.empty[ReferenceTarget])).getOrElse(Seq.empty[ReferenceTarget]).toSeq
+ }
+ }
+
+ /**
+ * @param t [[firrtl.annotations.ReferenceTarget]] to be queried.
+ * @return the statement containing the declaration of the target
+ */
+ def declaration(t: ReferenceTarget): FirrtlNode = {
+ require(contains(t), s"Cannot find\n${t.prettyPrint()}\nin circuit!")
+ declarations(t.encapsulatingModuleTarget)(asLocalRef(t))
+ }
+
+ /** Returns the references to the module's ports
+ *
+ * @param mt [[firrtl.annotations.ModuleTarget]] to be queried.
+ * @return the port references of `mt`
+ */
+ def ports(mt: ModuleTarget): Seq[ReferenceTarget] = {
+ require(contains(mt), s"Cannot find\n${mt.prettyPrint()}\nin circuit!")
+ modules(mt).ports.map { p => mt.ref(p.name) }
+ }
+
+ /** Given:
+ * A [[firrtl.annotations.ReferenceTarget]] of ~Top|Module>ref, which is a type of {foo: {bar: UInt}}
+ * Return:
+ * Seq(~Top|Module>ref, ~Top|Module>ref.foo, ~Top|Module>ref.foo.bar)
+ *
+ * @return a target to each sub-component, including intermediate subcomponents
+ */
+ def allTargets(r: ReferenceTarget): Seq[ReferenceTarget] = r.allSubTargets(tpe(r))
+
+ /** Given:
+ * A [[firrtl.annotations.ReferenceTarget]] of ~Top|Module>ref and a type of {foo: {bar: UInt}}
+ * Return:
+ * Seq(~Top|Module>ref.foo.bar)
+ *
+ * @return a target to each sub-component, excluding intermediate subcomponents.
+ */
+ def leafTargets(r: ReferenceTarget): Seq[ReferenceTarget] = r.leafSubTargets(tpe(r))
+
+ /** @return Returns ((inputs, outputs)) target and type of each module port. */
+ def moduleLeafPortTargets(m: ModuleTarget): (Seq[(ReferenceTarget, Type)], Seq[(ReferenceTarget, Type)]) =
+ modules(m).ports.flatMap {
+ case Port(_, name, Output, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SourceFlow))
+ case Port(_, name, Input, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SinkFlow))
+ }.foldLeft((Vector.empty[(ReferenceTarget, Type)], Vector.empty[(ReferenceTarget, Type)])) {
+ case ((inputs, outputs), e) if Utils.flow(e) == SourceFlow =>
+ (inputs, outputs :+ (ConnectionGraph.asTarget(m, new TokenTagger())(e), e.tpe))
+ case ((inputs, outputs), e) =>
+ (inputs :+ (ConnectionGraph.asTarget(m, new TokenTagger())(e), e.tpe), outputs)
+ }
+
+
+ /** @param t [[firrtl.annotations.ReferenceTarget]] to be queried.
+ * @return whether a ReferenceTarget is contained in this IRLookup
+ */
+ def contains(t: ReferenceTarget): Boolean = validPath(t.pathTarget) &&
+ declarations.contains(t.encapsulatingModuleTarget) &&
+ declarations(t.encapsulatingModuleTarget).contains(asLocalRef(t)) &&
+ getExpr(t, UnknownFlow).nonEmpty
+
+ /** @param mt [[firrtl.annotations.ModuleTarget]] or [[firrtl.annotations.InstanceTarget]] to be queried.
+ * @return whether a ModuleTarget or InstanceTarget is contained in this IRLookup
+ */
+ def contains(mt: IsModule): Boolean = validPath(mt)
+
+ /** @param t [[firrtl.annotations.ReferenceTarget]] to be queried.
+ * @return whether a given [[firrtl.annotations.IsModule]] is valid, given the circuit's module/instance hierarchy
+ */
+ def validPath(t: IsModule): Boolean = {
+ t match {
+ case m: ModuleTarget => declarations.contains(m)
+ case i: InstanceTarget =>
+ val all = i.pathAsTargets :+ i.encapsulatingModuleTarget.instOf(i.instance, i.ofModule)
+ all.map { x =>
+ declarations.contains(x.moduleTarget) && declarations(x.moduleTarget).contains(x.asReference) &&
+ (declarations(x.moduleTarget)(x.asReference) match {
+ case DefInstance(_, _, of, _) if of == x.ofModule => validPath(x.ofModuleTarget)
+ case _ => false
+ })
+ }.reduce(_ && _)
+ }
+ }
+
+ /** Updates expression cache with expression. */
+ private def updateExpr(mt: ModuleTarget, ref: Expression): Unit = {
+ val refs = Utils.expandRef(ref)
+ refs.foreach { e =>
+ val target = ConnectionGraph.asTarget(mt, new TokenTagger())(e)
+ exprCache(target.moduleTarget)((target, Utils.flow(e))) = e
+ }
+ }
+
+ /** Updates expression cache with expression. */
+ private def updateExpr(gt: ReferenceTarget, e: Expression): Unit = {
+ val g = Utils.flow(e)
+ e.tpe match {
+ case _: GroundType =>
+ exprCache(gt.moduleTarget)((gt, g)) = e
+ case VectorType(t, size) =>
+ exprCache(gt.moduleTarget)((gt, g)) = e
+ (0 until size).foreach { i => updateExpr(gt.index(i), SubIndex(e, i, t, g)) }
+ case BundleType(fields) =>
+ exprCache(gt.moduleTarget)((gt, g)) = e
+ fields.foreach { f => updateExpr(gt.field(f.name), SubField(e, f.name, f.tpe, Utils.times(g, f.flip))) }
+ case other => sys.error(s"Error! Unexpected type $other")
+ }
+ }
+
+ /** Optionally returns the expression corresponding to the target if contained in the expression cache. */
+ private def inCache(pathless: ReferenceTarget, flow: Flow): Option[Expression] = {
+ (flow,
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains((pathless, SourceFlow)),
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains((pathless, SinkFlow)),
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains(pathless, DuplexFlow)
+ ) match {
+ case (SourceFlow, true, _, _) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, flow)))
+ case (SinkFlow, _, true, _) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, flow)))
+ case (DuplexFlow, _, _, true) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, DuplexFlow)))
+ case (UnknownFlow, _, _, true) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, DuplexFlow)))
+ case (UnknownFlow, true, false, false) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)))
+ case (UnknownFlow, false, true, false) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SinkFlow)))
+ case _ => None
+ }
+ }
+}
diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala
index ddd5eb8b..c0aec4aa 100644
--- a/src/main/scala/firrtl/analyses/InstanceGraph.scala
+++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala
@@ -30,18 +30,18 @@ class InstanceGraph(c: Circuit) {
val moduleMap = c.modules.map({m => (m.name,m) }).toMap
private val instantiated = new mutable.LinkedHashSet[String]
private val childInstances =
- new mutable.LinkedHashMap[String, mutable.LinkedHashSet[WDefInstance]]
+ new mutable.LinkedHashMap[String, mutable.LinkedHashSet[DefInstance]]
for (m <- c.modules) {
- childInstances(m.name) = new mutable.LinkedHashSet[WDefInstance]
+ childInstances(m.name) = new mutable.LinkedHashSet[DefInstance]
m.foreach(InstanceGraph.collectInstances(childInstances(m.name)))
instantiated ++= childInstances(m.name).map(i => i.module)
}
- private val instanceGraph = new MutableDiGraph[WDefInstance]
- private val instanceQueue = new mutable.Queue[WDefInstance]
+ private val instanceGraph = new MutableDiGraph[DefInstance]
+ private val instanceQueue = new mutable.Queue[DefInstance]
for (subTop <- c.modules.view.map(_.name).filterNot(instantiated)) {
- val topInstance = WDefInstance(subTop,subTop)
+ val topInstance = DefInstance(subTop,subTop)
instanceQueue.enqueue(topInstance)
while (instanceQueue.nonEmpty) {
val current = instanceQueue.dequeue
@@ -57,11 +57,11 @@ class InstanceGraph(c: Circuit) {
}
// The true top module (circuit main)
- private val trueTopInstance = WDefInstance(c.main, c.main)
+ private val trueTopInstance = DefInstance(c.main, c.main)
/** A directed graph showing the instance dependencies among modules
- * in the circuit. Every WDefInstance of a module has an edge to
- * every WDefInstance arising from every instance statement in
+ * in the circuit. Every DefInstance of a module has an edge to
+ * every DefInstance arising from every instance statement in
* that module.
*/
lazy val graph = DiGraph(instanceGraph)
@@ -69,7 +69,7 @@ class InstanceGraph(c: Circuit) {
/** A list of absolute paths (each represented by a Seq of instances)
* of all module instances in the Circuit.
*/
- lazy val fullHierarchy: mutable.LinkedHashMap[WDefInstance,Seq[Seq[WDefInstance]]] = graph.pathsInDAG(trueTopInstance)
+ lazy val fullHierarchy: mutable.LinkedHashMap[DefInstance,Seq[Seq[DefInstance]]] = graph.pathsInDAG(trueTopInstance)
/** A count of the *static* number of instances of each module. For any module other than the top (main) module, this is
* equivalent to the number of inst statements in the circuit instantiating each module, irrespective of the number
@@ -96,9 +96,9 @@ class InstanceGraph(c: Circuit) {
* hierarchy of the top module of the circuit, it will return Nil.
*
* @param module the name of the selected module
- * @return a Seq[ Seq[WDefInstance] ] of absolute instance paths
+ * @return a Seq[ Seq[DefInstance] ] of absolute instance paths
*/
- def findInstancesInHierarchy(module: String): Seq[Seq[WDefInstance]] = {
+ def findInstancesInHierarchy(module: String): Seq[Seq[DefInstance]] = {
val instances = graph.getVertices.filter(_.module == module).toSeq
instances flatMap { i => fullHierarchy.getOrElse(i, Nil) }
}
@@ -109,8 +109,8 @@ class InstanceGraph(c: Circuit) {
/** Finds the lowest common ancestor instances for two module names in
* a design
*/
- def lowestCommonAncestor(moduleA: Seq[WDefInstance],
- moduleB: Seq[WDefInstance]): Seq[WDefInstance] = {
+ def lowestCommonAncestor(moduleA: Seq[DefInstance],
+ moduleB: Seq[DefInstance]): Seq[DefInstance] = {
tour.rmq(moduleA, moduleB)
}
@@ -126,7 +126,7 @@ class InstanceGraph(c: Circuit) {
/** Given a circuit, returns a map from module name to children
* instance/module definitions
*/
- def getChildrenInstances: mutable.LinkedHashMap[String, mutable.LinkedHashSet[WDefInstance]] = childInstances
+ def getChildrenInstances: mutable.LinkedHashMap[String, mutable.LinkedHashSet[DefInstance]] = childInstances
/** Given a circuit, returns a map from module name to children
* instance/module [[firrtl.annotations.TargetToken]]s
@@ -147,7 +147,7 @@ class InstanceGraph(c: Circuit) {
* in turn mapping instances names to corresponding module names
*/
def getChildrenInstanceMap: collection.Map[OfModule, collection.Map[Instance, OfModule]] =
- childInstances.map(kv => kv._1.OfModule -> asOrderedMap(kv._2, (i: WDefInstance) => i.toTokens))
+ childInstances.map(kv => kv._1.OfModule -> asOrderedMap(kv._2, (i: DefInstance) => i.toTokens))
/** The set of all modules in the circuit */
lazy val modules: collection.Set[OfModule] = graph.getVertices.map(_.OfModule)
@@ -163,17 +163,17 @@ class InstanceGraph(c: Circuit) {
object InstanceGraph {
- /** Returns all WDefInstances in a Statement
+ /** Returns all DefInstances in a Statement
*
* @param insts mutable datastructure to append to
* @param s statement to descend
* @return
*/
- def collectInstances(insts: mutable.Set[WDefInstance])
+ def collectInstances(insts: mutable.Set[DefInstance])
(s: Statement): Unit = s match {
- case i: WDefInstance => insts += i
- case i: DefInstance => throwInternalError("Expecting WDefInstance, found a DefInstance!")
- case i: WDefInstanceConnector => throwInternalError("Expecting WDefInstance, found a WDefInstanceConnector!")
+ case i: DefInstance => insts += i
+ case i: DefInstance => throwInternalError("Expecting DefInstance, found a DefInstance!")
+ case i: WDefInstanceConnector => throwInternalError("Expecting DefInstance, found a DefInstanceConnector!")
case _ => s.foreach(collectInstances(insts))
}
}
diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala
index e7ea07ca..4d1cdc2f 100644
--- a/src/main/scala/firrtl/annotations/Target.scala
+++ b/src/main/scala/firrtl/annotations/Target.scala
@@ -99,6 +99,15 @@ sealed trait Target extends Named {
/** Whether the target is directly instantiated in its root module */
def isLocal: Boolean
+
+ /** Share root module */
+ def sharedRoot(other: Target): Boolean = this.moduleOpt == other.moduleOpt && other.moduleOpt.nonEmpty
+
+ /** Checks whether this is inside of other */
+ def encapsulatedBy(other: IsModule): Boolean = this.moduleOpt.contains(other.encapsulatingModule)
+
+ /** @return Returns the instance hierarchy path, if one exists */
+ def path: Seq[(Instance, OfModule)]
}
object Target {
@@ -193,6 +202,33 @@ object Target {
case b: InstanceTarget => b.ofModuleTarget
case b: ReferenceTarget => b.pathlessTarget.moduleTarget
}
+
+ def getPathlessTarget(t: Target): Target = {
+ t.tryToComplete match {
+ case c: CircuitTarget => c
+ case m: IsMember => m.pathlessTarget
+ case t: GenericTarget if t.isLegal =>
+ val newTokens = t.tokens.dropWhile(x => x.isInstanceOf[Instance] || x.isInstanceOf[OfModule])
+ GenericTarget(t.circuitOpt, t.moduleOpt, newTokens)
+ case other => sys.error(s"Can't make $other pathless!")
+ }
+ }
+
+ def getReferenceTarget(t: Target): Target = {
+ (t.toGenericTarget match {
+ case t: GenericTarget if t.isLegal =>
+ val newTokens = t.tokens.reverse.dropWhile({
+ case x: Field => true
+ case x: Index => true
+ case Clock => true
+ case Init => true
+ case Reset => true
+ case other => false
+ }).reverse
+ GenericTarget(t.circuitOpt, t.moduleOpt, newTokens)
+ case other => sys.error(s"Can't make $other pathless!")
+ }).tryToComplete
+ }
}
/** Represents incomplete or non-standard [[Target]]s
@@ -235,6 +271,12 @@ case class GenericTarget(circuitOpt: Option[String],
override def isLocal: Boolean = !(getPath.nonEmpty && getPath.get.nonEmpty)
+ def path: Vector[(Instance, OfModule)] = if(isComplete){
+ tokens.zip(tokens.tail).collect {
+ case (i: Instance, o: OfModule) => (i, o)
+ }
+ } else Vector.empty[(Instance, OfModule)]
+
/** If complete, return this [[GenericTarget]]'s path
* @return
*/
@@ -342,6 +384,14 @@ case class GenericTarget(circuitOpt: Option[String],
def isCircuitTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.isEmpty && tokens.isEmpty
def isModuleTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.isEmpty
def isComponentTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.nonEmpty
+
+ lazy val (parentModule: Option[String], astModule: Option[String]) = path match {
+ case Seq() => (None, moduleOpt)
+ case Seq((i, OfModule(o))) => (moduleOpt, Some(o))
+ case seq if seq.size > 1 =>
+ val reversed = seq.reverse
+ (Some(reversed(1)._2.value), Some(reversed(0)._2.value))
+ }
}
/** Concretely points to a FIRRTL target, no generic selectors
@@ -368,7 +418,7 @@ trait CompleteTarget extends Target {
override def toTarget: CompleteTarget = this
// Very useful for debugging, I (@azidar) think this is reasonable
- override def toString = serialize
+ override def toString: String = serialize
}
@@ -417,6 +467,13 @@ trait IsMember extends CompleteTarget {
* @return
*/
def setPathTarget(newPath: IsModule): CompleteTarget
+
+ /** @return The [[ModuleTarget]] of the module that directly contains this component */
+ def encapsulatingModule: String = if(path.isEmpty) module else path.last._2.value
+
+ def encapsulatingModuleTarget: ModuleTarget = ModuleTarget(circuit, encapsulatingModule)
+
+ def leafModule: String
}
/** References a module-like target (e.g. a [[ModuleTarget]] or an [[InstanceTarget]])
@@ -435,10 +492,6 @@ trait IsModule extends IsMember {
/** A component of a FIRRTL Module (e.g. cannot point to a CircuitTarget or ModuleTarget)
*/
trait IsComponent extends IsMember {
-
- /** @return The [[ModuleTarget]] of the module that directly contains this component */
- def encapsulatingModule: String = if(path.isEmpty) module else path.last._2.value
-
/** Removes n levels of instance hierarchy
*
* Example: n=1, transforms (Top, A)/b:B/c:C -> (Top, B)/c:C
@@ -505,6 +558,8 @@ case class CircuitTarget(circuit: String) extends CompleteTarget {
override def addHierarchy(root: String, instance: String): ReferenceTarget =
ReferenceTarget(circuit, root, Nil, instance, Nil)
+ override def path = Seq()
+
override def toNamed: CircuitName = CircuitName(circuit)
}
@@ -545,6 +600,8 @@ case class ModuleTarget(circuit: String, module: String) extends IsModule {
override def setPathTarget(newPath: IsModule): IsModule = newPath
override def toNamed: ModuleName = ModuleName(module, CircuitName(circuit))
+
+ override def leafModule: String = module
}
/** Target pointing to a declared named component in a [[firrtl.ir.DefModule]]
@@ -631,6 +688,30 @@ case class ReferenceTarget(circuit: String,
ReferenceTarget(newPath.circuit, newPath.module, newPath.asPath, ref, component)
override def asPath: Seq[(Instance, OfModule)] = path
+
+ def isClock: Boolean = tokens.last == Clock
+
+ def isInit: Boolean = tokens.last == Init
+
+ def isReset: Boolean = tokens.last == Reset
+
+ def noComponents: ReferenceTarget = this.copy(component = Nil)
+
+ def leafSubTargets(tpe: firrtl.ir.Type): Seq[ReferenceTarget] = tpe match {
+ case _: firrtl.ir.GroundType => Vector(this)
+ case firrtl.ir.VectorType(t, size) => (0 until size).flatMap { i => index(i).leafSubTargets(t) }
+ case firrtl.ir.BundleType(fields) => fields.flatMap { f => field(f.name).leafSubTargets(f.tpe)}
+ case other => sys.error(s"Error! Unexpected type $other")
+ }
+
+ def allSubTargets(tpe: firrtl.ir.Type): Seq[ReferenceTarget] = tpe match {
+ case _: firrtl.ir.GroundType => Vector(this)
+ case firrtl.ir.VectorType(t, size) => this +: (0 until size).flatMap { i => index(i).allSubTargets(t) }
+ case firrtl.ir.BundleType(fields) => this +: fields.flatMap { f => field(f.name).allSubTargets(f.tpe)}
+ case other => sys.error(s"Error! Unexpected type $other")
+ }
+
+ override def leafModule: String = encapsulatingModule
}
/** Points to an instance declaration of a module (termed an ofModule)
@@ -652,6 +733,12 @@ case class InstanceTarget(circuit: String,
/** @return a [[ModuleTarget]] referring to declaration of this ofModule */
def ofModuleTarget: ModuleTarget = ModuleTarget(circuit, ofModule)
+ /** @return a [[ReferenceTarget]] referring to given reference within this instance */
+ def addReference(rt: ReferenceTarget): ReferenceTarget = {
+ require(rt.module == ofModule)
+ ReferenceTarget(circuit, module, asPath, rt.ref, rt.component)
+ }
+
override def circuitOpt: Option[String] = Some(circuit)
override def moduleOpt: Option[String] = Some(module)
@@ -690,6 +777,8 @@ case class InstanceTarget(circuit: String,
override def setPathTarget(newPath: IsModule): InstanceTarget =
InstanceTarget(newPath.circuit, newPath.module, newPath.asPath, instance, ofModule)
+
+ override def leafModule: String = ofModule
}
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
index f30beec1..32bcac5f 100644
--- a/src/main/scala/firrtl/graph/DiGraph.scala
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -2,9 +2,8 @@
package firrtl.graph
-import scala.collection.{Set, Map}
-import scala.collection.mutable
-import scala.collection.mutable.{LinkedHashSet, LinkedHashMap}
+import scala.collection.{Map, Set, mutable}
+import scala.collection.mutable.{LinkedHashMap, LinkedHashSet}
/** An exception that is raised when an assumed DAG has a cycle */
class CyclicException(val node: Any) extends Exception(s"No valid linearization for cyclic graph, found at $node")
@@ -18,7 +17,7 @@ object DiGraph {
def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] = mdg
/** Create a DiGraph from a Map[T,Set[T]] of edge data */
- def apply[T](edgeData: Map[T,Set[T]]): DiGraph[T] = {
+ def apply[T](edgeData: Map[T, Set[T]]): DiGraph[T] = {
val edgeDataCopy = new LinkedHashMap[T, LinkedHashSet[T]]
for ((k, v) <- edgeData) {
edgeDataCopy(k) = new LinkedHashSet[T]
@@ -34,7 +33,7 @@ object DiGraph {
}
/** Represents common behavior of all directed graphs */
-class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) {
+class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) {
/** Check whether the graph contains vertex v */
def contains(v: T): Boolean = edges.contains(v)
@@ -188,7 +187,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
*
* @param start the start node
* @param end the destination node
- * @throws PathNotFoundException
+ * @throws firrtl.graph.PathNotFoundException
* @return a Seq[T] of nodes defining an arbitrary valid path
*/
def path(start: T, end: T): Seq[T] = path(start, end, Set.empty[T])
@@ -198,7 +197,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
* @param start the start node
* @param end the destination node
* @param blacklist list of nodes which break path, if encountered
- * @throws PathNotFoundException
+ * @throws firrtl.graph.PathNotFoundException
* @return a Seq[T] of nodes defining an arbitrary valid path
*/
def path(start: T, end: T, blacklist: Set[T]): Seq[T] = {
@@ -336,7 +335,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
* Any edge including a deleted node will be deleted
*
* @param vprime the Set[T] of desired vertices
- * @throws scala.IllegalArgumentException if vprime is not a subset of V
+ * @throws java.lang.IllegalArgumentException if vprime is not a subset of V
* @return the subgraph
*/
def subgraph(vprime: Set[T]): DiGraph[T] = {
@@ -350,12 +349,12 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
* transformed into an edge (u,v).
*
* @param vprime the Set[T] of desired vertices
- * @throws scala.IllegalArgumentException if vprime is not a subset of V
+ * @throws java.lang.IllegalArgumentException if vprime is not a subset of V
* @return the simplified graph
*/
def simplify(vprime: Set[T]): DiGraph[T] = {
require(vprime.subsetOf(edges.keySet))
- val pathEdges = vprime.map( v => (v, reachableFrom(v) & (vprime-v)) )
+ val pathEdges = vprime.map(v => (v, reachableFrom(v) & (vprime-v)) )
new DiGraph(new LinkedHashMap[T, LinkedHashSet[T]] ++ pathEdges)
}
@@ -394,7 +393,7 @@ class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]
}
/** Add edge (u,v) to the graph.
- * @throws scala.IllegalArgumentException if u and/or v is not in the graph
+ * @throws java.lang.IllegalArgumentException if u and/or v is not in the graph
*/
def addEdge(u: T, v: T): Unit = {
require(contains(u))
diff --git a/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala b/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala
new file mode 100644
index 00000000..79922fa9
--- /dev/null
+++ b/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala
@@ -0,0 +1,50 @@
+// See LICENSE for license details.
+
+package firrtlTests.analyses
+
+import firrtl.analyses.CircuitGraph
+import firrtl.annotations.CircuitTarget
+import firrtl.options.Dependency
+import firrtl.passes.ExpandWhensAndCheck
+import firrtl.stage.{Forms, TransformManager}
+import firrtl.testutils.FirrtlFlatSpec
+import firrtl.{ChirrtlForm, CircuitState, FileUtils, UnknownForm}
+
+class CircuitGraphSpec extends FirrtlFlatSpec {
+ "CircuitGraph" should "find paths with deep hierarchy quickly" in {
+ def mkChild(n: Int): String =
+ s""" module Child${n} :
+ | input in: UInt<8>
+ | output out: UInt<8>
+ | inst c1 of Child${n+1}
+ | inst c2 of Child${n+1}
+ | c1.in <= in
+ | c2.in <= c1.out
+ | out <= c2.out
+ """.stripMargin
+ def mkLeaf(n: Int): String =
+ s""" module Child${n} :
+ | input in: UInt<8>
+ | output out: UInt<8>
+ | wire middle: UInt<8>
+ | middle <= in
+ | out <= middle
+ """.stripMargin
+ (2 until 23 by 2).foreach { n =>
+ val input = new StringBuilder()
+ input ++=
+ """circuit Child0:
+ |""".stripMargin
+ (0 until n).foreach { i => input ++= mkChild(i); input ++= "\n" }
+ input ++= mkLeaf(n)
+ val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform(
+ CircuitState(parse(input.toString()), UnknownForm)
+ ).circuit
+ val circuitGraph = CircuitGraph(circuit)
+ val C = CircuitTarget("Child0")
+ val Child0 = C.module("Child0")
+ circuitGraph.connectionPath(Child0.ref("in"), Child0.ref("out"))
+ }
+ }
+
+}
diff --git a/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala b/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala
new file mode 100644
index 00000000..06f59a3c
--- /dev/null
+++ b/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala
@@ -0,0 +1,107 @@
+// See LICENSE for license details.
+
+package firrtlTests.analyses
+
+import firrtl.{ChirrtlForm, CircuitState, FileUtils, IRToWorkingIR, UnknownForm}
+import firrtl.analyses.{CircuitGraph, ConnectionGraph}
+import firrtl.annotations.ModuleTarget
+import firrtl.options.Dependency
+import firrtl.passes.ExpandWhensAndCheck
+import firrtl.stage.{Forms, TransformManager}
+import firrtl.testutils.FirrtlFlatSpec
+
+class ConnectionGraphSpec extends FirrtlFlatSpec {
+
+ "ConnectionGraph" should "build connection graph for rocket-chip" in {
+ ConnectionGraph(
+ new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform(
+ CircuitState(parse(FileUtils.getTextResource("/regress/RocketCore.fir")), UnknownForm)
+ ).circuit
+ )
+ }
+
+ val input =
+ """circuit Test:
+ | module Test :
+ | input in: UInt<8>
+ | input clk: Clock
+ | input reset: UInt<1>
+ | output out: {a: UInt<8>, b: UInt<8>[2]}
+ | out is invalid
+ | reg r: UInt<8>, clk with:
+ | (reset => (reset, UInt(0)))
+ | r <= in
+ | node x = r
+ | wire y: UInt<8>
+ | y <= x
+ | out.b[0] <= and(y, asUInt(SInt(-1)))
+ | inst child of Child
+ | child.in <= in
+ | out.a <= child.out
+ | module Child:
+ | input in: UInt<8>
+ | output out: UInt<8>
+ | out <= in
+ |""".stripMargin
+
+ val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform(
+ CircuitState(parse(input), UnknownForm)
+ ).circuit
+
+ "ConnectionGraph" should "work with pathsInDAG" in {
+ val Test = ModuleTarget("Test", "Test")
+ val irGraph = ConnectionGraph(circuit)
+
+ val paths = irGraph.pathsInDAG(Test.ref("in"))
+ paths(Test.ref("out").field("b").index(0)) shouldBe Seq(
+ Seq(
+ Test.ref("in"),
+ Test.ref("r"),
+ Test.ref("x"),
+ Test.ref("y"),
+ Test.ref("@and#0"),
+ Test.ref("out").field("b").index(0)
+ )
+ )
+ paths(Test.ref("out").field("a")) shouldBe Seq(
+ Seq(
+ Test.ref("in"),
+ Test.ref("child").field("in"),
+ Test.instOf("child", "Child").ref("in"),
+ Test.instOf("child", "Child").ref("out"),
+ Test.ref("child").field("out"),
+ Test.ref("out").field("a")
+ )
+ )
+
+ }
+
+ "ConnectionGraph" should "work with path" in {
+ val Test = ModuleTarget("Test", "Test")
+ val irGraph = ConnectionGraph(circuit)
+
+ irGraph.path(Test.ref("in"), Test.ref("out").field("b").index(0)) shouldBe Seq(
+ Test.ref("in"),
+ Test.ref("r"),
+ Test.ref("x"),
+ Test.ref("y"),
+ Test.ref("@and#0"),
+ Test.ref("out").field("b").index(0)
+ )
+
+ irGraph.path(Test.ref("in"), Test.ref("out").field("a")) shouldBe Seq(
+ Test.ref("in"),
+ Test.ref("child").field("in"),
+ Test.instOf("child", "Child").ref("in"),
+ Test.instOf("child", "Child").ref("out"),
+ Test.ref("child").field("out"),
+ Test.ref("out").field("a")
+ )
+
+ irGraph.path(Test.ref("@invalid#0"), Test.ref("out").field("b").index(1)) shouldBe Seq(
+ Test.ref("@invalid#0"),
+ Test.ref("out").field("b").index(1)
+ )
+ }
+
+}
diff --git a/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala b/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala
new file mode 100644
index 00000000..50ee75ac
--- /dev/null
+++ b/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala
@@ -0,0 +1,296 @@
+package firrtlTests.analyses
+
+import firrtl.PrimOps.AsUInt
+import firrtl.analyses.IRLookup
+import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget}
+import firrtl._
+import firrtl.ir._
+import firrtl.options.Dependency
+import firrtl.passes.ExpandWhensAndCheck
+import firrtl.stage.{Forms, TransformManager}
+import firrtl.testutils.FirrtlFlatSpec
+
+
+class IRLookupSpec extends FirrtlFlatSpec {
+
+ "IRLookup" should "return declarations" in {
+ val input =
+ """circuit Test:
+ | module Test :
+ | input in: UInt<8>
+ | input clk: Clock
+ | input reset: UInt<1>
+ | output out: {a: UInt<8>, b: UInt<8>[2]}
+ | input ana1: Analog<8>
+ | output ana2: Analog<8>
+ | out is invalid
+ | reg r: UInt<8>, clk with:
+ | (reset => (reset, UInt(0)))
+ | node x = r
+ | wire y: UInt<8>
+ | y <= x
+ | out.b[0] <= and(y, asUInt(SInt(-1)))
+ | attach(ana1, ana2)
+ | inst child of Child
+ | out.a <= child.out
+ | module Child:
+ | output out: UInt<8>
+ | out <= UInt(1)
+ |""".stripMargin
+
+ val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform(
+ CircuitState(parse(input), UnknownForm)
+ ).circuit
+ val irLookup = IRLookup(circuit)
+ val Test = ModuleTarget("Test", "Test")
+ val uint8 = UIntType(IntWidth(8))
+
+ irLookup.declaration(Test.ref("in")) shouldBe Port(NoInfo, "in", Input, uint8)
+ irLookup.declaration(Test.ref("clk")) shouldBe Port(NoInfo, "clk", Input, ClockType)
+ irLookup.declaration(Test.ref("reset")) shouldBe Port(NoInfo, "reset", Input, UIntType(IntWidth(1)))
+
+ val out = Port(NoInfo, "out", Output,
+ BundleType(Seq(Field("a", Default, uint8), Field("b", Default, VectorType(uint8, 2))))
+ )
+ irLookup.declaration(Test.ref("out")) shouldBe out
+ irLookup.declaration(Test.ref("out").field("a")) shouldBe out
+ irLookup.declaration(Test.ref("out").field("b").index(0)) shouldBe out
+ irLookup.declaration(Test.ref("out").field("b").index(1)) shouldBe out
+
+ irLookup.declaration(Test.ref("ana1")) shouldBe Port(NoInfo, "ana1", Input, AnalogType(IntWidth(8)))
+ irLookup.declaration(Test.ref("ana2")) shouldBe Port(NoInfo, "ana2", Output, AnalogType(IntWidth(8)))
+
+ val clk = WRef("clk", ClockType, PortKind, SourceFlow)
+ val reset = WRef("reset", UIntType(IntWidth(1)), PortKind, SourceFlow)
+ val init = UIntLiteral(0)
+ val reg = DefRegister(NoInfo, "r", uint8, clk, reset, init)
+ irLookup.declaration(Test.ref("r")) shouldBe reg
+ irLookup.declaration(Test.ref("r").clock) shouldBe reg
+ irLookup.declaration(Test.ref("r").reset) shouldBe reg
+ irLookup.declaration(Test.ref("r").init) shouldBe reg
+ irLookup.kindFinder(Test, RegKind) shouldBe Seq(Test.ref("r"))
+ irLookup.declaration(Test.ref("x")) shouldBe DefNode(NoInfo, "x", WRef("r", uint8, RegKind, SourceFlow))
+ irLookup.declaration(Test.ref("y")) shouldBe DefWire(NoInfo, "y", uint8)
+
+ irLookup.declaration(Test.ref("@and#0")) shouldBe
+ DoPrim(PrimOps.And,
+ Seq(WRef("y", uint8, WireKind, SourceFlow), DoPrim(AsUInt, Seq(SIntLiteral(-1)), Nil, UIntType(IntWidth(1)))),
+ Nil,
+ uint8
+ )
+
+ val inst = WDefInstance(NoInfo, "child", "Child", BundleType(Seq(Field("out", Default, uint8))))
+ irLookup.declaration(Test.ref("child")) shouldBe inst
+ irLookup.declaration(Test.ref("child").field("out")) shouldBe inst
+ irLookup.declaration(Test.instOf("child", "Child").ref("out")) shouldBe Port(NoInfo, "out", Output, uint8)
+
+ intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Child").ref("missing")) }
+ intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Missing").ref("out")) }
+ intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("missing", "Child").ref("out")) }
+ intercept[IllegalArgumentException]{ irLookup.declaration(Test.ref("missing")) }
+ intercept[IllegalArgumentException]{ irLookup.declaration(Test.ref("out").field("c")) }
+ intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Child").ref("out").field("missing")) }
+ }
+
+ "IRLookup" should "return mem declarations" in {
+ def commonFields: Seq[String] = Seq("clk", "en", "addr")
+ def readerTargets(rt: ReferenceTarget): Seq[ReferenceTarget] = {
+ (commonFields ++ Seq("data")).map(rt.field)
+ }
+ def writerTargets(rt: ReferenceTarget): Seq[ReferenceTarget] = {
+ (commonFields ++ Seq("data", "mask")).map(rt.field)
+ }
+ def readwriterTargets(rt: ReferenceTarget): Seq[ReferenceTarget] = {
+ (commonFields ++ Seq("wdata", "wmask", "wmode", "rdata")).map(rt.field)
+ }
+ val input =
+ s"""circuit Test:
+ | module Test :
+ | input in : UInt<8>
+ | input clk: Clock[3]
+ | input dataClk: Clock
+ | input mode: UInt<1>
+ | output out : UInt<8>[2]
+ | mem m:
+ | data-type => UInt<8>
+ | reader => r
+ | writer => w
+ | readwriter => rw
+ | depth => 2
+ | write-latency => 1
+ | read-latency => 0
+ |
+ | reg addr: UInt<1>, dataClk
+ | reg en: UInt<1>, dataClk
+ | reg indata: UInt<8>, dataClk
+ |
+ | m.r.clk <= clk[0]
+ | m.r.en <= en
+ | m.r.addr <= addr
+ | out[0] <= m.r.data
+ |
+ | m.w.clk <= clk[1]
+ | m.w.en <= en
+ | m.w.addr <= addr
+ | m.w.data <= indata
+ | m.w.mask <= en
+ |
+ | m.rw.clk <= clk[2]
+ | m.rw.en <= en
+ | m.rw.addr <= addr
+ | m.rw.wdata <= indata
+ | m.rw.wmask <= en
+ | m.rw.wmode <= en
+ | out[1] <= m.rw.rdata
+ |""".stripMargin
+
+ val C = CircuitTarget("Test")
+ val MemTest = C.module("Test")
+ val Mem = MemTest.ref("m")
+ val Reader = Mem.field("r")
+ val Writer = Mem.field("w")
+ val Readwriter = Mem.field("rw")
+ val allSignals = readerTargets(Reader) ++ writerTargets(Writer) ++ readwriterTargets(Readwriter)
+
+ val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform(
+ CircuitState(parse(input), UnknownForm)
+ ).circuit
+ val irLookup = IRLookup(circuit)
+ val uint8 = UIntType(IntWidth(8))
+ val mem = DefMemory(NoInfo, "m", uint8, 2, 1, 0, Seq("r"), Seq("w"), Seq("rw"))
+ allSignals.foreach { at =>
+ irLookup.declaration(at) shouldBe mem
+ }
+ }
+
+ "IRLookup" should "return expressions, types, kinds, and flows" in {
+ val input =
+ """circuit Test:
+ | module Test :
+ | input in: UInt<8>
+ | input clk: Clock
+ | input reset: UInt<1>
+ | output out: {a: UInt<8>, b: UInt<8>[2]}
+ | input ana1: Analog<8>
+ | output ana2: Analog<8>
+ | out is invalid
+ | reg r: UInt<8>, clk with:
+ | (reset => (reset, UInt(0)))
+ | node x = r
+ | wire y: UInt<8>
+ | y <= x
+ | out.b[0] <= and(y, asUInt(SInt(-1)))
+ | attach(ana1, ana2)
+ | inst child of Child
+ | out.a <= child.out
+ | module Child:
+ | output out: UInt<8>
+ | out <= UInt(1)
+ |""".stripMargin
+
+ val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform(
+ CircuitState(parse(input), UnknownForm)
+ ).circuit
+ val irLookup = IRLookup(circuit)
+ val Test = ModuleTarget("Test", "Test")
+ val uint8 = UIntType(IntWidth(8))
+ val uint1 = UIntType(IntWidth(1))
+
+ def check(rt: ReferenceTarget, e: Expression): Unit = {
+ irLookup.expr(rt) shouldBe e
+ irLookup.tpe(rt) shouldBe e.tpe
+ irLookup.kind(rt) shouldBe Utils.kind(e)
+ irLookup.flow(rt) shouldBe Utils.flow(e)
+ }
+
+ check(Test.ref("in"), WRef("in", uint8, PortKind, SourceFlow))
+ check(Test.ref("clk"), WRef("clk", ClockType, PortKind, SourceFlow))
+ check(Test.ref("reset"), WRef("reset", uint1, PortKind, SourceFlow))
+
+ val out = Test.ref("out")
+ val outExpr =
+ WRef("out",
+ BundleType(Seq(Field("a", Default, uint8), Field("b", Default, VectorType(uint8, 2)))),
+ PortKind,
+ SinkFlow
+ )
+ check(out, outExpr)
+ check(out.field("a"), WSubField(outExpr, "a", uint8, SinkFlow))
+ val outB = out.field("b")
+ val outBExpr = WSubField(outExpr, "b", VectorType(uint8, 2), SinkFlow)
+ check(outB, outBExpr)
+ check(outB.index(0), WSubIndex(outBExpr, 0, uint8, SinkFlow))
+ check(outB.index(1), WSubIndex(outBExpr, 1, uint8, SinkFlow))
+
+ check(Test.ref("ana1"), WRef("ana1", AnalogType(IntWidth(8)), PortKind, SourceFlow))
+ check(Test.ref("ana2"), WRef("ana2", AnalogType(IntWidth(8)), PortKind, SinkFlow))
+
+ val clk = WRef("clk", ClockType, PortKind, SourceFlow)
+ val reset = WRef("reset", UIntType(IntWidth(1)), PortKind, SourceFlow)
+ val init = UIntLiteral(0)
+ check(Test.ref("r"), WRef("r", uint8, RegKind, DuplexFlow))
+ check(Test.ref("r").clock, clk)
+ check(Test.ref("r").reset, reset)
+ check(Test.ref("r").init, init)
+
+ check(Test.ref("x"), WRef("x", uint8, ExpKind, SourceFlow))
+
+ check(Test.ref("y"), WRef("y", uint8, WireKind, DuplexFlow))
+
+ check(Test.ref("@and#0"),
+ DoPrim(PrimOps.And,
+ Seq(WRef("y", uint8, WireKind, SourceFlow), DoPrim(AsUInt, Seq(SIntLiteral(-1)), Nil, UIntType(IntWidth(1)))),
+ Nil,
+ uint8
+ )
+ )
+
+ val child = WRef("child", BundleType(Seq(Field("out", Default, uint8))), InstanceKind, SourceFlow)
+ check(Test.ref("child"), child)
+ check(Test.ref("child").field("out"),
+ WSubField(child, "out", uint8, SourceFlow)
+ )
+ }
+
+ "IRLookup" should "cache expressions" in {
+ def mkType(i: Int): String = {
+ if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}"
+ }
+
+ val depth = 500
+
+ val input =
+ s"""circuit Test:
+ | module Test :
+ | input in: ${mkType(depth)}
+ | output out: ${mkType(depth)}
+ | out <= in
+ |""".stripMargin
+
+ val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform(
+ CircuitState(parse(input), UnknownForm)
+ ).circuit
+ val Test = ModuleTarget("Test", "Test")
+ val irLookup = IRLookup(circuit)
+ def mkReferences(parent: ReferenceTarget, i: Int): Seq[ReferenceTarget] = {
+ if(i == 0) Seq(parent) else {
+ val newParent = parent.field("x")
+ newParent +: mkReferences(newParent, i - 1)
+ }
+ }
+
+ // Check caching from root to leaf
+ val inRefs = mkReferences(Test.ref("in"), depth)
+ val (inStartTime, _) = Utils.time(irLookup.expr(inRefs.head))
+ inRefs.tail.foreach { r =>
+ val (ms, _) = Utils.time(irLookup.expr(r))
+ require(inStartTime > ms)
+ }
+ val outRefs = mkReferences(Test.ref("out"), depth).reverse
+ val (outStartTime, _) = Utils.time(irLookup.expr(outRefs.head))
+ outRefs.tail.foreach { r =>
+ val (ms, _) = Utils.time(irLookup.expr(r))
+ require(outStartTime > ms)
+ }
+ }
+}
diff --git a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala
index 4d38340f..f847fb6c 100644
--- a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala
+++ b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala
@@ -1,3 +1,5 @@
+// See LICENSE for license details.
+
package firrtlTests
package transforms