aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Magyar2019-11-04 13:12:59 -0800
committermergify[bot]2019-11-04 21:12:59 +0000
commit0d7defc81b02c41e416237ad226adc5f1ab0f8f2 (patch)
treebc9ba3e5b4bc145ff2857431fbd27ae47ae64539 /src
parent8f108c1aa8cac656da56b2505519db47080d5a26 (diff)
Ignore extmodule instances in Flatten (#1218)
* Closes #1162 * Instances of extmodules remain in the final hierarchy * Extmodule definitions are not renamed or duplicated * The rest of the pass may proceed as normal
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/Flatten.scala56
-rw-r--r--src/test/scala/firrtlTests/FlattenTests.scala78
2 files changed, 93 insertions, 41 deletions
diff --git a/src/main/scala/firrtl/transforms/Flatten.scala b/src/main/scala/firrtl/transforms/Flatten.scala
index 658f0987..26d2b06d 100644
--- a/src/main/scala/firrtl/transforms/Flatten.scala
+++ b/src/main/scala/firrtl/transforms/Flatten.scala
@@ -15,16 +15,19 @@ case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named
}
/**
- * Takes flatten annotations for module instances and modules and inline the entire hierarchy of modules down from the annotations.
- * This transformation instantiates and is based on the InlineInstances transformation.
- * Note: Inlining a module means inlining all its children module instances
- */
+ * Takes flatten annotations for module instances and modules and inline the entire hierarchy of
+ * modules down from the annotations. This transformation instantiates and is based on the
+ * InlineInstances transformation.
+ *
+ * @note Flattening a module means inlining all its fully-defined child instances
+ * @note Instances of extmodules are not (and cannot be) inlined
+ */
class Flatten extends Transform {
def inputForm = LowForm
def outputForm = LowForm
val inlineTransform = new InlineInstances
-
+
private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) {
case ((modNames, instNames), ann) => ann match {
@@ -40,7 +43,7 @@ class Flatten extends Transform {
/**
* Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and
- * by rewriting the original circuit to refer to the new modules that will be inlined later.
+ * by rewriting the original circuit to refer to the new modules that will be inlined later.
* @return modified circuit and ModuleNames to inline
*/
def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = {
@@ -49,10 +52,10 @@ class Flatten extends Transform {
val newModDefs = mutable.Set.empty[DefModule]
val nsp = Namespace(c)
- /**
- * We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later.
- * It populates seedMods where we capture the mapping between the original module name of the instances came from annotation
- * to a new module name that we will create as a replica of the original one.
+ /**
+ * We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later.
+ * It populates seedMods where we capture the mapping between the original module name of the instances came from annotation
+ * to a new module name that we will create as a replica of the original one.
* Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged.
*/
def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
@@ -66,12 +69,12 @@ class Flatten extends Transform {
} else x
case _ => x
}
-
+
val modifMods = c.modules map { m => m map rewriteMod(m) }
-
- /**
- * Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations).
- * Populates newModDefs, which are replicated modules used in the subcircuit that we create
+
+ /**
+ * Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations).
+ * Populates newModDefs, which are replicated modules used in the subcircuit that we create
* by recursively traversing modules captured inside seedMods and replicating them
*/
def recDupMods(mods: Map[String, String]): Unit = {
@@ -79,20 +82,23 @@ class Flatten extends Transform {
def dupMod(x: Statement): Statement = x match {
case _: Block => x map dupMod
- case WDefInstance(info, instName, moduleName, instTpe) =>
- val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
- replMods += moduleName -> newModName
- WDefInstance(info, instName, newModName, instTpe)
- case _ => x
+ case WDefInstance(info, instName, moduleName, instTpe) => modMap(moduleName) match {
+ case m: Module =>
+ val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
+ replMods += moduleName -> newModName
+ WDefInstance(info, instName, newModName, instTpe)
+ case _ => x // Ignore extmodules
+ }
+ case _ => x
}
-
+
def dupName(name: String): String = mods(name)
val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName }
-
+
newModDefs ++= newMods
-
+
if(replMods.size > 0) recDupMods(replMods.toMap)
-
+
}
recDupMods(seedMods.toMap)
@@ -100,7 +106,7 @@ class Flatten extends Transform {
val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) }
(c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
}
-
+
override def execute(state: CircuitState): CircuitState = {
val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a }
annos match {
diff --git a/src/test/scala/firrtlTests/FlattenTests.scala b/src/test/scala/firrtlTests/FlattenTests.scala
index 82c3ebdc..468cc1c4 100644
--- a/src/test/scala/firrtlTests/FlattenTests.scala
+++ b/src/test/scala/firrtlTests/FlattenTests.scala
@@ -25,7 +25,7 @@ class FlattenTests extends LowTransformSpec {
val name = if (parts.size == 1) modName else ComponentName(parts.tail.mkString("."), modName)
FlattenAnnotation(name)
}
-
+
"The modules inside Top " should "be inlined" in {
val input =
"""circuit Top :
@@ -55,7 +55,7 @@ class FlattenTests extends LowTransformSpec {
| b <= a""".stripMargin
execute(input, check, Seq(flatten("Top")))
}
-
+
"Two instances of the same module inside Top " should "be inlined" in {
val input =
"""circuit Top :
@@ -112,14 +112,14 @@ class FlattenTests extends LowTransformSpec {
| input a : UInt<32>
| output b : UInt<32>
| inst i of Inline2
- | i.a <= a
- | b <= i.a
+ | i.a <= a
+ | b <= i.a
| module Inline1 :
| input a : UInt<32>
| output b : UInt<32>
| inst i of Inline2
- | i.a <= a
- | b <= i.a
+ | i.a <= a
+ | b <= i.a
| module Inline2 :
| input a : UInt<32>
| output b : UInt<32>
@@ -147,13 +147,13 @@ class FlattenTests extends LowTransformSpec {
| input a : UInt<32>
| output b : UInt<32>
| inst i of Inline2
- | b <= i.a
- | i.a <= a
+ | b <= i.a
+ | i.a <= a
| module Inline1 :
| input a : UInt<32>
| output b : UInt<32>
- | inst i of Inline2
- | b <= i.a
+ | inst i of Inline2
+ | b <= i.a
| i.a <= a
| module Inline2 :
| input a : UInt<32>
@@ -179,14 +179,14 @@ class FlattenTests extends LowTransformSpec {
| input a : UInt<32>
| output b : UInt<32>
| inst i of Inline2
- | i.a <= a
- | b <= i.a
+ | i.a <= a
+ | b <= i.a
| module Inline1 :
| input a : UInt<32>
| output b : UInt<32>
| inst i of Inline2
- | i.a <= a
- | b <= i.a
+ | i.a <= a
+ | b <= i.a
| module Inline2 :
| input a : UInt<32>
| output b : UInt<32>
@@ -208,8 +208,8 @@ class FlattenTests extends LowTransformSpec {
| input a : UInt<32>
| output b : UInt<32>
| inst i of Inline2
- | b <= i.a
- | i.a <= a
+ | b <= i.a
+ | i.a <= a
| module Inline1 :
| input a : UInt<32>
| output b : UInt<32>
@@ -234,4 +234,50 @@ class FlattenTests extends LowTransformSpec {
|""".stripMargin
execute(input, input, Seq.empty)
}
+
+ "The Flatten transform" should "ignore extmodules" in {
+ val input = """
+ |circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline
+ | i.a <= a
+ | b <= i.b
+ | module Inline :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of ExternalMod
+ | i.a <= a
+ | b <= i.b
+ | extmodule ExternalMod :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | defname = ExternalMod
+ """.stripMargin
+ val check = """
+ |circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | wire i_a : UInt<32>
+ | wire i_b : UInt<32>
+ | inst i_i of ExternalMod
+ | i_b <= i_i.b
+ | i_i.a <= i_a
+ | b <= i_b
+ | i_a <= a
+ | module Inline :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of ExternalMod
+ | b <= i.b
+ | i.a <= a
+ | extmodule ExternalMod :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | defname = ExternalMod
+ """.stripMargin
+ execute(input, check, Seq(flatten("Top")))
+ }
}