aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/Flatten.scala120
-rw-r--r--src/test/scala/firrtlTests/FlattenTests.scala190
2 files changed, 310 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/Flatten.scala b/src/main/scala/firrtl/transforms/Flatten.scala
new file mode 100644
index 00000000..748ea00c
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/Flatten.scala
@@ -0,0 +1,120 @@
+// See LICENSE for license details.
+
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.annotations._
+import scala.collection.mutable
+import firrtl.passes.{InlineInstances,PassException}
+
+/** Tags an annotation to be consumed by this transform */
+object FlattenAnnotation {
+ def apply(target: Named): Annotation = Annotation(target, classOf[Flatten], "")
+
+ def unapply(a: Annotation): Option[Named] = a match {
+ case Annotation(named, t, _) if t == classOf[Flatten] => Some(named)
+ case _ => None
+ }
+}
+
+/**
+ * Takes flatten annotations for module instances and modules and inline the entire hierarchy of modules down from the annotations.
+ * This transformation instantiates and is based on the InlineInstances transformation.
+ * Note: Inlining a module means inlining all its children module instances
+ */
+class Flatten extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
+ val inlineTransform = new InlineInstances
+
+ private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
+ anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) {
+ case ((modNames, instNames), ann) => ann match {
+ case FlattenAnnotation(CircuitName(c)) =>
+ (circuit.modules.collect {
+ case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
+ }.toSet, instNames)
+ case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
+ case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
+ case _ => throw new PassException("Annotation must be InlineDeepAnnotation")
+ }
+ }
+
+ /**
+ * Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and
+ * by rewriting the original circuit to refer to the new modules that will be inlined later.
+ * @return modified circuit and ModuleNames to inline
+ */
+ def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = {
+ val modMap = c.modules.map(m => m.name->m) toMap
+ val seedMods = mutable.Map.empty[String, String]
+ val newModDefs = mutable.Set.empty[DefModule]
+ val nsp = Namespace(c)
+
+ /**
+ * We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later.
+ * It populates seedMods where we capture the mapping between the original module name of the instances came from annotation
+ * to a new module name that we will create as a replica of the original one.
+ * Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged.
+ */
+ def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
+ case _: Block => x map rewriteMod(parent)
+ case WDefInstance(info, instName, moduleName, instTpe) =>
+ if (insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main))))
+ || mods.contains(ModuleName(parent.name, CircuitName(c.main)))) {
+ val newModName = nsp.newName(moduleName+"_TO_FLATTEN")
+ seedMods += moduleName -> newModName
+ WDefInstance(info, instName, newModName, instTpe)
+ } else x
+ case _ => x
+ }
+
+ val modifMods = c.modules map { m => m map rewriteMod(m) }
+
+ /**
+ * Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations).
+ * Populates newModDefs, which are replicated modules used in the subcircuit that we create
+ * by recursively traversing modules captured inside seedMods and replicating them
+ */
+ def recDupMods(mods: Map[String, String]): Unit = {
+ val replMods = mutable.Map.empty[String, String]
+
+ def dupMod(x: Statement): Statement = x match {
+ case _: Block => x map dupMod
+ case WDefInstance(info, instName, moduleName, instTpe) =>
+ val newModName = nsp.newName(moduleName+"_TO_FLATTEN")
+ replMods += moduleName -> newModName
+ WDefInstance(info, instName, newModName, instTpe)
+ case _ => x
+ }
+
+ def dupName(name: String): String = mods(name)
+ val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName }
+
+ newModDefs ++= newMods
+
+ if(replMods.size > 0) recDupMods(replMods.toMap)
+
+ }
+ recDupMods(seedMods.toMap)
+
+ //convert newly created modules to ModuleName for inlining next (outside this function)
+ val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) }
+ (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
+ }
+
+ override def execute(state: CircuitState): CircuitState = {
+ getMyAnnotations(state) match {
+ case Nil => CircuitState(state.circuit, state.form)
+ case myAnnotations =>
+ val c = state.circuit
+ val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
+ // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline
+ val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames)
+ inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations)
+ }
+ }
+}
diff --git a/src/test/scala/firrtlTests/FlattenTests.scala b/src/test/scala/firrtlTests/FlattenTests.scala
new file mode 100644
index 00000000..f695cf4a
--- /dev/null
+++ b/src/test/scala/firrtlTests/FlattenTests.scala
@@ -0,0 +1,190 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import org.scalatest.FlatSpec
+import org.scalatest.Matchers
+import org.scalatest.junit.JUnitRunner
+import firrtl.ir.Circuit
+import firrtl.Parser
+import firrtl.passes.PassExceptions
+import firrtl.annotations.{Annotation, CircuitName, ComponentName, ModuleName, Named}
+import firrtl.transforms.{FlattenAnnotation, Flatten}
+import logger.{LogLevel, Logger}
+import logger.LogLevel.Debug
+
+
+/**
+ * Tests deep inline transformation
+ */
+class FlattenTests extends LowTransformSpec {
+ def transform = new Flatten
+ def flatten(mod: String): Annotation = {
+ val parts = mod.split('.')
+ val modName = ModuleName(parts.head, CircuitName("Top")) // If this fails, bad input
+ val name = if (parts.size == 1) modName else ComponentName(parts.tail.mkString("."), modName)
+ FlattenAnnotation(name)
+ }
+
+
+ "The modules inside Top " should "be inlined" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline1
+ | i.a <= a
+ | b <= i.b
+ | module Inline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | b <= a""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | wire i$a : UInt<32>
+ | wire i$b : UInt<32>
+ | i$b <= i$a
+ | b <= i$b
+ | i$a <= a
+ | module Inline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | b <= a""".stripMargin
+ execute(input, check, Seq(flatten("Top")))
+ }
+
+ "The module instance i in Top " should "be inlined" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | input na : UInt<32>
+ | output b : UInt<32>
+ | output nb : UInt<32>
+ | inst i of Inline1
+ | inst ni of NotInline1
+ | i.a <= a
+ | b <= i.b
+ | ni.a <= na
+ | nb <= ni.b
+ | module NotInline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline2
+ | i.a <= a
+ | b <= i.a
+ | module Inline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline2
+ | i.a <= a
+ | b <= i.a
+ | module Inline2 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | b <= a""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | input na : UInt<32>
+ | output b : UInt<32>
+ | output nb : UInt<32>
+ | wire i$a : UInt<32>
+ | wire i$b : UInt<32>
+ | wire i$i$a : UInt<32>
+ | wire i$i$b : UInt<32>
+ | i$i$b <= i$i$a
+ | i$b <= i$i$a
+ | i$i$a <= i$a
+ | inst ni of NotInline1
+ | b <= i$b
+ | nb <= ni.b
+ | i$a <= a
+ | ni.a <= na
+ | module NotInline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline2
+ | b <= i.a
+ | i.a <= a
+ | module Inline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline2
+ | b <= i.a
+ | i.a <= a
+ | module Inline2 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | b <= a""".stripMargin
+ execute(input, check, Seq(flatten("Top.i")))
+ }
+ "The module Inline1" should "be inlined" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | input na : UInt<32>
+ | output b : UInt<32>
+ | output nb : UInt<32>
+ | inst i of Inline1
+ | inst ni of NotInline1
+ | i.a <= a
+ | b <= i.b
+ | ni.a <= na
+ | nb <= ni.b
+ | module NotInline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline2
+ | i.a <= a
+ | b <= i.a
+ | module Inline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline2
+ | i.a <= a
+ | b <= i.a
+ | module Inline2 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | b <= a""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | input na : UInt<32>
+ | output b : UInt<32>
+ | output nb : UInt<32>
+ | inst i of Inline1
+ | inst ni of NotInline1
+ | b <= i.b
+ | nb <= ni.b
+ | i.a <= a
+ | ni.a <= na
+ | module NotInline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline2
+ | b <= i.a
+ | i.a <= a
+ | module Inline1 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | wire i$a : UInt<32>
+ | wire i$b : UInt<32>
+ | i$b <= i$a
+ | b <= i$a
+ | i$a <= a
+ | module Inline2 :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | b <= a""".stripMargin
+ execute(input, check, Seq(flatten("Inline1")))
+ }
+}