aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/CustomTransformSpec.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/test/scala/firrtlTests/CustomTransformSpec.scala')
-rw-r--r--src/test/scala/firrtlTests/CustomTransformSpec.scala153
1 files changed, 120 insertions, 33 deletions
diff --git a/src/test/scala/firrtlTests/CustomTransformSpec.scala b/src/test/scala/firrtlTests/CustomTransformSpec.scala
index 04cbf276..809f2b1e 100644
--- a/src/test/scala/firrtlTests/CustomTransformSpec.scala
+++ b/src/test/scala/firrtlTests/CustomTransformSpec.scala
@@ -6,12 +6,15 @@ import firrtl.ir.Circuit
import firrtl._
import firrtl.passes.Pass
import firrtl.ir._
-import firrtl.stage.{FirrtlSourceAnnotation, FirrtlStage, RunFirrtlTransformAnnotation}
+import firrtl.stage.{FirrtlSourceAnnotation, FirrtlStage, Forms, RunFirrtlTransformAnnotation}
+import firrtl.options.Dependency
+import firrtl.transforms.IdentityTransform
-class CustomTransformSpec extends FirrtlFlatSpec {
- behavior of "Custom Transforms"
+import scala.reflect.runtime
- they should "be able to introduce high firrtl" in {
+object CustomTransformSpec {
+
+ class ReplaceExtModuleTransform extends SeqTransform with FirrtlMatchers {
// Simple module
val delayModuleString = """
|circuit Delay :
@@ -31,38 +34,99 @@ class CustomTransformSpec extends FirrtlFlatSpec {
val delayModuleCircuit = parse(delayModuleString)
val delayModule = delayModuleCircuit.modules.find(_.name == delayModuleCircuit.main).get
- class ReplaceExtModuleTransform extends SeqTransform {
- class ReplaceExtModule extends Pass {
- def run(c: Circuit): Circuit = c.copy(
- modules = c.modules map {
- case ExtModule(_, "Delay", _, _, _) => delayModule
- case other => other
- }
- )
- }
- def transforms = Seq(new ReplaceExtModule)
- def inputForm = LowForm
- def outputForm = HighForm
+ class ReplaceExtModule extends Pass {
+ def run(c: Circuit): Circuit = c.copy(
+ modules = c.modules map {
+ case ExtModule(_, "Delay", _, _, _) => delayModule
+ case other => other
+ }
+ )
}
-
- runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform))
+ def transforms = Seq(new ReplaceExtModule)
+ def inputForm = LowForm
+ def outputForm = HighForm
}
- they should "not cause \"Internal Errors\"" in {
- val input = """
+ val input = """
|circuit test :
| module test :
| output out : UInt
| out <= UInt(123)""".stripMargin
- val errorString = "My Custom Transform failed!"
- class ErroringTransform extends Transform {
+ val errorString = "My Custom Transform failed!"
+ class ErroringTransform extends Transform {
+ def inputForm = HighForm
+ def outputForm = HighForm
+ def execute(state: CircuitState): CircuitState = {
+ require(false, errorString)
+ state
+ }
+ }
+
+ object MutableState {
+ var count: Int = 0
+ }
+
+ class FirstTransform extends Transform {
+ def inputForm = HighForm
+ def outputForm = HighForm
+
+ def execute(state: CircuitState): CircuitState = {
+ require(MutableState.count == 0, s"Count was ${MutableState.count}, expected 0")
+ MutableState.count = 1
+ state
+ }
+ }
+
+ class SecondTransform extends Transform {
+ def inputForm = HighForm
+ def outputForm = HighForm
+
+ def execute(state: CircuitState): CircuitState = {
+ require(MutableState.count == 1, s"Count was ${MutableState.count}, expected 1")
+ MutableState.count = 2
+ state
+ }
+ }
+
+ class ThirdTransform extends Transform {
+ def inputForm = HighForm
+ def outputForm = HighForm
+
+ def execute(state: CircuitState): CircuitState = {
+ require(MutableState.count == 2, s"Count was ${MutableState.count}, expected 2")
+ MutableState.count = 3
+ state
+ }
+ }
+
+ class IdentityLowForm extends IdentityTransform(LowForm) {
+ override val name = ">>>>> IdentityLowForm <<<<<"
+ }
+
+ object Foo {
+ class A extends Transform {
def inputForm = HighForm
def outputForm = HighForm
- def execute(state: CircuitState): CircuitState = {
- require(false, errorString)
- state
+ def execute(s: CircuitState) = {
+ assert(name.endsWith("A"))
+ s
}
}
+ }
+
+}
+
+class CustomTransformSpec extends FirrtlFlatSpec {
+
+ import CustomTransformSpec._
+
+ behavior of "Custom Transforms"
+
+ they should "be able to introduce high firrtl" in {
+ runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform))
+ }
+
+ they should "not cause \"Internal Errors\"" in {
val optionsManager = new ExecutionOptionsManager("test") with HasFirrtlOptions {
firrtlOptions = FirrtlExecutionOptions(
firrtlSource = Some(input),
@@ -73,15 +137,38 @@ class CustomTransformSpec extends FirrtlFlatSpec {
}).getMessage should include (errorString)
}
- object Foo {
- class A extends Transform {
- def inputForm = HighForm
- def outputForm = HighForm
- def execute(s: CircuitState) = {
- assert(name.endsWith("A"))
- s
- }
+ they should "preserve the input order" in {
+ runFirrtlTest("CustomTransform", "/features", customTransforms = List(
+ new FirstTransform,
+ new SecondTransform,
+ new ThirdTransform,
+ new ReplaceExtModuleTransform
+ ))
+ }
+
+ they should "run right before the emitter when inputForm=LowForm" in {
+
+ val custom = Dependency[IdentityLowForm]
+
+ def testOrder(emitter: Dependency[Emitter], preceders: Seq[Dependency[Transform]]): Unit = {
+ info(s"""${preceders.map(_.getSimpleName).mkString(" -> ")} -> ${custom.getSimpleName} -> ${emitter.getSimpleName} ok!""")
+
+ val compiler = new firrtl.stage.transforms.Compiler(Seq(custom, emitter))
+ info("Transform Order: \n" + compiler.prettyPrint(" "))
+
+ val expectedSlice = preceders ++ Seq(custom, emitter)
+
+ compiler
+ .flattenedTransformOrder
+ .map(Dependency.fromTransform(_))
+ .containsSlice(expectedSlice) should be (true)
}
+
+ Seq( (Dependency[LowFirrtlEmitter], Seq(Forms.LowForm.last) ),
+ (Dependency[MinimumVerilogEmitter], Seq(Forms.LowFormMinimumOptimized.last) ),
+ (Dependency[VerilogEmitter], Seq(Forms.LowFormOptimized.last) ),
+ (Dependency[SystemVerilogEmitter], Seq(Forms.LowFormOptimized.last) )
+ ).foreach((testOrder _).tupled)
}
they should "work if placed inside an object" in {