diff options
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 137 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InlineInstancesTests.scala | 85 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/PassTests.scala | 8 |
3 files changed, 186 insertions, 44 deletions
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index b9e67c82..0ca98ac5 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -6,7 +6,9 @@ package passes import firrtl.ir._ import firrtl.Mappers._ import firrtl.annotations._ -import firrtl.analyses.InstanceGraph +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.{InstanceGraph} +import firrtl.graph.{DiGraph, MutableDiGraph} import firrtl.stage.RunFirrtlTransformAnnotation import firrtl.options.{RegisteredTransform, ShellOption} @@ -108,15 +110,15 @@ class InlineInstances extends Transform with RegisteredTransform { def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName], annos: AnnotationSeq): CircuitState = { - def getInstancesOf(c: Circuit, modules: Set[String]): Set[String] = - c.modules.foldLeft(Set[String]()) { (set, d) => + def getInstancesOf(c: Circuit, modules: Set[String]): Set[(OfModule, Instance)] = + c.modules.foldLeft(Set[(OfModule, Instance)]()) { (set, d) => d match { case e: ExtModule => set case m: Module => - val instances = mutable.HashSet[String]() + val instances = mutable.HashSet[(OfModule, Instance)]() def findInstances(s: Statement): Statement = s match { case WDefInstance(info, instName, moduleName, instTpe) if modules.contains(moduleName) => - instances += m.name + "." + instName + instances += (OfModule(m.name) -> Instance(instName)) s case sx => sx map findInstances } @@ -128,19 +130,20 @@ class InlineInstances extends Transform with RegisteredTransform { // Check annotations and circuit match up check(c, modsToInline, instsToInline) val flatModules = modsToInline.map(m => m.name) - val flatInstances = instsToInline.map(i => i.module.name + "." + i.name) ++ getInstancesOf(c, flatModules) + val flatInstances: Set[(OfModule, Instance)] = instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules) val iGraph = new InstanceGraph(c) val namespaceMap = collection.mutable.Map[String, Namespace]() // Map of Module name to Map of instance name to Module name - val instMaps: Map[String, Map[String, String]] = { + val instMaps: Map[OfModule, Map[Instance, OfModule]] = { iGraph.graph.getEdgeMap.view.map { case (mod, children) => - mod.module -> children.view.map(i => i.name -> i.module).toMap + OfModule(mod.module) -> children.view.map(i => Instance(i.name) -> OfModule(i.module)).toMap }.toMap } /** Add a prefix to all declarations updating a [[Namespace]] and appending to a [[RenameMap]] */ def appendNamePrefix( currentModule: IsModule, + nextModule: IsModule, prefix: String, ns: Namespace, renames: mutable.HashMap[String, String], @@ -151,17 +154,17 @@ class InlineInstances extends Transform with RegisteredTransform { } ofModuleOpt match { case None => - renameMap.record(currentModule.ref(name), currentModule.ref(prefix + name)) + renameMap.record(currentModule.ref(name), nextModule.ref(prefix + name)) case Some(ofModule) => - renameMap.record(currentModule.instOf(name, ofModule), currentModule.instOf(prefix + name, ofModule)) + renameMap.record(currentModule.instOf(name, ofModule), nextModule.instOf(prefix + name, ofModule)) } renames(name) = prefix + name prefix + name } s match { - case s: WDefInstance => s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, prefix, ns, renames, renameMap)) - case other => s.map(onName(None)).map(appendNamePrefix(currentModule, prefix, ns, renames, renameMap)) + case s: WDefInstance => s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) + case other => s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) } } @@ -180,38 +183,82 @@ class InlineInstances extends Transform with RegisteredTransform { s.map(onExpr).map(appendRefPrefix(currentModule, renames)) } - def fixupRefs( - instMap: Map[String, String], - currentModule: IsModule, - renames: RenameMap)(e: Expression): Expression = e match { - case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) => - val inst = currentModule.instOf(ref, instMap(ref)) - val port = inst.ref(field) - renames.get(port) match { - case Some(Seq(p)) => - p match { - case ReferenceTarget(_, _, Seq((TargetToken.Instance(r), TargetToken.OfModule(_))), f, Nil) => - wsf.copy(expr = wr.copy(name = r), name = f) - case ReferenceTarget(_, _, Nil, r, Nil) => WRef(r, tpe, WireKind, gen) - } - case None => wsf + val cache = mutable.HashMap.empty[ModuleTarget, Statement] + + /** renamesMap is a map of instances to [[RenameMap]]. + * The keys are pairs of enclosing [[OfModule]] and [[Instance]] + * The [[RenameMap]]s in renamesMap are appear in renamesSeq + * in the order that they should be applied + */ + val (renamesMap, renamesSeq) = { + val mutableDiGraph = new MutableDiGraph[(OfModule, Instance)] + // compute instance graph + instMaps.foreach { case (grandParentOfMod, parents) => + parents.foreach { case (parentInst, parentOfMod) => + val from = grandParentOfMod -> parentInst + mutableDiGraph.addVertex(from) + instMaps(parentOfMod).foreach { case (childInst, _) => + val to = parentOfMod -> childInst + mutableDiGraph.addVertex(to) + mutableDiGraph.addEdge(from, to) + } } - case wr@ WRef(name, _, _, _) => - val comp = currentModule.ref(name) //ComponentName(name, currentModule) - renames.get(comp).getOrElse(Seq(comp)) match { - case Seq(car: ReferenceTarget) => wr.copy(name=car.ref) + } + + val diGraph = DiGraph(mutableDiGraph) + val subgraph = diGraph.simplify(flatInstances) + val edges = subgraph.getEdgeMap + + // calculate which [[RenameMap]] should be associated with each instance + val indexMap = new mutable.HashMap[(OfModule, Instance), Int] + flatInstances.foreach(v => indexMap(v) = 0) + subgraph.linearize.foreach { parent => + edges(parent).foreach { child => + indexMap(child) = indexMap(parent) + 1 } - case ex => ex.map(fixupRefs(instMap, currentModule, renames)) + } + + val maxIdx = indexMap.values.max + val resultSeq = Seq.fill(maxIdx + 1)(RenameMap()) + val resultMap = indexMap.mapValues(idx => resultSeq(maxIdx - idx)) + (resultMap, resultSeq) } - var renames = RenameMap() + def fixupRefs( + instMap: Map[Instance, OfModule], + currentModule: IsModule)(e: Expression): Expression = { + e match { + case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) => + val inst = currentModule.instOf(ref, instMap(Instance(ref)).value) + val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance)) + val port = inst.ref(field) + renamesOpt.flatMap(_.get(port)) match { + case Some(Seq(p)) => + p match { + case ReferenceTarget(_, _, Seq((TargetToken.Instance(r), TargetToken.OfModule(_))), f, Nil) => + wsf.copy(expr = wr.copy(name = r), name = f) + case ReferenceTarget(_, _, Nil, r, Nil) => WRef(r, tpe, WireKind, gen) + } + case None => wsf + } + case wr@ WRef(name, _, InstanceKind, _) => + val inst = currentModule.instOf(name, instMap(Instance(name)).value) + val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance)) + val comp = currentModule.ref(name) + renamesOpt.flatMap(_.get(comp)).getOrElse(Seq(comp)) match { + case Seq(car: ReferenceTarget) => wr.copy(name=car.ref) + } + case ex => ex.map(fixupRefs(instMap, currentModule)) + } + } def onStmt(currentModule: ModuleTarget)(s: Statement): Statement = { val currentModuleName = currentModule.module val ns = namespaceMap.getOrElseUpdate(currentModuleName, Namespace(iGraph.moduleMap(currentModuleName))) - val instMap = instMaps(currentModuleName) + val instMap = instMaps(OfModule(currentModuleName)) s match { - case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(s"${currentModuleName}.$instName") => + case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) => + val renames = renamesMap(OfModule(currentModuleName) -> Instance(instName)) val toInline = iGraph.moduleMap(modName) match { case m: ExtModule => throw new PassException(s"Cannot inline external module ${m.name}") case m: Module => m @@ -219,7 +266,10 @@ class InlineInstances extends Transform with RegisteredTransform { val ports = toInline.ports.map(p => DefWire(p.info, p.name, p.tpe)) - val bodyx = Block(ports :+ toInline.body) map onStmt(currentModule.copy(module = modName)) + val bodyx = { + val module = currentModule.copy(module = modName) + cache.getOrElseUpdate(module, Block(ports :+ toInline.body) map onStmt(module)) + } val names = "" +: Uniquify .enumerateNames(Uniquify.stmtToType(bodyx)(NoInfo, "")) @@ -232,31 +282,30 @@ class InlineInstances extends Transform with RegisteredTransform { */ val safePrefix = Uniquify.findValidPrefix(instName + inlineDelim, names, ns.cloneUnderlying - instName) - val prefixRenames = RenameMap() val prefixMap = mutable.HashMap.empty[String, String] val inlineTarget = currentModule.instOf(instName, modName) val renamedBody = bodyx - .map(appendNamePrefix(inlineTarget, safePrefix, ns, prefixMap, prefixRenames)) + .map(appendNamePrefix(inlineTarget, currentModule, safePrefix, ns, prefixMap, renames)) .map(appendRefPrefix(inlineTarget, prefixMap)) - val inlineRenames = RenameMap() - inlineRenames.record(inlineTarget, currentModule) - - renames = renames.andThen(prefixRenames).andThen(inlineRenames) + renames.record(inlineTarget, currentModule) renamedBody case sx => sx - .map(fixupRefs(instMap, currentModule, renames)) + .map(fixupRefs(instMap, currentModule)) .map(onStmt(currentModule)) } } val flatCircuit = c.copy(modules = c.modules.flatMap { case m if flatModules.contains(m.name) => None - case m => Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main))))) + case m => + Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main))))) }) + val renames = renamesSeq.tail.foldLeft(renamesSeq.head)(_ andThen _) + CircuitState(flatCircuit, LowForm, annos, Some(renames)) } } diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 5f48c883..36469064 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -459,6 +459,91 @@ class InlineInstancesTests extends LowTransformSpec { ) ) } + + "inlining both grandparent and grandchild" should "should work" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of NestedInline + | inst bar of NestedNoInline + | foo.a <= a + | bar.a <= foo.b + | b <= bar.b + | module NestedInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of NestedInline + | foo.a <= a + | b <= foo.b + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | wire i_foo_a : UInt<32> + | wire i_foo_b : UInt<32> + | i_foo_b <= i_foo_a + | inst i_bar of NestedNoInline + | i_b <= i_bar.b + | i_foo_a <= i_a + | i_bar.a <= i_foo_b + | b <= i_b + | i_a <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | wire foo_a : UInt<32> + | wire foo_b : UInt<32> + | foo_b <= foo_a + | b <= foo_b + | foo_a <= a + |""".stripMargin + val top = CircuitTarget("Top").module("Top") + val inlined = top.instOf("i", "Inline") + val nestedInlined = inlined.instOf("foo", "NestedInline") + val nestedNotInlined = inlined.instOf("bar", "NestedNoInline") + val innerNestedInlined = nestedNotInlined.instOf("foo", "NestedInline") + + executeWithAnnos(input, check, + Seq( + inline("Inline"), + inline("NestedInline"), + DummyAnno(inlined.ref("a")), + DummyAnno(inlined.ref("b")), + DummyAnno(nestedInlined.ref("a")), + DummyAnno(nestedInlined.ref("b")), + DummyAnno(nestedNotInlined.ref("a")), + DummyAnno(nestedNotInlined.ref("b")), + DummyAnno(innerNestedInlined.ref("a")), + DummyAnno(innerNestedInlined.ref("b")) + ), + Seq( + DummyAnno(top.ref("i_a")), + DummyAnno(top.ref("i_b")), + DummyAnno(top.ref("i_foo_a")), + DummyAnno(top.ref("i_foo_b")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_a")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_b")) + ) + ) + } } // Execution driven tests for inlining modules diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index 7fe154ec..6e12dd5b 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -38,6 +38,14 @@ abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Com logger.debug(actual) logger.debug(expected) (actual) should be (expected) + + annotations.foreach { anno => + logger.debug(anno.serialize) + } + + finalState.annotations.toSeq.foreach { anno => + logger.debug(anno.serialize) + } checkAnnotations.foreach { check => (finalState.annotations.toSeq) should contain (check) } |
