aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--src/main/scala/firrtl/analyses/Netlist.scala20
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala3
-rw-r--r--src/main/scala/firrtl/graph/EulerTour.scala223
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockList.scala6
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala15
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala2
-rw-r--r--src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala2
-rw-r--r--src/main/scala/firrtl/passes/memlib/DecorateMems.scala3
-rw-r--r--src/main/scala/firrtl/passes/wiring/Wiring.scala287
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringTransform.scala103
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringUtils.scala226
-rw-r--r--src/test/scala/firrtlTests/WiringTests.scala1035
-rw-r--r--src/test/scala/firrtlTests/graph/EulerTourTests.scala36
14 files changed, 1411 insertions, 553 deletions
diff --git a/.gitignore b/.gitignore
index 41d9d628..94a12c29 100644
--- a/.gitignore
+++ b/.gitignore
@@ -56,3 +56,6 @@ project/plugins/project/
gen/
project/project/
/bin/
+
+*~
+*#*#
diff --git a/src/main/scala/firrtl/analyses/Netlist.scala b/src/main/scala/firrtl/analyses/Netlist.scala
index 4e211d8b..99f3645f 100644
--- a/src/main/scala/firrtl/analyses/Netlist.scala
+++ b/src/main/scala/firrtl/analyses/Netlist.scala
@@ -1,3 +1,5 @@
+// See LICENSE for license details.
+
package firrtl.analyses
import scala.collection.mutable
@@ -10,13 +12,14 @@ import firrtl.Mappers._
/** A class representing the instance hierarchy of a working IR Circuit
- *
+ *
* @constructor constructs an instance graph from a Circuit
* @param c the Circuit to analyze
*/
class InstanceGraph(c: Circuit) {
- private def collectInstances(insts: mutable.Set[WDefInstance])(s: Statement): Statement = s match {
+ private def collectInstances(insts: mutable.Set[WDefInstance])
+ (s: Statement): Statement = s match {
case i: WDefInstance =>
insts += i
i
@@ -72,7 +75,7 @@ class InstanceGraph(c: Circuit) {
/** Finds the absolute paths (each represented by a Seq of instances
* representing the chain of hierarchy) of all instances of a
* particular module.
- *
+ *
* @param module the name of the selected module
* @return a Seq[Seq[WDefInstance]] of absolute instance paths
*/
@@ -81,5 +84,14 @@ class InstanceGraph(c: Circuit) {
instances flatMap { i => fullHierarchy(i) }
}
-}
+ /** An `[[EulerTour]]` representation of the `[[DiGraph]]` */
+ lazy val tour = EulerTour(graph, trueTopInstance)
+ /** Finds the lowest common ancestor instances for two module names in
+ * a design
+ */
+ def lowestCommonAncestor(moduleA: Seq[WDefInstance],
+ moduleB: Seq[WDefInstance]): Seq[WDefInstance] = {
+ tour.rmq(moduleA, moduleB)
+ }
+}
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
index b869982d..6538c880 100644
--- a/src/main/scala/firrtl/graph/DiGraph.scala
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -1,3 +1,5 @@
+// See LICENSE for license details.
+
package firrtl.graph
import scala.collection.{Set, Map}
@@ -306,7 +308,6 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
edges.foreach({ case (k, v) => eprime(f(k)) ++= v.map(f(_)) })
new DiGraph(eprime)
}
-
}
class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]]) {
diff --git a/src/main/scala/firrtl/graph/EulerTour.scala b/src/main/scala/firrtl/graph/EulerTour.scala
new file mode 100644
index 00000000..db25d8d0
--- /dev/null
+++ b/src/main/scala/firrtl/graph/EulerTour.scala
@@ -0,0 +1,223 @@
+// See LICENSE for license details.
+
+package firrtl.graph
+
+import scala.collection.mutable
+
+/** Euler Tour companion object */
+object EulerTour {
+ /** Create an Euler Tour of a `DiGraph[T]` */
+ def apply[T](diGraph: DiGraph[T], start: T): EulerTour[Seq[T]] = {
+ val r = mutable.Map[Seq[T], Int]()
+ val e = mutable.ArrayBuffer[Seq[T]]()
+ val h = mutable.ArrayBuffer[Int]()
+
+ def tour(u: T, parent: Vector[T], height: Int): Unit = {
+ val id = parent :+ u
+ r.getOrElseUpdate(id, e.size)
+ e += id
+ h += height
+ diGraph.getEdges(id.last).foreach { v =>
+ tour(v, id, height + 1)
+ e += id
+ h += height
+ }
+ }
+
+ tour(start, Vector.empty, 0)
+ new EulerTour(r.toMap, e, h)
+ }
+}
+
+/** A class that represents an Euler Tour of a directed graph from a
+ * given root. This requires `O(n)` preprocessing time to generate
+ * the initial Euler Tour.
+ *
+ * @constructor Create a new EulerTour from the specified data
+ * @param r A map of a node to its first index
+ * @param e A representation of the EulerTour as a `Seq[T]`
+ * @param h The depths of the Euler Tour represented as a `Seq[Int]`
+ */
+class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) {
+ private def lg(x: Double): Double = math.log(x) / math.log(2)
+
+ /** Range Minimum Query of an Euler Tour using a naive algorithm.
+ *
+ * @param x The first query bound
+ * @param y The second query bound
+ * @return The minimum between the first and second query
+ * @note The order of '''x''' and '''y''' does not matter
+ * @note '''Performance''':
+ * - preprocessing: `O(1)`
+ * - query: `O(n)`
+ */
+ def rmqNaive(x: T, y: T): T = {
+ val Seq(i, j) = Seq(r(x), r(y)).sorted
+ e.zip(h).slice(i, j + 1).minBy(_._2)._1
+ }
+
+ // n: the length of the Euler Tour
+ // m: the size of blocks the Euler Tour is split into
+ private val n = h.size
+ private val m = math.ceil(lg(n) / 2).toInt
+
+ /** Split up the tour into blocks of size m, padding the last block to
+ * be a multiple of m. Compute the minimum of each block, a, and
+ * the index of that minimum in each block, b.
+ */
+ private lazy val blocks = (h ++ (1 to (m - n % m))).grouped(m).toArray
+ private lazy val a = blocks map (_.min)
+ private lazy val b = blocks map (b => b.indexOf(b.min))
+
+ /** Construct a Sparse Table (ST) representation for the minimum index
+ * of a sequence of integers. Data in the returned array is indexed
+ * as: [base, power of 2 range]
+ */
+ private def constructSparseTable(x: Seq[Int]): Array[Array[Int]] = {
+ val tmp = Array.ofDim[Int](x.size + 1, math.ceil(lg(x.size)).toInt)
+ for (i <- 0 to x.size - 1; j <- 0 to math.ceil(lg(x.size)).toInt - 1) {
+ tmp(i)(j) = -1
+ }
+
+ def tableRecursive(base: Int, size: Int): Int = {
+ if (size == 0) {
+ tmp(base)(size) = base
+ base
+ } else {
+ val (a, b, c) = (base, base + (1 << (size - 1)), size - 1)
+
+ val l = if (tmp(a)(c) != -1) { tmp(a)(c) }
+ else { tableRecursive(a, c) }
+
+ val r = if (tmp(b)(c) != -1) { tmp(b)(c) }
+ else { tableRecursive(b, c) }
+
+ val min = if (x(l) < x(r)) l else r
+ tmp(base)(size) = min
+ assert(min >= base)
+ min
+ }
+ }
+
+ for (i <- (0 to x.size - 1);
+ j <- (0 to math.ceil(lg(x.size)).toInt - 1);
+ if i + (1 << j) - 1 < x.size) {
+ tableRecursive(i, j)
+ }
+ tmp
+ }
+ private lazy val st = constructSparseTable(a)
+
+ /** Precompute all possible RMQs for an array of size `n where each
+ * entry in the range is different from the last by only +-1
+ */
+ private def constructTableLookups(n: Int): Array[Array[Array[Int]]] = {
+ def sortSeqSeq[T <: Int](x: Seq[T], y: Seq[T]): Boolean = {
+ if (x(0) != y(0)) x(0) < y(0) else sortSeqSeq(x.tail, y.tail)
+ }
+
+ val size = m - 1
+ val out = Seq.fill(size)(Seq(-1, 1))
+ .flatten.combinations(m - 1).flatMap(_.permutations).toList
+ .sortWith(sortSeqSeq)
+ .map(_.foldLeft(Seq(0))((h, pm) => (h.head + pm) +: h).reverse)
+ .map{ a =>
+ var tmp = Array.ofDim[Int](m, m)
+ for (i <- 0 to size; j <- i to size) yield {
+ val window = a.slice(i, j + 1)
+ tmp(i)(j) = window.indexOf(window.min) + i }
+ tmp }.toArray
+ out
+ }
+ private lazy val tables = constructTableLookups(m)
+
+ /** Compute the precomputed table index of a given block */
+ private def mapBlockToTable(block: Seq[Int]): Int = {
+ var index = 0
+ var power = block.size - 2
+ for (Seq(l, r) <- block.sliding(2)) {
+ if (l < r) { index += 1 << power }
+ power -= 1
+ }
+ index
+ }
+
+ /** Precompute a mapping of all blocks to their precomputed RMQ table
+ * indices
+ */
+ private def mapBlocksToTables(blocks: Seq[Seq[Int]]): Array[Int] = {
+ val out = blocks.map(mapBlockToTable(_)).toArray
+ out
+ }
+ private lazy val tableIdx = mapBlocksToTables(blocks)
+
+ /** Range Minimum Query using the Berkman--Vishkin algorithm with the
+ * simplifications of Bender--Farach-Colton.
+ *
+ * @param x The first query bound
+ * @param y The second query bound
+ * @return The minimum between the first and second query
+ * @note The order of '''x''' and '''y''' does not matter
+ * @note '''Performance''':
+ * - preprocessing: `O(n)`
+ * - query: `O(1)`
+ */
+ def rmqBV(x: T, y: T): T = {
+ val Seq(i, j) = Seq(r(x), r(y)).sorted
+
+ // Compute block and word indices
+ val (block_i, block_j) = (i / m, j / m)
+ val (word_i, word_j) = (i % m, j % m)
+
+ /** Up to four possible minimum indices are then computed based on the
+ * following conditions:
+ * 1. `i` and `j` are in the same block:
+ * - one precomputed RMQ from `i` to `j`
+ * 2. `i` and `j` are in adjacent blocks:
+ * - one precomputed RMQ from `i` to the end of its block
+ * - one precomputed RMQ from `j` to the beginning of its block
+ * 3. `i` and `j` have blocks between them:
+ * - one precomputed RMQ from `i` to the end of its block
+ * - one precomputed RMQ from `j` to the beginning of its block
+ * - two sparse table lookups to fully cover all blocks
+ * between `i` and `j`
+ */
+ val minIndices = (block_i, block_j) match {
+ case (bi, bj) if (block_i == block_j) =>
+ val min_i = block_i * m + tables(tableIdx(block_i))(word_i)(word_j)
+ Seq(min_i)
+ case (bi, bj) if (block_i == block_j - 1) =>
+ val min_i = block_i * m + tables(tableIdx(block_i))(word_i)( m - 1)
+ val min_j = block_j * m + tables(tableIdx(block_j))( 0)(word_j)
+ Seq(min_i, min_j)
+ case _ =>
+ val min_i = block_i * m + tables(tableIdx(block_i))(word_i)( m - 1)
+ val min_j = block_j * m + tables(tableIdx(block_j))( 0)(word_j)
+ val (min_between_l, min_between_r) = {
+ val range = math.floor(lg(block_j - block_i - 1)).toInt
+ val base_0 = block_i + 1
+ val base_1 = block_j - (1 << range)
+
+ val (idx_0, idx_1) = (st(base_0)(range), st(base_1)(range))
+ val (min_0, min_1) = (b(idx_0) + idx_0 * m, b(idx_1) + idx_1 * m)
+ (min_0, min_1) }
+ Seq(min_i, min_between_l, min_between_r, min_j)
+ }
+
+ // Return the minimum of all possible minimum indices
+ e(minIndices.minBy(h(_)))
+ }
+
+ /** Range Minimum Query of the Euler Tour.
+ *
+ * Use this for typical queries.
+ *
+ * @param x The first query bound
+ * @param y The second query bound
+ * @return The minimum between the first and second query
+ * @note This currently maps to `rmqBV`, but may choose to map to
+ * either `rmqBV` or `rmqNaive`
+ * @note The order of '''x''' and '''y''' does not matter
+ */
+ def rmq(x: T, y: T): T = rmqBV(x, y)
+}
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala
index bd2536ab..073eb050 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala
@@ -8,7 +8,7 @@ import firrtl.ir._
import annotations._
import Utils.error
import java.io.{File, CharArrayWriter, PrintWriter, Writer}
-import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage}
+import wiring.WiringUtils.{getChildrenMap, getLineage}
import wiring.Lineage
import ClockListUtils._
import Utils._
@@ -30,7 +30,7 @@ class ClockList(top: String, writer: Writer) extends Pass {
// === Checks ===
// TODO(izraelevitz): Check all registers/memories use "clock" clock port
// ==============
-
+
// Clock sources must be blackbox outputs and top's clock
val partialSourceList = getSourceList(moduleMap)(lineages)
val sourceList = partialSourceList ++ moduleMap(top).ports.collect{ case Port(i, n, Input, ClockType) => n }
@@ -39,7 +39,7 @@ class ClockList(top: String, writer: Writer) extends Pass {
// Remove everything from the circuit, unless it has a clock type
// This simplifies the circuit drastically so InlineInstances doesn't take forever.
val onlyClockCircuit = RemoveAllButClocks.run(c)
-
+
// Inline the clock-only circuit up to the specified top module
val modulesToInline = (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet
val inlineTransform = new InlineInstances
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
index b04171a7..24f25525 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
@@ -8,7 +8,6 @@ import firrtl.ir._
import annotations._
import Utils.error
import java.io.{File, CharArrayWriter, PrintWriter, Writer}
-import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage}
import wiring.Lineage
import ClockListUtils._
import Utils._
@@ -22,23 +21,23 @@ object ClockListAnnotation {
[Optional] ClockList
List which signal drives each clock of every descendent of specified module
-Usage:
+Usage:
--list-clocks -c:<circuit>:-m:<module>:-o:<filename>
*** Note: sub-arguments to --list-clocks should be delimited by : and not white space!
-"""
-
+"""
+
//Parse pass options
val passOptions = PassConfigUtil.getPassOptions(t, usage)
val outputConfig = passOptions.getOrElse(
- OutputConfigFileName,
+ OutputConfigFileName,
error("No output config file provided for ClockList!" + usage)
)
val passCircuit = passOptions.getOrElse(
- PassCircuitName,
+ PassCircuitName,
error("No circuit name specified for ClockList!" + usage)
)
val passModule = passOptions.getOrElse(
- PassModuleName,
+ PassModuleName,
error("No module name specified for ClockList!" + usage)
)
passOptions.get(InputConfigFileName) match {
@@ -65,7 +64,7 @@ class ClockListTransform extends Transform {
def passSeq(top: String, writer: Writer): Seq[Pass] =
Seq(new ClockList(top, writer))
def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match {
- case Seq(ClockListAnnotation(ModuleName(top, CircuitName(state.circuit.main)), out)) =>
+ case Seq(ClockListAnnotation(ModuleName(top, CircuitName(state.circuit.main)), out)) =>
val outputFile = new PrintWriter(out)
val newC = (new ClockList(top, outputFile)).run(state.circuit)
outputFile.close()
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
index b81d0c7e..892f1642 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
@@ -8,7 +8,6 @@ import firrtl.ir._
import annotations._
import Utils.error
import java.io.{File, CharArrayWriter, PrintWriter, Writer}
-import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage}
import wiring.Lineage
import ClockListUtils._
import Utils._
@@ -61,4 +60,3 @@ object ClockListUtils {
}
}
}
-
diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
index 53787b1d..1178ce69 100644
--- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
+++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
@@ -8,8 +8,6 @@ import firrtl.ir._
import annotations._
import Utils.error
import java.io.{File, CharArrayWriter, PrintWriter, Writer}
-import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage}
-import wiring.Lineage
import ClockListUtils._
import Utils._
import memlib.AnalysisUtils._
diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
index 648b0234..ad3616ad 100644
--- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
+++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
@@ -18,9 +18,8 @@ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform
val cN = CircuitName(state.circuit.main)
val oldAnnos = state.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations
val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { case ((annos, pins), config) =>
- val top = TopAnnotation(ModuleName(config.top.name, cN), config.pin.name)
val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)), config.pin.name)
- (annos ++ Seq(top, source), pins :+ config.pin.name)
+ (annos, pins :+ config.pin.name)
}
state.copy(annotations = Some(AnnotationMap(as :+ PinAnnotation(cN, pins.toSeq))))
}
diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala
index 9656abb2..a268dba7 100644
--- a/src/main/scala/firrtl/passes/wiring/Wiring.scala
+++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala
@@ -10,157 +10,194 @@ import firrtl.Mappers._
import scala.collection.mutable
import firrtl.annotations._
import firrtl.annotations.AnnotationUtils._
+import firrtl.analyses.InstanceGraph
import WiringUtils._
-case class WiringException(msg: String) extends PassException(msg)
+/** A data store of one sink--source wiring relationship */
+case class WiringInfo(source: ComponentName, sinks: Seq[Named], pin: String)
-case class WiringInfo(source: String, comp: String, sinks: Set[String], pin: String, top: String)
+/** A data store of wiring names */
+case class WiringNames(compName: String, source: String, sinks: Seq[Named],
+ pin: String)
+/** Pass that computes and applies a sequence of wiring modifications
+ *
+ * @constructor construct a new Wiring pass
+ * @param wiSeq the [[WiringInfo]] to apply
+ */
class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
- def run(c: Circuit): Circuit = {
- wiSeq.foldLeft(c) { (circuit, wi) => wire(circuit, wi) }
+ def run(c: Circuit): Circuit = analyze(c)
+ .foldLeft(c){
+ case (cx, (tpe, modsMap)) => cx.copy(
+ modules = cx.modules map onModule(tpe, modsMap)) }
+
+ /** Converts multiple units of wiring information to module modifications */
+ private def analyze(c: Circuit): Seq[(Type, Map[String, Modifications])] = {
+
+ val names = wiSeq
+ .map ( wi => (wi.source, wi.sinks, wi.pin) match {
+ case (ComponentName(comp, ModuleName(source,_)), sinks, pin) =>
+ WiringNames(comp, source, sinks, pin) })
+
+ val portNames = mutable.Seq.fill(names.size)(Map[String, String]())
+ c.modules.foreach{ m =>
+ val ns = Namespace(m)
+ names.zipWithIndex.foreach{ case (WiringNames(c, so, si, p), i) =>
+ portNames(i) = portNames(i) +
+ ( m.name -> {
+ if (si.exists(getModuleName(_) == m.name)) ns.newName(p)
+ else ns.newName(tokenize(c) filterNot ("[]." contains _) mkString "_")
+ })}}
+
+ val iGraph = new InstanceGraph(c)
+ names.zip(portNames).map{ case(WiringNames(comp, so, si, _), pn) =>
+ computeModifications(c, iGraph, comp, so, si, pn) }
}
- /** Add pins to modules and wires a signal to them, under the scope of a specified top module
- * Description:
- * Adds a pin to each sink module
- * Punches ports up from the source signal to the specified top module
- * Punches ports down to each sink module
- * Wires the source up and down, connecting to all sink modules
- * Restrictions:
- * - Can only have one source module instance under scope of the specified top
- * - All instances of each sink module must be under the scope of the specified top
- * Notes:
- * - No module uniquification occurs (due to imposed restrictions)
+ /** Converts a single unit of wiring information to module modifications
+ *
+ * @param c the circuit that will be modified
+ * @param iGraph an InstanceGraph representation of the circuit
+ * @param compName the name of a component
+ * @param source the name of the source component
+ * @param sinks a list of sink components/modules that the source
+ * should be connected to
+ * @param portNames a mapping of module names to new ports/wires
+ * that should be generated if needed
+ *
+ * @return a tuple of the component type and a map of module names
+ * to pending modifications
*/
- def wire(c: Circuit, wi: WiringInfo): Circuit = {
- // Split out WiringInfo
- val source = wi.source
- val sinks = wi.sinks
- val compName = wi.comp
- val pin = wi.pin
-
- // Maps modules to children instances, i.e. (instance, module)
- val childrenMap = getChildrenMap(c)
-
- // Check restrictions
- val nSources = countInstances(childrenMap, wi.top, source)
- if(nSources != 1)
- throw new WiringException(s"Cannot have $nSources instance of $source under ${wi.top}")
- sinks.foreach { m =>
- val total = countInstances(childrenMap, c.main, m)
- val nTop = countInstances(childrenMap, c.main, wi.top)
- val perTop = countInstances(childrenMap, wi.top, m)
- if(total != nTop * perTop)
- throw new WiringException(s"Module ${wi.top} does not contain all instances of $m.")
- }
-
- // Create valid port names for wiring that have no name conflicts
- val portNames = c.modules.foldLeft(Map.empty[String, String]) { (map, m) =>
- map + (m.name -> {
- val ns = Namespace(m)
- if(sinks.contains(m.name)) ns.newName(pin)
- else ns.newName(tokenize(compName) filterNot ("[]." contains _) mkString "_")
- })
- }
-
- // Create a lineage tree from children map
- val lineages = getLineage(childrenMap, wi.top)
-
- // Populate lineage tree with relationship information, i.e. who is source,
- // sink, parent of source, etc.
- val withFields = setSharedParent(wi.top)(setFields(sinks, source)(lineages))
-
- // Populate lineage tree with what to instantiate, connect to/from, etc.
- val withThings = setThings(portNames, compName)(withFields)
+ private def computeModifications(c: Circuit,
+ iGraph: InstanceGraph,
+ compName: String,
+ source: String,
+ sinks: Seq[Named],
+ portNames: Map[String, String]):
+ (Type, Map[String, Modifications]) = {
- // Create a map from module name to lineage information
- val map = pointToLineage(withThings)
-
- // Obtain the source component type
val sourceComponentType = getType(c, source, compName)
-
- // Return new circuit with correct wiring
- val cx = c.copy(modules = c.modules map onModule(map, sourceComponentType))
-
- // Replace inserted IR nodes with WIR nodes
- ToWorkingIR.run(cx)
+ val sinkComponents: Map[String, Seq[String]] = sinks
+ .collect{ case ComponentName(c, ModuleName(m, _)) => (c, m) }
+ .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]){
+ case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)) ) }
+
+ // Determine "ownership" of sources to sinks via minimum distance
+ val owners = sinksToSources(sinks, source, iGraph)
+
+ // Determine port and pending modifications for all sink--source
+ // ownership pairs
+ val meta = new mutable.HashMap[String, Modifications]
+ .withDefaultValue(Modifications())
+ owners.foreach { case (sink, source) =>
+ val lca = iGraph.lowestCommonAncestor(sink, source)
+
+ // Compute metadata along Sink to LCA paths.
+ sink.drop(lca.size - 1).sliding(2).toList.reverse.map {
+ case Seq(WDefInstance(_,_,pm,_), WDefInstance(_,ci,cm,_)) =>
+ val to = s"$ci.${portNames(cm)}"
+ val from = s"${portNames(pm)}"
+ meta(pm) = meta(pm).copy(
+ addPortOrWire = Some((portNames(pm), DecWire)),
+ cons = (meta(pm).cons :+ (to, from)).distinct
+ )
+ meta(cm) = meta(cm).copy(
+ addPortOrWire = Some((portNames(cm), DecInput))
+ )
+ // Case where the sink is the LCA
+ case Seq(WDefInstance(_,_,pm,_)) =>
+ // Case where the source is also the LCA
+ if (source.drop(lca.size).isEmpty) {
+ meta(pm) = meta(pm).copy (
+ addPortOrWire = Some((portNames(pm), DecWire))
+ )
+ } else {
+ val WDefInstance(_,ci,cm,_) = source.drop(lca.size).head
+ val to = s"${portNames(pm)}"
+ val from = s"$ci.${portNames(cm)}"
+ meta(pm) = meta(pm).copy(
+ addPortOrWire = Some((portNames(pm), DecWire)),
+ cons = (meta(pm).cons :+ (to, from)).distinct
+ )
+ }
+ }
+
+ // Compute metadata for the Sink
+ sink.last match { case WDefInstance(_, _, m, _) =>
+ if (sinkComponents.contains(m)) {
+ val from = s"${portNames(m)}"
+ sinkComponents(m).foreach( to =>
+ meta(m) = meta(m).copy(
+ cons = (meta(m).cons :+ (to, from)).distinct
+ )
+ )
+ }
+ }
+
+ // Compute metadata for the Source
+ source.last match { case WDefInstance(_, _, m, _) =>
+ val to = s"${portNames(m)}"
+ val from = compName
+ meta(m) = meta(m).copy(
+ cons = (meta(m).cons :+ (to, from)).distinct
+ )
+ }
+
+ // Compute metadata along Source to LCA path
+ source.drop(lca.size - 1).sliding(2).toList.reverse.map {
+ case Seq(WDefInstance(_,_,pm,_), WDefInstance(_,ci,cm,_)) => {
+ val to = s"${portNames(pm)}"
+ val from = s"$ci.${portNames(cm)}"
+ meta(pm) = meta(pm).copy(
+ cons = (meta(pm).cons :+ (to, from)).distinct
+ )
+ meta(cm) = meta(cm).copy(
+ addPortOrWire = Some((portNames(cm), DecOutput))
+ )
+ }
+ // Case where the source is the LCA
+ case Seq(WDefInstance(_,_,pm,_)) => {
+ // Case where the sink is also the LCA. We do nothing here,
+ // as we've created the connecting wire above
+ if (sink.drop(lca.size).isEmpty) {
+ } else {
+ val WDefInstance(_,ci,cm,_) = sink.drop(lca.size).head
+ val to = s"$ci.${portNames(cm)}"
+ val from = s"${portNames(pm)}"
+ meta(pm) = meta(pm).copy(
+ cons = (meta(pm).cons :+ (to, from)).distinct
+ )
+ }
+ }
+ }
+ }
+ (sourceComponentType, meta.toMap)
}
- /** Return new module with correct wiring
- */
- private def onModule(map: Map[String, Lineage], t: Type)(m: DefModule) = {
+ /** Apply modifications to a module */
+ private def onModule(t: Type, map: Map[String, Modifications])(m: DefModule) = {
map.get(m.name) match {
case None => m
case Some(l) =>
- val stmts = mutable.ArrayBuffer[Statement]()
+ val defines = mutable.ArrayBuffer[Statement]()
+ val connects = mutable.ArrayBuffer[Statement]()
val ports = mutable.ArrayBuffer[Port]()
- l.addPort match {
+ l.addPortOrWire match {
case None =>
case Some((s, dt)) => dt match {
- case DecInput => ports += Port(NoInfo, s, Input, t)
+ case DecInput => ports += Port(NoInfo, s, Input, t)
case DecOutput => ports += Port(NoInfo, s, Output, t)
- case DecWire =>
- stmts += DefWire(NoInfo, s, t)
+ case DecWire => defines += DefWire(NoInfo, s, t)
}
}
- stmts ++= (l.cons map { case ((l, r)) =>
+ connects ++= (l.cons map { case ((l, r)) =>
Connect(NoInfo, toExp(l), toExp(r))
})
- def onStmt(s: Statement): Statement = Block(Seq(s) ++ stmts)
m match {
- case Module(i, n, ps, s) => Module(i, n, ps ++ ports, Block(Seq(s) ++ stmts))
+ case Module(i, n, ps, s) => Module(i, n, ps ++ ports,
+ Block(defines ++ Seq(s) ++ connects))
case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps ++ ports, dn, p)
}
}
}
-
- /** Returns the type of the component specified
- */
- private def getType(c: Circuit, module: String, comp: String) = {
- def getRoot(e: Expression): String = e match {
- case r: Reference => r.name
- case i: SubIndex => getRoot(i.expr)
- case a: SubAccess => getRoot(a.expr)
- case f: SubField => getRoot(f.expr)
- }
- val eComp = toExp(comp)
- val root = getRoot(eComp)
- var tpe: Option[Type] = None
- def getType(s: Statement): Statement = s match {
- case DefRegister(_, n, t, _, _, _) if n == root =>
- tpe = Some(t)
- s
- case DefWire(_, n, t) if n == root =>
- tpe = Some(t)
- s
- case WDefInstance(_, n, m, t) if n == root =>
- tpe = Some(t)
- s
- case DefNode(_, n, e) if n == root =>
- tpe = Some(e.tpe)
- s
- case sx: DefMemory if sx.name == root =>
- tpe = Some(MemPortUtils.memType(sx))
- sx
- case sx => sx map getType
- }
- val m = c.modules find (_.name == module) getOrElse error(s"Must have a module named $module")
- tpe = m.ports find (_.name == root) map (_.tpe)
- m match {
- case Module(i, n, ps, b) => getType(b)
- case e: ExtModule =>
- }
- tpe match {
- case None => error(s"Didn't find $comp in $module!")
- case Some(t) =>
- def setType(e: Expression): Expression = e map setType match {
- case ex: Reference => ex.copy(tpe = t)
- case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name))
- case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe))
- case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe))
- }
- setType(eComp).tpe
- }
- }
}
diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
index a8ef5f58..01e6f83a 100644
--- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
+++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
@@ -11,86 +11,79 @@ import scala.collection.mutable
import firrtl.annotations._
import WiringUtils._
-/** A component, e.g. register etc. Must be declared only once under the TopAnnotation
- */
+/** A class for all exceptions originating from firrtl.passes.wiring */
+case class WiringException(msg: String) extends PassException(msg)
+
+/** An extractor of annotated source components */
object SourceAnnotation {
- def apply(target: ComponentName, pin: String): Annotation = Annotation(target, classOf[WiringTransform], s"source $pin")
+ def apply(target: ComponentName, pin: String): Annotation =
+ Annotation(target, classOf[WiringTransform], s"source $pin")
private val matcher = "source (.+)".r
def unapply(a: Annotation): Option[(ComponentName, String)] = a match {
- case Annotation(ComponentName(n, m), _, matcher(pin)) => Some((ComponentName(n, m), pin))
+ case Annotation(ComponentName(n, m), _, matcher(pin)) =>
+ Some((ComponentName(n, m), pin))
case _ => None
}
}
-/** A module, e.g. ExtModule etc., that should add the input pin
- */
+/** An extractor of annotation sink components or modules */
object SinkAnnotation {
- def apply(target: ModuleName, pin: String): Annotation = Annotation(target, classOf[WiringTransform], s"sink $pin")
+ def apply(target: Named, pin: String): Annotation =
+ Annotation(target, classOf[WiringTransform], s"sink $pin")
private val matcher = "sink (.+)".r
- def unapply(a: Annotation): Option[(ModuleName, String)] = a match {
- case Annotation(ModuleName(n, c), _, matcher(pin)) => Some((ModuleName(n, c), pin))
+ def unapply(a: Annotation): Option[(Named, String)] = a match {
+ case Annotation(ModuleName(n, c), _, matcher(pin)) =>
+ Some((ModuleName(n, c), pin))
+ case Annotation(ComponentName(n, m), _, matcher(pin)) =>
+ Some((ComponentName(n, m), pin))
case _ => None
}
}
-/** A module under which all sink module must be declared, and there is only
- * one source component
+/** Wires a Module's Source Component to one or more Sink
+ * Modules/Components
+ *
+ * Sinks are wired to their closest source through their lowest
+ * common ancestor (LCA). Verbosely, this modifies the circuit in
+ * the following ways:
+ * - Adds a pin to each sink module
+ * - Punches ports up from source signals to the LCA
+ * - Punches ports down from LCAs to each sink module
+ * - Wires sources up to LCA, sinks down from LCA, and across each LCA
+ *
+ * @throws WiringException if a sink is equidistant to two sources
*/
-object TopAnnotation {
- def apply(target: ModuleName, pin: String): Annotation = Annotation(target, classOf[WiringTransform], s"top $pin")
+class WiringTransform extends Transform {
+ def inputForm: CircuitForm = MidForm
+ def outputForm: CircuitForm = HighForm
- private val matcher = "top (.+)".r
- def unapply(a: Annotation): Option[(ModuleName, String)] = a match {
- case Annotation(ModuleName(n, c), _, matcher(pin)) => Some((ModuleName(n, c), pin))
- case _ => None
- }
-}
+ /** Defines the sequence of Transform that should be applied */
+ private def transforms(w: Seq[WiringInfo]): Seq[Transform] = Seq(
+ new Wiring(w),
+ ToWorkingIR
+ )
-/** Add pins to modules and wires a signal to them, under the scope of a specified top module
- * Description:
- * Adds a pin to each sink module
- * Punches ports up from the source signal to the specified top module
- * Punches ports down to each sink module
- * Wires the source up and down, connecting to all sink modules
- * Restrictions:
- * - Can only have one source module instance under scope of the specified top
- * - All instances of each sink module must be under the scope of the specified top
- * Notes:
- * - No module uniquification occurs (due to imposed restrictions)
- */
-class WiringTransform extends Transform {
- def inputForm = MidForm
- def outputForm = MidForm
- def transforms(wis: Seq[WiringInfo]) =
- Seq(new Wiring(wis),
- InferTypes,
- ResolveKinds,
- ResolveGenders)
def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match {
case Nil => state
- case p =>
- val sinks = mutable.HashMap[String, Set[String]]()
- val sources = mutable.HashMap[String, String]()
- val tops = mutable.HashMap[String, String]()
- val comp = mutable.HashMap[String, String]()
- p.foreach {
+ case p =>
+ val sinks = mutable.HashMap[String, Seq[Named]]()
+ val sources = mutable.HashMap[String, ComponentName]()
+ p.foreach {
case SinkAnnotation(m, pin) =>
- sinks(pin) = sinks.getOrElse(pin, Set.empty) + m.name
+ sinks(pin) = sinks.getOrElse(pin, Seq.empty) :+ m
case SourceAnnotation(c, pin) =>
- sources(pin) = c.module.name
- comp(pin) = c.name
- case TopAnnotation(m, pin) => tops(pin) = m.name
+ sources(pin) = c
}
- (sources.size, tops.size, sinks.size, comp.size) match {
- case (0, 0, p, 0) => state
- case (s, t, p, c) if (p > 0) & (s == t) & (t == c) =>
- val wis = tops.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, top)) =>
- seq :+ WiringInfo(sources(pin), comp(pin), sinks(pin), pin, top)
+ (sources.size, sinks.size) match {
+ case (0, p) => state
+ case (s, p) if (p > 0) =>
+ val wis = sources.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, source)) =>
+ seq :+ WiringInfo(source, sinks(pin), pin)
}
transforms(wis).foldLeft(state) { (in, xform) => xform.runTransform(in) }
- case _ => error("Wrong number of sources, tops, or sinks!")
+ case _ => error("Wrong number of sources or sinks!")
}
}
}
diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
index 29c93ca7..117a3824 100644
--- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
+++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
@@ -9,6 +9,9 @@ import firrtl.Utils._
import firrtl.Mappers._
import scala.collection.mutable
import firrtl.annotations._
+import firrtl.annotations.AnnotationUtils._
+import firrtl.analyses.InstanceGraph
+import firrtl.graph.DiGraph
import WiringUtils._
/** Declaration kind in lineage (e.g. input port, output port, wire)
@@ -18,6 +21,19 @@ case object DecInput extends DecKind
case object DecOutput extends DecKind
case object DecWire extends DecKind
+/** Store of pending wiring information for a Module */
+case class Modifications(
+ addPortOrWire: Option[(String, DecKind)] = None,
+ cons: Seq[(String, String)] = Seq.empty) {
+
+ override def toString: String = serialize("")
+
+ def serialize(tab: String): String = s"""
+ |$tab addPortOrWire: $addPortOrWire
+ |$tab cons: $cons
+ |""".stripMargin
+}
+
/** A lineage tree representing the instance hierarchy in a design
*/
case class Lineage(
@@ -41,7 +57,7 @@ case class Lineage(
|$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))}
|""".stripMargin
- def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B =
+ def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B =
this.children.foldLeft(z)(op)
def serialize(tab: String): String = s"""
@@ -57,9 +73,6 @@ case class Lineage(
|""".stripMargin
}
-
-
-
object WiringUtils {
type ChildrenMap = mutable.HashMap[String, Seq[(String, String)]]
@@ -69,10 +82,10 @@ object WiringUtils {
def getChildrenMap(c: Circuit): ChildrenMap = {
val childrenMap = new ChildrenMap()
def getChildren(mname: String)(s: Statement): Statement = s match {
- case s: WDefInstance =>
+ case s: WDefInstance =>
childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module)
s
- case s: DefInstance =>
+ case s: DefInstance =>
childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module)
s
case s => s map getChildren(mname)
@@ -84,86 +97,151 @@ object WiringUtils {
childrenMap
}
- /** Counts the number of instances of a module declared under a top module
- */
- def countInstances(childrenMap: ChildrenMap, top: String, module: String): Int = {
- if(top == module) 1
- else childrenMap(top).foldLeft(0) { case (count, (i, child)) =>
- count + countInstances(childrenMap, child, module)
- }
- }
-
/** Returns a module's lineage, containing all children lineages as well
*/
def getLineage(childrenMap: ChildrenMap, module: String): Lineage =
Lineage(module, childrenMap(module) map { case (i, m) => (i, getLineage(childrenMap, m)) } )
- /** Sets the sink, sinkParent, source, and sourceParent fields of every
- * Lineage in tree
+ /** Return a map of sink instances to source instances that minimizes
+ * distance
+ *
+ * @param sinks a sequence of sink modules
+ * @param source the source module
+ * @param i a graph representing a circuit
+ * @return a map of sink instance names to source instance names
+ * @throws WiringException if a sink is equidistant to two sources
*/
- def setFields(sinks: Set[String], source: String)(lin: Lineage): Lineage = lin map setFields(sinks, source) match {
- case l if sinks.contains(l.name) => l.copy(sink = true)
- case l =>
- val src = l.name == source
- val sinkParent = l.children.foldLeft(false) { case (b, (i, m)) => b || m.sink || m.sinkParent }
- val sourceParent = if(src) true else l.children.foldLeft(false) { case (b, (i, m)) => b || m.source || m.sourceParent }
- l.copy(sinkParent=sinkParent, sourceParent=sourceParent, source=src)
- }
+ def sinksToSources(sinks: Seq[Named],
+ source: String,
+ i: InstanceGraph):
+ Map[Seq[WDefInstance], Seq[WDefInstance]] = {
+ val owners = new mutable.HashMap[Seq[WDefInstance], Vector[Seq[WDefInstance]]]
+ .withDefaultValue(Vector())
+ val queue = new mutable.Queue[Seq[WDefInstance]]
+ val visited = new mutable.HashMap[Seq[WDefInstance], Boolean]
+ .withDefaultValue(false)
+
+ i.fullHierarchy.keys.filter { case WDefInstance(_,_,m,_) => m == source }
+ .foreach( i.fullHierarchy(_)
+ .foreach { l =>
+ queue.enqueue(l)
+ owners(l) = Vector(l)
+ }
+ )
+
+ val sinkInsts = i.fullHierarchy.keys
+ .filter { case WDefInstance(_, _, module, _) =>
+ sinks.map(getModuleName(_)).contains(module) }
+ .flatMap { k => i.fullHierarchy(k) }
+ .toSet
+
+ /** If we're lucky and there is only one source, then that source owns
+ * all sinks. If we're unlucky, we need to do a full (slow) BFS
+ * to figure out who owns what. Currently, the BFS is not
+ * performant.
+ *
+ * [todo] The performance of this will need to be improved.
+ * Possible directions are that if we're purely source-under-sink
+ * or sink-under-source, then ownership is trivially a mapping
+ * down/up. Ownership seems to require a BFS if we have
+ * sources/sinks not under sinks/sources.
+ */
+ if (queue.size == 1) {
+ val u = queue.dequeue
+ sinkInsts.foreach { v => owners(v) = Vector(u) }
+ } else {
+ while (queue.nonEmpty) {
+ val u = queue.dequeue
+ visited(u) = true
+
+ val edges = (i.graph.getEdges(u.last).map(u :+ _).toVector :+ u.dropRight(1))
+
+ // [todo] This is the critical section
+ edges
+ .filter( e => !visited(e) && e.nonEmpty )
+ .foreach{ v =>
+ owners(v) = owners(v) ++ owners(u)
+ queue.enqueue(v)
+ }
+ }
+
+ // Check that every sink has one unique owner. The only time that
+ // this should fail is if a sink is equidistant to two sources.
+ sinkInsts.foreach { s =>
+ if (!owners.contains(s) || owners(s).size > 1) {
+ throw new WiringException(
+ s"Unable to determine source mapping for sink '${s.map(_.name)}'") }
+ }
+ }
- /** Sets the sharedParent of lineage top
- */
- def setSharedParent(top: String)(lin: Lineage): Lineage = lin map setSharedParent(top) match {
- case l if l.name == top => l.copy(sharedParent = true)
- case l => l
+ owners
+ .collect { case (k, v) if sinkInsts.contains(k) => (k, v.flatten) }.toMap
}
- /** Sets the addPort and cons fields of the lineage tree
- */
- def setThings(portNames:Map[String, String], compName: String)(lin: Lineage): Lineage = {
- val funs = Seq(
- ((l: Lineage) => l map setThings(portNames, compName)),
- ((l: Lineage) => l match {
- case Lineage(name, _, _, _, _, _, true, _, _) => //SharedParent
- l.copy(addPort=Some((portNames(name), DecWire)))
- case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent
- l.copy(addPort=Some((portNames(name), DecOutput)))
- case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent
- l.copy(addPort=Some((portNames(name), DecInput)))
- case Lineage(name, _, _, true, _, _, _, _, _) => //Sink
- l.copy(addPort=Some((portNames(name), DecInput)))
- case l => l
- }),
- ((l: Lineage) => l match {
- case Lineage(name, _, true, _, _, _, _, _, _) => //Source
- val tos = Seq(s"${portNames(name)}")
- val from = compName
- l.copy(cons = l.cons ++ tos.map(t => (t, from)))
- case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent
- val tos = Seq(s"${portNames(name)}")
- val from = l.children.filter { case (i, c) => c.sourceParent }.map { case (i, c) => s"$i.${portNames(c.name)}" }.head
- l.copy(cons = l.cons ++ tos.map(t => (t, from)))
- case l => l
- }),
- ((l: Lineage) => l match {
- case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent
- val tos = l.children.filter { case (i, c) => (c.sinkParent || c.sink) && !c.sourceParent } map { case (i, c) => s"$i.${portNames(c.name)}" }
- val from = s"${portNames(name)}"
- l.copy(cons = l.cons ++ tos.map(t => (t, from)))
- case l => l
- })
- )
- funs.foldLeft(lin)((l, fun) => fun(l))
+ /** Helper script to extract a module name from a named Module or Component */
+ def getModuleName(n: Named): String = {
+ n match {
+ case ModuleName(m, _) => m
+ case ComponentName(_, ModuleName(m, _)) => m
+ case _ => throw new WiringException(
+ "Only Components or Modules have an associated Module name")
+ }
}
- /** Return a map from module to its lineage in the tree
+ /** Determine the Type of a specific component
+ *
+ * @param c the circuit containing the target module
+ * @param module the module containing the target component
+ * @param comp the target component
+ * @return the component's type
+ * @throws WiringException if the module is not contained in the
+ * circuit or if the component is not contained in the module
*/
- def pointToLineage(lin: Lineage): Map[String, Lineage] = {
- val map = mutable.HashMap[String, Lineage]()
- def onLineage(l: Lineage): Lineage = {
- map(l.name) = l
- l map onLineage
+ def getType(c: Circuit, module: String, comp: String): Type = {
+ def getRoot(e: Expression): String = e match {
+ case r: Reference => r.name
+ case i: SubIndex => getRoot(i.expr)
+ case a: SubAccess => getRoot(a.expr)
+ case f: SubField => getRoot(f.expr)
+ }
+ val eComp = toExp(comp)
+ val root = getRoot(eComp)
+ var tpe: Option[Type] = None
+ def getType(s: Statement): Statement = s match {
+ case DefRegister(_, n, t, _, _, _) if n == root =>
+ tpe = Some(t)
+ s
+ case DefWire(_, n, t) if n == root =>
+ tpe = Some(t)
+ s
+ case WDefInstance(_, n, _, t) if n == root =>
+ tpe = Some(t)
+ s
+ case DefNode(_, n, e) if n == root =>
+ tpe = Some(e.tpe)
+ s
+ case sx: DefMemory if sx.name == root =>
+ tpe = Some(MemPortUtils.memType(sx))
+ sx
+ case sx => sx map getType
+ }
+ val m = c.modules find (_.name == module) getOrElse {
+ throw new WiringException(s"Must have a module named $module") }
+ tpe = m.ports find (_.name == root) map (_.tpe)
+ m match {
+ case Module(i, n, ps, b) => getType(b)
+ case e: ExtModule =>
+ }
+ tpe match {
+ case None => throw new WiringException(s"Didn't find $comp in $module!")
+ case Some(t) =>
+ def setType(e: Expression): Expression = e map setType match {
+ case ex: Reference => ex.copy(tpe = t)
+ case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name))
+ case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe))
+ case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe))
+ }
+ setType(eComp).tpe
}
- onLineage(lin)
- map.toMap
}
}
diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala
index 01ad573f..5dd048a3 100644
--- a/src/test/scala/firrtlTests/WiringTests.scala
+++ b/src/test/scala/firrtlTests/WiringTests.scala
@@ -33,86 +33,87 @@ class WiringTests extends FirrtlFlatSpec {
InferWidths
)
- "Wiring from r to X" should "work" in {
- val sinks = Set("X")
- val sas = WiringInfo("C", "r", sinks, "pin", "A")
+ it should "wire from a register source (r) to multiple extmodule sinks (X)" in {
+ val sinks = Seq(ModuleName("X", CircuitName("Top")))
+ val source = ComponentName("r", ModuleName("C", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
val input =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a of A
- | a.clock <= clock
- | module A :
- | input clock: Clock
- | inst b of B
- | b.clock <= clock
- | inst x of X
- | x.clock <= clock
- | inst d of D
- | d.clock <= clock
- | module B :
- | input clock: Clock
- | inst c of C
- | c.clock <= clock
- | inst d of D
- | d.clock <= clock
- | module C :
- | input clock: Clock
- | reg r: UInt<5>, clock
- | module D :
- | input clock: Clock
- | inst x1 of X
- | x1.clock <= clock
- | inst x2 of X
- | x2.clock <= clock
- | extmodule X :
- | input clock: Clock
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | inst b of B
+ | b.clock <= clock
+ | inst x of X
+ | x.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | module B :
+ | input clock: Clock
+ | inst c of C
+ | c.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | module C :
+ | input clock: Clock
+ | reg r: UInt<5>, clock
+ | module D :
+ | input clock: Clock
+ | inst x1 of X
+ | x1.clock <= clock
+ | inst x2 of X
+ | x2.clock <= clock
+ | extmodule X :
+ | input clock: Clock
+ |""".stripMargin
val check =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a of A
- | a.clock <= clock
- | module A :
- | input clock: Clock
- | inst b of B
- | b.clock <= clock
- | inst x of X
- | x.clock <= clock
- | inst d of D
- | d.clock <= clock
- | wire r: UInt<5>
- | r <= b.r
- | x.pin <= r
- | d.r <= r
- | module B :
- | input clock: Clock
- | output r: UInt<5>
- | inst c of C
- | c.clock <= clock
- | inst d of D
- | d.clock <= clock
- | r <= c.r_0
- | d.r <= r
- | module C :
- | input clock: Clock
- | output r_0: UInt<5>
- | reg r: UInt<5>, clock
- | r_0 <= r
- | module D :
- | input clock: Clock
- | input r: UInt<5>
- | inst x1 of X
- | x1.clock <= clock
- | inst x2 of X
- | x2.clock <= clock
- | x1.pin <= r
- | x2.pin <= r
- | extmodule X :
- | input clock: Clock
- | input pin: UInt<5>
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire r: UInt<5>
+ | inst b of B
+ | b.clock <= clock
+ | inst x of X
+ | x.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | d.r <= r
+ | r <= b.r
+ | x.pin <= r
+ | module B :
+ | input clock: Clock
+ | output r: UInt<5>
+ | inst c of C
+ | c.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | r <= c.r_0
+ | d.r <= r
+ | module C :
+ | input clock: Clock
+ | output r_0: UInt<5>
+ | reg r: UInt<5>, clock
+ | r_0 <= r
+ | module D :
+ | input clock: Clock
+ | input r: UInt<5>
+ | inst x1 of X
+ | x1.clock <= clock
+ | inst x2 of X
+ | x2.clock <= clock
+ | x1.pin <= r
+ | x2.pin <= r
+ | extmodule X :
+ | input clock: Clock
+ | input pin: UInt<5>
+ |""".stripMargin
val c = passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
}
@@ -121,41 +122,87 @@ class WiringTests extends FirrtlFlatSpec {
(parse(retC.serialize).serialize) should be (parse(check).serialize)
}
- "Wiring from r.x to X" should "work" in {
- val sinks = Set("X")
- val sas = WiringInfo("A", "r.x", sinks, "pin", "A")
+ it should "wire from a register source (r) to multiple module sinks (X)" in {
+ val sinks = Seq(ModuleName("X", CircuitName("Top")))
+ val source = ComponentName("r", ModuleName("C", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
val input =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a of A
- | a.clock <= clock
- | module A :
- | input clock: Clock
- | reg r : {x: UInt<5>}, clock
- | inst x of X
- | x.clock <= clock
- | extmodule X :
- | input clock: Clock
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | inst b of B
+ | b.clock <= clock
+ | inst x of X
+ | x.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | module B :
+ | input clock: Clock
+ | inst c of C
+ | c.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | module C :
+ | input clock: Clock
+ | reg r: UInt<5>, clock
+ | module D :
+ | input clock: Clock
+ | inst x1 of X
+ | x1.clock <= clock
+ | inst x2 of X
+ | x2.clock <= clock
+ | module X :
+ | input clock: Clock
+ |""".stripMargin
val check =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a of A
- | a.clock <= clock
- | module A :
- | input clock: Clock
- | reg r: {x: UInt<5>}, clock
- | inst x of X
- | x.clock <= clock
- | wire r_x: UInt<5>
- | r_x <= r.x
- | x.pin <= r_x
- | extmodule X :
- | input clock: Clock
- | input pin: UInt<5>
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire r: UInt<5>
+ | inst b of B
+ | b.clock <= clock
+ | inst x of X
+ | x.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | d.r <= r
+ | r <= b.r
+ | x.pin <= r
+ | module B :
+ | input clock: Clock
+ | output r: UInt<5>
+ | inst c of C
+ | c.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | r <= c.r_0
+ | d.r <= r
+ | module C :
+ | input clock: Clock
+ | output r_0: UInt<5>
+ | reg r: UInt<5>, clock
+ | r_0 <= r
+ | module D :
+ | input clock: Clock
+ | input r: UInt<5>
+ | inst x1 of X
+ | x1.clock <= clock
+ | inst x2 of X
+ | x2.clock <= clock
+ | x1.pin <= r
+ | x2.pin <= r
+ | module X :
+ | input clock: Clock
+ | input pin: UInt<5>
+ |""".stripMargin
val c = passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
}
@@ -163,39 +210,91 @@ class WiringTests extends FirrtlFlatSpec {
val retC = wiringPass.run(c)
(parse(retC.serialize).serialize) should be (parse(check).serialize)
}
- "Wiring from clock to X" should "work" in {
- val sinks = Set("X")
- val sas = WiringInfo("A", "clock", sinks, "pin", "A")
+
+ it should "wire from a register sink (r) to a wire source (s) in another module (X)" in {
+ val sinks = Seq(ComponentName("s", ModuleName("X", CircuitName("Top"))))
+ val source = ComponentName("r", ModuleName("C", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
val input =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a of A
- | a.clock <= clock
- | module A :
- | input clock: Clock
- | inst x of X
- | x.clock <= clock
- | extmodule X :
- | input clock: Clock
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | inst b of B
+ | b.clock <= clock
+ | inst x of X
+ | x.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | module B :
+ | input clock: Clock
+ | inst c of C
+ | c.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | module C :
+ | input clock: Clock
+ | reg r: UInt<5>, clock
+ | module D :
+ | input clock: Clock
+ | inst x1 of X
+ | x1.clock <= clock
+ | inst x2 of X
+ | x2.clock <= clock
+ | module X :
+ | input clock: Clock
+ | wire s: UInt<5>
+ |""".stripMargin
val check =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a of A
- | a.clock <= clock
- | module A :
- | input clock: Clock
- | inst x of X
- | x.clock <= clock
- | wire clock_0: Clock
- | clock_0 <= clock
- | x.pin <= clock_0
- | extmodule X :
- | input clock: Clock
- | input pin: Clock
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire r: UInt<5>
+ | inst b of B
+ | b.clock <= clock
+ | inst x of X
+ | x.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | d.r <= r
+ | r <= b.r
+ | x.pin <= r
+ | module B :
+ | input clock: Clock
+ | output r: UInt<5>
+ | inst c of C
+ | c.clock <= clock
+ | inst d of D
+ | d.clock <= clock
+ | r <= c.r_0
+ | d.r <= r
+ | module C :
+ | input clock: Clock
+ | output r_0: UInt<5>
+ | reg r: UInt<5>, clock
+ | r_0 <= r
+ | module D :
+ | input clock: Clock
+ | input r: UInt<5>
+ | inst x1 of X
+ | x1.clock <= clock
+ | inst x2 of X
+ | x2.clock <= clock
+ | x1.pin <= r
+ | x2.pin <= r
+ | module X :
+ | input clock: Clock
+ | input pin: UInt<5>
+ | wire s: UInt<5>
+ | s <= pin
+ |""".stripMargin
val c = passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
}
@@ -203,69 +302,210 @@ class WiringTests extends FirrtlFlatSpec {
val retC = wiringPass.run(c)
(parse(retC.serialize).serialize) should be (parse(check).serialize)
}
- "Two sources" should "fail" in {
- val sinks = Set("X")
- val sas = WiringInfo("A", "clock", sinks, "pin", "Top")
+
+ it should "wire from a SubField source (r.x) to an extmodule sink (X)" in {
+ val sinks = Seq(ModuleName("X", CircuitName("Top")))
+ val source = ComponentName("r.x", ModuleName("A", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
val input =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a1 of A
- | a1.clock <= clock
- | inst a2 of A
- | a2.clock <= clock
- | module A :
- | input clock: Clock
- | inst x of X
- | x.clock <= clock
- | extmodule X :
- | input clock: Clock
- |""".stripMargin
- intercept[WiringException] {
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
- val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | reg r : {x: UInt<5>}, clock
+ | inst x of X
+ | x.clock <= clock
+ | extmodule X :
+ | input clock: Clock
+ |""".stripMargin
+ val check =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire r_x: UInt<5>
+ | reg r: {x: UInt<5>}, clock
+ | inst x of X
+ | x.clock <= clock
+ | x.pin <= r_x
+ | r_x <= r.x
+ | extmodule X :
+ | input clock: Clock
+ | input pin: UInt<5>
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(Seq(sas))
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+
+ it should "wire properly with a source as a submodule of a sink" in {
+ val sinks = Seq(ComponentName("s", ModuleName("A", CircuitName("Top"))))
+ val source = ComponentName("r", ModuleName("X", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
+ val input =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire s: UInt<5>
+ | inst x of X
+ | x.clock <= clock
+ | module X :
+ | input clock: Clock
+ | reg r: UInt<5>, clock
+ |""".stripMargin
+ val check =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire pin: UInt<5>
+ | wire s: UInt<5>
+ | inst x of X
+ | x.clock <= clock
+ | pin <= x.r_0
+ | s <= pin
+ | module X :
+ | input clock: Clock
+ | output r_0: UInt<5>
+ | reg r: UInt<5>, clock
+ | r_0 <= r
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(Seq(sas))
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+
+ it should "wire with source and sink in the same module" in {
+ val sinks = Seq(ComponentName("s", ModuleName("A", CircuitName("Top"))))
+ val source = ComponentName("r", ModuleName("A", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
+ val input =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire s: UInt<5>
+ | reg r: UInt<5>, clock
+ |""".stripMargin
+ val check =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire pin: UInt<5>
+ | wire s: UInt<5>
+ | reg r: UInt<5>, clock
+ | s <= pin
+ | pin <= r
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(Seq(sas))
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+
+ it should "wire multiple sinks in the same module" in {
+ val sinks = Seq(ComponentName("s", ModuleName("A", CircuitName("Top"))),
+ ComponentName("t", ModuleName("A", CircuitName("Top"))))
+ val source = ComponentName("r", ModuleName("A", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
+ val input =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire s: UInt<5>
+ | wire t: UInt<5>
+ | reg r: UInt<5>, clock
+ |""".stripMargin
+ val check =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire pin: UInt<5>
+ | wire s: UInt<5>
+ | wire t: UInt<5>
+ | reg r: UInt<5>, clock
+ | t <= pin
+ | s <= pin
+ | pin <= r
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
}
+ val wiringPass = new Wiring(Seq(sas))
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
}
- "Wiring from A.clock to X, with 2 A's, and A as top" should "work" in {
- val sinks = Set("X")
- val sas = WiringInfo("A", "clock", sinks, "pin", "A")
+
+ it should "wire clocks" in {
+ val sinks = Seq(ModuleName("X", CircuitName("Top")))
+ val source = ComponentName("clock", ModuleName("A", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
val input =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a1 of A
- | a1.clock <= clock
- | inst a2 of A
- | a2.clock <= clock
- | module A :
- | input clock: Clock
- | inst x of X
- | x.clock <= clock
- | extmodule X :
- | input clock: Clock
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | inst x of X
+ | x.clock <= clock
+ | extmodule X :
+ | input clock: Clock
+ |""".stripMargin
val check =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a1 of A
- | a1.clock <= clock
- | inst a2 of A
- | a2.clock <= clock
- | module A :
- | input clock: Clock
- | inst x of X
- | x.clock <= clock
- | wire clock_0: Clock
- | clock_0 <= clock
- | x.pin <= clock_0
- | extmodule X :
- | input clock: Clock
- | input pin: Clock
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire clock_0: Clock
+ | inst x of X
+ | x.clock <= clock
+ | x.pin <= clock_0
+ | clock_0 <= clock
+ | extmodule X :
+ | input clock: Clock
+ | input pin: Clock
+ |""".stripMargin
val c = passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
}
@@ -273,26 +513,120 @@ class WiringTests extends FirrtlFlatSpec {
val retC = wiringPass.run(c)
(parse(retC.serialize).serialize) should be (parse(check).serialize)
}
- "Wiring from A.clock to X, with 2 A's, and A as top, but Top instantiates X" should "error" in {
- val sinks = Set("X")
- val sas = WiringInfo("A", "clock", sinks, "pin", "A")
+
+ it should "handle two source instances with clearly defined sinks" in {
+ val sinks = Seq(ModuleName("X", CircuitName("Top")))
+ val source = ComponentName("clock", ModuleName("A", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
val input =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a1 of A
- | a1.clock <= clock
- | inst a2 of A
- | a2.clock <= clock
- | inst x of X
- | x.clock <= clock
- | module A :
- | input clock: Clock
- | inst x of X
- | x.clock <= clock
- | extmodule X :
- | input clock: Clock
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a1 of A
+ | a1.clock <= clock
+ | inst a2 of A
+ | a2.clock <= clock
+ | module A :
+ | input clock: Clock
+ | inst x of X
+ | x.clock <= clock
+ | extmodule X :
+ | input clock: Clock
+ |""".stripMargin
+ val check =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a1 of A
+ | a1.clock <= clock
+ | inst a2 of A
+ | a2.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire clock_0: Clock
+ | inst x of X
+ | x.clock <= clock
+ | x.pin <= clock_0
+ | clock_0 <= clock
+ | extmodule X :
+ | input clock: Clock
+ | input pin: Clock
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(Seq(sas))
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+
+ it should "wire multiple clocks" in {
+ val sinks = Seq(ModuleName("X", CircuitName("Top")))
+ val source = ComponentName("clock", ModuleName("A", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
+ val input =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a1 of A
+ | a1.clock <= clock
+ | inst a2 of A
+ | a2.clock <= clock
+ | module A :
+ | input clock: Clock
+ | inst x of X
+ | x.clock <= clock
+ | extmodule X :
+ | input clock: Clock
+ |""".stripMargin
+ val check =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a1 of A
+ | a1.clock <= clock
+ | inst a2 of A
+ | a2.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire clock_0: Clock
+ | inst x of X
+ | x.clock <= clock
+ | x.pin <= clock_0
+ | clock_0 <= clock
+ | extmodule X :
+ | input clock: Clock
+ | input pin: Clock
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(Seq(sas))
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+
+ it should "error with WiringException for indeterminate ownership" in {
+ val sinks = Seq(ModuleName("X", CircuitName("Top")))
+ val source = ComponentName("clock", ModuleName("A", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
+ val input =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a1 of A
+ | a1.clock <= clock
+ | inst a2 of A
+ | a2.clock <= clock
+ | inst x of X
+ | x.clock <= clock
+ | module A :
+ | input clock: Clock
+ | inst x of X
+ | x.clock <= clock
+ | extmodule X :
+ | input clock: Clock
+ |""".stripMargin
intercept[WiringException] {
val c = passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
@@ -301,43 +635,45 @@ class WiringTests extends FirrtlFlatSpec {
val retC = wiringPass.run(c)
}
}
- "Wiring from A.r[a] to X" should "work" in {
- val sinks = Set("X")
- val sas = WiringInfo("A", "r[a]", sinks, "pin", "A")
+
+ it should "wire subindex source to sink" in {
+ val sinks = Seq(ModuleName("X", CircuitName("Top")))
+ val source = ComponentName("r[a]", ModuleName("A", CircuitName("Top")))
+ val sas = WiringInfo(source, sinks, "pin")
val input =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a of A
- | a.clock <= clock
- | module A :
- | input clock: Clock
- | reg r: UInt<2>[5], clock
- | node a = UInt(5)
- | inst x of X
- | x.clock <= clock
- | extmodule X :
- | input clock: Clock
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | reg r: UInt<2>[5], clock
+ | node a = UInt(5)
+ | inst x of X
+ | x.clock <= clock
+ | extmodule X :
+ | input clock: Clock
+ |""".stripMargin
val check =
- """circuit Top :
- | module Top :
- | input clock: Clock
- | inst a of A
- | a.clock <= clock
- | module A :
- | input clock: Clock
- | reg r: UInt<2>[5], clock
- | node a = UInt(5)
- | inst x of X
- | x.clock <= clock
- | wire r_a: UInt<2>
- | r_a <= r[a]
- | x.pin <= r_a
- | extmodule X :
- | input clock: Clock
- | input pin: UInt<2>
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire r_a: UInt<2>
+ | reg r: UInt<2>[5], clock
+ | node a = UInt(5)
+ | inst x of X
+ | x.clock <= clock
+ | x.pin <= r_a
+ | r_a <= r[a]
+ | extmodule X :
+ | input clock: Clock
+ | input pin: UInt<2>
+ |""".stripMargin
val c = passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
}
@@ -346,37 +682,182 @@ class WiringTests extends FirrtlFlatSpec {
(parse(retC.serialize).serialize) should be (parse(check).serialize)
}
- "Wiring annotations" should "work" in {
+ it should "wire using Annotations with a sink module" in {
val source = SourceAnnotation(ComponentName("r", ModuleName("Top", CircuitName("Top"))), "pin")
val sink = SinkAnnotation(ModuleName("X", CircuitName("Top")), "pin")
- val top = TopAnnotation(ModuleName("Top", CircuitName("Top")), "pin")
val input =
- """circuit Top :
- | module Top :
- | input clk: Clock
- | inst x of X
- | reg r: UInt<5>, clk
- | extmodule X :
- | input clk: Clock
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst x of X
+ | x.clk <= clk
+ | reg r: UInt<5>, clk
+ | extmodule X :
+ | input clk: Clock
+ |""".stripMargin
val check =
- """circuit Top :
- | module Top :
- | input clk: Clock
- | inst x of X
- | reg r: UInt<5>, clk
- | wire r_0 : UInt<5>
- | r_0 <= r
- | x.pin <= r_0
- | extmodule X :
- | input clk: Clock
- | input pin: UInt<5>
- |""".stripMargin
+ """|circuit Top :
+ | module Top :
+ | input clk: Clock
+ | wire r_0 : UInt<5>
+ | inst x of X
+ | x.clk <= clk
+ | reg r: UInt<5>, clk
+ | x.pin <= r_0
+ | r_0 <= r
+ | extmodule X :
+ | input clk: Clock
+ | input pin: UInt<5>
+ |""".stripMargin
val c = passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
}
val wiringXForm = new WiringTransform()
- val retC = wiringXForm.execute(CircuitState(c, LowForm, Some(AnnotationMap(Seq(source, sink, top))), None)).circuit
+ val retC = wiringXForm.execute(CircuitState(c, MidForm, Some(AnnotationMap(Seq(source, sink))), None)).circuit
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+
+ it should "wire using Annotations with a sink component" in {
+ val source = SourceAnnotation(ComponentName("r", ModuleName("Top", CircuitName("Top"))), "pin")
+ val sink = SinkAnnotation(ComponentName("s", ModuleName("X", CircuitName("Top"))), "pin")
+ val input =
+ """|circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst x of X
+ | x.clk <= clk
+ | reg r: UInt<5>, clk
+ | module X :
+ | input clk: Clock
+ | wire s: UInt<5>
+ |""".stripMargin
+ val check =
+ """|circuit Top :
+ | module Top :
+ | input clk: Clock
+ | wire r_0 : UInt<5>
+ | inst x of X
+ | x.clk <= clk
+ | reg r: UInt<5>, clk
+ | x.pin <= r_0
+ | r_0 <= r
+ | module X :
+ | input clk: Clock
+ | input pin: UInt<5>
+ | wire s: UInt<5>
+ | s <= pin
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringXForm = new WiringTransform()
+ val retC = wiringXForm.execute(CircuitState(c, MidForm, Some(AnnotationMap(Seq(source, sink))), None)).circuit
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+
+ it should "wire using annotations with Aggregate source" in {
+ val source = SourceAnnotation(ComponentName("bundle", ModuleName("A", CircuitName("Top"))), "pin")
+ val sink = SinkAnnotation(ModuleName("B", CircuitName("Top")), "pin")
+ val input =
+ """|circuit Top :
+ | module Top :
+ | input clock : Clock
+ | inst a of A
+ | inst b of B
+ | a.clock <= clock
+ | b.clock <= clock
+ | module A :
+ | input clock : Clock
+ | wire bundle : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }
+ | bundle is invalid
+ | module B :
+ | input clock : Clock""".stripMargin
+ val check =
+ """|circuit Top :
+ | module Top :
+ | input clock : Clock
+ | wire bundle : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }
+ | inst a of A
+ | inst b of B
+ | a.clock <= clock
+ | b.clock <= clock
+ | b.pin <= bundle
+ | bundle <= a.bundle_0
+ | module A :
+ | input clock : Clock
+ | output bundle_0 : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }
+ | wire bundle : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }
+ | bundle is invalid
+ | bundle_0 <= bundle
+ | module B :
+ | input clock : Clock
+ | input pin : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }"""
+ .stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringXForm = new WiringTransform()
+ val retC = wiringXForm.execute(CircuitState(c, MidForm, Some(AnnotationMap(Seq(source, sink))), None)).circuit
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+
+ it should "wire one sink to multiple, disjoint extmodules" in {
+ val sinkX = Seq(ModuleName("X", CircuitName("Top")))
+ val sourceX = ComponentName("r.x", ModuleName("A", CircuitName("Top")))
+ val sinkY = Seq(ModuleName("Y", CircuitName("Top")))
+ val sourceY = ComponentName("r.x", ModuleName("A", CircuitName("Top")))
+ val wiSeq = Seq(
+ WiringInfo(sourceX, sinkX, "pin"),
+ WiringInfo(sourceY, sinkY, "pin"))
+ val input =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | reg r : {x: UInt<5>}, clock
+ | inst x of X
+ | x.clock <= clock
+ | inst y of Y
+ | y.clock <= clock
+ | extmodule X :
+ | input clock: Clock
+ | extmodule Y :
+ | input clock: Clock
+ |""".stripMargin
+ val check =
+ """|circuit Top :
+ | module Top :
+ | input clock: Clock
+ | inst a of A
+ | a.clock <= clock
+ | module A :
+ | input clock: Clock
+ | wire r_x_0: UInt<5>
+ | wire r_x: UInt<5>
+ | reg r: {x: UInt<5>}, clock
+ | inst x of X
+ | x.clock <= clock
+ | inst y of Y
+ | y.clock <= clock
+ | x.pin <= r_x
+ | r_x <= r.x
+ | y.pin <= r_x_0
+ | r_x_0 <= r.x
+ | extmodule X :
+ | input clock: Clock
+ | input pin: UInt<5>
+ | extmodule Y :
+ | input clock: Clock
+ | input pin: UInt<5>
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(wiSeq)
+ val retC = wiringPass.run(c)
(parse(retC.serialize).serialize) should be (parse(check).serialize)
}
}
diff --git a/src/test/scala/firrtlTests/graph/EulerTourTests.scala b/src/test/scala/firrtlTests/graph/EulerTourTests.scala
new file mode 100644
index 00000000..0b69ce61
--- /dev/null
+++ b/src/test/scala/firrtlTests/graph/EulerTourTests.scala
@@ -0,0 +1,36 @@
+package firrtlTests.graph
+
+import firrtl.graph._
+import firrtlTests._
+
+class EulerTourTests extends FirrtlFlatSpec {
+
+ val top = "top"
+ val first_layer = Set("1a", "1b", "1c")
+ val second_layer = Set("2a", "2b", "2c")
+ val third_layer = Set("3a", "3b", "3c")
+ val last_null = Set.empty[String]
+
+ val m = Map(top -> first_layer) ++ first_layer.map{
+ case x => Map(x -> second_layer) }.flatten.toMap ++ second_layer.map{
+ case x => Map(x -> third_layer) }.flatten.toMap ++ third_layer.map{
+ case x => Map(x -> last_null) }.flatten.toMap
+
+ val graph = DiGraph(m)
+ val instances = graph.pathsInDAG(top).values.flatten
+ val tour = EulerTour(graph, top)
+
+ it should "show equivalency of Berkman--Vishkin and naive RMQs" in {
+ instances.toSeq.combinations(2).toList.map { case Seq(a, b) =>
+ tour.rmqNaive(a, b) should be (tour.rmqBV(a, b))
+ }
+ }
+
+ it should "determine naive RMQs of itself correctly" in {
+ instances.toSeq.map { case a => tour.rmqNaive(a, a) should be (a) }
+ }
+
+ it should "determine Berkman--Vishkin RMQs of itself correctly" in {
+ instances.toSeq.map { case a => tour.rmqNaive(a, a) should be (a) }
+ }
+}