diff options
| author | Schuyler Eldridge | 2018-01-15 18:53:28 -0500 |
|---|---|---|
| committer | Jack Koenig | 2018-01-15 15:53:28 -0800 |
| commit | 347cc522e96f8090d53b3b042af646e4a0e765b2 (patch) | |
| tree | 5ce616a60da92583bcf996b56fb6ba041c68b3a2 | |
| parent | 8e18404b2919ef6226b511bb666116f657082aa8 (diff) | |
WiringTransform Refactor (#648)
Massive refactoring to WiringTransform with the use of a new EulerTour
class to speed things up via fast least common ancestor (LCA) queries.
Changes include (but are not limited to):
* Use lowest common ancestor when wiring
* Add EulerTour class with naive and Berkman-Vishkin RMQ
* Adds LCA method for Instance Graph
* Enables "Two Sources" using "Top" wiring test as this is now valid
* Remove TopAnnotation from WiringTransform
* Represent WiringTransform sink as `Seq[Named]`
* Remove WiringUtils.countInstances, fix imports
* Support sources under sinks in WiringTransform
* Enable internal module wiring
* Support Wiring of Aggregates
h/t @edcote
fixes #728
Signed-off-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
Reviewed-by: Jack Koenig<jack.koenig3@gmail.com>
| -rw-r--r-- | .gitignore | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/analyses/Netlist.scala | 20 | ||||
| -rw-r--r-- | src/main/scala/firrtl/graph/DiGraph.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/graph/EulerTour.scala | 223 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/clocklist/ClockList.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala | 15 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/DecorateMems.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/wiring/Wiring.scala | 287 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/wiring/WiringTransform.scala | 103 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/wiring/WiringUtils.scala | 226 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/WiringTests.scala | 1035 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/graph/EulerTourTests.scala | 36 |
14 files changed, 1411 insertions, 553 deletions
@@ -56,3 +56,6 @@ project/plugins/project/ gen/ project/project/ /bin/ + +*~ +*#*# diff --git a/src/main/scala/firrtl/analyses/Netlist.scala b/src/main/scala/firrtl/analyses/Netlist.scala index 4e211d8b..99f3645f 100644 --- a/src/main/scala/firrtl/analyses/Netlist.scala +++ b/src/main/scala/firrtl/analyses/Netlist.scala @@ -1,3 +1,5 @@ +// See LICENSE for license details. + package firrtl.analyses import scala.collection.mutable @@ -10,13 +12,14 @@ import firrtl.Mappers._ /** A class representing the instance hierarchy of a working IR Circuit - * + * * @constructor constructs an instance graph from a Circuit * @param c the Circuit to analyze */ class InstanceGraph(c: Circuit) { - private def collectInstances(insts: mutable.Set[WDefInstance])(s: Statement): Statement = s match { + private def collectInstances(insts: mutable.Set[WDefInstance]) + (s: Statement): Statement = s match { case i: WDefInstance => insts += i i @@ -72,7 +75,7 @@ class InstanceGraph(c: Circuit) { /** Finds the absolute paths (each represented by a Seq of instances * representing the chain of hierarchy) of all instances of a * particular module. - * + * * @param module the name of the selected module * @return a Seq[Seq[WDefInstance]] of absolute instance paths */ @@ -81,5 +84,14 @@ class InstanceGraph(c: Circuit) { instances flatMap { i => fullHierarchy(i) } } -} + /** An `[[EulerTour]]` representation of the `[[DiGraph]]` */ + lazy val tour = EulerTour(graph, trueTopInstance) + /** Finds the lowest common ancestor instances for two module names in + * a design + */ + def lowestCommonAncestor(moduleA: Seq[WDefInstance], + moduleB: Seq[WDefInstance]): Seq[WDefInstance] = { + tour.rmq(moduleA, moduleB) + } +} diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index b869982d..6538c880 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -1,3 +1,5 @@ +// See LICENSE for license details. + package firrtl.graph import scala.collection.{Set, Map} @@ -306,7 +308,6 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link edges.foreach({ case (k, v) => eprime(f(k)) ++= v.map(f(_)) }) new DiGraph(eprime) } - } class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]]) { diff --git a/src/main/scala/firrtl/graph/EulerTour.scala b/src/main/scala/firrtl/graph/EulerTour.scala new file mode 100644 index 00000000..db25d8d0 --- /dev/null +++ b/src/main/scala/firrtl/graph/EulerTour.scala @@ -0,0 +1,223 @@ +// See LICENSE for license details. + +package firrtl.graph + +import scala.collection.mutable + +/** Euler Tour companion object */ +object EulerTour { + /** Create an Euler Tour of a `DiGraph[T]` */ + def apply[T](diGraph: DiGraph[T], start: T): EulerTour[Seq[T]] = { + val r = mutable.Map[Seq[T], Int]() + val e = mutable.ArrayBuffer[Seq[T]]() + val h = mutable.ArrayBuffer[Int]() + + def tour(u: T, parent: Vector[T], height: Int): Unit = { + val id = parent :+ u + r.getOrElseUpdate(id, e.size) + e += id + h += height + diGraph.getEdges(id.last).foreach { v => + tour(v, id, height + 1) + e += id + h += height + } + } + + tour(start, Vector.empty, 0) + new EulerTour(r.toMap, e, h) + } +} + +/** A class that represents an Euler Tour of a directed graph from a + * given root. This requires `O(n)` preprocessing time to generate + * the initial Euler Tour. + * + * @constructor Create a new EulerTour from the specified data + * @param r A map of a node to its first index + * @param e A representation of the EulerTour as a `Seq[T]` + * @param h The depths of the Euler Tour represented as a `Seq[Int]` + */ +class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) { + private def lg(x: Double): Double = math.log(x) / math.log(2) + + /** Range Minimum Query of an Euler Tour using a naive algorithm. + * + * @param x The first query bound + * @param y The second query bound + * @return The minimum between the first and second query + * @note The order of '''x''' and '''y''' does not matter + * @note '''Performance''': + * - preprocessing: `O(1)` + * - query: `O(n)` + */ + def rmqNaive(x: T, y: T): T = { + val Seq(i, j) = Seq(r(x), r(y)).sorted + e.zip(h).slice(i, j + 1).minBy(_._2)._1 + } + + // n: the length of the Euler Tour + // m: the size of blocks the Euler Tour is split into + private val n = h.size + private val m = math.ceil(lg(n) / 2).toInt + + /** Split up the tour into blocks of size m, padding the last block to + * be a multiple of m. Compute the minimum of each block, a, and + * the index of that minimum in each block, b. + */ + private lazy val blocks = (h ++ (1 to (m - n % m))).grouped(m).toArray + private lazy val a = blocks map (_.min) + private lazy val b = blocks map (b => b.indexOf(b.min)) + + /** Construct a Sparse Table (ST) representation for the minimum index + * of a sequence of integers. Data in the returned array is indexed + * as: [base, power of 2 range] + */ + private def constructSparseTable(x: Seq[Int]): Array[Array[Int]] = { + val tmp = Array.ofDim[Int](x.size + 1, math.ceil(lg(x.size)).toInt) + for (i <- 0 to x.size - 1; j <- 0 to math.ceil(lg(x.size)).toInt - 1) { + tmp(i)(j) = -1 + } + + def tableRecursive(base: Int, size: Int): Int = { + if (size == 0) { + tmp(base)(size) = base + base + } else { + val (a, b, c) = (base, base + (1 << (size - 1)), size - 1) + + val l = if (tmp(a)(c) != -1) { tmp(a)(c) } + else { tableRecursive(a, c) } + + val r = if (tmp(b)(c) != -1) { tmp(b)(c) } + else { tableRecursive(b, c) } + + val min = if (x(l) < x(r)) l else r + tmp(base)(size) = min + assert(min >= base) + min + } + } + + for (i <- (0 to x.size - 1); + j <- (0 to math.ceil(lg(x.size)).toInt - 1); + if i + (1 << j) - 1 < x.size) { + tableRecursive(i, j) + } + tmp + } + private lazy val st = constructSparseTable(a) + + /** Precompute all possible RMQs for an array of size `n where each + * entry in the range is different from the last by only +-1 + */ + private def constructTableLookups(n: Int): Array[Array[Array[Int]]] = { + def sortSeqSeq[T <: Int](x: Seq[T], y: Seq[T]): Boolean = { + if (x(0) != y(0)) x(0) < y(0) else sortSeqSeq(x.tail, y.tail) + } + + val size = m - 1 + val out = Seq.fill(size)(Seq(-1, 1)) + .flatten.combinations(m - 1).flatMap(_.permutations).toList + .sortWith(sortSeqSeq) + .map(_.foldLeft(Seq(0))((h, pm) => (h.head + pm) +: h).reverse) + .map{ a => + var tmp = Array.ofDim[Int](m, m) + for (i <- 0 to size; j <- i to size) yield { + val window = a.slice(i, j + 1) + tmp(i)(j) = window.indexOf(window.min) + i } + tmp }.toArray + out + } + private lazy val tables = constructTableLookups(m) + + /** Compute the precomputed table index of a given block */ + private def mapBlockToTable(block: Seq[Int]): Int = { + var index = 0 + var power = block.size - 2 + for (Seq(l, r) <- block.sliding(2)) { + if (l < r) { index += 1 << power } + power -= 1 + } + index + } + + /** Precompute a mapping of all blocks to their precomputed RMQ table + * indices + */ + private def mapBlocksToTables(blocks: Seq[Seq[Int]]): Array[Int] = { + val out = blocks.map(mapBlockToTable(_)).toArray + out + } + private lazy val tableIdx = mapBlocksToTables(blocks) + + /** Range Minimum Query using the Berkman--Vishkin algorithm with the + * simplifications of Bender--Farach-Colton. + * + * @param x The first query bound + * @param y The second query bound + * @return The minimum between the first and second query + * @note The order of '''x''' and '''y''' does not matter + * @note '''Performance''': + * - preprocessing: `O(n)` + * - query: `O(1)` + */ + def rmqBV(x: T, y: T): T = { + val Seq(i, j) = Seq(r(x), r(y)).sorted + + // Compute block and word indices + val (block_i, block_j) = (i / m, j / m) + val (word_i, word_j) = (i % m, j % m) + + /** Up to four possible minimum indices are then computed based on the + * following conditions: + * 1. `i` and `j` are in the same block: + * - one precomputed RMQ from `i` to `j` + * 2. `i` and `j` are in adjacent blocks: + * - one precomputed RMQ from `i` to the end of its block + * - one precomputed RMQ from `j` to the beginning of its block + * 3. `i` and `j` have blocks between them: + * - one precomputed RMQ from `i` to the end of its block + * - one precomputed RMQ from `j` to the beginning of its block + * - two sparse table lookups to fully cover all blocks + * between `i` and `j` + */ + val minIndices = (block_i, block_j) match { + case (bi, bj) if (block_i == block_j) => + val min_i = block_i * m + tables(tableIdx(block_i))(word_i)(word_j) + Seq(min_i) + case (bi, bj) if (block_i == block_j - 1) => + val min_i = block_i * m + tables(tableIdx(block_i))(word_i)( m - 1) + val min_j = block_j * m + tables(tableIdx(block_j))( 0)(word_j) + Seq(min_i, min_j) + case _ => + val min_i = block_i * m + tables(tableIdx(block_i))(word_i)( m - 1) + val min_j = block_j * m + tables(tableIdx(block_j))( 0)(word_j) + val (min_between_l, min_between_r) = { + val range = math.floor(lg(block_j - block_i - 1)).toInt + val base_0 = block_i + 1 + val base_1 = block_j - (1 << range) + + val (idx_0, idx_1) = (st(base_0)(range), st(base_1)(range)) + val (min_0, min_1) = (b(idx_0) + idx_0 * m, b(idx_1) + idx_1 * m) + (min_0, min_1) } + Seq(min_i, min_between_l, min_between_r, min_j) + } + + // Return the minimum of all possible minimum indices + e(minIndices.minBy(h(_))) + } + + /** Range Minimum Query of the Euler Tour. + * + * Use this for typical queries. + * + * @param x The first query bound + * @param y The second query bound + * @return The minimum between the first and second query + * @note This currently maps to `rmqBV`, but may choose to map to + * either `rmqBV` or `rmqNaive` + * @note The order of '''x''' and '''y''' does not matter + */ + def rmq(x: T, y: T): T = rmqBV(x, y) +} diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala index bd2536ab..073eb050 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala @@ -8,7 +8,7 @@ import firrtl.ir._ import annotations._ import Utils.error import java.io.{File, CharArrayWriter, PrintWriter, Writer} -import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage} +import wiring.WiringUtils.{getChildrenMap, getLineage} import wiring.Lineage import ClockListUtils._ import Utils._ @@ -30,7 +30,7 @@ class ClockList(top: String, writer: Writer) extends Pass { // === Checks === // TODO(izraelevitz): Check all registers/memories use "clock" clock port // ============== - + // Clock sources must be blackbox outputs and top's clock val partialSourceList = getSourceList(moduleMap)(lineages) val sourceList = partialSourceList ++ moduleMap(top).ports.collect{ case Port(i, n, Input, ClockType) => n } @@ -39,7 +39,7 @@ class ClockList(top: String, writer: Writer) extends Pass { // Remove everything from the circuit, unless it has a clock type // This simplifies the circuit drastically so InlineInstances doesn't take forever. val onlyClockCircuit = RemoveAllButClocks.run(c) - + // Inline the clock-only circuit up to the specified top module val modulesToInline = (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet val inlineTransform = new InlineInstances diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala index b04171a7..24f25525 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala @@ -8,7 +8,6 @@ import firrtl.ir._ import annotations._ import Utils.error import java.io.{File, CharArrayWriter, PrintWriter, Writer} -import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage} import wiring.Lineage import ClockListUtils._ import Utils._ @@ -22,23 +21,23 @@ object ClockListAnnotation { [Optional] ClockList List which signal drives each clock of every descendent of specified module -Usage: +Usage: --list-clocks -c:<circuit>:-m:<module>:-o:<filename> *** Note: sub-arguments to --list-clocks should be delimited by : and not white space! -""" - +""" + //Parse pass options val passOptions = PassConfigUtil.getPassOptions(t, usage) val outputConfig = passOptions.getOrElse( - OutputConfigFileName, + OutputConfigFileName, error("No output config file provided for ClockList!" + usage) ) val passCircuit = passOptions.getOrElse( - PassCircuitName, + PassCircuitName, error("No circuit name specified for ClockList!" + usage) ) val passModule = passOptions.getOrElse( - PassModuleName, + PassModuleName, error("No module name specified for ClockList!" + usage) ) passOptions.get(InputConfigFileName) match { @@ -65,7 +64,7 @@ class ClockListTransform extends Transform { def passSeq(top: String, writer: Writer): Seq[Pass] = Seq(new ClockList(top, writer)) def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { - case Seq(ClockListAnnotation(ModuleName(top, CircuitName(state.circuit.main)), out)) => + case Seq(ClockListAnnotation(ModuleName(top, CircuitName(state.circuit.main)), out)) => val outputFile = new PrintWriter(out) val newC = (new ClockList(top, outputFile)).run(state.circuit) outputFile.close() diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala index b81d0c7e..892f1642 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala @@ -8,7 +8,6 @@ import firrtl.ir._ import annotations._ import Utils.error import java.io.{File, CharArrayWriter, PrintWriter, Writer} -import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage} import wiring.Lineage import ClockListUtils._ import Utils._ @@ -61,4 +60,3 @@ object ClockListUtils { } } } - diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala index 53787b1d..1178ce69 100644 --- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala +++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala @@ -8,8 +8,6 @@ import firrtl.ir._ import annotations._ import Utils.error import java.io.{File, CharArrayWriter, PrintWriter, Writer} -import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage} -import wiring.Lineage import ClockListUtils._ import Utils._ import memlib.AnalysisUtils._ diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 648b0234..ad3616ad 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -18,9 +18,8 @@ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform val cN = CircuitName(state.circuit.main) val oldAnnos = state.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { case ((annos, pins), config) => - val top = TopAnnotation(ModuleName(config.top.name, cN), config.pin.name) val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)), config.pin.name) - (annos ++ Seq(top, source), pins :+ config.pin.name) + (annos, pins :+ config.pin.name) } state.copy(annotations = Some(AnnotationMap(as :+ PinAnnotation(cN, pins.toSeq)))) } diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala index 9656abb2..a268dba7 100644 --- a/src/main/scala/firrtl/passes/wiring/Wiring.scala +++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala @@ -10,157 +10,194 @@ import firrtl.Mappers._ import scala.collection.mutable import firrtl.annotations._ import firrtl.annotations.AnnotationUtils._ +import firrtl.analyses.InstanceGraph import WiringUtils._ -case class WiringException(msg: String) extends PassException(msg) +/** A data store of one sink--source wiring relationship */ +case class WiringInfo(source: ComponentName, sinks: Seq[Named], pin: String) -case class WiringInfo(source: String, comp: String, sinks: Set[String], pin: String, top: String) +/** A data store of wiring names */ +case class WiringNames(compName: String, source: String, sinks: Seq[Named], + pin: String) +/** Pass that computes and applies a sequence of wiring modifications + * + * @constructor construct a new Wiring pass + * @param wiSeq the [[WiringInfo]] to apply + */ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { - def run(c: Circuit): Circuit = { - wiSeq.foldLeft(c) { (circuit, wi) => wire(circuit, wi) } + def run(c: Circuit): Circuit = analyze(c) + .foldLeft(c){ + case (cx, (tpe, modsMap)) => cx.copy( + modules = cx.modules map onModule(tpe, modsMap)) } + + /** Converts multiple units of wiring information to module modifications */ + private def analyze(c: Circuit): Seq[(Type, Map[String, Modifications])] = { + + val names = wiSeq + .map ( wi => (wi.source, wi.sinks, wi.pin) match { + case (ComponentName(comp, ModuleName(source,_)), sinks, pin) => + WiringNames(comp, source, sinks, pin) }) + + val portNames = mutable.Seq.fill(names.size)(Map[String, String]()) + c.modules.foreach{ m => + val ns = Namespace(m) + names.zipWithIndex.foreach{ case (WiringNames(c, so, si, p), i) => + portNames(i) = portNames(i) + + ( m.name -> { + if (si.exists(getModuleName(_) == m.name)) ns.newName(p) + else ns.newName(tokenize(c) filterNot ("[]." contains _) mkString "_") + })}} + + val iGraph = new InstanceGraph(c) + names.zip(portNames).map{ case(WiringNames(comp, so, si, _), pn) => + computeModifications(c, iGraph, comp, so, si, pn) } } - /** Add pins to modules and wires a signal to them, under the scope of a specified top module - * Description: - * Adds a pin to each sink module - * Punches ports up from the source signal to the specified top module - * Punches ports down to each sink module - * Wires the source up and down, connecting to all sink modules - * Restrictions: - * - Can only have one source module instance under scope of the specified top - * - All instances of each sink module must be under the scope of the specified top - * Notes: - * - No module uniquification occurs (due to imposed restrictions) + /** Converts a single unit of wiring information to module modifications + * + * @param c the circuit that will be modified + * @param iGraph an InstanceGraph representation of the circuit + * @param compName the name of a component + * @param source the name of the source component + * @param sinks a list of sink components/modules that the source + * should be connected to + * @param portNames a mapping of module names to new ports/wires + * that should be generated if needed + * + * @return a tuple of the component type and a map of module names + * to pending modifications */ - def wire(c: Circuit, wi: WiringInfo): Circuit = { - // Split out WiringInfo - val source = wi.source - val sinks = wi.sinks - val compName = wi.comp - val pin = wi.pin - - // Maps modules to children instances, i.e. (instance, module) - val childrenMap = getChildrenMap(c) - - // Check restrictions - val nSources = countInstances(childrenMap, wi.top, source) - if(nSources != 1) - throw new WiringException(s"Cannot have $nSources instance of $source under ${wi.top}") - sinks.foreach { m => - val total = countInstances(childrenMap, c.main, m) - val nTop = countInstances(childrenMap, c.main, wi.top) - val perTop = countInstances(childrenMap, wi.top, m) - if(total != nTop * perTop) - throw new WiringException(s"Module ${wi.top} does not contain all instances of $m.") - } - - // Create valid port names for wiring that have no name conflicts - val portNames = c.modules.foldLeft(Map.empty[String, String]) { (map, m) => - map + (m.name -> { - val ns = Namespace(m) - if(sinks.contains(m.name)) ns.newName(pin) - else ns.newName(tokenize(compName) filterNot ("[]." contains _) mkString "_") - }) - } - - // Create a lineage tree from children map - val lineages = getLineage(childrenMap, wi.top) - - // Populate lineage tree with relationship information, i.e. who is source, - // sink, parent of source, etc. - val withFields = setSharedParent(wi.top)(setFields(sinks, source)(lineages)) - - // Populate lineage tree with what to instantiate, connect to/from, etc. - val withThings = setThings(portNames, compName)(withFields) + private def computeModifications(c: Circuit, + iGraph: InstanceGraph, + compName: String, + source: String, + sinks: Seq[Named], + portNames: Map[String, String]): + (Type, Map[String, Modifications]) = { - // Create a map from module name to lineage information - val map = pointToLineage(withThings) - - // Obtain the source component type val sourceComponentType = getType(c, source, compName) - - // Return new circuit with correct wiring - val cx = c.copy(modules = c.modules map onModule(map, sourceComponentType)) - - // Replace inserted IR nodes with WIR nodes - ToWorkingIR.run(cx) + val sinkComponents: Map[String, Seq[String]] = sinks + .collect{ case ComponentName(c, ModuleName(m, _)) => (c, m) } + .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]){ + case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)) ) } + + // Determine "ownership" of sources to sinks via minimum distance + val owners = sinksToSources(sinks, source, iGraph) + + // Determine port and pending modifications for all sink--source + // ownership pairs + val meta = new mutable.HashMap[String, Modifications] + .withDefaultValue(Modifications()) + owners.foreach { case (sink, source) => + val lca = iGraph.lowestCommonAncestor(sink, source) + + // Compute metadata along Sink to LCA paths. + sink.drop(lca.size - 1).sliding(2).toList.reverse.map { + case Seq(WDefInstance(_,_,pm,_), WDefInstance(_,ci,cm,_)) => + val to = s"$ci.${portNames(cm)}" + val from = s"${portNames(pm)}" + meta(pm) = meta(pm).copy( + addPortOrWire = Some((portNames(pm), DecWire)), + cons = (meta(pm).cons :+ (to, from)).distinct + ) + meta(cm) = meta(cm).copy( + addPortOrWire = Some((portNames(cm), DecInput)) + ) + // Case where the sink is the LCA + case Seq(WDefInstance(_,_,pm,_)) => + // Case where the source is also the LCA + if (source.drop(lca.size).isEmpty) { + meta(pm) = meta(pm).copy ( + addPortOrWire = Some((portNames(pm), DecWire)) + ) + } else { + val WDefInstance(_,ci,cm,_) = source.drop(lca.size).head + val to = s"${portNames(pm)}" + val from = s"$ci.${portNames(cm)}" + meta(pm) = meta(pm).copy( + addPortOrWire = Some((portNames(pm), DecWire)), + cons = (meta(pm).cons :+ (to, from)).distinct + ) + } + } + + // Compute metadata for the Sink + sink.last match { case WDefInstance(_, _, m, _) => + if (sinkComponents.contains(m)) { + val from = s"${portNames(m)}" + sinkComponents(m).foreach( to => + meta(m) = meta(m).copy( + cons = (meta(m).cons :+ (to, from)).distinct + ) + ) + } + } + + // Compute metadata for the Source + source.last match { case WDefInstance(_, _, m, _) => + val to = s"${portNames(m)}" + val from = compName + meta(m) = meta(m).copy( + cons = (meta(m).cons :+ (to, from)).distinct + ) + } + + // Compute metadata along Source to LCA path + source.drop(lca.size - 1).sliding(2).toList.reverse.map { + case Seq(WDefInstance(_,_,pm,_), WDefInstance(_,ci,cm,_)) => { + val to = s"${portNames(pm)}" + val from = s"$ci.${portNames(cm)}" + meta(pm) = meta(pm).copy( + cons = (meta(pm).cons :+ (to, from)).distinct + ) + meta(cm) = meta(cm).copy( + addPortOrWire = Some((portNames(cm), DecOutput)) + ) + } + // Case where the source is the LCA + case Seq(WDefInstance(_,_,pm,_)) => { + // Case where the sink is also the LCA. We do nothing here, + // as we've created the connecting wire above + if (sink.drop(lca.size).isEmpty) { + } else { + val WDefInstance(_,ci,cm,_) = sink.drop(lca.size).head + val to = s"$ci.${portNames(cm)}" + val from = s"${portNames(pm)}" + meta(pm) = meta(pm).copy( + cons = (meta(pm).cons :+ (to, from)).distinct + ) + } + } + } + } + (sourceComponentType, meta.toMap) } - /** Return new module with correct wiring - */ - private def onModule(map: Map[String, Lineage], t: Type)(m: DefModule) = { + /** Apply modifications to a module */ + private def onModule(t: Type, map: Map[String, Modifications])(m: DefModule) = { map.get(m.name) match { case None => m case Some(l) => - val stmts = mutable.ArrayBuffer[Statement]() + val defines = mutable.ArrayBuffer[Statement]() + val connects = mutable.ArrayBuffer[Statement]() val ports = mutable.ArrayBuffer[Port]() - l.addPort match { + l.addPortOrWire match { case None => case Some((s, dt)) => dt match { - case DecInput => ports += Port(NoInfo, s, Input, t) + case DecInput => ports += Port(NoInfo, s, Input, t) case DecOutput => ports += Port(NoInfo, s, Output, t) - case DecWire => - stmts += DefWire(NoInfo, s, t) + case DecWire => defines += DefWire(NoInfo, s, t) } } - stmts ++= (l.cons map { case ((l, r)) => + connects ++= (l.cons map { case ((l, r)) => Connect(NoInfo, toExp(l), toExp(r)) }) - def onStmt(s: Statement): Statement = Block(Seq(s) ++ stmts) m match { - case Module(i, n, ps, s) => Module(i, n, ps ++ ports, Block(Seq(s) ++ stmts)) + case Module(i, n, ps, s) => Module(i, n, ps ++ ports, + Block(defines ++ Seq(s) ++ connects)) case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps ++ ports, dn, p) } } } - - /** Returns the type of the component specified - */ - private def getType(c: Circuit, module: String, comp: String) = { - def getRoot(e: Expression): String = e match { - case r: Reference => r.name - case i: SubIndex => getRoot(i.expr) - case a: SubAccess => getRoot(a.expr) - case f: SubField => getRoot(f.expr) - } - val eComp = toExp(comp) - val root = getRoot(eComp) - var tpe: Option[Type] = None - def getType(s: Statement): Statement = s match { - case DefRegister(_, n, t, _, _, _) if n == root => - tpe = Some(t) - s - case DefWire(_, n, t) if n == root => - tpe = Some(t) - s - case WDefInstance(_, n, m, t) if n == root => - tpe = Some(t) - s - case DefNode(_, n, e) if n == root => - tpe = Some(e.tpe) - s - case sx: DefMemory if sx.name == root => - tpe = Some(MemPortUtils.memType(sx)) - sx - case sx => sx map getType - } - val m = c.modules find (_.name == module) getOrElse error(s"Must have a module named $module") - tpe = m.ports find (_.name == root) map (_.tpe) - m match { - case Module(i, n, ps, b) => getType(b) - case e: ExtModule => - } - tpe match { - case None => error(s"Didn't find $comp in $module!") - case Some(t) => - def setType(e: Expression): Expression = e map setType match { - case ex: Reference => ex.copy(tpe = t) - case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name)) - case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe)) - case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe)) - } - setType(eComp).tpe - } - } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index a8ef5f58..01e6f83a 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -11,86 +11,79 @@ import scala.collection.mutable import firrtl.annotations._ import WiringUtils._ -/** A component, e.g. register etc. Must be declared only once under the TopAnnotation - */ +/** A class for all exceptions originating from firrtl.passes.wiring */ +case class WiringException(msg: String) extends PassException(msg) + +/** An extractor of annotated source components */ object SourceAnnotation { - def apply(target: ComponentName, pin: String): Annotation = Annotation(target, classOf[WiringTransform], s"source $pin") + def apply(target: ComponentName, pin: String): Annotation = + Annotation(target, classOf[WiringTransform], s"source $pin") private val matcher = "source (.+)".r def unapply(a: Annotation): Option[(ComponentName, String)] = a match { - case Annotation(ComponentName(n, m), _, matcher(pin)) => Some((ComponentName(n, m), pin)) + case Annotation(ComponentName(n, m), _, matcher(pin)) => + Some((ComponentName(n, m), pin)) case _ => None } } -/** A module, e.g. ExtModule etc., that should add the input pin - */ +/** An extractor of annotation sink components or modules */ object SinkAnnotation { - def apply(target: ModuleName, pin: String): Annotation = Annotation(target, classOf[WiringTransform], s"sink $pin") + def apply(target: Named, pin: String): Annotation = + Annotation(target, classOf[WiringTransform], s"sink $pin") private val matcher = "sink (.+)".r - def unapply(a: Annotation): Option[(ModuleName, String)] = a match { - case Annotation(ModuleName(n, c), _, matcher(pin)) => Some((ModuleName(n, c), pin)) + def unapply(a: Annotation): Option[(Named, String)] = a match { + case Annotation(ModuleName(n, c), _, matcher(pin)) => + Some((ModuleName(n, c), pin)) + case Annotation(ComponentName(n, m), _, matcher(pin)) => + Some((ComponentName(n, m), pin)) case _ => None } } -/** A module under which all sink module must be declared, and there is only - * one source component +/** Wires a Module's Source Component to one or more Sink + * Modules/Components + * + * Sinks are wired to their closest source through their lowest + * common ancestor (LCA). Verbosely, this modifies the circuit in + * the following ways: + * - Adds a pin to each sink module + * - Punches ports up from source signals to the LCA + * - Punches ports down from LCAs to each sink module + * - Wires sources up to LCA, sinks down from LCA, and across each LCA + * + * @throws WiringException if a sink is equidistant to two sources */ -object TopAnnotation { - def apply(target: ModuleName, pin: String): Annotation = Annotation(target, classOf[WiringTransform], s"top $pin") +class WiringTransform extends Transform { + def inputForm: CircuitForm = MidForm + def outputForm: CircuitForm = HighForm - private val matcher = "top (.+)".r - def unapply(a: Annotation): Option[(ModuleName, String)] = a match { - case Annotation(ModuleName(n, c), _, matcher(pin)) => Some((ModuleName(n, c), pin)) - case _ => None - } -} + /** Defines the sequence of Transform that should be applied */ + private def transforms(w: Seq[WiringInfo]): Seq[Transform] = Seq( + new Wiring(w), + ToWorkingIR + ) -/** Add pins to modules and wires a signal to them, under the scope of a specified top module - * Description: - * Adds a pin to each sink module - * Punches ports up from the source signal to the specified top module - * Punches ports down to each sink module - * Wires the source up and down, connecting to all sink modules - * Restrictions: - * - Can only have one source module instance under scope of the specified top - * - All instances of each sink module must be under the scope of the specified top - * Notes: - * - No module uniquification occurs (due to imposed restrictions) - */ -class WiringTransform extends Transform { - def inputForm = MidForm - def outputForm = MidForm - def transforms(wis: Seq[WiringInfo]) = - Seq(new Wiring(wis), - InferTypes, - ResolveKinds, - ResolveGenders) def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { case Nil => state - case p => - val sinks = mutable.HashMap[String, Set[String]]() - val sources = mutable.HashMap[String, String]() - val tops = mutable.HashMap[String, String]() - val comp = mutable.HashMap[String, String]() - p.foreach { + case p => + val sinks = mutable.HashMap[String, Seq[Named]]() + val sources = mutable.HashMap[String, ComponentName]() + p.foreach { case SinkAnnotation(m, pin) => - sinks(pin) = sinks.getOrElse(pin, Set.empty) + m.name + sinks(pin) = sinks.getOrElse(pin, Seq.empty) :+ m case SourceAnnotation(c, pin) => - sources(pin) = c.module.name - comp(pin) = c.name - case TopAnnotation(m, pin) => tops(pin) = m.name + sources(pin) = c } - (sources.size, tops.size, sinks.size, comp.size) match { - case (0, 0, p, 0) => state - case (s, t, p, c) if (p > 0) & (s == t) & (t == c) => - val wis = tops.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, top)) => - seq :+ WiringInfo(sources(pin), comp(pin), sinks(pin), pin, top) + (sources.size, sinks.size) match { + case (0, p) => state + case (s, p) if (p > 0) => + val wis = sources.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, source)) => + seq :+ WiringInfo(source, sinks(pin), pin) } transforms(wis).foldLeft(state) { (in, xform) => xform.runTransform(in) } - case _ => error("Wrong number of sources, tops, or sinks!") + case _ => error("Wrong number of sources or sinks!") } } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala index 29c93ca7..117a3824 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala @@ -9,6 +9,9 @@ import firrtl.Utils._ import firrtl.Mappers._ import scala.collection.mutable import firrtl.annotations._ +import firrtl.annotations.AnnotationUtils._ +import firrtl.analyses.InstanceGraph +import firrtl.graph.DiGraph import WiringUtils._ /** Declaration kind in lineage (e.g. input port, output port, wire) @@ -18,6 +21,19 @@ case object DecInput extends DecKind case object DecOutput extends DecKind case object DecWire extends DecKind +/** Store of pending wiring information for a Module */ +case class Modifications( + addPortOrWire: Option[(String, DecKind)] = None, + cons: Seq[(String, String)] = Seq.empty) { + + override def toString: String = serialize("") + + def serialize(tab: String): String = s""" + |$tab addPortOrWire: $addPortOrWire + |$tab cons: $cons + |""".stripMargin +} + /** A lineage tree representing the instance hierarchy in a design */ case class Lineage( @@ -41,7 +57,7 @@ case class Lineage( |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))} |""".stripMargin - def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B = + def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B = this.children.foldLeft(z)(op) def serialize(tab: String): String = s""" @@ -57,9 +73,6 @@ case class Lineage( |""".stripMargin } - - - object WiringUtils { type ChildrenMap = mutable.HashMap[String, Seq[(String, String)]] @@ -69,10 +82,10 @@ object WiringUtils { def getChildrenMap(c: Circuit): ChildrenMap = { val childrenMap = new ChildrenMap() def getChildren(mname: String)(s: Statement): Statement = s match { - case s: WDefInstance => + case s: WDefInstance => childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module) s - case s: DefInstance => + case s: DefInstance => childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module) s case s => s map getChildren(mname) @@ -84,86 +97,151 @@ object WiringUtils { childrenMap } - /** Counts the number of instances of a module declared under a top module - */ - def countInstances(childrenMap: ChildrenMap, top: String, module: String): Int = { - if(top == module) 1 - else childrenMap(top).foldLeft(0) { case (count, (i, child)) => - count + countInstances(childrenMap, child, module) - } - } - /** Returns a module's lineage, containing all children lineages as well */ def getLineage(childrenMap: ChildrenMap, module: String): Lineage = Lineage(module, childrenMap(module) map { case (i, m) => (i, getLineage(childrenMap, m)) } ) - /** Sets the sink, sinkParent, source, and sourceParent fields of every - * Lineage in tree + /** Return a map of sink instances to source instances that minimizes + * distance + * + * @param sinks a sequence of sink modules + * @param source the source module + * @param i a graph representing a circuit + * @return a map of sink instance names to source instance names + * @throws WiringException if a sink is equidistant to two sources */ - def setFields(sinks: Set[String], source: String)(lin: Lineage): Lineage = lin map setFields(sinks, source) match { - case l if sinks.contains(l.name) => l.copy(sink = true) - case l => - val src = l.name == source - val sinkParent = l.children.foldLeft(false) { case (b, (i, m)) => b || m.sink || m.sinkParent } - val sourceParent = if(src) true else l.children.foldLeft(false) { case (b, (i, m)) => b || m.source || m.sourceParent } - l.copy(sinkParent=sinkParent, sourceParent=sourceParent, source=src) - } + def sinksToSources(sinks: Seq[Named], + source: String, + i: InstanceGraph): + Map[Seq[WDefInstance], Seq[WDefInstance]] = { + val owners = new mutable.HashMap[Seq[WDefInstance], Vector[Seq[WDefInstance]]] + .withDefaultValue(Vector()) + val queue = new mutable.Queue[Seq[WDefInstance]] + val visited = new mutable.HashMap[Seq[WDefInstance], Boolean] + .withDefaultValue(false) + + i.fullHierarchy.keys.filter { case WDefInstance(_,_,m,_) => m == source } + .foreach( i.fullHierarchy(_) + .foreach { l => + queue.enqueue(l) + owners(l) = Vector(l) + } + ) + + val sinkInsts = i.fullHierarchy.keys + .filter { case WDefInstance(_, _, module, _) => + sinks.map(getModuleName(_)).contains(module) } + .flatMap { k => i.fullHierarchy(k) } + .toSet + + /** If we're lucky and there is only one source, then that source owns + * all sinks. If we're unlucky, we need to do a full (slow) BFS + * to figure out who owns what. Currently, the BFS is not + * performant. + * + * [todo] The performance of this will need to be improved. + * Possible directions are that if we're purely source-under-sink + * or sink-under-source, then ownership is trivially a mapping + * down/up. Ownership seems to require a BFS if we have + * sources/sinks not under sinks/sources. + */ + if (queue.size == 1) { + val u = queue.dequeue + sinkInsts.foreach { v => owners(v) = Vector(u) } + } else { + while (queue.nonEmpty) { + val u = queue.dequeue + visited(u) = true + + val edges = (i.graph.getEdges(u.last).map(u :+ _).toVector :+ u.dropRight(1)) + + // [todo] This is the critical section + edges + .filter( e => !visited(e) && e.nonEmpty ) + .foreach{ v => + owners(v) = owners(v) ++ owners(u) + queue.enqueue(v) + } + } + + // Check that every sink has one unique owner. The only time that + // this should fail is if a sink is equidistant to two sources. + sinkInsts.foreach { s => + if (!owners.contains(s) || owners(s).size > 1) { + throw new WiringException( + s"Unable to determine source mapping for sink '${s.map(_.name)}'") } + } + } - /** Sets the sharedParent of lineage top - */ - def setSharedParent(top: String)(lin: Lineage): Lineage = lin map setSharedParent(top) match { - case l if l.name == top => l.copy(sharedParent = true) - case l => l + owners + .collect { case (k, v) if sinkInsts.contains(k) => (k, v.flatten) }.toMap } - /** Sets the addPort and cons fields of the lineage tree - */ - def setThings(portNames:Map[String, String], compName: String)(lin: Lineage): Lineage = { - val funs = Seq( - ((l: Lineage) => l map setThings(portNames, compName)), - ((l: Lineage) => l match { - case Lineage(name, _, _, _, _, _, true, _, _) => //SharedParent - l.copy(addPort=Some((portNames(name), DecWire))) - case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent - l.copy(addPort=Some((portNames(name), DecOutput))) - case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent - l.copy(addPort=Some((portNames(name), DecInput))) - case Lineage(name, _, _, true, _, _, _, _, _) => //Sink - l.copy(addPort=Some((portNames(name), DecInput))) - case l => l - }), - ((l: Lineage) => l match { - case Lineage(name, _, true, _, _, _, _, _, _) => //Source - val tos = Seq(s"${portNames(name)}") - val from = compName - l.copy(cons = l.cons ++ tos.map(t => (t, from))) - case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent - val tos = Seq(s"${portNames(name)}") - val from = l.children.filter { case (i, c) => c.sourceParent }.map { case (i, c) => s"$i.${portNames(c.name)}" }.head - l.copy(cons = l.cons ++ tos.map(t => (t, from))) - case l => l - }), - ((l: Lineage) => l match { - case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent - val tos = l.children.filter { case (i, c) => (c.sinkParent || c.sink) && !c.sourceParent } map { case (i, c) => s"$i.${portNames(c.name)}" } - val from = s"${portNames(name)}" - l.copy(cons = l.cons ++ tos.map(t => (t, from))) - case l => l - }) - ) - funs.foldLeft(lin)((l, fun) => fun(l)) + /** Helper script to extract a module name from a named Module or Component */ + def getModuleName(n: Named): String = { + n match { + case ModuleName(m, _) => m + case ComponentName(_, ModuleName(m, _)) => m + case _ => throw new WiringException( + "Only Components or Modules have an associated Module name") + } } - /** Return a map from module to its lineage in the tree + /** Determine the Type of a specific component + * + * @param c the circuit containing the target module + * @param module the module containing the target component + * @param comp the target component + * @return the component's type + * @throws WiringException if the module is not contained in the + * circuit or if the component is not contained in the module */ - def pointToLineage(lin: Lineage): Map[String, Lineage] = { - val map = mutable.HashMap[String, Lineage]() - def onLineage(l: Lineage): Lineage = { - map(l.name) = l - l map onLineage + def getType(c: Circuit, module: String, comp: String): Type = { + def getRoot(e: Expression): String = e match { + case r: Reference => r.name + case i: SubIndex => getRoot(i.expr) + case a: SubAccess => getRoot(a.expr) + case f: SubField => getRoot(f.expr) + } + val eComp = toExp(comp) + val root = getRoot(eComp) + var tpe: Option[Type] = None + def getType(s: Statement): Statement = s match { + case DefRegister(_, n, t, _, _, _) if n == root => + tpe = Some(t) + s + case DefWire(_, n, t) if n == root => + tpe = Some(t) + s + case WDefInstance(_, n, _, t) if n == root => + tpe = Some(t) + s + case DefNode(_, n, e) if n == root => + tpe = Some(e.tpe) + s + case sx: DefMemory if sx.name == root => + tpe = Some(MemPortUtils.memType(sx)) + sx + case sx => sx map getType + } + val m = c.modules find (_.name == module) getOrElse { + throw new WiringException(s"Must have a module named $module") } + tpe = m.ports find (_.name == root) map (_.tpe) + m match { + case Module(i, n, ps, b) => getType(b) + case e: ExtModule => + } + tpe match { + case None => throw new WiringException(s"Didn't find $comp in $module!") + case Some(t) => + def setType(e: Expression): Expression = e map setType match { + case ex: Reference => ex.copy(tpe = t) + case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name)) + case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe)) + case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe)) + } + setType(eComp).tpe } - onLineage(lin) - map.toMap } } diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala index 01ad573f..5dd048a3 100644 --- a/src/test/scala/firrtlTests/WiringTests.scala +++ b/src/test/scala/firrtlTests/WiringTests.scala @@ -33,86 +33,87 @@ class WiringTests extends FirrtlFlatSpec { InferWidths ) - "Wiring from r to X" should "work" in { - val sinks = Set("X") - val sas = WiringInfo("C", "r", sinks, "pin", "A") + it should "wire from a register source (r) to multiple extmodule sinks (X)" in { + val sinks = Seq(ModuleName("X", CircuitName("Top"))) + val source = ComponentName("r", ModuleName("C", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") val input = - """circuit Top : - | module Top : - | input clock: Clock - | inst a of A - | a.clock <= clock - | module A : - | input clock: Clock - | inst b of B - | b.clock <= clock - | inst x of X - | x.clock <= clock - | inst d of D - | d.clock <= clock - | module B : - | input clock: Clock - | inst c of C - | c.clock <= clock - | inst d of D - | d.clock <= clock - | module C : - | input clock: Clock - | reg r: UInt<5>, clock - | module D : - | input clock: Clock - | inst x1 of X - | x1.clock <= clock - | inst x2 of X - | x2.clock <= clock - | extmodule X : - | input clock: Clock - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | inst b of B + | b.clock <= clock + | inst x of X + | x.clock <= clock + | inst d of D + | d.clock <= clock + | module B : + | input clock: Clock + | inst c of C + | c.clock <= clock + | inst d of D + | d.clock <= clock + | module C : + | input clock: Clock + | reg r: UInt<5>, clock + | module D : + | input clock: Clock + | inst x1 of X + | x1.clock <= clock + | inst x2 of X + | x2.clock <= clock + | extmodule X : + | input clock: Clock + |""".stripMargin val check = - """circuit Top : - | module Top : - | input clock: Clock - | inst a of A - | a.clock <= clock - | module A : - | input clock: Clock - | inst b of B - | b.clock <= clock - | inst x of X - | x.clock <= clock - | inst d of D - | d.clock <= clock - | wire r: UInt<5> - | r <= b.r - | x.pin <= r - | d.r <= r - | module B : - | input clock: Clock - | output r: UInt<5> - | inst c of C - | c.clock <= clock - | inst d of D - | d.clock <= clock - | r <= c.r_0 - | d.r <= r - | module C : - | input clock: Clock - | output r_0: UInt<5> - | reg r: UInt<5>, clock - | r_0 <= r - | module D : - | input clock: Clock - | input r: UInt<5> - | inst x1 of X - | x1.clock <= clock - | inst x2 of X - | x2.clock <= clock - | x1.pin <= r - | x2.pin <= r - | extmodule X : - | input clock: Clock - | input pin: UInt<5> - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire r: UInt<5> + | inst b of B + | b.clock <= clock + | inst x of X + | x.clock <= clock + | inst d of D + | d.clock <= clock + | d.r <= r + | r <= b.r + | x.pin <= r + | module B : + | input clock: Clock + | output r: UInt<5> + | inst c of C + | c.clock <= clock + | inst d of D + | d.clock <= clock + | r <= c.r_0 + | d.r <= r + | module C : + | input clock: Clock + | output r_0: UInt<5> + | reg r: UInt<5>, clock + | r_0 <= r + | module D : + | input clock: Clock + | input r: UInt<5> + | inst x1 of X + | x1.clock <= clock + | inst x2 of X + | x2.clock <= clock + | x1.pin <= r + | x2.pin <= r + | extmodule X : + | input clock: Clock + | input pin: UInt<5> + |""".stripMargin val c = passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) } @@ -121,41 +122,87 @@ class WiringTests extends FirrtlFlatSpec { (parse(retC.serialize).serialize) should be (parse(check).serialize) } - "Wiring from r.x to X" should "work" in { - val sinks = Set("X") - val sas = WiringInfo("A", "r.x", sinks, "pin", "A") + it should "wire from a register source (r) to multiple module sinks (X)" in { + val sinks = Seq(ModuleName("X", CircuitName("Top"))) + val source = ComponentName("r", ModuleName("C", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") val input = - """circuit Top : - | module Top : - | input clock: Clock - | inst a of A - | a.clock <= clock - | module A : - | input clock: Clock - | reg r : {x: UInt<5>}, clock - | inst x of X - | x.clock <= clock - | extmodule X : - | input clock: Clock - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | inst b of B + | b.clock <= clock + | inst x of X + | x.clock <= clock + | inst d of D + | d.clock <= clock + | module B : + | input clock: Clock + | inst c of C + | c.clock <= clock + | inst d of D + | d.clock <= clock + | module C : + | input clock: Clock + | reg r: UInt<5>, clock + | module D : + | input clock: Clock + | inst x1 of X + | x1.clock <= clock + | inst x2 of X + | x2.clock <= clock + | module X : + | input clock: Clock + |""".stripMargin val check = - """circuit Top : - | module Top : - | input clock: Clock - | inst a of A - | a.clock <= clock - | module A : - | input clock: Clock - | reg r: {x: UInt<5>}, clock - | inst x of X - | x.clock <= clock - | wire r_x: UInt<5> - | r_x <= r.x - | x.pin <= r_x - | extmodule X : - | input clock: Clock - | input pin: UInt<5> - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire r: UInt<5> + | inst b of B + | b.clock <= clock + | inst x of X + | x.clock <= clock + | inst d of D + | d.clock <= clock + | d.r <= r + | r <= b.r + | x.pin <= r + | module B : + | input clock: Clock + | output r: UInt<5> + | inst c of C + | c.clock <= clock + | inst d of D + | d.clock <= clock + | r <= c.r_0 + | d.r <= r + | module C : + | input clock: Clock + | output r_0: UInt<5> + | reg r: UInt<5>, clock + | r_0 <= r + | module D : + | input clock: Clock + | input r: UInt<5> + | inst x1 of X + | x1.clock <= clock + | inst x2 of X + | x2.clock <= clock + | x1.pin <= r + | x2.pin <= r + | module X : + | input clock: Clock + | input pin: UInt<5> + |""".stripMargin val c = passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) } @@ -163,39 +210,91 @@ class WiringTests extends FirrtlFlatSpec { val retC = wiringPass.run(c) (parse(retC.serialize).serialize) should be (parse(check).serialize) } - "Wiring from clock to X" should "work" in { - val sinks = Set("X") - val sas = WiringInfo("A", "clock", sinks, "pin", "A") + + it should "wire from a register sink (r) to a wire source (s) in another module (X)" in { + val sinks = Seq(ComponentName("s", ModuleName("X", CircuitName("Top")))) + val source = ComponentName("r", ModuleName("C", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") val input = - """circuit Top : - | module Top : - | input clock: Clock - | inst a of A - | a.clock <= clock - | module A : - | input clock: Clock - | inst x of X - | x.clock <= clock - | extmodule X : - | input clock: Clock - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | inst b of B + | b.clock <= clock + | inst x of X + | x.clock <= clock + | inst d of D + | d.clock <= clock + | module B : + | input clock: Clock + | inst c of C + | c.clock <= clock + | inst d of D + | d.clock <= clock + | module C : + | input clock: Clock + | reg r: UInt<5>, clock + | module D : + | input clock: Clock + | inst x1 of X + | x1.clock <= clock + | inst x2 of X + | x2.clock <= clock + | module X : + | input clock: Clock + | wire s: UInt<5> + |""".stripMargin val check = - """circuit Top : - | module Top : - | input clock: Clock - | inst a of A - | a.clock <= clock - | module A : - | input clock: Clock - | inst x of X - | x.clock <= clock - | wire clock_0: Clock - | clock_0 <= clock - | x.pin <= clock_0 - | extmodule X : - | input clock: Clock - | input pin: Clock - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire r: UInt<5> + | inst b of B + | b.clock <= clock + | inst x of X + | x.clock <= clock + | inst d of D + | d.clock <= clock + | d.r <= r + | r <= b.r + | x.pin <= r + | module B : + | input clock: Clock + | output r: UInt<5> + | inst c of C + | c.clock <= clock + | inst d of D + | d.clock <= clock + | r <= c.r_0 + | d.r <= r + | module C : + | input clock: Clock + | output r_0: UInt<5> + | reg r: UInt<5>, clock + | r_0 <= r + | module D : + | input clock: Clock + | input r: UInt<5> + | inst x1 of X + | x1.clock <= clock + | inst x2 of X + | x2.clock <= clock + | x1.pin <= r + | x2.pin <= r + | module X : + | input clock: Clock + | input pin: UInt<5> + | wire s: UInt<5> + | s <= pin + |""".stripMargin val c = passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) } @@ -203,69 +302,210 @@ class WiringTests extends FirrtlFlatSpec { val retC = wiringPass.run(c) (parse(retC.serialize).serialize) should be (parse(check).serialize) } - "Two sources" should "fail" in { - val sinks = Set("X") - val sas = WiringInfo("A", "clock", sinks, "pin", "Top") + + it should "wire from a SubField source (r.x) to an extmodule sink (X)" in { + val sinks = Seq(ModuleName("X", CircuitName("Top"))) + val source = ComponentName("r.x", ModuleName("A", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") val input = - """circuit Top : - | module Top : - | input clock: Clock - | inst a1 of A - | a1.clock <= clock - | inst a2 of A - | a2.clock <= clock - | module A : - | input clock: Clock - | inst x of X - | x.clock <= clock - | extmodule X : - | input clock: Clock - |""".stripMargin - intercept[WiringException] { - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } - val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | reg r : {x: UInt<5>}, clock + | inst x of X + | x.clock <= clock + | extmodule X : + | input clock: Clock + |""".stripMargin + val check = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire r_x: UInt<5> + | reg r: {x: UInt<5>}, clock + | inst x of X + | x.clock <= clock + | x.pin <= r_x + | r_x <= r.x + | extmodule X : + | input clock: Clock + | input pin: UInt<5> + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(Seq(sas)) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + + it should "wire properly with a source as a submodule of a sink" in { + val sinks = Seq(ComponentName("s", ModuleName("A", CircuitName("Top")))) + val source = ComponentName("r", ModuleName("X", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") + val input = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire s: UInt<5> + | inst x of X + | x.clock <= clock + | module X : + | input clock: Clock + | reg r: UInt<5>, clock + |""".stripMargin + val check = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire pin: UInt<5> + | wire s: UInt<5> + | inst x of X + | x.clock <= clock + | pin <= x.r_0 + | s <= pin + | module X : + | input clock: Clock + | output r_0: UInt<5> + | reg r: UInt<5>, clock + | r_0 <= r + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(Seq(sas)) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + + it should "wire with source and sink in the same module" in { + val sinks = Seq(ComponentName("s", ModuleName("A", CircuitName("Top")))) + val source = ComponentName("r", ModuleName("A", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") + val input = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire s: UInt<5> + | reg r: UInt<5>, clock + |""".stripMargin + val check = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire pin: UInt<5> + | wire s: UInt<5> + | reg r: UInt<5>, clock + | s <= pin + | pin <= r + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(Seq(sas)) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + + it should "wire multiple sinks in the same module" in { + val sinks = Seq(ComponentName("s", ModuleName("A", CircuitName("Top"))), + ComponentName("t", ModuleName("A", CircuitName("Top")))) + val source = ComponentName("r", ModuleName("A", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") + val input = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire s: UInt<5> + | wire t: UInt<5> + | reg r: UInt<5>, clock + |""".stripMargin + val check = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire pin: UInt<5> + | wire s: UInt<5> + | wire t: UInt<5> + | reg r: UInt<5>, clock + | t <= pin + | s <= pin + | pin <= r + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) } + val wiringPass = new Wiring(Seq(sas)) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) } - "Wiring from A.clock to X, with 2 A's, and A as top" should "work" in { - val sinks = Set("X") - val sas = WiringInfo("A", "clock", sinks, "pin", "A") + + it should "wire clocks" in { + val sinks = Seq(ModuleName("X", CircuitName("Top"))) + val source = ComponentName("clock", ModuleName("A", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") val input = - """circuit Top : - | module Top : - | input clock: Clock - | inst a1 of A - | a1.clock <= clock - | inst a2 of A - | a2.clock <= clock - | module A : - | input clock: Clock - | inst x of X - | x.clock <= clock - | extmodule X : - | input clock: Clock - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | inst x of X + | x.clock <= clock + | extmodule X : + | input clock: Clock + |""".stripMargin val check = - """circuit Top : - | module Top : - | input clock: Clock - | inst a1 of A - | a1.clock <= clock - | inst a2 of A - | a2.clock <= clock - | module A : - | input clock: Clock - | inst x of X - | x.clock <= clock - | wire clock_0: Clock - | clock_0 <= clock - | x.pin <= clock_0 - | extmodule X : - | input clock: Clock - | input pin: Clock - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire clock_0: Clock + | inst x of X + | x.clock <= clock + | x.pin <= clock_0 + | clock_0 <= clock + | extmodule X : + | input clock: Clock + | input pin: Clock + |""".stripMargin val c = passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) } @@ -273,26 +513,120 @@ class WiringTests extends FirrtlFlatSpec { val retC = wiringPass.run(c) (parse(retC.serialize).serialize) should be (parse(check).serialize) } - "Wiring from A.clock to X, with 2 A's, and A as top, but Top instantiates X" should "error" in { - val sinks = Set("X") - val sas = WiringInfo("A", "clock", sinks, "pin", "A") + + it should "handle two source instances with clearly defined sinks" in { + val sinks = Seq(ModuleName("X", CircuitName("Top"))) + val source = ComponentName("clock", ModuleName("A", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") val input = - """circuit Top : - | module Top : - | input clock: Clock - | inst a1 of A - | a1.clock <= clock - | inst a2 of A - | a2.clock <= clock - | inst x of X - | x.clock <= clock - | module A : - | input clock: Clock - | inst x of X - | x.clock <= clock - | extmodule X : - | input clock: Clock - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a1 of A + | a1.clock <= clock + | inst a2 of A + | a2.clock <= clock + | module A : + | input clock: Clock + | inst x of X + | x.clock <= clock + | extmodule X : + | input clock: Clock + |""".stripMargin + val check = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a1 of A + | a1.clock <= clock + | inst a2 of A + | a2.clock <= clock + | module A : + | input clock: Clock + | wire clock_0: Clock + | inst x of X + | x.clock <= clock + | x.pin <= clock_0 + | clock_0 <= clock + | extmodule X : + | input clock: Clock + | input pin: Clock + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(Seq(sas)) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + + it should "wire multiple clocks" in { + val sinks = Seq(ModuleName("X", CircuitName("Top"))) + val source = ComponentName("clock", ModuleName("A", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") + val input = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a1 of A + | a1.clock <= clock + | inst a2 of A + | a2.clock <= clock + | module A : + | input clock: Clock + | inst x of X + | x.clock <= clock + | extmodule X : + | input clock: Clock + |""".stripMargin + val check = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a1 of A + | a1.clock <= clock + | inst a2 of A + | a2.clock <= clock + | module A : + | input clock: Clock + | wire clock_0: Clock + | inst x of X + | x.clock <= clock + | x.pin <= clock_0 + | clock_0 <= clock + | extmodule X : + | input clock: Clock + | input pin: Clock + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(Seq(sas)) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + + it should "error with WiringException for indeterminate ownership" in { + val sinks = Seq(ModuleName("X", CircuitName("Top"))) + val source = ComponentName("clock", ModuleName("A", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") + val input = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a1 of A + | a1.clock <= clock + | inst a2 of A + | a2.clock <= clock + | inst x of X + | x.clock <= clock + | module A : + | input clock: Clock + | inst x of X + | x.clock <= clock + | extmodule X : + | input clock: Clock + |""".stripMargin intercept[WiringException] { val c = passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) @@ -301,43 +635,45 @@ class WiringTests extends FirrtlFlatSpec { val retC = wiringPass.run(c) } } - "Wiring from A.r[a] to X" should "work" in { - val sinks = Set("X") - val sas = WiringInfo("A", "r[a]", sinks, "pin", "A") + + it should "wire subindex source to sink" in { + val sinks = Seq(ModuleName("X", CircuitName("Top"))) + val source = ComponentName("r[a]", ModuleName("A", CircuitName("Top"))) + val sas = WiringInfo(source, sinks, "pin") val input = - """circuit Top : - | module Top : - | input clock: Clock - | inst a of A - | a.clock <= clock - | module A : - | input clock: Clock - | reg r: UInt<2>[5], clock - | node a = UInt(5) - | inst x of X - | x.clock <= clock - | extmodule X : - | input clock: Clock - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | reg r: UInt<2>[5], clock + | node a = UInt(5) + | inst x of X + | x.clock <= clock + | extmodule X : + | input clock: Clock + |""".stripMargin val check = - """circuit Top : - | module Top : - | input clock: Clock - | inst a of A - | a.clock <= clock - | module A : - | input clock: Clock - | reg r: UInt<2>[5], clock - | node a = UInt(5) - | inst x of X - | x.clock <= clock - | wire r_a: UInt<2> - | r_a <= r[a] - | x.pin <= r_a - | extmodule X : - | input clock: Clock - | input pin: UInt<2> - |""".stripMargin + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire r_a: UInt<2> + | reg r: UInt<2>[5], clock + | node a = UInt(5) + | inst x of X + | x.clock <= clock + | x.pin <= r_a + | r_a <= r[a] + | extmodule X : + | input clock: Clock + | input pin: UInt<2> + |""".stripMargin val c = passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) } @@ -346,37 +682,182 @@ class WiringTests extends FirrtlFlatSpec { (parse(retC.serialize).serialize) should be (parse(check).serialize) } - "Wiring annotations" should "work" in { + it should "wire using Annotations with a sink module" in { val source = SourceAnnotation(ComponentName("r", ModuleName("Top", CircuitName("Top"))), "pin") val sink = SinkAnnotation(ModuleName("X", CircuitName("Top")), "pin") - val top = TopAnnotation(ModuleName("Top", CircuitName("Top")), "pin") val input = - """circuit Top : - | module Top : - | input clk: Clock - | inst x of X - | reg r: UInt<5>, clk - | extmodule X : - | input clk: Clock - |""".stripMargin + """|circuit Top : + | module Top : + | input clk: Clock + | inst x of X + | x.clk <= clk + | reg r: UInt<5>, clk + | extmodule X : + | input clk: Clock + |""".stripMargin val check = - """circuit Top : - | module Top : - | input clk: Clock - | inst x of X - | reg r: UInt<5>, clk - | wire r_0 : UInt<5> - | r_0 <= r - | x.pin <= r_0 - | extmodule X : - | input clk: Clock - | input pin: UInt<5> - |""".stripMargin + """|circuit Top : + | module Top : + | input clk: Clock + | wire r_0 : UInt<5> + | inst x of X + | x.clk <= clk + | reg r: UInt<5>, clk + | x.pin <= r_0 + | r_0 <= r + | extmodule X : + | input clk: Clock + | input pin: UInt<5> + |""".stripMargin val c = passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) } val wiringXForm = new WiringTransform() - val retC = wiringXForm.execute(CircuitState(c, LowForm, Some(AnnotationMap(Seq(source, sink, top))), None)).circuit + val retC = wiringXForm.execute(CircuitState(c, MidForm, Some(AnnotationMap(Seq(source, sink))), None)).circuit + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + + it should "wire using Annotations with a sink component" in { + val source = SourceAnnotation(ComponentName("r", ModuleName("Top", CircuitName("Top"))), "pin") + val sink = SinkAnnotation(ComponentName("s", ModuleName("X", CircuitName("Top"))), "pin") + val input = + """|circuit Top : + | module Top : + | input clk: Clock + | inst x of X + | x.clk <= clk + | reg r: UInt<5>, clk + | module X : + | input clk: Clock + | wire s: UInt<5> + |""".stripMargin + val check = + """|circuit Top : + | module Top : + | input clk: Clock + | wire r_0 : UInt<5> + | inst x of X + | x.clk <= clk + | reg r: UInt<5>, clk + | x.pin <= r_0 + | r_0 <= r + | module X : + | input clk: Clock + | input pin: UInt<5> + | wire s: UInt<5> + | s <= pin + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringXForm = new WiringTransform() + val retC = wiringXForm.execute(CircuitState(c, MidForm, Some(AnnotationMap(Seq(source, sink))), None)).circuit + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + + it should "wire using annotations with Aggregate source" in { + val source = SourceAnnotation(ComponentName("bundle", ModuleName("A", CircuitName("Top"))), "pin") + val sink = SinkAnnotation(ModuleName("B", CircuitName("Top")), "pin") + val input = + """|circuit Top : + | module Top : + | input clock : Clock + | inst a of A + | inst b of B + | a.clock <= clock + | b.clock <= clock + | module A : + | input clock : Clock + | wire bundle : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} } + | bundle is invalid + | module B : + | input clock : Clock""".stripMargin + val check = + """|circuit Top : + | module Top : + | input clock : Clock + | wire bundle : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} } + | inst a of A + | inst b of B + | a.clock <= clock + | b.clock <= clock + | b.pin <= bundle + | bundle <= a.bundle_0 + | module A : + | input clock : Clock + | output bundle_0 : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} } + | wire bundle : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} } + | bundle is invalid + | bundle_0 <= bundle + | module B : + | input clock : Clock + | input pin : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }""" + .stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringXForm = new WiringTransform() + val retC = wiringXForm.execute(CircuitState(c, MidForm, Some(AnnotationMap(Seq(source, sink))), None)).circuit + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + + it should "wire one sink to multiple, disjoint extmodules" in { + val sinkX = Seq(ModuleName("X", CircuitName("Top"))) + val sourceX = ComponentName("r.x", ModuleName("A", CircuitName("Top"))) + val sinkY = Seq(ModuleName("Y", CircuitName("Top"))) + val sourceY = ComponentName("r.x", ModuleName("A", CircuitName("Top"))) + val wiSeq = Seq( + WiringInfo(sourceX, sinkX, "pin"), + WiringInfo(sourceY, sinkY, "pin")) + val input = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | reg r : {x: UInt<5>}, clock + | inst x of X + | x.clock <= clock + | inst y of Y + | y.clock <= clock + | extmodule X : + | input clock: Clock + | extmodule Y : + | input clock: Clock + |""".stripMargin + val check = + """|circuit Top : + | module Top : + | input clock: Clock + | inst a of A + | a.clock <= clock + | module A : + | input clock: Clock + | wire r_x_0: UInt<5> + | wire r_x: UInt<5> + | reg r: {x: UInt<5>}, clock + | inst x of X + | x.clock <= clock + | inst y of Y + | y.clock <= clock + | x.pin <= r_x + | r_x <= r.x + | y.pin <= r_x_0 + | r_x_0 <= r.x + | extmodule X : + | input clock: Clock + | input pin: UInt<5> + | extmodule Y : + | input clock: Clock + | input pin: UInt<5> + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(wiSeq) + val retC = wiringPass.run(c) (parse(retC.serialize).serialize) should be (parse(check).serialize) } } diff --git a/src/test/scala/firrtlTests/graph/EulerTourTests.scala b/src/test/scala/firrtlTests/graph/EulerTourTests.scala new file mode 100644 index 00000000..0b69ce61 --- /dev/null +++ b/src/test/scala/firrtlTests/graph/EulerTourTests.scala @@ -0,0 +1,36 @@ +package firrtlTests.graph + +import firrtl.graph._ +import firrtlTests._ + +class EulerTourTests extends FirrtlFlatSpec { + + val top = "top" + val first_layer = Set("1a", "1b", "1c") + val second_layer = Set("2a", "2b", "2c") + val third_layer = Set("3a", "3b", "3c") + val last_null = Set.empty[String] + + val m = Map(top -> first_layer) ++ first_layer.map{ + case x => Map(x -> second_layer) }.flatten.toMap ++ second_layer.map{ + case x => Map(x -> third_layer) }.flatten.toMap ++ third_layer.map{ + case x => Map(x -> last_null) }.flatten.toMap + + val graph = DiGraph(m) + val instances = graph.pathsInDAG(top).values.flatten + val tour = EulerTour(graph, top) + + it should "show equivalency of Berkman--Vishkin and naive RMQs" in { + instances.toSeq.combinations(2).toList.map { case Seq(a, b) => + tour.rmqNaive(a, b) should be (tour.rmqBV(a, b)) + } + } + + it should "determine naive RMQs of itself correctly" in { + instances.toSeq.map { case a => tour.rmqNaive(a, a) should be (a) } + } + + it should "determine Berkman--Vishkin RMQs of itself correctly" in { + instances.toSeq.map { case a => tour.rmqNaive(a, a) should be (a) } + } +} |
