aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlbert Chen2019-09-19 14:55:15 -0700
committerAlbert Magyar2019-09-19 14:55:14 -0700
commit5e9b286185e98c58e5fde1987c48d085ebdb1e25 (patch)
tree41b292cf71686186442cbffab03396a84a2adfb2
parent932b5d1ea66d3cc2475a22d21c237b0ed2ee9c09 (diff)
Faster inline renaming (#1184)
* dont chain inline and refix RenameMaps * cache already inlined modules * reduce number of chained RenameMaps * InlineInstances: cleanup and add comments
-rw-r--r--src/main/scala/firrtl/passes/Inline.scala137
-rw-r--r--src/test/scala/firrtlTests/InlineInstancesTests.scala85
-rw-r--r--src/test/scala/firrtlTests/PassTests.scala8
3 files changed, 186 insertions, 44 deletions
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala
index b9e67c82..0ca98ac5 100644
--- a/src/main/scala/firrtl/passes/Inline.scala
+++ b/src/main/scala/firrtl/passes/Inline.scala
@@ -6,7 +6,9 @@ package passes
import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
-import firrtl.analyses.InstanceGraph
+import firrtl.annotations.TargetToken.{Instance, OfModule}
+import firrtl.analyses.{InstanceGraph}
+import firrtl.graph.{DiGraph, MutableDiGraph}
import firrtl.stage.RunFirrtlTransformAnnotation
import firrtl.options.{RegisteredTransform, ShellOption}
@@ -108,15 +110,15 @@ class InlineInstances extends Transform with RegisteredTransform {
def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName], annos: AnnotationSeq): CircuitState = {
- def getInstancesOf(c: Circuit, modules: Set[String]): Set[String] =
- c.modules.foldLeft(Set[String]()) { (set, d) =>
+ def getInstancesOf(c: Circuit, modules: Set[String]): Set[(OfModule, Instance)] =
+ c.modules.foldLeft(Set[(OfModule, Instance)]()) { (set, d) =>
d match {
case e: ExtModule => set
case m: Module =>
- val instances = mutable.HashSet[String]()
+ val instances = mutable.HashSet[(OfModule, Instance)]()
def findInstances(s: Statement): Statement = s match {
case WDefInstance(info, instName, moduleName, instTpe) if modules.contains(moduleName) =>
- instances += m.name + "." + instName
+ instances += (OfModule(m.name) -> Instance(instName))
s
case sx => sx map findInstances
}
@@ -128,19 +130,20 @@ class InlineInstances extends Transform with RegisteredTransform {
// Check annotations and circuit match up
check(c, modsToInline, instsToInline)
val flatModules = modsToInline.map(m => m.name)
- val flatInstances = instsToInline.map(i => i.module.name + "." + i.name) ++ getInstancesOf(c, flatModules)
+ val flatInstances: Set[(OfModule, Instance)] = instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules)
val iGraph = new InstanceGraph(c)
val namespaceMap = collection.mutable.Map[String, Namespace]()
// Map of Module name to Map of instance name to Module name
- val instMaps: Map[String, Map[String, String]] = {
+ val instMaps: Map[OfModule, Map[Instance, OfModule]] = {
iGraph.graph.getEdgeMap.view.map { case (mod, children) =>
- mod.module -> children.view.map(i => i.name -> i.module).toMap
+ OfModule(mod.module) -> children.view.map(i => Instance(i.name) -> OfModule(i.module)).toMap
}.toMap
}
/** Add a prefix to all declarations updating a [[Namespace]] and appending to a [[RenameMap]] */
def appendNamePrefix(
currentModule: IsModule,
+ nextModule: IsModule,
prefix: String,
ns: Namespace,
renames: mutable.HashMap[String, String],
@@ -151,17 +154,17 @@ class InlineInstances extends Transform with RegisteredTransform {
}
ofModuleOpt match {
case None =>
- renameMap.record(currentModule.ref(name), currentModule.ref(prefix + name))
+ renameMap.record(currentModule.ref(name), nextModule.ref(prefix + name))
case Some(ofModule) =>
- renameMap.record(currentModule.instOf(name, ofModule), currentModule.instOf(prefix + name, ofModule))
+ renameMap.record(currentModule.instOf(name, ofModule), nextModule.instOf(prefix + name, ofModule))
}
renames(name) = prefix + name
prefix + name
}
s match {
- case s: WDefInstance => s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, prefix, ns, renames, renameMap))
- case other => s.map(onName(None)).map(appendNamePrefix(currentModule, prefix, ns, renames, renameMap))
+ case s: WDefInstance => s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
+ case other => s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
}
}
@@ -180,38 +183,82 @@ class InlineInstances extends Transform with RegisteredTransform {
s.map(onExpr).map(appendRefPrefix(currentModule, renames))
}
- def fixupRefs(
- instMap: Map[String, String],
- currentModule: IsModule,
- renames: RenameMap)(e: Expression): Expression = e match {
- case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) =>
- val inst = currentModule.instOf(ref, instMap(ref))
- val port = inst.ref(field)
- renames.get(port) match {
- case Some(Seq(p)) =>
- p match {
- case ReferenceTarget(_, _, Seq((TargetToken.Instance(r), TargetToken.OfModule(_))), f, Nil) =>
- wsf.copy(expr = wr.copy(name = r), name = f)
- case ReferenceTarget(_, _, Nil, r, Nil) => WRef(r, tpe, WireKind, gen)
- }
- case None => wsf
+ val cache = mutable.HashMap.empty[ModuleTarget, Statement]
+
+ /** renamesMap is a map of instances to [[RenameMap]].
+ * The keys are pairs of enclosing [[OfModule]] and [[Instance]]
+ * The [[RenameMap]]s in renamesMap are appear in renamesSeq
+ * in the order that they should be applied
+ */
+ val (renamesMap, renamesSeq) = {
+ val mutableDiGraph = new MutableDiGraph[(OfModule, Instance)]
+ // compute instance graph
+ instMaps.foreach { case (grandParentOfMod, parents) =>
+ parents.foreach { case (parentInst, parentOfMod) =>
+ val from = grandParentOfMod -> parentInst
+ mutableDiGraph.addVertex(from)
+ instMaps(parentOfMod).foreach { case (childInst, _) =>
+ val to = parentOfMod -> childInst
+ mutableDiGraph.addVertex(to)
+ mutableDiGraph.addEdge(from, to)
+ }
}
- case wr@ WRef(name, _, _, _) =>
- val comp = currentModule.ref(name) //ComponentName(name, currentModule)
- renames.get(comp).getOrElse(Seq(comp)) match {
- case Seq(car: ReferenceTarget) => wr.copy(name=car.ref)
+ }
+
+ val diGraph = DiGraph(mutableDiGraph)
+ val subgraph = diGraph.simplify(flatInstances)
+ val edges = subgraph.getEdgeMap
+
+ // calculate which [[RenameMap]] should be associated with each instance
+ val indexMap = new mutable.HashMap[(OfModule, Instance), Int]
+ flatInstances.foreach(v => indexMap(v) = 0)
+ subgraph.linearize.foreach { parent =>
+ edges(parent).foreach { child =>
+ indexMap(child) = indexMap(parent) + 1
}
- case ex => ex.map(fixupRefs(instMap, currentModule, renames))
+ }
+
+ val maxIdx = indexMap.values.max
+ val resultSeq = Seq.fill(maxIdx + 1)(RenameMap())
+ val resultMap = indexMap.mapValues(idx => resultSeq(maxIdx - idx))
+ (resultMap, resultSeq)
}
- var renames = RenameMap()
+ def fixupRefs(
+ instMap: Map[Instance, OfModule],
+ currentModule: IsModule)(e: Expression): Expression = {
+ e match {
+ case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) =>
+ val inst = currentModule.instOf(ref, instMap(Instance(ref)).value)
+ val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance))
+ val port = inst.ref(field)
+ renamesOpt.flatMap(_.get(port)) match {
+ case Some(Seq(p)) =>
+ p match {
+ case ReferenceTarget(_, _, Seq((TargetToken.Instance(r), TargetToken.OfModule(_))), f, Nil) =>
+ wsf.copy(expr = wr.copy(name = r), name = f)
+ case ReferenceTarget(_, _, Nil, r, Nil) => WRef(r, tpe, WireKind, gen)
+ }
+ case None => wsf
+ }
+ case wr@ WRef(name, _, InstanceKind, _) =>
+ val inst = currentModule.instOf(name, instMap(Instance(name)).value)
+ val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance))
+ val comp = currentModule.ref(name)
+ renamesOpt.flatMap(_.get(comp)).getOrElse(Seq(comp)) match {
+ case Seq(car: ReferenceTarget) => wr.copy(name=car.ref)
+ }
+ case ex => ex.map(fixupRefs(instMap, currentModule))
+ }
+ }
def onStmt(currentModule: ModuleTarget)(s: Statement): Statement = {
val currentModuleName = currentModule.module
val ns = namespaceMap.getOrElseUpdate(currentModuleName, Namespace(iGraph.moduleMap(currentModuleName)))
- val instMap = instMaps(currentModuleName)
+ val instMap = instMaps(OfModule(currentModuleName))
s match {
- case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(s"${currentModuleName}.$instName") =>
+ case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) =>
+ val renames = renamesMap(OfModule(currentModuleName) -> Instance(instName))
val toInline = iGraph.moduleMap(modName) match {
case m: ExtModule => throw new PassException(s"Cannot inline external module ${m.name}")
case m: Module => m
@@ -219,7 +266,10 @@ class InlineInstances extends Transform with RegisteredTransform {
val ports = toInline.ports.map(p => DefWire(p.info, p.name, p.tpe))
- val bodyx = Block(ports :+ toInline.body) map onStmt(currentModule.copy(module = modName))
+ val bodyx = {
+ val module = currentModule.copy(module = modName)
+ cache.getOrElseUpdate(module, Block(ports :+ toInline.body) map onStmt(module))
+ }
val names = "" +: Uniquify
.enumerateNames(Uniquify.stmtToType(bodyx)(NoInfo, ""))
@@ -232,31 +282,30 @@ class InlineInstances extends Transform with RegisteredTransform {
*/
val safePrefix = Uniquify.findValidPrefix(instName + inlineDelim, names, ns.cloneUnderlying - instName)
- val prefixRenames = RenameMap()
val prefixMap = mutable.HashMap.empty[String, String]
val inlineTarget = currentModule.instOf(instName, modName)
val renamedBody = bodyx
- .map(appendNamePrefix(inlineTarget, safePrefix, ns, prefixMap, prefixRenames))
+ .map(appendNamePrefix(inlineTarget, currentModule, safePrefix, ns, prefixMap, renames))
.map(appendRefPrefix(inlineTarget, prefixMap))
- val inlineRenames = RenameMap()
- inlineRenames.record(inlineTarget, currentModule)
-
- renames = renames.andThen(prefixRenames).andThen(inlineRenames)
+ renames.record(inlineTarget, currentModule)
renamedBody
case sx =>
sx
- .map(fixupRefs(instMap, currentModule, renames))
+ .map(fixupRefs(instMap, currentModule))
.map(onStmt(currentModule))
}
}
val flatCircuit = c.copy(modules = c.modules.flatMap {
case m if flatModules.contains(m.name) => None
- case m => Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main)))))
+ case m =>
+ Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main)))))
})
+ val renames = renamesSeq.tail.foldLeft(renamesSeq.head)(_ andThen _)
+
CircuitState(flatCircuit, LowForm, annos, Some(renames))
}
}
diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala
index 5f48c883..36469064 100644
--- a/src/test/scala/firrtlTests/InlineInstancesTests.scala
+++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala
@@ -459,6 +459,91 @@ class InlineInstancesTests extends LowTransformSpec {
)
)
}
+
+ "inlining both grandparent and grandchild" should "should work" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst i of Inline
+ | i.a <= a
+ | b <= i.b
+ | module Inline :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst foo of NestedInline
+ | inst bar of NestedNoInline
+ | foo.a <= a
+ | bar.a <= foo.b
+ | b <= bar.b
+ | module NestedInline :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | b <= a
+ | module NestedNoInline :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst foo of NestedInline
+ | foo.a <= a
+ | b <= foo.b
+ |""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | wire i_a : UInt<32>
+ | wire i_b : UInt<32>
+ | wire i_foo_a : UInt<32>
+ | wire i_foo_b : UInt<32>
+ | i_foo_b <= i_foo_a
+ | inst i_bar of NestedNoInline
+ | i_b <= i_bar.b
+ | i_foo_a <= i_a
+ | i_bar.a <= i_foo_b
+ | b <= i_b
+ | i_a <= a
+ | module NestedNoInline :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | wire foo_a : UInt<32>
+ | wire foo_b : UInt<32>
+ | foo_b <= foo_a
+ | b <= foo_b
+ | foo_a <= a
+ |""".stripMargin
+ val top = CircuitTarget("Top").module("Top")
+ val inlined = top.instOf("i", "Inline")
+ val nestedInlined = inlined.instOf("foo", "NestedInline")
+ val nestedNotInlined = inlined.instOf("bar", "NestedNoInline")
+ val innerNestedInlined = nestedNotInlined.instOf("foo", "NestedInline")
+
+ executeWithAnnos(input, check,
+ Seq(
+ inline("Inline"),
+ inline("NestedInline"),
+ DummyAnno(inlined.ref("a")),
+ DummyAnno(inlined.ref("b")),
+ DummyAnno(nestedInlined.ref("a")),
+ DummyAnno(nestedInlined.ref("b")),
+ DummyAnno(nestedNotInlined.ref("a")),
+ DummyAnno(nestedNotInlined.ref("b")),
+ DummyAnno(innerNestedInlined.ref("a")),
+ DummyAnno(innerNestedInlined.ref("b"))
+ ),
+ Seq(
+ DummyAnno(top.ref("i_a")),
+ DummyAnno(top.ref("i_b")),
+ DummyAnno(top.ref("i_foo_a")),
+ DummyAnno(top.ref("i_foo_b")),
+ DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")),
+ DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")),
+ DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_a")),
+ DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_b"))
+ )
+ )
+ }
}
// Execution driven tests for inlining modules
diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala
index 7fe154ec..6e12dd5b 100644
--- a/src/test/scala/firrtlTests/PassTests.scala
+++ b/src/test/scala/firrtlTests/PassTests.scala
@@ -38,6 +38,14 @@ abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Com
logger.debug(actual)
logger.debug(expected)
(actual) should be (expected)
+
+ annotations.foreach { anno =>
+ logger.debug(anno.serialize)
+ }
+
+ finalState.annotations.toSeq.foreach { anno =>
+ logger.debug(anno.serialize)
+ }
checkAnnotations.foreach { check =>
(finalState.annotations.toSeq) should contain (check)
}