aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/Inline.scala149
-rw-r--r--src/test/scala/firrtlTests/InlineInstancesTests.scala17
2 files changed, 59 insertions, 107 deletions
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala
index 40bc7d7d..5c9d4367 100644
--- a/src/main/scala/firrtl/passes/Inline.scala
+++ b/src/main/scala/firrtl/passes/Inline.scala
@@ -25,9 +25,13 @@ class InlineInstances extends Transform {
val inlineDelim = "$"
override def name = "Inline Instances"
- private def collectAnns(anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
+ private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) {
case ((modNames, instNames), ann) => ann match {
+ case InlineAnnotation(CircuitName(c)) =>
+ (circuit.modules.collect {
+ case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
+ }.toSet, instNames)
case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
case _ => throw new PassException("Annotation must be InlineAnnotation")
@@ -40,7 +44,7 @@ class InlineInstances extends Transform {
getMyAnnotations(state) match {
case Nil => CircuitState(state.circuit, state.form)
case myAnnotations =>
- val (modNames, instNames) = collectAnns(myAnnotations)
+ val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
run(state.circuit, modNames, instNames)
}
}
@@ -84,99 +88,60 @@ class InlineInstances extends Transform {
if (errors.nonEmpty) throw new PassExceptions(errors)
}
- def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): CircuitState = {
- // Check annotations and circuit match up
- check(c, modsToInline, instsToInline)
- // ---- Rename functions/data ----
- val renameMap = mutable.HashMap[Named,Seq[Named]]()
- // Updates renameMap with new names
- def update(name: Named, rename: Named) = {
- val existing = renameMap.getOrElse(name, Seq[Named]())
- if (!existing.contains(rename)) renameMap(name) = existing.:+(rename)
+ def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): CircuitState = {
+ def getInstancesOf(c: Circuit, modules: Set[String]): Set[String] =
+ c.modules.foldLeft(Set[String]()) { (set, d) =>
+ d match {
+ case e: ExtModule => set
+ case m: Module =>
+ val instances = mutable.HashSet[String]()
+ def findInstances(s: Statement): Statement = s match {
+ case WDefInstance(info, instName, moduleName, instTpe) if modules.contains(moduleName) =>
+ instances += m.name + "." + instName
+ s
+ case sx => sx map findInstances
+ }
+ findInstances(m.body)
+ instances.toSet ++ set
+ }
}
- def set(name: Named, renames: Seq[Named]) = renameMap(name) = renames
- // ---- Pass functions/data ----
- // Contains all unaltered modules
- val originalModules = mutable.HashMap[String,DefModule]()
- // Contains modules whose direct/indirect children modules have been inlined, and whose tagged instances have been inlined.
- val inlinedModules = mutable.HashMap[String,DefModule]()
- val cname = CircuitName(c.main)
+ // Check annotations and circuit match up
+ check(c, modsToInline, instsToInline)
+ val flatModules = modsToInline.map(m => m.name)
+ val flatInstances = instsToInline.map(i => i.module.name + "." + i.name) ++ getInstancesOf(c, flatModules)
+ val moduleMap = c.modules.foldLeft(Map[String, DefModule]()) { (map, m) => map + (m.name -> m) }
- // Recursive.
- def onModule(m: DefModule): DefModule = {
- val inlinedInstances = mutable.ArrayBuffer[String]()
- // Recursive. Replaces inst.port with inst$port
- def onExp(e: Expression): Expression = e match {
- case WSubField(WRef(ref, _, _, _), field, tpe, gen) =>
- // Relies on instance declaration before any instance references
- if (inlinedInstances.contains(ref)) {
- val newName = ref + inlineDelim + field
- set(ComponentName(ref, ModuleName(m.name, cname)), Seq.empty)
- WRef(newName, tpe, WireKind, gen)
- }
- else e
- case ex => ex map onExp
- }
- // Recursive. Inlines tagged instances
- def onStmt(s: Statement): Statement = s match {
- case WDefInstance(info, instName, moduleName, instTpe) =>
- def rename(name:String): String = {
- val newName = instName + inlineDelim + name
- update(ComponentName(name, ModuleName(moduleName, cname)), ComponentName(newName, ModuleName(m.name, cname)))
- newName
- }
- // Rewrites references in inlined statements from ref to inst$ref
- def renameStmt(s: Statement): Statement = {
- def renameExp(e: Expression): Expression = {
- e map renameExp match {
- case WRef(name, tpe, kind, gen) => WRef(rename(name), tpe, kind, gen)
- case ex => ex
- }
- }
- s map rename map renameStmt map renameExp
- }
- val shouldInline =
- modsToInline.contains(ModuleName(moduleName, cname)) ||
- instsToInline.contains(ComponentName(instName, ModuleName(m.name, cname)))
- // Used memoized instance if available
- val instModule =
- if (inlinedModules.contains(name)) inlinedModules(name)
- else {
- // Warning - can infinitely recurse if there is an instance loop
- onModule(originalModules(moduleName))
- }
- if (shouldInline) {
- inlinedInstances += instName
- val instInModule = instModule match {
- case m: ExtModule => throw new PassException("Cannot inline external module")
- case m: Module => m
- }
- val stmts = mutable.ArrayBuffer[Statement]()
- for (p <- instInModule.ports) {
- stmts += DefWire(p.info, rename(p.name), p.tpe)
- }
- stmts += renameStmt(instInModule.body)
- Block(stmts.toSeq)
- } else s
- case sx => sx map onExp map onStmt
- }
- m match {
- case Module(info, name, ports, body) =>
- val mx = Module(info, name, ports, onStmt(body))
- inlinedModules(name) = mx
- mx
- case mx: ExtModule =>
- inlinedModules(mx.name) = mx
- mx
- }
- }
+ def appendNamePrefix(prefix: String)(name:String): String = prefix + name
+ def appendRefPrefix(prefix: String, currentModule: String)(e: Expression): Expression = e match {
+ case WSubField(WRef(ref, _, InstanceKind, _), field, tpe, gen) if flatInstances.contains(currentModule + "." + ref) =>
+ WRef(prefix + ref + inlineDelim + field, tpe, WireKind, gen)
+ case WRef(name, tpe, kind, gen) => WRef(prefix + name, tpe, kind, gen)
+ case ex => ex map appendRefPrefix(prefix, currentModule)
+ }
- c.modules.foreach{ m => originalModules(m.name) = m}
- val top = c.modules.find(m => m.name == c.main).get
- onModule(top)
- val modulesx = c.modules.map(m => inlinedModules(m.name))
- CircuitState(Circuit(c.info, modulesx, c.main), LowForm, None, Some(RenameMap(renameMap.toMap)))
- }
+ def onStmt(prefix: String, currentModule: String)(s: Statement): Statement = s match {
+ case WDefInstance(info, instName, moduleName, instTpe) =>
+ // Rewrites references in inlined statements from ref to inst$ref
+ val shouldInline = flatInstances.contains(currentModule + "." + instName)
+ // Used memoized instance if available
+ if (shouldInline) {
+ val toInline = moduleMap(moduleName) match {
+ case m: ExtModule => throw new PassException("Cannot inline external module")
+ case m: Module => m
+ }
+ val stmts = toInline.ports.map(p => DefWire(p.info, p.name, p.tpe)) :+ toInline.body
+ onStmt(prefix + instName + inlineDelim, moduleName)(Block(stmts))
+ } else s
+ case sx => sx map appendRefPrefix(prefix, currentModule) map onStmt(prefix, currentModule) map appendNamePrefix(prefix)
+ }
+
+ val flatCircuit = c.copy(modules = c.modules.flatMap {
+ case m if flatModules.contains(m.name) => None
+ case m =>
+ Some(m map onStmt("", m.name))
+ })
+ CircuitState(flatCircuit, LowForm, None, None)
+ }
}
diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala
index 36b007e8..92ed1195 100644
--- a/src/test/scala/firrtlTests/InlineInstancesTests.scala
+++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala
@@ -49,11 +49,7 @@ class InlineInstancesTests extends LowTransformSpec {
| wire i$b : UInt<32>
| i$b <= i$a
| b <= i$b
- | i$a <= a
- | module Inline :
- | input a : UInt<32>
- | output b : UInt<32>
- | b <= a""".stripMargin
+ | i$a <= a""".stripMargin
val writer = new StringWriter()
val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top")))))
execute(writer, aMap, input, check)
@@ -87,11 +83,7 @@ class InlineInstancesTests extends LowTransformSpec {
| i1$b <= i1$a
| b <= i1$b
| i0$a <= a
- | i1$a <= i0$b
- | module Simple :
- | input a : UInt<32>
- | output b : UInt<32>
- | b <= a""".stripMargin
+ | i1$a <= i0$b""".stripMargin
val writer = new StringWriter()
val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top")))))
execute(writer, aMap, input, check)
@@ -166,10 +158,6 @@ class InlineInstancesTests extends LowTransformSpec {
| b <= i1.b
| i0$a <= a
| i1.a <= i0$b
- | module A :
- | input a : UInt<32>
- | output b : UInt<32>
- | b <= a
| module B :
| input a : UInt<32>
| output b : UInt<32>
@@ -183,7 +171,6 @@ class InlineInstancesTests extends LowTransformSpec {
execute(writer, aMap, input, check)
}
-
// ---- Errors ----
// 1) ext module
"External module" should "not be inlined" in {