diff options
| author | Jack Koenig | 2021-05-18 14:54:59 -0700 |
|---|---|---|
| committer | GitHub | 2021-05-18 14:54:59 -0700 |
| commit | e0844966cbd2eb44b66c8bf341fa26370e3b4f1c (patch) | |
| tree | ae71ed0eb031e5b292c493c67aa01b1ec85afd3d /src | |
| parent | 97273bff5718cbcbce2673d57bce1a76ec909977 (diff) | |
Improve performance of RenameMap in LowerTypes (#2233)
LowerTypes creates a lot of mappings for the RenameMap. The built-in
.distinct of renames becomes a performance program for designs with
deeply nested Aggregates. Because LowerTypes does not create duplicate
renames, it can safely eschew the safety of using .distinct via a
private internal API.
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/RenameMap.scala | 16 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 7 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/LowerTypesSpec.scala | 46 |
3 files changed, 56 insertions, 13 deletions
diff --git a/src/main/scala/firrtl/RenameMap.scala b/src/main/scala/firrtl/RenameMap.scala index 82c00ca5..d39d8106 100644 --- a/src/main/scala/firrtl/RenameMap.scala +++ b/src/main/scala/firrtl/RenameMap.scala @@ -78,6 +78,11 @@ object RenameMap { /** Initialize a new RenameMap */ def apply(): RenameMap = new RenameMap + // This is a private internal API for transforms where the .distinct operation is very expensive + // (eg. LowerTypes). The onus is on the user of this API to be very careful and not inject + // duplicates. This is a bad, hacky API that no one should use + private[firrtl] def noDistinct(): RenameMap = new RenameMap(doDistinct = false) + abstract class RenameTargetException(reason: String) extends Exception(reason) case class IllegalRenameException(reason: String) extends RenameTargetException(reason) case class CircularRenameException(reason: String) extends RenameTargetException(reason) @@ -94,7 +99,11 @@ object RenameMap { final class RenameMap private ( val underlying: mutable.HashMap[CompleteTarget, Seq[CompleteTarget]] = mutable.HashMap[CompleteTarget, Seq[CompleteTarget]](), - val chained: Option[RenameMap] = None) { + val chained: Option[RenameMap] = None, + // This is a private internal API for transforms where the .distinct operation is very expensive + // (eg. LowerTypes). The onus is on the user of this API to be very careful and not inject + // duplicates. This is a bad, hacky API that no one should use + doDistinct: Boolean = true) { /** Chain a [[RenameMap]] with this [[RenameMap]] * @param next the map to chain with this map @@ -662,7 +671,10 @@ final class RenameMap private ( private def completeRename(from: CompleteTarget, tos: Seq[CompleteTarget]): Unit = { tos.foreach { recordSensitivity(from, _) } val existing = underlying.getOrElse(from, Vector.empty) - val updated = (existing ++ tos).distinct + val updated = { + val all = (existing ++ tos) + if (doDistinct) all.distinct else all + } underlying(from) = updated traverseTokensCache.clear() traverseHierarchyCache.clear() diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 0bd44a8c..7ba320d0 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -75,7 +75,12 @@ object LowerTypes extends Transform with DependencyAPIMigration { val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule) val c = CircuitTarget(state.circuit.main) - val refRenameMap = RenameMap() + // By default, the RenameMap enforces a .distinct invariant for renames. This helps transform + // writers not mess up because violating that invariant can cause problems for transform + // writers. Unfortunately, when you have lots of renames, this is very expensive + // performance-wise. We use a private internal API that does not run .distinct to improve + // performance, but we must be careful to not insert any duplicates. + val refRenameMap = RenameMap.noDistinct() val resultAndRenames = state.circuit.modules.map(m => onModule(c, m, memInitByModule.getOrElse(m.name, Seq()), refRenameMap)) val result = state.circuit.copy(modules = resultAndRenames.map(_._1)) diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 6e774d18..e3b7c869 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -18,19 +18,27 @@ import firrtl.util.TestOptions class LowerTypesSpec extends FirrtlFlatSpec { private val compiler = new TransformManager(Seq(Dependency(LowerTypes))) - private def executeTest(input: String, expected: Seq[String]) = { - val fir = Parser.parse(input.split("\n").toIterator) - val c = compiler.runTransform(CircuitState(fir, Seq())).circuit - val lines = c.serialize.split("\n").map(normalized) + private def executeTest(input: String, expected: Seq[String]): Unit = executeTest(input, expected, Nil, Nil) + private def executeTest( + input: String, + expected: Seq[String], + inputAnnos: Seq[Annotation], + expectedAnnos: Seq[Annotation] + ): Unit = { + val circuit = Parser.parse(input.split("\n").toIterator) + val result = compiler.runTransform(CircuitState(circuit, inputAnnos)) + val lines = result.circuit.serialize.split("\n").map(normalized) - expected.foreach { e => - lines should contain(e) + expected.map(normalized).foreach { e => + assert(lines.contains(e), f"Failed to find $e in ${lines.mkString("\n")}") } + + result.annotations.toSeq should equal(expectedAnnos) } behavior.of("Lower Types") - it should "lower ports" in { + it should "lower ports and rename them appropriately (no duplicates)" in { val input = """circuit Test : | module Test : @@ -39,7 +47,7 @@ class LowerTypesSpec extends FirrtlFlatSpec { | input y : UInt<1>[4] | input z : { c : { d : UInt<1>, e : UInt<1>}, f : UInt<1>[2] }[2] """.stripMargin - val expected = Seq( + val expectedNames = Seq( "w", "x_a", "x_b", @@ -55,9 +63,27 @@ class LowerTypesSpec extends FirrtlFlatSpec { "z_1_c_e", "z_1_f_0", "z_1_f_1" - ).map(x => s"input $x : UInt<1>").map(normalized) + ) + val expected = expectedNames.map(x => s"input $x : UInt<1>").map(normalized) + + // This annotation will error if the RenameMap returns any duplicates, checking the .distinct + // invariant on renames + case class CheckDuplicationAnnotation(ts: Seq[ReferenceTarget]) extends MultiTargetAnnotation { + def targets = ts.map(Seq(_)) + + def duplicate(n: Seq[Seq[Target]]): Annotation = { + val flat: Seq[ReferenceTarget] = n.flatten.map(_.asInstanceOf[ReferenceTarget]) + val distinct = flat.distinct + assert(flat.size == distinct.size, s"There must be no duplication of targets! Got ${flat.map(_.serialize)}") + this.copy(flat) + } + } - executeTest(input, expected) + val m = CircuitTarget("Test").module("Test") + val inputAnnos = CheckDuplicationAnnotation(Seq("w", "x", "y", "z").map(m.ref(_))) :: Nil + val expectedAnnos = CheckDuplicationAnnotation(expectedNames.map(m.ref(_))) :: Nil + + executeTest(input, expected, inputAnnos, expectedAnnos) } it should "lower mixed-direction ports" in { |
