From 15013df6f6ac2dafeb35d7ed15cf95c7ac8a5bef Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Tue, 15 Dec 2020 16:41:53 -0800 Subject: Improve performance of LowerTypes renaming (#2024) This is done by having LowerTypes uses two RenameMaps instead of one for each module. There is one for renaming instance paths, and one for renaming everything within modules. Also add some utilities: * TargetUtils for dealing with InstanceTargets * RenameMap.fromInstanceRenames--- src/main/scala/firrtl/RenameMap.scala | 54 +++++++++++++++++++++ .../scala/firrtl/annotations/TargetUtils.scala | 46 ++++++++++++++++++ src/main/scala/firrtl/passes/LowerTypes.scala | 42 +++++++++------- src/test/scala/firrtl/RenameMapPrivateSpec.scala | 39 +++++++++++++++ src/test/scala/firrtl/passes/LowerTypesSpec.scala | 32 +++++++++---- src/test/scala/firrtlTests/RenameMapSpec.scala | 2 + .../annotationTests/TargetUtilsSpec.scala | 56 ++++++++++++++++++++++ 7 files changed, 246 insertions(+), 25 deletions(-) create mode 100644 src/main/scala/firrtl/annotations/TargetUtils.scala create mode 100644 src/test/scala/firrtl/RenameMapPrivateSpec.scala create mode 100644 src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala diff --git a/src/main/scala/firrtl/RenameMap.scala b/src/main/scala/firrtl/RenameMap.scala index df98f72f..82c00ca5 100644 --- a/src/main/scala/firrtl/RenameMap.scala +++ b/src/main/scala/firrtl/RenameMap.scala @@ -4,7 +4,9 @@ package firrtl import annotations._ import firrtl.RenameMap.IllegalRenameException +import firrtl.analyses.InstanceKeyGraph import firrtl.annotations.TargetToken.{Field, Index, Instance, OfModule} +import TargetUtils.{instKeyPathToTarget, unfoldInstanceTargets} import scala.collection.mutable @@ -21,6 +23,58 @@ object RenameMap { rm } + /** RenameMap factory for simple renaming of instances + * + * @param graph [[InstanceKeyGraph]] from *before* renaming + * @param renames Mapping of old instance name to new within Modules + */ + private[firrtl] def fromInstanceRenames( + graph: InstanceKeyGraph, + renames: Map[OfModule, Map[Instance, Instance]] + ): RenameMap = { + def renameAll(it: InstanceTarget): InstanceTarget = { + var prevMod = OfModule(it.module) + val pathx = it.path.map { + case (inst, of) => + val instx = renames + .get(prevMod) + .flatMap(_.get(inst)) + .getOrElse(inst) + prevMod = of + instx -> of + } + // Sanity check, the last one should always be a rename (or we wouldn't be calling this method) + val instx = renames(prevMod)(Instance(it.instance)) + it.copy(path = pathx, instance = instx.value) + } + val underlying = new mutable.HashMap[CompleteTarget, Seq[CompleteTarget]] + val instOf: String => Map[String, String] = + graph.getChildInstances.toMap + // Laziness here is desirable, we only access each key once, some we don't access + .mapValues(_.map(k => k.name -> k.module).toMap) + for ((OfModule(module), instMapping) <- renames) { + val modLookup = instOf(module) + val parentInstances = graph.findInstancesInHierarchy(module) + for { + // For every instance of the Module where the renamed instance resides + parent <- parentInstances + parentTarget = instKeyPathToTarget(parent) + // Create the absolute InstanceTarget to be renamed + (Instance(from), _) <- instMapping // The to is given by renameAll + instMod = modLookup(from) + fromTarget = parentTarget.instOf(from, instMod) + // Ensure all renames apply to the InstanceTarget + toTarget = renameAll(fromTarget) + // RenameMap only allows 1 hit when looking up InstanceTargets, so rename all possible + // paths to this instance + (fromx, tox) <- unfoldInstanceTargets(fromTarget).zip(unfoldInstanceTargets(toTarget)) + } yield { + underlying(fromx) = List(tox) + } + } + new RenameMap(underlying) + } + /** Initialize a new RenameMap */ def apply(): RenameMap = new RenameMap diff --git a/src/main/scala/firrtl/annotations/TargetUtils.scala b/src/main/scala/firrtl/annotations/TargetUtils.scala new file mode 100644 index 00000000..164c430b --- /dev/null +++ b/src/main/scala/firrtl/annotations/TargetUtils.scala @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.annotations + +import firrtl._ +import firrtl.analyses.InstanceKeyGraph +import firrtl.analyses.InstanceKeyGraph.InstanceKey +import firrtl.annotations.TargetToken._ + +object TargetUtils { + + /** Turns an instance path into a corresponding [[IsModule]] + * + * @note First InstanceKey is treated as the [[CircuitTarget]] + * @param path Instance path + * @param start Module in instance path to be starting [[ModuleTarget]] + * @return [[IsModule]] corresponding to Instance path + */ + def instKeyPathToTarget(path: Seq[InstanceKey], start: Option[String] = None): IsModule = { + val head = path.head + val startx = start.getOrElse(head.module) + val top: IsModule = CircuitTarget(head.module).module(startx) // ~Top|Start + val pathx = path.dropWhile(_.module != startx) + if (pathx.isEmpty) top + else pathx.tail.foldLeft(top) { case (acc, key) => acc.instOf(key.name, key.module) } + } + + /** Calculates all [[InstanceTarget]]s that refer to the given [[IsModule]] + * + * {{{ + * ~Top|Top/a:A/b:B/c:C unfolds to: + * * ~Top|Top/a:A/b:B/c:C + * * ~Top|A/b:B/c:C + * * ~Top|B/c:C + * }}} + * @note [[ModuleTarget]] arguments return an empty Iterable + */ + def unfoldInstanceTargets(ismod: IsModule): Iterable[InstanceTarget] = { + // concretely use List which is fast in practice + def rec(im: IsModule): List[InstanceTarget] = im match { + case inst: InstanceTarget => inst :: rec(inst.stripHierarchy(1)) + case _ => Nil + } + rec(ismod) + } +} diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 592caf5d..0bd44a8c 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -8,8 +8,10 @@ import firrtl.annotations.{ MemoryInitAnnotation, MemoryRandomInitAnnotation, ModuleTarget, - ReferenceTarget + ReferenceTarget, + TargetToken } +import TargetToken.{Instance, OfModule} import firrtl.{ CircuitForm, CircuitState, @@ -73,16 +75,18 @@ object LowerTypes extends Transform with DependencyAPIMigration { val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule) val c = CircuitTarget(state.circuit.main) - val resultAndRenames = state.circuit.modules.map(m => onModule(c, m, memInitByModule.getOrElse(m.name, Seq()))) + val refRenameMap = RenameMap() + val resultAndRenames = + state.circuit.modules.map(m => onModule(c, m, memInitByModule.getOrElse(m.name, Seq()), refRenameMap)) val result = state.circuit.copy(modules = resultAndRenames.map(_._1)) // memory init annotations could have been modified val newAnnos = otherAnnos ++ resultAndRenames.flatMap(_._3) - // chain module renames in topological order - val moduleRenames = resultAndRenames.map { case (m, r, _) => m.name -> r }.toMap - val moduleOrderBottomUp = InstanceKeyGraph(result).moduleOrder.reverseIterator - val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a, b) => a.andThen(b)) + // Build RenameMap for instances + val moduleRenames = resultAndRenames.map { case (m, r, _) => OfModule(m.name) -> r }.toMap + val instRenameMap = RenameMap.fromInstanceRenames(InstanceKeyGraph(state.circuit), moduleRenames) + val renames = instRenameMap.andThen(refRenameMap) state.copy(circuit = result, renames = Some(renames), annotations = newAnnos) } @@ -90,9 +94,9 @@ object LowerTypes extends Transform with DependencyAPIMigration { private def onModule( c: CircuitTarget, m: DefModule, - memoryInit: Seq[MemoryInitAnnotation] - ): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = { - val renameMap = RenameMap() + memoryInit: Seq[MemoryInitAnnotation], + renameMap: RenameMap + ): (DefModule, Map[Instance, Instance], Seq[MemoryInitAnnotation]) = { val ref = c.module(m.name) // first we lower the ports in order to ensure that their names are independent of the module body @@ -105,7 +109,9 @@ object LowerTypes extends Transform with DependencyAPIMigration { implicit val memInit: Seq[MemoryInitAnnotation] = memoryInit val newMod = mLoweredPorts.mapStmt(onStatement) - (newMod, renameMap, memInit) + val instRenames = symbols.getInstanceRenames.toMap + + (newMod, instRenames, memInit) } // We lower ports in a separate pass in order to ensure that statements inside the module do not influence port names. @@ -221,6 +227,7 @@ private class LoweringTable( private val namespace = mutable.HashSet[String]() ++ table.getSymbolNames // Serialized old access string to new ground type reference. private val nameToExprs = mutable.HashMap[String, Seq[RefLikeExpression]]() ++ portNameToExprs + private val instRenames = mutable.ListBuffer[(Instance, Instance)]() def lower(mem: DefMemory): Seq[DefMemory] = { val (mems, refs) = DestructTypes.destructMemory(m, mem, namespace, renameMap, portNames) @@ -228,7 +235,7 @@ private class LoweringTable( mems } def lower(inst: DefInstance): DefInstance = { - val (newInst, refs) = DestructTypes.destructInstance(m, inst, namespace, renameMap, portNames) + val (newInst, refs) = DestructTypes.destructInstance(m, inst, namespace, instRenames, portNames) nameToExprs ++= refs.map { case (name, r) => name -> List(r) } newInst } @@ -245,6 +252,7 @@ private class LoweringTable( } def getReferences(expr: RefLikeExpression): Seq[RefLikeExpression] = nameToExprs(serialize(expr)) + def getInstanceRenames: List[(Instance, Instance)] = instRenames.toList // We could just use FirrtlNode.serialize here, but we want to make sure there are not SubAccess nodes left. private def serialize(expr: RefLikeExpression): String = expr match { @@ -296,11 +304,11 @@ private object DestructTypes { * instead of a flat Reference when turning them into access expressions. */ def destructInstance( - m: ModuleTarget, - instance: DefInstance, - namespace: Namespace, - renameMap: RenameMap, - reserved: Set[String] + m: ModuleTarget, + instance: DefInstance, + namespace: Namespace, + instRenames: mutable.ListBuffer[(Instance, Instance)], + reserved: Set[String] ): (DefInstance, Seq[(String, SubField)]) = { val (rename, _) = uniquify(Field(instance.name, Default, instance.tpe), namespace, reserved) val newName = rename.map(_.name).getOrElse(instance.name) @@ -314,7 +322,7 @@ private object DestructTypes { // rename all references to the instance if necessary if (newName != instance.name) { - renameMap.record(m.instOf(instance.name, instance.module), m.instOf(newName, instance.module)) + instRenames += Instance(instance.name) -> Instance(newName) } // The ports do not need to be explicitly renamed here. They are renamed when the module ports are lowered. diff --git a/src/test/scala/firrtl/RenameMapPrivateSpec.scala b/src/test/scala/firrtl/RenameMapPrivateSpec.scala new file mode 100644 index 00000000..d735e6c8 --- /dev/null +++ b/src/test/scala/firrtl/RenameMapPrivateSpec.scala @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl + +import firrtl.annotations.Target +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.InstanceKeyGraph +import firrtl.testutils.FirrtlFlatSpec + +class RenameMapPrivateSpec extends FirrtlFlatSpec { + "RenameMap.fromInstanceRenames" should "handle instance renames" in { + def tar(str: String): Target = Target.deserialize(str) + val circuit = parse( + """circuit Top : + | module Bar : + | skip + | module Foo : + | inst bar of Bar + | module Top : + | inst foo1 of Foo + | inst foo2 of Foo + | inst bar of Bar + |""".stripMargin + ) + val graph = InstanceKeyGraph(circuit) + val renames = Map( + OfModule("Foo") -> Map(Instance("bar") -> Instance("bbb")), + OfModule("Top") -> Map(Instance("foo1") -> Instance("ffff")) + ) + val rm = RenameMap.fromInstanceRenames(graph, renames) + rm.get(tar("~Top|Top/foo1:Foo")) should be(Some(Seq(tar("~Top|Top/ffff:Foo")))) + rm.get(tar("~Top|Top/foo2:Foo")) should be(None) + // Check of nesting + rm.get(tar("~Top|Top/foo1:Foo/bar:Bar")) should be(Some(Seq(tar("~Top|Top/ffff:Foo/bbb:Bar")))) + rm.get(tar("~Top|Top/foo2:Foo/bar:Bar")) should be(Some(Seq(tar("~Top|Top/foo2:Foo/bbb:Bar")))) + rm.get(tar("~Top|Foo/bar:Bar")) should be(Some(Seq(tar("~Top|Foo/bbb:Bar")))) + rm.get(tar("~Top|Top/bar:Bar")) should be(None) + } +} diff --git a/src/test/scala/firrtl/passes/LowerTypesSpec.scala b/src/test/scala/firrtl/passes/LowerTypesSpec.scala index 70fa51fd..7ca98544 100644 --- a/src/test/scala/firrtl/passes/LowerTypesSpec.scala +++ b/src/test/scala/firrtl/passes/LowerTypesSpec.scala @@ -2,10 +2,13 @@ package firrtl.passes import firrtl.annotations.{CircuitTarget, IsMember} +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.InstanceKeyGraph import firrtl.{CircuitState, RenameMap, Utils} import firrtl.options.Dependency import firrtl.stage.TransformManager import firrtl.stage.TransformManager.TransformDependency +import firrtl.testutils.FirrtlMatchers import org.scalatest.flatspec.AnyFlatSpec /** Unit test style tests for [[LowerTypes]]. @@ -228,22 +231,35 @@ class LowerTypesRenamingSpec extends AnyFlatSpec { } /** Instances are a special case since they do not get completely destructed but instead become a 1-deep bundle. */ -class LowerTypesOfInstancesSpec extends AnyFlatSpec { +class LowerTypesOfInstancesSpec extends AnyFlatSpec with FirrtlMatchers { import LowerTypesSpecUtils._ private case class Lower(inst: firrtl.ir.DefInstance, fields: Seq[String], renameMap: RenameMap) private val m = CircuitTarget("m").module("m") + private val igraph = InstanceKeyGraph( + parse( + """circuit m: + | module c: + | skip + | module m: + | inst i of c + |""".stripMargin + ) + ) def resultToFieldSeq(res: Seq[(String, firrtl.ir.SubField)]): Seq[String] = res.map(_._2).map(r => s"${r.name} : ${r.tpe.serialize}") private def lower( - n: String, - tpe: String, - module: String, - namespace: Set[String], - renames: RenameMap = RenameMap() + n: String, + tpe: String, + module: String, + namespace: Set[String], + otherRenames: RenameMap = RenameMap() ): Lower = { val ref = firrtl.ir.DefInstance(firrtl.ir.NoInfo, n, module, parseType(tpe)) val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace - val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, renames, Set()) + val instRenames = scala.collection.mutable.ListBuffer[(Instance, Instance)]() + val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, instRenames, Set()) + val instMap = Map(OfModule("m") -> instRenames.toMap) + val renames = RenameMap.fromInstanceRenames(igraph, instMap).andThen(otherRenames) Lower(newInstance, resultToFieldSeq(res), renames) } private def get(l: Lower, m: IsMember): Set[IsMember] = l.renameMap.get(m).get.toSet @@ -305,7 +321,7 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { assert(get(l, i) == Set(i_)) // the ports renaming is also noted - val r = portRenames.andThen(otherRenames) + val r = portRenames.andThen(l.renameMap) assert(r.get(i.ref("b")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b").field("c")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b_c")).get == Seq(i_.ref("b_c"))) diff --git a/src/test/scala/firrtlTests/RenameMapSpec.scala b/src/test/scala/firrtlTests/RenameMapSpec.scala index 29466c72..bebeb0bf 100644 --- a/src/test/scala/firrtlTests/RenameMapSpec.scala +++ b/src/test/scala/firrtlTests/RenameMapSpec.scala @@ -5,6 +5,8 @@ package firrtlTests import firrtl.RenameMap import firrtl.RenameMap.IllegalRenameException import firrtl.annotations._ +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.InstanceKeyGraph import firrtl.testutils._ class RenameMapSpec extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala b/src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala new file mode 100644 index 00000000..38266efe --- /dev/null +++ b/src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests.annotationTests + +import firrtl.analyses.InstanceKeyGraph.InstanceKey +import firrtl.annotations._ +import firrtl.annotations.TargetToken._ +import firrtl.annotations.TargetUtils._ +import firrtl.testutils.FirrtlFlatSpec + +class TargetUtilsSpec extends FirrtlFlatSpec { + + behavior.of("instKeyPathToTarget") + + it should "create a ModuleTarget for the top module" in { + val input = InstanceKey("Top", "Top") :: Nil + val expected = ModuleTarget("Top", "Top") + instKeyPathToTarget(input) should be(expected) + } + + it should "create absolute InstanceTargets" in { + val input = InstanceKey("Top", "Top") :: + InstanceKey("foo", "Foo") :: + InstanceKey("bar", "Bar") :: + Nil + val expected = InstanceTarget("Top", "Top", Seq((Instance("foo"), OfModule("Foo"))), "bar", "Bar") + instKeyPathToTarget(input) should be(expected) + } + + it should "support starting somewhere down the path" in { + val input = InstanceKey("Top", "Top") :: + InstanceKey("foo", "Foo") :: + InstanceKey("bar", "Bar") :: + InstanceKey("fizz", "Fizz") :: + Nil + val expected = InstanceTarget("Top", "Bar", Seq(), "fizz", "Fizz") + instKeyPathToTarget(input, Some("Bar")) should be(expected) + } + + behavior.of("unfoldInstanceTargets") + + it should "return nothing for ModuleTargets" in { + val input = ModuleTarget("Top", "Foo") + unfoldInstanceTargets(input) should be(Iterable()) + } + + it should "return all other InstanceTargets to the same instance" in { + val input = ModuleTarget("Top", "Top").instOf("foo", "Foo").instOf("bar", "Bar").instOf("fizz", "Fizz") + val expected = + input :: + ModuleTarget("Top", "Foo").instOf("bar", "Bar").instOf("fizz", "Fizz") :: + ModuleTarget("Top", "Bar").instOf("fizz", "Fizz") :: + Nil + unfoldInstanceTargets(input) should be(expected) + } +} -- cgit v1.2.3