diff options
| author | Albert Magyar | 2019-11-04 13:12:59 -0800 |
|---|---|---|
| committer | mergify[bot] | 2019-11-04 21:12:59 +0000 |
| commit | 0d7defc81b02c41e416237ad226adc5f1ab0f8f2 (patch) | |
| tree | bc9ba3e5b4bc145ff2857431fbd27ae47ae64539 /src | |
| parent | 8f108c1aa8cac656da56b2505519db47080d5a26 (diff) | |
Ignore extmodule instances in Flatten (#1218)
* Closes #1162
* Instances of extmodules remain in the final hierarchy
* Extmodule definitions are not renamed or duplicated
* The rest of the pass may proceed as normal
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/transforms/Flatten.scala | 56 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/FlattenTests.scala | 78 |
2 files changed, 93 insertions, 41 deletions
diff --git a/src/main/scala/firrtl/transforms/Flatten.scala b/src/main/scala/firrtl/transforms/Flatten.scala index 658f0987..26d2b06d 100644 --- a/src/main/scala/firrtl/transforms/Flatten.scala +++ b/src/main/scala/firrtl/transforms/Flatten.scala @@ -15,16 +15,19 @@ case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named } /** - * Takes flatten annotations for module instances and modules and inline the entire hierarchy of modules down from the annotations. - * This transformation instantiates and is based on the InlineInstances transformation. - * Note: Inlining a module means inlining all its children module instances - */ + * Takes flatten annotations for module instances and modules and inline the entire hierarchy of + * modules down from the annotations. This transformation instantiates and is based on the + * InlineInstances transformation. + * + * @note Flattening a module means inlining all its fully-defined child instances + * @note Instances of extmodules are not (and cannot be) inlined + */ class Flatten extends Transform { def inputForm = LowForm def outputForm = LowForm val inlineTransform = new InlineInstances - + private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) { case ((modNames, instNames), ann) => ann match { @@ -40,7 +43,7 @@ class Flatten extends Transform { /** * Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and - * by rewriting the original circuit to refer to the new modules that will be inlined later. + * by rewriting the original circuit to refer to the new modules that will be inlined later. * @return modified circuit and ModuleNames to inline */ def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = { @@ -49,10 +52,10 @@ class Flatten extends Transform { val newModDefs = mutable.Set.empty[DefModule] val nsp = Namespace(c) - /** - * We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later. - * It populates seedMods where we capture the mapping between the original module name of the instances came from annotation - * to a new module name that we will create as a replica of the original one. + /** + * We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later. + * It populates seedMods where we capture the mapping between the original module name of the instances came from annotation + * to a new module name that we will create as a replica of the original one. * Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged. */ def rewriteMod(parent: DefModule)(x: Statement): Statement = x match { @@ -66,12 +69,12 @@ class Flatten extends Transform { } else x case _ => x } - + val modifMods = c.modules map { m => m map rewriteMod(m) } - - /** - * Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations). - * Populates newModDefs, which are replicated modules used in the subcircuit that we create + + /** + * Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations). + * Populates newModDefs, which are replicated modules used in the subcircuit that we create * by recursively traversing modules captured inside seedMods and replicating them */ def recDupMods(mods: Map[String, String]): Unit = { @@ -79,20 +82,23 @@ class Flatten extends Transform { def dupMod(x: Statement): Statement = x match { case _: Block => x map dupMod - case WDefInstance(info, instName, moduleName, instTpe) => - val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN") - replMods += moduleName -> newModName - WDefInstance(info, instName, newModName, instTpe) - case _ => x + case WDefInstance(info, instName, moduleName, instTpe) => modMap(moduleName) match { + case m: Module => + val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN") + replMods += moduleName -> newModName + WDefInstance(info, instName, newModName, instTpe) + case _ => x // Ignore extmodules + } + case _ => x } - + def dupName(name: String): String = mods(name) val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName } - + newModDefs ++= newMods - + if(replMods.size > 0) recDupMods(replMods.toMap) - + } recDupMods(seedMods.toMap) @@ -100,7 +106,7 @@ class Flatten extends Transform { val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) } (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet) } - + override def execute(state: CircuitState): CircuitState = { val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a } annos match { diff --git a/src/test/scala/firrtlTests/FlattenTests.scala b/src/test/scala/firrtlTests/FlattenTests.scala index 82c3ebdc..468cc1c4 100644 --- a/src/test/scala/firrtlTests/FlattenTests.scala +++ b/src/test/scala/firrtlTests/FlattenTests.scala @@ -25,7 +25,7 @@ class FlattenTests extends LowTransformSpec { val name = if (parts.size == 1) modName else ComponentName(parts.tail.mkString("."), modName) FlattenAnnotation(name) } - + "The modules inside Top " should "be inlined" in { val input = """circuit Top : @@ -55,7 +55,7 @@ class FlattenTests extends LowTransformSpec { | b <= a""".stripMargin execute(input, check, Seq(flatten("Top"))) } - + "Two instances of the same module inside Top " should "be inlined" in { val input = """circuit Top : @@ -112,14 +112,14 @@ class FlattenTests extends LowTransformSpec { | input a : UInt<32> | output b : UInt<32> | inst i of Inline2 - | i.a <= a - | b <= i.a + | i.a <= a + | b <= i.a | module Inline1 : | input a : UInt<32> | output b : UInt<32> | inst i of Inline2 - | i.a <= a - | b <= i.a + | i.a <= a + | b <= i.a | module Inline2 : | input a : UInt<32> | output b : UInt<32> @@ -147,13 +147,13 @@ class FlattenTests extends LowTransformSpec { | input a : UInt<32> | output b : UInt<32> | inst i of Inline2 - | b <= i.a - | i.a <= a + | b <= i.a + | i.a <= a | module Inline1 : | input a : UInt<32> | output b : UInt<32> - | inst i of Inline2 - | b <= i.a + | inst i of Inline2 + | b <= i.a | i.a <= a | module Inline2 : | input a : UInt<32> @@ -179,14 +179,14 @@ class FlattenTests extends LowTransformSpec { | input a : UInt<32> | output b : UInt<32> | inst i of Inline2 - | i.a <= a - | b <= i.a + | i.a <= a + | b <= i.a | module Inline1 : | input a : UInt<32> | output b : UInt<32> | inst i of Inline2 - | i.a <= a - | b <= i.a + | i.a <= a + | b <= i.a | module Inline2 : | input a : UInt<32> | output b : UInt<32> @@ -208,8 +208,8 @@ class FlattenTests extends LowTransformSpec { | input a : UInt<32> | output b : UInt<32> | inst i of Inline2 - | b <= i.a - | i.a <= a + | b <= i.a + | i.a <= a | module Inline1 : | input a : UInt<32> | output b : UInt<32> @@ -234,4 +234,50 @@ class FlattenTests extends LowTransformSpec { |""".stripMargin execute(input, input, Seq.empty) } + + "The Flatten transform" should "ignore extmodules" in { + val input = """ + |circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst i of ExternalMod + | i.a <= a + | b <= i.b + | extmodule ExternalMod : + | input a : UInt<32> + | output b : UInt<32> + | defname = ExternalMod + """.stripMargin + val check = """ + |circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | inst i_i of ExternalMod + | i_b <= i_i.b + | i_i.a <= i_a + | b <= i_b + | i_a <= a + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst i of ExternalMod + | b <= i.b + | i.a <= a + | extmodule ExternalMod : + | input a : UInt<32> + | output b : UInt<32> + | defname = ExternalMod + """.stripMargin + execute(input, check, Seq(flatten("Top"))) + } } |
