aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes
diff options
context:
space:
mode:
authorSchuyler Eldridge2018-01-15 18:53:28 -0500
committerJack Koenig2018-01-15 15:53:28 -0800
commit347cc522e96f8090d53b3b042af646e4a0e765b2 (patch)
tree5ce616a60da92583bcf996b56fb6ba041c68b3a2 /src/main/scala/firrtl/passes
parent8e18404b2919ef6226b511bb666116f657082aa8 (diff)
WiringTransform Refactor (#648)
Massive refactoring to WiringTransform with the use of a new EulerTour class to speed things up via fast least common ancestor (LCA) queries. Changes include (but are not limited to): * Use lowest common ancestor when wiring * Add EulerTour class with naive and Berkman-Vishkin RMQ * Adds LCA method for Instance Graph * Enables "Two Sources" using "Top" wiring test as this is now valid * Remove TopAnnotation from WiringTransform * Represent WiringTransform sink as `Seq[Named]` * Remove WiringUtils.countInstances, fix imports * Support sources under sinks in WiringTransform * Enable internal module wiring * Support Wiring of Aggregates h/t @edcote fixes #728 Signed-off-by: Schuyler Eldridge <schuyler.eldridge@ibm.com> Reviewed-by: Jack Koenig<jack.koenig3@gmail.com>
Diffstat (limited to 'src/main/scala/firrtl/passes')
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockList.scala6
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala15
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala2
-rw-r--r--src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala2
-rw-r--r--src/main/scala/firrtl/passes/memlib/DecorateMems.scala3
-rw-r--r--src/main/scala/firrtl/passes/wiring/Wiring.scala287
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringTransform.scala103
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringUtils.scala226
8 files changed, 373 insertions, 271 deletions
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala
index bd2536ab..073eb050 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala
@@ -8,7 +8,7 @@ import firrtl.ir._
import annotations._
import Utils.error
import java.io.{File, CharArrayWriter, PrintWriter, Writer}
-import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage}
+import wiring.WiringUtils.{getChildrenMap, getLineage}
import wiring.Lineage
import ClockListUtils._
import Utils._
@@ -30,7 +30,7 @@ class ClockList(top: String, writer: Writer) extends Pass {
// === Checks ===
// TODO(izraelevitz): Check all registers/memories use "clock" clock port
// ==============
-
+
// Clock sources must be blackbox outputs and top's clock
val partialSourceList = getSourceList(moduleMap)(lineages)
val sourceList = partialSourceList ++ moduleMap(top).ports.collect{ case Port(i, n, Input, ClockType) => n }
@@ -39,7 +39,7 @@ class ClockList(top: String, writer: Writer) extends Pass {
// Remove everything from the circuit, unless it has a clock type
// This simplifies the circuit drastically so InlineInstances doesn't take forever.
val onlyClockCircuit = RemoveAllButClocks.run(c)
-
+
// Inline the clock-only circuit up to the specified top module
val modulesToInline = (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet
val inlineTransform = new InlineInstances
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
index b04171a7..24f25525 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
@@ -8,7 +8,6 @@ import firrtl.ir._
import annotations._
import Utils.error
import java.io.{File, CharArrayWriter, PrintWriter, Writer}
-import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage}
import wiring.Lineage
import ClockListUtils._
import Utils._
@@ -22,23 +21,23 @@ object ClockListAnnotation {
[Optional] ClockList
List which signal drives each clock of every descendent of specified module
-Usage:
+Usage:
--list-clocks -c:<circuit>:-m:<module>:-o:<filename>
*** Note: sub-arguments to --list-clocks should be delimited by : and not white space!
-"""
-
+"""
+
//Parse pass options
val passOptions = PassConfigUtil.getPassOptions(t, usage)
val outputConfig = passOptions.getOrElse(
- OutputConfigFileName,
+ OutputConfigFileName,
error("No output config file provided for ClockList!" + usage)
)
val passCircuit = passOptions.getOrElse(
- PassCircuitName,
+ PassCircuitName,
error("No circuit name specified for ClockList!" + usage)
)
val passModule = passOptions.getOrElse(
- PassModuleName,
+ PassModuleName,
error("No module name specified for ClockList!" + usage)
)
passOptions.get(InputConfigFileName) match {
@@ -65,7 +64,7 @@ class ClockListTransform extends Transform {
def passSeq(top: String, writer: Writer): Seq[Pass] =
Seq(new ClockList(top, writer))
def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match {
- case Seq(ClockListAnnotation(ModuleName(top, CircuitName(state.circuit.main)), out)) =>
+ case Seq(ClockListAnnotation(ModuleName(top, CircuitName(state.circuit.main)), out)) =>
val outputFile = new PrintWriter(out)
val newC = (new ClockList(top, outputFile)).run(state.circuit)
outputFile.close()
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
index b81d0c7e..892f1642 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
@@ -8,7 +8,6 @@ import firrtl.ir._
import annotations._
import Utils.error
import java.io.{File, CharArrayWriter, PrintWriter, Writer}
-import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage}
import wiring.Lineage
import ClockListUtils._
import Utils._
@@ -61,4 +60,3 @@ object ClockListUtils {
}
}
}
-
diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
index 53787b1d..1178ce69 100644
--- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
+++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
@@ -8,8 +8,6 @@ import firrtl.ir._
import annotations._
import Utils.error
import java.io.{File, CharArrayWriter, PrintWriter, Writer}
-import wiring.WiringUtils.{getChildrenMap, countInstances, ChildrenMap, getLineage}
-import wiring.Lineage
import ClockListUtils._
import Utils._
import memlib.AnalysisUtils._
diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
index 648b0234..ad3616ad 100644
--- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
+++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
@@ -18,9 +18,8 @@ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform
val cN = CircuitName(state.circuit.main)
val oldAnnos = state.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations
val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { case ((annos, pins), config) =>
- val top = TopAnnotation(ModuleName(config.top.name, cN), config.pin.name)
val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)), config.pin.name)
- (annos ++ Seq(top, source), pins :+ config.pin.name)
+ (annos, pins :+ config.pin.name)
}
state.copy(annotations = Some(AnnotationMap(as :+ PinAnnotation(cN, pins.toSeq))))
}
diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala
index 9656abb2..a268dba7 100644
--- a/src/main/scala/firrtl/passes/wiring/Wiring.scala
+++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala
@@ -10,157 +10,194 @@ import firrtl.Mappers._
import scala.collection.mutable
import firrtl.annotations._
import firrtl.annotations.AnnotationUtils._
+import firrtl.analyses.InstanceGraph
import WiringUtils._
-case class WiringException(msg: String) extends PassException(msg)
+/** A data store of one sink--source wiring relationship */
+case class WiringInfo(source: ComponentName, sinks: Seq[Named], pin: String)
-case class WiringInfo(source: String, comp: String, sinks: Set[String], pin: String, top: String)
+/** A data store of wiring names */
+case class WiringNames(compName: String, source: String, sinks: Seq[Named],
+ pin: String)
+/** Pass that computes and applies a sequence of wiring modifications
+ *
+ * @constructor construct a new Wiring pass
+ * @param wiSeq the [[WiringInfo]] to apply
+ */
class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
- def run(c: Circuit): Circuit = {
- wiSeq.foldLeft(c) { (circuit, wi) => wire(circuit, wi) }
+ def run(c: Circuit): Circuit = analyze(c)
+ .foldLeft(c){
+ case (cx, (tpe, modsMap)) => cx.copy(
+ modules = cx.modules map onModule(tpe, modsMap)) }
+
+ /** Converts multiple units of wiring information to module modifications */
+ private def analyze(c: Circuit): Seq[(Type, Map[String, Modifications])] = {
+
+ val names = wiSeq
+ .map ( wi => (wi.source, wi.sinks, wi.pin) match {
+ case (ComponentName(comp, ModuleName(source,_)), sinks, pin) =>
+ WiringNames(comp, source, sinks, pin) })
+
+ val portNames = mutable.Seq.fill(names.size)(Map[String, String]())
+ c.modules.foreach{ m =>
+ val ns = Namespace(m)
+ names.zipWithIndex.foreach{ case (WiringNames(c, so, si, p), i) =>
+ portNames(i) = portNames(i) +
+ ( m.name -> {
+ if (si.exists(getModuleName(_) == m.name)) ns.newName(p)
+ else ns.newName(tokenize(c) filterNot ("[]." contains _) mkString "_")
+ })}}
+
+ val iGraph = new InstanceGraph(c)
+ names.zip(portNames).map{ case(WiringNames(comp, so, si, _), pn) =>
+ computeModifications(c, iGraph, comp, so, si, pn) }
}
- /** Add pins to modules and wires a signal to them, under the scope of a specified top module
- * Description:
- * Adds a pin to each sink module
- * Punches ports up from the source signal to the specified top module
- * Punches ports down to each sink module
- * Wires the source up and down, connecting to all sink modules
- * Restrictions:
- * - Can only have one source module instance under scope of the specified top
- * - All instances of each sink module must be under the scope of the specified top
- * Notes:
- * - No module uniquification occurs (due to imposed restrictions)
+ /** Converts a single unit of wiring information to module modifications
+ *
+ * @param c the circuit that will be modified
+ * @param iGraph an InstanceGraph representation of the circuit
+ * @param compName the name of a component
+ * @param source the name of the source component
+ * @param sinks a list of sink components/modules that the source
+ * should be connected to
+ * @param portNames a mapping of module names to new ports/wires
+ * that should be generated if needed
+ *
+ * @return a tuple of the component type and a map of module names
+ * to pending modifications
*/
- def wire(c: Circuit, wi: WiringInfo): Circuit = {
- // Split out WiringInfo
- val source = wi.source
- val sinks = wi.sinks
- val compName = wi.comp
- val pin = wi.pin
-
- // Maps modules to children instances, i.e. (instance, module)
- val childrenMap = getChildrenMap(c)
-
- // Check restrictions
- val nSources = countInstances(childrenMap, wi.top, source)
- if(nSources != 1)
- throw new WiringException(s"Cannot have $nSources instance of $source under ${wi.top}")
- sinks.foreach { m =>
- val total = countInstances(childrenMap, c.main, m)
- val nTop = countInstances(childrenMap, c.main, wi.top)
- val perTop = countInstances(childrenMap, wi.top, m)
- if(total != nTop * perTop)
- throw new WiringException(s"Module ${wi.top} does not contain all instances of $m.")
- }
-
- // Create valid port names for wiring that have no name conflicts
- val portNames = c.modules.foldLeft(Map.empty[String, String]) { (map, m) =>
- map + (m.name -> {
- val ns = Namespace(m)
- if(sinks.contains(m.name)) ns.newName(pin)
- else ns.newName(tokenize(compName) filterNot ("[]." contains _) mkString "_")
- })
- }
-
- // Create a lineage tree from children map
- val lineages = getLineage(childrenMap, wi.top)
-
- // Populate lineage tree with relationship information, i.e. who is source,
- // sink, parent of source, etc.
- val withFields = setSharedParent(wi.top)(setFields(sinks, source)(lineages))
-
- // Populate lineage tree with what to instantiate, connect to/from, etc.
- val withThings = setThings(portNames, compName)(withFields)
+ private def computeModifications(c: Circuit,
+ iGraph: InstanceGraph,
+ compName: String,
+ source: String,
+ sinks: Seq[Named],
+ portNames: Map[String, String]):
+ (Type, Map[String, Modifications]) = {
- // Create a map from module name to lineage information
- val map = pointToLineage(withThings)
-
- // Obtain the source component type
val sourceComponentType = getType(c, source, compName)
-
- // Return new circuit with correct wiring
- val cx = c.copy(modules = c.modules map onModule(map, sourceComponentType))
-
- // Replace inserted IR nodes with WIR nodes
- ToWorkingIR.run(cx)
+ val sinkComponents: Map[String, Seq[String]] = sinks
+ .collect{ case ComponentName(c, ModuleName(m, _)) => (c, m) }
+ .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]){
+ case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)) ) }
+
+ // Determine "ownership" of sources to sinks via minimum distance
+ val owners = sinksToSources(sinks, source, iGraph)
+
+ // Determine port and pending modifications for all sink--source
+ // ownership pairs
+ val meta = new mutable.HashMap[String, Modifications]
+ .withDefaultValue(Modifications())
+ owners.foreach { case (sink, source) =>
+ val lca = iGraph.lowestCommonAncestor(sink, source)
+
+ // Compute metadata along Sink to LCA paths.
+ sink.drop(lca.size - 1).sliding(2).toList.reverse.map {
+ case Seq(WDefInstance(_,_,pm,_), WDefInstance(_,ci,cm,_)) =>
+ val to = s"$ci.${portNames(cm)}"
+ val from = s"${portNames(pm)}"
+ meta(pm) = meta(pm).copy(
+ addPortOrWire = Some((portNames(pm), DecWire)),
+ cons = (meta(pm).cons :+ (to, from)).distinct
+ )
+ meta(cm) = meta(cm).copy(
+ addPortOrWire = Some((portNames(cm), DecInput))
+ )
+ // Case where the sink is the LCA
+ case Seq(WDefInstance(_,_,pm,_)) =>
+ // Case where the source is also the LCA
+ if (source.drop(lca.size).isEmpty) {
+ meta(pm) = meta(pm).copy (
+ addPortOrWire = Some((portNames(pm), DecWire))
+ )
+ } else {
+ val WDefInstance(_,ci,cm,_) = source.drop(lca.size).head
+ val to = s"${portNames(pm)}"
+ val from = s"$ci.${portNames(cm)}"
+ meta(pm) = meta(pm).copy(
+ addPortOrWire = Some((portNames(pm), DecWire)),
+ cons = (meta(pm).cons :+ (to, from)).distinct
+ )
+ }
+ }
+
+ // Compute metadata for the Sink
+ sink.last match { case WDefInstance(_, _, m, _) =>
+ if (sinkComponents.contains(m)) {
+ val from = s"${portNames(m)}"
+ sinkComponents(m).foreach( to =>
+ meta(m) = meta(m).copy(
+ cons = (meta(m).cons :+ (to, from)).distinct
+ )
+ )
+ }
+ }
+
+ // Compute metadata for the Source
+ source.last match { case WDefInstance(_, _, m, _) =>
+ val to = s"${portNames(m)}"
+ val from = compName
+ meta(m) = meta(m).copy(
+ cons = (meta(m).cons :+ (to, from)).distinct
+ )
+ }
+
+ // Compute metadata along Source to LCA path
+ source.drop(lca.size - 1).sliding(2).toList.reverse.map {
+ case Seq(WDefInstance(_,_,pm,_), WDefInstance(_,ci,cm,_)) => {
+ val to = s"${portNames(pm)}"
+ val from = s"$ci.${portNames(cm)}"
+ meta(pm) = meta(pm).copy(
+ cons = (meta(pm).cons :+ (to, from)).distinct
+ )
+ meta(cm) = meta(cm).copy(
+ addPortOrWire = Some((portNames(cm), DecOutput))
+ )
+ }
+ // Case where the source is the LCA
+ case Seq(WDefInstance(_,_,pm,_)) => {
+ // Case where the sink is also the LCA. We do nothing here,
+ // as we've created the connecting wire above
+ if (sink.drop(lca.size).isEmpty) {
+ } else {
+ val WDefInstance(_,ci,cm,_) = sink.drop(lca.size).head
+ val to = s"$ci.${portNames(cm)}"
+ val from = s"${portNames(pm)}"
+ meta(pm) = meta(pm).copy(
+ cons = (meta(pm).cons :+ (to, from)).distinct
+ )
+ }
+ }
+ }
+ }
+ (sourceComponentType, meta.toMap)
}
- /** Return new module with correct wiring
- */
- private def onModule(map: Map[String, Lineage], t: Type)(m: DefModule) = {
+ /** Apply modifications to a module */
+ private def onModule(t: Type, map: Map[String, Modifications])(m: DefModule) = {
map.get(m.name) match {
case None => m
case Some(l) =>
- val stmts = mutable.ArrayBuffer[Statement]()
+ val defines = mutable.ArrayBuffer[Statement]()
+ val connects = mutable.ArrayBuffer[Statement]()
val ports = mutable.ArrayBuffer[Port]()
- l.addPort match {
+ l.addPortOrWire match {
case None =>
case Some((s, dt)) => dt match {
- case DecInput => ports += Port(NoInfo, s, Input, t)
+ case DecInput => ports += Port(NoInfo, s, Input, t)
case DecOutput => ports += Port(NoInfo, s, Output, t)
- case DecWire =>
- stmts += DefWire(NoInfo, s, t)
+ case DecWire => defines += DefWire(NoInfo, s, t)
}
}
- stmts ++= (l.cons map { case ((l, r)) =>
+ connects ++= (l.cons map { case ((l, r)) =>
Connect(NoInfo, toExp(l), toExp(r))
})
- def onStmt(s: Statement): Statement = Block(Seq(s) ++ stmts)
m match {
- case Module(i, n, ps, s) => Module(i, n, ps ++ ports, Block(Seq(s) ++ stmts))
+ case Module(i, n, ps, s) => Module(i, n, ps ++ ports,
+ Block(defines ++ Seq(s) ++ connects))
case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps ++ ports, dn, p)
}
}
}
-
- /** Returns the type of the component specified
- */
- private def getType(c: Circuit, module: String, comp: String) = {
- def getRoot(e: Expression): String = e match {
- case r: Reference => r.name
- case i: SubIndex => getRoot(i.expr)
- case a: SubAccess => getRoot(a.expr)
- case f: SubField => getRoot(f.expr)
- }
- val eComp = toExp(comp)
- val root = getRoot(eComp)
- var tpe: Option[Type] = None
- def getType(s: Statement): Statement = s match {
- case DefRegister(_, n, t, _, _, _) if n == root =>
- tpe = Some(t)
- s
- case DefWire(_, n, t) if n == root =>
- tpe = Some(t)
- s
- case WDefInstance(_, n, m, t) if n == root =>
- tpe = Some(t)
- s
- case DefNode(_, n, e) if n == root =>
- tpe = Some(e.tpe)
- s
- case sx: DefMemory if sx.name == root =>
- tpe = Some(MemPortUtils.memType(sx))
- sx
- case sx => sx map getType
- }
- val m = c.modules find (_.name == module) getOrElse error(s"Must have a module named $module")
- tpe = m.ports find (_.name == root) map (_.tpe)
- m match {
- case Module(i, n, ps, b) => getType(b)
- case e: ExtModule =>
- }
- tpe match {
- case None => error(s"Didn't find $comp in $module!")
- case Some(t) =>
- def setType(e: Expression): Expression = e map setType match {
- case ex: Reference => ex.copy(tpe = t)
- case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name))
- case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe))
- case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe))
- }
- setType(eComp).tpe
- }
- }
}
diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
index a8ef5f58..01e6f83a 100644
--- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
+++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
@@ -11,86 +11,79 @@ import scala.collection.mutable
import firrtl.annotations._
import WiringUtils._
-/** A component, e.g. register etc. Must be declared only once under the TopAnnotation
- */
+/** A class for all exceptions originating from firrtl.passes.wiring */
+case class WiringException(msg: String) extends PassException(msg)
+
+/** An extractor of annotated source components */
object SourceAnnotation {
- def apply(target: ComponentName, pin: String): Annotation = Annotation(target, classOf[WiringTransform], s"source $pin")
+ def apply(target: ComponentName, pin: String): Annotation =
+ Annotation(target, classOf[WiringTransform], s"source $pin")
private val matcher = "source (.+)".r
def unapply(a: Annotation): Option[(ComponentName, String)] = a match {
- case Annotation(ComponentName(n, m), _, matcher(pin)) => Some((ComponentName(n, m), pin))
+ case Annotation(ComponentName(n, m), _, matcher(pin)) =>
+ Some((ComponentName(n, m), pin))
case _ => None
}
}
-/** A module, e.g. ExtModule etc., that should add the input pin
- */
+/** An extractor of annotation sink components or modules */
object SinkAnnotation {
- def apply(target: ModuleName, pin: String): Annotation = Annotation(target, classOf[WiringTransform], s"sink $pin")
+ def apply(target: Named, pin: String): Annotation =
+ Annotation(target, classOf[WiringTransform], s"sink $pin")
private val matcher = "sink (.+)".r
- def unapply(a: Annotation): Option[(ModuleName, String)] = a match {
- case Annotation(ModuleName(n, c), _, matcher(pin)) => Some((ModuleName(n, c), pin))
+ def unapply(a: Annotation): Option[(Named, String)] = a match {
+ case Annotation(ModuleName(n, c), _, matcher(pin)) =>
+ Some((ModuleName(n, c), pin))
+ case Annotation(ComponentName(n, m), _, matcher(pin)) =>
+ Some((ComponentName(n, m), pin))
case _ => None
}
}
-/** A module under which all sink module must be declared, and there is only
- * one source component
+/** Wires a Module's Source Component to one or more Sink
+ * Modules/Components
+ *
+ * Sinks are wired to their closest source through their lowest
+ * common ancestor (LCA). Verbosely, this modifies the circuit in
+ * the following ways:
+ * - Adds a pin to each sink module
+ * - Punches ports up from source signals to the LCA
+ * - Punches ports down from LCAs to each sink module
+ * - Wires sources up to LCA, sinks down from LCA, and across each LCA
+ *
+ * @throws WiringException if a sink is equidistant to two sources
*/
-object TopAnnotation {
- def apply(target: ModuleName, pin: String): Annotation = Annotation(target, classOf[WiringTransform], s"top $pin")
+class WiringTransform extends Transform {
+ def inputForm: CircuitForm = MidForm
+ def outputForm: CircuitForm = HighForm
- private val matcher = "top (.+)".r
- def unapply(a: Annotation): Option[(ModuleName, String)] = a match {
- case Annotation(ModuleName(n, c), _, matcher(pin)) => Some((ModuleName(n, c), pin))
- case _ => None
- }
-}
+ /** Defines the sequence of Transform that should be applied */
+ private def transforms(w: Seq[WiringInfo]): Seq[Transform] = Seq(
+ new Wiring(w),
+ ToWorkingIR
+ )
-/** Add pins to modules and wires a signal to them, under the scope of a specified top module
- * Description:
- * Adds a pin to each sink module
- * Punches ports up from the source signal to the specified top module
- * Punches ports down to each sink module
- * Wires the source up and down, connecting to all sink modules
- * Restrictions:
- * - Can only have one source module instance under scope of the specified top
- * - All instances of each sink module must be under the scope of the specified top
- * Notes:
- * - No module uniquification occurs (due to imposed restrictions)
- */
-class WiringTransform extends Transform {
- def inputForm = MidForm
- def outputForm = MidForm
- def transforms(wis: Seq[WiringInfo]) =
- Seq(new Wiring(wis),
- InferTypes,
- ResolveKinds,
- ResolveGenders)
def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match {
case Nil => state
- case p =>
- val sinks = mutable.HashMap[String, Set[String]]()
- val sources = mutable.HashMap[String, String]()
- val tops = mutable.HashMap[String, String]()
- val comp = mutable.HashMap[String, String]()
- p.foreach {
+ case p =>
+ val sinks = mutable.HashMap[String, Seq[Named]]()
+ val sources = mutable.HashMap[String, ComponentName]()
+ p.foreach {
case SinkAnnotation(m, pin) =>
- sinks(pin) = sinks.getOrElse(pin, Set.empty) + m.name
+ sinks(pin) = sinks.getOrElse(pin, Seq.empty) :+ m
case SourceAnnotation(c, pin) =>
- sources(pin) = c.module.name
- comp(pin) = c.name
- case TopAnnotation(m, pin) => tops(pin) = m.name
+ sources(pin) = c
}
- (sources.size, tops.size, sinks.size, comp.size) match {
- case (0, 0, p, 0) => state
- case (s, t, p, c) if (p > 0) & (s == t) & (t == c) =>
- val wis = tops.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, top)) =>
- seq :+ WiringInfo(sources(pin), comp(pin), sinks(pin), pin, top)
+ (sources.size, sinks.size) match {
+ case (0, p) => state
+ case (s, p) if (p > 0) =>
+ val wis = sources.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, source)) =>
+ seq :+ WiringInfo(source, sinks(pin), pin)
}
transforms(wis).foldLeft(state) { (in, xform) => xform.runTransform(in) }
- case _ => error("Wrong number of sources, tops, or sinks!")
+ case _ => error("Wrong number of sources or sinks!")
}
}
}
diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
index 29c93ca7..117a3824 100644
--- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
+++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
@@ -9,6 +9,9 @@ import firrtl.Utils._
import firrtl.Mappers._
import scala.collection.mutable
import firrtl.annotations._
+import firrtl.annotations.AnnotationUtils._
+import firrtl.analyses.InstanceGraph
+import firrtl.graph.DiGraph
import WiringUtils._
/** Declaration kind in lineage (e.g. input port, output port, wire)
@@ -18,6 +21,19 @@ case object DecInput extends DecKind
case object DecOutput extends DecKind
case object DecWire extends DecKind
+/** Store of pending wiring information for a Module */
+case class Modifications(
+ addPortOrWire: Option[(String, DecKind)] = None,
+ cons: Seq[(String, String)] = Seq.empty) {
+
+ override def toString: String = serialize("")
+
+ def serialize(tab: String): String = s"""
+ |$tab addPortOrWire: $addPortOrWire
+ |$tab cons: $cons
+ |""".stripMargin
+}
+
/** A lineage tree representing the instance hierarchy in a design
*/
case class Lineage(
@@ -41,7 +57,7 @@ case class Lineage(
|$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))}
|""".stripMargin
- def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B =
+ def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B =
this.children.foldLeft(z)(op)
def serialize(tab: String): String = s"""
@@ -57,9 +73,6 @@ case class Lineage(
|""".stripMargin
}
-
-
-
object WiringUtils {
type ChildrenMap = mutable.HashMap[String, Seq[(String, String)]]
@@ -69,10 +82,10 @@ object WiringUtils {
def getChildrenMap(c: Circuit): ChildrenMap = {
val childrenMap = new ChildrenMap()
def getChildren(mname: String)(s: Statement): Statement = s match {
- case s: WDefInstance =>
+ case s: WDefInstance =>
childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module)
s
- case s: DefInstance =>
+ case s: DefInstance =>
childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module)
s
case s => s map getChildren(mname)
@@ -84,86 +97,151 @@ object WiringUtils {
childrenMap
}
- /** Counts the number of instances of a module declared under a top module
- */
- def countInstances(childrenMap: ChildrenMap, top: String, module: String): Int = {
- if(top == module) 1
- else childrenMap(top).foldLeft(0) { case (count, (i, child)) =>
- count + countInstances(childrenMap, child, module)
- }
- }
-
/** Returns a module's lineage, containing all children lineages as well
*/
def getLineage(childrenMap: ChildrenMap, module: String): Lineage =
Lineage(module, childrenMap(module) map { case (i, m) => (i, getLineage(childrenMap, m)) } )
- /** Sets the sink, sinkParent, source, and sourceParent fields of every
- * Lineage in tree
+ /** Return a map of sink instances to source instances that minimizes
+ * distance
+ *
+ * @param sinks a sequence of sink modules
+ * @param source the source module
+ * @param i a graph representing a circuit
+ * @return a map of sink instance names to source instance names
+ * @throws WiringException if a sink is equidistant to two sources
*/
- def setFields(sinks: Set[String], source: String)(lin: Lineage): Lineage = lin map setFields(sinks, source) match {
- case l if sinks.contains(l.name) => l.copy(sink = true)
- case l =>
- val src = l.name == source
- val sinkParent = l.children.foldLeft(false) { case (b, (i, m)) => b || m.sink || m.sinkParent }
- val sourceParent = if(src) true else l.children.foldLeft(false) { case (b, (i, m)) => b || m.source || m.sourceParent }
- l.copy(sinkParent=sinkParent, sourceParent=sourceParent, source=src)
- }
+ def sinksToSources(sinks: Seq[Named],
+ source: String,
+ i: InstanceGraph):
+ Map[Seq[WDefInstance], Seq[WDefInstance]] = {
+ val owners = new mutable.HashMap[Seq[WDefInstance], Vector[Seq[WDefInstance]]]
+ .withDefaultValue(Vector())
+ val queue = new mutable.Queue[Seq[WDefInstance]]
+ val visited = new mutable.HashMap[Seq[WDefInstance], Boolean]
+ .withDefaultValue(false)
+
+ i.fullHierarchy.keys.filter { case WDefInstance(_,_,m,_) => m == source }
+ .foreach( i.fullHierarchy(_)
+ .foreach { l =>
+ queue.enqueue(l)
+ owners(l) = Vector(l)
+ }
+ )
+
+ val sinkInsts = i.fullHierarchy.keys
+ .filter { case WDefInstance(_, _, module, _) =>
+ sinks.map(getModuleName(_)).contains(module) }
+ .flatMap { k => i.fullHierarchy(k) }
+ .toSet
+
+ /** If we're lucky and there is only one source, then that source owns
+ * all sinks. If we're unlucky, we need to do a full (slow) BFS
+ * to figure out who owns what. Currently, the BFS is not
+ * performant.
+ *
+ * [todo] The performance of this will need to be improved.
+ * Possible directions are that if we're purely source-under-sink
+ * or sink-under-source, then ownership is trivially a mapping
+ * down/up. Ownership seems to require a BFS if we have
+ * sources/sinks not under sinks/sources.
+ */
+ if (queue.size == 1) {
+ val u = queue.dequeue
+ sinkInsts.foreach { v => owners(v) = Vector(u) }
+ } else {
+ while (queue.nonEmpty) {
+ val u = queue.dequeue
+ visited(u) = true
+
+ val edges = (i.graph.getEdges(u.last).map(u :+ _).toVector :+ u.dropRight(1))
+
+ // [todo] This is the critical section
+ edges
+ .filter( e => !visited(e) && e.nonEmpty )
+ .foreach{ v =>
+ owners(v) = owners(v) ++ owners(u)
+ queue.enqueue(v)
+ }
+ }
+
+ // Check that every sink has one unique owner. The only time that
+ // this should fail is if a sink is equidistant to two sources.
+ sinkInsts.foreach { s =>
+ if (!owners.contains(s) || owners(s).size > 1) {
+ throw new WiringException(
+ s"Unable to determine source mapping for sink '${s.map(_.name)}'") }
+ }
+ }
- /** Sets the sharedParent of lineage top
- */
- def setSharedParent(top: String)(lin: Lineage): Lineage = lin map setSharedParent(top) match {
- case l if l.name == top => l.copy(sharedParent = true)
- case l => l
+ owners
+ .collect { case (k, v) if sinkInsts.contains(k) => (k, v.flatten) }.toMap
}
- /** Sets the addPort and cons fields of the lineage tree
- */
- def setThings(portNames:Map[String, String], compName: String)(lin: Lineage): Lineage = {
- val funs = Seq(
- ((l: Lineage) => l map setThings(portNames, compName)),
- ((l: Lineage) => l match {
- case Lineage(name, _, _, _, _, _, true, _, _) => //SharedParent
- l.copy(addPort=Some((portNames(name), DecWire)))
- case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent
- l.copy(addPort=Some((portNames(name), DecOutput)))
- case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent
- l.copy(addPort=Some((portNames(name), DecInput)))
- case Lineage(name, _, _, true, _, _, _, _, _) => //Sink
- l.copy(addPort=Some((portNames(name), DecInput)))
- case l => l
- }),
- ((l: Lineage) => l match {
- case Lineage(name, _, true, _, _, _, _, _, _) => //Source
- val tos = Seq(s"${portNames(name)}")
- val from = compName
- l.copy(cons = l.cons ++ tos.map(t => (t, from)))
- case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent
- val tos = Seq(s"${portNames(name)}")
- val from = l.children.filter { case (i, c) => c.sourceParent }.map { case (i, c) => s"$i.${portNames(c.name)}" }.head
- l.copy(cons = l.cons ++ tos.map(t => (t, from)))
- case l => l
- }),
- ((l: Lineage) => l match {
- case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent
- val tos = l.children.filter { case (i, c) => (c.sinkParent || c.sink) && !c.sourceParent } map { case (i, c) => s"$i.${portNames(c.name)}" }
- val from = s"${portNames(name)}"
- l.copy(cons = l.cons ++ tos.map(t => (t, from)))
- case l => l
- })
- )
- funs.foldLeft(lin)((l, fun) => fun(l))
+ /** Helper script to extract a module name from a named Module or Component */
+ def getModuleName(n: Named): String = {
+ n match {
+ case ModuleName(m, _) => m
+ case ComponentName(_, ModuleName(m, _)) => m
+ case _ => throw new WiringException(
+ "Only Components or Modules have an associated Module name")
+ }
}
- /** Return a map from module to its lineage in the tree
+ /** Determine the Type of a specific component
+ *
+ * @param c the circuit containing the target module
+ * @param module the module containing the target component
+ * @param comp the target component
+ * @return the component's type
+ * @throws WiringException if the module is not contained in the
+ * circuit or if the component is not contained in the module
*/
- def pointToLineage(lin: Lineage): Map[String, Lineage] = {
- val map = mutable.HashMap[String, Lineage]()
- def onLineage(l: Lineage): Lineage = {
- map(l.name) = l
- l map onLineage
+ def getType(c: Circuit, module: String, comp: String): Type = {
+ def getRoot(e: Expression): String = e match {
+ case r: Reference => r.name
+ case i: SubIndex => getRoot(i.expr)
+ case a: SubAccess => getRoot(a.expr)
+ case f: SubField => getRoot(f.expr)
+ }
+ val eComp = toExp(comp)
+ val root = getRoot(eComp)
+ var tpe: Option[Type] = None
+ def getType(s: Statement): Statement = s match {
+ case DefRegister(_, n, t, _, _, _) if n == root =>
+ tpe = Some(t)
+ s
+ case DefWire(_, n, t) if n == root =>
+ tpe = Some(t)
+ s
+ case WDefInstance(_, n, _, t) if n == root =>
+ tpe = Some(t)
+ s
+ case DefNode(_, n, e) if n == root =>
+ tpe = Some(e.tpe)
+ s
+ case sx: DefMemory if sx.name == root =>
+ tpe = Some(MemPortUtils.memType(sx))
+ sx
+ case sx => sx map getType
+ }
+ val m = c.modules find (_.name == module) getOrElse {
+ throw new WiringException(s"Must have a module named $module") }
+ tpe = m.ports find (_.name == root) map (_.tpe)
+ m match {
+ case Module(i, n, ps, b) => getType(b)
+ case e: ExtModule =>
+ }
+ tpe match {
+ case None => throw new WiringException(s"Didn't find $comp in $module!")
+ case Some(t) =>
+ def setType(e: Expression): Expression = e map setType match {
+ case ex: Reference => ex.copy(tpe = t)
+ case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name))
+ case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe))
+ case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe))
+ }
+ setType(eComp).tpe
}
- onLineage(lin)
- map.toMap
}
}