aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/CustomTransformSpec.scala
blob: 1b0e8190f9c7154fd62dd699fec52cd4ea5aef4f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// See LICENSE for license details.

package firrtlTests

import firrtl.ir.Circuit
import firrtl._
import firrtl.passes.Pass
import firrtl.ir._

class CustomTransformSpec extends FirrtlFlatSpec {
  behavior of "Custom Transforms"

  they should "be able to introduce high firrtl" in {
    // Simple module
    val delayModuleString = """
      |circuit Delay :
      |  module Delay :
      |    input clock : Clock
      |    input reset : UInt<1>
      |    input a : UInt<32>
      |    input en : UInt<1>
      |    output b : UInt<32>
      |
      |    reg r : UInt<32>, clock
      |    r <= r
      |    when en :
      |      r <= a
      |    b <= r
      |""".stripMargin
    val delayModuleCircuit = parse(delayModuleString)
    val delayModule = delayModuleCircuit.modules.find(_.name == delayModuleCircuit.main).get

    class ReplaceExtModuleTransform extends SeqTransform {
      class ReplaceExtModule extends Pass {
        def run(c: Circuit): Circuit = c.copy(
          modules = c.modules map {
            case ExtModule(_, "Delay", _, _, _) => delayModule
            case other => other
          }
        )
      }
      def transforms = Seq(new ReplaceExtModule)
      def inputForm = LowForm
      def outputForm = HighForm
    }

    runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform))
  }

  they should "not cause \"Internal Errors\"" in {
    val input = """
      |circuit test :
      |  module test :
      |    output out : UInt
      |    out <= UInt(123)""".stripMargin
    val errorString = "My Custom Transform failed!"
    class ErroringTransform extends Transform {
      def inputForm = HighForm
      def outputForm = HighForm
      def execute(state: CircuitState): CircuitState = {
        require(false, errorString)
        state
      }
    }
    val optionsManager = new ExecutionOptionsManager("test") with HasFirrtlOptions {
      firrtlOptions = FirrtlExecutionOptions(
        firrtlSource = Some(input),
        customTransforms = List(new ErroringTransform))
    }
    (the [java.lang.IllegalArgumentException] thrownBy {
      Driver.execute(optionsManager)
    }).getMessage should include (errorString)
  }
}