aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/RegisterUpdateSpec.scala
blob: 3b03e12768a27c22b006fea45f9070a4b18987f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// SPDX-License-Identifier: Apache-2.0

package firrtlTests

import firrtl._
import firrtl.ir._
import firrtl.transforms.FlattenRegUpdate
import firrtl.annotations.NoTargetAnnotation
import firrtl.stage.transforms.Compiler
import firrtl.options.Dependency
import firrtl.testutils._
import FirrtlCheckers._
import scala.util.matching.Regex

object RegisterUpdateSpec {
  case class CaptureStateAnno(value: CircuitState) extends NoTargetAnnotation
  // Capture the CircuitState between FlattenRegUpdate and VerilogEmitter
  //Emit captured state as FIRRTL for use in testing
  class CaptureCircuitState extends Transform with DependencyAPIMigration {
    override def prerequisites = Dependency[FlattenRegUpdate] :: Nil
    override def optionalPrerequisiteOf = Dependency[VerilogEmitter] :: Nil
    override def invalidates(a: Transform): Boolean = false
    def execute(state: CircuitState): CircuitState = {
      val emittedAnno = EmittedFirrtlCircuitAnnotation(
        EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, ".fir")
      )
      val capturedState = state.copy(annotations = emittedAnno +: state.annotations)
      state.copy(annotations = CaptureStateAnno(capturedState) +: state.annotations)
    }
  }
}

class RegisterUpdateSpec extends FirrtlFlatSpec {
  import RegisterUpdateSpec._
  def compile(input: String): CircuitState = {
    val compiler = new Compiler(Seq(Dependency[CaptureCircuitState], Dependency[VerilogEmitter]))
    compiler.execute(CircuitState(parse(input), EmitCircuitAnnotation(classOf[VerilogEmitter]) :: Nil))
  }
  def compileBody(body: String) = {
    val str = """
                |circuit Test :
                |  module Test :
                |""".stripMargin + body.split("\n").mkString("    ", "\n    ", "")
    compile(str)
  }

  "Register update logic" should "not duplicate common subtrees" in {
    val result = compileBody(s"""
                                |input clock : Clock
                                |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>}
                                |reg r : UInt<8>, clock
                                |when io.a :
                                |  r <= io.in
                                |when io.b :
                                |  when io.c :
                                |    r <= UInt(2)
                                |io.out <= r""".stripMargin)
    // Checking intermediate state between FlattenRegUpdate and Verilog emission
    val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get
    fstate should containLine("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""")
    // Checking the Verilog
    val verilog = result.getEmittedCircuit.value
    result shouldNot containLine("r <= io_in;")
    verilog shouldNot include("if (io_a) begin")
    result should containLine("r <= _GEN_0;")
  }

  it should "not let duplicate subtrees on one register affect another" in {

    val result = compileBody(s"""
                                |input clock : Clock
                                |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>}

                                |reg r : UInt<8>, clock
                                |reg r2 : UInt<8>, clock
                                |when io.a :
                                |  r <= io.in
                                |  r2 <= io.in
                                |when io.b :
                                |  r2 <= UInt(3)
                                |  when io.c :
                                |    r <= UInt(2)
                                |io.out <= and(r, r2)""".stripMargin)
    // Checking intermediate state between FlattenRegUpdate and Verilog emission
    val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get
    fstate should containLine("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""")
    fstate should containLine("""r2 <= mux(io_b, UInt<8>("h3"), mux(io_a, io_in, r2))""")
    // Checking the Verilog
    val verilog = result.getEmittedCircuit.value
    result shouldNot containLine("r <= io_in;")
    result should containLine("r <= _GEN_0;")
    result should containLine("r2 <= io_in;")
    verilog should include("if (io_a) begin") // For r2
    // 1 time for r2, old versions would have 3 occurences
    Regex.quote("if (io_a) begin").r.findAllMatchIn(verilog).size should be(1)
  }

}