aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack Koenig2021-05-18 14:54:59 -0700
committerGitHub2021-05-18 14:54:59 -0700
commite0844966cbd2eb44b66c8bf341fa26370e3b4f1c (patch)
treeae71ed0eb031e5b292c493c67aa01b1ec85afd3d
parent97273bff5718cbcbce2673d57bce1a76ec909977 (diff)
Improve performance of RenameMap in LowerTypes (#2233)
LowerTypes creates a lot of mappings for the RenameMap. The built-in .distinct of renames becomes a performance program for designs with deeply nested Aggregates. Because LowerTypes does not create duplicate renames, it can safely eschew the safety of using .distinct via a private internal API.
-rw-r--r--src/main/scala/firrtl/RenameMap.scala16
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala7
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala46
3 files changed, 56 insertions, 13 deletions
diff --git a/src/main/scala/firrtl/RenameMap.scala b/src/main/scala/firrtl/RenameMap.scala
index 82c00ca5..d39d8106 100644
--- a/src/main/scala/firrtl/RenameMap.scala
+++ b/src/main/scala/firrtl/RenameMap.scala
@@ -78,6 +78,11 @@ object RenameMap {
/** Initialize a new RenameMap */
def apply(): RenameMap = new RenameMap
+ // This is a private internal API for transforms where the .distinct operation is very expensive
+ // (eg. LowerTypes). The onus is on the user of this API to be very careful and not inject
+ // duplicates. This is a bad, hacky API that no one should use
+ private[firrtl] def noDistinct(): RenameMap = new RenameMap(doDistinct = false)
+
abstract class RenameTargetException(reason: String) extends Exception(reason)
case class IllegalRenameException(reason: String) extends RenameTargetException(reason)
case class CircularRenameException(reason: String) extends RenameTargetException(reason)
@@ -94,7 +99,11 @@ object RenameMap {
final class RenameMap private (
val underlying: mutable.HashMap[CompleteTarget, Seq[CompleteTarget]] =
mutable.HashMap[CompleteTarget, Seq[CompleteTarget]](),
- val chained: Option[RenameMap] = None) {
+ val chained: Option[RenameMap] = None,
+ // This is a private internal API for transforms where the .distinct operation is very expensive
+ // (eg. LowerTypes). The onus is on the user of this API to be very careful and not inject
+ // duplicates. This is a bad, hacky API that no one should use
+ doDistinct: Boolean = true) {
/** Chain a [[RenameMap]] with this [[RenameMap]]
* @param next the map to chain with this map
@@ -662,7 +671,10 @@ final class RenameMap private (
private def completeRename(from: CompleteTarget, tos: Seq[CompleteTarget]): Unit = {
tos.foreach { recordSensitivity(from, _) }
val existing = underlying.getOrElse(from, Vector.empty)
- val updated = (existing ++ tos).distinct
+ val updated = {
+ val all = (existing ++ tos)
+ if (doDistinct) all.distinct else all
+ }
underlying(from) = updated
traverseTokensCache.clear()
traverseHierarchyCache.clear()
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index 0bd44a8c..7ba320d0 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -75,7 +75,12 @@ object LowerTypes extends Transform with DependencyAPIMigration {
val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule)
val c = CircuitTarget(state.circuit.main)
- val refRenameMap = RenameMap()
+ // By default, the RenameMap enforces a .distinct invariant for renames. This helps transform
+ // writers not mess up because violating that invariant can cause problems for transform
+ // writers. Unfortunately, when you have lots of renames, this is very expensive
+ // performance-wise. We use a private internal API that does not run .distinct to improve
+ // performance, but we must be careful to not insert any duplicates.
+ val refRenameMap = RenameMap.noDistinct()
val resultAndRenames =
state.circuit.modules.map(m => onModule(c, m, memInitByModule.getOrElse(m.name, Seq()), refRenameMap))
val result = state.circuit.copy(modules = resultAndRenames.map(_._1))
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index 6e774d18..e3b7c869 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -18,19 +18,27 @@ import firrtl.util.TestOptions
class LowerTypesSpec extends FirrtlFlatSpec {
private val compiler = new TransformManager(Seq(Dependency(LowerTypes)))
- private def executeTest(input: String, expected: Seq[String]) = {
- val fir = Parser.parse(input.split("\n").toIterator)
- val c = compiler.runTransform(CircuitState(fir, Seq())).circuit
- val lines = c.serialize.split("\n").map(normalized)
+ private def executeTest(input: String, expected: Seq[String]): Unit = executeTest(input, expected, Nil, Nil)
+ private def executeTest(
+ input: String,
+ expected: Seq[String],
+ inputAnnos: Seq[Annotation],
+ expectedAnnos: Seq[Annotation]
+ ): Unit = {
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ val result = compiler.runTransform(CircuitState(circuit, inputAnnos))
+ val lines = result.circuit.serialize.split("\n").map(normalized)
- expected.foreach { e =>
- lines should contain(e)
+ expected.map(normalized).foreach { e =>
+ assert(lines.contains(e), f"Failed to find $e in ${lines.mkString("\n")}")
}
+
+ result.annotations.toSeq should equal(expectedAnnos)
}
behavior.of("Lower Types")
- it should "lower ports" in {
+ it should "lower ports and rename them appropriately (no duplicates)" in {
val input =
"""circuit Test :
| module Test :
@@ -39,7 +47,7 @@ class LowerTypesSpec extends FirrtlFlatSpec {
| input y : UInt<1>[4]
| input z : { c : { d : UInt<1>, e : UInt<1>}, f : UInt<1>[2] }[2]
""".stripMargin
- val expected = Seq(
+ val expectedNames = Seq(
"w",
"x_a",
"x_b",
@@ -55,9 +63,27 @@ class LowerTypesSpec extends FirrtlFlatSpec {
"z_1_c_e",
"z_1_f_0",
"z_1_f_1"
- ).map(x => s"input $x : UInt<1>").map(normalized)
+ )
+ val expected = expectedNames.map(x => s"input $x : UInt<1>").map(normalized)
+
+ // This annotation will error if the RenameMap returns any duplicates, checking the .distinct
+ // invariant on renames
+ case class CheckDuplicationAnnotation(ts: Seq[ReferenceTarget]) extends MultiTargetAnnotation {
+ def targets = ts.map(Seq(_))
+
+ def duplicate(n: Seq[Seq[Target]]): Annotation = {
+ val flat: Seq[ReferenceTarget] = n.flatten.map(_.asInstanceOf[ReferenceTarget])
+ val distinct = flat.distinct
+ assert(flat.size == distinct.size, s"There must be no duplication of targets! Got ${flat.map(_.serialize)}")
+ this.copy(flat)
+ }
+ }
- executeTest(input, expected)
+ val m = CircuitTarget("Test").module("Test")
+ val inputAnnos = CheckDuplicationAnnotation(Seq("w", "x", "y", "z").map(m.ref(_))) :: Nil
+ val expectedAnnos = CheckDuplicationAnnotation(expectedNames.map(m.ref(_))) :: Nil
+
+ executeTest(input, expected, inputAnnos, expectedAnnos)
}
it should "lower mixed-direction ports" in {