diff options
| author | Adam Izraelevitz | 2016-10-27 13:00:02 -0700 |
|---|---|---|
| committer | GitHub | 2016-10-27 13:00:02 -0700 |
| commit | 5b35f2d2722f72c81d2d6c507cd379be2a1476d8 (patch) | |
| tree | 78dc2db9e12c6db52fcbf222e339a37b6ebc0b72 /src | |
| parent | 1c61a0e7102983891d99d8e9c49e331c8a2178a6 (diff) | |
Wiring (#348)
Added wiring pass and simple test
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Annotations.scala | 71 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/DecorateMems.scala | 25 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala | 27 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala | 46 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/YamlUtils.scala | 10 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/wiring/Wiring.scala | 162 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/wiring/WiringTransform.scala | 80 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/wiring/WiringUtils.scala | 164 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/WiringTests.scala | 347 |
10 files changed, 913 insertions, 21 deletions
diff --git a/src/main/scala/firrtl/Annotations.scala b/src/main/scala/firrtl/Annotations.scala index 2d76a832..d47ce67e 100644 --- a/src/main/scala/firrtl/Annotations.scala +++ b/src/main/scala/firrtl/Annotations.scala @@ -1,5 +1,7 @@ package firrtl +import firrtl.ir._ + import scala.collection.mutable import java.io.Writer @@ -38,6 +40,62 @@ import java.io.Writer * ----------|----------|----------|------------|-----------| */ object Annotations { + /** Returns true if a valid Module name */ + val SerializedModuleName = """([a-zA-Z_][a-zA-Z_0-9~!@#$%^*\-+=?/]*)""".r + def validModuleName(s: String): Boolean = s match { + case SerializedModuleName(name) => true + case _ => false + } + + /** Returns true if a valid component/subcomponent name */ + val SerializedComponentName = """([a-zA-Z_][a-zA-Z_0-9\[\]\.~!@#$%^*\-+=?/]*)""".r + def validComponentName(s: String): Boolean = s match { + case SerializedComponentName(name) => true + case _ => false + } + + /** Tokenizes a string with '[', ']', '.' as tokens, e.g.: + * "foo.bar[boo.far]" becomes Seq("foo" "." "bar" "[" "boo" "." "far" "]") + */ + def tokenize(s: String): Seq[String] = s.find(c => "[].".contains(c)) match { + case Some(_) => + val i = s.indexWhere(c => "[].".contains(c)) + Seq(s.slice(0, i), s(i).toString) ++ tokenize(s.drop(i + 1)) + case None => Seq(s) + } + + /** Given a serialized component/subcomponent reference, subindex, subaccess, + * or subfield, return the corresponding IR expression. + */ + def toExp(s: String): Expression = { + def parse(tokens: Seq[String]): Expression = { + val DecPattern = """([1-9]\d*)""".r + def findClose(tokens: Seq[String], index: Int, nOpen: Int): Seq[String] = + if(index >= tokens.size) { + error("Cannot find closing bracket ]") + } else tokens(index) match { + case "[" => findClose(tokens, index + 1, nOpen + 1) + case "]" if nOpen == 1 => tokens.slice(1, index) + case _ => findClose(tokens, index + 1, nOpen) + } + def buildup(e: Expression, tokens: Seq[String]): Expression = tokens match { + case "[" :: tail => + val indexOrAccess = findClose(tokens, 0, 0) + indexOrAccess.head match { + case DecPattern(d) => SubIndex(e, d.toInt, UnknownType) + case _ => buildup(SubAccess(e, parse(indexOrAccess), UnknownType), tokens.slice(1, indexOrAccess.size)) + } + case "." :: tail => + buildup(SubField(e, tokens(1), UnknownType), tokens.drop(2)) + case Nil => e + } + val root = Reference(tokens.head, UnknownType) + buildup(root, tokens.tail) + } + if(validComponentName(s)) { + parse(tokenize(s)) + } else error(s"Cannot convert $s into an expression.") + } case class AnnotationException(message: String) extends Exception(message) @@ -45,9 +103,16 @@ object Annotations { * Named classes associate an annotation with a component in a Firrtl circuit */ trait Named { def name: String } - case class CircuitName(name: String) extends Named - case class ModuleName(name: String, circuit: CircuitName) extends Named - case class ComponentName(name: String, module: ModuleName) extends Named + case class CircuitName(name: String) extends Named { + if(!validModuleName(name)) throw AnnotationException(s"Illegal circuit name: $name") + } + case class ModuleName(name: String, circuit: CircuitName) extends Named { + if(!validModuleName(name)) throw AnnotationException(s"Illegal module name: $name") + } + case class ComponentName(name: String, module: ModuleName) extends Named { + if(!validComponentName(name)) throw AnnotationException(s"Illegal component name: $name") + def expr: Expression = toExp(name) + } /** * Transform ID (TransID) associates an annotation with an instantiated diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 294afe57..22a3eac6 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -50,6 +50,8 @@ import com.typesafe.scalalogging.LazyLogging class FIRRTLException(str: String) extends Exception(str) object Utils extends LazyLogging { + def throwInternalError = + error("Internal Error! Please file an issue at https://github.com/ucb-bar/firrtl/issues") private[firrtl] def time[R](name: String)(block: => R): R = { logger.info(s"Starting $name") val t0 = System.nanoTime() diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala new file mode 100644 index 00000000..10cc8f88 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -0,0 +1,25 @@ +package firrtl +package passes +package memlib +import ir._ +import Annotations._ +import wiring._ + +class CreateMemoryAnnotations(reader: Option[YamlFileReader], replaceID: TransID, wiringID: TransID) extends Transform { + def name = "Create Memory Annotations" + def execute(c: Circuit, map: AnnotationMap): TransformResult = reader match { + case None => TransformResult(c) + case Some(r) => + import CustomYAMLProtocol._ + r.parse[Config] match { + case Seq(config) => + val cN = CircuitName(c.main) + val top = TopAnnotation(ModuleName(config.top.name, cN), wiringID) + val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)), wiringID) + val pin = PinAnnotation(cN, replaceID, config.pin.name) + TransformResult(c, None, Some(AnnotationMap(Seq(top, source, pin)))) + case Nil => TransformResult(c, None, None) + case _ => error("Can only have one config in yaml file") + } + } +} diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index a52f7d38..9ab496d2 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -10,12 +10,24 @@ import firrtl.Mappers._ import MemPortUtils.{MemPortMap, Modules} import MemTransformUtils._ import AnalysisUtils._ +import Annotations._ +import wiring._ + + +/** Annotates the name of the pin to add for WiringTransform + */ +case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends Annotation with Loose with Unstable { + def duplicate(n: Named) = n match { + case n: CircuitName => this.copy(target = n) + case _ => throwInternalError + } +} /** Replace DefAnnotatedMemory with memory blackbox + wrapper + conf file. * This will not generate wmask ports if not needed. * Creates the minimum # of black boxes needed by the design. */ -class ReplaceMemMacros(writer: ConfWriter) extends Pass { +class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) extends Transform { def name = "Replace Memory Macros" /** Return true if mask granularity is per bit, false if per byte or unspecified @@ -194,7 +206,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { map updateStmtRefs(memPortMap)) } - def run(c: Circuit) = { + def execute(c: Circuit, map: AnnotationMap): TransformResult = { val namespace = Namespace(c) val memMods = new Modules val nameMap = new NameMap @@ -202,6 +214,15 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { val modules = c.modules map updateMemMods(namespace, nameMap, memMods) // print conf writer.serialize() - c copy (modules = modules ++ memMods) + val pin = map get myID match { + case Some(p) => + p.values.head match { + case PinAnnotation(c, _, pin) => pin + case _ => error(s"Bad Annotations: ${p.values}") + } + case None => "pin" + } + val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), wiringID, pin) } + TransformResult(c.copy(modules = modules ++ memMods), None, Some(AnnotationMap(annos))) } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index dfa828c9..01f020f5 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -9,6 +9,7 @@ import Annotations._ import AnalysisUtils._ import Utils.error import java.io.{File, CharArrayWriter, PrintWriter} +import wiring._ sealed trait PassOption case object InputConfigFileName extends PassOption @@ -92,21 +93,36 @@ Optional Arguments: def duplicate(n: Named) = this copy (t = t.replace(s"-c:$passCircuit", s"-c:${n.name}")) } +case class SimpleTransform(p: Pass) extends Transform { + def execute(c: Circuit, map: AnnotationMap): TransformResult = + TransformResult(p.run(c)) +} class ReplSeqMem(transID: TransID) extends Transform with SimpleRun { - def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter) = - Seq(Legalize, - ToMemIR, - ResolveMaskGranularity, - RenameAnnotatedMemoryPorts, - ResolveMemoryReference, - //new AnnotateValidMemConfigs(inConfigFile), - new ReplaceMemMacros(outConfigFile), - RemoveEmpty, - CheckInitialization, - InferTypes, - Uniquify, - ResolveKinds, // Must be run for the transform to work! - ResolveGenders) + def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = + Seq(SimpleTransform(Legalize), + SimpleTransform(ToMemIR), + SimpleTransform(ResolveMaskGranularity), + SimpleTransform(RenameAnnotatedMemoryPorts), + SimpleTransform(ResolveMemoryReference), + new CreateMemoryAnnotations(inConfigFile, TransID(-7), TransID(-8)), + new ReplaceMemMacros(outConfigFile, TransID(-7), TransID(-8)), + new WiringTransform(TransID(-8)), + SimpleTransform(RemoveEmpty), + SimpleTransform(CheckInitialization), + SimpleTransform(InferTypes), + SimpleTransform(Uniquify), + SimpleTransform(ResolveKinds), + SimpleTransform(ResolveGenders)) + def run(circuit: Circuit, map: AnnotationMap, xForms: Seq[Transform]): TransformResult = { + (xForms.foldLeft(TransformResult(circuit, None, Some(map)))) { case (tr: TransformResult, xForm: Transform) => + val x = xForm.execute(tr.circuit, tr.annotation.get) + x.annotation match { + case None => TransformResult(x.circuit, None, Some(map)) + case Some(ann) => TransformResult(x.circuit, None, Some( + AnnotationMap(ann.annotations ++ tr.annotation.get.annotations))) + } + } + } def execute(c: Circuit, map: AnnotationMap) = map get transID match { case Some(p) => p get CircuitName(c.main) match { @@ -118,7 +134,7 @@ class ReplSeqMem(transID: TransID) extends Transform with SimpleRun { else error("Input configuration file does not exist!") } val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName)) - run(c, passSeq(inConfigFile, outConfigFile)) + run(c, map, passSeq(inConfigFile, outConfigFile)) case _ => error("Unexpected transform annotation") } case _ => TransformResult(c) diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala index a1088300..fcef4229 100644 --- a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala @@ -5,8 +5,18 @@ import java.io.{File, CharArrayWriter, PrintWriter} object CustomYAMLProtocol extends DefaultYamlProtocol { // bottom depends on top + implicit val _pin = yamlFormat1(Pin) + implicit val _source = yamlFormat2(Source) + implicit val _top = yamlFormat1(Top) + implicit val _configs = yamlFormat3(Config) } +case class Pin(name: String) +case class Source(name: String, module: String) +case class Top(name: String) +case class Config(pin: Pin, source: Source, top: Top) + + class YamlFileReader(file: String) { import CustomYAMLProtocol._ def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala new file mode 100644 index 00000000..d3f6f3dd --- /dev/null +++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala @@ -0,0 +1,162 @@ +package firrtl.passes +package wiring + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import scala.collection.mutable +import firrtl.Annotations._ +import WiringUtils._ + +case class WiringException(msg: String) extends PassException(msg) + +case class WiringInfo(source: String, comp: String, sinks: Map[String, String], top: String) + +class Wiring(wi: WiringInfo) extends Pass { + def name = this.getClass.getSimpleName + + /** Add pins to modules and wires a signal to them, under the scope of a specified top module + * Description: + * Adds a pin to each sink module + * Punches ports up from the source signal to the specified top module + * Punches ports down to each sink module + * Wires the source up and down, connecting to all sink modules + * Restrictions: + * - Can only have one source module instance under scope of the specified top + * - All instances of each sink module must be under the scope of the specified top + * Notes: + * - No module uniquification occurs (due to imposed restrictions) + */ + def run(c: Circuit): Circuit = { + // Split out WiringInfo + val source = wi.source + val sinks = wi.sinks.keys.toSet + val compName = wi.comp + + // Maps modules to children instances, i.e. (instance, module) + val childrenMap = getChildrenMap(c) + + // Check restrictions + val nSources = countInstances(childrenMap, wi.top, source) + if(nSources != 1) + throw new WiringException(s"Cannot have $nSources instance of $source under ${wi.top}") + sinks.foreach { m => + val total = countInstances(childrenMap, c.main, m) + val nTop = countInstances(childrenMap, c.main, wi.top) + val perTop = countInstances(childrenMap, wi.top, m) + if(total != nTop * perTop) + throw new WiringException(s"Module ${wi.top} does not contain all instances of $m.") + } + + // Create valid port names for wiring that have no name conflicts + val portNames = c.modules.foldLeft(Map.empty[String, String]) { (map, m) => + map + (m.name -> { + val ns = Namespace(m) + wi.sinks.get(m.name) match { + case Some(pin) => ns.newName(pin) + case None => ns.newName(tokenize(compName) filterNot ("[]." contains _) mkString "_") + } + }) + } + + // Create a lineage tree from children map + val lineages = getLineage(childrenMap, wi.top) + + // Populate lineage tree with relationship information, i.e. who is source, + // sink, parent of source, etc. + val withFields = setSharedParent(wi.top)(setFields(sinks, source)(lineages)) + + // Populate lineage tree with what to instantiate, connect to/from, etc. + val withThings = setThings(portNames, compName)(withFields) + + // Create a map from module name to lineage information + val map = pointToLineage(withThings) + + // Obtain the source component type + val sourceComponentType = getType(c, source, compName) + + // Return new circuit with correct wiring + val cx = c.copy(modules = c.modules map onModule(map, sourceComponentType)) + + // Replace inserted IR nodes with WIR nodes + ToWorkingIR.run(cx) + } + + /** Return new module with correct wiring + */ + private def onModule(map: Map[String, Lineage], t: Type)(m: DefModule) = { + map.get(m.name) match { + case None => m + case Some(l) => + val stmts = mutable.ArrayBuffer[Statement]() + val ports = mutable.ArrayBuffer[Port]() + l.addPort match { + case None => + case Some((s, dt)) => dt match { + case DecInput => ports += Port(NoInfo, s, Input, t) + case DecOutput => ports += Port(NoInfo, s, Output, t) + case DecWire => + stmts += DefWire(NoInfo, s, t) + } + } + stmts ++= (l.cons map { case ((l, r)) => + Connect(NoInfo, toExp(l), toExp(r)) + }) + def onStmt(s: Statement): Statement = Block(Seq(s) ++ stmts) + m match { + case Module(i, n, ps, s) => Module(i, n, ps ++ ports, Block(Seq(s) ++ stmts)) + case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps ++ ports, dn, p) + } + } + } + + /** Returns the type of the component specified + */ + private def getType(c: Circuit, module: String, comp: String) = { + def getRoot(e: Expression): String = e match { + case r: Reference => r.name + case i: SubIndex => getRoot(i.expr) + case a: SubAccess => getRoot(a.expr) + case f: SubField => getRoot(f.expr) + } + val eComp = toExp(comp) + val root = getRoot(eComp) + var tpe: Option[Type] = None + def getType(s: Statement): Statement = s match { + case DefRegister(_, n, t, _, _, _) if n == root => + tpe = Some(t) + s + case DefWire(_, n, t) if n == root => + tpe = Some(t) + s + case WDefInstance(_, n, m, t) if n == root => + tpe = Some(t) + s + case DefNode(_, n, e) if n == root => + tpe = Some(e.tpe) + s + case sx: DefMemory if sx.name == root => + tpe = Some(MemPortUtils.memType(sx)) + sx + case sx => sx map getType + } + val m = c.modules find (_.name == module) getOrElse error(s"Must have a module named $module") + tpe = m.ports find (_.name == root) map (_.tpe) + m match { + case Module(i, n, ps, b) => getType(b) + case e: ExtModule => + } + tpe match { + case None => error(s"Didn't find $comp in $module!") + case Some(t) => + def setType(e: Expression): Expression = e map setType match { + case ex: Reference => ex.copy(tpe = t) + case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name)) + case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe)) + case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe)) + } + setType(eComp).tpe + } + } +} diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala new file mode 100644 index 00000000..919948b6 --- /dev/null +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -0,0 +1,80 @@ +package firrtl.passes +package wiring + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import scala.collection.mutable +import firrtl.Annotations._ +import WiringUtils._ + +/** A component, e.g. register etc. Must be declared only once under the TopAnnotation + */ +case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotation with Loose with Unstable { + def duplicate(n: Named) = n match { + case n: ComponentName => this.copy(target = n) + case _ => throwInternalError + } +} + +/** A module, e.g. ExtModule etc., that should add the input pin + */ +case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends Annotation with Loose with Unstable { + def duplicate(n: Named) = n match { + case n: ModuleName => this.copy(target = n) + case _ => throwInternalError + } +} + +/** A module under which all sink module must be declared, and there is only + * one source component + */ +case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation with Loose with Unstable { + def duplicate(n: Named) = n match { + case n: ModuleName => this.copy(target = n) + case _ => throwInternalError + } +} + +/** Add pins to modules and wires a signal to them, under the scope of a specified top module + * Description: + * Adds a pin to each sink module + * Punches ports up from the source signal to the specified top module + * Punches ports down to each sink module + * Wires the source up and down, connecting to all sink modules + * Restrictions: + * - Can only have one source module instance under scope of the specified top + * - All instances of each sink module must be under the scope of the specified top + * Notes: + * - No module uniquification occurs (due to imposed restrictions) + */ +class WiringTransform(transID: TransID) extends Transform with SimpleRun { + def passSeq(wi: WiringInfo) = + Seq(new Wiring(wi), + InferTypes, + ResolveKinds, + ResolveGenders) + def execute(c: Circuit, map: AnnotationMap) = map get transID match { + case Some(p) => + val sinks = mutable.HashMap[String, String]() + val sources = mutable.Set[String]() + val tops = mutable.Set[String]() + val comp = mutable.Set[String]() + p.values.foreach { a => + a match { + case SinkAnnotation(m, _, pin) => sinks(m.name) = pin + case SourceAnnotation(c, _) => + sources += c.module.name + comp += c.name + case TopAnnotation(m, _) => tops += m.name + } + } + (sources.size, tops.size, sinks.size, comp.size) match { + case (0, 0, p, 0) => TransformResult(c) + case (1, 1, p, 1) if p > 0 => run(c, passSeq(WiringInfo(sources.head, comp.head, sinks.toMap, tops.head))) + case _ => error("Wrong number of sources, tops, or sinks!") + } + case None => TransformResult(c) + } +} diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala new file mode 100644 index 00000000..bfa94a81 --- /dev/null +++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala @@ -0,0 +1,164 @@ +package firrtl.passes +package wiring + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import scala.collection.mutable +import firrtl.Annotations._ +import WiringUtils._ + +/** Declaration kind in lineage (e.g. input port, output port, wire) + */ +sealed trait DecKind +case object DecInput extends DecKind +case object DecOutput extends DecKind +case object DecWire extends DecKind + +/** A lineage tree representing the instance hierarchy in a design + */ +case class Lineage( + name: String, + children: Seq[(String, Lineage)] = Seq.empty, + source: Boolean = false, + sink: Boolean = false, + sourceParent: Boolean = false, + sinkParent: Boolean = false, + sharedParent: Boolean = false, + addPort: Option[(String, DecKind)] = None, + cons: Seq[(String, String)] = Seq.empty) { + + def map(f: Lineage => Lineage): Lineage = + this.copy(children = children.map{ case (i, m) => (i, f(m)) }) + + override def toString: String = shortSerialize("") + + def shortSerialize(tab: String): String = s""" + |$tab name: $name, + |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))} + |""".stripMargin + + def serialize(tab: String): String = s""" + |$tab name: $name, + |$tab source: $source, + |$tab sink: $sink, + |$tab sourceParent: $sourceParent, + |$tab sinkParent: $sinkParent, + |$tab sharedParent: $sharedParent, + |$tab addPort: $addPort + |$tab cons: $cons + |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))} + |""".stripMargin +} + + + + +object WiringUtils { + type ChildrenMap = mutable.HashMap[String, Seq[(String, String)]] + + /** Given a circuit, returns a map from module name to children + * instance/module names + */ + def getChildrenMap(c: Circuit): ChildrenMap = { + val childrenMap = new ChildrenMap() + def getChildren(mname: String)(s: Statement): Statement = s match { + case s: WDefInstance => + childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module) + s + case s: DefInstance => + childrenMap(mname) = childrenMap(mname) :+ (s.name, s.module) + s + case s => s map getChildren(mname) + } + c.modules.foreach{ m => + childrenMap(m.name) = Nil + m map getChildren(m.name) + } + childrenMap + } + + /** Counts the number of instances of a module declared under a top module + */ + def countInstances(childrenMap: ChildrenMap, top: String, module: String): Int = { + if(top == module) 1 + else childrenMap(top).foldLeft(0) { case (count, (i, child)) => + count + countInstances(childrenMap, child, module) + } + } + + /** Returns a module's lineage, containing all children lineages as well + */ + def getLineage(childrenMap: ChildrenMap, module: String): Lineage = + Lineage(module, childrenMap(module) map { case (i, m) => (i, getLineage(childrenMap, m)) } ) + + /** Sets the sink, sinkParent, source, and sourceParent fields of every + * Lineage in tree + */ + def setFields(sinks: Set[String], source: String)(lin: Lineage): Lineage = lin map setFields(sinks, source) match { + case l if sinks.contains(l.name) => l.copy(sink = true) + case l => + val src = l.name == source + val sinkParent = l.children.foldLeft(false) { case (b, (i, m)) => b || m.sink || m.sinkParent } + val sourceParent = if(src) true else l.children.foldLeft(false) { case (b, (i, m)) => b || m.source || m.sourceParent } + l.copy(sinkParent=sinkParent, sourceParent=sourceParent, source=src) + } + + /** Sets the sharedParent of lineage top + */ + def setSharedParent(top: String)(lin: Lineage): Lineage = lin map setSharedParent(top) match { + case l if l.name == top => l.copy(sharedParent = true) + case l => l + } + + /** Sets the addPort and cons fields of the lineage tree + */ + def setThings(portNames:Map[String, String], compName: String)(lin: Lineage): Lineage = { + val funs = Seq( + ((l: Lineage) => l map setThings(portNames, compName)), + ((l: Lineage) => l match { + case Lineage(name, _, _, _, _, _, true, _, _) => //SharedParent + l.copy(addPort=Some((portNames(name), DecWire))) + case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent + l.copy(addPort=Some((portNames(name), DecOutput))) + case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent + l.copy(addPort=Some((portNames(name), DecInput))) + case Lineage(name, _, _, true, _, _, _, _, _) => //Sink + l.copy(addPort=Some((portNames(name), DecInput))) + case l => l + }), + ((l: Lineage) => l match { + case Lineage(name, _, true, _, _, _, _, _, _) => //Source + val tos = Seq(s"${portNames(name)}") + val from = compName + l.copy(cons = l.cons ++ tos.map(t => (t, from))) + case Lineage(name, _, _, _, true, _, _, _, _) => //SourceParent + val tos = Seq(s"${portNames(name)}") + val from = l.children.filter { case (i, c) => c.sourceParent }.map { case (i, c) => s"$i.${portNames(c.name)}" }.head + l.copy(cons = l.cons ++ tos.map(t => (t, from))) + case l => l + }), + ((l: Lineage) => l match { + case Lineage(name, _, _, _, _, true, _, _, _) => //SinkParent + val tos = l.children.filter { case (i, c) => (c.sinkParent || c.sink) && !c.sourceParent } map { case (i, c) => s"$i.${portNames(c.name)}" } + val from = s"${portNames(name)}" + l.copy(cons = l.cons ++ tos.map(t => (t, from))) + case l => l + }) + ) + funs.foldLeft(lin)((l, fun) => fun(l)) + } + + /** Return a map from module to its lineage in the tree + */ + def pointToLineage(lin: Lineage): Map[String, Lineage] = { + val map = mutable.HashMap[String, Lineage]() + def onLineage(l: Lineage): Lineage = { + map(l.name) = l + l map onLineage + } + onLineage(lin) + map.toMap + } +} diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala new file mode 100644 index 00000000..5f40d861 --- /dev/null +++ b/src/test/scala/firrtlTests/WiringTests.scala @@ -0,0 +1,347 @@ +package firrtlTests + +import java.io._ +import org.scalatest._ +import org.scalatest.prop._ +import firrtl._ +import firrtl.ir.Circuit +import firrtl.passes._ +import firrtl.Parser.IgnoreInfo +import Annotations._ +import wiring.WiringUtils._ +import wiring._ + +class WiringTests extends FirrtlFlatSpec { + def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) + private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + val lines = c.serialize.split("\n") map normalized + + expected foreach { e => + lines should contain(e) + } + } + + def passes = Seq( + ToWorkingIR, + ResolveKinds, + InferTypes, + ResolveGenders, + InferWidths + ) + + "Wiring from r to X" should "work" in { + val sinks = Map(("X"-> "pin")) + val sas = WiringInfo("C", "r", sinks, "A") + val input = + """circuit Top : + | module Top : + | input clk: Clock + | inst a of A + | a.clk <= clk + | module A : + | input clk: Clock + | inst b of B + | b.clk <= clk + | inst x of X + | x.clk <= clk + | inst d of D + | d.clk <= clk + | module B : + | input clk: Clock + | inst c of C + | c.clk <= clk + | inst d of D + | d.clk <= clk + | module C : + | input clk: Clock + | reg r: UInt<5>, clk + | module D : + | input clk: Clock + | inst x1 of X + | x1.clk <= clk + | inst x2 of X + | x2.clk <= clk + | extmodule X : + | input clk: Clock + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk: Clock + | inst a of A + | a.clk <= clk + | module A : + | input clk: Clock + | inst b of B + | b.clk <= clk + | inst x of X + | x.clk <= clk + | inst d of D + | d.clk <= clk + | wire r: UInt<5> + | r <= b.r + | x.pin <= r + | d.r <= r + | module B : + | input clk: Clock + | output r: UInt<5> + | inst c of C + | c.clk <= clk + | inst d of D + | d.clk <= clk + | r <= c.r_0 + | d.r <= r + | module C : + | input clk: Clock + | output r_0: UInt<5> + | reg r: UInt<5>, clk + | r_0 <= r + | module D : + | input clk: Clock + | input r: UInt<5> + | inst x1 of X + | x1.clk <= clk + | inst x2 of X + | x2.clk <= clk + | x1.pin <= r + | x2.pin <= r + | extmodule X : + | input clk: Clock + | input pin: UInt<5> + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(sas) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + + "Wiring from r.x to X" should "work" in { + val sinks = Map(("X"-> "pin")) + val sas = WiringInfo("A", "r.x", sinks, "A") + val input = + """circuit Top : + | module Top : + | input clk: Clock + | inst a of A + | a.clk <= clk + | module A : + | input clk: Clock + | reg r : {x: UInt<5>}, clk + | inst x of X + | x.clk <= clk + | extmodule X : + | input clk: Clock + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk: Clock + | inst a of A + | a.clk <= clk + | module A : + | input clk: Clock + | reg r: {x: UInt<5>}, clk + | inst x of X + | x.clk <= clk + | wire r_x: UInt<5> + | r_x <= r.x + | x.pin <= r_x + | extmodule X : + | input clk: Clock + | input pin: UInt<5> + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(sas) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + "Wiring from clk to X" should "work" in { + val sinks = Map(("X"-> "pin")) + val sas = WiringInfo("A", "clk", sinks, "A") + val input = + """circuit Top : + | module Top : + | input clk: Clock + | inst a of A + | a.clk <= clk + | module A : + | input clk: Clock + | inst x of X + | x.clk <= clk + | extmodule X : + | input clk: Clock + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk: Clock + | inst a of A + | a.clk <= clk + | module A : + | input clk: Clock + | inst x of X + | x.clk <= clk + | wire clk_0: Clock + | clk_0 <= clk + | x.pin <= clk_0 + | extmodule X : + | input clk: Clock + | input pin: Clock + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(sas) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + "Two sources" should "fail" in { + val sinks = Map(("X"-> "pin")) + val sas = WiringInfo("A", "clk", sinks, "Top") + val input = + """circuit Top : + | module Top : + | input clk: Clock + | inst a1 of A + | a1.clk <= clk + | inst a2 of A + | a2.clk <= clk + | module A : + | input clk: Clock + | inst x of X + | x.clk <= clk + | extmodule X : + | input clk: Clock + |""".stripMargin + intercept[WiringException] { + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(sas) + val retC = wiringPass.run(c) + } + } + "Wiring from A.clk to X, with 2 A's, and A as top" should "work" in { + val sinks = Map(("X"-> "pin")) + val sas = WiringInfo("A", "clk", sinks, "A") + val input = + """circuit Top : + | module Top : + | input clk: Clock + | inst a1 of A + | a1.clk <= clk + | inst a2 of A + | a2.clk <= clk + | module A : + | input clk: Clock + | inst x of X + | x.clk <= clk + | extmodule X : + | input clk: Clock + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk: Clock + | inst a1 of A + | a1.clk <= clk + | inst a2 of A + | a2.clk <= clk + | module A : + | input clk: Clock + | inst x of X + | x.clk <= clk + | wire clk_0: Clock + | clk_0 <= clk + | x.pin <= clk_0 + | extmodule X : + | input clk: Clock + | input pin: Clock + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(sas) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } + "Wiring from A.clk to X, with 2 A's, and A as top, but Top instantiates X" should "error" in { + val sinks = Map(("X"-> "pin")) + val sas = WiringInfo("A", "clk", sinks, "A") + val input = + """circuit Top : + | module Top : + | input clk: Clock + | inst a1 of A + | a1.clk <= clk + | inst a2 of A + | a2.clk <= clk + | inst x of X + | x.clk <= clk + | module A : + | input clk: Clock + | inst x of X + | x.clk <= clk + | extmodule X : + | input clk: Clock + |""".stripMargin + intercept[WiringException] { + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(sas) + val retC = wiringPass.run(c) + } + } + "Wiring from A.r[a] to X" should "work" in { + val sinks = Map(("X"-> "pin")) + val sas = WiringInfo("A", "r[a]", sinks, "A") + val input = + """circuit Top : + | module Top : + | input clk: Clock + | inst a of A + | a.clk <= clk + | module A : + | input clk: Clock + | reg r: UInt<2>[5], clk + | node a = UInt(5) + | inst x of X + | x.clk <= clk + | extmodule X : + | input clk: Clock + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk: Clock + | inst a of A + | a.clk <= clk + | module A : + | input clk: Clock + | reg r: UInt<2>[5], clk + | node a = UInt(5) + | inst x of X + | x.clk <= clk + | wire r_a: UInt<2> + | r_a <= r[a] + | x.pin <= r_a + | extmodule X : + | input clk: Clock + | input pin: UInt<2> + |""".stripMargin + val c = passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + val wiringPass = new Wiring(sas) + val retC = wiringPass.run(c) + (parse(retC.serialize).serialize) should be (parse(check).serialize) + } +} |
