aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorAdam Izraelevitz2018-03-21 14:24:25 -0700
committerGitHub2018-03-21 14:24:25 -0700
commit6ea4ac666e4ce8dfaca1545660f372fccff610f5 (patch)
tree8f2125855557962d642386fe8b49ed0396f562c2 /src/main
parent6b195e4a5348eed2e714e1183024588c5f91a283 (diff)
GroupModule Transform (#766)
* Added grouping pass * Added InfoMagnet and infomappers * Changed return type of execute to allow final CircuitState inspection * Updated dedup. Now is name-agnostic * Added GroupAndDedup transform
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/Mappers.scala9
-rw-r--r--src/main/scala/firrtl/WIR.scala4
-rw-r--r--src/main/scala/firrtl/analyses/InstanceGraph.scala (renamed from src/main/scala/firrtl/analyses/Netlist.scala)44
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala34
-rw-r--r--src/main/scala/firrtl/ir/IR.scala19
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemIR.scala1
-rw-r--r--src/main/scala/firrtl/transforms/Dedup.scala377
-rw-r--r--src/main/scala/firrtl/transforms/GroupComponents.scala346
8 files changed, 705 insertions, 129 deletions
diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala
index aeb6e6fe..e8283d93 100644
--- a/src/main/scala/firrtl/Mappers.scala
+++ b/src/main/scala/firrtl/Mappers.scala
@@ -24,6 +24,9 @@ object Mappers {
implicit def forString(f: String => String): StmtMagnet = new StmtMagnet {
override def map(stmt: Statement): Statement = stmt mapString f
}
+ implicit def forInfo(f: Info => Info): StmtMagnet = new StmtMagnet {
+ override def map(stmt: Statement): Statement = stmt mapInfo f
+ }
}
implicit class StmtMap(val _stmt: Statement) extends AnyVal {
// Using implicit types to allow overloading of function type to map, see StmtMagnet above
@@ -95,6 +98,9 @@ object Mappers {
implicit def forString(f: String => String): ModuleMagnet = new ModuleMagnet {
override def map(module: DefModule): DefModule = module mapString f
}
+ implicit def forInfo(f: Info => Info): ModuleMagnet = new ModuleMagnet {
+ override def map(module: DefModule): DefModule = module mapInfo f
+ }
}
implicit class ModuleMap(val _module: DefModule) extends AnyVal {
def map[T](f: T => T)(implicit magnet: (T => T) => ModuleMagnet): DefModule = magnet(f).map(_module)
@@ -111,6 +117,9 @@ object Mappers {
implicit def forString(f: String => String): CircuitMagnet = new CircuitMagnet {
override def map(circuit: Circuit): Circuit = circuit mapString f
}
+ implicit def forInfo(f: Info => Info): CircuitMagnet = new CircuitMagnet {
+ override def map(circuit: Circuit): Circuit = circuit mapInfo f
+ }
}
implicit class CircuitMap(val _circuit: Circuit) extends AnyVal {
def map[T](f: T => T)(implicit magnet: (T => T) => CircuitMagnet): Circuit = magnet(f).map(_circuit)
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 47e0f321..20da680f 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -87,6 +87,7 @@ case class WDefInstance(info: Info, name: String, module: String, tpe: Type) ext
def mapStmt(f: Statement => Statement): Statement = this
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(f(info))
}
object WDefInstance {
def apply(name: String, module: String): WDefInstance = new WDefInstance(NoInfo, name, module, UnknownType)
@@ -104,6 +105,7 @@ case class WDefInstanceConnector(
def mapStmt(f: Statement => Statement): Statement = this
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(f(info))
}
// Resultant width is the same as the maximum input width
@@ -280,6 +282,7 @@ case class CDefMemory(
def mapStmt(f: Statement => Statement): Statement = this
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(f(info))
}
case class CDefMPort(info: Info,
name: String,
@@ -295,5 +298,6 @@ case class CDefMPort(info: Info,
def mapStmt(f: Statement => Statement): Statement = this
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(f(info))
}
diff --git a/src/main/scala/firrtl/analyses/Netlist.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala
index 99f3645f..29942cd5 100644
--- a/src/main/scala/firrtl/analyses/Netlist.scala
+++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala
@@ -18,22 +18,13 @@ import firrtl.Mappers._
*/
class InstanceGraph(c: Circuit) {
- private def collectInstances(insts: mutable.Set[WDefInstance])
- (s: Statement): Statement = s match {
- case i: WDefInstance =>
- insts += i
- i
- case _ =>
- s map collectInstances(insts)
- }
-
private val moduleMap = c.modules.map({m => (m.name,m) }).toMap
private val instantiated = new mutable.HashSet[String]
private val childInstances =
new mutable.HashMap[String,mutable.Set[WDefInstance]]
for (m <- c.modules) {
childInstances(m.name) = new mutable.HashSet[WDefInstance]
- m map collectInstances(childInstances(m.name))
+ m map InstanceGraph.collectInstances(childInstances(m.name))
instantiated ++= childInstances(m.name).map(i => i.module)
}
@@ -44,7 +35,7 @@ class InstanceGraph(c: Circuit) {
uninstantiated.foreach({ subTop =>
val topInstance = WDefInstance(subTop,subTop)
instanceQueue.enqueue(topInstance)
- while (!instanceQueue.isEmpty) {
+ while (instanceQueue.nonEmpty) {
val current = instanceQueue.dequeue
instanceGraph.addVertex(current)
for (child <- childInstances(current.module)) {
@@ -70,14 +61,14 @@ class InstanceGraph(c: Circuit) {
/** A list of absolute paths (each represented by a Seq of instances)
* of all module instances in the Circuit.
*/
- lazy val fullHierarchy = graph.pathsInDAG(trueTopInstance)
+ lazy val fullHierarchy: collection.Map[WDefInstance,Seq[Seq[WDefInstance]]] = graph.pathsInDAG(trueTopInstance)
/** Finds the absolute paths (each represented by a Seq of instances
* representing the chain of hierarchy) of all instances of a
* particular module.
*
* @param module the name of the selected module
- * @return a Seq[Seq[WDefInstance]] of absolute instance paths
+ * @return a Seq[ Seq[WDefInstance] ] of absolute instance paths
*/
def findInstancesInHierarchy(module: String): Seq[Seq[WDefInstance]] = {
val instances = graph.getVertices.filter(_.module == module).toSeq
@@ -94,4 +85,31 @@ class InstanceGraph(c: Circuit) {
moduleB: Seq[WDefInstance]): Seq[WDefInstance] = {
tour.rmq(moduleA, moduleB)
}
+
+ /**
+ * Module order from highest module to leaf module
+ * @return sequence of modules in order from top to leaf
+ */
+ def moduleOrder: Seq[DefModule] = {
+ graph.transformNodes(_.module).linearize.map(moduleMap(_))
+ }
+}
+
+object InstanceGraph {
+
+ /** Returns all WDefInstances in a Statement
+ *
+ * @param insts mutable datastructure to append to
+ * @param s statement to descend
+ * @return
+ */
+ def collectInstances(insts: mutable.Set[WDefInstance])
+ (s: Statement): Statement = s match {
+ case i: WDefInstance =>
+ insts += i
+ i
+ case i: DefInstance => throwInternalError(Some("Expecting WDefInstance, found a DefInstance!"))
+ case i: WDefInstanceConnector => throwInternalError(Some("Expecting WDefInstance, found a WDefInstanceConnector!"))
+ case _ => s map collectInstances(insts)
+ }
}
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
index 6dad56d7..450ec4ff 100644
--- a/src/main/scala/firrtl/graph/DiGraph.scala
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -66,7 +66,6 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
/** Linearizes (topologically sorts) a DAG
*
- * @param root the start node
* @throws CyclicException if the graph is cyclic
* @return a Map[T,T] from each visited node to its predecessor in the
* traversal
@@ -75,8 +74,8 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
// permanently marked nodes are implicitly held in order
val order = new mutable.ArrayBuffer[T]
// invariant: no intersection between unmarked and tempMarked
- val unmarked = new LinkedHashSet[T]
- val tempMarked = new LinkedHashSet[T]
+ val unmarked = new mutable.LinkedHashSet[T]
+ val tempMarked = new mutable.LinkedHashSet[T]
def visit(n: T): Unit = {
if (tempMarked.contains(n)) {
@@ -94,7 +93,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
}
unmarked ++= getVertices
- while (!unmarked.isEmpty) {
+ while (unmarked.nonEmpty) {
visit(unmarked.head)
}
@@ -108,14 +107,23 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
* @return a Map[T,T] from each visited node to its predecessor in the
* traversal
*/
- def BFS(root: T): Map[T,T] = {
- val prev = new LinkedHashMap[T,T]
+ def BFS(root: T): Map[T,T] = BFS(root, Set.empty[T])
+
+ /** Performs breadth-first search on the directed graph, with a blacklist of nodes
+ *
+ * @param root the start node
+ * @param blacklist list of nodes to stop searching, if encountered
+ * @return a Map[T,T] from each visited node to its predecessor in the
+ * traversal
+ */
+ def BFS(root: T, blacklist: Set[T]): Map[T,T] = {
+ val prev = new mutable.LinkedHashMap[T,T]
val queue = new mutable.Queue[T]
queue.enqueue(root)
- while (!queue.isEmpty) {
+ while (queue.nonEmpty) {
val u = queue.dequeue
for (v <- getEdges(u)) {
- if (!prev.contains(v)) {
+ if (!prev.contains(v) && !blacklist.contains(v)) {
prev(v) = u
queue.enqueue(v)
}
@@ -129,7 +137,15 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link
* @param root the start node
* @return a Set[T] of nodes reachable from the root
*/
- def reachableFrom(root: T): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root).map({ case (k, v) => k })
+ def reachableFrom(root: T): LinkedHashSet[T] = reachableFrom(root, Set.empty[T])
+
+ /** Finds the set of nodes reachable from a particular node, with a blacklist
+ *
+ * @param root the start node
+ * @param blacklist list of nodes to stop searching, if encountered
+ * @return a Set[T] of nodes reachable from the root
+ */
+ def reachableFrom(root: T, blacklist: Set[T]): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root, blacklist).map({ case (k, v) => k })
/** Finds a path (if one exists) from one node to another
*
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index a3ad4231..53fbb765 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -189,6 +189,7 @@ abstract class Statement extends FirrtlNode {
def mapExpr(f: Expression => Expression): Statement
def mapType(f: Type => Type): Statement
def mapString(f: String => String): Statement
+ def mapInfo(f: Info => Info): Statement
}
case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration {
def serialize: String = s"wire $name : ${tpe.serialize}" + info.serialize
@@ -196,6 +197,7 @@ case class DefWire(info: Info, name: String, tpe: Type) extends Statement with I
def mapExpr(f: Expression => Expression): Statement = this
def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe))
def mapString(f: String => String): Statement = DefWire(info, f(name), tpe)
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class DefRegister(
info: Info,
@@ -212,6 +214,7 @@ case class DefRegister(
DefRegister(info, name, tpe, f(clock), f(reset), f(init))
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class DefInstance(info: Info, name: String, module: String) extends Statement with IsDeclaration {
@@ -220,6 +223,7 @@ case class DefInstance(info: Info, name: String, module: String) extends Stateme
def mapExpr(f: Expression => Expression): Statement = this
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = DefInstance(info, f(name), module)
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class DefMemory(
info: Info,
@@ -248,6 +252,7 @@ case class DefMemory(
def mapExpr(f: Expression => Expression): Statement = this
def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration {
def serialize: String = s"node $name = ${value.serialize}" + info.serialize
@@ -255,6 +260,7 @@ case class DefNode(info: Info, name: String, value: Expression) extends Statemen
def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value))
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = DefNode(info, f(name), value)
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class Conditionally(
info: Info,
@@ -270,6 +276,7 @@ case class Conditionally(
def mapExpr(f: Expression => Expression): Statement = Conditionally(info, f(pred), conseq, alt)
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class Block(stmts: Seq[Statement]) extends Statement {
def serialize: String = stmts map (_.serialize) mkString "\n"
@@ -277,6 +284,7 @@ case class Block(stmts: Seq[Statement]) extends Statement {
def mapExpr(f: Expression => Expression): Statement = this
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this
}
case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo {
def serialize: String = s"${loc.serialize} <- ${expr.serialize}" + info.serialize
@@ -284,6 +292,7 @@ case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends
def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr))
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo {
def serialize: String = s"${loc.serialize} <= ${expr.serialize}" + info.serialize
@@ -291,6 +300,7 @@ case class Connect(info: Info, loc: Expression, expr: Expression) extends Statem
def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr))
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo {
def serialize: String = s"${expr.serialize} is invalid" + info.serialize
@@ -298,6 +308,7 @@ case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInf
def mapExpr(f: Expression => Expression): Statement = IsInvalid(info, f(expr))
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with HasInfo {
def serialize: String = "attach " + exprs.map(_.serialize).mkString("(", ", ", ")")
@@ -305,6 +316,7 @@ case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with Has
def mapExpr(f: Expression => Expression): Statement = Attach(info, exprs map f)
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo {
def serialize: String = s"stop(${clk.serialize}, ${en.serialize}, $ret)" + info.serialize
@@ -312,6 +324,7 @@ case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends S
def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en))
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case class Print(
info: Info,
@@ -328,6 +341,7 @@ case class Print(
def mapExpr(f: Expression => Expression): Statement = Print(info, string, args map f, f(clk), f(en))
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
case object EmptyStmt extends Statement {
def serialize: String = "skip"
@@ -335,6 +349,7 @@ case object EmptyStmt extends Statement {
def mapExpr(f: Expression => Expression): Statement = this
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this
}
abstract class Width extends FirrtlNode {
@@ -514,6 +529,7 @@ abstract class DefModule extends FirrtlNode with IsDeclaration {
def mapStmt(f: Statement => Statement): DefModule
def mapPort(f: Port => Port): DefModule
def mapString(f: String => String): DefModule
+ def mapInfo(f: Info => Info): DefModule
}
/** Internal Module
*
@@ -524,6 +540,7 @@ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) e
def mapStmt(f: Statement => Statement): DefModule = this.copy(body = f(body))
def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f)
def mapString(f: String => String): DefModule = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): DefModule = this.copy(f(info))
}
/** External Module
*
@@ -541,6 +558,7 @@ case class ExtModule(
def mapStmt(f: Statement => Statement): DefModule = this
def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f)
def mapString(f: String => String): DefModule = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): DefModule = this.copy(f(info))
}
case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends FirrtlNode with HasInfo {
@@ -549,4 +567,5 @@ case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends Fi
(modules map ("\n" + _.serialize) map indent mkString "\n") + "\n"
def mapModule(f: DefModule => DefModule): Circuit = this.copy(modules = modules map f)
def mapString(f: String => String): Circuit = this.copy(main = f(main))
+ def mapInfo(f: Info => Info): Circuit = this.copy(f(info))
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala
index 54441481..5fb837c1 100644
--- a/src/main/scala/firrtl/passes/memlib/MemIR.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala
@@ -30,4 +30,5 @@ case class DefAnnotatedMemory(
def toMem = DefMemory(info, name, dataType, depth,
writeLatency, readLatency, readers, writers,
readwriters, readUnderWrite)
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
}
diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala
index f22415f0..91c82395 100644
--- a/src/main/scala/firrtl/transforms/Dedup.scala
+++ b/src/main/scala/firrtl/transforms/Dedup.scala
@@ -5,8 +5,9 @@ package transforms
import firrtl.ir._
import firrtl.Mappers._
+import firrtl.analyses.InstanceGraph
import firrtl.annotations._
-import firrtl.passes.PassException
+import firrtl.passes.{InferTypes, MemPortUtils}
// Datastructures
import scala.collection.mutable
@@ -18,134 +19,296 @@ case class NoDedupAnnotation(target: ModuleName) extends SingleTargetAnnotation[
def duplicate(n: ModuleName) = NoDedupAnnotation(n)
}
-// Only use on legal Firrtl. Specifically, the restriction of
-// instance loops must have been checked, or else this pass can
-// infinitely recurse
+/** Only use on legal Firrtl.
+ *
+ * Specifically, the restriction of instance loops must have been checked, or else this pass can
+ * infinitely recurse
+ */
class DedupModules extends Transform {
- def inputForm = HighForm
- def outputForm = HighForm
- // Orders the modules of a circuit from leaves to root
- // A module will appear *after* all modules it instantiates
- private def buildModuleOrder(c: Circuit): Seq[String] = {
- val moduleOrder = mutable.ArrayBuffer.empty[String]
- def hasInstance(b: Statement): Boolean = {
- var has = false
- def onStmt(s: Statement): Statement = s map onStmt match {
- case DefInstance(i, n, m) =>
- if(!(moduleOrder contains m)) has = true
- s
- case WDefInstance(i, n, m, t) =>
- if(!(moduleOrder contains m)) has = true
- s
- case _ => s
- }
- onStmt(b)
- has
+ def inputForm: CircuitForm = HighForm
+ def outputForm: CircuitForm = HighForm
+
+ /**
+ * Deduplicate a Circuit
+ * @param state Input Firrtl AST
+ * @return A transformed Firrtl AST
+ */
+ def execute(state: CircuitState): CircuitState = {
+ val noDedups = state.annotations.collect { case NoDedupAnnotation(ModuleName(m, c)) => m }
+ val (newC, renameMap) = run(state.circuit, noDedups)
+ state.copy(circuit = newC, renames = Some(renameMap))
+ }
+
+ /**
+ * Deduplicates a circuit, and records renaming
+ * @param c Circuit to dedup
+ * @param noDedups Modules not to dedup
+ * @return Deduped Circuit and corresponding RenameMap
+ */
+ def run(c: Circuit, noDedups: Seq[String]): (Circuit, RenameMap) = {
+
+ // RenameMap
+ val renameMap = RenameMap()
+ renameMap.setCircuit(c.main)
+
+ // Maps module name to corresponding dedup module
+ val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, renameMap)
+
+ // Use old module list to preserve ordering
+ val dedupedModules = c.modules.map(m => dedupMap(m.name)).distinct
+
+ val cname = CircuitName(c.main)
+ renameMap.addMap(dedupMap.map { case (from, to) =>
+ logger.debug(s"[Dedup] $from -> ${to.name}")
+ ModuleName(from, cname) -> List(ModuleName(to.name, cname))
+ })
+
+ (InferTypes.run(c.copy(modules = dedupedModules)), renameMap)
+ }
+}
+
+/**
+ * Utility functions for [[DedupModules]]
+ */
+object DedupModules {
+ /**
+ * Change's a module's internal signal names, types, infos, and modules.
+ * @param rename Function to rename a signal. Called on declaration and references.
+ * @param retype Function to retype a signal. Called on declaration, references, and subfields
+ * @param reinfo Function to re-info a statement
+ * @param renameModule Function to rename an instance's module
+ * @param module Module to change internals
+ * @return Changed Module
+ */
+ def changeInternals(rename: String=>String,
+ retype: String=>Type=>Type,
+ reinfo: Info=>Info,
+ renameModule: String=>String
+ )(module: DefModule): DefModule = {
+ def onPort(p: Port): Port = Port(reinfo(p.info), rename(p.name), p.direction, retype(p.name)(p.tpe))
+ def onExp(e: Expression): Expression = e match {
+ case WRef(n, t, k, g) => WRef(rename(n), retype(n)(t), k, g)
+ case WSubField(expr, n, tpe, kind) =>
+ val fieldIndex = expr.tpe.asInstanceOf[BundleType].fields.indexWhere(f => f.name == n)
+ val newExpr = onExp(expr)
+ val newField = newExpr.tpe.asInstanceOf[BundleType].fields(fieldIndex)
+ val finalExpr = WSubField(newExpr, newField.name, newField.tpe, kind)
+ //TODO: renameMap.rename(e.serialize, finalExpr.serialize)
+ finalExpr
+ case other => other map onExp
}
- def addModule(m: DefModule): DefModule = m match {
- case Module(info, n, ps, b) =>
- if (!hasInstance(b)) moduleOrder += m.name
- m
- case e: ExtModule =>
- moduleOrder += m.name
- m
- case _ => m
+ def onStmt(s: Statement): Statement = s match {
+ case WDefInstance(i, n, m, t) =>
+ val newmod = renameModule(m)
+ WDefInstance(reinfo(i), rename(n), newmod, retype(n)(t))
+ case DefInstance(i, n, m) => DefInstance(reinfo(i), rename(n), renameModule(m))
+ case d: DefMemory =>
+ val oldType = MemPortUtils.memType(d)
+ val newType = retype(d.name)(oldType)
+ val index = oldType
+ .asInstanceOf[BundleType].fields.headOption
+ .map(_.tpe.asInstanceOf[BundleType].fields.indexWhere(
+ {
+ case Field("data" | "wdata" | "rdata", _, _) => true
+ case _ => false
+ }))
+ val newDataType = index match {
+ case Some(i) =>
+ //If index nonempty, then there exists a port
+ newType.asInstanceOf[BundleType].fields.head.tpe.asInstanceOf[BundleType].fields(i).tpe
+ case None =>
+ //If index is empty, this mem has no ports, and so we don't need to record the dataType
+ // Thus, call retype with an illegal name, so we can retype the memory's datatype, but not
+ // associate it with the type of the memory (as the memory type is different than the datatype)
+ retype(d.name + ";&*^$")(d.dataType)
+ }
+ d.copy(dataType = newDataType) map rename map reinfo
+ case h: IsDeclaration => h map rename map retype(h.name) map onExp map reinfo
+ case other => other map reinfo map onExp map onStmt
}
-
- while ((moduleOrder.size < c.modules.size)) {
- c.modules.foreach(m => if (!moduleOrder.contains(m.name)) addModule(m))
+ val finalModule = module match {
+ case m: Module => m map onPort map onStmt
+ case other => other
}
- moduleOrder
+ finalModule
}
- // Finds duplicate Modules
- // Also changes DefInstances to instantiate the deduplicated module
- // Returns (Deduped Module name -> Seq of identical modules,
- // Deuplicate Module name -> deduped module name)
- private def findDups(
- moduleOrder: Seq[String],
- moduleMap: Map[String, DefModule],
- noDedups: Seq[String]): (Map[String, Seq[DefModule]], Map[String, String]) = {
- // Module body -> Module name
- val dedupModules = mutable.HashMap.empty[String, String]
- // Old module name -> dup module name
- val dedupMap = mutable.HashMap.empty[String, String]
- // Deduplicated module name -> all identical modules
- val oldModuleMap = mutable.HashMap.empty[String, Seq[DefModule]]
-
- def onModule(m: DefModule): Unit = {
- def fixInstance(s: Statement): Statement = s map fixInstance match {
- case DefInstance(i, n, m) => DefInstance(i, n, dedupMap.getOrElse(m, m))
- case WDefInstance(i, n, m, t) => WDefInstance(i, n, dedupMap.getOrElse(m, m), t)
- case x => x
+ /**
+ * Turns a module into a name-agnostic module
+ * @param module module to change
+ * @return name-agnostic module
+ */
+ def agnostify(module: DefModule, name2tag: mutable.HashMap[String, String], tag2name: mutable.HashMap[String, String]): DefModule = {
+ val namespace = Namespace()
+ val nameMap = mutable.HashMap[String, String]()
+ val typeMap = mutable.HashMap[String, Type]()
+ def rename(name: String): String = {
+ if (nameMap.contains(name)) nameMap(name) else {
+ val newName = namespace.newTemp
+ nameMap(name) = newName
+ newName
}
- def removeInfo(stmt: Statement): Statement = stmt map removeInfo match {
- case sx: HasInfo => sx match {
- case s: DefWire => s.copy(info = NoInfo)
- case s: DefNode => s.copy(info = NoInfo)
- case s: DefRegister => s.copy(info = NoInfo)
- case s: DefInstance => s.copy(info = NoInfo)
- case s: WDefInstance => s.copy(info = NoInfo)
- case s: DefMemory => s.copy(info = NoInfo)
- case s: Connect => s.copy(info = NoInfo)
- case s: PartialConnect => s.copy(info = NoInfo)
- case s: IsInvalid => s.copy(info = NoInfo)
- case s: Attach => s.copy(info = NoInfo)
- case s: Stop => s.copy(info = NoInfo)
- case s: Print => s.copy(info = NoInfo)
- case s: Conditionally => s.copy(info = NoInfo)
+ }
+ def retype(name: String)(tpe: Type): Type = {
+ if (typeMap.contains(name)) typeMap(name) else {
+ def onType(tpe: Type): Type = tpe map onType match {
+ case BundleType(fields) => BundleType(fields.map(f => Field(rename(f.name), f.flip, f.tpe)))
+ case other => other
}
- case sx => sx
+ val newType = onType(tpe)
+ typeMap(name) = newType
+ newType
}
- def removePortInfo(p: Port): Port = p.copy(info = NoInfo)
+ }
+ def remodule(name: String): String = tag2name(name2tag(name))
+ changeInternals(rename, retype, {i: Info => NoInfo}, remodule)(module)
+ }
+ /** Dedup a module's instances based on dedup map
+ *
+ * Will fixes up module if deduped instance's ports are differently named
+ *
+ * @param moduleName Module name who's instances will be deduped
+ * @param moduleMap Map of module name to its original module
+ * @param name2name Map of module name to the module deduping it. Not mutated in this function.
+ * @param renameMap Will be modified to keep track of renames in this function
+ * @return fixed up module deduped instances
+ */
+ def dedupInstances(moduleName: String, moduleMap: Map[String, DefModule], name2name: mutable.Map[String, String], renameMap: RenameMap): DefModule = {
+ val module = moduleMap(moduleName)
- val mx = m map fixInstance
- val mxx = (mx map removeInfo) map removePortInfo
+ // If black box, return it (it has no instances)
+ if (module.isInstanceOf[ExtModule]) return module
- // If shouldn't dedup, just make it fail to be the same to any other modules
- val unique = if (!noDedups.contains(mxx.name)) "" else mxx.name
- val string = mxx match {
- case Module(i, n, ps, b) =>
- ps.map(_.serialize).mkString + b.serialize + unique
- case ExtModule(i, n, ps, dn, p) =>
- ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString + unique
- }
- dedupModules.get(string) match {
- case Some(dupname) =>
- dedupMap(mx.name) = dupname
- oldModuleMap(dupname) = oldModuleMap(dupname) :+ mx
- case None =>
- dedupModules(string) = mx.name
- oldModuleMap(mx.name) = Seq(mx)
+ // Get all instances to know what to rename in the module
+ val instances = mutable.Set[WDefInstance]()
+ InstanceGraph.collectInstances(instances)(module.asInstanceOf[Module].body)
+ val instanceModuleMap = instances.map(i => i.name -> i.module).toMap
+ val moduleNames = instances.map(_.module)
+
+ def getNewModule(old: String): DefModule = {
+ moduleMap(name2name(old))
+ }
+ // Define rename functions
+ def renameModule(name: String): String = getNewModule(name).name
+ val typeMap = mutable.HashMap[String, Type]()
+ def retype(name: String)(tpe: Type): Type = {
+ if (typeMap.contains(name)) typeMap(name) else {
+ if (instanceModuleMap.contains(name)) {
+ val newType = Utils.module_type(getNewModule(instanceModuleMap(name)))
+ typeMap(name) = newType
+ getAffectedExpressions(WRef(name, tpe)).zip(getAffectedExpressions(WRef(name, newType))).foreach {
+ case (old, nuu) => renameMap.rename(old.serialize, nuu.serialize)
+ }
+ newType
+ } else tpe
}
}
- moduleOrder.foreach(n => onModule(moduleMap(n)))
- (oldModuleMap.toMap, dedupMap.toMap)
+
+ renameMap.setModule(module.name)
+ // Change module internals
+ changeInternals({n => n}, retype, {i => i}, renameModule)(module)
}
- def run(c: Circuit, noDedups: Seq[String]): (Circuit, RenameMap) = {
- val moduleOrder = buildModuleOrder(c)
- val moduleMap = c.modules.map(m => m.name -> m).toMap
+ /**
+ * Deduplicate
+ * @param circuit Circuit
+ * @param noDedups list of modules to not dedup
+ * @param renameMap rename map to populate when deduping
+ * @return Map of original Module name -> Deduped Module
+ */
+ def deduplicate(circuit: Circuit,
+ noDedups: Set[String],
+ renameMap: RenameMap): Map[String, DefModule] = {
- val (oldModuleMap, dedupMap) = findDups(moduleOrder, moduleMap, noDedups)
+ // Order of modules, from leaf to top
+ val moduleLinearization = new InstanceGraph(circuit).moduleOrder.map(_.name).reverse
- // Use old module list to preserve ordering
- val dedupedModules = c.modules.flatMap(m => oldModuleMap.get(m.name).map(_.head))
+ // Maps module name to original module
+ val moduleMap = circuit.modules.map(m => m.name -> m).toMap
- val cname = CircuitName(c.main)
- val renameMap = RenameMap(dedupMap.map { case (from, to) =>
- logger.debug(s"[Dedup] $from -> $to")
- ModuleName(from, cname) -> List(ModuleName(to, cname))
- })
+ // Maps a module's tag to its deduplicated module
+ val tag2name = mutable.HashMap.empty[String, String]
+
+ // Maps a module's name to its tag
+ val name2tag = mutable.HashMap.empty[String, String]
+
+ // Maps a tag to all matching module names
+ val tag2all = mutable.HashMap.empty[String, mutable.Set[String]]
+
+ // Build dedupMap
+ moduleLinearization.foreach { moduleName =>
+ // Get original module
+ val originalModule = moduleMap(moduleName)
+
+ // Replace instance references to new deduped modules
+ val dontcare = RenameMap()
+ dontcare.setCircuit("dontcare")
+ //val fixedModule = DedupModules.dedupInstances(originalModule, tag2module, name2tag, name2module, dontcare)
+
+ if (noDedups.contains(originalModule.name)) {
+ // Don't dedup. Set dedup module to be the same as fixed module
+ name2tag(originalModule.name) = originalModule.name
+ tag2name(originalModule.name) = originalModule.name
+ //templateModules += originalModule.name
+ } else { // Try to dedup
+
+ // Build name-agnostic module
+ val agnosticModule = DedupModules.agnostify(originalModule, name2tag, tag2name)
+
+ // Build tag
+ val tag = (agnosticModule match {
+ case Module(i, n, ps, b) =>
+ ps.map(_.serialize).mkString + b.serialize
+ case ExtModule(i, n, ps, dn, p) =>
+ ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString
+ }).hashCode().toString
+
+ // Match old module name to its tag
+ name2tag(originalModule.name) = tag
+
+ // Set tag's module to be the first matching module
+ if (!tag2name.contains(tag)) {
+ tag2name(tag) = originalModule.name
+ tag2all(tag) = mutable.Set(originalModule.name)
+ } else {
+ tag2all(tag) += originalModule.name
+ }
+ }
+ }
+
+
+ // Set tag2name to be the best dedup module name
+ val moduleIndex = circuit.modules.zipWithIndex.map{case (m, i) => m.name -> i}.toMap
+ def order(l: String, r: String): String = if (moduleIndex(l) < moduleIndex(r)) l else r
+ tag2all.foreach { case (tag, all) => tag2name(tag) = all.reduce(order)}
- (c.copy(modules = dedupedModules), renameMap)
+ // Create map from original to dedup name
+ val name2name = name2tag.map({ case (name, tag) => name -> tag2name(tag) })
+
+ // Build Remap for modules with deduped module references
+ val tag2module = tag2name.map({ case (tag, name) => tag -> DedupModules.dedupInstances(name, moduleMap, name2name, renameMap) })
+
+ // Build map from original name to corresponding deduped module
+ val name2module = name2tag.map({ case (name, tag) => name -> tag2module(tag) })
+
+ name2module.toMap
}
- def execute(state: CircuitState): CircuitState = {
- val noDedups = state.annotations.collect { case NoDedupAnnotation(ModuleName(m, c)) => m }
- val (newC, renameMap) = run(state.circuit, noDedups)
- state.copy(circuit = newC, renames = Some(renameMap))
+ def getAffectedExpressions(root: Expression): Seq[Expression] = {
+ val all = mutable.ArrayBuffer[Expression]()
+
+ def onExp(expr: Expression): Unit = {
+ expr.tpe match {
+ case _: GroundType =>
+ case b: BundleType => b.fields.foreach { f => onExp(WSubField(expr, f.name, f.tpe)) }
+ case v: VectorType => (0 until v.size).foreach { i => onExp(WSubIndex(expr, i, v.tpe, UNKNOWNGENDER)) }
+ }
+ all += expr
+ }
+
+ onExp(root)
+ all
}
}
diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala
new file mode 100644
index 00000000..43053e3d
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/GroupComponents.scala
@@ -0,0 +1,346 @@
+package firrtl.transforms
+
+import firrtl._
+import firrtl.Mappers._
+import firrtl.ir._
+import firrtl.annotations.{Annotation, ComponentName}
+import firrtl.passes.{InferTypes, LowerTypes, MemPortUtils}
+import firrtl.Utils.{kind, throwInternalError}
+import firrtl.graph.{DiGraph, MutableDiGraph}
+
+import scala.collection.mutable
+
+
+/**
+ * Specifies a group of components, within a module, to pull out into their own module
+ * Components that are only connected to a group's components will also be included
+ *
+ * @param components components in this group
+ * @param newModule suggested name of the new module
+ * @param newInstance suggested name of the instance of the new module
+ * @param outputSuffix suggested suffix of any output ports of the new module
+ * @param inputSuffix suggested suffix of any input ports of the new module
+ */
+case class GroupAnnotation(components: Seq[ComponentName], newModule: String, newInstance: String, outputSuffix: Option[String] = None, inputSuffix: Option[String] = None) extends Annotation {
+ if(components.nonEmpty) {
+ require(components.forall(_.module == components.head.module), "All components must be in the same module.")
+ require(components.forall(!_.name.contains('.')), "No components can be a subcomponent.")
+ }
+
+ /**
+ * The module that all components are located in
+ * @return
+ */
+ def currentModule: String = components.head.module.name
+
+ /* Only keeps components renamed to components */
+ def update(renames: RenameMap): Seq[Annotation] = {
+ val newComponents = components.flatMap{c => renames.get(c).getOrElse(Seq(c))}.collect {
+ case c: ComponentName => c
+ }
+ Seq(GroupAnnotation(newComponents, newModule, newInstance, outputSuffix, inputSuffix))
+ }
+}
+
+/**
+ * Splits a module into multiple modules by grouping its components via [[GroupAnnotation]]'s
+ */
+class GroupComponents extends firrtl.Transform {
+ type MSet[T] = mutable.Set[T]
+
+ def inputForm: CircuitForm = MidForm
+ def outputForm: CircuitForm = MidForm
+
+ override def execute(state: CircuitState): CircuitState = {
+ val groups = state.annotations.collect {case g: GroupAnnotation => g}
+ val module2group = groups.groupBy(_.currentModule)
+ val mnamespace = Namespace(state.circuit)
+ val newModules = state.circuit.modules.flatMap {
+ case m: Module if module2group.contains(m.name) =>
+ // do stuff
+ groupModule(m, module2group(m.name).filter(_.components.nonEmpty), mnamespace)
+ case other => Seq(other)
+ }
+ val cs = state.copy(circuit = state.circuit.copy(modules = newModules))
+ val csx = InferTypes.execute(cs)
+ csx
+ }
+
+ def groupModule(m: Module, groups: Seq[GroupAnnotation], mnamespace: Namespace): Seq[Module] = {
+ val namespace = Namespace(m)
+ val groupRoots = groups.map(_.components.map(_.name))
+ val totalSum = groupRoots.map(_.size).sum
+ val union = groupRoots.foldLeft(Set.empty[String]){(all, set) => all.union(set.toSet)}
+
+ require(groupRoots.forall{_.forall{namespace.contains}}, "All names should be in this module")
+ require(totalSum == union.size, "No name can be in more than one group")
+ require(groupRoots.forall(_.nonEmpty), "All groupRoots must by non-empty")
+
+
+ // Order of groups, according to their label. The label is the first root in the group
+ val labelOrder = groups.collect({ case g: GroupAnnotation => g.components.head.name })
+
+ // Annotations, by label
+ val label2annotation = groups.collect({ case g: GroupAnnotation => g.components.head.name -> g }).toMap
+
+ // Group roots, by label
+ // The label "" indicates the original module, and components belonging to that group will remain
+ // in the original module (not get moved into a new module)
+ val label2group: Map[String, MSet[String]] = groups.collect{
+ case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name):_*)
+ }.toMap + ("" -> mutable.Set(""))
+
+ // Name of new module containing each group, by label
+ val label2module: Map[String, String] =
+ groups.map(a => a.components.head.name -> mnamespace.newName(a.newModule)).toMap
+
+ // Name of instance of new module, by label
+ val label2instance: Map[String, String] =
+ groups.map(a => a.components.head.name -> namespace.newName(a.newInstance)).toMap
+
+ // Build set of components not in set
+ val notSet = label2group.map { case (key, value) => key -> union.diff(value) }
+
+
+ // Get all dependencies between components
+ val deps = getComponentConnectivity(m)
+
+ // For each node not in the set, which group (by label) can reach it
+ val reachableNodes = new mutable.HashMap[String, MSet[String]]()
+
+ // For each group (by label), add connectivity between nodes in set
+ // Populate reachableNodes with reachability, where blacklist is their notSet
+ label2group.foreach { case (label, set) =>
+ set.foreach { x =>
+ deps.addPairWithEdge(label, x)
+ }
+ deps.reachableFrom(label, notSet(label)) foreach { node =>
+ reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label
+ }
+ }
+
+ // Add nodes who are reached by a single group, to that group
+ reachableNodes.foreach { case (node, membership) =>
+ if(membership.size == 1) {
+ label2group(membership.head) += node
+ } else {
+ label2group("") += node
+ }
+ }
+
+ applyGrouping(m, labelOrder, label2group, label2module, label2instance, label2annotation)
+ }
+
+ /**
+ * Applies datastructures to a module, to group its components into distinct modules
+ * @param m module to split apart
+ * @param labelOrder order of groups in SeqAnnotation, to make the grouping more deterministic
+ * @param label2group group components, by label
+ * @param label2module module name, by label
+ * @param label2instance instance name of the group's module, by label
+ * @param label2annotation annotation specifying the group, by label
+ * @return new modules, including each group's module and the new split module
+ */
+ def applyGrouping( m: Module,
+ labelOrder: Seq[String],
+ label2group: Map[String, MSet[String]],
+ label2module: Map[String, String],
+ label2instance: Map[String, String],
+ label2annotation: Map[String, GroupAnnotation]
+ ): Seq[Module] = {
+ // Maps node to group
+ val byNode = mutable.HashMap[String, String]()
+ label2group.foreach { case (group, nodes) =>
+ nodes.foreach { node =>
+ byNode(node) = group
+ }
+ }
+ val groupNamespace = label2group.map { case (head, set) => head -> Namespace(set.toSeq) }
+
+ val groupStatements = mutable.HashMap[String, mutable.ArrayBuffer[Statement]]()
+ val groupPorts = mutable.HashMap[String, mutable.ArrayBuffer[Port]]()
+ val groupPortNames = mutable.HashMap[String, mutable.HashMap[String, String]]()
+ label2group.keys.foreach { group =>
+ groupStatements(group) = new mutable.ArrayBuffer[Statement]()
+ groupPorts(group) = new mutable.ArrayBuffer[Port]()
+ groupPortNames(group) = new mutable.HashMap[String, String]()
+ }
+
+ def addPort(group: String, exp: Expression, d: Direction): String = {
+ val source = LowerTypes.loweredName(exp)
+ val portNames = groupPortNames(group)
+ val suffix = d match {
+ case Output => label2annotation(group).outputSuffix.getOrElse("")
+ case Input => label2annotation(group).inputSuffix.getOrElse("")
+ }
+ val newName = groupNamespace(group).newName(source + suffix)
+ val portName = portNames.getOrElseUpdate(source, newName)
+ groupPorts(group) += Port(NoInfo, portName, d, exp.tpe)
+ portName
+ }
+
+ def punchSignalOut(group: String, exp: Expression): String = {
+ val portName = addPort(group, exp, Output)
+ groupStatements(group) += Connect(NoInfo, WRef(portName), exp)
+ portName
+ }
+
+ // Given the sink is in a group, tidy up source references
+ def inGroupFixExps(group: String, added: mutable.ArrayBuffer[Statement])(e: Expression): Expression = e match {
+ case _: Literal => e
+ case _: DoPrim | _: Mux | _: ValidIf => e map inGroupFixExps(group, added)
+ case otherExp: Expression =>
+ val wref = getWRef(otherExp)
+ val source = wref.name
+ byNode(source) match {
+ // case 1: source in the same group as sink
+ case `group` => otherExp //do nothing
+
+ // case 2: source in top
+ case "" =>
+ // Add port to group's Module
+ val toPort = addPort(group, otherExp, Input)
+
+ // Add connection in Top to group's Module port
+ added += Connect(NoInfo, WSubField(WRef(label2instance(group)), toPort), otherExp)
+
+ // Return WRef with new kind (its inside the group Module now)
+ WRef(toPort, otherExp.tpe, PortKind, MALE)
+
+ // case 3: source in different group
+ case otherGroup =>
+ // Add port to otherGroup's Module
+ val fromPort = punchSignalOut(otherGroup, otherExp)
+ val toPort = addPort(group, otherExp, Input)
+
+ // Add connection in Top from otherGroup's port to group's port
+ val groupInst = label2instance(group)
+ val otherInst = label2instance(otherGroup)
+ added += Connect(NoInfo, WSubField(WRef(groupInst), toPort), WSubField(WRef(otherInst), fromPort))
+
+ // Return WRef with new kind (its inside the group Module now)
+ WRef(toPort, otherExp.tpe, PortKind, MALE)
+ }
+ }
+
+ // Given the sink is in the parent module, tidy up source references belonging to groups
+ def inTopFixExps(e: Expression): Expression = e match {
+ case _: DoPrim | _: Mux | _: ValidIf => e map inTopFixExps
+ case otherExp: Expression =>
+ val wref = getWRef(otherExp)
+ if(byNode(wref.name) != "") {
+ // Get the name of source's group
+ val otherGroup = byNode(wref.name)
+
+ // Add port to otherGroup's Module
+ val otherPortName = punchSignalOut(otherGroup, otherExp)
+
+ // Return WSubField (its inside the top Module still)
+ WSubField(WRef(label2instance(otherGroup)), otherPortName)
+
+ } else otherExp
+ }
+
+ def onStmt(s: Statement): Statement = {
+ s match {
+ // Sink is in a group
+ case r: IsDeclaration if byNode(r.name) != "" =>
+ val topStmts = mutable.ArrayBuffer[Statement]()
+ val group = byNode(r.name)
+ groupStatements(group) += r mapExpr inGroupFixExps(group, topStmts)
+ Block(topStmts)
+ case c: Connect if byNode(getWRef(c.loc).name) != "" =>
+ // Sink is in a group
+ val topStmts = mutable.ArrayBuffer[Statement]()
+ val group = byNode(getWRef(c.loc).name)
+ groupStatements(group) += Connect(c.info, c.loc, inGroupFixExps(group, topStmts)(c.expr))
+ Block(topStmts)
+ // TODO Attach if all are in a group?
+ case _: IsDeclaration | _: Connect | _: Attach =>
+ // Sink is in Top
+ val ret = s mapExpr inTopFixExps
+ ret
+ case other => other map onStmt
+ }
+ }
+
+
+ // Build datastructures
+ val newTopBody = Block(labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body)))
+ val finalTopBody = Block(Utils.squashEmpty(newTopBody).asInstanceOf[Block].stmts.distinct)
+
+ // For all group labels (not including the original module label), return a new Module.
+ val newModules = labelOrder.filter(_ != "") map { group =>
+ Module(NoInfo, label2module(group), groupPorts(group).distinct, Block(groupStatements(group).distinct))
+ }
+ Seq(m.copy(body = finalTopBody)) ++ newModules
+ }
+
+ def getWRef(e: Expression): WRef = e match {
+ case w: WRef => w
+ case other =>
+ var w = WRef("")
+ other mapExpr { e => w = getWRef(e); e}
+ w
+ }
+
+ /**
+ * Compute how each component connects to each other component
+ * It is non-directioned; there is an edge from source to sink and from sink to souce
+ * @param m module to compute connectivity
+ * @return a bi-directional representation of component connectivity
+ */
+ def getComponentConnectivity(m: Module): MutableDiGraph[String] = {
+ val bidirGraph = new MutableDiGraph[String]
+ val simNamespace = Namespace()
+ val simulations = new mutable.HashMap[String, Statement]
+ def onExpr(sink: WRef)(e: Expression): Expression = e match {
+ case w @ WRef(name, _, _, _) =>
+ bidirGraph.addPairWithEdge(sink.name, name)
+ bidirGraph.addPairWithEdge(name, sink.name)
+ w
+ case other => other map onExpr(sink)
+ }
+ def onStmt(stmt: Statement): Unit = stmt match {
+ case w: WDefInstance =>
+ case h: IsDeclaration => h map onExpr(WRef(h.name))
+ case Attach(_, exprs) => // Add edge between each expression
+ exprs.tail map onExpr(getWRef(exprs.head))
+ case Connect(_, loc, expr) =>
+ onExpr(getWRef(loc))(expr)
+ case q @ Stop(_,_, clk, en) =>
+ val simName = simNamespace.newTemp
+ simulations(simName) = q
+ Seq(clk, en) map onExpr(WRef(simName))
+ case q @ Print(_, _, args, clk, en) =>
+ val simName = simNamespace.newTemp
+ simulations(simName) = q
+ (args :+ clk :+ en) map onExpr(WRef(simName))
+ case Block(stmts) => stmts.foreach(onStmt)
+ case ignore @ (_: IsInvalid | EmptyStmt) => // do nothing
+ case other => throw new Exception(s"Unexpected Statement $other")
+ }
+
+ onStmt(m.body)
+ m.ports.foreach { p =>
+ bidirGraph.addPairWithEdge("", p.name)
+ bidirGraph.addPairWithEdge(p.name, "")
+ }
+ bidirGraph
+ }
+}
+
+/**
+ * Splits a module into multiple modules by grouping its components via [[GroupAnnotation]]'s
+ * Tries to deduplicate the resulting circuit
+ */
+class GroupAndDedup extends Transform {
+ def inputForm: CircuitForm = MidForm
+ def outputForm: CircuitForm = MidForm
+
+ override def execute(state: CircuitState): CircuitState = {
+ val cs = new GroupComponents().execute(state)
+ val csx = new DedupModules().execute(cs)
+ csx
+ }
+}