diff options
| author | jackkoenig | 2016-10-20 00:19:01 -0700 |
|---|---|---|
| committer | Jack Koenig | 2016-11-04 13:29:09 -0700 |
| commit | 8fa9429a6e916ab2a789f5d81fa803b022805b52 (patch) | |
| tree | fac2efcbd0a68bfb1916f09afc7f003c7a3d6528 /src/test | |
| parent | 62133264a788f46b319ebab9c31424b7e0536101 (diff) | |
Refactor Compilers and Transforms
* Transform Ids now handled by Class[_ <: Transform] instead of magic numbers
* Transforms define inputForm and outputForm
* Custom transforms can be inserted at runtime into compiler or the Driver
* Current "built-in" custom transforms handled via above mechanism
* Verilog-specific passes moved to the Verilog emitter
Diffstat (limited to 'src/test')
24 files changed, 294 insertions, 126 deletions
diff --git a/src/test/resources/features/CustomTransform.fir b/src/test/resources/features/CustomTransform.fir new file mode 100644 index 00000000..941a9e9c --- /dev/null +++ b/src/test/resources/features/CustomTransform.fir @@ -0,0 +1,33 @@ +circuit CustomTransform : + ; Replaced in custom transform by an implementation + extmodule Delay : + input clk : Clock + input reset : UInt<1> + input a : UInt<32> + input en : UInt<1> + output b : UInt<32> + + module CustomTransform : + input clk : Clock + input reset : UInt<1> + + reg cycle : UInt<32>, clk with : (reset => (reset, UInt<32>(0))) + cycle <= tail(add(cycle, UInt<32>(1)), 1) + + inst delay of Delay + delay.clk <= clk + delay.reset <= reset + delay.a <= UInt(0) + delay.en <= UInt(0) + + when eq(cycle, UInt(0)) : + delay.en <= UInt(1) + delay.a <= UInt("hdeadbeef") + when eq(cycle, UInt(1)) : + when neq(delay.b, UInt("hdeadbeef")) : + printf(clk, UInt(1), "Assertion failed!\n") + stop(clk, UInt(1), 1) + when eq(cycle, UInt(2)) : + printf(clk, UInt(1), "Success!\n") + stop(clk, UInt(1), 0) + diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 0312df5d..c395139b 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -9,14 +9,16 @@ import org.scalatest.junit.JUnitRunner import firrtl.ir.Circuit import firrtl.Parser import firrtl.{ + CircuitState, ResolveAndCheck, RenameMap, Compiler, - CompilerResult, - VerilogCompiler + ChirrtlForm, + LowForm, + VerilogCompiler, + Transform } import firrtl.Annotations.{ - TransID, Named, CircuitName, ModuleName, @@ -39,17 +41,17 @@ import firrtl.Annotations.{ */ trait AnnotationSpec extends LowTransformSpec { // Dummy transform - def transform = new ResolveAndCheck() + def transform = new CustomResolveAndCheck(LowForm) // Check if Annotation Exception is thrown override def failingexecute(writer: Writer, annotations: AnnotationMap, input: String) = { intercept[AnnotationException] { - compile(parse(input), annotations, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer) } } def execute(writer: Writer, annotations: AnnotationMap, input: String, check: Annotation) = { - val cr = compile(parse(input), annotations, writer) - (cr.annotationMap.annotations.head) should be (check) + val cr = compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer) + (cr.annotations.get.annotations.head) should be (check) } } @@ -63,7 +65,6 @@ trait AnnotationSpec extends LowTransformSpec { */ class AnnotationTests extends AnnotationSpec with Matchers { def getAMap (a: Annotation): AnnotationMap = new AnnotationMap(Seq(a)) - val tID = TransID(1) val input = """circuit Top : | module Top : @@ -76,11 +77,12 @@ class AnnotationTests extends AnnotationSpec with Matchers { val cName = ComponentName("c", mName) "Loose and Sticky annotation on a node" should "pass through" in { - case class TestAnnotation(target: Named, tID: TransID) extends Annotation with Loose with Sticky { + case class TestAnnotation(target: Named) extends Annotation with Loose with Sticky { def duplicate(to: Named) = this.copy(target=to) + def transform = classOf[Transform] } val w = new StringWriter() - val ta = TestAnnotation(cName, tID) + val ta = TestAnnotation(cName) execute(w, getAMap(ta), input, ta) } } diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala index d1e07eae..3a67bf04 100644 --- a/src/test/scala/firrtlTests/AttachSpec.scala +++ b/src/test/scala/firrtlTests/AttachSpec.scala @@ -37,10 +37,9 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class InoutVerilog extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], compiler: Compiler) = { val writer = new StringWriter() - compiler.compile(parse(input), new AnnotationMap(Seq.empty), writer) + compiler.compile(CircuitState(parse(input), ChirrtlForm), writer) val lines = writer.toString().split("\n") map normalized expected foreach { e => lines should contain(e) @@ -176,7 +175,6 @@ class InoutVerilog extends FirrtlFlatSpec { } class AttachAnalogSpec extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala index 719a3334..51663eaf 100644 --- a/src/test/scala/firrtlTests/CInferMDirSpec.scala +++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala @@ -63,13 +63,12 @@ class CInferMDir extends LowTransformSpec { } } - object CInferMDirCheck extends Transform with SimpleRun { - def execute(c: Circuit, map: AnnotationMap) = - run(c, Seq(ConstProp, CInferMDirCheckPass)) + def transform = new PassBasedTransform { + def inputForm = LowForm + def outputForm = LowForm + def passSeq = Seq(ConstProp, CInferMDirCheckPass) } - def transform = CInferMDirCheck - "Memory" should "have correct mem port directions" in { val input = """ circuit foo : @@ -97,7 +96,7 @@ circuit foo : val annotationMap = AnnotationMap(Nil) val writer = new java.io.StringWriter - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) // Check correctness of firrtl parse(writer.toString) } diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala index e2eaf690..e8dc60ae 100644 --- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala +++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala @@ -36,7 +36,6 @@ import firrtl.Parser.IgnoreInfo import firrtl.passes._ class CheckInitializationSpec extends FirrtlFlatSpec { - private def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private val passes = Seq( ToWorkingIR, CheckHighForm, diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index e0691a6b..63397da8 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -76,13 +76,12 @@ class ChirrtlMemSpec extends LowTransformSpec { } } - object MemEnableCheck extends Transform with SimpleRun { - def execute(c: Circuit, map: AnnotationMap) = - run(c, Seq(ConstProp, MemEnableCheckPass)) + def transform = new PassBasedTransform { + def inputForm = LowForm + def outputForm = LowForm + def passSeq = Seq(ConstProp, MemEnableCheckPass) } - def transform = MemEnableCheck - "Sequential Memory" should "have correct enable signals" in { val input = """ circuit foo : @@ -104,7 +103,7 @@ circuit foo : val annotationMap = AnnotationMap(Nil) val writer = new java.io.StringWriter - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) // Check correctness of firrtl parse(writer.toString) } @@ -131,7 +130,7 @@ circuit foo : val annotationMap = AnnotationMap(Nil) val writer = new java.io.StringWriter - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) // Check correctness of firrtl parse(writer.toString) } diff --git a/src/test/scala/firrtlTests/CompilerTests.scala b/src/test/scala/firrtlTests/CompilerTests.scala index 2eab6e0f..28d09c2d 100644 --- a/src/test/scala/firrtlTests/CompilerTests.scala +++ b/src/test/scala/firrtlTests/CompilerTests.scala @@ -8,13 +8,14 @@ import org.scalatest.junit.JUnitRunner import firrtl.ir.Circuit import firrtl.{ - HighFirrtlCompiler, - LowFirrtlCompiler, - VerilogCompiler, - Compiler, - Parser + ChirrtlForm, + CircuitState, + Compiler, + HighFirrtlCompiler, + LowFirrtlCompiler, + Parser, + VerilogCompiler } -import firrtl.Annotations.AnnotationMap /** * An example methodology for testing Firrtl compilers. @@ -30,7 +31,7 @@ abstract class CompilerSpec extends FlatSpec { def input: String def check: String def getOutput: String = { - compiler.compile(parse(input), new AnnotationMap(Seq.empty), writer) + compiler.compile(CircuitState(parse(input), ChirrtlForm), writer) writer.toString() } } diff --git a/src/test/scala/firrtlTests/CompilerUtilsSpec.scala b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala new file mode 100644 index 00000000..1d349db1 --- /dev/null +++ b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala @@ -0,0 +1,76 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.CompilerUtils.mergeTransforms + +class CompilerUtilsSpec extends FirrtlFlatSpec { + + def genTransform(_inputForm: CircuitForm, _outputForm: CircuitForm) = new Transform { + def inputForm = _inputForm + def outputForm = _outputForm + def execute(state: CircuitState): CircuitState = state + } + + // Core lowering transforms + val chirrtlToHigh = genTransform(ChirrtlForm, HighForm) + val highToMid = genTransform(HighForm, MidForm) + val midToLow = genTransform(MidForm, LowForm) + val chirrtlToLowList = List(chirrtlToHigh, highToMid, midToLow) + + // Custom transforms + val chirrtlToChirrtl = genTransform(ChirrtlForm, ChirrtlForm) + val highToHigh = genTransform(HighForm, HighForm) + val midToMid = genTransform(MidForm, MidForm) + val lowToLow = genTransform(LowForm, LowForm) + + val lowToHigh = genTransform(LowForm, HighForm) + + val lowToLowTwo = genTransform(LowForm, LowForm) + + behavior of "mergeTransforms" + + it should "do nothing if there are no custom transforms" in { + mergeTransforms(chirrtlToLowList, List.empty) should be (chirrtlToLowList) + } + + it should "insert transforms at the correct place" in { + mergeTransforms(chirrtlToLowList, List(chirrtlToChirrtl)) should be + (chirrtlToChirrtl +: chirrtlToLowList) + mergeTransforms(chirrtlToLowList, List(highToHigh)) should be + (List(chirrtlToHigh, highToHigh, highToMid, midToLow)) + mergeTransforms(chirrtlToLowList, List(midToMid)) should be + (List(chirrtlToHigh, highToMid, midToMid, midToLow)) + mergeTransforms(chirrtlToLowList, List(lowToLow)) should be + (chirrtlToLowList :+ lowToLow) + } + + it should "insert transforms at the last legal location" in { + lowToLow should not be (lowToLowTwo) // sanity check + mergeTransforms(chirrtlToLowList :+ lowToLow, List(lowToLowTwo)).last should be (lowToLowTwo) + } + + it should "insert multiple transforms correctly" in { + mergeTransforms(chirrtlToLowList, List(highToHigh, lowToLow)) should be + (List(chirrtlToHigh, highToHigh, highToMid, midToLow, lowToLow)) + } + + it should "handle transforms that raise the form" in { + mergeTransforms(chirrtlToLowList, List(lowToHigh)) match { + case chirrtlToHigh :: highToMid :: midToLow :: lowToHigh :: remainder => + // Remainder will be the actual Firrtl lowering transforms + remainder.head.inputForm should be (HighForm) + remainder.last.outputForm should be (LowForm) + case _ => fail() + } + } + + // Order is not always maintained, see note on function Scaladoc + it should "maintain order of custom tranforms" in { + mergeTransforms(chirrtlToLowList, List(lowToLow, lowToLowTwo)) should be + (chirrtlToLowList ++ List(lowToLow, lowToLowTwo)) + } + +} + diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index bfe58a2c..f6bfa5ef 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -22,7 +22,6 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { ResolveGenders, InferWidths, ConstProp) - def parse(input: String): Circuit = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def exec (input: String) = { passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/CustomTransformSpec.scala b/src/test/scala/firrtlTests/CustomTransformSpec.scala new file mode 100644 index 00000000..4a3faf6b --- /dev/null +++ b/src/test/scala/firrtlTests/CustomTransformSpec.scala @@ -0,0 +1,51 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl.ir.Circuit +import firrtl._ +import firrtl.passes.Pass +import firrtl.ir._ + +class CustomTransformSpec extends FirrtlFlatSpec { + behavior of "Custom Transforms" + + they should "be able to introduce high firrtl" in { + // Simple module + val delayModuleString = """ + |circuit Delay : + | module Delay : + | input clk : Clock + | input reset : UInt<1> + | input a : UInt<32> + | input en : UInt<1> + | output b : UInt<32> + | + | reg r : UInt<32>, clk + | r <= r + | when en : + | r <= a + | b <= r + |""".stripMargin + val delayModuleCircuit = parse(delayModuleString) + val delayModule = delayModuleCircuit.modules.find(_.name == delayModuleCircuit.main).get + + class ReplaceExtModuleTransform extends PassBasedTransform { + class ReplaceExtModule extends Pass { + def name = "Replace External Module" + def run(c: Circuit): Circuit = c.copy( + modules = c.modules map { + case ExtModule(_, "Delay", _, _, _) => delayModule + case other => other + } + ) + } + def passSeq = Seq(new ReplaceExtModule) + def inputForm = LowForm + def outputForm = HighForm + } + + runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform)) + } +} + diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala index 8bbecaeb..06963708 100644 --- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala +++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala @@ -36,7 +36,6 @@ import firrtl.ir._ import firrtl.Parser.IgnoreInfo class ExpandWhensSpec extends FirrtlFlatSpec { - private def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, notExpected: String, passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala index f491b0f5..83cccf3b 100644 --- a/src/test/scala/firrtlTests/FirrtlSpec.scala +++ b/src/test/scala/firrtlTests/FirrtlSpec.scala @@ -36,6 +36,7 @@ import org.scalatest.prop._ import scala.io.Source import firrtl._ +import firrtl.Parser.IgnoreInfo import firrtl.Annotations.AnnotationMap // This trait is borrowed from Chisel3, ideally this code should only exist in one location @@ -131,6 +132,7 @@ trait BackendCompilationUtilities { } trait FirrtlRunners extends BackendCompilationUtilities { + def parse(str: String) = Parser.parse(str.split("\n").toIterator, IgnoreInfo) lazy val cppHarness = new File(s"/top.cpp") /** Compile a Firrtl file * @@ -141,6 +143,7 @@ trait FirrtlRunners extends BackendCompilationUtilities { def compileFirrtlTest( prefix: String, srcDir: String, + customTransforms: Seq[Transform] = Seq.empty, annotations: AnnotationMap = new AnnotationMap(Seq.empty)): File = { val testDir = createTempDirectory(prefix) copyResourceToFile(s"${srcDir}/${prefix}.fir", new File(testDir, s"${prefix}.fir")) @@ -150,6 +153,7 @@ trait FirrtlRunners extends BackendCompilationUtilities { s"$testDir/$prefix.v", new VerilogCompiler(), Parser.IgnoreInfo, + customTransforms, annotations) testDir } @@ -164,8 +168,9 @@ trait FirrtlRunners extends BackendCompilationUtilities { prefix: String, srcDir: String, verilogPrefixes: Seq[String] = Seq.empty, + customTransforms: Seq[Transform] = Seq.empty, annotations: AnnotationMap = new AnnotationMap(Seq.empty)) = { - val testDir = compileFirrtlTest(prefix, srcDir, annotations) + val testDir = compileFirrtlTest(prefix, srcDir, customTransforms, annotations) val harness = new File(testDir, s"top.cpp") copyResourceToFile(cppHarness.toString, harness) diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index be663872..b6e8f726 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -61,19 +61,19 @@ class InferReadWriteSpec extends SimpleTransformSpec { } } - object InferReadWriteCheck extends Transform with SimpleRun { - def execute (c: Circuit, map: AnnotationMap) = - run(c, Seq(InferReadWriteCheckPass)) + class InferReadWriteCheck extends PassBasedTransform { + def inputForm = MidForm + def outputForm = MidForm + def passSeq = Seq(InferReadWriteCheckPass) } - def transforms (writer: java.io.Writer) = Seq( - new Chisel3ToHighFirrtl(), - new IRToWorkingIR(), - new ResolveAndCheck(), - new HighFirrtlToMiddleFirrtl(), - new memlib.InferReadWrite(TransID(-1)), - InferReadWriteCheck, - new EmitFirrtl(writer) + def transforms = Seq( + new ChirrtlToHighFirrtl, + new IRToWorkingIR, + new ResolveAndCheck, + new HighFirrtlToMiddleFirrtl, + new memlib.InferReadWrite, + new InferReadWriteCheck ) "Infer ReadWrite Ports" should "infer readwrite ports for the same clock" in { @@ -100,9 +100,9 @@ circuit sram6t : T_5 <= io.wdata """.stripMargin - val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t", TransID(-1)))) + val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t"))) val writer = new java.io.StringWriter - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) // Check correctness of firrtl parse(writer.toString) } @@ -132,10 +132,10 @@ circuit sram6t : T_5 <= io.wdata """.stripMargin - val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t", TransID(-1)))) + val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t"))) val writer = new java.io.StringWriter intercept[InferReadWriteCheckException] { - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) } } } diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 5f19af5c..f7845cc7 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -14,7 +14,6 @@ import firrtl.Annotations.{ CircuitName, ModuleName, ComponentName, - TransID, Annotation, AnnotationMap } @@ -24,9 +23,8 @@ import firrtl.passes.{InlineInstances, InlineAnnotation} /** * Tests inline instances transformation */ -class InlineInstancesTests extends HighTransformSpec { - val tID = TransID(0) - val transform = new InlineInstances(tID) +class InlineInstancesTests extends LowTransformSpec { + def transform = new InlineInstances "The module Inline" should "be inlined" in { val input = """circuit Top : @@ -48,14 +46,14 @@ class InlineInstancesTests extends HighTransformSpec { | wire i$a : UInt<32> | wire i$b : UInt<32> | i$b <= i$a - | i$a <= a | b <= i$b + | i$a <= a | module Inline : | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top"))))) execute(writer, aMap, input, check) } @@ -85,15 +83,15 @@ class InlineInstancesTests extends HighTransformSpec { | wire i1$a : UInt<32> | wire i1$b : UInt<32> | i1$b <= i1$a + | b <= i1$b | i0$a <= a | i1$a <= i0$b - | b <= i1$b | module Simple : | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top"))))) execute(writer, aMap, input, check) } @@ -121,15 +119,15 @@ class InlineInstancesTests extends HighTransformSpec { | wire i0$b : UInt<32> | i0$b <= i0$a | inst i1 of Simple + | b <= i1.b | i0$a <= a | i1.a <= i0$b - | b <= i1.b | module Simple : | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ComponentName("i0",ModuleName("Top", CircuitName("Top"))), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ComponentName("i0",ModuleName("Top", CircuitName("Top")))))) execute(writer, aMap, input, check) } @@ -163,9 +161,9 @@ class InlineInstancesTests extends HighTransformSpec { | wire i0$b : UInt<32> | i0$b <= i0$a | inst i1 of B + | b <= i1.b | i0$a <= a | i1.a <= i0$b - | b <= i1.b | module A : | input a : UInt<32> | output b : UInt<32> @@ -176,10 +174,10 @@ class InlineInstancesTests extends HighTransformSpec { | wire i$a : UInt<32> | wire i$b : UInt<32> | i$b <= i$a - | i$a <= a - | b <= i$b""".stripMargin + | b <= i$b + | i$a <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) execute(writer, aMap, input, check) } @@ -199,7 +197,7 @@ class InlineInstancesTests extends HighTransformSpec { | input a : UInt<32> | output b : UInt<32>""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) failingexecute(writer, aMap, input) } // 2) ext instance @@ -216,7 +214,7 @@ class InlineInstancesTests extends HighTransformSpec { | input a : UInt<32> | output b : UInt<32>""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) failingexecute(writer, aMap, input) } // 3) no module @@ -228,7 +226,7 @@ class InlineInstancesTests extends HighTransformSpec { | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) failingexecute(writer, aMap, input) } // 4) no inst @@ -240,7 +238,7 @@ class InlineInstancesTests extends HighTransformSpec { | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) failingexecute(writer, aMap, input) } } diff --git a/src/test/scala/firrtlTests/MultiThreadingSpec.scala b/src/test/scala/firrtlTests/MultiThreadingSpec.scala index bfaed330..b2934314 100644 --- a/src/test/scala/firrtlTests/MultiThreadingSpec.scala +++ b/src/test/scala/firrtlTests/MultiThreadingSpec.scala @@ -2,6 +2,8 @@ package firrtlTests +import firrtl.{ChirrtlForm, CircuitState, Compiler, Annotations} + import scala.concurrent.{Future, Await, ExecutionContext} import scala.concurrent.duration.Duration @@ -13,7 +15,7 @@ class MultiThreadingSpec extends FirrtlPropSpec { def runCompiler(input: Seq[String], compiler: firrtl.Compiler): String = { val writer = new java.io.StringWriter val parsedInput = firrtl.Parser.parse(input) - compiler.compile(parsedInput,new firrtl.Annotations.AnnotationMap(Seq.empty), writer) + compiler.compile(CircuitState(parsedInput, ChirrtlForm), writer) writer.toString } // The parameters we're testing with diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index e5269396..e574d31f 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -4,21 +4,27 @@ import com.typesafe.scalalogging.LazyLogging import java.io.{StringWriter,Writer} import org.scalatest.{FlatSpec, Matchers} import org.scalatest.junit.JUnitRunner -import firrtl.{Parser,FIRRTLEmitter} import firrtl.ir.Circuit import firrtl.Parser.IgnoreInfo -import firrtl.passes.{Pass, PassExceptions} +import firrtl.passes.{Pass, PassExceptions, RemoveEmpty} import firrtl.{ Transform, - TransformResult, + PassBasedTransform, + CircuitState, + CircuitForm, + ChirrtlForm, + HighForm, + MidForm, + LowForm, SimpleRun, - Chisel3ToHighFirrtl, + ChirrtlToHighFirrtl, IRToWorkingIR, ResolveAndCheck, HighFirrtlToMiddleFirrtl, MiddleFirrtlToLowFirrtl, - EmitFirrtl, - Compiler + FirrtlEmitter, + Compiler, + Parser } import firrtl.Annotations.AnnotationMap @@ -26,58 +32,66 @@ import firrtl.Annotations.AnnotationMap // An example methodology for testing Firrtl Passes // Spec class should extend this class abstract class SimpleTransformSpec extends FlatSpec with Matchers with Compiler with LazyLogging { + def emitter = new FirrtlEmitter + // Utility function def parse(s: String): Circuit = Parser.parse(s.split("\n").toIterator, infoMode = IgnoreInfo) // Executes the test. Call in tests. def execute(writer: Writer, annotations: AnnotationMap, input: String, check: String) = { - compile(parse(input), annotations, writer) - logger.debug(writer.toString) - logger.debug(check) - (parse(writer.toString)) should be (parse(check)) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer) + val actual = RemoveEmpty.run(parse(writer.toString)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be (expected) } // Executes the test, should throw an error def failingexecute(writer: Writer, annotations: AnnotationMap, input: String): Exception = { intercept[PassExceptions] { - compile(parse(input), annotations, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer) } } } +class CustomResolveAndCheck(form: CircuitForm) extends PassBasedTransform { + private val wrappedTransform = new ResolveAndCheck + def inputForm = form + def outputForm = form + def passSeq = wrappedTransform.passSeq +} + trait LowTransformSpec extends SimpleTransformSpec { def transform: Transform - def transforms (writer: Writer) = Seq( - new Chisel3ToHighFirrtl(), + def transforms = Seq( + new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), new MiddleFirrtlToLowFirrtl(), - new ResolveAndCheck(), - transform, - new EmitFirrtl(writer) + new CustomResolveAndCheck(LowForm), + transform ) } trait MiddleTransformSpec extends SimpleTransformSpec { def transform: Transform - def transforms (writer: Writer) = Seq( - new Chisel3ToHighFirrtl(), + def transforms = Seq( + new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), - new ResolveAndCheck(), - transform, - new EmitFirrtl(writer) + new CustomResolveAndCheck(MidForm), + transform ) } trait HighTransformSpec extends SimpleTransformSpec { def transform: Transform - def transforms (writer: Writer) = Seq( - new Chisel3ToHighFirrtl(), + def transforms = Seq( + new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), - transform, - new EmitFirrtl(writer) + transform ) } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 78b3d9f0..e46230ef 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -6,19 +6,19 @@ import firrtl.passes.memlib._ import Annotations._ class ReplSeqMemSpec extends SimpleTransformSpec { - val passSeq = Seq( - ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) - def transforms (writer: java.io.Writer) = Seq( - new Chisel3ToHighFirrtl(), + def transforms = Seq( + new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), - new InferReadWrite(TransID(-1)), - new ReplSeqMem(TransID(-2)), + new InferReadWrite(), + new ReplSeqMem(), new MiddleFirrtlToLowFirrtl(), - (new Transform with SimpleRun { - def execute(c: ir.Circuit, a: AnnotationMap) = run(c, passSeq) } ), - new EmitFirrtl(writer) + new PassBasedTransform { + def inputForm = LowForm + def outputForm = LowForm + def passSeq = Seq(ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) + } ) "ReplSeqMem" should "generate blackbox wrappers for mems of bundle type" in { @@ -58,9 +58,9 @@ circuit Top : io2.commit_entry.bits.info <- R1 """.stripMargin val confLoc = "ReplSeqMemTests.confTEMP" - val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc, TransID(-2)))) + val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc))) val writer = new java.io.StringWriter - compile(parse(input), aMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer) // Check correctness of firrtl parse(writer.toString) (new java.io.File(confLoc)).delete() @@ -81,9 +81,9 @@ circuit Top : write mport T_155 = mem[p_address], clk """.stripMargin val confLoc = "ReplSeqMemTests.confTEMP" - val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc, TransID(-2)))) + val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc))) val writer = new java.io.StringWriter - compile(parse(input), aMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer) // Check correctness of firrtl parse(writer.toString) (new java.io.File(confLoc)).delete() diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 245c32e8..1025c02b 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -36,7 +36,6 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class UnitTests extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) @@ -114,7 +113,7 @@ class UnitTests extends FirrtlFlatSpec { (c: Circuit, p: Pass) => p.run(c) } val writer = new StringWriter() - FIRRTLEmitter.run(c_result,writer) + (new FirrtlEmitter).emit(CircuitState(c_result, HighForm), writer) (parse(writer.toString())) should be (parse(check)) } @@ -136,7 +135,7 @@ class UnitTests extends FirrtlFlatSpec { intercept[PassException] { val c = Parser.parse(splitExpTestCode.split("\n").toIterator) val c2 = passes.foldLeft(c)((c, p) => p run c) - new VerilogEmitter().run(c2, new OutputStreamWriter(new ByteArrayOutputStream)) + (new VerilogEmitter).emit(CircuitState(c2, LowForm), new StringWriter) } } @@ -147,7 +146,7 @@ class UnitTests extends FirrtlFlatSpec { InferTypes) val c = Parser.parse(splitExpTestCode.split("\n").toIterator) val c2 = passes.foldLeft(c)((c, p) => p run c) - new VerilogEmitter().run(c2, new OutputStreamWriter(new ByteArrayOutputStream)) + (new VerilogEmitter).emit(CircuitState(c2, LowForm), new StringWriter) } "Simple compound expressions" should "be split" in { diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index 1f6142bc..e9bf5429 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -37,10 +37,9 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class DoPrimVerilog extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], compiler: Compiler) = { val writer = new StringWriter() - compiler.compile(parse(input), new AnnotationMap(Seq.empty), writer) + compiler.compile(CircuitState(parse(input), ChirrtlForm), writer) val lines = writer.toString().split("\n") map normalized expected foreach { e => lines should contain(e) diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala index d1b16bb9..74f6432f 100644 --- a/src/test/scala/firrtlTests/WidthSpec.scala +++ b/src/test/scala/firrtlTests/WidthSpec.scala @@ -36,7 +36,6 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class WidthSpec extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala index 5f40d861..309014d4 100644 --- a/src/test/scala/firrtlTests/WiringTests.scala +++ b/src/test/scala/firrtlTests/WiringTests.scala @@ -12,7 +12,6 @@ import wiring.WiringUtils._ import wiring._ class WiringTests extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala index a9a1bb47..4a87290d 100644 --- a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala @@ -5,12 +5,11 @@ package firrtlTests.fixed import java.io.StringWriter import firrtl.Annotations.AnnotationMap -import firrtl.{LowFirrtlCompiler, Parser} +import firrtl.{CircuitState, ChirrtlForm, LowFirrtlCompiler, Parser} import firrtl.Parser.IgnoreInfo import firrtlTests.FirrtlFlatSpec class FixedPointMathSpec extends FirrtlFlatSpec { - def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) val SumPattern = """.*output sum.*<(\d+)>.*.*""".r val ProductPattern = """.*output product.*<(\d+)>.*""".r @@ -45,7 +44,7 @@ class FixedPointMathSpec extends FirrtlFlatSpec { val writer = new StringWriter() - lowerer.compile(parse(input), new AnnotationMap(Seq.empty), writer) + lowerer.compile(CircuitState(parse(input), ChirrtlForm), writer) val output = writer.toString.split("\n") diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala index 53b4f4c0..3f465361 100644 --- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala @@ -34,7 +34,6 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class FixedTypeInferenceSpec extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala index 6799a367..27d7e172 100644 --- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala +++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala @@ -35,7 +35,6 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class RemoveFixedTypeSpec extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) @@ -204,14 +203,14 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | io_out <= io_in """.stripMargin - class CheckChirrtlTransform extends Transform with SimpleRun { + class CheckChirrtlTransform extends PassBasedTransform { + def inputForm = ChirrtlForm + def outputForm = ChirrtlForm val passSeq = Seq(passes.CheckChirrtl) - def execute (circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } val chirrtlTransform = new CheckChirrtlTransform - chirrtlTransform.execute(parse(input), new AnnotationMap(Seq.empty)) + chirrtlTransform.execute(CircuitState(parse(input), ChirrtlForm, Some(new AnnotationMap(Seq.empty)))) } } |
