diff options
| author | Kevin Laeufer | 2020-07-17 18:07:54 -0700 |
|---|---|---|
| committer | GitHub | 2020-07-18 01:07:54 +0000 |
| commit | 1b9f4ddff4102fee72ae4dd8c111c82c32e42d5d (patch) | |
| tree | fee6379c83e4026edcf3d577a2a9474024ed0b59 /src | |
| parent | 5f70175d24cbeeef2ffae3fb00b99e06c5462bd0 (diff) | |
Faster dedup instance graph (#1732)
* dedup: add faster InstanceGraph implementation and use it in dedup
The new implementation takes care not to hash the instance
types contained in DefInstance nodes.
This should make dedup considerably faster.
* FastInstanceGraph: cache vertices for faster findInstancesInHierarchy
* FastInstanceGraph: remove the parent name field since it isn't actually necessary
* FastInstanceGraph -> InstanceKeyGraph
* InstanceGraph: describe performance problems.
* InstanceKeyGraph: turn moduleMap into a def instead of a val
This will make changing implementation details much easier
in the future.
* InstanceKeyGraph: return childInstances as Seq instead of Map
This ensures a deterministic iteration order and it
can easily be turned into a Map for O(1) accesses.
* InstanceKeyGraph: add tests for public methods
* InstanceKeyGraph: group public methods together
* InstanceKeyGraphSpec: fix wording of a comment
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src')
4 files changed, 264 insertions, 18 deletions
diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala index 7b60b110..ddd5eb8b 100644 --- a/src/main/scala/firrtl/analyses/InstanceGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala @@ -15,6 +15,15 @@ import firrtl.annotations.TargetToken._ * * @constructor constructs an instance graph from a Circuit * @param c the Circuit to analyze + * @note The current implementation has some performance problems, which is why [[InstanceKeyGraph]] + * exists and should be preferred for new use cases. Eventually the old class will be deprecated + * in favor of the new implementation. + * The performance problems in the old implementation stem from the fact that DefInstance is used as the + * key to the underlying Map. DefInstance contains the type of the module besides the module and instance names. + * This type is not needed as it can be inferred from the module name. If the module name is the same, + * the type will be the same and vice versa. + * Hashing and comparing deep bundle types however is inefficient which can manifest in slower then necessary + * lookups and insertions. */ class InstanceGraph(c: Circuit) { diff --git a/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala new file mode 100644 index 00000000..ab3c9742 --- /dev/null +++ b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala @@ -0,0 +1,116 @@ +// See LICENSE for license details. + +package firrtl.analyses + +import firrtl.annotations._ +import firrtl.annotations.TargetToken._ +import firrtl.graph.{DiGraph, MutableDiGraph} +import firrtl.ir + +import scala.collection.mutable + +/** A class representing the instance hierarchy of firrtl Circuit + * This is a faster version of the old `InstanceGraph` which only uses + * pairs of InstanceName and Module name as vertex keys instead of using WDefInstance + * which will hash the instance type causing some performance issues. + */ +class InstanceKeyGraph(c: ir.Circuit) { + import InstanceKeyGraph._ + + private val nameToModule: Map[String, ir.DefModule] = c.modules.map({m => (m.name,m) }).toMap + private val childInstances: Seq[(String, Seq[InstanceKey])] = c.modules.map { m => + m.name -> InstanceKeyGraph.collectInstances(m) + } + private val instantiated = childInstances.flatMap(_._2).map(_.module).toSet + private val roots = c.modules.map(_.name).filterNot(instantiated) + private val graph = buildGraph(childInstances, roots) + private val circuitTopInstance = topKey(c.main) + // cache vertices to speed up repeat calls to findInstancesInHierarchy + private lazy val vertices = graph.getVertices + + /** A list of absolute paths (each represented by a Seq of instances) of all module instances in the Circuit. */ + private lazy val fullHierarchy: mutable.LinkedHashMap[InstanceKey, Seq[Seq[InstanceKey]]] = + graph.pathsInDAG(circuitTopInstance) + + /** maps module names to the DefModule node */ + def moduleMap: Map[String, ir.DefModule] = nameToModule + + /** Module order from highest module to leaf module */ + def moduleOrder: Seq[ir.DefModule] = graph.transformNodes(_.module).linearize.map(nameToModule(_)) + + /** Returns a sequence that can be turned into a map from module name to instances defined in said module. */ + def getChildInstances: Seq[(String, Seq[InstanceKey])] = childInstances + + /** Finds the absolute paths (each represented by a Seq of instances + * representing the chain of hierarchy) of all instances of a particular + * module. Note that this includes one implicit instance of the top (main) + * module of the circuit. If the module is not instantiated within the + * hierarchy of the top module of the circuit, it will return Nil. + * + * @param module the name of the selected module + * @return a Seq[ Seq[WDefInstance] ] of absolute instance paths + */ + def findInstancesInHierarchy(module: String): Seq[Seq[InstanceKey]] = { + val instances = vertices.filter(_.module == module).toSeq + instances.flatMap{ i => fullHierarchy.getOrElse(i, Nil) } + } +} + + +object InstanceKeyGraph { + /** We want to only use this untyped version as key because hashing bundle types is expensive + * @param name the name of the instance + * @param module the name of the module that is instantiated + */ + case class InstanceKey(name: String, module: String) { + def Instance: Instance = TargetToken.Instance(name) + def OfModule: OfModule = TargetToken.OfModule(module) + def toTokens: (Instance, OfModule) = (Instance, OfModule) + } + + /** Finds all instance definitions in a firrtl Module. */ + def collectInstances(m: ir.DefModule): Seq[InstanceKey] = m match { + case _ : ir.ExtModule => Seq() + case ir.Module(_, _, _, body) => { + val instances = mutable.ArrayBuffer[InstanceKey]() + def onStmt(s: ir.Statement): Unit = s match { + case firrtl.WDefInstance(_, name, module, _) => instances += InstanceKey(name, module) + case ir.DefInstance(_, name, module, _) => instances += InstanceKey(name, module) + case _: firrtl.WDefInstanceConnector => + firrtl.Utils.throwInternalError("Expecting WDefInstance, found a WDefInstanceConnector!") + case other => other.foreachStmt(onStmt) + } + onStmt(body) + instances + } + } + + private def topKey(module: String): InstanceKey = InstanceKey(module, module) + + private def buildGraph(childInstances: Seq[(String, Seq[InstanceKey])], roots: Iterable[String]): + DiGraph[InstanceKey] = { + val instanceGraph = new MutableDiGraph[InstanceKey] + val childInstanceMap = childInstances.toMap + + // iterate over all modules that are not instantiated and thus act as a root + roots.foreach { subTop => + // create a root node + val topInstance = topKey(subTop) + // graph traversal + val instanceQueue = new mutable.Queue[InstanceKey] + instanceQueue.enqueue(topInstance) + while (instanceQueue.nonEmpty) { + val current = instanceQueue.dequeue + instanceGraph.addVertex(current) + for (child <- childInstanceMap(current.module)) { + if (!instanceGraph.contains(child)) { + instanceQueue.enqueue(child) + instanceGraph.addVertex(child) + } + instanceGraph.addEdge(current, child) + } + } + } + instanceGraph + } +} diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index ba06ba4b..03b5faa9 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -6,12 +6,12 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ -import firrtl.analyses.InstanceGraph +import firrtl.analyses.InstanceKeyGraph import firrtl.annotations._ import firrtl.passes.{InferTypes, MemPortUtils} import firrtl.Utils.{kind, splitRef, throwInternalError} import firrtl.annotations.transforms.DupedResult -import firrtl.annotations.TargetToken.{OfModule, Instance} +import firrtl.annotations.TargetToken.{Instance, OfModule} import firrtl.options.{HasShellOptions, ShellOption} import logger.LazyLogging @@ -157,7 +157,7 @@ class DedupModules extends Transform with DependencyAPIMigration { moduleRenameMap.recordAll(map) // Build instanceify renaming map - val instanceGraph = new InstanceGraph(c) + val instanceGraph = new InstanceKeyGraph(c) val instanceify = RenameMap() val moduleName2Index = c.modules.map(_.name).zipWithIndex.map { case (n, i) => { @@ -171,16 +171,14 @@ class DedupModules extends Transform with DependencyAPIMigration { // get the ordered set of instances a module, includes new Deduped modules val getChildrenInstances = { - val childrenMap = instanceGraph.getChildrenInstances - val newModsMap: Map[String, mutable.LinkedHashSet[WDefInstance]] = dedupMap.map { - case (name, m: Module) => - val set = new mutable.LinkedHashSet[WDefInstance] - InstanceGraph.collectInstances(set)(m.body) - m.name -> set - case (name, m: DefModule) => - m.name -> mutable.LinkedHashSet.empty[WDefInstance] - }.toMap - (mod: String) => childrenMap.get(mod).getOrElse(newModsMap(mod)) + val childrenMap = instanceGraph.getChildInstances.toMap + val newModsMap = dedupMap.map { + case (_, m: Module) => + m.name -> InstanceKeyGraph.collectInstances(m) + case (_, m: DefModule) => + m.name -> List() + } + (mod: String) => childrenMap.getOrElse(mod, newModsMap(mod)) } val instanceNameMap: Map[OfModule, Map[Instance, Instance]] = { @@ -200,7 +198,7 @@ class DedupModules extends Transform with DependencyAPIMigration { // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option val newTargets = paths.map { path => val root: IsModule = ct.module(c) - path.foldLeft(root -> root) { case ((oldRelPath, newRelPath), WDefInstance(_, name, mod, _)) => + path.foldLeft(root -> root) { case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) => if(mod == c) { val mod = CircuitTarget(c).module(c) mod -> mod @@ -333,9 +331,8 @@ object DedupModules extends LazyLogging { // If black box, return it (it has no instances) if (module.isInstanceOf[ExtModule]) return module - // Get all instances to know what to rename in the module - val instances = mutable.Set[WDefInstance]() - InstanceGraph.collectInstances(instances)(module.asInstanceOf[Module].body) + // Get all instances to know what to rename in the module s + val instances = InstanceKeyGraph.collectInstances(module) val instanceModuleMap = instances.map(i => i.name -> i.module).toMap def getNewModule(old: String): DefModule = { @@ -470,7 +467,7 @@ object DedupModules extends LazyLogging { renameMap: RenameMap): Map[String, DefModule] = { val (moduleMap, moduleLinearization) = { - val iGraph = new InstanceGraph(circuit) + val iGraph = new InstanceKeyGraph(circuit) (iGraph.moduleMap, iGraph.moduleOrder.reverse) } val main = circuit.main diff --git a/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala b/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala new file mode 100644 index 00000000..8d073ecb --- /dev/null +++ b/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala @@ -0,0 +1,124 @@ +// See LICENSE for license details. + +package firrtlTests.analyses + +import firrtl.analyses.InstanceKeyGraph +import firrtl.analyses.InstanceKeyGraph.InstanceKey +import firrtl.testutils.FirrtlFlatSpec + +class InstanceKeyGraphSpec extends FirrtlFlatSpec { + behavior of "InstanceKeyGraph" + + // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to + // experience non-determinism + it should "preserve Module declaration order" in { + val input = """ + |circuit Top : + | module Top : + | inst c1 of Child1 + | inst c2 of Child2 + | module Child1 : + | inst a of Child1a + | inst b of Child1b + | skip + | module Child1a : + | skip + | module Child1b : + | skip + | module Child2 : + | skip + |""".stripMargin + val circuit = parse(input) + val instGraph = new InstanceKeyGraph(circuit) + val childMap = instGraph.getChildInstances + childMap.map(_._1) should equal (Seq("Top", "Child1", "Child1a", "Child1b", "Child2")) + } + + // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to + // experience non-determinism + it should "preserve Instance declaration order" in { + val input = """ + |circuit Top : + | module Top : + | inst a of Child + | inst b of Child + | inst c of Child + | inst d of Child + | inst e of Child + | inst f of Child + | module Child : + | skip + |""".stripMargin + val circuit = parse(input) + val instGraph = new InstanceKeyGraph(circuit) + val childMap = instGraph.getChildInstances.toMap + val insts = childMap("Top").map(_.name) + insts should equal (Seq("a", "b", "c", "d", "e", "f")) + } + + it should "compute a correct and deterministic module order" in { + val input = """ + |circuit Top : + | module Top : + | inst c1 of Child1 + | inst c2 of Child2 + | inst c4 of Child4 + | inst c3 of Child3 + | module Child1 : + | inst a of Child1a + | inst b of Child1b + | skip + | module Child1a : + | skip + | module Child1b : + | skip + | module Child2 : + | skip + | module Child3 : + | skip + | module Child4 : + | skip + |""".stripMargin + val circuit = parse(input) + val instGraph = new InstanceKeyGraph(circuit) + val order = instGraph.moduleOrder.map(_.name) + // Where it has freedom, the instance declaration order will be reversed. + order should equal (Seq("Top", "Child3", "Child4", "Child2", "Child1", "Child1b", "Child1a")) + } + + it should "find hierarchical instances correctly in disconnected hierarchies" in { + val input = + """circuit Top : + | module Top : + | inst c of Child1 + | module Child1 : + | skip + | + | module Top2 : + | inst a of Child2 + | inst b of Child3 + | skip + | module Child2 : + | inst a of Child2a + | inst b of Child2b + | skip + | module Child2a : + | skip + | module Child2b : + | skip + | module Child3 : + | skip + |""".stripMargin + + val circuit = parse(input) + val iGraph = new InstanceKeyGraph(circuit) + iGraph.findInstancesInHierarchy("Top") shouldBe Seq(Seq(InstanceKey("Top", "Top"))) + iGraph.findInstancesInHierarchy("Child1") shouldBe Seq(Seq(InstanceKey("Top", "Top"), InstanceKey("c", "Child1"))) + iGraph.findInstancesInHierarchy("Top2") shouldBe Nil + iGraph.findInstancesInHierarchy("Child2") shouldBe Nil + iGraph.findInstancesInHierarchy("Child2a") shouldBe Nil + iGraph.findInstancesInHierarchy("Child2b") shouldBe Nil + iGraph.findInstancesInHierarchy("Child3") shouldBe Nil + } + +} |
