diff options
| author | Jack Koenig | 2021-05-18 14:54:59 -0700 |
|---|---|---|
| committer | GitHub | 2021-05-18 14:54:59 -0700 |
| commit | e0844966cbd2eb44b66c8bf341fa26370e3b4f1c (patch) | |
| tree | ae71ed0eb031e5b292c493c67aa01b1ec85afd3d /src/test | |
| parent | 97273bff5718cbcbce2673d57bce1a76ec909977 (diff) | |
Improve performance of RenameMap in LowerTypes (#2233)
LowerTypes creates a lot of mappings for the RenameMap. The built-in
.distinct of renames becomes a performance program for designs with
deeply nested Aggregates. Because LowerTypes does not create duplicate
renames, it can safely eschew the safety of using .distinct via a
private internal API.
Diffstat (limited to 'src/test')
| -rw-r--r-- | src/test/scala/firrtlTests/LowerTypesSpec.scala | 46 |
1 files changed, 36 insertions, 10 deletions
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 6e774d18..e3b7c869 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -18,19 +18,27 @@ import firrtl.util.TestOptions class LowerTypesSpec extends FirrtlFlatSpec { private val compiler = new TransformManager(Seq(Dependency(LowerTypes))) - private def executeTest(input: String, expected: Seq[String]) = { - val fir = Parser.parse(input.split("\n").toIterator) - val c = compiler.runTransform(CircuitState(fir, Seq())).circuit - val lines = c.serialize.split("\n").map(normalized) + private def executeTest(input: String, expected: Seq[String]): Unit = executeTest(input, expected, Nil, Nil) + private def executeTest( + input: String, + expected: Seq[String], + inputAnnos: Seq[Annotation], + expectedAnnos: Seq[Annotation] + ): Unit = { + val circuit = Parser.parse(input.split("\n").toIterator) + val result = compiler.runTransform(CircuitState(circuit, inputAnnos)) + val lines = result.circuit.serialize.split("\n").map(normalized) - expected.foreach { e => - lines should contain(e) + expected.map(normalized).foreach { e => + assert(lines.contains(e), f"Failed to find $e in ${lines.mkString("\n")}") } + + result.annotations.toSeq should equal(expectedAnnos) } behavior.of("Lower Types") - it should "lower ports" in { + it should "lower ports and rename them appropriately (no duplicates)" in { val input = """circuit Test : | module Test : @@ -39,7 +47,7 @@ class LowerTypesSpec extends FirrtlFlatSpec { | input y : UInt<1>[4] | input z : { c : { d : UInt<1>, e : UInt<1>}, f : UInt<1>[2] }[2] """.stripMargin - val expected = Seq( + val expectedNames = Seq( "w", "x_a", "x_b", @@ -55,9 +63,27 @@ class LowerTypesSpec extends FirrtlFlatSpec { "z_1_c_e", "z_1_f_0", "z_1_f_1" - ).map(x => s"input $x : UInt<1>").map(normalized) + ) + val expected = expectedNames.map(x => s"input $x : UInt<1>").map(normalized) + + // This annotation will error if the RenameMap returns any duplicates, checking the .distinct + // invariant on renames + case class CheckDuplicationAnnotation(ts: Seq[ReferenceTarget]) extends MultiTargetAnnotation { + def targets = ts.map(Seq(_)) + + def duplicate(n: Seq[Seq[Target]]): Annotation = { + val flat: Seq[ReferenceTarget] = n.flatten.map(_.asInstanceOf[ReferenceTarget]) + val distinct = flat.distinct + assert(flat.size == distinct.size, s"There must be no duplication of targets! Got ${flat.map(_.serialize)}") + this.copy(flat) + } + } - executeTest(input, expected) + val m = CircuitTarget("Test").module("Test") + val inputAnnos = CheckDuplicationAnnotation(Seq("w", "x", "y", "z").map(m.ref(_))) :: Nil + val expectedAnnos = CheckDuplicationAnnotation(expectedNames.map(m.ref(_))) :: Nil + + executeTest(input, expected, inputAnnos, expectedAnnos) } it should "lower mixed-direction ports" in { |
