diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 149 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InlineInstancesTests.scala | 17 |
2 files changed, 59 insertions, 107 deletions
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 40bc7d7d..5c9d4367 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -25,9 +25,13 @@ class InlineInstances extends Transform { val inlineDelim = "$" override def name = "Inline Instances" - private def collectAnns(anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) { case ((modNames, instNames), ann) => ann match { + case InlineAnnotation(CircuitName(c)) => + (circuit.modules.collect { + case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) + }.toSet, instNames) case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) case _ => throw new PassException("Annotation must be InlineAnnotation") @@ -40,7 +44,7 @@ class InlineInstances extends Transform { getMyAnnotations(state) match { case Nil => CircuitState(state.circuit, state.form) case myAnnotations => - val (modNames, instNames) = collectAnns(myAnnotations) + val (modNames, instNames) = collectAnns(state.circuit, myAnnotations) run(state.circuit, modNames, instNames) } } @@ -84,99 +88,60 @@ class InlineInstances extends Transform { if (errors.nonEmpty) throw new PassExceptions(errors) } - def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): CircuitState = { - // Check annotations and circuit match up - check(c, modsToInline, instsToInline) - // ---- Rename functions/data ---- - val renameMap = mutable.HashMap[Named,Seq[Named]]() - // Updates renameMap with new names - def update(name: Named, rename: Named) = { - val existing = renameMap.getOrElse(name, Seq[Named]()) - if (!existing.contains(rename)) renameMap(name) = existing.:+(rename) + def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): CircuitState = { + def getInstancesOf(c: Circuit, modules: Set[String]): Set[String] = + c.modules.foldLeft(Set[String]()) { (set, d) => + d match { + case e: ExtModule => set + case m: Module => + val instances = mutable.HashSet[String]() + def findInstances(s: Statement): Statement = s match { + case WDefInstance(info, instName, moduleName, instTpe) if modules.contains(moduleName) => + instances += m.name + "." + instName + s + case sx => sx map findInstances + } + findInstances(m.body) + instances.toSet ++ set + } } - def set(name: Named, renames: Seq[Named]) = renameMap(name) = renames - // ---- Pass functions/data ---- - // Contains all unaltered modules - val originalModules = mutable.HashMap[String,DefModule]() - // Contains modules whose direct/indirect children modules have been inlined, and whose tagged instances have been inlined. - val inlinedModules = mutable.HashMap[String,DefModule]() - val cname = CircuitName(c.main) + // Check annotations and circuit match up + check(c, modsToInline, instsToInline) + val flatModules = modsToInline.map(m => m.name) + val flatInstances = instsToInline.map(i => i.module.name + "." + i.name) ++ getInstancesOf(c, flatModules) + val moduleMap = c.modules.foldLeft(Map[String, DefModule]()) { (map, m) => map + (m.name -> m) } - // Recursive. - def onModule(m: DefModule): DefModule = { - val inlinedInstances = mutable.ArrayBuffer[String]() - // Recursive. Replaces inst.port with inst$port - def onExp(e: Expression): Expression = e match { - case WSubField(WRef(ref, _, _, _), field, tpe, gen) => - // Relies on instance declaration before any instance references - if (inlinedInstances.contains(ref)) { - val newName = ref + inlineDelim + field - set(ComponentName(ref, ModuleName(m.name, cname)), Seq.empty) - WRef(newName, tpe, WireKind, gen) - } - else e - case ex => ex map onExp - } - // Recursive. Inlines tagged instances - def onStmt(s: Statement): Statement = s match { - case WDefInstance(info, instName, moduleName, instTpe) => - def rename(name:String): String = { - val newName = instName + inlineDelim + name - update(ComponentName(name, ModuleName(moduleName, cname)), ComponentName(newName, ModuleName(m.name, cname))) - newName - } - // Rewrites references in inlined statements from ref to inst$ref - def renameStmt(s: Statement): Statement = { - def renameExp(e: Expression): Expression = { - e map renameExp match { - case WRef(name, tpe, kind, gen) => WRef(rename(name), tpe, kind, gen) - case ex => ex - } - } - s map rename map renameStmt map renameExp - } - val shouldInline = - modsToInline.contains(ModuleName(moduleName, cname)) || - instsToInline.contains(ComponentName(instName, ModuleName(m.name, cname))) - // Used memoized instance if available - val instModule = - if (inlinedModules.contains(name)) inlinedModules(name) - else { - // Warning - can infinitely recurse if there is an instance loop - onModule(originalModules(moduleName)) - } - if (shouldInline) { - inlinedInstances += instName - val instInModule = instModule match { - case m: ExtModule => throw new PassException("Cannot inline external module") - case m: Module => m - } - val stmts = mutable.ArrayBuffer[Statement]() - for (p <- instInModule.ports) { - stmts += DefWire(p.info, rename(p.name), p.tpe) - } - stmts += renameStmt(instInModule.body) - Block(stmts.toSeq) - } else s - case sx => sx map onExp map onStmt - } - m match { - case Module(info, name, ports, body) => - val mx = Module(info, name, ports, onStmt(body)) - inlinedModules(name) = mx - mx - case mx: ExtModule => - inlinedModules(mx.name) = mx - mx - } - } + def appendNamePrefix(prefix: String)(name:String): String = prefix + name + def appendRefPrefix(prefix: String, currentModule: String)(e: Expression): Expression = e match { + case WSubField(WRef(ref, _, InstanceKind, _), field, tpe, gen) if flatInstances.contains(currentModule + "." + ref) => + WRef(prefix + ref + inlineDelim + field, tpe, WireKind, gen) + case WRef(name, tpe, kind, gen) => WRef(prefix + name, tpe, kind, gen) + case ex => ex map appendRefPrefix(prefix, currentModule) + } - c.modules.foreach{ m => originalModules(m.name) = m} - val top = c.modules.find(m => m.name == c.main).get - onModule(top) - val modulesx = c.modules.map(m => inlinedModules(m.name)) - CircuitState(Circuit(c.info, modulesx, c.main), LowForm, None, Some(RenameMap(renameMap.toMap))) - } + def onStmt(prefix: String, currentModule: String)(s: Statement): Statement = s match { + case WDefInstance(info, instName, moduleName, instTpe) => + // Rewrites references in inlined statements from ref to inst$ref + val shouldInline = flatInstances.contains(currentModule + "." + instName) + // Used memoized instance if available + if (shouldInline) { + val toInline = moduleMap(moduleName) match { + case m: ExtModule => throw new PassException("Cannot inline external module") + case m: Module => m + } + val stmts = toInline.ports.map(p => DefWire(p.info, p.name, p.tpe)) :+ toInline.body + onStmt(prefix + instName + inlineDelim, moduleName)(Block(stmts)) + } else s + case sx => sx map appendRefPrefix(prefix, currentModule) map onStmt(prefix, currentModule) map appendNamePrefix(prefix) + } + + val flatCircuit = c.copy(modules = c.modules.flatMap { + case m if flatModules.contains(m.name) => None + case m => + Some(m map onStmt("", m.name)) + }) + CircuitState(flatCircuit, LowForm, None, None) + } } diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 36b007e8..92ed1195 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -49,11 +49,7 @@ class InlineInstancesTests extends LowTransformSpec { | wire i$b : UInt<32> | i$b <= i$a | b <= i$b - | i$a <= a - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin + | i$a <= a""".stripMargin val writer = new StringWriter() val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top"))))) execute(writer, aMap, input, check) @@ -87,11 +83,7 @@ class InlineInstancesTests extends LowTransformSpec { | i1$b <= i1$a | b <= i1$b | i0$a <= a - | i1$a <= i0$b - | module Simple : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin + | i1$a <= i0$b""".stripMargin val writer = new StringWriter() val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top"))))) execute(writer, aMap, input, check) @@ -166,10 +158,6 @@ class InlineInstancesTests extends LowTransformSpec { | b <= i1.b | i0$a <= a | i1.a <= i0$b - | module A : - | input a : UInt<32> - | output b : UInt<32> - | b <= a | module B : | input a : UInt<32> | output b : UInt<32> @@ -183,7 +171,6 @@ class InlineInstancesTests extends LowTransformSpec { execute(writer, aMap, input, check) } - // ---- Errors ---- // 1) ext module "External module" should "not be inlined" in { |
