diff options
| author | Jiuyang Liu | 2020-08-01 00:25:13 +0800 |
|---|---|---|
| committer | GitHub | 2020-07-31 16:25:13 +0000 |
| commit | f22652a330afe1daa77be2aadb525d65ab05e9fe (patch) | |
| tree | 59424ccbe5634993b62a3040f74d077e66ed7c1d /src | |
| parent | ba2be50f42c1ec760decc22cfda73fbd39113b53 (diff) | |
[WIP] Implement CircuitGraph and IRLookup to firrtl.analyses (#1603)
* WIP Commit
* Add EdgeDataDiGraph with views to amortize graph construction
* WIP, got basic structure, need tests to pipeclean
* First tests pass. Need more.
* Tests pass, more need to be written
* More tests pass! Things should work, except for memories
* Added clearPrev to fix digraph uses where caching prev breaks
* Removed old Component. Documented IRLookup
* Added comments. Make prev arg to getEdges
* WIP: Refactoring for CircuitGraph
* Refactored into CircuitGraph. Can do topological module analysis
* Removed old versions
* Added support for memories
* Added cached test
* More stufffff
* Added implicit caching of connectivity
* Added tests for IRLookup, and others
* Many major changes.
Replaced CircuitGraph as ConnectionGraph
Added CircuitGraph to be top-level user-facing object
ConnectionGraph now automatically shortcuts getEdges
ConnectionGraph overwrites BFS as PriorityBFS
Added leafModule to Target
Added lookup by kind to IRLookup
Added more tests
* Reordered stuff in ConnectionGraph
* Made path work with deep hierarchies. Added PML for IllegalClockCrossings
* Made pathsInDAG work with current shortcut semantics
* Bugfix: check pathless targets when shortcutting paths
* Added documentation/licenses
* Removed UnnamedToken and related functionality
* Added documentation of ConnectionGraph
* Added back topo, needed for correct solving of intermediate modules
* Bugfix. Cache intermediate clockSources from same BFS with same root, but not BFS with different root
* Added literal/invalid clock source, and unknown top for getclocksource
* Bugfix for clocks in bundles
* Add CompleteTargetSerializer and test
* remove ClockFinder, be able to compile.
* test is able to compile, but need to fix.
* public and abstract DiGraph, remove DiGraphLike.
* revert some DiGraph code, ConnectionGraphSpec passed.
* CircuitGraphSpec passed.
* minimize diff between master
* codes clean up
* override linearize and revert DiGraph
* keep DiGraph unchanged.
* make ci happy again.
* codes clean up.
* bug fix for rebase
* remove wir
* make scaladoc happy again.
* update for review.
* add some documentation.
* remove tag
* wip IRLookup
* code clean up and add some doucmentations.
* IRLookup cache with ModuleTarget guarded.
* make unidoc and 2.13 happy
Co-authored-by: Adam Izraelevitz <azidar@gmail.com>
Co-authored-by: Albert Magyar <albert.magyar@gmail.com>
Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/analyses/CircuitGraph.scala | 137 | ||||
| -rw-r--r-- | src/main/scala/firrtl/analyses/ConnectionGraph.scala | 573 | ||||
| -rw-r--r-- | src/main/scala/firrtl/analyses/IRLookup.scala | 265 | ||||
| -rw-r--r-- | src/main/scala/firrtl/analyses/InstanceGraph.scala | 40 | ||||
| -rw-r--r-- | src/main/scala/firrtl/annotations/Target.scala | 99 | ||||
| -rw-r--r-- | src/main/scala/firrtl/graph/DiGraph.scala | 21 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala | 50 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala | 107 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/analyses/IRLookupSpec.scala | 296 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala | 2 |
10 files changed, 1554 insertions, 36 deletions
diff --git a/src/main/scala/firrtl/analyses/CircuitGraph.scala b/src/main/scala/firrtl/analyses/CircuitGraph.scala new file mode 100644 index 00000000..a5c811c6 --- /dev/null +++ b/src/main/scala/firrtl/analyses/CircuitGraph.scala @@ -0,0 +1,137 @@ +// See LICENSE for license details. + +package firrtl.analyses + +import firrtl.Kind +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.annotations._ +import firrtl.ir.{Circuit, DefInstance} + +/** Use to construct [[CircuitGraph]] + * Also contains useful related functions + */ +object CircuitGraph { + + /** Build a CircuitGraph + * [[firrtl.ir.Circuit]] must be of MiddleForm or lower + * @param circuit + * @return + */ + def apply(circuit: Circuit): CircuitGraph = new CircuitGraph(ConnectionGraph(circuit)) + + /** Return a nicely-formatted string of a path of [[firrtl.annotations.ReferenceTarget]] + * @param connectionPath + * @param tab + * @return + */ + def prettyToString(connectionPath: Seq[ReferenceTarget], tab: String = ""): String = { + tab + connectionPath.mkString(s"\n$tab") + } +} + +/** Graph-representation of a FIRRTL Circuit + * + * Requires Middle FIRRTL + * Useful for writing design-specific custom-transforms that require connectivity information + * + * @param connectionGraph Source-to-sink connectivity graph + */ +class CircuitGraph private[analyses] (val connectionGraph: ConnectionGraph) { + + // Reverse (sink-to-source) connectivity graph + lazy val reverseConnectionGraph = connectionGraph.reverseConnectionGraph + + // AST Circuit + val circuit = connectionGraph.circuit + + // AST Information + val irLookup = connectionGraph.irLookup + + // Module/Instance Hierarchy information + lazy val instanceGraph = new InstanceGraph(circuit) + + // Per module, which modules does it instantiate + lazy val moduleChildren = instanceGraph.getChildrenInstanceOfModule + + // Top-level module target + val main = ModuleTarget(circuit.main, circuit.main) + + /** Given a signal, return the signals that it drives + * @param source + * @return + */ + def fanOutSignals(source: ReferenceTarget): Set[ReferenceTarget] = connectionGraph.getEdges(source).toSet + + /** Given a signal, return the signals that drive it + * @param sink + * @return + */ + def fanInSignals(sink: ReferenceTarget): Set[ReferenceTarget] = reverseConnectionGraph.getEdges(sink).toSet + + /** Return the absolute paths of all instances of this module. + * + * For example: + * - Top instantiates a1 of A and a2 of A + * - A instantiates b1 of B and b2 of B + * Then, absolutePaths of B will return: + * - Seq(~Top|Top/a1:A/b1:B, ~Top|Top/a1:A/b2:B, ~Top|Top/a2:A/b1:B, ~Top|Top/a2:A/b2:B) + * @param mt + * @return + */ + def absolutePaths(mt: ModuleTarget): Seq[IsModule] = instanceGraph.findInstancesInHierarchy(mt.module).map { + case seq if seq.nonEmpty => seq.foldLeft(CircuitTarget(circuit.main).module(circuit.main): IsModule) { + case (it, DefInstance(_, instance, ofModule, _)) => it.instOf(instance, ofModule) + } + } + + /** Return the sequence of nodes from source to sink, inclusive + * @param source + * @param sink + * @return + */ + def connectionPath(source: ReferenceTarget, sink: ReferenceTarget): Seq[ReferenceTarget] = + connectionGraph.path(source, sink) + + /** Return a reference to all nodes of given kind, directly contained in the referenced module/instance + * Path can be either a module, or an instance + * @param path + * @param kind + * @return + */ + def localReferences(path: IsModule, kind: Kind): Seq[ReferenceTarget] = { + val leafModule = path.leafModule + irLookup.kindFinder(ModuleTarget(circuit.main, leafModule), kind).map(_.setPathTarget(path)) + } + + /** Return a reference to all nodes of given kind, contained in the referenced module/instance or any child instance + * Path can be either a module, or an instance + * @param kind + * @param path + * @return + */ + def deepReferences(kind: Kind, path: IsModule = ModuleTarget(circuit.main, circuit.main)): Seq[ReferenceTarget] = { + val leafModule = path.leafModule + val children = moduleChildren(leafModule) + val localRefs = localReferences(path, kind) + localRefs ++ children.flatMap { + case (Instance(inst), OfModule(ofModule)) => deepReferences(kind, path.instOf(inst, ofModule)) + } + } + + /** Return all absolute references to signals of the given kind directly contained in the module + * @param moduleTarget + * @param kind + * @return + */ + def absoluteReferences(moduleTarget: ModuleTarget, kind: Kind): Seq[ReferenceTarget] = { + localReferences(moduleTarget, kind).flatMap(makeAbsolute) + } + + /** Given a reference, return all instances of that reference (i.e. with absolute paths) + * @param reference + * @return + */ + def makeAbsolute(reference: ReferenceTarget): Seq[ReferenceTarget] = { + absolutePaths(reference.moduleTarget).map(abs => reference.setPathTarget(abs)) + } +} diff --git a/src/main/scala/firrtl/analyses/ConnectionGraph.scala b/src/main/scala/firrtl/analyses/ConnectionGraph.scala new file mode 100644 index 00000000..0e13711a --- /dev/null +++ b/src/main/scala/firrtl/analyses/ConnectionGraph.scala @@ -0,0 +1,573 @@ +// See LICENSE for license details. + +package firrtl.analyses + +import firrtl.Mappers._ +import firrtl.annotations.{TargetToken, _} +import firrtl.graph.{CyclicException, DiGraph, MutableDiGraph} +import firrtl.ir._ +import firrtl.passes.MemPortUtils +import firrtl.{InstanceKind, PortKind, SinkFlow, SourceFlow, Utils, WInvalid} + +import scala.collection.mutable + +/** Class to represent circuit connection. + * + * @param circuit firrtl AST of this graph. + * @param digraph Directed graph of ReferenceTarget in the AST. + * @param irLookup [[IRLookup]] instance of circuit graph. + * */ +class ConnectionGraph protected(val circuit: Circuit, + val digraph: DiGraph[ReferenceTarget], + val irLookup: IRLookup) + extends DiGraph[ReferenceTarget](digraph.getEdgeMap.asInstanceOf[mutable.LinkedHashMap[ReferenceTarget, mutable.LinkedHashSet[ReferenceTarget]]]) { + + lazy val serialize: String = s"""{ + |${getEdgeMap.map { case (k, vs) => + s""" "$k": { + | "kind": "${irLookup.kind(k)}", + | "type": "${irLookup.tpe(k)}", + | "expr": "${irLookup.expr(k, irLookup.flow(k))}", + | "sinks": [${vs.map { v => s""""$v"""" }.mkString(", ")}], + | "declaration": "${irLookup.declaration(k)}" + | }""".stripMargin }.mkString(",\n")} + |}""".stripMargin + + /** Used by BFS to map each visited node to the list of instance inputs visited thus far + * + * When BFS descends into a child instance, the child instance port is prepended to the list + * When BFS ascends into a parent instance, the head of the list is removed + * In essence, the list is a stack that you push when descending, and pop when ascending + * + * Because the search is BFS not DFS, we must record the state of the stack for each edge node, so + * when that edge node is finally visited, we know the state of the stack + * + * For example: + * circuit Top: + * module Top: + * input in: UInt + * output out: UInt + * inst a of A + * a.in <= in + * out <= a.out + * module A: + * input in: UInt + * output out: UInt + * inst b of B + * b.in <= in + * out <= b.out + * module B: + * input in: UInt + * output out: UInt + * out <= in + * + * We perform BFS starting at `Top>in`, + * Node [[portConnectivityStack]] + * + * Top>in List() + * Top>a.in List() + * Top/a:A>in List(Top>a.in) + * Top/a:A>b.in List(Top>a.in) + * Top/a:A/b:B/in List(Top/a:A>b.in, Top>a.in) + * Top/a:A/b:B/out List(Top/a:A>b.in, Top>a.in) + * Top/a:A>b.out List(Top>a.in) + * Top/a:A>out List(Top>a.in) + * Top>a.out List() + * Top>out List() + * when we reach `Top/a:A>`, `Top>a.in` will be pushed into [[portConnectivityStack]]; + * when we reach `Top/a:A/b:B>`, `Top/a:A>b.in` will be pushed into [[portConnectivityStack]]; + * when we leave `Top/a:A/b:B>`, `Top/a:A>b.in` will be popped from [[portConnectivityStack]]; + * when we leave `Top/a:A>`, `Top/a:A>b.in` will be popped from [[portConnectivityStack]]. + */ + private val portConnectivityStack: mutable.HashMap[ReferenceTarget, List[ReferenceTarget]] = + mutable.HashMap.empty[ReferenceTarget, List[ReferenceTarget]] + + /** Records connectivities found while BFS is executing, from a module's source port to sink ports of a module + * + * All keys and values are local references. + * + * A BFS search will first query this map. If the query fails, then it continues and populates the map. If the query + * succeeds, then the BFS shortcuts with the values provided by the query. + * + * Because this BFS implementation uses a priority queue which prioritizes exploring deeper instances first, a + * successful query during BFS will only occur after all paths which leave the module from that reference have + * already been searched. + */ + private val bfsShortCuts: mutable.HashMap[ReferenceTarget, mutable.HashSet[ReferenceTarget]] = + mutable.HashMap.empty[ReferenceTarget, mutable.HashSet[ReferenceTarget]] + + /** Records connectivities found after BFS is completed, from a module's source port to sink ports of a module + * + * All keys and values are local references. + * + * If its keys contain a reference, then the value will be complete, in that all paths from the reference out of + * the module will have been explored + * + * For example, if Top>in connects to Top>out1 and Top>out2, then foundShortCuts(Top>in) will contain + * Set(Top>out1, Top>out2), not Set(Top>out1) or Set(Top>out2) + */ + private val foundShortCuts: mutable.HashMap[ReferenceTarget, mutable.HashSet[ReferenceTarget]] = + mutable.HashMap.empty[ReferenceTarget, mutable.HashSet[ReferenceTarget]] + + /** Returns whether a previous BFS search has found a shortcut out of a module, starting from target + * + * @param target first target to find shortcut. + * @return true if find a shortcut. + */ + def hasShortCut(target: ReferenceTarget): Boolean = getShortCut(target).nonEmpty + + /** Optionally returns the shortcut a previous BFS search may have found out of a module, starting from target + * + * @param target first target to find shortcut. + * @return [[firrtl.annotations.ReferenceTarget]] of short cut. + */ + def getShortCut(target: ReferenceTarget): Option[Set[ReferenceTarget]] = + foundShortCuts.get(target.pathlessTarget).map(set => set.map(_.setPathTarget(target.pathTarget)).toSet) + + /** Returns the shortcut a previous BFS search may have found out of a module, starting from target + * + * @param target first target to find shortcut. + * @return [[firrtl.annotations.ReferenceTarget]] of short cut. + */ + def shortCut(target: ReferenceTarget): Set[ReferenceTarget] = getShortCut(target).get + + /** @return a new, reversed connection graph where edges point from sinks to sources. */ + def reverseConnectionGraph: ConnectionGraph = new ConnectionGraph(circuit, digraph.reverse, irLookup) + + override def BFS(root: ReferenceTarget, blacklist: collection.Set[ReferenceTarget]): collection.Map[ReferenceTarget, ReferenceTarget] = { + val prev = new mutable.LinkedHashMap[ReferenceTarget, ReferenceTarget]() + val ordering = new Ordering[ReferenceTarget] { + override def compare(x: ReferenceTarget, y: ReferenceTarget): Int = x.path.size - y.path.size + } + val bfsQueue = new mutable.PriorityQueue[ReferenceTarget]()(ordering) + bfsQueue.enqueue(root) + while (bfsQueue.nonEmpty) { + val u = bfsQueue.dequeue + for (v <- getEdges(u)) { + if (!prev.contains(v) && !blacklist.contains(v)) { + prev(v) = u + bfsQueue.enqueue(v) + } + } + } + + foundShortCuts ++= bfsShortCuts + bfsShortCuts.clear() + portConnectivityStack.clear() + + prev + } + + /** Linearizes (topologically sorts) a DAG + * + * @throws firrtl.graph.CyclicException if the graph is cyclic + * @return a Seq[T] describing the topological order of the DAG + * traversal + */ + override def linearize: Seq[ReferenceTarget] = { + // permanently marked nodes are implicitly held in order + val order = new mutable.ArrayBuffer[ReferenceTarget] + // invariant: no intersection between unmarked and tempMarked + val unmarked = new mutable.LinkedHashSet[ReferenceTarget] + val tempMarked = new mutable.LinkedHashSet[ReferenceTarget] + val finished = new mutable.LinkedHashSet[ReferenceTarget] + + case class LinearizeFrame[A](v: A, expanded: Boolean) + val callStack = mutable.Stack[LinearizeFrame[ReferenceTarget]]() + + unmarked ++= getVertices + while (unmarked.nonEmpty) { + callStack.push(LinearizeFrame(unmarked.head, false)) + while (callStack.nonEmpty) { + val LinearizeFrame(n, expanded) = callStack.pop() + if (!expanded) { + if (tempMarked.contains(n)) { + throw new CyclicException(n) + } + if (unmarked.contains(n)) { + tempMarked += n + unmarked -= n + callStack.push(LinearizeFrame(n, true)) + // We want to visit the first edge first (so push it last) + for (m <- getEdges(n).toSeq.reverse) { + if (!unmarked.contains(m) && !tempMarked.contains(m) && !finished.contains(m)) { + unmarked += m + } + callStack.push(LinearizeFrame(m, false)) + } + } + } else { + tempMarked -= n + finished += n + order.append(n) + } + } + } + + // visited nodes are in post-traversal order, so must be reversed + order.toSeq.reverse + } + + override def getEdges(source: ReferenceTarget): collection.Set[ReferenceTarget] = { + import ConnectionGraph._ + + val localSource = source.pathlessTarget + + bfsShortCuts.get(localSource) match { + case Some(set) => set.map { x => x.setPathTarget(source.pathTarget) } + case None => + + val pathlessEdges = super.getEdges(localSource) + + val ret = pathlessEdges.flatMap { + + case localSink if withinSameInstance(source)(localSink) => + portConnectivityStack(localSink) = portConnectivityStack.getOrElse(localSource, Nil) + Set[ReferenceTarget](localSink.setPathTarget(source.pathTarget)) + + case localSink if enteringParentInstance(source)(localSink) => + val currentStack = portConnectivityStack.getOrElse(localSource, Nil) + if (currentStack.nonEmpty && currentStack.head.module == localSink.module) { + // Exiting back to parent module + // Update shortcut path from entrance from parent to new exit to parent + val instancePort = currentStack.head + val modulePort = ReferenceTarget( + localSource.circuit, + localSource.module, + Nil, + instancePort.component.head.value.toString, + instancePort.component.tail + ) + val destinations = bfsShortCuts.getOrElse(modulePort, mutable.HashSet.empty[ReferenceTarget]) + bfsShortCuts(modulePort) = destinations + localSource + // Remove entrance from parent from stack + portConnectivityStack(localSink) = currentStack.tail + } else { + // Exiting to parent, but had unresolved trip through child, so don't update shortcut + portConnectivityStack(localSink) = localSource +: currentStack + } + Set[ReferenceTarget](localSink.setPathTarget(source.noComponents.targetParent.asInstanceOf[IsComponent].pathTarget)) + + case localSink if enteringChildInstance(source)(localSink) => + portConnectivityStack(localSink) = localSource +: portConnectivityStack.getOrElse(localSource, Nil) + val x = localSink.setPathTarget(source.pathTarget.instOf(source.ref, localSink.module)) + Set[ReferenceTarget](x) + + case localSink if leavingRootInstance(source)(localSink) => Set[ReferenceTarget]() + + case localSink if enteringNonParentInstance(source)(localSink) => Set[ReferenceTarget]() + + case other => Utils.throwInternalError(s"BAD? $source -> $other") + + } + ret + } + + } + + override def path(start: ReferenceTarget, end: ReferenceTarget, blacklist: collection.Set[ReferenceTarget]): Seq[ReferenceTarget] = { + insertShortCuts(super.path(start, end, blacklist)) + } + + private def insertShortCuts(path: Seq[ReferenceTarget]): Seq[ReferenceTarget] = { + val soFar = mutable.HashSet[ReferenceTarget]() + if (path.size > 1) { + path.head +: path.sliding(2).flatMap { + case Seq(from, to) => + getShortCut(from) match { + case Some(set) if set.contains(to) && soFar.contains(from.pathlessTarget) => + soFar += from.pathlessTarget + Seq(from.pathTarget.ref("..."), to) + case _ => + soFar += from.pathlessTarget + Seq(to) + } + }.toSeq + } else path + } + + /** Finds all paths starting at a particular node in a DAG + * + * WARNING: This is an exponential time algorithm (as any algorithm + * must be for this problem), but is useful for flattening circuit + * graph hierarchies. Each path is represented by a Seq[T] of nodes + * in a traversable order. + * + * @param start the node to start at + * @return a `Map[T,Seq[Seq[T]]]` where the value associated with v is the Seq of all paths from start to v + */ + override def pathsInDAG(start: ReferenceTarget): mutable.LinkedHashMap[ReferenceTarget, Seq[Seq[ReferenceTarget]]] = { + val linkedMap = super.pathsInDAG(start) + linkedMap.keysIterator.foreach { key => + linkedMap(key) = linkedMap(key).map(insertShortCuts) + } + linkedMap + } + + override def findSCCs: Seq[Seq[ReferenceTarget]] = Utils.throwInternalError("Cannot call findSCCs on ConnectionGraph") +} + +object ConnectionGraph { + + /** Returns a [[firrtl.graph.DiGraph]] of [[firrtl.annotations.Target]] and corresponding [[IRLookup]] + * Represents the directed connectivity of a FIRRTL circuit + * + * @param circuit firrtl AST of graph to be constructed. + * @return [[ConnectionGraph]] of this `circuit`. + */ + def apply(circuit: Circuit): ConnectionGraph = buildCircuitGraph(circuit) + + /** Within a module, given an [[firrtl.ir.Expression]] inside a module, return a corresponding [[firrtl.annotations.ReferenceTarget]] + * @todo why no subaccess. + * + * @param m Target of module containing the expression + * @param e + * @return + */ + def asTarget(m: ModuleTarget, tagger: TokenTagger)(e: FirrtlNode): ReferenceTarget = e match { + case l: Literal => m.ref(tagger.getRef(l.value.toString)) + case r: Reference => m.ref(r.name) + case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value) + case s: SubField => asTarget(m, tagger)(s.expr).field(s.name) + case d: DoPrim => m.ref(tagger.getRef(d.op.serialize)) + case _: Mux => m.ref(tagger.getRef("mux")) + case _: ValidIf => m.ref(tagger.getRef("validif")) + case WInvalid => m.ref(tagger.getRef("invalid")) + case _: Print => m.ref(tagger.getRef("print")) + case _: Stop => m.ref(tagger.getRef("print")) + case other => sys.error(s"Unsupported: $other") + } + + def withinSameInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = { + source.encapsulatingModule == localSink.encapsulatingModule + } + + def enteringParentInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = { + def b1 = source.path.nonEmpty + + def b2 = source.noComponents.targetParent.asInstanceOf[InstanceTarget].encapsulatingModule == localSink.module + + def b3 = localSink.ref == source.path.last._1.value + + b1 && b2 && b3 + } + + def enteringNonParentInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = { + source.path.nonEmpty && + (source.noComponents.targetParent.asInstanceOf[InstanceTarget].encapsulatingModule != localSink.module || + localSink.ref != source.path.last._1.value) + } + + def enteringChildInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = source match { + case ReferenceTarget(_, _, _, _, TargetToken.Field(port) +: comps) + if port == localSink.ref && comps == localSink.component => true + case _ => false + } + + def leavingRootInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = source match { + case ReferenceTarget(_, _, Seq(), port, comps) + if port == localSink.component.head.value && comps == localSink.component.tail => true + case _ => false + } + + + private def buildCircuitGraph(circuit: Circuit): ConnectionGraph = { + val mdg = new MutableDiGraph[ReferenceTarget]() + val declarations = mutable.LinkedHashMap[ModuleTarget, mutable.LinkedHashMap[ReferenceTarget, FirrtlNode]]() + val circuitTarget = CircuitTarget(circuit.main) + val moduleMap = circuit.modules.map { m => circuitTarget.module(m.name) -> m }.toMap + + circuit map buildModule(circuitTarget) + + def addLabeledVertex(v: ReferenceTarget, f: FirrtlNode): Unit = { + mdg.addVertex(v) + declarations.getOrElseUpdate(v.moduleTarget, mutable.LinkedHashMap.empty[ReferenceTarget, FirrtlNode])(v) = f + } + + def buildModule(c: CircuitTarget)(module: DefModule): DefModule = { + val m = c.module(module.name) + module map buildPort(m) map buildStatement(m, new TokenTagger()) + } + + def buildPort(m: ModuleTarget)(port: Port): Port = { + val p = m.ref(port.name) + addLabeledVertex(p, port) + port + } + + def buildInstance(m: ModuleTarget, tagger: TokenTagger, name: String, ofModule: String, tpe: Type): Unit = { + val instPorts = Utils.create_exps(Reference(name, tpe, InstanceKind, SinkFlow)) + val modulePorts = tpe.asInstanceOf[BundleType].fields.flatMap { + // Module output + case firrtl.ir.Field(name, Default, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SourceFlow)) + // Module input + case firrtl.ir.Field(name, Flip, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SinkFlow)) + } + assert(instPorts.size == modulePorts.size) + val o = m.circuitTarget.module(ofModule) + instPorts.zip(modulePorts).foreach { x => + val (instExp, modExp) = x + val it = asTarget(m, tagger)(instExp) + val mt = asTarget(o, tagger)(modExp) + (Utils.flow(instExp), Utils.flow(modExp)) match { + case (SourceFlow, SinkFlow) => mdg.addPairWithEdge(it, mt) + case (SinkFlow, SourceFlow) => mdg.addPairWithEdge(mt, it) + case _ => sys.error("Something went wrong...") + } + } + } + + def buildMemory(mt: ModuleTarget, d: DefMemory): Unit = { + val readers = d.readers.toSet + val readwriters = d.readwriters.toSet + val mem = mt.ref(d.name) + MemPortUtils.memType(d).fields.foreach { + case Field(name, _, _: BundleType) if readers.contains(name) || readwriters.contains(name) => + val port = mem.field(name) + val sources = Seq( + port.field("clk"), + port.field("en"), + port.field("addr") + ) ++ (if (readwriters.contains(name)) Seq(port.field("wmode")) else Nil) + + val data = if (readers.contains(name)) port.field("data") else port.field("rdata") + val sinks = data.leafSubTargets(d.dataType) + + sources.foreach { + mdg.addVertex + } + sinks.foreach { sink => + mdg.addVertex(sink) + sources.foreach { source => mdg.addEdge(source, sink) } + } + case _ => + } + } + + def buildRegister(m: ModuleTarget, tagger: TokenTagger, d: DefRegister): Unit = { + val regTarget = m.ref(d.name) + val clockTarget = regTarget.clock + val resetTarget = regTarget.reset + val initTarget = regTarget.init + + // Build clock expression + mdg.addVertex(clockTarget) + buildExpression(m, tagger, clockTarget)(d.clock) + + // Build reset expression + mdg.addVertex(resetTarget) + buildExpression(m, tagger, resetTarget)(d.reset) + + // Connect each subTarget to the corresponding init subTarget + val allRegTargets = regTarget.leafSubTargets(d.tpe) + val allInitTargets = initTarget.leafSubTargets(d.tpe).zip(Utils.create_exps(d.init)) + allRegTargets.zip(allInitTargets).foreach { case (r, (i, e)) => + mdg.addVertex(i) + mdg.addVertex(r) + mdg.addEdge(clockTarget, r) + mdg.addEdge(resetTarget, r) + mdg.addEdge(i, r) + buildExpression(m, tagger, i)(e) + } + } + + def buildStatement(m: ModuleTarget, tagger: TokenTagger)(stmt: Statement): Statement = { + stmt match { + case d: DefWire => + addLabeledVertex(m.ref(d.name), stmt) + + case d: DefNode => + val sinkTarget = m.ref(d.name) + addLabeledVertex(sinkTarget, stmt) + val nodeTargets = sinkTarget.leafSubTargets(d.value.tpe) + nodeTargets.zip(Utils.create_exps(d.value)).foreach { case (n, e) => + mdg.addVertex(n) + buildExpression(m, tagger, n)(e) + } + + case c: Connect => + val sinkTarget = asTarget(m, tagger)(c.loc) + mdg.addVertex(sinkTarget) + buildExpression(m, tagger, sinkTarget)(c.expr) + + case i: IsInvalid => + val sourceTarget = asTarget(m, tagger)(WInvalid) + addLabeledVertex(sourceTarget, stmt) + mdg.addVertex(sourceTarget) + val sinkTarget = asTarget(m, tagger)(i.expr) + sinkTarget.allSubTargets(i.expr.tpe).foreach { st => + mdg.addVertex(st) + mdg.addEdge(sourceTarget, st) + } + + case DefInstance(_, name, ofModule, tpe) => + addLabeledVertex(m.ref(name), stmt) + buildInstance(m, tagger, name, ofModule, tpe) + + case d: DefRegister => + addLabeledVertex(m.ref(d.name), d) + buildRegister(m, tagger, d) + + case d: DefMemory => + addLabeledVertex(m.ref(d.name), d) + buildMemory(m, d) + + /** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.passes.ExpandWhensAndCheck]]*/ + case _: Conditionally => sys.error("Unsupported! Only works on Middle Firrtl") + + case s: Block => s map buildStatement(m, tagger) + + case a: Attach => + val attachTargets = a.exprs.map { r => + val at = asTarget(m, tagger)(r) + mdg.addVertex(at) + at + } + attachTargets.combinations(2).foreach { case Seq(l, r) => + mdg.addEdge(l, r) + mdg.addEdge(r, l) + } + case p: Print => addLabeledVertex(asTarget(m, tagger)(p), p) + case s: Stop => addLabeledVertex(asTarget(m, tagger)(s), s) + case EmptyStmt => + } + stmt + } + + def buildExpression(m: ModuleTarget, tagger: TokenTagger, sinkTarget: ReferenceTarget)(expr: Expression): Expression = { + /** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.stage.Forms.Resolved]]. */ + val sourceTarget = asTarget(m, tagger)(expr) + mdg.addVertex(sourceTarget) + mdg.addEdge(sourceTarget, sinkTarget) + expr match { + case _: DoPrim | _: Mux | _: ValidIf | _: Literal => + addLabeledVertex(sourceTarget, expr) + expr map buildExpression(m, tagger, sourceTarget) + case _ => + } + expr + } + + new ConnectionGraph(circuit, DiGraph(mdg), new IRLookup(declarations.mapValues(_.toMap).toMap, moduleMap)) + } +} + + +/** Used for obtaining a tag for a given label unnamed Target. */ +class TokenTagger { + private val counterMap = mutable.HashMap[String, Int]() + + def getTag(label: String): Int = { + val tag = counterMap.getOrElse(label, 0) + counterMap(label) = tag + 1 + tag + } + + def getRef(label: String): String = { + "@" + label + "#" + getTag(label) + } +} + +object TokenTagger { + val literalRegex = "@([-]?[0-9]+)#[0-9]+".r +} diff --git a/src/main/scala/firrtl/analyses/IRLookup.scala b/src/main/scala/firrtl/analyses/IRLookup.scala new file mode 100644 index 00000000..f9819ebd --- /dev/null +++ b/src/main/scala/firrtl/analyses/IRLookup.scala @@ -0,0 +1,265 @@ +// See LICENSE for license details. + +package firrtl.analyses + +import firrtl.annotations.TargetToken._ +import firrtl.annotations._ +import firrtl.ir._ +import firrtl.passes.MemPortUtils +import firrtl.{DuplexFlow, ExpKind, Flow, InstanceKind, Kind, MemKind, PortKind, RegKind, SinkFlow, SourceFlow, UnknownFlow, Utils, WInvalid, WireKind} + +import scala.collection.mutable + +object IRLookup { + def apply(circuit: Circuit): IRLookup = ConnectionGraph(circuit).irLookup +} + +/** Handy lookup for obtaining AST information about a given Target + * + * @param declarations Maps references (not subreferences) to declarations + * @param modules Maps module targets to modules + */ +class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map[ReferenceTarget, FirrtlNode]], + private val modules: Map[ModuleTarget, DefModule]) { + + private val flowCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Flow]]() + private val kindCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Kind]]() + private val tpeCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Type]]() + private val exprCache = mutable.HashMap[ModuleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]]() + private val refCache = mutable.HashMap[ModuleTarget, mutable.LinkedHashMap[Kind, mutable.ArrayBuffer[ReferenceTarget]]]() + + + /** @example Given ~Top|MyModule/inst:Other>foo.bar, returns ~Top|Other>foo + * @return the target converted to its local reference + */ + def asLocalRef(t: ReferenceTarget): ReferenceTarget = t.pathlessTarget.copy(component = Nil) + + def flow(t: ReferenceTarget): Flow = flowCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Flow]()).getOrElseUpdate(t.pathlessTarget, Utils.flow(expr(t.pathlessTarget))) + + def kind(t: ReferenceTarget): Kind = kindCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Kind]()).getOrElseUpdate(t.pathlessTarget, Utils.kind(expr(t.pathlessTarget))) + + def tpe(t: ReferenceTarget): Type = tpeCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Type]()).getOrElseUpdate(t.pathlessTarget, expr(t.pathlessTarget).tpe) + + /** get expression of the target. + * It can return None for many reasons, including + * - declaration is missing + * - flow is wrong + * - component is wrong + * + * @param t [[firrtl.annotations.ReferenceTarget]] to be queried. + * @param flow flow of the target + * @return Some(e) if expression exists, None if it does not + */ + def getExpr(t: ReferenceTarget, flow: Flow): Option[Expression] = { + val pathless = t.pathlessTarget + + inCache(pathless, flow) match { + case e@Some(_) => return e + case None => + val mt = pathless.moduleTarget + val emt = t.encapsulatingModuleTarget + if (declarations.contains(emt) && declarations(emt).contains(asLocalRef(t))) { + declarations(emt)(asLocalRef(t)) match { + case e: Expression => + require(e.tpe.isInstanceOf[GroundType]) + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).getOrElseUpdate((pathless, Utils.flow(e)), e) + case d: IsDeclaration => d match { + case n: DefNode => + updateExpr(mt, Reference(n.name, n.value.tpe, ExpKind, SourceFlow)) + case p: Port => + updateExpr(mt, Reference(p.name, p.tpe, PortKind, Utils.get_flow(p))) + case w: DefInstance => + updateExpr(mt, Reference(w.name, w.tpe, InstanceKind, SourceFlow)) + case w: DefWire => + updateExpr(mt, Reference(w.name, w.tpe, WireKind, SourceFlow)) + updateExpr(mt, Reference(w.name, w.tpe, WireKind, SinkFlow)) + updateExpr(mt, Reference(w.name, w.tpe, WireKind, DuplexFlow)) + case r: DefRegister if pathless.tokens.last == Clock => + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.clock + case r: DefRegister if pathless.tokens.isDefinedAt(1) && pathless.tokens(1) == Init => + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.init + updateExpr(pathless, r.init) + case r: DefRegister if pathless.tokens.last == Reset => + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.reset + case r: DefRegister => + updateExpr(mt, Reference(r.name, r.tpe, RegKind, SourceFlow)) + updateExpr(mt, Reference(r.name, r.tpe, RegKind, SinkFlow)) + updateExpr(mt, Reference(r.name, r.tpe, RegKind, DuplexFlow)) + case m: DefMemory => + updateExpr(mt, Reference(m.name, MemPortUtils.memType(m), MemKind, SourceFlow)) + case other => + sys.error(s"Cannot call expr with: $t, given declaration $other") + } + case _: IsInvalid => + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = WInvalid + } + } + } + + inCache(pathless, flow) + } + + /** + * @param t [[firrtl.annotations.ReferenceTarget]] to be queried. + * @param flow flow of the target + * @return expression of `t` + */ + def expr(t: ReferenceTarget, flow: Flow = UnknownFlow): Expression = { + require(contains(t), s"Cannot find\n${t.prettyPrint()}\nin circuit!") + getExpr(t, flow) match { + case Some(e) => e + case None => + require(getExpr(t.pathlessTarget, UnknownFlow).isEmpty, s"Illegal flow $flow with target $t") + sys.error("") + } + } + + /** Find [[firrtl.annotations.ReferenceTarget]] with a specific [[firrtl.Kind]] in a [[firrtl.annotations.ModuleTarget]] + * + * @param moduleTarget [[firrtl.annotations.ModuleTarget]] to be queried. + * @param kind [[firrtl.Kind]] to be find. + * @return all [[firrtl.annotations.ReferenceTarget]] in this node. */ + def kindFinder(moduleTarget: ModuleTarget, kind: Kind): Seq[ReferenceTarget] = { + def updateRefs(kind: Kind, rt: ReferenceTarget): Unit = refCache + .getOrElseUpdate(rt.moduleTarget, mutable.LinkedHashMap.empty[Kind, mutable.ArrayBuffer[ReferenceTarget]]) + .getOrElseUpdate(kind, mutable.ArrayBuffer.empty[ReferenceTarget]) += rt + + require(contains(moduleTarget), s"Cannot find\n${moduleTarget.prettyPrint()}\nin circuit!") + if (refCache.contains(moduleTarget) && refCache(moduleTarget).contains(kind)) refCache(moduleTarget)(kind).toSeq + else { + declarations(moduleTarget).foreach { + case (rt, _: DefRegister) => updateRefs(RegKind, rt) + case (rt, _: DefWire) => updateRefs(WireKind, rt) + case (rt, _: DefNode) => updateRefs(ExpKind, rt) + case (rt, _: DefMemory) => updateRefs(MemKind, rt) + case (rt, _: DefInstance) => updateRefs(InstanceKind, rt) + case (rt, _: Port) => updateRefs(PortKind, rt) + case _ => + } + refCache.get(moduleTarget).map(_.getOrElse(kind, Seq.empty[ReferenceTarget])).getOrElse(Seq.empty[ReferenceTarget]).toSeq + } + } + + /** + * @param t [[firrtl.annotations.ReferenceTarget]] to be queried. + * @return the statement containing the declaration of the target + */ + def declaration(t: ReferenceTarget): FirrtlNode = { + require(contains(t), s"Cannot find\n${t.prettyPrint()}\nin circuit!") + declarations(t.encapsulatingModuleTarget)(asLocalRef(t)) + } + + /** Returns the references to the module's ports + * + * @param mt [[firrtl.annotations.ModuleTarget]] to be queried. + * @return the port references of `mt` + */ + def ports(mt: ModuleTarget): Seq[ReferenceTarget] = { + require(contains(mt), s"Cannot find\n${mt.prettyPrint()}\nin circuit!") + modules(mt).ports.map { p => mt.ref(p.name) } + } + + /** Given: + * A [[firrtl.annotations.ReferenceTarget]] of ~Top|Module>ref, which is a type of {foo: {bar: UInt}} + * Return: + * Seq(~Top|Module>ref, ~Top|Module>ref.foo, ~Top|Module>ref.foo.bar) + * + * @return a target to each sub-component, including intermediate subcomponents + */ + def allTargets(r: ReferenceTarget): Seq[ReferenceTarget] = r.allSubTargets(tpe(r)) + + /** Given: + * A [[firrtl.annotations.ReferenceTarget]] of ~Top|Module>ref and a type of {foo: {bar: UInt}} + * Return: + * Seq(~Top|Module>ref.foo.bar) + * + * @return a target to each sub-component, excluding intermediate subcomponents. + */ + def leafTargets(r: ReferenceTarget): Seq[ReferenceTarget] = r.leafSubTargets(tpe(r)) + + /** @return Returns ((inputs, outputs)) target and type of each module port. */ + def moduleLeafPortTargets(m: ModuleTarget): (Seq[(ReferenceTarget, Type)], Seq[(ReferenceTarget, Type)]) = + modules(m).ports.flatMap { + case Port(_, name, Output, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SourceFlow)) + case Port(_, name, Input, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SinkFlow)) + }.foldLeft((Vector.empty[(ReferenceTarget, Type)], Vector.empty[(ReferenceTarget, Type)])) { + case ((inputs, outputs), e) if Utils.flow(e) == SourceFlow => + (inputs, outputs :+ (ConnectionGraph.asTarget(m, new TokenTagger())(e), e.tpe)) + case ((inputs, outputs), e) => + (inputs :+ (ConnectionGraph.asTarget(m, new TokenTagger())(e), e.tpe), outputs) + } + + + /** @param t [[firrtl.annotations.ReferenceTarget]] to be queried. + * @return whether a ReferenceTarget is contained in this IRLookup + */ + def contains(t: ReferenceTarget): Boolean = validPath(t.pathTarget) && + declarations.contains(t.encapsulatingModuleTarget) && + declarations(t.encapsulatingModuleTarget).contains(asLocalRef(t)) && + getExpr(t, UnknownFlow).nonEmpty + + /** @param mt [[firrtl.annotations.ModuleTarget]] or [[firrtl.annotations.InstanceTarget]] to be queried. + * @return whether a ModuleTarget or InstanceTarget is contained in this IRLookup + */ + def contains(mt: IsModule): Boolean = validPath(mt) + + /** @param t [[firrtl.annotations.ReferenceTarget]] to be queried. + * @return whether a given [[firrtl.annotations.IsModule]] is valid, given the circuit's module/instance hierarchy + */ + def validPath(t: IsModule): Boolean = { + t match { + case m: ModuleTarget => declarations.contains(m) + case i: InstanceTarget => + val all = i.pathAsTargets :+ i.encapsulatingModuleTarget.instOf(i.instance, i.ofModule) + all.map { x => + declarations.contains(x.moduleTarget) && declarations(x.moduleTarget).contains(x.asReference) && + (declarations(x.moduleTarget)(x.asReference) match { + case DefInstance(_, _, of, _) if of == x.ofModule => validPath(x.ofModuleTarget) + case _ => false + }) + }.reduce(_ && _) + } + } + + /** Updates expression cache with expression. */ + private def updateExpr(mt: ModuleTarget, ref: Expression): Unit = { + val refs = Utils.expandRef(ref) + refs.foreach { e => + val target = ConnectionGraph.asTarget(mt, new TokenTagger())(e) + exprCache(target.moduleTarget)((target, Utils.flow(e))) = e + } + } + + /** Updates expression cache with expression. */ + private def updateExpr(gt: ReferenceTarget, e: Expression): Unit = { + val g = Utils.flow(e) + e.tpe match { + case _: GroundType => + exprCache(gt.moduleTarget)((gt, g)) = e + case VectorType(t, size) => + exprCache(gt.moduleTarget)((gt, g)) = e + (0 until size).foreach { i => updateExpr(gt.index(i), SubIndex(e, i, t, g)) } + case BundleType(fields) => + exprCache(gt.moduleTarget)((gt, g)) = e + fields.foreach { f => updateExpr(gt.field(f.name), SubField(e, f.name, f.tpe, Utils.times(g, f.flip))) } + case other => sys.error(s"Error! Unexpected type $other") + } + } + + /** Optionally returns the expression corresponding to the target if contained in the expression cache. */ + private def inCache(pathless: ReferenceTarget, flow: Flow): Option[Expression] = { + (flow, + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains((pathless, SourceFlow)), + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains((pathless, SinkFlow)), + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains(pathless, DuplexFlow) + ) match { + case (SourceFlow, true, _, _) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, flow))) + case (SinkFlow, _, true, _) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, flow))) + case (DuplexFlow, _, _, true) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, DuplexFlow))) + case (UnknownFlow, _, _, true) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, DuplexFlow))) + case (UnknownFlow, true, false, false) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow))) + case (UnknownFlow, false, true, false) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SinkFlow))) + case _ => None + } + } +} diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala index ddd5eb8b..c0aec4aa 100644 --- a/src/main/scala/firrtl/analyses/InstanceGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala @@ -30,18 +30,18 @@ class InstanceGraph(c: Circuit) { val moduleMap = c.modules.map({m => (m.name,m) }).toMap private val instantiated = new mutable.LinkedHashSet[String] private val childInstances = - new mutable.LinkedHashMap[String, mutable.LinkedHashSet[WDefInstance]] + new mutable.LinkedHashMap[String, mutable.LinkedHashSet[DefInstance]] for (m <- c.modules) { - childInstances(m.name) = new mutable.LinkedHashSet[WDefInstance] + childInstances(m.name) = new mutable.LinkedHashSet[DefInstance] m.foreach(InstanceGraph.collectInstances(childInstances(m.name))) instantiated ++= childInstances(m.name).map(i => i.module) } - private val instanceGraph = new MutableDiGraph[WDefInstance] - private val instanceQueue = new mutable.Queue[WDefInstance] + private val instanceGraph = new MutableDiGraph[DefInstance] + private val instanceQueue = new mutable.Queue[DefInstance] for (subTop <- c.modules.view.map(_.name).filterNot(instantiated)) { - val topInstance = WDefInstance(subTop,subTop) + val topInstance = DefInstance(subTop,subTop) instanceQueue.enqueue(topInstance) while (instanceQueue.nonEmpty) { val current = instanceQueue.dequeue @@ -57,11 +57,11 @@ class InstanceGraph(c: Circuit) { } // The true top module (circuit main) - private val trueTopInstance = WDefInstance(c.main, c.main) + private val trueTopInstance = DefInstance(c.main, c.main) /** A directed graph showing the instance dependencies among modules - * in the circuit. Every WDefInstance of a module has an edge to - * every WDefInstance arising from every instance statement in + * in the circuit. Every DefInstance of a module has an edge to + * every DefInstance arising from every instance statement in * that module. */ lazy val graph = DiGraph(instanceGraph) @@ -69,7 +69,7 @@ class InstanceGraph(c: Circuit) { /** A list of absolute paths (each represented by a Seq of instances) * of all module instances in the Circuit. */ - lazy val fullHierarchy: mutable.LinkedHashMap[WDefInstance,Seq[Seq[WDefInstance]]] = graph.pathsInDAG(trueTopInstance) + lazy val fullHierarchy: mutable.LinkedHashMap[DefInstance,Seq[Seq[DefInstance]]] = graph.pathsInDAG(trueTopInstance) /** A count of the *static* number of instances of each module. For any module other than the top (main) module, this is * equivalent to the number of inst statements in the circuit instantiating each module, irrespective of the number @@ -96,9 +96,9 @@ class InstanceGraph(c: Circuit) { * hierarchy of the top module of the circuit, it will return Nil. * * @param module the name of the selected module - * @return a Seq[ Seq[WDefInstance] ] of absolute instance paths + * @return a Seq[ Seq[DefInstance] ] of absolute instance paths */ - def findInstancesInHierarchy(module: String): Seq[Seq[WDefInstance]] = { + def findInstancesInHierarchy(module: String): Seq[Seq[DefInstance]] = { val instances = graph.getVertices.filter(_.module == module).toSeq instances flatMap { i => fullHierarchy.getOrElse(i, Nil) } } @@ -109,8 +109,8 @@ class InstanceGraph(c: Circuit) { /** Finds the lowest common ancestor instances for two module names in * a design */ - def lowestCommonAncestor(moduleA: Seq[WDefInstance], - moduleB: Seq[WDefInstance]): Seq[WDefInstance] = { + def lowestCommonAncestor(moduleA: Seq[DefInstance], + moduleB: Seq[DefInstance]): Seq[DefInstance] = { tour.rmq(moduleA, moduleB) } @@ -126,7 +126,7 @@ class InstanceGraph(c: Circuit) { /** Given a circuit, returns a map from module name to children * instance/module definitions */ - def getChildrenInstances: mutable.LinkedHashMap[String, mutable.LinkedHashSet[WDefInstance]] = childInstances + def getChildrenInstances: mutable.LinkedHashMap[String, mutable.LinkedHashSet[DefInstance]] = childInstances /** Given a circuit, returns a map from module name to children * instance/module [[firrtl.annotations.TargetToken]]s @@ -147,7 +147,7 @@ class InstanceGraph(c: Circuit) { * in turn mapping instances names to corresponding module names */ def getChildrenInstanceMap: collection.Map[OfModule, collection.Map[Instance, OfModule]] = - childInstances.map(kv => kv._1.OfModule -> asOrderedMap(kv._2, (i: WDefInstance) => i.toTokens)) + childInstances.map(kv => kv._1.OfModule -> asOrderedMap(kv._2, (i: DefInstance) => i.toTokens)) /** The set of all modules in the circuit */ lazy val modules: collection.Set[OfModule] = graph.getVertices.map(_.OfModule) @@ -163,17 +163,17 @@ class InstanceGraph(c: Circuit) { object InstanceGraph { - /** Returns all WDefInstances in a Statement + /** Returns all DefInstances in a Statement * * @param insts mutable datastructure to append to * @param s statement to descend * @return */ - def collectInstances(insts: mutable.Set[WDefInstance]) + def collectInstances(insts: mutable.Set[DefInstance]) (s: Statement): Unit = s match { - case i: WDefInstance => insts += i - case i: DefInstance => throwInternalError("Expecting WDefInstance, found a DefInstance!") - case i: WDefInstanceConnector => throwInternalError("Expecting WDefInstance, found a WDefInstanceConnector!") + case i: DefInstance => insts += i + case i: DefInstance => throwInternalError("Expecting DefInstance, found a DefInstance!") + case i: WDefInstanceConnector => throwInternalError("Expecting DefInstance, found a DefInstanceConnector!") case _ => s.foreach(collectInstances(insts)) } } diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala index e7ea07ca..4d1cdc2f 100644 --- a/src/main/scala/firrtl/annotations/Target.scala +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -99,6 +99,15 @@ sealed trait Target extends Named { /** Whether the target is directly instantiated in its root module */ def isLocal: Boolean + + /** Share root module */ + def sharedRoot(other: Target): Boolean = this.moduleOpt == other.moduleOpt && other.moduleOpt.nonEmpty + + /** Checks whether this is inside of other */ + def encapsulatedBy(other: IsModule): Boolean = this.moduleOpt.contains(other.encapsulatingModule) + + /** @return Returns the instance hierarchy path, if one exists */ + def path: Seq[(Instance, OfModule)] } object Target { @@ -193,6 +202,33 @@ object Target { case b: InstanceTarget => b.ofModuleTarget case b: ReferenceTarget => b.pathlessTarget.moduleTarget } + + def getPathlessTarget(t: Target): Target = { + t.tryToComplete match { + case c: CircuitTarget => c + case m: IsMember => m.pathlessTarget + case t: GenericTarget if t.isLegal => + val newTokens = t.tokens.dropWhile(x => x.isInstanceOf[Instance] || x.isInstanceOf[OfModule]) + GenericTarget(t.circuitOpt, t.moduleOpt, newTokens) + case other => sys.error(s"Can't make $other pathless!") + } + } + + def getReferenceTarget(t: Target): Target = { + (t.toGenericTarget match { + case t: GenericTarget if t.isLegal => + val newTokens = t.tokens.reverse.dropWhile({ + case x: Field => true + case x: Index => true + case Clock => true + case Init => true + case Reset => true + case other => false + }).reverse + GenericTarget(t.circuitOpt, t.moduleOpt, newTokens) + case other => sys.error(s"Can't make $other pathless!") + }).tryToComplete + } } /** Represents incomplete or non-standard [[Target]]s @@ -235,6 +271,12 @@ case class GenericTarget(circuitOpt: Option[String], override def isLocal: Boolean = !(getPath.nonEmpty && getPath.get.nonEmpty) + def path: Vector[(Instance, OfModule)] = if(isComplete){ + tokens.zip(tokens.tail).collect { + case (i: Instance, o: OfModule) => (i, o) + } + } else Vector.empty[(Instance, OfModule)] + /** If complete, return this [[GenericTarget]]'s path * @return */ @@ -342,6 +384,14 @@ case class GenericTarget(circuitOpt: Option[String], def isCircuitTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.isEmpty && tokens.isEmpty def isModuleTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.isEmpty def isComponentTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.nonEmpty + + lazy val (parentModule: Option[String], astModule: Option[String]) = path match { + case Seq() => (None, moduleOpt) + case Seq((i, OfModule(o))) => (moduleOpt, Some(o)) + case seq if seq.size > 1 => + val reversed = seq.reverse + (Some(reversed(1)._2.value), Some(reversed(0)._2.value)) + } } /** Concretely points to a FIRRTL target, no generic selectors @@ -368,7 +418,7 @@ trait CompleteTarget extends Target { override def toTarget: CompleteTarget = this // Very useful for debugging, I (@azidar) think this is reasonable - override def toString = serialize + override def toString: String = serialize } @@ -417,6 +467,13 @@ trait IsMember extends CompleteTarget { * @return */ def setPathTarget(newPath: IsModule): CompleteTarget + + /** @return The [[ModuleTarget]] of the module that directly contains this component */ + def encapsulatingModule: String = if(path.isEmpty) module else path.last._2.value + + def encapsulatingModuleTarget: ModuleTarget = ModuleTarget(circuit, encapsulatingModule) + + def leafModule: String } /** References a module-like target (e.g. a [[ModuleTarget]] or an [[InstanceTarget]]) @@ -435,10 +492,6 @@ trait IsModule extends IsMember { /** A component of a FIRRTL Module (e.g. cannot point to a CircuitTarget or ModuleTarget) */ trait IsComponent extends IsMember { - - /** @return The [[ModuleTarget]] of the module that directly contains this component */ - def encapsulatingModule: String = if(path.isEmpty) module else path.last._2.value - /** Removes n levels of instance hierarchy * * Example: n=1, transforms (Top, A)/b:B/c:C -> (Top, B)/c:C @@ -505,6 +558,8 @@ case class CircuitTarget(circuit: String) extends CompleteTarget { override def addHierarchy(root: String, instance: String): ReferenceTarget = ReferenceTarget(circuit, root, Nil, instance, Nil) + override def path = Seq() + override def toNamed: CircuitName = CircuitName(circuit) } @@ -545,6 +600,8 @@ case class ModuleTarget(circuit: String, module: String) extends IsModule { override def setPathTarget(newPath: IsModule): IsModule = newPath override def toNamed: ModuleName = ModuleName(module, CircuitName(circuit)) + + override def leafModule: String = module } /** Target pointing to a declared named component in a [[firrtl.ir.DefModule]] @@ -631,6 +688,30 @@ case class ReferenceTarget(circuit: String, ReferenceTarget(newPath.circuit, newPath.module, newPath.asPath, ref, component) override def asPath: Seq[(Instance, OfModule)] = path + + def isClock: Boolean = tokens.last == Clock + + def isInit: Boolean = tokens.last == Init + + def isReset: Boolean = tokens.last == Reset + + def noComponents: ReferenceTarget = this.copy(component = Nil) + + def leafSubTargets(tpe: firrtl.ir.Type): Seq[ReferenceTarget] = tpe match { + case _: firrtl.ir.GroundType => Vector(this) + case firrtl.ir.VectorType(t, size) => (0 until size).flatMap { i => index(i).leafSubTargets(t) } + case firrtl.ir.BundleType(fields) => fields.flatMap { f => field(f.name).leafSubTargets(f.tpe)} + case other => sys.error(s"Error! Unexpected type $other") + } + + def allSubTargets(tpe: firrtl.ir.Type): Seq[ReferenceTarget] = tpe match { + case _: firrtl.ir.GroundType => Vector(this) + case firrtl.ir.VectorType(t, size) => this +: (0 until size).flatMap { i => index(i).allSubTargets(t) } + case firrtl.ir.BundleType(fields) => this +: fields.flatMap { f => field(f.name).allSubTargets(f.tpe)} + case other => sys.error(s"Error! Unexpected type $other") + } + + override def leafModule: String = encapsulatingModule } /** Points to an instance declaration of a module (termed an ofModule) @@ -652,6 +733,12 @@ case class InstanceTarget(circuit: String, /** @return a [[ModuleTarget]] referring to declaration of this ofModule */ def ofModuleTarget: ModuleTarget = ModuleTarget(circuit, ofModule) + /** @return a [[ReferenceTarget]] referring to given reference within this instance */ + def addReference(rt: ReferenceTarget): ReferenceTarget = { + require(rt.module == ofModule) + ReferenceTarget(circuit, module, asPath, rt.ref, rt.component) + } + override def circuitOpt: Option[String] = Some(circuit) override def moduleOpt: Option[String] = Some(module) @@ -690,6 +777,8 @@ case class InstanceTarget(circuit: String, override def setPathTarget(newPath: IsModule): InstanceTarget = InstanceTarget(newPath.circuit, newPath.module, newPath.asPath, instance, ofModule) + + override def leafModule: String = ofModule } diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index f30beec1..32bcac5f 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -2,9 +2,8 @@ package firrtl.graph -import scala.collection.{Set, Map} -import scala.collection.mutable -import scala.collection.mutable.{LinkedHashSet, LinkedHashMap} +import scala.collection.{Map, Set, mutable} +import scala.collection.mutable.{LinkedHashMap, LinkedHashSet} /** An exception that is raised when an assumed DAG has a cycle */ class CyclicException(val node: Any) extends Exception(s"No valid linearization for cyclic graph, found at $node") @@ -18,7 +17,7 @@ object DiGraph { def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] = mdg /** Create a DiGraph from a Map[T,Set[T]] of edge data */ - def apply[T](edgeData: Map[T,Set[T]]): DiGraph[T] = { + def apply[T](edgeData: Map[T, Set[T]]): DiGraph[T] = { val edgeDataCopy = new LinkedHashMap[T, LinkedHashSet[T]] for ((k, v) <- edgeData) { edgeDataCopy(k) = new LinkedHashSet[T] @@ -34,7 +33,7 @@ object DiGraph { } /** Represents common behavior of all directed graphs */ -class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { +class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { /** Check whether the graph contains vertex v */ def contains(v: T): Boolean = edges.contains(v) @@ -188,7 +187,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * * @param start the start node * @param end the destination node - * @throws PathNotFoundException + * @throws firrtl.graph.PathNotFoundException * @return a Seq[T] of nodes defining an arbitrary valid path */ def path(start: T, end: T): Seq[T] = path(start, end, Set.empty[T]) @@ -198,7 +197,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * @param start the start node * @param end the destination node * @param blacklist list of nodes which break path, if encountered - * @throws PathNotFoundException + * @throws firrtl.graph.PathNotFoundException * @return a Seq[T] of nodes defining an arbitrary valid path */ def path(start: T, end: T, blacklist: Set[T]): Seq[T] = { @@ -336,7 +335,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * Any edge including a deleted node will be deleted * * @param vprime the Set[T] of desired vertices - * @throws scala.IllegalArgumentException if vprime is not a subset of V + * @throws java.lang.IllegalArgumentException if vprime is not a subset of V * @return the subgraph */ def subgraph(vprime: Set[T]): DiGraph[T] = { @@ -350,12 +349,12 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * transformed into an edge (u,v). * * @param vprime the Set[T] of desired vertices - * @throws scala.IllegalArgumentException if vprime is not a subset of V + * @throws java.lang.IllegalArgumentException if vprime is not a subset of V * @return the simplified graph */ def simplify(vprime: Set[T]): DiGraph[T] = { require(vprime.subsetOf(edges.keySet)) - val pathEdges = vprime.map( v => (v, reachableFrom(v) & (vprime-v)) ) + val pathEdges = vprime.map(v => (v, reachableFrom(v) & (vprime-v)) ) new DiGraph(new LinkedHashMap[T, LinkedHashSet[T]] ++ pathEdges) } @@ -394,7 +393,7 @@ class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T] } /** Add edge (u,v) to the graph. - * @throws scala.IllegalArgumentException if u and/or v is not in the graph + * @throws java.lang.IllegalArgumentException if u and/or v is not in the graph */ def addEdge(u: T, v: T): Unit = { require(contains(u)) diff --git a/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala b/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala new file mode 100644 index 00000000..79922fa9 --- /dev/null +++ b/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala @@ -0,0 +1,50 @@ +// See LICENSE for license details. + +package firrtlTests.analyses + +import firrtl.analyses.CircuitGraph +import firrtl.annotations.CircuitTarget +import firrtl.options.Dependency +import firrtl.passes.ExpandWhensAndCheck +import firrtl.stage.{Forms, TransformManager} +import firrtl.testutils.FirrtlFlatSpec +import firrtl.{ChirrtlForm, CircuitState, FileUtils, UnknownForm} + +class CircuitGraphSpec extends FirrtlFlatSpec { + "CircuitGraph" should "find paths with deep hierarchy quickly" in { + def mkChild(n: Int): String = + s""" module Child${n} : + | input in: UInt<8> + | output out: UInt<8> + | inst c1 of Child${n+1} + | inst c2 of Child${n+1} + | c1.in <= in + | c2.in <= c1.out + | out <= c2.out + """.stripMargin + def mkLeaf(n: Int): String = + s""" module Child${n} : + | input in: UInt<8> + | output out: UInt<8> + | wire middle: UInt<8> + | middle <= in + | out <= middle + """.stripMargin + (2 until 23 by 2).foreach { n => + val input = new StringBuilder() + input ++= + """circuit Child0: + |""".stripMargin + (0 until n).foreach { i => input ++= mkChild(i); input ++= "\n" } + input ++= mkLeaf(n) + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( + CircuitState(parse(input.toString()), UnknownForm) + ).circuit + val circuitGraph = CircuitGraph(circuit) + val C = CircuitTarget("Child0") + val Child0 = C.module("Child0") + circuitGraph.connectionPath(Child0.ref("in"), Child0.ref("out")) + } + } + +} diff --git a/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala b/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala new file mode 100644 index 00000000..06f59a3c --- /dev/null +++ b/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala @@ -0,0 +1,107 @@ +// See LICENSE for license details. + +package firrtlTests.analyses + +import firrtl.{ChirrtlForm, CircuitState, FileUtils, IRToWorkingIR, UnknownForm} +import firrtl.analyses.{CircuitGraph, ConnectionGraph} +import firrtl.annotations.ModuleTarget +import firrtl.options.Dependency +import firrtl.passes.ExpandWhensAndCheck +import firrtl.stage.{Forms, TransformManager} +import firrtl.testutils.FirrtlFlatSpec + +class ConnectionGraphSpec extends FirrtlFlatSpec { + + "ConnectionGraph" should "build connection graph for rocket-chip" in { + ConnectionGraph( + new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( + CircuitState(parse(FileUtils.getTextResource("/regress/RocketCore.fir")), UnknownForm) + ).circuit + ) + } + + val input = + """circuit Test: + | module Test : + | input in: UInt<8> + | input clk: Clock + | input reset: UInt<1> + | output out: {a: UInt<8>, b: UInt<8>[2]} + | out is invalid + | reg r: UInt<8>, clk with: + | (reset => (reset, UInt(0))) + | r <= in + | node x = r + | wire y: UInt<8> + | y <= x + | out.b[0] <= and(y, asUInt(SInt(-1))) + | inst child of Child + | child.in <= in + | out.a <= child.out + | module Child: + | input in: UInt<8> + | output out: UInt<8> + | out <= in + |""".stripMargin + + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( + CircuitState(parse(input), UnknownForm) + ).circuit + + "ConnectionGraph" should "work with pathsInDAG" in { + val Test = ModuleTarget("Test", "Test") + val irGraph = ConnectionGraph(circuit) + + val paths = irGraph.pathsInDAG(Test.ref("in")) + paths(Test.ref("out").field("b").index(0)) shouldBe Seq( + Seq( + Test.ref("in"), + Test.ref("r"), + Test.ref("x"), + Test.ref("y"), + Test.ref("@and#0"), + Test.ref("out").field("b").index(0) + ) + ) + paths(Test.ref("out").field("a")) shouldBe Seq( + Seq( + Test.ref("in"), + Test.ref("child").field("in"), + Test.instOf("child", "Child").ref("in"), + Test.instOf("child", "Child").ref("out"), + Test.ref("child").field("out"), + Test.ref("out").field("a") + ) + ) + + } + + "ConnectionGraph" should "work with path" in { + val Test = ModuleTarget("Test", "Test") + val irGraph = ConnectionGraph(circuit) + + irGraph.path(Test.ref("in"), Test.ref("out").field("b").index(0)) shouldBe Seq( + Test.ref("in"), + Test.ref("r"), + Test.ref("x"), + Test.ref("y"), + Test.ref("@and#0"), + Test.ref("out").field("b").index(0) + ) + + irGraph.path(Test.ref("in"), Test.ref("out").field("a")) shouldBe Seq( + Test.ref("in"), + Test.ref("child").field("in"), + Test.instOf("child", "Child").ref("in"), + Test.instOf("child", "Child").ref("out"), + Test.ref("child").field("out"), + Test.ref("out").field("a") + ) + + irGraph.path(Test.ref("@invalid#0"), Test.ref("out").field("b").index(1)) shouldBe Seq( + Test.ref("@invalid#0"), + Test.ref("out").field("b").index(1) + ) + } + +} diff --git a/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala b/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala new file mode 100644 index 00000000..50ee75ac --- /dev/null +++ b/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala @@ -0,0 +1,296 @@ +package firrtlTests.analyses + +import firrtl.PrimOps.AsUInt +import firrtl.analyses.IRLookup +import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget} +import firrtl._ +import firrtl.ir._ +import firrtl.options.Dependency +import firrtl.passes.ExpandWhensAndCheck +import firrtl.stage.{Forms, TransformManager} +import firrtl.testutils.FirrtlFlatSpec + + +class IRLookupSpec extends FirrtlFlatSpec { + + "IRLookup" should "return declarations" in { + val input = + """circuit Test: + | module Test : + | input in: UInt<8> + | input clk: Clock + | input reset: UInt<1> + | output out: {a: UInt<8>, b: UInt<8>[2]} + | input ana1: Analog<8> + | output ana2: Analog<8> + | out is invalid + | reg r: UInt<8>, clk with: + | (reset => (reset, UInt(0))) + | node x = r + | wire y: UInt<8> + | y <= x + | out.b[0] <= and(y, asUInt(SInt(-1))) + | attach(ana1, ana2) + | inst child of Child + | out.a <= child.out + | module Child: + | output out: UInt<8> + | out <= UInt(1) + |""".stripMargin + + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( + CircuitState(parse(input), UnknownForm) + ).circuit + val irLookup = IRLookup(circuit) + val Test = ModuleTarget("Test", "Test") + val uint8 = UIntType(IntWidth(8)) + + irLookup.declaration(Test.ref("in")) shouldBe Port(NoInfo, "in", Input, uint8) + irLookup.declaration(Test.ref("clk")) shouldBe Port(NoInfo, "clk", Input, ClockType) + irLookup.declaration(Test.ref("reset")) shouldBe Port(NoInfo, "reset", Input, UIntType(IntWidth(1))) + + val out = Port(NoInfo, "out", Output, + BundleType(Seq(Field("a", Default, uint8), Field("b", Default, VectorType(uint8, 2)))) + ) + irLookup.declaration(Test.ref("out")) shouldBe out + irLookup.declaration(Test.ref("out").field("a")) shouldBe out + irLookup.declaration(Test.ref("out").field("b").index(0)) shouldBe out + irLookup.declaration(Test.ref("out").field("b").index(1)) shouldBe out + + irLookup.declaration(Test.ref("ana1")) shouldBe Port(NoInfo, "ana1", Input, AnalogType(IntWidth(8))) + irLookup.declaration(Test.ref("ana2")) shouldBe Port(NoInfo, "ana2", Output, AnalogType(IntWidth(8))) + + val clk = WRef("clk", ClockType, PortKind, SourceFlow) + val reset = WRef("reset", UIntType(IntWidth(1)), PortKind, SourceFlow) + val init = UIntLiteral(0) + val reg = DefRegister(NoInfo, "r", uint8, clk, reset, init) + irLookup.declaration(Test.ref("r")) shouldBe reg + irLookup.declaration(Test.ref("r").clock) shouldBe reg + irLookup.declaration(Test.ref("r").reset) shouldBe reg + irLookup.declaration(Test.ref("r").init) shouldBe reg + irLookup.kindFinder(Test, RegKind) shouldBe Seq(Test.ref("r")) + irLookup.declaration(Test.ref("x")) shouldBe DefNode(NoInfo, "x", WRef("r", uint8, RegKind, SourceFlow)) + irLookup.declaration(Test.ref("y")) shouldBe DefWire(NoInfo, "y", uint8) + + irLookup.declaration(Test.ref("@and#0")) shouldBe + DoPrim(PrimOps.And, + Seq(WRef("y", uint8, WireKind, SourceFlow), DoPrim(AsUInt, Seq(SIntLiteral(-1)), Nil, UIntType(IntWidth(1)))), + Nil, + uint8 + ) + + val inst = WDefInstance(NoInfo, "child", "Child", BundleType(Seq(Field("out", Default, uint8)))) + irLookup.declaration(Test.ref("child")) shouldBe inst + irLookup.declaration(Test.ref("child").field("out")) shouldBe inst + irLookup.declaration(Test.instOf("child", "Child").ref("out")) shouldBe Port(NoInfo, "out", Output, uint8) + + intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Child").ref("missing")) } + intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Missing").ref("out")) } + intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("missing", "Child").ref("out")) } + intercept[IllegalArgumentException]{ irLookup.declaration(Test.ref("missing")) } + intercept[IllegalArgumentException]{ irLookup.declaration(Test.ref("out").field("c")) } + intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Child").ref("out").field("missing")) } + } + + "IRLookup" should "return mem declarations" in { + def commonFields: Seq[String] = Seq("clk", "en", "addr") + def readerTargets(rt: ReferenceTarget): Seq[ReferenceTarget] = { + (commonFields ++ Seq("data")).map(rt.field) + } + def writerTargets(rt: ReferenceTarget): Seq[ReferenceTarget] = { + (commonFields ++ Seq("data", "mask")).map(rt.field) + } + def readwriterTargets(rt: ReferenceTarget): Seq[ReferenceTarget] = { + (commonFields ++ Seq("wdata", "wmask", "wmode", "rdata")).map(rt.field) + } + val input = + s"""circuit Test: + | module Test : + | input in : UInt<8> + | input clk: Clock[3] + | input dataClk: Clock + | input mode: UInt<1> + | output out : UInt<8>[2] + | mem m: + | data-type => UInt<8> + | reader => r + | writer => w + | readwriter => rw + | depth => 2 + | write-latency => 1 + | read-latency => 0 + | + | reg addr: UInt<1>, dataClk + | reg en: UInt<1>, dataClk + | reg indata: UInt<8>, dataClk + | + | m.r.clk <= clk[0] + | m.r.en <= en + | m.r.addr <= addr + | out[0] <= m.r.data + | + | m.w.clk <= clk[1] + | m.w.en <= en + | m.w.addr <= addr + | m.w.data <= indata + | m.w.mask <= en + | + | m.rw.clk <= clk[2] + | m.rw.en <= en + | m.rw.addr <= addr + | m.rw.wdata <= indata + | m.rw.wmask <= en + | m.rw.wmode <= en + | out[1] <= m.rw.rdata + |""".stripMargin + + val C = CircuitTarget("Test") + val MemTest = C.module("Test") + val Mem = MemTest.ref("m") + val Reader = Mem.field("r") + val Writer = Mem.field("w") + val Readwriter = Mem.field("rw") + val allSignals = readerTargets(Reader) ++ writerTargets(Writer) ++ readwriterTargets(Readwriter) + + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( + CircuitState(parse(input), UnknownForm) + ).circuit + val irLookup = IRLookup(circuit) + val uint8 = UIntType(IntWidth(8)) + val mem = DefMemory(NoInfo, "m", uint8, 2, 1, 0, Seq("r"), Seq("w"), Seq("rw")) + allSignals.foreach { at => + irLookup.declaration(at) shouldBe mem + } + } + + "IRLookup" should "return expressions, types, kinds, and flows" in { + val input = + """circuit Test: + | module Test : + | input in: UInt<8> + | input clk: Clock + | input reset: UInt<1> + | output out: {a: UInt<8>, b: UInt<8>[2]} + | input ana1: Analog<8> + | output ana2: Analog<8> + | out is invalid + | reg r: UInt<8>, clk with: + | (reset => (reset, UInt(0))) + | node x = r + | wire y: UInt<8> + | y <= x + | out.b[0] <= and(y, asUInt(SInt(-1))) + | attach(ana1, ana2) + | inst child of Child + | out.a <= child.out + | module Child: + | output out: UInt<8> + | out <= UInt(1) + |""".stripMargin + + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( + CircuitState(parse(input), UnknownForm) + ).circuit + val irLookup = IRLookup(circuit) + val Test = ModuleTarget("Test", "Test") + val uint8 = UIntType(IntWidth(8)) + val uint1 = UIntType(IntWidth(1)) + + def check(rt: ReferenceTarget, e: Expression): Unit = { + irLookup.expr(rt) shouldBe e + irLookup.tpe(rt) shouldBe e.tpe + irLookup.kind(rt) shouldBe Utils.kind(e) + irLookup.flow(rt) shouldBe Utils.flow(e) + } + + check(Test.ref("in"), WRef("in", uint8, PortKind, SourceFlow)) + check(Test.ref("clk"), WRef("clk", ClockType, PortKind, SourceFlow)) + check(Test.ref("reset"), WRef("reset", uint1, PortKind, SourceFlow)) + + val out = Test.ref("out") + val outExpr = + WRef("out", + BundleType(Seq(Field("a", Default, uint8), Field("b", Default, VectorType(uint8, 2)))), + PortKind, + SinkFlow + ) + check(out, outExpr) + check(out.field("a"), WSubField(outExpr, "a", uint8, SinkFlow)) + val outB = out.field("b") + val outBExpr = WSubField(outExpr, "b", VectorType(uint8, 2), SinkFlow) + check(outB, outBExpr) + check(outB.index(0), WSubIndex(outBExpr, 0, uint8, SinkFlow)) + check(outB.index(1), WSubIndex(outBExpr, 1, uint8, SinkFlow)) + + check(Test.ref("ana1"), WRef("ana1", AnalogType(IntWidth(8)), PortKind, SourceFlow)) + check(Test.ref("ana2"), WRef("ana2", AnalogType(IntWidth(8)), PortKind, SinkFlow)) + + val clk = WRef("clk", ClockType, PortKind, SourceFlow) + val reset = WRef("reset", UIntType(IntWidth(1)), PortKind, SourceFlow) + val init = UIntLiteral(0) + check(Test.ref("r"), WRef("r", uint8, RegKind, DuplexFlow)) + check(Test.ref("r").clock, clk) + check(Test.ref("r").reset, reset) + check(Test.ref("r").init, init) + + check(Test.ref("x"), WRef("x", uint8, ExpKind, SourceFlow)) + + check(Test.ref("y"), WRef("y", uint8, WireKind, DuplexFlow)) + + check(Test.ref("@and#0"), + DoPrim(PrimOps.And, + Seq(WRef("y", uint8, WireKind, SourceFlow), DoPrim(AsUInt, Seq(SIntLiteral(-1)), Nil, UIntType(IntWidth(1)))), + Nil, + uint8 + ) + ) + + val child = WRef("child", BundleType(Seq(Field("out", Default, uint8))), InstanceKind, SourceFlow) + check(Test.ref("child"), child) + check(Test.ref("child").field("out"), + WSubField(child, "out", uint8, SourceFlow) + ) + } + + "IRLookup" should "cache expressions" in { + def mkType(i: Int): String = { + if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" + } + + val depth = 500 + + val input = + s"""circuit Test: + | module Test : + | input in: ${mkType(depth)} + | output out: ${mkType(depth)} + | out <= in + |""".stripMargin + + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( + CircuitState(parse(input), UnknownForm) + ).circuit + val Test = ModuleTarget("Test", "Test") + val irLookup = IRLookup(circuit) + def mkReferences(parent: ReferenceTarget, i: Int): Seq[ReferenceTarget] = { + if(i == 0) Seq(parent) else { + val newParent = parent.field("x") + newParent +: mkReferences(newParent, i - 1) + } + } + + // Check caching from root to leaf + val inRefs = mkReferences(Test.ref("in"), depth) + val (inStartTime, _) = Utils.time(irLookup.expr(inRefs.head)) + inRefs.tail.foreach { r => + val (ms, _) = Utils.time(irLookup.expr(r)) + require(inStartTime > ms) + } + val outRefs = mkReferences(Test.ref("out"), depth).reverse + val (outStartTime, _) = Utils.time(irLookup.expr(outRefs.head)) + outRefs.tail.foreach { r => + val (ms, _) = Utils.time(irLookup.expr(r)) + require(outStartTime > ms) + } + } +} diff --git a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala index 4d38340f..f847fb6c 100644 --- a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala +++ b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala @@ -1,3 +1,5 @@ +// See LICENSE for license details. + package firrtlTests package transforms |
