diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/transforms/Dedup.scala | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/transforms/Dedup.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/Dedup.scala | 297 |
1 files changed, 167 insertions, 130 deletions
diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index 627af11f..18e32cbc 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -20,7 +20,6 @@ import scala.annotation.tailrec // Datastructures import scala.collection.mutable - /** A component, e.g. register etc. Must be declared only once under the TopAnnotation */ case class NoDedupAnnotation(target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] { def duplicate(n: ModuleTarget): NoDedupAnnotation = NoDedupAnnotation(n) @@ -36,7 +35,9 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt new ShellOption[Unit]( longOption = "no-dedup", toAnnotationSeq = _ => Seq(NoCircuitDedupAnnotation), - helpText = "Do NOT dedup modules" ) ) + helpText = "Do NOT dedup modules" + ) + ) } @@ -46,12 +47,13 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt * @param original Original module * @param index the normalized position of the original module in the original module list, fraction between 0 and 1 */ -case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double) extends MultiTargetAnnotation { +case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double) + extends MultiTargetAnnotation { override val targets: Seq[Seq[Target]] = Seq(Seq(original), duplicate.toList) override def duplicate(n: Seq[Seq[Target]]): Annotation = { n.toList match { case Seq(_, List(dup: IsModule)) => DedupedResult(original, Some(dup), index) - case _ => DedupedResult(original, None, -1) + case _ => DedupedResult(original, None, -1) } } } @@ -96,7 +98,7 @@ class DedupModules extends Transform with DependencyAPIMigration { val noDedups = state.circuit.main +: state.annotations.collect { case NoDedupAnnotation(ModuleTarget(_, m)) => m } val (remainingAnnotations, dupResults) = state.annotations.partition { case _: DupedResult => false - case _ => true + case _ => true } val previouslyDupedMap = dupResults.flatMap { case DupedResult(newModules, original) => @@ -114,9 +116,11 @@ class DedupModules extends Transform with DependencyAPIMigration { * @param noDedups Modules not to dedup * @return Deduped Circuit and corresponding RenameMap */ - def run(c: Circuit, - noDedups: Seq[String], - previouslyDupedMap: Map[String, String]): (Circuit, RenameMap, AnnotationSeq) = { + def run( + c: Circuit, + noDedups: Seq[String], + previouslyDupedMap: Map[String, String] + ): (Circuit, RenameMap, AnnotationSeq) = { // RenameMap val componentRenameMap = RenameMap() @@ -124,13 +128,16 @@ class DedupModules extends Transform with DependencyAPIMigration { // Maps module name to corresponding dedup module val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, previouslyDupedMap, componentRenameMap) - val dedupCliques = dedupMap.foldLeft(Map.empty[String, Set[String]]) { - case (dedupCliqueMap, (orig: String, dupMod: DefModule)) => - val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig - dedupCliqueMap + (dupMod.name -> set) - }.flatMap { case (dedupName, set) => - set.map { _ -> set } - } + val dedupCliques = dedupMap + .foldLeft(Map.empty[String, Set[String]]) { + case (dedupCliqueMap, (orig: String, dupMod: DefModule)) => + val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig + dedupCliqueMap + (dupMod.name -> set) + } + .flatMap { + case (dedupName, set) => + set.map { _ -> set } + } // Use old module list to preserve ordering // Lookup what a module deduped to, if its a duplicate, remove it @@ -149,9 +156,10 @@ class DedupModules extends Transform with DependencyAPIMigration { val ct = CircuitTarget(c.main) - val map = dedupMap.map { case (from, to) => - logger.debug(s"[Dedup] $from -> ${to.name}") - ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name)) + val map = dedupMap.map { + case (from, to) => + logger.debug(s"[Dedup] $from -> ${to.name}") + ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name)) } val moduleRenameMap = RenameMap() moduleRenameMap.recordAll(map) @@ -159,15 +167,19 @@ class DedupModules extends Transform with DependencyAPIMigration { // Build instanceify renaming map val instanceGraph = InstanceKeyGraph(c) val instanceify = RenameMap() - val moduleName2Index = c.modules.map(_.name).zipWithIndex.map { case (n, i) => - { - c.modules.size match { - case 0 => (n, 0.0) - case 1 => (n, 1.0) - case d => (n, i.toDouble / (d - 1)) + val moduleName2Index = c.modules + .map(_.name) + .zipWithIndex + .map { + case (n, i) => { + c.modules.size match { + case 0 => (n, 0.0) + case 1 => (n, 1.0) + case d => (n, i.toDouble / (d - 1)) + } } } - }.toMap + .toMap // get the ordered set of instances a module, includes new Deduped modules val getChildrenInstances = { @@ -182,56 +194,62 @@ class DedupModules extends Transform with DependencyAPIMigration { } val instanceNameMap: Map[OfModule, Map[Instance, Instance]] = { - dedupMap.map { case (oldName, dedupedMod) => - val key = OfModule(oldName) - val value = getChildrenInstances(oldName).zip(getChildrenInstances(dedupedMod.name)).map { - case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name) - }.toMap - key -> value + dedupMap.map { + case (oldName, dedupedMod) => + val key = OfModule(oldName) + val value = getChildrenInstances(oldName) + .zip(getChildrenInstances(dedupedMod.name)) + .map { + case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name) + } + .toMap + key -> value }.toMap } - val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap { case mt@ModuleTarget(c, m) if dedupCliques(m).size > 1 => - dedupMap.get(m) match { - case None => Nil - case Some(module: DefModule) => - val paths = instanceGraph.findInstancesInHierarchy(m) - // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option - val newTargets = paths.map { path => - val root: IsModule = ct.module(c) - path.foldLeft(root -> root) { case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) => - if(mod == c) { - val mod = CircuitTarget(c).module(c) - mod -> mod - } else { - val enclosingMod = oldRelPath match { - case i: InstanceTarget => i.ofModule - case m: ModuleTarget => m.module - } - val instMap = instanceNameMap(OfModule(enclosingMod)) - val newInstName = instMap(Instance(name)).value - val old = oldRelPath.instOf(name, mod) - old -> newRelPath.instOf(newInstName, mod) + val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap { + case mt @ ModuleTarget(c, m) if dedupCliques(m).size > 1 => + dedupMap.get(m) match { + case None => Nil + case Some(module: DefModule) => + val paths = instanceGraph.findInstancesInHierarchy(m) + // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option + val newTargets = paths.map { path => + val root: IsModule = ct.module(c) + path.foldLeft(root -> root) { + case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) => + if (mod == c) { + val mod = CircuitTarget(c).module(c) + mod -> mod + } else { + val enclosingMod = oldRelPath match { + case i: InstanceTarget => i.ofModule + case m: ModuleTarget => m.module + } + val instMap = instanceNameMap(OfModule(enclosingMod)) + val newInstName = instMap(Instance(name)).value + val old = oldRelPath.instOf(name, mod) + old -> newRelPath.instOf(newInstName, mod) + } } } - } - // Add all relative paths to referredModule to map to new instances - def addRecord(from: IsMember, to: IsMember): Unit = from match { - case x: ModuleTarget => - instanceify.record(x, to) - case x: IsComponent => - instanceify.record(x, to) - addRecord(x.stripHierarchy(1), to) - } - // Instanceify deduped Modules! - if (dedupCliques(module.name).size > 1) { - newTargets.foreach { case (from, to) => addRecord(from, to) } - } - // Return Deduped Results - if (newTargets.size == 1) { - Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m))) - } else Nil - } + // Add all relative paths to referredModule to map to new instances + def addRecord(from: IsMember, to: IsMember): Unit = from match { + case x: ModuleTarget => + instanceify.record(x, to) + case x: IsComponent => + instanceify.record(x, to) + addRecord(x.stripHierarchy(1), to) + } + // Instanceify deduped Modules! + if (dedupCliques(module.name).size > 1) { + newTargets.foreach { case (from, to) => addRecord(from, to) } + } + // Return Deduped Results + if (newTargets.size == 1) { + Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m))) + } else Nil + } case noDedups => Nil } @@ -242,6 +260,7 @@ class DedupModules extends Transform with DependencyAPIMigration { /** Utility functions for [[DedupModules]] */ object DedupModules extends LazyLogging { + /** Change's a module's internal signal names, types, infos, and modules. * @param rename Function to rename a signal. Called on declaration and references. * @param retype Function to retype a signal. Called on declaration, references, and subfields @@ -250,14 +269,16 @@ object DedupModules extends LazyLogging { * @param module Module to change internals * @return Changed Module */ - def changeInternals(rename: String=>String, - retype: String=>Type=>Type, - reinfo: Info=>Info, - renameOfModule: (String, String)=>String, - renameExps: Boolean = true - )(module: DefModule): DefModule = { + def changeInternals( + rename: String => String, + retype: String => Type => Type, + reinfo: Info => Info, + renameOfModule: (String, String) => String, + renameExps: Boolean = true + )(module: DefModule + ): DefModule = { def onPort(p: Port): Port = Port(reinfo(p.info), rename(p.name), p.direction, retype(p.name)(p.tpe)) - def onExp(e: Expression): Expression = e match { + def onExp(e: Expression): Expression = e match { case WRef(n, t, k, g) => WRef(rename(n), retype(n)(t), k, g) case WSubField(expr, n, tpe, kind) => val fieldIndex = expr.tpe.asInstanceOf[BundleType].fields.indexWhere(f => f.name == n) @@ -266,12 +287,12 @@ object DedupModules extends LazyLogging { val finalExpr = WSubField(newExpr, newField.name, newField.tpe, kind) //TODO: renameMap.rename(e.serialize, finalExpr.serialize) finalExpr - case other => other map onExp + case other => other.map(onExp) } def onStmt(s: Statement): Statement = s match { case DefNode(info, name, value) => retype(name)(value.tpe) - if(renameExps) DefNode(reinfo(info), rename(name), onExp(value)) + if (renameExps) DefNode(reinfo(info), rename(name), onExp(value)) else DefNode(reinfo(info), rename(name), value) case WDefInstance(i, n, m, t) => val newmod = renameOfModule(n, m) @@ -283,12 +304,18 @@ object DedupModules extends LazyLogging { val oldType = MemPortUtils.memType(d) val newType = retype(d.name)(oldType) val index = oldType - .asInstanceOf[BundleType].fields.headOption - .map(_.tpe.asInstanceOf[BundleType].fields.indexWhere( - { - case Field("data" | "wdata" | "rdata", _, _) => true - case _ => false - })) + .asInstanceOf[BundleType] + .fields + .headOption + .map( + _.tpe + .asInstanceOf[BundleType] + .fields + .indexWhere({ + case Field("data" | "wdata" | "rdata", _, _) => true + case _ => false + }) + ) val newDataType = index match { case Some(i) => //If index nonempty, then there exists a port @@ -299,15 +326,15 @@ object DedupModules extends LazyLogging { // associate it with the type of the memory (as the memory type is different than the datatype) retype(d.name + ";&*^$")(d.dataType) } - d.copy(dataType = newDataType) map rename map reinfo + d.copy(dataType = newDataType).map(rename).map(reinfo) case h: IsDeclaration => - val temp = h map rename map retype(h.name) map reinfo - if(renameExps) temp map onExp else temp + val temp = h.map(rename).map(retype(h.name)).map(reinfo) + if (renameExps) temp.map(onExp) else temp case other => - val temp = other map reinfo map onStmt - if(renameExps) temp map onExp else temp + val temp = other.map(reinfo).map(onStmt) + if (renameExps) temp.map(onExp) else temp } - module map onPort map onStmt + module.map(onPort).map(onStmt) } /** Dedup a module's instances based on dedup map @@ -321,11 +348,13 @@ object DedupModules extends LazyLogging { * @param renameMap Will be modified to keep track of renames in this function * @return fixed up module deduped instances */ - def dedupInstances(top: CircuitTarget, - originalModule: String, - moduleMap: Map[String, DefModule], - name2name: Map[String, String], - renameMap: RenameMap): DefModule = { + def dedupInstances( + top: CircuitTarget, + originalModule: String, + moduleMap: Map[String, DefModule], + name2name: Map[String, String], + renameMap: RenameMap + ): DefModule = { val module = moduleMap(originalModule) // If black box, return it (it has no instances) @@ -340,7 +369,8 @@ object DedupModules extends LazyLogging { } val typeMap = mutable.HashMap[String, Type]() def retype(name: String)(tpe: Type): Type = { - if (typeMap.contains(name)) typeMap(name) else { + if (typeMap.contains(name)) typeMap(name) + else { if (instanceModuleMap.contains(name)) { val newType = Utils.module_type(getNewModule(instanceModuleMap(name))) typeMap(name) = newType @@ -360,7 +390,7 @@ object DedupModules extends LazyLogging { def renameOfModule(instance: String, ofModule: String): String = { name2name(ofModule) } - changeInternals({n => n}, retype, {i => i}, renameOfModule)(module) + changeInternals({ n => n }, retype, { i => i }, renameOfModule)(module) } @tailrec @@ -415,10 +445,11 @@ object DedupModules extends LazyLogging { * @return A map from tag to names of modules with the same structure and * a RenameMap which maps Module names to their Tag. */ - def buildRTLTags(top: CircuitTarget, - moduleLinearization: Seq[DefModule], - noDedups: Set[String] - ): (collection.Map[String, collection.Set[String]], RenameMap) = { + def buildRTLTags( + top: CircuitTarget, + moduleLinearization: Seq[DefModule], + noDedups: Set[String] + ): (collection.Map[String, collection.Set[String]], RenameMap) = { // maps hash code to human readable tag val hashToTag = mutable.HashMap[ir.HashCode, String]() @@ -449,9 +480,9 @@ object DedupModules extends LazyLogging { moduleNameToTag(originalModule.name) = hashToTag(hash) } - val tag2all = hashToNames.map{ case (hash, names) => hashToTag(hash) -> names.toSet } + val tag2all = hashToNames.map { case (hash, names) => hashToTag(hash) -> names.toSet } val tagMap = RenameMap() - moduleNameToTag.foreach{ case (name, tag) => tagMap.record(top.module(name), top.module(tag)) } + moduleNameToTag.foreach { case (name, tag) => tagMap.record(top.module(name), top.module(tag)) } (tag2all, tagMap) } @@ -461,10 +492,12 @@ object DedupModules extends LazyLogging { * @param renameMap rename map to populate when deduping * @return Map of original Module name -> Deduped Module */ - def deduplicate(circuit: Circuit, - noDedups: Set[String], - previousDupResults: Map[String, String], - renameMap: RenameMap): Map[String, DefModule] = { + def deduplicate( + circuit: Circuit, + noDedups: Set[String], + previousDupResults: Map[String, String], + renameMap: RenameMap + ): Map[String, DefModule] = { val (moduleMap, moduleLinearization) = { val iGraph = InstanceKeyGraph(circuit) @@ -479,13 +512,14 @@ object DedupModules extends LazyLogging { val (tag2all, tagMap) = buildRTLTags(top, moduleLinearization, noDedups) // Set tag2name to be the best dedup module name - val moduleIndex = circuit.modules.zipWithIndex.map{case (m, i) => m.name -> i}.toMap + val moduleIndex = circuit.modules.zipWithIndex.map { case (m, i) => m.name -> i }.toMap // returns the module matching the circuit name or the module with lower index otherwise def order(l: String, r: String): String = { if (l == main) l else if (r == main) r - else if (moduleIndex(l) < moduleIndex(r)) l else r + else if (moduleIndex(l) < moduleIndex(r)) l + else r } // Maps a module's tag to its deduplicated module @@ -499,7 +533,7 @@ object DedupModules extends LazyLogging { tag2name(tag) = dedupName val dedupModule = moduleMap(dedupWithoutOldName) match { case e: ExtModule => e.copy(name = dedupName) - case e: Module => e.copy(name = dedupName) + case e: Module => e.copy(name = dedupName) } dedupName -> dedupModule }.toMap @@ -508,32 +542,32 @@ object DedupModules extends LazyLogging { val name2name = moduleMap.keysIterator.map { originalModule => tagMap.get(top.module(originalModule)) match { case Some(Seq(Target(_, Some(tag), Nil))) => originalModule -> tag2name(tag) - case None => originalModule -> originalModule - case other => throwInternalError(other.toString) + case None => originalModule -> originalModule + case other => throwInternalError(other.toString) } }.toMap // Build Remap for modules with deduped module references val dedupedName2module = tag2name.map { - case (tag, name) => name -> DedupModules.dedupInstances( - top, name, moduleMapWithOldNames, name2name, renameMap) + case (tag, name) => name -> DedupModules.dedupInstances(top, name, moduleMapWithOldNames, name2name, renameMap) } // Build map from original name to corresponding deduped module // It is important to flatMap before looking up the DefModules so that they aren't hashed val name2module: Map[String, DefModule] = tag2all.flatMap { case (tag, names) => names.map(_ -> tag) } - .mapValues(tag => dedupedName2module(tag2name(tag))) - .toMap + .mapValues(tag => dedupedName2module(tag2name(tag))) + .toMap // Build renameMap val indexedTargets = mutable.HashMap[String, IndexedSeq[ReferenceTarget]]() - name2module.foreach { case (originalName, depModule) => - if(originalName != depModule.name) { - val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule)) - val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName)) - computeRenameMap(fromSeq, toSeq, renameMap) - } + name2module.foreach { + case (originalName, depModule) => + if (originalName != depModule.name) { + val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule)) + val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName)) + computeRenameMap(fromSeq, toSeq, renameMap) + } } name2module @@ -549,18 +583,21 @@ object DedupModules extends LazyLogging { tpe } - changeInternals(rename, retype, {i => i}, {(x, y) => x}, renameExps = false)(m) + changeInternals(rename, retype, { i => i }, { (x, y) => x }, renameExps = false)(m) refs.toIndexedSeq } - def computeRenameMap(originalNames: IndexedSeq[ReferenceTarget], - dedupedNames: IndexedSeq[ReferenceTarget], - renameMap: RenameMap): Unit = { + def computeRenameMap( + originalNames: IndexedSeq[ReferenceTarget], + dedupedNames: IndexedSeq[ReferenceTarget], + renameMap: RenameMap + ): Unit = { originalNames.zip(dedupedNames).foreach { - case (o, d) => if (o.component != d.component || o.ref != d.ref) { - renameMap.record(o, d.copy(module = o.module)) - } + case (o, d) => + if (o.component != d.component || o.ref != d.ref) { + renameMap.record(o, d.copy(module = o.module)) + } } } |
