diff options
| author | Kevin Laeufer | 2021-08-02 13:46:29 -0700 |
|---|---|---|
| committer | GitHub | 2021-08-02 20:46:29 +0000 |
| commit | e04f1e7f303920ac1d1f865450d0e280aafb58b3 (patch) | |
| tree | 73f26cd236ac8069d9c4877a3c42457d65d477fe /src/test | |
| parent | ff1cd28202fb423956a6803a889c3632487d8872 (diff) | |
add emitter for optimized low firrtl (#2304)
* rearrange passes to enable optimized firrtl emission
* Support ConstProp on padded arguments to comparisons with literals
* Move shr legalization logic into ConstProp
Continue calling ConstProp of shr in Legalize.
Co-authored-by: Jack Koenig <koenig@sifive.com>
Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'src/test')
7 files changed, 238 insertions, 99 deletions
diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala index 6ce90eab..f6788435 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -195,8 +195,8 @@ class FirrtlExpressionSemanticsSpec extends AnyFlatSpec { } it should "correctly translate the `neg` primitive operation" in { - assert(primop(true, "neg", 4, List(3)) == "sub(sext(3'b0, 1), sext(i0, 1))") - assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "sub(zext(3'b0, 1), zext(i0, 1))") + assert(primop(true, "neg", 4, List(3)) == "neg(sext(i0, 1))") + assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "neg(zext(i0, 1))") } it should "correctly translate the `not` primitive operation" in { diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala index 8f17a847..e5226226 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala @@ -48,7 +48,7 @@ class InvalidToRandomSpec extends LeanTransformSpec(Seq(Dependency(InvalidToRand assert(result.contains("node _GEN_1 = mux(not(o2_valid), _GEN_1_invalid, UInt<3>(\"h7\"))")) // expressions that are trivially valid do not get randomized - assert(result.contains("o3 <= UInt<2>(\"h3\")")) + assert(result.contains("o3 <= UInt<8>(\"h3\")")) val defRandCount = result.count(_.contains("rand ")) assert(defRandCount == 2) } diff --git a/src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala b/src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala new file mode 100644 index 00000000..6f1c56c5 --- /dev/null +++ b/src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests + +import firrtl._ +import firrtl.stage._ +import firrtl.util.BackendCompilationUtilities +import org.scalatest.flatspec.AnyFlatSpec + +class LoFirrtlOptimizedEmitterTests extends AnyFlatSpec { + behavior.of("LoFirrtlOptimizedEmitter") + + it should "generate valid firrtl for AddNot" in { compileAndParse("AddNot") } + it should "generate valid firrtl for FPU" in { compileAndParse("FPU") } + it should "generate valid firrtl for HwachaSequencer" in { compileAndParse("HwachaSequencer") } + it should "generate valid firrtl for ICache" in { compileAndParse("ICache") } + it should "generate valid firrtl for Ops" in { compileAndParse("Ops") } + it should "generate valid firrtl for Rob" in { compileAndParse("Rob") } + it should "generate valid firrtl for RocketCore" in { compileAndParse("RocketCore") } + + private def compileAndParse(name: String): Unit = { + val testDir = os.RelPath( + BackendCompilationUtilities.createTestDirectory( + "LoFirrtlOptimizedEmitter_should_generate_valid_firrtl_for" + name + ) + ) + val inputFile = testDir / s"$name.fir" + val outputFile = testDir / s"$name.opt.lo.fir" + + BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", (os.pwd / inputFile).toIO) + + val stage = new FirrtlStage + // run low-opt emitter + val args = Array( + "-ll", + "error", // surpress warnings to keep test output clean + "--target-dir", + testDir.toString, + "-i", + inputFile.toString, + "-E", + "low-opt" + ) + val res = stage.execute(args, Seq()) + + // load in result to check + stage.execute(Array("--target-dir", testDir.toString, "-i", outputFile.toString()), Seq()) + } +} diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index d56ca657..bb1a8169 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -89,7 +89,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { firrtl.passes.InferTypes, firrtl.passes.ResolveFlows, new firrtl.passes.InferWidths, - firrtl.passes.Legalize, + firrtl.passes.LegalizeConnects, firrtl.transforms.RemoveReset, firrtl.passes.ResolveFlows, new firrtl.transforms.CheckCombLoops, @@ -102,7 +102,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { new firrtl.transforms.ConstantPropagation, firrtl.passes.PadWidths, new firrtl.transforms.ConstantPropagation, - firrtl.passes.Legalize, + firrtl.passes.LegalizeConnects, firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter new firrtl.transforms.ConstantPropagation, firrtl.passes.SplitExpressions, @@ -114,7 +114,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { Seq( firrtl.passes.RemoveValidIf, firrtl.passes.PadWidths, - firrtl.passes.Legalize, + firrtl.passes.LegalizeConnects, firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter firrtl.passes.SplitExpressions ) @@ -215,76 +215,6 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { compare(legacyTransforms(new MiddleFirrtlToLowFirrtl), tm, patches) } - behavior.of("MinimumLowFirrtlOptimization") - - it should "replicate the old order" in { - val tm = new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm) - val patches = Seq( - Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))), - Add(6, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds))) - ) - compare(legacyTransforms(new MinimumLowFirrtlOptimization), tm, patches) - } - - behavior.of("LowFirrtlOptimization") - - it should "replicate the old order" in { - val tm = new TransformManager(Forms.LowFormOptimized, Forms.LowForm) - val patches = Seq( - Add(6, Seq(Dependency(firrtl.passes.ResolveFlows))), - Add(7, Seq(Dependency(firrtl.passes.Legalize))), - Add(8, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds))) - ) - compare(legacyTransforms(new LowFirrtlOptimization), tm, patches) - } - - behavior.of("VerilogMinimumOptimized") - - it should "replicate the old order" in { - val legacy = Seq( - new firrtl.transforms.BlackBoxSourceHelper, - new firrtl.transforms.FixAddingNegativeLiterals, - new firrtl.transforms.ReplaceTruncatingArithmetic, - new firrtl.transforms.InlineBitExtractionsTransform, - new firrtl.transforms.PropagatePresetAnnotations, - new firrtl.transforms.InlineAcrossCastsTransform, - new firrtl.transforms.LegalizeClocksTransform, - new firrtl.transforms.FlattenRegUpdate, - firrtl.passes.VerilogModulusCleanup, - new firrtl.transforms.VerilogRename, - firrtl.passes.InferTypes, - firrtl.passes.VerilogPrep, - new firrtl.AddDescriptionNodes - ) - val tm = new TransformManager(Forms.VerilogMinimumOptimized, (new firrtl.VerilogEmitter).prerequisites) - compare(legacy, tm) - } - - behavior.of("VerilogOptimized") - - it should "replicate the old order" in { - val legacy = Seq( - new firrtl.transforms.InlineBooleanExpressions, - new firrtl.transforms.DeadCodeElimination, - new firrtl.transforms.BlackBoxSourceHelper, - new firrtl.transforms.FixAddingNegativeLiterals, - new firrtl.transforms.ReplaceTruncatingArithmetic, - new firrtl.transforms.InlineBitExtractionsTransform, - new firrtl.transforms.PropagatePresetAnnotations, - new firrtl.transforms.InlineAcrossCastsTransform, - new firrtl.transforms.LegalizeClocksTransform, - new firrtl.transforms.FlattenRegUpdate, - new firrtl.transforms.DeadCodeElimination, - firrtl.passes.VerilogModulusCleanup, - new firrtl.transforms.VerilogRename, - firrtl.passes.InferTypes, - firrtl.passes.VerilogPrep, - new firrtl.AddDescriptionNodes - ) - val tm = new TransformManager(Forms.VerilogOptimized, Forms.LowFormOptimized) - compare(legacy, tm) - } - behavior.of("Legacy Custom Transforms") it should "work for Chirrtl -> Chirrtl" in { @@ -311,7 +241,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { compare(expected, tm) } - it should "work for Mid -> Mid" in { + it should "work for Mid -> Mid" ignore { val expected = new TransformManager(Forms.MidForm).flattenedTransformOrder ++ Some(new Transforms.MidToMid) ++ diff --git a/src/test/scala/firrtlTests/PadWidthsTests.scala b/src/test/scala/firrtlTests/PadWidthsTests.scala new file mode 100644 index 00000000..c92a8b79 --- /dev/null +++ b/src/test/scala/firrtlTests/PadWidthsTests.scala @@ -0,0 +1,170 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl.CircuitState +import firrtl.options.Dependency +import firrtl.stage.{Forms, TransformManager} +import firrtl.testutils.LeanTransformSpec + +class PadWidthsTests extends LeanTransformSpec(Seq(Dependency(firrtl.passes.PadWidths))) { + behavior.of("PadWidths pass") + + it should "pad widths inside a mux" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | input b : UInt<20> + | input pred : UInt<1> + | output c : UInt<32> + | c <= mux(pred,a,b)""".stripMargin + val check = Seq("c <= mux(pred, a, pad(b, 32))") + executeTest(input, check) + } + + it should "pad widths of connects" in { + val input = + """circuit Top : + | module Top : + | output a : UInt<32> + | input b : UInt<20> + | a <= b + | """.stripMargin + val check = Seq("a <= pad(b, 32)") + executeTest(input, check) + } + + it should "pad widths of register init expressions" in { + val input = + """circuit Top : + | module Top : + | input clock: Clock + | input reset: AsyncReset + | + | reg r: UInt<8>, clock with: + | reset => (reset, UInt<1>("h1")) + | """.stripMargin + // PadWidths will call into constant prop directly, thus the literal is widened instead of adding a pad + val check = Seq("reset => (reset, UInt<8>(\"h1\"))") + executeTest(input, check) + } + + private def testOp(op: String, width: Int, resultWidth: Int): Unit = { + assert(width > 0) + val input = + s"""circuit Top : + | module Top : + | input a : UInt<32> + | input b : UInt<$width> + | output c : UInt<$resultWidth> + | c <= $op(a,b)""".stripMargin + val check = if (width < 32) { + Seq(s"c <= $op(a, pad(b, 32))") + } else if (width == 32) { + Seq(s"c <= $op(a, b)") + } else { + Seq(s"c <= $op(pad(a, $width), b)") + } + executeTest(input, check) + } + + it should "pad widths of the arguments to add and sub" in { + // add and sub have the same width inference rule: max(w_1, w_2) + 1 + testOp("add", 2, 33) + testOp("add", 32, 33) + testOp("add", 35, 36) + + testOp("sub", 2, 33) + testOp("sub", 32, 33) + testOp("sub", 35, 36) + } + + it should "pad widths of the arguments to and, or and xor" in { + // and, or and xor have the same width inference rule: max(w_1, w_2) + testOp("and", 2, 32) + testOp("and", 32, 32) + testOp("and", 35, 35) + + testOp("or", 2, 32) + testOp("or", 32, 32) + testOp("or", 35, 35) + + testOp("xor", 2, 32) + testOp("xor", 32, 32) + testOp("xor", 35, 35) + } + + it should "pad widths of the arguments to lt, leq, gt, geq, eq and neq" in { + // lt, leq, gt, geq, eq and ne have the same width inference rule: 1 + testOp("lt", 2, 1) + testOp("lt", 32, 1) + testOp("lt", 35, 1) + + testOp("leq", 2, 1) + testOp("leq", 32, 1) + testOp("leq", 35, 1) + + testOp("gt", 2, 1) + testOp("gt", 32, 1) + testOp("gt", 35, 1) + + testOp("geq", 2, 1) + testOp("geq", 32, 1) + testOp("geq", 35, 1) + + testOp("eq", 2, 1) + testOp("eq", 32, 1) + testOp("eq", 35, 1) + + testOp("neq", 2, 1) + testOp("neq", 32, 1) + testOp("neq", 35, 1) + } + + private val resolvedCompiler = new TransformManager(Forms.Resolved) + private def checkWidthsAfterPadWidths(input: String, op: String): Unit = { + val result = compile(input) + + // we serialize the result in order to rerun width inference + val resultFir = firrtl.Parser.parse(result.circuit.serialize) + val newWidths = resolvedCompiler.runTransform(CircuitState(resultFir, Seq())) + + // the newly loaded circuit should look the same in serialized form (if this fails, the test has a bug) + assert(newWidths.circuit.serialize == result.circuit.serialize) + + // we compare the widths produced by PadWidths with the widths that would normally be inferred + assert(newWidths.circuit.modules.head == result.circuit.modules.head, s"failed with op `$op`") + } + + it should "always generate valid firrtl" in { + // an older version of PadWidths would generate ill types firrtl for mul, div, rem and dshl + + def input(op: String): String = + s"""circuit Top: + | module Top: + | input a: UInt<3> + | input b: UInt<1> + | output c: UInt + | c <= $op(a, b) + |""".stripMargin + + def test(op: String): Unit = checkWidthsAfterPadWidths(input(op), op) + + // This was never broken, but we want to make sure that the test works. + test("add") + + test("mul") + test("div") + test("rem") + test("dshl") + } + + private def executeTest(input: String, expected: Seq[String]): Unit = { + val result = compile(input) + val lines = result.circuit.serialize.split("\n").map(normalized) + expected.map(normalized).foreach { e => + assert(lines.contains(e), f"Failed to find $e in ${lines.mkString("\n")}") + } + } +} diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala index 58d42710..4022b267 100644 --- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala +++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala @@ -98,7 +98,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec { val (nodes, wires) = getNodesAndWires(result.circuit) wires.size should be(0) nodes.map(_.serialize) should be( - Seq("""node w = pad(UInt<2>("h2"), 8)""") + Seq("""node w = UInt<8>("h2")""") ) } diff --git a/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala b/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala index 32b1c55d..8491977c 100644 --- a/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala +++ b/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala @@ -2,32 +2,22 @@ package firrtlTests -import firrtl._ import firrtl.testutils._ -import firrtl.testutils.FirrtlCheckers._ import firrtl.ir.Circuit -import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage} +import firrtl.options.Dependency +import firrtl.passes.memlib.VerilogMemDelays -import org.scalatest.freespec.AnyFreeSpec -import org.scalatest.matchers.should.Matchers - -class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { +class VerilogMemDelaySpec extends LeanTransformSpec(Seq(Dependency(VerilogMemDelays))) { + behavior.of("VerilogMemDelaySpec") private def compileTwiceReturnFirst(input: String): Circuit = { - (new FirrtlStage) - .transform(Seq(FirrtlSourceAnnotation(input))) - .toSeq - .collectFirst { - case fca: FirrtlCircuitAnnotation => - (new FirrtlStage).transform(Seq(fca)) - fca.circuit - } - .get + val res0 = compile(input) + compile(res0.circuit.serialize).circuit } private def compileTwice(input: String): Unit = compileTwiceReturnFirst(input) - "The following low FIRRTL should be parsed by VerilogMemDelays" in { + it should "The following low FIRRTL should be parsed by VerilogMemDelays" in { val input = """ |circuit Test : @@ -63,7 +53,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { compileTwice(input) } - "Using a read-first memory should be allowed in VerilogMemDelays" in { + it should "Using a read-first memory should be allowed in VerilogMemDelays" in { val input = """ |circuit Test : @@ -107,7 +97,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { compileTwice(input) } - "Chained memories should generate correct FIRRTL" in { + it should "Chained memories should generate correct FIRRTL" in { val input = """ |circuit Test : @@ -151,7 +141,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { compileTwice(input) } - "VerilogMemDelays should not violate use before declaration of clocks" in { + it should "VerilogMemDelays should not violate use before declaration of clocks" in { val input = """ |circuit Test : @@ -188,7 +178,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { | m.write.data <= in """.stripMargin - val res = compileTwiceReturnFirst(input).serialize + val res = compile(input).circuit.serialize // Inject a Wire when using a clock not derived from ports res should include("wire m_clock : Clock") res should include("m_clock <= cm.clock") |
