diff options
| author | Adam Izraelevitz | 2018-03-21 14:24:25 -0700 |
|---|---|---|
| committer | GitHub | 2018-03-21 14:24:25 -0700 |
| commit | 6ea4ac666e4ce8dfaca1545660f372fccff610f5 (patch) | |
| tree | 8f2125855557962d642386fe8b49ed0396f562c2 | |
| parent | 6b195e4a5348eed2e714e1183024588c5f91a283 (diff) | |
GroupModule Transform (#766)
* Added grouping pass
* Added InfoMagnet and infomappers
* Changed return type of execute to allow final CircuitState inspection
* Updated dedup. Now is name-agnostic
* Added GroupAndDedup transform
| -rw-r--r-- | src/main/scala/firrtl/Mappers.scala | 9 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/analyses/InstanceGraph.scala (renamed from src/main/scala/firrtl/analyses/Netlist.scala) | 44 | ||||
| -rw-r--r-- | src/main/scala/firrtl/graph/DiGraph.scala | 34 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 19 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/MemIR.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/Dedup.scala | 377 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/GroupComponents.scala | 346 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/AttachSpec.scala | 2 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/PassTests.scala | 3 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/transforms/DedupTests.scala | 336 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala | 290 |
12 files changed, 1196 insertions, 269 deletions
diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala index aeb6e6fe..e8283d93 100644 --- a/src/main/scala/firrtl/Mappers.scala +++ b/src/main/scala/firrtl/Mappers.scala @@ -24,6 +24,9 @@ object Mappers { implicit def forString(f: String => String): StmtMagnet = new StmtMagnet { override def map(stmt: Statement): Statement = stmt mapString f } + implicit def forInfo(f: Info => Info): StmtMagnet = new StmtMagnet { + override def map(stmt: Statement): Statement = stmt mapInfo f + } } implicit class StmtMap(val _stmt: Statement) extends AnyVal { // Using implicit types to allow overloading of function type to map, see StmtMagnet above @@ -95,6 +98,9 @@ object Mappers { implicit def forString(f: String => String): ModuleMagnet = new ModuleMagnet { override def map(module: DefModule): DefModule = module mapString f } + implicit def forInfo(f: Info => Info): ModuleMagnet = new ModuleMagnet { + override def map(module: DefModule): DefModule = module mapInfo f + } } implicit class ModuleMap(val _module: DefModule) extends AnyVal { def map[T](f: T => T)(implicit magnet: (T => T) => ModuleMagnet): DefModule = magnet(f).map(_module) @@ -111,6 +117,9 @@ object Mappers { implicit def forString(f: String => String): CircuitMagnet = new CircuitMagnet { override def map(circuit: Circuit): Circuit = circuit mapString f } + implicit def forInfo(f: Info => Info): CircuitMagnet = new CircuitMagnet { + override def map(circuit: Circuit): Circuit = circuit mapInfo f + } } implicit class CircuitMap(val _circuit: Circuit) extends AnyVal { def map[T](f: T => T)(implicit magnet: (T => T) => CircuitMagnet): Circuit = magnet(f).map(_circuit) diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 47e0f321..20da680f 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -87,6 +87,7 @@ case class WDefInstance(info: Info, name: String, module: String, tpe: Type) ext def mapStmt(f: Statement => Statement): Statement = this def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(f(info)) } object WDefInstance { def apply(name: String, module: String): WDefInstance = new WDefInstance(NoInfo, name, module, UnknownType) @@ -104,6 +105,7 @@ case class WDefInstanceConnector( def mapStmt(f: Statement => Statement): Statement = this def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(f(info)) } // Resultant width is the same as the maximum input width @@ -280,6 +282,7 @@ case class CDefMemory( def mapStmt(f: Statement => Statement): Statement = this def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(f(info)) } case class CDefMPort(info: Info, name: String, @@ -295,5 +298,6 @@ case class CDefMPort(info: Info, def mapStmt(f: Statement => Statement): Statement = this def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(f(info)) } diff --git a/src/main/scala/firrtl/analyses/Netlist.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala index 99f3645f..29942cd5 100644 --- a/src/main/scala/firrtl/analyses/Netlist.scala +++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala @@ -18,22 +18,13 @@ import firrtl.Mappers._ */ class InstanceGraph(c: Circuit) { - private def collectInstances(insts: mutable.Set[WDefInstance]) - (s: Statement): Statement = s match { - case i: WDefInstance => - insts += i - i - case _ => - s map collectInstances(insts) - } - private val moduleMap = c.modules.map({m => (m.name,m) }).toMap private val instantiated = new mutable.HashSet[String] private val childInstances = new mutable.HashMap[String,mutable.Set[WDefInstance]] for (m <- c.modules) { childInstances(m.name) = new mutable.HashSet[WDefInstance] - m map collectInstances(childInstances(m.name)) + m map InstanceGraph.collectInstances(childInstances(m.name)) instantiated ++= childInstances(m.name).map(i => i.module) } @@ -44,7 +35,7 @@ class InstanceGraph(c: Circuit) { uninstantiated.foreach({ subTop => val topInstance = WDefInstance(subTop,subTop) instanceQueue.enqueue(topInstance) - while (!instanceQueue.isEmpty) { + while (instanceQueue.nonEmpty) { val current = instanceQueue.dequeue instanceGraph.addVertex(current) for (child <- childInstances(current.module)) { @@ -70,14 +61,14 @@ class InstanceGraph(c: Circuit) { /** A list of absolute paths (each represented by a Seq of instances) * of all module instances in the Circuit. */ - lazy val fullHierarchy = graph.pathsInDAG(trueTopInstance) + lazy val fullHierarchy: collection.Map[WDefInstance,Seq[Seq[WDefInstance]]] = graph.pathsInDAG(trueTopInstance) /** Finds the absolute paths (each represented by a Seq of instances * representing the chain of hierarchy) of all instances of a * particular module. * * @param module the name of the selected module - * @return a Seq[Seq[WDefInstance]] of absolute instance paths + * @return a Seq[ Seq[WDefInstance] ] of absolute instance paths */ def findInstancesInHierarchy(module: String): Seq[Seq[WDefInstance]] = { val instances = graph.getVertices.filter(_.module == module).toSeq @@ -94,4 +85,31 @@ class InstanceGraph(c: Circuit) { moduleB: Seq[WDefInstance]): Seq[WDefInstance] = { tour.rmq(moduleA, moduleB) } + + /** + * Module order from highest module to leaf module + * @return sequence of modules in order from top to leaf + */ + def moduleOrder: Seq[DefModule] = { + graph.transformNodes(_.module).linearize.map(moduleMap(_)) + } +} + +object InstanceGraph { + + /** Returns all WDefInstances in a Statement + * + * @param insts mutable datastructure to append to + * @param s statement to descend + * @return + */ + def collectInstances(insts: mutable.Set[WDefInstance]) + (s: Statement): Statement = s match { + case i: WDefInstance => + insts += i + i + case i: DefInstance => throwInternalError(Some("Expecting WDefInstance, found a DefInstance!")) + case i: WDefInstanceConnector => throwInternalError(Some("Expecting WDefInstance, found a WDefInstanceConnector!")) + case _ => s map collectInstances(insts) + } } diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index 6dad56d7..450ec4ff 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -66,7 +66,6 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link /** Linearizes (topologically sorts) a DAG * - * @param root the start node * @throws CyclicException if the graph is cyclic * @return a Map[T,T] from each visited node to its predecessor in the * traversal @@ -75,8 +74,8 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link // permanently marked nodes are implicitly held in order val order = new mutable.ArrayBuffer[T] // invariant: no intersection between unmarked and tempMarked - val unmarked = new LinkedHashSet[T] - val tempMarked = new LinkedHashSet[T] + val unmarked = new mutable.LinkedHashSet[T] + val tempMarked = new mutable.LinkedHashSet[T] def visit(n: T): Unit = { if (tempMarked.contains(n)) { @@ -94,7 +93,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link } unmarked ++= getVertices - while (!unmarked.isEmpty) { + while (unmarked.nonEmpty) { visit(unmarked.head) } @@ -108,14 +107,23 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * @return a Map[T,T] from each visited node to its predecessor in the * traversal */ - def BFS(root: T): Map[T,T] = { - val prev = new LinkedHashMap[T,T] + def BFS(root: T): Map[T,T] = BFS(root, Set.empty[T]) + + /** Performs breadth-first search on the directed graph, with a blacklist of nodes + * + * @param root the start node + * @param blacklist list of nodes to stop searching, if encountered + * @return a Map[T,T] from each visited node to its predecessor in the + * traversal + */ + def BFS(root: T, blacklist: Set[T]): Map[T,T] = { + val prev = new mutable.LinkedHashMap[T,T] val queue = new mutable.Queue[T] queue.enqueue(root) - while (!queue.isEmpty) { + while (queue.nonEmpty) { val u = queue.dequeue for (v <- getEdges(u)) { - if (!prev.contains(v)) { + if (!prev.contains(v) && !blacklist.contains(v)) { prev(v) = u queue.enqueue(v) } @@ -129,7 +137,15 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * @param root the start node * @return a Set[T] of nodes reachable from the root */ - def reachableFrom(root: T): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root).map({ case (k, v) => k }) + def reachableFrom(root: T): LinkedHashSet[T] = reachableFrom(root, Set.empty[T]) + + /** Finds the set of nodes reachable from a particular node, with a blacklist + * + * @param root the start node + * @param blacklist list of nodes to stop searching, if encountered + * @return a Set[T] of nodes reachable from the root + */ + def reachableFrom(root: T, blacklist: Set[T]): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root, blacklist).map({ case (k, v) => k }) /** Finds a path (if one exists) from one node to another * diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index a3ad4231..53fbb765 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -189,6 +189,7 @@ abstract class Statement extends FirrtlNode { def mapExpr(f: Expression => Expression): Statement def mapType(f: Type => Type): Statement def mapString(f: String => String): Statement + def mapInfo(f: Info => Info): Statement } case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration { def serialize: String = s"wire $name : ${tpe.serialize}" + info.serialize @@ -196,6 +197,7 @@ case class DefWire(info: Info, name: String, tpe: Type) extends Statement with I def mapExpr(f: Expression => Expression): Statement = this def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe)) def mapString(f: String => String): Statement = DefWire(info, f(name), tpe) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class DefRegister( info: Info, @@ -212,6 +214,7 @@ case class DefRegister( DefRegister(info, name, tpe, f(clock), f(reset), f(init)) def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class DefInstance(info: Info, name: String, module: String) extends Statement with IsDeclaration { @@ -220,6 +223,7 @@ case class DefInstance(info: Info, name: String, module: String) extends Stateme def mapExpr(f: Expression => Expression): Statement = this def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = DefInstance(info, f(name), module) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class DefMemory( info: Info, @@ -248,6 +252,7 @@ case class DefMemory( def mapExpr(f: Expression => Expression): Statement = this def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration { def serialize: String = s"node $name = ${value.serialize}" + info.serialize @@ -255,6 +260,7 @@ case class DefNode(info: Info, name: String, value: Expression) extends Statemen def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value)) def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = DefNode(info, f(name), value) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class Conditionally( info: Info, @@ -270,6 +276,7 @@ case class Conditionally( def mapExpr(f: Expression => Expression): Statement = Conditionally(info, f(pred), conseq, alt) def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class Block(stmts: Seq[Statement]) extends Statement { def serialize: String = stmts map (_.serialize) mkString "\n" @@ -277,6 +284,7 @@ case class Block(stmts: Seq[Statement]) extends Statement { def mapExpr(f: Expression => Expression): Statement = this def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this } case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <- ${expr.serialize}" + info.serialize @@ -284,6 +292,7 @@ case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr)) def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <= ${expr.serialize}" + info.serialize @@ -291,6 +300,7 @@ case class Connect(info: Info, loc: Expression, expr: Expression) extends Statem def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr)) def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${expr.serialize} is invalid" + info.serialize @@ -298,6 +308,7 @@ case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInf def mapExpr(f: Expression => Expression): Statement = IsInvalid(info, f(expr)) def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with HasInfo { def serialize: String = "attach " + exprs.map(_.serialize).mkString("(", ", ", ")") @@ -305,6 +316,7 @@ case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with Has def mapExpr(f: Expression => Expression): Statement = Attach(info, exprs map f) def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo { def serialize: String = s"stop(${clk.serialize}, ${en.serialize}, $ret)" + info.serialize @@ -312,6 +324,7 @@ case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends S def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en)) def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case class Print( info: Info, @@ -328,6 +341,7 @@ case class Print( def mapExpr(f: Expression => Expression): Statement = Print(info, string, args map f, f(clk), f(en)) def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } case object EmptyStmt extends Statement { def serialize: String = "skip" @@ -335,6 +349,7 @@ case object EmptyStmt extends Statement { def mapExpr(f: Expression => Expression): Statement = this def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this } abstract class Width extends FirrtlNode { @@ -514,6 +529,7 @@ abstract class DefModule extends FirrtlNode with IsDeclaration { def mapStmt(f: Statement => Statement): DefModule def mapPort(f: Port => Port): DefModule def mapString(f: String => String): DefModule + def mapInfo(f: Info => Info): DefModule } /** Internal Module * @@ -524,6 +540,7 @@ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) e def mapStmt(f: Statement => Statement): DefModule = this.copy(body = f(body)) def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) def mapString(f: String => String): DefModule = this.copy(name = f(name)) + def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) } /** External Module * @@ -541,6 +558,7 @@ case class ExtModule( def mapStmt(f: Statement => Statement): DefModule = this def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) def mapString(f: String => String): DefModule = this.copy(name = f(name)) + def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) } case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends FirrtlNode with HasInfo { @@ -549,4 +567,5 @@ case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends Fi (modules map ("\n" + _.serialize) map indent mkString "\n") + "\n" def mapModule(f: DefModule => DefModule): Circuit = this.copy(modules = modules map f) def mapString(f: String => String): Circuit = this.copy(main = f(main)) + def mapInfo(f: Info => Info): Circuit = this.copy(f(info)) } diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala index 54441481..5fb837c1 100644 --- a/src/main/scala/firrtl/passes/memlib/MemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala @@ -30,4 +30,5 @@ case class DefAnnotatedMemory( def toMem = DefMemory(info, name, dataType, depth, writeLatency, readLatency, readers, writers, readwriters, readUnderWrite) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) } diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index f22415f0..91c82395 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -5,8 +5,9 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ +import firrtl.analyses.InstanceGraph import firrtl.annotations._ -import firrtl.passes.PassException +import firrtl.passes.{InferTypes, MemPortUtils} // Datastructures import scala.collection.mutable @@ -18,134 +19,296 @@ case class NoDedupAnnotation(target: ModuleName) extends SingleTargetAnnotation[ def duplicate(n: ModuleName) = NoDedupAnnotation(n) } -// Only use on legal Firrtl. Specifically, the restriction of -// instance loops must have been checked, or else this pass can -// infinitely recurse +/** Only use on legal Firrtl. + * + * Specifically, the restriction of instance loops must have been checked, or else this pass can + * infinitely recurse + */ class DedupModules extends Transform { - def inputForm = HighForm - def outputForm = HighForm - // Orders the modules of a circuit from leaves to root - // A module will appear *after* all modules it instantiates - private def buildModuleOrder(c: Circuit): Seq[String] = { - val moduleOrder = mutable.ArrayBuffer.empty[String] - def hasInstance(b: Statement): Boolean = { - var has = false - def onStmt(s: Statement): Statement = s map onStmt match { - case DefInstance(i, n, m) => - if(!(moduleOrder contains m)) has = true - s - case WDefInstance(i, n, m, t) => - if(!(moduleOrder contains m)) has = true - s - case _ => s - } - onStmt(b) - has + def inputForm: CircuitForm = HighForm + def outputForm: CircuitForm = HighForm + + /** + * Deduplicate a Circuit + * @param state Input Firrtl AST + * @return A transformed Firrtl AST + */ + def execute(state: CircuitState): CircuitState = { + val noDedups = state.annotations.collect { case NoDedupAnnotation(ModuleName(m, c)) => m } + val (newC, renameMap) = run(state.circuit, noDedups) + state.copy(circuit = newC, renames = Some(renameMap)) + } + + /** + * Deduplicates a circuit, and records renaming + * @param c Circuit to dedup + * @param noDedups Modules not to dedup + * @return Deduped Circuit and corresponding RenameMap + */ + def run(c: Circuit, noDedups: Seq[String]): (Circuit, RenameMap) = { + + // RenameMap + val renameMap = RenameMap() + renameMap.setCircuit(c.main) + + // Maps module name to corresponding dedup module + val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, renameMap) + + // Use old module list to preserve ordering + val dedupedModules = c.modules.map(m => dedupMap(m.name)).distinct + + val cname = CircuitName(c.main) + renameMap.addMap(dedupMap.map { case (from, to) => + logger.debug(s"[Dedup] $from -> ${to.name}") + ModuleName(from, cname) -> List(ModuleName(to.name, cname)) + }) + + (InferTypes.run(c.copy(modules = dedupedModules)), renameMap) + } +} + +/** + * Utility functions for [[DedupModules]] + */ +object DedupModules { + /** + * Change's a module's internal signal names, types, infos, and modules. + * @param rename Function to rename a signal. Called on declaration and references. + * @param retype Function to retype a signal. Called on declaration, references, and subfields + * @param reinfo Function to re-info a statement + * @param renameModule Function to rename an instance's module + * @param module Module to change internals + * @return Changed Module + */ + def changeInternals(rename: String=>String, + retype: String=>Type=>Type, + reinfo: Info=>Info, + renameModule: String=>String + )(module: DefModule): DefModule = { + def onPort(p: Port): Port = Port(reinfo(p.info), rename(p.name), p.direction, retype(p.name)(p.tpe)) + def onExp(e: Expression): Expression = e match { + case WRef(n, t, k, g) => WRef(rename(n), retype(n)(t), k, g) + case WSubField(expr, n, tpe, kind) => + val fieldIndex = expr.tpe.asInstanceOf[BundleType].fields.indexWhere(f => f.name == n) + val newExpr = onExp(expr) + val newField = newExpr.tpe.asInstanceOf[BundleType].fields(fieldIndex) + val finalExpr = WSubField(newExpr, newField.name, newField.tpe, kind) + //TODO: renameMap.rename(e.serialize, finalExpr.serialize) + finalExpr + case other => other map onExp } - def addModule(m: DefModule): DefModule = m match { - case Module(info, n, ps, b) => - if (!hasInstance(b)) moduleOrder += m.name - m - case e: ExtModule => - moduleOrder += m.name - m - case _ => m + def onStmt(s: Statement): Statement = s match { + case WDefInstance(i, n, m, t) => + val newmod = renameModule(m) + WDefInstance(reinfo(i), rename(n), newmod, retype(n)(t)) + case DefInstance(i, n, m) => DefInstance(reinfo(i), rename(n), renameModule(m)) + case d: DefMemory => + val oldType = MemPortUtils.memType(d) + val newType = retype(d.name)(oldType) + val index = oldType + .asInstanceOf[BundleType].fields.headOption + .map(_.tpe.asInstanceOf[BundleType].fields.indexWhere( + { + case Field("data" | "wdata" | "rdata", _, _) => true + case _ => false + })) + val newDataType = index match { + case Some(i) => + //If index nonempty, then there exists a port + newType.asInstanceOf[BundleType].fields.head.tpe.asInstanceOf[BundleType].fields(i).tpe + case None => + //If index is empty, this mem has no ports, and so we don't need to record the dataType + // Thus, call retype with an illegal name, so we can retype the memory's datatype, but not + // associate it with the type of the memory (as the memory type is different than the datatype) + retype(d.name + ";&*^$")(d.dataType) + } + d.copy(dataType = newDataType) map rename map reinfo + case h: IsDeclaration => h map rename map retype(h.name) map onExp map reinfo + case other => other map reinfo map onExp map onStmt } - - while ((moduleOrder.size < c.modules.size)) { - c.modules.foreach(m => if (!moduleOrder.contains(m.name)) addModule(m)) + val finalModule = module match { + case m: Module => m map onPort map onStmt + case other => other } - moduleOrder + finalModule } - // Finds duplicate Modules - // Also changes DefInstances to instantiate the deduplicated module - // Returns (Deduped Module name -> Seq of identical modules, - // Deuplicate Module name -> deduped module name) - private def findDups( - moduleOrder: Seq[String], - moduleMap: Map[String, DefModule], - noDedups: Seq[String]): (Map[String, Seq[DefModule]], Map[String, String]) = { - // Module body -> Module name - val dedupModules = mutable.HashMap.empty[String, String] - // Old module name -> dup module name - val dedupMap = mutable.HashMap.empty[String, String] - // Deduplicated module name -> all identical modules - val oldModuleMap = mutable.HashMap.empty[String, Seq[DefModule]] - - def onModule(m: DefModule): Unit = { - def fixInstance(s: Statement): Statement = s map fixInstance match { - case DefInstance(i, n, m) => DefInstance(i, n, dedupMap.getOrElse(m, m)) - case WDefInstance(i, n, m, t) => WDefInstance(i, n, dedupMap.getOrElse(m, m), t) - case x => x + /** + * Turns a module into a name-agnostic module + * @param module module to change + * @return name-agnostic module + */ + def agnostify(module: DefModule, name2tag: mutable.HashMap[String, String], tag2name: mutable.HashMap[String, String]): DefModule = { + val namespace = Namespace() + val nameMap = mutable.HashMap[String, String]() + val typeMap = mutable.HashMap[String, Type]() + def rename(name: String): String = { + if (nameMap.contains(name)) nameMap(name) else { + val newName = namespace.newTemp + nameMap(name) = newName + newName } - def removeInfo(stmt: Statement): Statement = stmt map removeInfo match { - case sx: HasInfo => sx match { - case s: DefWire => s.copy(info = NoInfo) - case s: DefNode => s.copy(info = NoInfo) - case s: DefRegister => s.copy(info = NoInfo) - case s: DefInstance => s.copy(info = NoInfo) - case s: WDefInstance => s.copy(info = NoInfo) - case s: DefMemory => s.copy(info = NoInfo) - case s: Connect => s.copy(info = NoInfo) - case s: PartialConnect => s.copy(info = NoInfo) - case s: IsInvalid => s.copy(info = NoInfo) - case s: Attach => s.copy(info = NoInfo) - case s: Stop => s.copy(info = NoInfo) - case s: Print => s.copy(info = NoInfo) - case s: Conditionally => s.copy(info = NoInfo) + } + def retype(name: String)(tpe: Type): Type = { + if (typeMap.contains(name)) typeMap(name) else { + def onType(tpe: Type): Type = tpe map onType match { + case BundleType(fields) => BundleType(fields.map(f => Field(rename(f.name), f.flip, f.tpe))) + case other => other } - case sx => sx + val newType = onType(tpe) + typeMap(name) = newType + newType } - def removePortInfo(p: Port): Port = p.copy(info = NoInfo) + } + def remodule(name: String): String = tag2name(name2tag(name)) + changeInternals(rename, retype, {i: Info => NoInfo}, remodule)(module) + } + /** Dedup a module's instances based on dedup map + * + * Will fixes up module if deduped instance's ports are differently named + * + * @param moduleName Module name who's instances will be deduped + * @param moduleMap Map of module name to its original module + * @param name2name Map of module name to the module deduping it. Not mutated in this function. + * @param renameMap Will be modified to keep track of renames in this function + * @return fixed up module deduped instances + */ + def dedupInstances(moduleName: String, moduleMap: Map[String, DefModule], name2name: mutable.Map[String, String], renameMap: RenameMap): DefModule = { + val module = moduleMap(moduleName) - val mx = m map fixInstance - val mxx = (mx map removeInfo) map removePortInfo + // If black box, return it (it has no instances) + if (module.isInstanceOf[ExtModule]) return module - // If shouldn't dedup, just make it fail to be the same to any other modules - val unique = if (!noDedups.contains(mxx.name)) "" else mxx.name - val string = mxx match { - case Module(i, n, ps, b) => - ps.map(_.serialize).mkString + b.serialize + unique - case ExtModule(i, n, ps, dn, p) => - ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString + unique - } - dedupModules.get(string) match { - case Some(dupname) => - dedupMap(mx.name) = dupname - oldModuleMap(dupname) = oldModuleMap(dupname) :+ mx - case None => - dedupModules(string) = mx.name - oldModuleMap(mx.name) = Seq(mx) + // Get all instances to know what to rename in the module + val instances = mutable.Set[WDefInstance]() + InstanceGraph.collectInstances(instances)(module.asInstanceOf[Module].body) + val instanceModuleMap = instances.map(i => i.name -> i.module).toMap + val moduleNames = instances.map(_.module) + + def getNewModule(old: String): DefModule = { + moduleMap(name2name(old)) + } + // Define rename functions + def renameModule(name: String): String = getNewModule(name).name + val typeMap = mutable.HashMap[String, Type]() + def retype(name: String)(tpe: Type): Type = { + if (typeMap.contains(name)) typeMap(name) else { + if (instanceModuleMap.contains(name)) { + val newType = Utils.module_type(getNewModule(instanceModuleMap(name))) + typeMap(name) = newType + getAffectedExpressions(WRef(name, tpe)).zip(getAffectedExpressions(WRef(name, newType))).foreach { + case (old, nuu) => renameMap.rename(old.serialize, nuu.serialize) + } + newType + } else tpe } } - moduleOrder.foreach(n => onModule(moduleMap(n))) - (oldModuleMap.toMap, dedupMap.toMap) + + renameMap.setModule(module.name) + // Change module internals + changeInternals({n => n}, retype, {i => i}, renameModule)(module) } - def run(c: Circuit, noDedups: Seq[String]): (Circuit, RenameMap) = { - val moduleOrder = buildModuleOrder(c) - val moduleMap = c.modules.map(m => m.name -> m).toMap + /** + * Deduplicate + * @param circuit Circuit + * @param noDedups list of modules to not dedup + * @param renameMap rename map to populate when deduping + * @return Map of original Module name -> Deduped Module + */ + def deduplicate(circuit: Circuit, + noDedups: Set[String], + renameMap: RenameMap): Map[String, DefModule] = { - val (oldModuleMap, dedupMap) = findDups(moduleOrder, moduleMap, noDedups) + // Order of modules, from leaf to top + val moduleLinearization = new InstanceGraph(circuit).moduleOrder.map(_.name).reverse - // Use old module list to preserve ordering - val dedupedModules = c.modules.flatMap(m => oldModuleMap.get(m.name).map(_.head)) + // Maps module name to original module + val moduleMap = circuit.modules.map(m => m.name -> m).toMap - val cname = CircuitName(c.main) - val renameMap = RenameMap(dedupMap.map { case (from, to) => - logger.debug(s"[Dedup] $from -> $to") - ModuleName(from, cname) -> List(ModuleName(to, cname)) - }) + // Maps a module's tag to its deduplicated module + val tag2name = mutable.HashMap.empty[String, String] + + // Maps a module's name to its tag + val name2tag = mutable.HashMap.empty[String, String] + + // Maps a tag to all matching module names + val tag2all = mutable.HashMap.empty[String, mutable.Set[String]] + + // Build dedupMap + moduleLinearization.foreach { moduleName => + // Get original module + val originalModule = moduleMap(moduleName) + + // Replace instance references to new deduped modules + val dontcare = RenameMap() + dontcare.setCircuit("dontcare") + //val fixedModule = DedupModules.dedupInstances(originalModule, tag2module, name2tag, name2module, dontcare) + + if (noDedups.contains(originalModule.name)) { + // Don't dedup. Set dedup module to be the same as fixed module + name2tag(originalModule.name) = originalModule.name + tag2name(originalModule.name) = originalModule.name + //templateModules += originalModule.name + } else { // Try to dedup + + // Build name-agnostic module + val agnosticModule = DedupModules.agnostify(originalModule, name2tag, tag2name) + + // Build tag + val tag = (agnosticModule match { + case Module(i, n, ps, b) => + ps.map(_.serialize).mkString + b.serialize + case ExtModule(i, n, ps, dn, p) => + ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString + }).hashCode().toString + + // Match old module name to its tag + name2tag(originalModule.name) = tag + + // Set tag's module to be the first matching module + if (!tag2name.contains(tag)) { + tag2name(tag) = originalModule.name + tag2all(tag) = mutable.Set(originalModule.name) + } else { + tag2all(tag) += originalModule.name + } + } + } + + + // Set tag2name to be the best dedup module name + val moduleIndex = circuit.modules.zipWithIndex.map{case (m, i) => m.name -> i}.toMap + def order(l: String, r: String): String = if (moduleIndex(l) < moduleIndex(r)) l else r + tag2all.foreach { case (tag, all) => tag2name(tag) = all.reduce(order)} - (c.copy(modules = dedupedModules), renameMap) + // Create map from original to dedup name + val name2name = name2tag.map({ case (name, tag) => name -> tag2name(tag) }) + + // Build Remap for modules with deduped module references + val tag2module = tag2name.map({ case (tag, name) => tag -> DedupModules.dedupInstances(name, moduleMap, name2name, renameMap) }) + + // Build map from original name to corresponding deduped module + val name2module = name2tag.map({ case (name, tag) => name -> tag2module(tag) }) + + name2module.toMap } - def execute(state: CircuitState): CircuitState = { - val noDedups = state.annotations.collect { case NoDedupAnnotation(ModuleName(m, c)) => m } - val (newC, renameMap) = run(state.circuit, noDedups) - state.copy(circuit = newC, renames = Some(renameMap)) + def getAffectedExpressions(root: Expression): Seq[Expression] = { + val all = mutable.ArrayBuffer[Expression]() + + def onExp(expr: Expression): Unit = { + expr.tpe match { + case _: GroundType => + case b: BundleType => b.fields.foreach { f => onExp(WSubField(expr, f.name, f.tpe)) } + case v: VectorType => (0 until v.size).foreach { i => onExp(WSubIndex(expr, i, v.tpe, UNKNOWNGENDER)) } + } + all += expr + } + + onExp(root) + all } } diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala new file mode 100644 index 00000000..43053e3d --- /dev/null +++ b/src/main/scala/firrtl/transforms/GroupComponents.scala @@ -0,0 +1,346 @@ +package firrtl.transforms + +import firrtl._ +import firrtl.Mappers._ +import firrtl.ir._ +import firrtl.annotations.{Annotation, ComponentName} +import firrtl.passes.{InferTypes, LowerTypes, MemPortUtils} +import firrtl.Utils.{kind, throwInternalError} +import firrtl.graph.{DiGraph, MutableDiGraph} + +import scala.collection.mutable + + +/** + * Specifies a group of components, within a module, to pull out into their own module + * Components that are only connected to a group's components will also be included + * + * @param components components in this group + * @param newModule suggested name of the new module + * @param newInstance suggested name of the instance of the new module + * @param outputSuffix suggested suffix of any output ports of the new module + * @param inputSuffix suggested suffix of any input ports of the new module + */ +case class GroupAnnotation(components: Seq[ComponentName], newModule: String, newInstance: String, outputSuffix: Option[String] = None, inputSuffix: Option[String] = None) extends Annotation { + if(components.nonEmpty) { + require(components.forall(_.module == components.head.module), "All components must be in the same module.") + require(components.forall(!_.name.contains('.')), "No components can be a subcomponent.") + } + + /** + * The module that all components are located in + * @return + */ + def currentModule: String = components.head.module.name + + /* Only keeps components renamed to components */ + def update(renames: RenameMap): Seq[Annotation] = { + val newComponents = components.flatMap{c => renames.get(c).getOrElse(Seq(c))}.collect { + case c: ComponentName => c + } + Seq(GroupAnnotation(newComponents, newModule, newInstance, outputSuffix, inputSuffix)) + } +} + +/** + * Splits a module into multiple modules by grouping its components via [[GroupAnnotation]]'s + */ +class GroupComponents extends firrtl.Transform { + type MSet[T] = mutable.Set[T] + + def inputForm: CircuitForm = MidForm + def outputForm: CircuitForm = MidForm + + override def execute(state: CircuitState): CircuitState = { + val groups = state.annotations.collect {case g: GroupAnnotation => g} + val module2group = groups.groupBy(_.currentModule) + val mnamespace = Namespace(state.circuit) + val newModules = state.circuit.modules.flatMap { + case m: Module if module2group.contains(m.name) => + // do stuff + groupModule(m, module2group(m.name).filter(_.components.nonEmpty), mnamespace) + case other => Seq(other) + } + val cs = state.copy(circuit = state.circuit.copy(modules = newModules)) + val csx = InferTypes.execute(cs) + csx + } + + def groupModule(m: Module, groups: Seq[GroupAnnotation], mnamespace: Namespace): Seq[Module] = { + val namespace = Namespace(m) + val groupRoots = groups.map(_.components.map(_.name)) + val totalSum = groupRoots.map(_.size).sum + val union = groupRoots.foldLeft(Set.empty[String]){(all, set) => all.union(set.toSet)} + + require(groupRoots.forall{_.forall{namespace.contains}}, "All names should be in this module") + require(totalSum == union.size, "No name can be in more than one group") + require(groupRoots.forall(_.nonEmpty), "All groupRoots must by non-empty") + + + // Order of groups, according to their label. The label is the first root in the group + val labelOrder = groups.collect({ case g: GroupAnnotation => g.components.head.name }) + + // Annotations, by label + val label2annotation = groups.collect({ case g: GroupAnnotation => g.components.head.name -> g }).toMap + + // Group roots, by label + // The label "" indicates the original module, and components belonging to that group will remain + // in the original module (not get moved into a new module) + val label2group: Map[String, MSet[String]] = groups.collect{ + case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name):_*) + }.toMap + ("" -> mutable.Set("")) + + // Name of new module containing each group, by label + val label2module: Map[String, String] = + groups.map(a => a.components.head.name -> mnamespace.newName(a.newModule)).toMap + + // Name of instance of new module, by label + val label2instance: Map[String, String] = + groups.map(a => a.components.head.name -> namespace.newName(a.newInstance)).toMap + + // Build set of components not in set + val notSet = label2group.map { case (key, value) => key -> union.diff(value) } + + + // Get all dependencies between components + val deps = getComponentConnectivity(m) + + // For each node not in the set, which group (by label) can reach it + val reachableNodes = new mutable.HashMap[String, MSet[String]]() + + // For each group (by label), add connectivity between nodes in set + // Populate reachableNodes with reachability, where blacklist is their notSet + label2group.foreach { case (label, set) => + set.foreach { x => + deps.addPairWithEdge(label, x) + } + deps.reachableFrom(label, notSet(label)) foreach { node => + reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label + } + } + + // Add nodes who are reached by a single group, to that group + reachableNodes.foreach { case (node, membership) => + if(membership.size == 1) { + label2group(membership.head) += node + } else { + label2group("") += node + } + } + + applyGrouping(m, labelOrder, label2group, label2module, label2instance, label2annotation) + } + + /** + * Applies datastructures to a module, to group its components into distinct modules + * @param m module to split apart + * @param labelOrder order of groups in SeqAnnotation, to make the grouping more deterministic + * @param label2group group components, by label + * @param label2module module name, by label + * @param label2instance instance name of the group's module, by label + * @param label2annotation annotation specifying the group, by label + * @return new modules, including each group's module and the new split module + */ + def applyGrouping( m: Module, + labelOrder: Seq[String], + label2group: Map[String, MSet[String]], + label2module: Map[String, String], + label2instance: Map[String, String], + label2annotation: Map[String, GroupAnnotation] + ): Seq[Module] = { + // Maps node to group + val byNode = mutable.HashMap[String, String]() + label2group.foreach { case (group, nodes) => + nodes.foreach { node => + byNode(node) = group + } + } + val groupNamespace = label2group.map { case (head, set) => head -> Namespace(set.toSeq) } + + val groupStatements = mutable.HashMap[String, mutable.ArrayBuffer[Statement]]() + val groupPorts = mutable.HashMap[String, mutable.ArrayBuffer[Port]]() + val groupPortNames = mutable.HashMap[String, mutable.HashMap[String, String]]() + label2group.keys.foreach { group => + groupStatements(group) = new mutable.ArrayBuffer[Statement]() + groupPorts(group) = new mutable.ArrayBuffer[Port]() + groupPortNames(group) = new mutable.HashMap[String, String]() + } + + def addPort(group: String, exp: Expression, d: Direction): String = { + val source = LowerTypes.loweredName(exp) + val portNames = groupPortNames(group) + val suffix = d match { + case Output => label2annotation(group).outputSuffix.getOrElse("") + case Input => label2annotation(group).inputSuffix.getOrElse("") + } + val newName = groupNamespace(group).newName(source + suffix) + val portName = portNames.getOrElseUpdate(source, newName) + groupPorts(group) += Port(NoInfo, portName, d, exp.tpe) + portName + } + + def punchSignalOut(group: String, exp: Expression): String = { + val portName = addPort(group, exp, Output) + groupStatements(group) += Connect(NoInfo, WRef(portName), exp) + portName + } + + // Given the sink is in a group, tidy up source references + def inGroupFixExps(group: String, added: mutable.ArrayBuffer[Statement])(e: Expression): Expression = e match { + case _: Literal => e + case _: DoPrim | _: Mux | _: ValidIf => e map inGroupFixExps(group, added) + case otherExp: Expression => + val wref = getWRef(otherExp) + val source = wref.name + byNode(source) match { + // case 1: source in the same group as sink + case `group` => otherExp //do nothing + + // case 2: source in top + case "" => + // Add port to group's Module + val toPort = addPort(group, otherExp, Input) + + // Add connection in Top to group's Module port + added += Connect(NoInfo, WSubField(WRef(label2instance(group)), toPort), otherExp) + + // Return WRef with new kind (its inside the group Module now) + WRef(toPort, otherExp.tpe, PortKind, MALE) + + // case 3: source in different group + case otherGroup => + // Add port to otherGroup's Module + val fromPort = punchSignalOut(otherGroup, otherExp) + val toPort = addPort(group, otherExp, Input) + + // Add connection in Top from otherGroup's port to group's port + val groupInst = label2instance(group) + val otherInst = label2instance(otherGroup) + added += Connect(NoInfo, WSubField(WRef(groupInst), toPort), WSubField(WRef(otherInst), fromPort)) + + // Return WRef with new kind (its inside the group Module now) + WRef(toPort, otherExp.tpe, PortKind, MALE) + } + } + + // Given the sink is in the parent module, tidy up source references belonging to groups + def inTopFixExps(e: Expression): Expression = e match { + case _: DoPrim | _: Mux | _: ValidIf => e map inTopFixExps + case otherExp: Expression => + val wref = getWRef(otherExp) + if(byNode(wref.name) != "") { + // Get the name of source's group + val otherGroup = byNode(wref.name) + + // Add port to otherGroup's Module + val otherPortName = punchSignalOut(otherGroup, otherExp) + + // Return WSubField (its inside the top Module still) + WSubField(WRef(label2instance(otherGroup)), otherPortName) + + } else otherExp + } + + def onStmt(s: Statement): Statement = { + s match { + // Sink is in a group + case r: IsDeclaration if byNode(r.name) != "" => + val topStmts = mutable.ArrayBuffer[Statement]() + val group = byNode(r.name) + groupStatements(group) += r mapExpr inGroupFixExps(group, topStmts) + Block(topStmts) + case c: Connect if byNode(getWRef(c.loc).name) != "" => + // Sink is in a group + val topStmts = mutable.ArrayBuffer[Statement]() + val group = byNode(getWRef(c.loc).name) + groupStatements(group) += Connect(c.info, c.loc, inGroupFixExps(group, topStmts)(c.expr)) + Block(topStmts) + // TODO Attach if all are in a group? + case _: IsDeclaration | _: Connect | _: Attach => + // Sink is in Top + val ret = s mapExpr inTopFixExps + ret + case other => other map onStmt + } + } + + + // Build datastructures + val newTopBody = Block(labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body))) + val finalTopBody = Block(Utils.squashEmpty(newTopBody).asInstanceOf[Block].stmts.distinct) + + // For all group labels (not including the original module label), return a new Module. + val newModules = labelOrder.filter(_ != "") map { group => + Module(NoInfo, label2module(group), groupPorts(group).distinct, Block(groupStatements(group).distinct)) + } + Seq(m.copy(body = finalTopBody)) ++ newModules + } + + def getWRef(e: Expression): WRef = e match { + case w: WRef => w + case other => + var w = WRef("") + other mapExpr { e => w = getWRef(e); e} + w + } + + /** + * Compute how each component connects to each other component + * It is non-directioned; there is an edge from source to sink and from sink to souce + * @param m module to compute connectivity + * @return a bi-directional representation of component connectivity + */ + def getComponentConnectivity(m: Module): MutableDiGraph[String] = { + val bidirGraph = new MutableDiGraph[String] + val simNamespace = Namespace() + val simulations = new mutable.HashMap[String, Statement] + def onExpr(sink: WRef)(e: Expression): Expression = e match { + case w @ WRef(name, _, _, _) => + bidirGraph.addPairWithEdge(sink.name, name) + bidirGraph.addPairWithEdge(name, sink.name) + w + case other => other map onExpr(sink) + } + def onStmt(stmt: Statement): Unit = stmt match { + case w: WDefInstance => + case h: IsDeclaration => h map onExpr(WRef(h.name)) + case Attach(_, exprs) => // Add edge between each expression + exprs.tail map onExpr(getWRef(exprs.head)) + case Connect(_, loc, expr) => + onExpr(getWRef(loc))(expr) + case q @ Stop(_,_, clk, en) => + val simName = simNamespace.newTemp + simulations(simName) = q + Seq(clk, en) map onExpr(WRef(simName)) + case q @ Print(_, _, args, clk, en) => + val simName = simNamespace.newTemp + simulations(simName) = q + (args :+ clk :+ en) map onExpr(WRef(simName)) + case Block(stmts) => stmts.foreach(onStmt) + case ignore @ (_: IsInvalid | EmptyStmt) => // do nothing + case other => throw new Exception(s"Unexpected Statement $other") + } + + onStmt(m.body) + m.ports.foreach { p => + bidirGraph.addPairWithEdge("", p.name) + bidirGraph.addPairWithEdge(p.name, "") + } + bidirGraph + } +} + +/** + * Splits a module into multiple modules by grouping its components via [[GroupAnnotation]]'s + * Tries to deduplicate the resulting circuit + */ +class GroupAndDedup extends Transform { + def inputForm: CircuitForm = MidForm + def outputForm: CircuitForm = MidForm + + override def execute(state: CircuitState): CircuitState = { + val cs = new GroupComponents().execute(state) + val csx = new DedupModules().execute(cs) + csx + } +} diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala index cf92ec1c..c9c609df 100644 --- a/src/test/scala/firrtlTests/AttachSpec.scala +++ b/src/test/scala/firrtlTests/AttachSpec.scala @@ -48,7 +48,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { |); |endmodule |""".stripMargin.split("\n") map normalized - executeTest(input, check, compiler) + executeTest(input, check, compiler, Seq(dontDedup("A"), dontDedup("B"))) } it should "attach two instances" in { diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index 847643ef..6f94275e 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -20,13 +20,14 @@ abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Com // Executes the test. Call in tests. // annotations cannot have default value because scalatest trait Suite has a default value - def execute(input: String, check: String, annotations: Seq[Annotation]): Unit = { + def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = { val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize val expected = parse(check).serialize logger.debug(actual) logger.debug(expected) (actual) should be (expected) + finalState } // Executes the test, should throw an error // No default to be consistent with execute diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala index e88bd506..6fab902e 100644 --- a/src/test/scala/firrtlTests/transforms/DedupTests.scala +++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala @@ -3,150 +3,210 @@ package firrtlTests package transform -import org.scalatest.FlatSpec -import org.scalatest.Matchers -import org.scalatest.junit.JUnitRunner - -import firrtl.ir.Circuit -import firrtl.Parser -import firrtl.passes.PassExceptions -import firrtl.annotations.{ - Named, - CircuitName, - ModuleName, - Annotation -} -import firrtl.transforms.{DedupModules, NoDedupAnnotation} +import firrtl.annotations._ +import firrtl.transforms.{DedupModules} /** * Tests inline instances transformation */ class DedupModuleTests extends HighTransformSpec { - def transform = new DedupModules - "The module A" should "be deduped" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A - | module A : - | output x: UInt<1> - | x <= UInt(1) - """.stripMargin - execute(input, check, Seq.empty) - } - "The module A and B" should "be deduped" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | inst b of B - | x <= b.x - | module A_ : - | output x: UInt<1> - | inst b of B_ - | x <= b.x - | module B : - | output x: UInt<1> - | x <= UInt(1) - | module B_ : - | output x: UInt<1> - | x <= UInt(1) - """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A - | module A : - | output x: UInt<1> - | inst b of B - | x <= b.x - | module B : - | output x: UInt<1> - | x <= UInt(1) - """.stripMargin - execute(input, check, Seq.empty) - } - "The module A and B with comments" should "be deduped" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B @[yy 2:2] - | x <= b.x @[yy 2:2] - | module A_ : @[xx 1:1] - | output x: UInt<1> @[xx 1:1] - | inst b of B_ @[xx 1:1] - | x <= b.x @[xx 1:1] - | module B : - | output x: UInt<1> - | x <= UInt(1) - | module B_ : - | output x: UInt<1> - | x <= UInt(1) - """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B @[yy 2:2] - | x <= b.x @[yy 2:2] - | module B : - | output x: UInt<1> - | x <= UInt(1) - """.stripMargin - execute(input, check, Seq.empty) - } - "The module B, but not A, with comments" should "be deduped if not annotated" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) - | module A_ : @[xx 1:1] - | output x: UInt<1> @[xx 1:1] - | x <= UInt(1) - """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) - | module A_ : @[xx 1:1] - | output x: UInt<1> @[xx 1:1] - | x <= UInt(1) - """.stripMargin - execute(input, check, Seq(dontDedup("A"))) - } + def transform = new DedupModules + "The module A" should "be deduped" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + execute(input, check, Seq.empty) + } + "The module A and B" should "be deduped" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module A_ : + | output x: UInt<1> + | inst b of B_ + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + execute(input, check, Seq.empty) + } + "The module A and B with comments" should "be deduped" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B @[yy 2:2] + | x <= b.x @[yy 2:2] + | module A_ : @[xx 1:1] + | output x: UInt<1> @[xx 1:1] + | inst b of B_ @[xx 1:1] + | x <= b.x @[xx 1:1] + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B @[yy 2:2] + | x <= b.x @[yy 2:2] + | module B : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + execute(input, check, Seq.empty) + } + "A_ but not A" should "be deduped if not annotated" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + | module A_ : @[xx 1:1] + | output x: UInt<1> @[xx 1:1] + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + | module A_ : @[xx 1:1] + | output x: UInt<1> @[xx 1:1] + | x <= UInt(1) + """.stripMargin + execute(input, check, Seq(dontDedup("A"))) + } + "The module A and A_" should "be deduped even with different port names and info, and annotations should remap" in { + val input = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A_ + | out <= and(a1.x, a2.y) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + | module A_ : @[xx 1:1] + | output y: UInt<1> @[xx 1:1] + | y <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A + | out <= and(a1.x, a2.x) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + """.stripMargin + case class DummyAnnotation(target: ComponentName) extends SingleTargetAnnotation[ComponentName] { + override def duplicate(n: ComponentName): Annotation = DummyAnnotation(n) + } + + val mname = ModuleName("Top", CircuitName("Top")) + val finalState = execute(input, check, Seq(DummyAnnotation(ComponentName("a2.y", mname)))) + + finalState.annotations.collect({ case d: DummyAnnotation => d }).head should be(DummyAnnotation(ComponentName("a2.x", mname))) + + } + "The module A and B" should "be deduped with the first module in order" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | inst b of B_ + | x <= b.x + | module A_ : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + execute(input, check, Seq.empty) + } } // Execution driven tests for inlining modules diff --git a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala new file mode 100644 index 00000000..3a32ec71 --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala @@ -0,0 +1,290 @@ +package firrtlTests +package transforms + +import firrtl.annotations.{CircuitName, ComponentName, ModuleName} +import firrtl.transforms.{GroupAnnotation, GroupComponents} + +class GroupComponentsSpec extends LowTransformSpec { + def transform = new GroupComponents() + val top = "Top" + def topComp(name: String): ComponentName = ComponentName(name, ModuleName(top, CircuitName(top))) + "The register r" should "be grouped" in { + val input = + s"""circuit $top : + | module $top : + | input clk: Clock + | input data: UInt<16> + | output out: UInt<16> + | reg r: UInt<16>, clk + | r <= data + | out <= r + """.stripMargin + val groups = Seq( + GroupAnnotation(Seq(topComp("r")), "MyReg", "rInst", Some("_OUT"), Some("_IN")) + ) + val check = + s"""circuit Top : + | module $top : + | input clk: Clock + | input data: UInt<16> + | output out: UInt<16> + | inst rInst of MyReg + | rInst.clk_IN <= clk + | out <= rInst.r_OUT + | rInst.data_IN <= data + | module MyReg : + | input clk_IN: Clock + | output r_OUT: UInt<16> + | input data_IN: UInt<16> + | reg r: UInt<16>, clk_IN + | r_OUT <= r + | r <= data_IN + """.stripMargin + execute(input, check, groups) + } + + "The two sets of instances" should "be grouped" in { + val input = + s"""circuit $top : + | module $top : + | output out: UInt<16> + | inst c1a of Const1A + | inst c2a of Const2A + | inst c1b of Const1B + | inst c2b of Const2B + | node asum = add(c1a.out, c2a.out) + | node bsum = add(c1b.out, c2b.out) + | out <= add(asum, bsum) + | module Const1A : + | output out: UInt<8> + | out <= UInt(1) + | module Const2A : + | output out: UInt<8> + | out <= UInt(2) + | module Const1B : + | output out: UInt<8> + | out <= UInt(1) + | module Const2B : + | output out: UInt<8> + | out <= UInt(2) + """.stripMargin + val groups = Seq( + GroupAnnotation(Seq(topComp("c1a"), topComp("c2a")/*, topComp("asum")*/), "A", "cA", Some("_OUT"), Some("_IN")), + GroupAnnotation(Seq(topComp("c1b"), topComp("c2b")/*, topComp("bsum")*/), "B", "cB", Some("_OUT"), Some("_IN")) + ) + val check = + s"""circuit Top : + | module $top : + | output out: UInt<16> + | inst cA of A + | inst cB of B + | node asum = add(cA.c1a_out_OUT, cA.c2a_out_OUT) + | node bsum = add(cB.c1b_out_OUT, cB.c2b_out_OUT) + | out <= add(asum, bsum) + | module A : + | output c1a_out_OUT: UInt<8> + | output c2a_out_OUT: UInt<8> + | inst c1a of Const1A + | inst c2a of Const2A + | c1a_out_OUT <= c1a.out + | c2a_out_OUT <= c2a.out + | module B : + | output c1b_out_OUT: UInt<8> + | output c2b_out_OUT: UInt<8> + | inst c1b of Const1B + | inst c2b of Const2B + | c1b_out_OUT <= c1b.out + | c2b_out_OUT <= c2b.out + | module Const1A : + | output out: UInt<8> + | out <= UInt(1) + | module Const2A : + | output out: UInt<8> + | out <= UInt(2) + | module Const1B : + | output out: UInt<8> + | out <= UInt(1) + | module Const2B : + | output out: UInt<8> + | out <= UInt(2) + """.stripMargin + execute(input, check, groups) + } + "The two sets of instances" should "be grouped with their nodes" in { + val input = + s"""circuit $top : + | module $top : + | output out: UInt<16> + | inst c1a of Const1A + | inst c2a of Const2A + | inst c1b of Const1B + | inst c2b of Const2B + | node asum = add(c1a.out, c2a.out) + | node bsum = add(c1b.out, c2b.out) + | out <= add(asum, bsum) + | module Const1A : + | output out: UInt<8> + | out <= UInt(1) + | module Const2A : + | output out: UInt<8> + | out <= UInt(2) + | module Const1B : + | output out: UInt<8> + | out <= UInt(1) + | module Const2B : + | output out: UInt<8> + | out <= UInt(2) + """.stripMargin + val groups = Seq( + GroupAnnotation(Seq(topComp("c1a"), topComp("c2a"), topComp("asum")), "A", "cA", Some("_OUT"), Some("_IN")), + GroupAnnotation(Seq(topComp("c1b"), topComp("c2b"), topComp("bsum")), "B", "cB", Some("_OUT"), Some("_IN")) + ) + val check = + s"""circuit Top : + | module $top : + | output out: UInt<16> + | inst cA of A + | inst cB of B + | out <= add(cA.asum_OUT, cB.bsum_OUT) + | module A : + | output asum_OUT: UInt<9> + | inst c1a of Const1A + | inst c2a of Const2A + | node asum = add(c1a.out, c2a.out) + | asum_OUT <= asum + | module B : + | output bsum_OUT: UInt<9> + | inst c1b of Const1B + | inst c2b of Const2B + | node bsum = add(c1b.out, c2b.out) + | bsum_OUT <= bsum + | module Const1A : + | output out: UInt<8> + | out <= UInt(1) + | module Const2A : + | output out: UInt<8> + | out <= UInt(2) + | module Const1B : + | output out: UInt<8> + | out <= UInt(1) + | module Const2B : + | output out: UInt<8> + | out <= UInt(2) + """.stripMargin + execute(input, check, groups) + } + + "The two sets of instances" should "be grouped with one not grouped" in { + val input = + s"""circuit $top : + | module $top : + | output out: UInt<16> + | inst c1a of Const1A + | inst c2a of Const2A + | inst c1b of Const1B + | inst c2b of Const2B + | node asum = add(c1a.out, c2a.out) + | node bsum = add(c1b.out, c2b.out) + | inst pass of PassThrough + | pass.in <= add(asum, bsum) + | out <= pass.out + | module Const1A : + | output out: UInt<8> + | out <= UInt(1) + | module Const2A : + | output out: UInt<8> + | out <= UInt(2) + | module Const1B : + | output out: UInt<8> + | out <= UInt(1) + | module Const2B : + | output out: UInt<8> + | out <= UInt(2) + | module PassThrough : + | input in: UInt + | output out: UInt + | out <= in + """.stripMargin + val groups = Seq( + GroupAnnotation(Seq(topComp("c1a"), topComp("c2a"), topComp("asum")), "A", "cA", Some("_OUT"), Some("_IN")), + GroupAnnotation(Seq(topComp("c1b"), topComp("c2b"), topComp("bsum")), "B", "cB", Some("_OUT"), Some("_IN")) + ) + val check = + s"""circuit Top : + | module $top : + | output out: UInt<16> + | inst cA of A + | inst cB of B + | inst pass of PassThrough + | out <= pass.out + | pass.in <= add(cA.asum_OUT, cB.bsum_OUT) + | module A : + | output asum_OUT: UInt<9> + | inst c1a of Const1A + | inst c2a of Const2A + | node asum = add(c1a.out, c2a.out) + | asum_OUT <= asum + | module B : + | output bsum_OUT: UInt<9> + | inst c1b of Const1B + | inst c2b of Const2B + | node bsum = add(c1b.out, c2b.out) + | bsum_OUT <= bsum + | module Const1A : + | output out: UInt<8> + | out <= UInt(1) + | module Const2A : + | output out: UInt<8> + | out <= UInt(2) + | module Const1B : + | output out: UInt<8> + | out <= UInt(1) + | module Const2B : + | output out: UInt<8> + | out <= UInt(2) + | module PassThrough : + | input in: UInt<10> + | output out: UInt<10> + | out <= in + """.stripMargin + execute(input, check, groups) + } + + "The two sets of instances" should "be grouped with a connection between them" in { + val input = + s"""circuit $top : + | module $top : + | input in: UInt<16> + | output out: UInt<16> + | node first = in + | node second = not(first) + | out <= second + """.stripMargin + val groups = Seq( + GroupAnnotation(Seq(topComp("first")), "First", "first"), + GroupAnnotation(Seq(topComp("second")), "Second", "second") + ) + val check = + s"""circuit $top : + | module $top : + | input in: UInt<16> + | output out: UInt<16> + | inst first_0 of First + | inst second_0 of Second + | first_0.in <= in + | second_0.first <= first_0.first_0 + | out <= second_0.second_0 + | module First : + | input in: UInt<16> + | output first_0: UInt<16> + | node first = in + | first_0 <= first + | module Second : + | input first: UInt<16> + | output second_0: UInt<16> + | node second = not(first) + | second_0 <= second + """.stripMargin + execute(input, check, groups) + } +} |
