aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-10-27 13:00:02 -0700
committerGitHub2016-10-27 13:00:02 -0700
commit5b35f2d2722f72c81d2d6c507cd379be2a1476d8 (patch)
tree78dc2db9e12c6db52fcbf222e339a37b6ebc0b72
parent1c61a0e7102983891d99d8e9c49e331c8a2178a6 (diff)
Wiring (#348)
Added wiring pass and simple test
-rw-r--r--in.yaml9
-rw-r--r--src/main/scala/firrtl/Annotations.scala71
-rw-r--r--src/main/scala/firrtl/Utils.scala2
-rw-r--r--src/main/scala/firrtl/passes/memlib/DecorateMems.scala25
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala27
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala46
-rw-r--r--src/main/scala/firrtl/passes/memlib/YamlUtils.scala10
-rw-r--r--src/main/scala/firrtl/passes/wiring/Wiring.scala162
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringTransform.scala80
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringUtils.scala164
-rw-r--r--src/test/scala/firrtlTests/WiringTests.scala347
11 files changed, 922 insertions, 21 deletions
diff --git a/in.yaml b/in.yaml
new file mode 100644
index 00000000..f77e0a72
--- /dev/null
+++ b/in.yaml
@@ -0,0 +1,9 @@
+---
+pin:
+ name: mypin
+source:
+ name: cmd
+ module: Htif
+top:
+ name: Top
+---
diff --git a/src/main/scala/firrtl/Annotations.scala b/src/main/scala/firrtl/Annotations.scala
index 2d76a832..d47ce67e 100644
--- a/src/main/scala/firrtl/Annotations.scala
+++ b/src/main/scala/firrtl/Annotations.scala
@@ -1,5 +1,7 @@
package firrtl
+import firrtl.ir._
+
import scala.collection.mutable
import java.io.Writer
@@ -38,6 +40,62 @@ import java.io.Writer
* ----------|----------|----------|------------|-----------|
*/
object Annotations {
+ /** Returns true if a valid Module name */
+ val SerializedModuleName = """([a-zA-Z_][a-zA-Z_0-9~!@#$%^*\-+=?/]*)""".r
+ def validModuleName(s: String): Boolean = s match {
+ case SerializedModuleName(name) => true
+ case _ => false
+ }
+
+ /** Returns true if a valid component/subcomponent name */
+ val SerializedComponentName = """([a-zA-Z_][a-zA-Z_0-9\[\]\.~!@#$%^*\-+=?/]*)""".r
+ def validComponentName(s: String): Boolean = s match {
+ case SerializedComponentName(name) => true
+ case _ => false
+ }
+
+ /** Tokenizes a string with '[', ']', '.' as tokens, e.g.:
+ * "foo.bar[boo.far]" becomes Seq("foo" "." "bar" "[" "boo" "." "far" "]")
+ */
+ def tokenize(s: String): Seq[String] = s.find(c => "[].".contains(c)) match {
+ case Some(_) =>
+ val i = s.indexWhere(c => "[].".contains(c))
+ Seq(s.slice(0, i), s(i).toString) ++ tokenize(s.drop(i + 1))
+ case None => Seq(s)
+ }
+
+ /** Given a serialized component/subcomponent reference, subindex, subaccess,
+ * or subfield, return the corresponding IR expression.
+ */
+ def toExp(s: String): Expression = {
+ def parse(tokens: Seq[String]): Expression = {
+ val DecPattern = """([1-9]\d*)""".r
+ def findClose(tokens: Seq[String], index: Int, nOpen: Int): Seq[String] =
+ if(index >= tokens.size) {
+ error("Cannot find closing bracket ]")
+ } else tokens(index) match {
+ case "[" => findClose(tokens, index + 1, nOpen + 1)
+ case "]" if nOpen == 1 => tokens.slice(1, index)
+ case _ => findClose(tokens, index + 1, nOpen)
+ }
+ def buildup(e: Expression, tokens: Seq[String]): Expression = tokens match {
+ case "[" :: tail =>
+ val indexOrAccess = findClose(tokens, 0, 0)
+ indexOrAccess.head match {
+ case DecPattern(d) => SubIndex(e, d.toInt, UnknownType)
+ case _ => buildup(SubAccess(e, parse(indexOrAccess), UnknownType), tokens.slice(1, indexOrAccess.size))
+ }
+ case "." :: tail =>
+ buildup(SubField(e, tokens(1), UnknownType), tokens.drop(2))
+ case Nil => e
+ }
+ val root = Reference(tokens.head, UnknownType)
+ buildup(root, tokens.tail)
+ }
+ if(validComponentName(s)) {
+ parse(tokenize(s))
+ } else error(s"Cannot convert $s into an expression.")
+ }
case class AnnotationException(message: String) extends Exception(message)
@@ -45,9 +103,16 @@ object Annotations {
* Named classes associate an annotation with a component in a Firrtl circuit
*/
trait Named { def name: String }
- case class CircuitName(name: String) extends Named
- case class ModuleName(name: String, circuit: CircuitName) extends Named
- case class ComponentName(name: String, module: ModuleName) extends Named
+ case class CircuitName(name: String) extends Named {
+ if(!validModuleName(name)) throw AnnotationException(s"Illegal circuit name: $name")
+ }
+ case class ModuleName(name: String, circuit: CircuitName) extends Named {
+ if(!validModuleName(name)) throw AnnotationException(s"Illegal module name: $name")
+ }
+ case class ComponentName(name: String, module: ModuleName) extends Named {
+ if(!validComponentName(name)) throw AnnotationException(s"Illegal component name: $name")
+ def expr: Expression = toExp(name)
+ }
/**
* Transform ID (TransID) associates an annotation with an instantiated
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 294afe57..22a3eac6 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -50,6 +50,8 @@ import com.typesafe.scalalogging.LazyLogging
class FIRRTLException(str: String) extends Exception(str)
object Utils extends LazyLogging {
+ def throwInternalError =
+ error("Internal Error! Please file an issue at https://github.com/ucb-bar/firrtl/issues")
private[firrtl] def time[R](name: String)(block: => R): R = {
logger.info(s"Starting $name")
val t0 = System.nanoTime()
diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
new file mode 100644
index 00000000..10cc8f88
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
@@ -0,0 +1,25 @@
+package firrtl
+package passes
+package memlib
+import ir._
+import Annotations._
+import wiring._
+
+class CreateMemoryAnnotations(reader: Option[YamlFileReader], replaceID: TransID, wiringID: TransID) extends Transform {
+ def name = "Create Memory Annotations"
+ def execute(c: Circuit, map: AnnotationMap): TransformResult = reader match {
+ case None => TransformResult(c)
+ case Some(r) =>
+ import CustomYAMLProtocol._
+ r.parse[Config] match {
+ case Seq(config) =>
+ val cN = CircuitName(c.main)
+ val top = TopAnnotation(ModuleName(config.top.name, cN), wiringID)
+ val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)), wiringID)
+ val pin = PinAnnotation(cN, replaceID, config.pin.name)
+ TransformResult(c, None, Some(AnnotationMap(Seq(top, source, pin))))
+ case Nil => TransformResult(c, None, None)
+ case _ => error("Can only have one config in yaml file")
+ }
+ }
+}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
index a52f7d38..9ab496d2 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
@@ -10,12 +10,24 @@ import firrtl.Mappers._
import MemPortUtils.{MemPortMap, Modules}
import MemTransformUtils._
import AnalysisUtils._
+import Annotations._
+import wiring._
+
+
+/** Annotates the name of the pin to add for WiringTransform
+ */
+case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends Annotation with Loose with Unstable {
+ def duplicate(n: Named) = n match {
+ case n: CircuitName => this.copy(target = n)
+ case _ => throwInternalError
+ }
+}
/** Replace DefAnnotatedMemory with memory blackbox + wrapper + conf file.
* This will not generate wmask ports if not needed.
* Creates the minimum # of black boxes needed by the design.
*/
-class ReplaceMemMacros(writer: ConfWriter) extends Pass {
+class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) extends Transform {
def name = "Replace Memory Macros"
/** Return true if mask granularity is per bit, false if per byte or unspecified
@@ -194,7 +206,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
map updateStmtRefs(memPortMap))
}
- def run(c: Circuit) = {
+ def execute(c: Circuit, map: AnnotationMap): TransformResult = {
val namespace = Namespace(c)
val memMods = new Modules
val nameMap = new NameMap
@@ -202,6 +214,15 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
val modules = c.modules map updateMemMods(namespace, nameMap, memMods)
// print conf
writer.serialize()
- c copy (modules = modules ++ memMods)
+ val pin = map get myID match {
+ case Some(p) =>
+ p.values.head match {
+ case PinAnnotation(c, _, pin) => pin
+ case _ => error(s"Bad Annotations: ${p.values}")
+ }
+ case None => "pin"
+ }
+ val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), wiringID, pin) }
+ TransformResult(c.copy(modules = modules ++ memMods), None, Some(AnnotationMap(annos)))
}
}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index dfa828c9..01f020f5 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -9,6 +9,7 @@ import Annotations._
import AnalysisUtils._
import Utils.error
import java.io.{File, CharArrayWriter, PrintWriter}
+import wiring._
sealed trait PassOption
case object InputConfigFileName extends PassOption
@@ -92,21 +93,36 @@ Optional Arguments:
def duplicate(n: Named) = this copy (t = t.replace(s"-c:$passCircuit", s"-c:${n.name}"))
}
+case class SimpleTransform(p: Pass) extends Transform {
+ def execute(c: Circuit, map: AnnotationMap): TransformResult =
+ TransformResult(p.run(c))
+}
class ReplSeqMem(transID: TransID) extends Transform with SimpleRun {
- def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter) =
- Seq(Legalize,
- ToMemIR,
- ResolveMaskGranularity,
- RenameAnnotatedMemoryPorts,
- ResolveMemoryReference,
- //new AnnotateValidMemConfigs(inConfigFile),
- new ReplaceMemMacros(outConfigFile),
- RemoveEmpty,
- CheckInitialization,
- InferTypes,
- Uniquify,
- ResolveKinds, // Must be run for the transform to work!
- ResolveGenders)
+ def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] =
+ Seq(SimpleTransform(Legalize),
+ SimpleTransform(ToMemIR),
+ SimpleTransform(ResolveMaskGranularity),
+ SimpleTransform(RenameAnnotatedMemoryPorts),
+ SimpleTransform(ResolveMemoryReference),
+ new CreateMemoryAnnotations(inConfigFile, TransID(-7), TransID(-8)),
+ new ReplaceMemMacros(outConfigFile, TransID(-7), TransID(-8)),
+ new WiringTransform(TransID(-8)),
+ SimpleTransform(RemoveEmpty),
+ SimpleTransform(CheckInitialization),
+ SimpleTransform(InferTypes),
+ SimpleTransform(Uniquify),
+ SimpleTransform(ResolveKinds),
+ SimpleTransform(ResolveGenders))
+ def run(circuit: Circuit, map: AnnotationMap, xForms: Seq[Transform]): TransformResult = {
+ (xForms.foldLeft(TransformResult(circuit, None, Some(map)))) { case (tr: TransformResult, xForm: Transform) =>
+ val x = xForm.execute(tr.circuit, tr.annotation.get)
+ x.annotation match {
+ case None => TransformResult(x.circuit, None, Some(map))
+ case Some(ann) => TransformResult(x.circuit, None, Some(
+ AnnotationMap(ann.annotations ++ tr.annotation.get.annotations)))
+ }
+ }
+ }
def execute(c: Circuit, map: AnnotationMap) = map get transID match {
case Some(p) => p get CircuitName(c.main) match {
@@ -118,7 +134,7 @@ class ReplSeqMem(transID: TransID) extends Transform with SimpleRun {
else error("Input configuration file does not exist!")
}
val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName))
- run(c, passSeq(inConfigFile, outConfigFile))
+ run(c, map, passSeq(inConfigFile, outConfigFile))
case _ => error("Unexpected transform annotation")
}
case _ => TransformResult(c)
diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
index a1088300..fcef4229 100644
--- a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
@@ -5,8 +5,18 @@ import java.io.{File, CharArrayWriter, PrintWriter}
object CustomYAMLProtocol extends DefaultYamlProtocol {
// bottom depends on top
+ implicit val _pin = yamlFormat1(Pin)
+ implicit val _source = yamlFormat2(Source)
+ implicit val _top = yamlFormat1(Top)
+ implicit val _configs = yamlFormat3(Config)
}
+case class Pin(name: String)
+case class Source(name: String, module: String)
+case class Top(name: String)
+case class Config(pin: Pin, source: Source, top: Top)
+
+
class YamlFileReader(file: String) {
import CustomYAMLProtocol._
def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala
new file mode 100644
index 00000000..d3f6f3dd
--- /dev/null
+++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala
@@ -0,0 +1,162 @@
+package firrtl.passes
+package wiring
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import scala.collection.mutable
+import firrtl.Annotations._
+import WiringUtils._
+
+case class WiringException(msg: String) extends PassException(msg)
+
+case class WiringInfo(source: String, comp: String, sinks: Map[String, String], top: String)
+
+class Wiring(wi: WiringInfo) extends Pass {
+ def name = this.getClass.getSimpleName
+
+ /** Add pins to modules and wires a signal to them, under the scope of a specified top module
+ * Description:
+ * Adds a pin to each sink module
+ * Punches ports up from the source signal to the specified top module
+ * Punches ports down to each sink module
+ * Wires the source up and down, connecting to all sink modules
+ * Restrictions:
+ * - Can only have one source module instance under scope of the specified top
+ * - All instances of each sink module must be under the scope of the specified top
+ * Notes:
+ * - No module uniquification occurs (due to imposed restrictions)
+ */
+ def run(c: Circuit): Circuit = {
+ // Split out WiringInfo
+ val source = wi.source
+ val sinks = wi.sinks.keys.toSet
+ val compName = wi.comp
+
+ // Maps modules to children instances, i.e. (instance, module)
+ val childrenMap = getChildrenMap(c)
+
+ // Check restrictions
+ val nSources = countInstances(childrenMap, wi.top, source)
+ if(nSources != 1)
+ throw new WiringException(s"Cannot have $nSources instance of $source under ${wi.top}")
+ sinks.foreach { m =>
+ val total = countInstances(childrenMap, c.main, m)
+ val nTop = countInstances(childrenMap, c.main, wi.top)
+ val perTop = countInstances(childrenMap, wi.top, m)
+ if(total != nTop * perTop)
+ throw new WiringException(s"Module ${wi.top} does not contain all instances of $m.")
+ }
+
+ // Create valid port names for wiring that have no name conflicts
+ val portNames = c.modules.foldLeft(Map.empty[String, String]) { (map, m) =>
+ map + (m.name -> {
+ val ns = Namespace(m)
+ wi.sinks.get(m.name) match {
+ case Some(pin) => ns.newName(pin)
+ case None => ns.newName(tokenize(compName) filterNot ("[]." contains _) mkString "_")
+ }
+ })
+ }
+
+ // Create a lineage tree from children map
+ val lineages = getLineage(childrenMap, wi.top)
+
+ // Populate lineage tree with relationship information, i.e. who is source,
+ // sink, parent of source, etc.
+ val withFields = setSharedParent(wi.top)(setFields(sinks, source)(lineages))
+
+ // Populate lineage tree with what to instantiate, connect to/from, etc.
+ val withThings = setThings(portNames, compName)(withFields)
+
+ // Create a map from module name to lineage information
+ val map = pointToLineage(withThings)
+
+ // Obtain the source component type
+ val sourceComponentType = getType(c, source, compName)
+
+ // Return new circuit with correct wiring
+ val cx = c.copy(modules = c.modules map onModule(map, sourceComponentType))
+
+ // Replace inserted IR nodes with WIR nodes
+ ToWorkingIR.run(cx)
+ }
+
+ /** Return new module with correct wiring
+ */
+ private def onModule(map: Map[String, Lineage], t: Type)(m: DefModule) = {
+ map.get(m.name) match {
+ case None => m
+ case Some(l) =>
+ val stmts = mutable.ArrayBuffer[Statement]()
+ val ports = mutable.ArrayBuffer[Port]()
+ l.addPort match {
+ case None =>
+ case Some((s, dt)) => dt match {
+ case DecInput => ports += Port(NoInfo, s, Input, t)
+ case DecOutput => ports += Port(NoInfo, s, Output, t)
+ case DecWire =>
+ stmts += DefWire(NoInfo, s, t)
+ }
+ }
+ stmts ++= (l.cons map { case ((l, r)) =>
+ Connect(NoInfo, toExp(l), toExp(r))
+ })
+ def onStmt(s: Statement): Statement = Block(Seq(s) ++ stmts)
+ m match {
+ case Module(i, n, ps, s) => Module(i, n, ps ++ ports, Block(Seq(s) ++ stmts))
+ case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps ++ ports, dn, p)
+ }
+ }
+ }
+
+ /** Returns the type of the component specified
+ */
+ private def getType(c: Circuit, module: String, comp: String) = {
+ def getRoot(e: Expression): String = e match {
+ case r: Reference => r.name
+ case i: SubIndex => getRoot(i.expr)
+ case a: SubAccess => getRoot(a.expr)
+ case f: SubField => getRoot(f.expr)
+ }
+ val eComp = toExp(comp)
+ val root = getRoot(eComp)
+ var tpe: Option[Type] = None
+ def getType(s: Statement): Statement = s match {
+ case DefRegister(_, n, t, _, _, _) if n == root =>
+ tpe = Some(t)
+ s
+ case DefWire(_, n, t) if n == root =>
+ tpe = Some(t)
+ s
+ case WDefInstance(_, n, m, t) if n == root =>
+ tpe = Some(t)
+ s
+ case DefNode(_, n, e) if n == root =>
+ tpe = Some(e.tpe)
+ s
+ case sx: DefMemory if sx.name == root =>
+ tpe = Some(MemPortUtils.memType(sx))
+ sx
+ case sx => sx map getType
+ }
+ val m = c.modules find (_.name == module) getOrElse error(s"Must have a module named $module")
+ tpe = m.ports find (_.name == root) map (_.tpe)
+ m match {
+ case Module(i, n, ps, b) => getType(b)
+ case e: ExtModule =>
+ }
+ tpe match {
+ case None => error(s"Didn't find $comp in $module!")
+ case Some(t) =>
+ def setType(e: Expression): Expression = e map setType match {
+ case ex: Reference => ex.copy(tpe = t)
+ case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name))
+ case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe))
+ case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe))
+ }
+ setType(eComp).tpe
+ }
+ }
+}
diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
new file mode 100644
index 00000000..919948b6
--- /dev/null
+++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
@@ -0,0 +1,80 @@
+package firrtl.passes
+package wiring
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import scala.collection.mutable
+import firrtl.Annotations._
+import WiringUtils._
+
+/** A component, e.g. register etc. Must be declared only once under the TopAnnotation
+ */
+case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotation with Loose with Unstable {
+ def duplicate(n: Named) = n match {
+ case n: ComponentName => this.copy(target = n)
+ case _ => throwInternalError
+ }
+}
+
+/** A module, e.g. ExtModule etc., that should add the input pin
+ */
+case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends Annotation with Loose with Unstable {
+ def duplicate(n: Named) = n match {
+ case n: ModuleName => this.copy(target = n)
+ case _ => throwInternalError
+ }
+}
+
+/** A module under which all sink module must be declared, and there is only
+ * one source component
+ */
+case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation with Loose with Unstable {
+ def duplicate(n: Named) = n match {
+ case n: ModuleName => this.copy(target = n)
+ case _ => throwInternalError
+ }
+}
+
+/** Add pins to modules and wires a signal to them, under the scope of a specified top module
+ * Description:
+ * Adds a pin to each sink module
+ * Punches ports up from the source signal to the specified top module
+ * Punches ports down to each sink module
+ * Wires the source up and down, connecting to all sink modules
+ * Restrictions:
+ * - Can only have one source module instance under scope of the specified top
+ * - All instances of each sink module must be under the scope of the specified top
+ * Notes:
+ * - No module uniquification occurs (due to imposed restrictions)
+ */
+class WiringTransform(transID: TransID) extends Transform with SimpleRun {
+ def passSeq(wi: WiringInfo) =
+ Seq(new Wiring(wi),
+ InferTypes,
+ ResolveKinds,
+ ResolveGenders)
+ def execute(c: Circuit, map: AnnotationMap) = map get transID match {
+ case Some(p) =>
+ val sinks = mutable.HashMap[String, String]()
+ val sources = mutable.Set[String]()
+ val tops = mutable.Set[String]()
+ val comp = mutable.Set[String]()
+ p.values.foreach { a =>
+ a match {
+ case SinkAnnotation(m, _, pin) => sinks(m.name) = pin
+ case SourceAnnotation(c, _) =>
+ sources += c.module.name
+ comp += c.name
+ case TopAnnotation(m, _) => tops += m.name
+ }
+ }
+ (sources.size, tops.size, sinks.size, comp.size) match {
+ case (0, 0, p, 0) => TransformResult(c)
+ case (1, 1, p, 1) if p > 0 => run(c, passSeq(WiringInfo(sources.head, comp.head, sinks.toMap, tops.head)))
+ case _ => error("Wrong number of sources, tops, or sinks!")
+ }
+ case None => TransformResult(c)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
new file mode 100644
index 00000000..bfa94a81
--- /dev/null
+++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
@@ -0,0 +1,164 @@
+package firrtl.passes
+package wiring
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import scala.collection.mutable
+import firrtl.Annotations._
+import WiringUtils._
+
+/** Declaration kind in lineage (e.g. input port, output port, wire)
+ */
+sealed trait DecKind
+case object DecInput extends DecKind
+case object DecOutput extends DecKind
+case object DecWire extends DecKind
+
+/** A lineage tree representing the instance hierarchy in a design
+ */
+case class Lineage(
+ name: String,
+ children: Seq[(String, Lineage)] = Seq.empty,
+ source: Boolean = false,
+ sink: Boolean = false,
+ sourceParent: Boolean = false,
+ sinkParent: Boolean = false,
+ sharedParent: Boolean = false,
+ addPort: Option[(String, DecKind)] = None,
+ cons: Seq[(String, String)] = Seq.empty) {
+
+ def map(f: Lineage => Lineage): Lineage =
+ this.copy(children = children.map{ case (i, m) => (i, f(m)) })
+
+ override def toString: String = shortSerialize("")
+
+ def shortSerialize(tab: String): String = s"""
+ |$tab name: $name,
+ |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))}
+ |""".stripMargin
+
+ def serialize(tab: String): String = s"""
+ |$tab name: $name,
+ |$tab source: $source,
+ |$tab sink: $sink,
+ |$tab sourceParent: $sourceParent,
+ |$tab sinkParent: $sinkParent,
+ |$tab sharedParent: $sharedParent,
+ |$tab addPort: $addPort
+ |$tab cons: $cons
+ |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))}
+ |""".stripMargin
+}
+
+
+
+
+object WiringUtils {
+ type ChildrenMap = mutable.HashMap[String, Seq[(String, String)]]
+
+ /** Given a circuit, returns a map from module name to children
+ * instance/module names
+ */
+ def getChildrenMap(c: Circuit): ChildrenMap = {
+ val childrenMap = new ChildrenMap()
+ def getChildren(mname: String)(s: Statement): Statement = s match {
+ case s: WDefInstance =>
+ childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module)
+ s
+ case s: DefInstance =>
+ childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module)
+ s
+ case s => s map getChildren(mname)
+ }
+ c.modules.foreach{ m =>
+ childrenMap(m.name) = Nil
+ m map getChildren(m.name)
+ }
+ childrenMap
+ }
+
+ /** Counts the number of instances of a module declared under a top module
+ */
+ def countInstances(childrenMap: ChildrenMap, top: String, module: String): Int = {
+ if(top == module) 1
+ else childrenMap(top).foldLeft(0) { case (count, (i, child)) =>
+ count + countInstances(childrenMap, child, module)
+ }
+ }
+
+ /** Returns a module's lineage, containing all children lineages as well
+ */
+ def getLineage(childrenMap: ChildrenMap, module: String): Lineage =
+ Lineage(module, childrenMap(module) map { case (i, m) => (i, getLineage(childrenMap, m)) } )
+
+ /** Sets the sink, sinkParent, source, and sourceParent fields of every
+ * Lineage in tree
+ */
+ def setFields(sinks: Set[String], source: String)(lin: Lineage): Lineage = lin map setFields(sinks, source) match {
+ case l if sinks.contains(l.name) => l.copy(sink = true)
+ case l =>
+ val src = l.name == source
+ val sinkParent = l.children.foldLeft(false) { case (b, (i, m)) => b || m.sink || m.sinkParent }
+ val sourceParent = if(src) true else l.children.foldLeft(false) { case (b, (i, m)) => b || m.source || m.sourceParent }
+ l.copy(sinkParent=sinkParent, sourceParent=sourceParent, source=src)
+ }
+
+ /** Sets the sharedParent of lineage top
+ */
+ def setSharedParent(top: String)(lin: Lineage): Lineage = lin map setSharedParent(top) match {
+ case l if l.name == top => l.copy(sharedParent = true)
+ case l => l
+ }
+
+ /** Sets the addPort and cons fields of the lineage tree
+ */
+ def setThings(portNames:Map[String, String], compName: String)(lin: Lineage): Lineage = {
+ val funs = Seq(
+ ((l: Lineage) => l map setThings(portNames, compName)),
+ ((l: Lineage) => l match {
+ case Lineage(name, _, _, _, _, _, true, _, _) => //SharedParent
+ l.copy(addPort=Some((portNames(name), DecWire)))
+ case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent
+ l.copy(addPort=Some((portNames(name), DecOutput)))
+ case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent
+ l.copy(addPort=Some((portNames(name), DecInput)))
+ case Lineage(name, _, _, true, _, _, _, _, _) => //Sink
+ l.copy(addPort=Some((portNames(name), DecInput)))
+ case l => l
+ }),
+ ((l: Lineage) => l match {
+ case Lineage(name, _, true, _, _, _, _, _, _) => //Source
+ val tos = Seq(s"${portNames(name)}")
+ val from = compName
+ l.copy(cons = l.cons ++ tos.map(t => (t, from)))
+ case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent
+ val tos = Seq(s"${portNames(name)}")
+ val from = l.children.filter { case (i, c) => c.sourceParent }.map { case (i, c) => s"$i.${portNames(c.name)}" }.head
+ l.copy(cons = l.cons ++ tos.map(t => (t, from)))
+ case l => l
+ }),
+ ((l: Lineage) => l match {
+ case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent
+ val tos = l.children.filter { case (i, c) => (c.sinkParent || c.sink) && !c.sourceParent } map { case (i, c) => s"$i.${portNames(c.name)}" }
+ val from = s"${portNames(name)}"
+ l.copy(cons = l.cons ++ tos.map(t => (t, from)))
+ case l => l
+ })
+ )
+ funs.foldLeft(lin)((l, fun) => fun(l))
+ }
+
+ /** Return a map from module to its lineage in the tree
+ */
+ def pointToLineage(lin: Lineage): Map[String, Lineage] = {
+ val map = mutable.HashMap[String, Lineage]()
+ def onLineage(l: Lineage): Lineage = {
+ map(l.name) = l
+ l map onLineage
+ }
+ onLineage(lin)
+ map.toMap
+ }
+}
diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala
new file mode 100644
index 00000000..5f40d861
--- /dev/null
+++ b/src/test/scala/firrtlTests/WiringTests.scala
@@ -0,0 +1,347 @@
+package firrtlTests
+
+import java.io._
+import org.scalatest._
+import org.scalatest.prop._
+import firrtl._
+import firrtl.ir.Circuit
+import firrtl.passes._
+import firrtl.Parser.IgnoreInfo
+import Annotations._
+import wiring.WiringUtils._
+import wiring._
+
+class WiringTests extends FirrtlFlatSpec {
+ def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
+ private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
+ val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val lines = c.serialize.split("\n") map normalized
+
+ expected foreach { e =>
+ lines should contain(e)
+ }
+ }
+
+ def passes = Seq(
+ ToWorkingIR,
+ ResolveKinds,
+ InferTypes,
+ ResolveGenders,
+ InferWidths
+ )
+
+ "Wiring from r to X" should "work" in {
+ val sinks = Map(("X"-> "pin"))
+ val sas = WiringInfo("C", "r", sinks, "A")
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a of A
+ | a.clk <= clk
+ | module A :
+ | input clk: Clock
+ | inst b of B
+ | b.clk <= clk
+ | inst x of X
+ | x.clk <= clk
+ | inst d of D
+ | d.clk <= clk
+ | module B :
+ | input clk: Clock
+ | inst c of C
+ | c.clk <= clk
+ | inst d of D
+ | d.clk <= clk
+ | module C :
+ | input clk: Clock
+ | reg r: UInt<5>, clk
+ | module D :
+ | input clk: Clock
+ | inst x1 of X
+ | x1.clk <= clk
+ | inst x2 of X
+ | x2.clk <= clk
+ | extmodule X :
+ | input clk: Clock
+ |""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a of A
+ | a.clk <= clk
+ | module A :
+ | input clk: Clock
+ | inst b of B
+ | b.clk <= clk
+ | inst x of X
+ | x.clk <= clk
+ | inst d of D
+ | d.clk <= clk
+ | wire r: UInt<5>
+ | r <= b.r
+ | x.pin <= r
+ | d.r <= r
+ | module B :
+ | input clk: Clock
+ | output r: UInt<5>
+ | inst c of C
+ | c.clk <= clk
+ | inst d of D
+ | d.clk <= clk
+ | r <= c.r_0
+ | d.r <= r
+ | module C :
+ | input clk: Clock
+ | output r_0: UInt<5>
+ | reg r: UInt<5>, clk
+ | r_0 <= r
+ | module D :
+ | input clk: Clock
+ | input r: UInt<5>
+ | inst x1 of X
+ | x1.clk <= clk
+ | inst x2 of X
+ | x2.clk <= clk
+ | x1.pin <= r
+ | x2.pin <= r
+ | extmodule X :
+ | input clk: Clock
+ | input pin: UInt<5>
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(sas)
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+
+ "Wiring from r.x to X" should "work" in {
+ val sinks = Map(("X"-> "pin"))
+ val sas = WiringInfo("A", "r.x", sinks, "A")
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a of A
+ | a.clk <= clk
+ | module A :
+ | input clk: Clock
+ | reg r : {x: UInt<5>}, clk
+ | inst x of X
+ | x.clk <= clk
+ | extmodule X :
+ | input clk: Clock
+ |""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a of A
+ | a.clk <= clk
+ | module A :
+ | input clk: Clock
+ | reg r: {x: UInt<5>}, clk
+ | inst x of X
+ | x.clk <= clk
+ | wire r_x: UInt<5>
+ | r_x <= r.x
+ | x.pin <= r_x
+ | extmodule X :
+ | input clk: Clock
+ | input pin: UInt<5>
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(sas)
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+ "Wiring from clk to X" should "work" in {
+ val sinks = Map(("X"-> "pin"))
+ val sas = WiringInfo("A", "clk", sinks, "A")
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a of A
+ | a.clk <= clk
+ | module A :
+ | input clk: Clock
+ | inst x of X
+ | x.clk <= clk
+ | extmodule X :
+ | input clk: Clock
+ |""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a of A
+ | a.clk <= clk
+ | module A :
+ | input clk: Clock
+ | inst x of X
+ | x.clk <= clk
+ | wire clk_0: Clock
+ | clk_0 <= clk
+ | x.pin <= clk_0
+ | extmodule X :
+ | input clk: Clock
+ | input pin: Clock
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(sas)
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+ "Two sources" should "fail" in {
+ val sinks = Map(("X"-> "pin"))
+ val sas = WiringInfo("A", "clk", sinks, "Top")
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a1 of A
+ | a1.clk <= clk
+ | inst a2 of A
+ | a2.clk <= clk
+ | module A :
+ | input clk: Clock
+ | inst x of X
+ | x.clk <= clk
+ | extmodule X :
+ | input clk: Clock
+ |""".stripMargin
+ intercept[WiringException] {
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(sas)
+ val retC = wiringPass.run(c)
+ }
+ }
+ "Wiring from A.clk to X, with 2 A's, and A as top" should "work" in {
+ val sinks = Map(("X"-> "pin"))
+ val sas = WiringInfo("A", "clk", sinks, "A")
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a1 of A
+ | a1.clk <= clk
+ | inst a2 of A
+ | a2.clk <= clk
+ | module A :
+ | input clk: Clock
+ | inst x of X
+ | x.clk <= clk
+ | extmodule X :
+ | input clk: Clock
+ |""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a1 of A
+ | a1.clk <= clk
+ | inst a2 of A
+ | a2.clk <= clk
+ | module A :
+ | input clk: Clock
+ | inst x of X
+ | x.clk <= clk
+ | wire clk_0: Clock
+ | clk_0 <= clk
+ | x.pin <= clk_0
+ | extmodule X :
+ | input clk: Clock
+ | input pin: Clock
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(sas)
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+ "Wiring from A.clk to X, with 2 A's, and A as top, but Top instantiates X" should "error" in {
+ val sinks = Map(("X"-> "pin"))
+ val sas = WiringInfo("A", "clk", sinks, "A")
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a1 of A
+ | a1.clk <= clk
+ | inst a2 of A
+ | a2.clk <= clk
+ | inst x of X
+ | x.clk <= clk
+ | module A :
+ | input clk: Clock
+ | inst x of X
+ | x.clk <= clk
+ | extmodule X :
+ | input clk: Clock
+ |""".stripMargin
+ intercept[WiringException] {
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(sas)
+ val retC = wiringPass.run(c)
+ }
+ }
+ "Wiring from A.r[a] to X" should "work" in {
+ val sinks = Map(("X"-> "pin"))
+ val sas = WiringInfo("A", "r[a]", sinks, "A")
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a of A
+ | a.clk <= clk
+ | module A :
+ | input clk: Clock
+ | reg r: UInt<2>[5], clk
+ | node a = UInt(5)
+ | inst x of X
+ | x.clk <= clk
+ | extmodule X :
+ | input clk: Clock
+ |""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | inst a of A
+ | a.clk <= clk
+ | module A :
+ | input clk: Clock
+ | reg r: UInt<2>[5], clk
+ | node a = UInt(5)
+ | inst x of X
+ | x.clk <= clk
+ | wire r_a: UInt<2>
+ | r_a <= r[a]
+ | x.pin <= r_a
+ | extmodule X :
+ | input clk: Clock
+ | input pin: UInt<2>
+ |""".stripMargin
+ val c = passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val wiringPass = new Wiring(sas)
+ val retC = wiringPass.run(c)
+ (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ }
+}