diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms/Flatten.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/Flatten.scala | 179 |
1 files changed, 96 insertions, 83 deletions
diff --git a/src/main/scala/firrtl/transforms/Flatten.scala b/src/main/scala/firrtl/transforms/Flatten.scala index cc5b3504..36e71470 100644 --- a/src/main/scala/firrtl/transforms/Flatten.scala +++ b/src/main/scala/firrtl/transforms/Flatten.scala @@ -7,7 +7,7 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.annotations._ import scala.collection.mutable -import firrtl.passes.{InlineInstances,PassException} +import firrtl.passes.{InlineInstances, PassException} import firrtl.stage.Forms /** Tags an annotation to be consumed by this transform */ @@ -25,101 +25,114 @@ case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named */ class Flatten extends Transform with DependencyAPIMigration { - override def prerequisites = Forms.LowForm - override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Forms.LowEmitters + override def prerequisites = Forms.LowForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Forms.LowEmitters override def invalidates(a: Transform) = false - val inlineTransform = new InlineInstances - - private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = - anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) { - case ((modNames, instNames), ann) => ann match { - case FlattenAnnotation(CircuitName(c)) => - (circuit.modules.collect { - case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) - }.toSet, instNames) - case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) - case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) - case _ => throw new PassException("Annotation must be a FlattenAnnotation") - } - } - - /** + val inlineTransform = new InlineInstances + + private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) { + case ((modNames, instNames), ann) => + ann match { + case FlattenAnnotation(CircuitName(c)) => + ( + circuit.modules.collect { + case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) + }.toSet, + instNames + ) + case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) + case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) + case _ => throw new PassException("Annotation must be a FlattenAnnotation") + } + } + + /** * Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and * by rewriting the original circuit to refer to the new modules that will be inlined later. * @return modified circuit and ModuleNames to inline */ - def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = { - val modMap = c.modules.map(m => m.name->m).toMap - val seedMods = mutable.Map.empty[String, String] - val newModDefs = mutable.Set.empty[DefModule] - val nsp = Namespace(c) - - /** + def duplicateSubCircuitsFromAnno( + c: Circuit, + mods: Set[ModuleName], + insts: Set[ComponentName] + ): (Circuit, Set[ModuleName]) = { + val modMap = c.modules.map(m => m.name -> m).toMap + val seedMods = mutable.Map.empty[String, String] + val newModDefs = mutable.Set.empty[DefModule] + val nsp = Namespace(c) + + /** * We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later. * It populates seedMods where we capture the mapping between the original module name of the instances came from annotation * to a new module name that we will create as a replica of the original one. * Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged. */ - def rewriteMod(parent: DefModule)(x: Statement): Statement = x match { - case _: Block => x map rewriteMod(parent) - case WDefInstance(info, instName, moduleName, instTpe) => - if (insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main)))) - || mods.contains(ModuleName(parent.name, CircuitName(c.main)))) { - val newModName = if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN") - seedMods += moduleName -> newModName - WDefInstance(info, instName, newModName, instTpe) - } else x - case _ => x - } - - val modifMods = c.modules map { m => m map rewriteMod(m) } - - /** + def rewriteMod(parent: DefModule)(x: Statement): Statement = x match { + case _: Block => x.map(rewriteMod(parent)) + case WDefInstance(info, instName, moduleName, instTpe) => + if ( + insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main)))) + || mods.contains(ModuleName(parent.name, CircuitName(c.main))) + ) { + val newModName = + if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN") + seedMods += moduleName -> newModName + WDefInstance(info, instName, newModName, instTpe) + } else x + case _ => x + } + + val modifMods = c.modules.map { m => m.map(rewriteMod(m)) } + + /** * Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations). * Populates newModDefs, which are replicated modules used in the subcircuit that we create * by recursively traversing modules captured inside seedMods and replicating them */ - def recDupMods(mods: Map[String, String]): Unit = { - val replMods = mutable.Map.empty[String, String] - - def dupMod(x: Statement): Statement = x match { - case _: Block => x map dupMod - case WDefInstance(info, instName, moduleName, instTpe) => modMap(moduleName) match { - case m: Module => - val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN") - replMods += moduleName -> newModName - WDefInstance(info, instName, newModName, instTpe) - case _ => x // Ignore extmodules - } - case _ => x - } - - def dupName(name: String): String = mods(name) - val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName } - - newModDefs ++= newMods - - if(replMods.size > 0) recDupMods(replMods.toMap) - - } - recDupMods(seedMods.toMap) - - //convert newly created modules to ModuleName for inlining next (outside this function) - val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) } - (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet) - } - - override def execute(state: CircuitState): CircuitState = { - val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a } - annos match { - case Nil => state - case myAnnotations => - val (modNames, instNames) = collectAnns(state.circuit, myAnnotations) - // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline - val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames) - inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations) - } - } + def recDupMods(mods: Map[String, String]): Unit = { + val replMods = mutable.Map.empty[String, String] + + def dupMod(x: Statement): Statement = x match { + case _: Block => x.map(dupMod) + case WDefInstance(info, instName, moduleName, instTpe) => + modMap(moduleName) match { + case m: Module => + val newModName = + if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN") + replMods += moduleName -> newModName + WDefInstance(info, instName, newModName, instTpe) + case _ => x // Ignore extmodules + } + case _ => x + } + + def dupName(name: String): String = mods(name) + val newMods = mods.map { case (origName, newName) => modMap(origName).map(dupMod).map(dupName) } + + newModDefs ++= newMods + + if (replMods.size > 0) recDupMods(replMods.toMap) + + } + recDupMods(seedMods.toMap) + + //convert newly created modules to ModuleName for inlining next (outside this function) + val modsToInline = newModDefs.map { m => ModuleName(m.name, CircuitName(c.main)) } + (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet) + } + + override def execute(state: CircuitState): CircuitState = { + val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a } + annos match { + case Nil => state + case myAnnotations => + val (modNames, instNames) = collectAnns(state.circuit, myAnnotations) + // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline + val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames) + inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations) + } + } } |
