diff options
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/Flatten.scala | 120 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/FlattenTests.scala | 190 |
3 files changed, 311 insertions, 0 deletions
@@ -55,3 +55,4 @@ project/plugins/project/ .idea/ gen/ project/project/ +/bin/ diff --git a/src/main/scala/firrtl/transforms/Flatten.scala b/src/main/scala/firrtl/transforms/Flatten.scala new file mode 100644 index 00000000..748ea00c --- /dev/null +++ b/src/main/scala/firrtl/transforms/Flatten.scala @@ -0,0 +1,120 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.annotations._ +import scala.collection.mutable +import firrtl.passes.{InlineInstances,PassException} + +/** Tags an annotation to be consumed by this transform */ +object FlattenAnnotation { + def apply(target: Named): Annotation = Annotation(target, classOf[Flatten], "") + + def unapply(a: Annotation): Option[Named] = a match { + case Annotation(named, t, _) if t == classOf[Flatten] => Some(named) + case _ => None + } +} + +/** + * Takes flatten annotations for module instances and modules and inline the entire hierarchy of modules down from the annotations. + * This transformation instantiates and is based on the InlineInstances transformation. + * Note: Inlining a module means inlining all its children module instances + */ +class Flatten extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + val inlineTransform = new InlineInstances + + private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) { + case ((modNames, instNames), ann) => ann match { + case FlattenAnnotation(CircuitName(c)) => + (circuit.modules.collect { + case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) + }.toSet, instNames) + case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) + case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) + case _ => throw new PassException("Annotation must be InlineDeepAnnotation") + } + } + + /** + * Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and + * by rewriting the original circuit to refer to the new modules that will be inlined later. + * @return modified circuit and ModuleNames to inline + */ + def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = { + val modMap = c.modules.map(m => m.name->m) toMap + val seedMods = mutable.Map.empty[String, String] + val newModDefs = mutable.Set.empty[DefModule] + val nsp = Namespace(c) + + /** + * We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later. + * It populates seedMods where we capture the mapping between the original module name of the instances came from annotation + * to a new module name that we will create as a replica of the original one. + * Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged. + */ + def rewriteMod(parent: DefModule)(x: Statement): Statement = x match { + case _: Block => x map rewriteMod(parent) + case WDefInstance(info, instName, moduleName, instTpe) => + if (insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main)))) + || mods.contains(ModuleName(parent.name, CircuitName(c.main)))) { + val newModName = nsp.newName(moduleName+"_TO_FLATTEN") + seedMods += moduleName -> newModName + WDefInstance(info, instName, newModName, instTpe) + } else x + case _ => x + } + + val modifMods = c.modules map { m => m map rewriteMod(m) } + + /** + * Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations). + * Populates newModDefs, which are replicated modules used in the subcircuit that we create + * by recursively traversing modules captured inside seedMods and replicating them + */ + def recDupMods(mods: Map[String, String]): Unit = { + val replMods = mutable.Map.empty[String, String] + + def dupMod(x: Statement): Statement = x match { + case _: Block => x map dupMod + case WDefInstance(info, instName, moduleName, instTpe) => + val newModName = nsp.newName(moduleName+"_TO_FLATTEN") + replMods += moduleName -> newModName + WDefInstance(info, instName, newModName, instTpe) + case _ => x + } + + def dupName(name: String): String = mods(name) + val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName } + + newModDefs ++= newMods + + if(replMods.size > 0) recDupMods(replMods.toMap) + + } + recDupMods(seedMods.toMap) + + //convert newly created modules to ModuleName for inlining next (outside this function) + val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) } + (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet) + } + + override def execute(state: CircuitState): CircuitState = { + getMyAnnotations(state) match { + case Nil => CircuitState(state.circuit, state.form) + case myAnnotations => + val c = state.circuit + val (modNames, instNames) = collectAnns(state.circuit, myAnnotations) + // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline + val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames) + inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations) + } + } +} diff --git a/src/test/scala/firrtlTests/FlattenTests.scala b/src/test/scala/firrtlTests/FlattenTests.scala new file mode 100644 index 00000000..f695cf4a --- /dev/null +++ b/src/test/scala/firrtlTests/FlattenTests.scala @@ -0,0 +1,190 @@ +// See LICENSE for license details. + +package firrtlTests + +import org.scalatest.FlatSpec +import org.scalatest.Matchers +import org.scalatest.junit.JUnitRunner +import firrtl.ir.Circuit +import firrtl.Parser +import firrtl.passes.PassExceptions +import firrtl.annotations.{Annotation, CircuitName, ComponentName, ModuleName, Named} +import firrtl.transforms.{FlattenAnnotation, Flatten} +import logger.{LogLevel, Logger} +import logger.LogLevel.Debug + + +/** + * Tests deep inline transformation + */ +class FlattenTests extends LowTransformSpec { + def transform = new Flatten + def flatten(mod: String): Annotation = { + val parts = mod.split('.') + val modName = ModuleName(parts.head, CircuitName("Top")) // If this fails, bad input + val name = if (parts.size == 1) modName else ComponentName(parts.tail.mkString("."), modName) + FlattenAnnotation(name) + } + + + "The modules inside Top " should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline1 + | i.a <= a + | b <= i.b + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i$a : UInt<32> + | wire i$b : UInt<32> + | i$b <= i$a + | b <= i$b + | i$a <= a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Top"))) + } + + "The module instance i in Top " should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | inst i of Inline1 + | inst ni of NotInline1 + | i.a <= a + | b <= i.b + | ni.a <= na + | nb <= ni.b + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | wire i$a : UInt<32> + | wire i$b : UInt<32> + | wire i$i$a : UInt<32> + | wire i$i$b : UInt<32> + | i$i$b <= i$i$a + | i$b <= i$i$a + | i$i$a <= i$a + | inst ni of NotInline1 + | b <= i$b + | nb <= ni.b + | i$a <= a + | ni.a <= na + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | b <= i.a + | i.a <= a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | b <= i.a + | i.a <= a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Top.i"))) + } + "The module Inline1" should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | inst i of Inline1 + | inst ni of NotInline1 + | i.a <= a + | b <= i.b + | ni.a <= na + | nb <= ni.b + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | inst i of Inline1 + | inst ni of NotInline1 + | b <= i.b + | nb <= ni.b + | i.a <= a + | ni.a <= na + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | b <= i.a + | i.a <= a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | wire i$a : UInt<32> + | wire i$b : UInt<32> + | i$b <= i$a + | b <= i$a + | i$a <= a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Inline1"))) + } +} |
