aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/Dedup.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/transforms/Dedup.scala')
-rw-r--r--src/main/scala/firrtl/transforms/Dedup.scala377
1 files changed, 270 insertions, 107 deletions
diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala
index f22415f0..91c82395 100644
--- a/src/main/scala/firrtl/transforms/Dedup.scala
+++ b/src/main/scala/firrtl/transforms/Dedup.scala
@@ -5,8 +5,9 @@ package transforms
import firrtl.ir._
import firrtl.Mappers._
+import firrtl.analyses.InstanceGraph
import firrtl.annotations._
-import firrtl.passes.PassException
+import firrtl.passes.{InferTypes, MemPortUtils}
// Datastructures
import scala.collection.mutable
@@ -18,134 +19,296 @@ case class NoDedupAnnotation(target: ModuleName) extends SingleTargetAnnotation[
def duplicate(n: ModuleName) = NoDedupAnnotation(n)
}
-// Only use on legal Firrtl. Specifically, the restriction of
-// instance loops must have been checked, or else this pass can
-// infinitely recurse
+/** Only use on legal Firrtl.
+ *
+ * Specifically, the restriction of instance loops must have been checked, or else this pass can
+ * infinitely recurse
+ */
class DedupModules extends Transform {
- def inputForm = HighForm
- def outputForm = HighForm
- // Orders the modules of a circuit from leaves to root
- // A module will appear *after* all modules it instantiates
- private def buildModuleOrder(c: Circuit): Seq[String] = {
- val moduleOrder = mutable.ArrayBuffer.empty[String]
- def hasInstance(b: Statement): Boolean = {
- var has = false
- def onStmt(s: Statement): Statement = s map onStmt match {
- case DefInstance(i, n, m) =>
- if(!(moduleOrder contains m)) has = true
- s
- case WDefInstance(i, n, m, t) =>
- if(!(moduleOrder contains m)) has = true
- s
- case _ => s
- }
- onStmt(b)
- has
+ def inputForm: CircuitForm = HighForm
+ def outputForm: CircuitForm = HighForm
+
+ /**
+ * Deduplicate a Circuit
+ * @param state Input Firrtl AST
+ * @return A transformed Firrtl AST
+ */
+ def execute(state: CircuitState): CircuitState = {
+ val noDedups = state.annotations.collect { case NoDedupAnnotation(ModuleName(m, c)) => m }
+ val (newC, renameMap) = run(state.circuit, noDedups)
+ state.copy(circuit = newC, renames = Some(renameMap))
+ }
+
+ /**
+ * Deduplicates a circuit, and records renaming
+ * @param c Circuit to dedup
+ * @param noDedups Modules not to dedup
+ * @return Deduped Circuit and corresponding RenameMap
+ */
+ def run(c: Circuit, noDedups: Seq[String]): (Circuit, RenameMap) = {
+
+ // RenameMap
+ val renameMap = RenameMap()
+ renameMap.setCircuit(c.main)
+
+ // Maps module name to corresponding dedup module
+ val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, renameMap)
+
+ // Use old module list to preserve ordering
+ val dedupedModules = c.modules.map(m => dedupMap(m.name)).distinct
+
+ val cname = CircuitName(c.main)
+ renameMap.addMap(dedupMap.map { case (from, to) =>
+ logger.debug(s"[Dedup] $from -> ${to.name}")
+ ModuleName(from, cname) -> List(ModuleName(to.name, cname))
+ })
+
+ (InferTypes.run(c.copy(modules = dedupedModules)), renameMap)
+ }
+}
+
+/**
+ * Utility functions for [[DedupModules]]
+ */
+object DedupModules {
+ /**
+ * Change's a module's internal signal names, types, infos, and modules.
+ * @param rename Function to rename a signal. Called on declaration and references.
+ * @param retype Function to retype a signal. Called on declaration, references, and subfields
+ * @param reinfo Function to re-info a statement
+ * @param renameModule Function to rename an instance's module
+ * @param module Module to change internals
+ * @return Changed Module
+ */
+ def changeInternals(rename: String=>String,
+ retype: String=>Type=>Type,
+ reinfo: Info=>Info,
+ renameModule: String=>String
+ )(module: DefModule): DefModule = {
+ def onPort(p: Port): Port = Port(reinfo(p.info), rename(p.name), p.direction, retype(p.name)(p.tpe))
+ def onExp(e: Expression): Expression = e match {
+ case WRef(n, t, k, g) => WRef(rename(n), retype(n)(t), k, g)
+ case WSubField(expr, n, tpe, kind) =>
+ val fieldIndex = expr.tpe.asInstanceOf[BundleType].fields.indexWhere(f => f.name == n)
+ val newExpr = onExp(expr)
+ val newField = newExpr.tpe.asInstanceOf[BundleType].fields(fieldIndex)
+ val finalExpr = WSubField(newExpr, newField.name, newField.tpe, kind)
+ //TODO: renameMap.rename(e.serialize, finalExpr.serialize)
+ finalExpr
+ case other => other map onExp
}
- def addModule(m: DefModule): DefModule = m match {
- case Module(info, n, ps, b) =>
- if (!hasInstance(b)) moduleOrder += m.name
- m
- case e: ExtModule =>
- moduleOrder += m.name
- m
- case _ => m
+ def onStmt(s: Statement): Statement = s match {
+ case WDefInstance(i, n, m, t) =>
+ val newmod = renameModule(m)
+ WDefInstance(reinfo(i), rename(n), newmod, retype(n)(t))
+ case DefInstance(i, n, m) => DefInstance(reinfo(i), rename(n), renameModule(m))
+ case d: DefMemory =>
+ val oldType = MemPortUtils.memType(d)
+ val newType = retype(d.name)(oldType)
+ val index = oldType
+ .asInstanceOf[BundleType].fields.headOption
+ .map(_.tpe.asInstanceOf[BundleType].fields.indexWhere(
+ {
+ case Field("data" | "wdata" | "rdata", _, _) => true
+ case _ => false
+ }))
+ val newDataType = index match {
+ case Some(i) =>
+ //If index nonempty, then there exists a port
+ newType.asInstanceOf[BundleType].fields.head.tpe.asInstanceOf[BundleType].fields(i).tpe
+ case None =>
+ //If index is empty, this mem has no ports, and so we don't need to record the dataType
+ // Thus, call retype with an illegal name, so we can retype the memory's datatype, but not
+ // associate it with the type of the memory (as the memory type is different than the datatype)
+ retype(d.name + ";&*^$")(d.dataType)
+ }
+ d.copy(dataType = newDataType) map rename map reinfo
+ case h: IsDeclaration => h map rename map retype(h.name) map onExp map reinfo
+ case other => other map reinfo map onExp map onStmt
}
-
- while ((moduleOrder.size < c.modules.size)) {
- c.modules.foreach(m => if (!moduleOrder.contains(m.name)) addModule(m))
+ val finalModule = module match {
+ case m: Module => m map onPort map onStmt
+ case other => other
}
- moduleOrder
+ finalModule
}
- // Finds duplicate Modules
- // Also changes DefInstances to instantiate the deduplicated module
- // Returns (Deduped Module name -> Seq of identical modules,
- // Deuplicate Module name -> deduped module name)
- private def findDups(
- moduleOrder: Seq[String],
- moduleMap: Map[String, DefModule],
- noDedups: Seq[String]): (Map[String, Seq[DefModule]], Map[String, String]) = {
- // Module body -> Module name
- val dedupModules = mutable.HashMap.empty[String, String]
- // Old module name -> dup module name
- val dedupMap = mutable.HashMap.empty[String, String]
- // Deduplicated module name -> all identical modules
- val oldModuleMap = mutable.HashMap.empty[String, Seq[DefModule]]
-
- def onModule(m: DefModule): Unit = {
- def fixInstance(s: Statement): Statement = s map fixInstance match {
- case DefInstance(i, n, m) => DefInstance(i, n, dedupMap.getOrElse(m, m))
- case WDefInstance(i, n, m, t) => WDefInstance(i, n, dedupMap.getOrElse(m, m), t)
- case x => x
+ /**
+ * Turns a module into a name-agnostic module
+ * @param module module to change
+ * @return name-agnostic module
+ */
+ def agnostify(module: DefModule, name2tag: mutable.HashMap[String, String], tag2name: mutable.HashMap[String, String]): DefModule = {
+ val namespace = Namespace()
+ val nameMap = mutable.HashMap[String, String]()
+ val typeMap = mutable.HashMap[String, Type]()
+ def rename(name: String): String = {
+ if (nameMap.contains(name)) nameMap(name) else {
+ val newName = namespace.newTemp
+ nameMap(name) = newName
+ newName
}
- def removeInfo(stmt: Statement): Statement = stmt map removeInfo match {
- case sx: HasInfo => sx match {
- case s: DefWire => s.copy(info = NoInfo)
- case s: DefNode => s.copy(info = NoInfo)
- case s: DefRegister => s.copy(info = NoInfo)
- case s: DefInstance => s.copy(info = NoInfo)
- case s: WDefInstance => s.copy(info = NoInfo)
- case s: DefMemory => s.copy(info = NoInfo)
- case s: Connect => s.copy(info = NoInfo)
- case s: PartialConnect => s.copy(info = NoInfo)
- case s: IsInvalid => s.copy(info = NoInfo)
- case s: Attach => s.copy(info = NoInfo)
- case s: Stop => s.copy(info = NoInfo)
- case s: Print => s.copy(info = NoInfo)
- case s: Conditionally => s.copy(info = NoInfo)
+ }
+ def retype(name: String)(tpe: Type): Type = {
+ if (typeMap.contains(name)) typeMap(name) else {
+ def onType(tpe: Type): Type = tpe map onType match {
+ case BundleType(fields) => BundleType(fields.map(f => Field(rename(f.name), f.flip, f.tpe)))
+ case other => other
}
- case sx => sx
+ val newType = onType(tpe)
+ typeMap(name) = newType
+ newType
}
- def removePortInfo(p: Port): Port = p.copy(info = NoInfo)
+ }
+ def remodule(name: String): String = tag2name(name2tag(name))
+ changeInternals(rename, retype, {i: Info => NoInfo}, remodule)(module)
+ }
+ /** Dedup a module's instances based on dedup map
+ *
+ * Will fixes up module if deduped instance's ports are differently named
+ *
+ * @param moduleName Module name who's instances will be deduped
+ * @param moduleMap Map of module name to its original module
+ * @param name2name Map of module name to the module deduping it. Not mutated in this function.
+ * @param renameMap Will be modified to keep track of renames in this function
+ * @return fixed up module deduped instances
+ */
+ def dedupInstances(moduleName: String, moduleMap: Map[String, DefModule], name2name: mutable.Map[String, String], renameMap: RenameMap): DefModule = {
+ val module = moduleMap(moduleName)
- val mx = m map fixInstance
- val mxx = (mx map removeInfo) map removePortInfo
+ // If black box, return it (it has no instances)
+ if (module.isInstanceOf[ExtModule]) return module
- // If shouldn't dedup, just make it fail to be the same to any other modules
- val unique = if (!noDedups.contains(mxx.name)) "" else mxx.name
- val string = mxx match {
- case Module(i, n, ps, b) =>
- ps.map(_.serialize).mkString + b.serialize + unique
- case ExtModule(i, n, ps, dn, p) =>
- ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString + unique
- }
- dedupModules.get(string) match {
- case Some(dupname) =>
- dedupMap(mx.name) = dupname
- oldModuleMap(dupname) = oldModuleMap(dupname) :+ mx
- case None =>
- dedupModules(string) = mx.name
- oldModuleMap(mx.name) = Seq(mx)
+ // Get all instances to know what to rename in the module
+ val instances = mutable.Set[WDefInstance]()
+ InstanceGraph.collectInstances(instances)(module.asInstanceOf[Module].body)
+ val instanceModuleMap = instances.map(i => i.name -> i.module).toMap
+ val moduleNames = instances.map(_.module)
+
+ def getNewModule(old: String): DefModule = {
+ moduleMap(name2name(old))
+ }
+ // Define rename functions
+ def renameModule(name: String): String = getNewModule(name).name
+ val typeMap = mutable.HashMap[String, Type]()
+ def retype(name: String)(tpe: Type): Type = {
+ if (typeMap.contains(name)) typeMap(name) else {
+ if (instanceModuleMap.contains(name)) {
+ val newType = Utils.module_type(getNewModule(instanceModuleMap(name)))
+ typeMap(name) = newType
+ getAffectedExpressions(WRef(name, tpe)).zip(getAffectedExpressions(WRef(name, newType))).foreach {
+ case (old, nuu) => renameMap.rename(old.serialize, nuu.serialize)
+ }
+ newType
+ } else tpe
}
}
- moduleOrder.foreach(n => onModule(moduleMap(n)))
- (oldModuleMap.toMap, dedupMap.toMap)
+
+ renameMap.setModule(module.name)
+ // Change module internals
+ changeInternals({n => n}, retype, {i => i}, renameModule)(module)
}
- def run(c: Circuit, noDedups: Seq[String]): (Circuit, RenameMap) = {
- val moduleOrder = buildModuleOrder(c)
- val moduleMap = c.modules.map(m => m.name -> m).toMap
+ /**
+ * Deduplicate
+ * @param circuit Circuit
+ * @param noDedups list of modules to not dedup
+ * @param renameMap rename map to populate when deduping
+ * @return Map of original Module name -> Deduped Module
+ */
+ def deduplicate(circuit: Circuit,
+ noDedups: Set[String],
+ renameMap: RenameMap): Map[String, DefModule] = {
- val (oldModuleMap, dedupMap) = findDups(moduleOrder, moduleMap, noDedups)
+ // Order of modules, from leaf to top
+ val moduleLinearization = new InstanceGraph(circuit).moduleOrder.map(_.name).reverse
- // Use old module list to preserve ordering
- val dedupedModules = c.modules.flatMap(m => oldModuleMap.get(m.name).map(_.head))
+ // Maps module name to original module
+ val moduleMap = circuit.modules.map(m => m.name -> m).toMap
- val cname = CircuitName(c.main)
- val renameMap = RenameMap(dedupMap.map { case (from, to) =>
- logger.debug(s"[Dedup] $from -> $to")
- ModuleName(from, cname) -> List(ModuleName(to, cname))
- })
+ // Maps a module's tag to its deduplicated module
+ val tag2name = mutable.HashMap.empty[String, String]
+
+ // Maps a module's name to its tag
+ val name2tag = mutable.HashMap.empty[String, String]
+
+ // Maps a tag to all matching module names
+ val tag2all = mutable.HashMap.empty[String, mutable.Set[String]]
+
+ // Build dedupMap
+ moduleLinearization.foreach { moduleName =>
+ // Get original module
+ val originalModule = moduleMap(moduleName)
+
+ // Replace instance references to new deduped modules
+ val dontcare = RenameMap()
+ dontcare.setCircuit("dontcare")
+ //val fixedModule = DedupModules.dedupInstances(originalModule, tag2module, name2tag, name2module, dontcare)
+
+ if (noDedups.contains(originalModule.name)) {
+ // Don't dedup. Set dedup module to be the same as fixed module
+ name2tag(originalModule.name) = originalModule.name
+ tag2name(originalModule.name) = originalModule.name
+ //templateModules += originalModule.name
+ } else { // Try to dedup
+
+ // Build name-agnostic module
+ val agnosticModule = DedupModules.agnostify(originalModule, name2tag, tag2name)
+
+ // Build tag
+ val tag = (agnosticModule match {
+ case Module(i, n, ps, b) =>
+ ps.map(_.serialize).mkString + b.serialize
+ case ExtModule(i, n, ps, dn, p) =>
+ ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString
+ }).hashCode().toString
+
+ // Match old module name to its tag
+ name2tag(originalModule.name) = tag
+
+ // Set tag's module to be the first matching module
+ if (!tag2name.contains(tag)) {
+ tag2name(tag) = originalModule.name
+ tag2all(tag) = mutable.Set(originalModule.name)
+ } else {
+ tag2all(tag) += originalModule.name
+ }
+ }
+ }
+
+
+ // Set tag2name to be the best dedup module name
+ val moduleIndex = circuit.modules.zipWithIndex.map{case (m, i) => m.name -> i}.toMap
+ def order(l: String, r: String): String = if (moduleIndex(l) < moduleIndex(r)) l else r
+ tag2all.foreach { case (tag, all) => tag2name(tag) = all.reduce(order)}
- (c.copy(modules = dedupedModules), renameMap)
+ // Create map from original to dedup name
+ val name2name = name2tag.map({ case (name, tag) => name -> tag2name(tag) })
+
+ // Build Remap for modules with deduped module references
+ val tag2module = tag2name.map({ case (tag, name) => tag -> DedupModules.dedupInstances(name, moduleMap, name2name, renameMap) })
+
+ // Build map from original name to corresponding deduped module
+ val name2module = name2tag.map({ case (name, tag) => name -> tag2module(tag) })
+
+ name2module.toMap
}
- def execute(state: CircuitState): CircuitState = {
- val noDedups = state.annotations.collect { case NoDedupAnnotation(ModuleName(m, c)) => m }
- val (newC, renameMap) = run(state.circuit, noDedups)
- state.copy(circuit = newC, renames = Some(renameMap))
+ def getAffectedExpressions(root: Expression): Seq[Expression] = {
+ val all = mutable.ArrayBuffer[Expression]()
+
+ def onExp(expr: Expression): Unit = {
+ expr.tpe match {
+ case _: GroundType =>
+ case b: BundleType => b.fields.foreach { f => onExp(WSubField(expr, f.name, f.tpe)) }
+ case v: VectorType => (0 until v.size).foreach { i => onExp(WSubIndex(expr, i, v.tpe, UNKNOWNGENDER)) }
+ }
+ all += expr
+ }
+
+ onExp(root)
+ all
}
}