diff options
Diffstat (limited to 'src/test')
| -rw-r--r-- | src/test/scala/firrtlTests/RegisterUpdateSpec.scala | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/src/test/scala/firrtlTests/RegisterUpdateSpec.scala b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala new file mode 100644 index 00000000..dfef5955 --- /dev/null +++ b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala @@ -0,0 +1,100 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.ir._ +import firrtl.transforms.FlattenRegUpdate +import firrtl.annotations.NoTargetAnnotation +import firrtl.stage.transforms.Compiler +import firrtl.options.Dependency +import firrtl.testutils._ +import FirrtlCheckers._ +import scala.util.matching.Regex + +object RegisterUpdateSpec { + case class CaptureStateAnno(value: CircuitState) extends NoTargetAnnotation + // Capture the CircuitState between FlattenRegUpdate and VerilogEmitter + //Emit captured state as FIRRTL for use in testing + class CaptureCircuitState extends Transform with DependencyAPIMigration { + override def prerequisites = Dependency[FlattenRegUpdate] :: Nil + override def optionalPrerequisiteOf = Dependency[VerilogEmitter] :: Nil + override def invalidates(a: Transform): Boolean = false + def execute(state: CircuitState): CircuitState = { + val emittedAnno = EmittedFirrtlCircuitAnnotation( + EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, ".fir")) + val capturedState = state.copy(annotations = emittedAnno +: state.annotations) + state.copy(annotations = CaptureStateAnno(capturedState) +: state.annotations) + } + } +} + +class RegisterUpdateSpec extends FirrtlFlatSpec { + import RegisterUpdateSpec._ + def compile(input: String): CircuitState = { + val compiler = new Compiler(Seq(Dependency[CaptureCircuitState], Dependency[VerilogEmitter])) + compiler.execute(CircuitState(parse(input), EmitCircuitAnnotation(classOf[VerilogEmitter]) :: Nil)) + } + def compileBody(body: String) = { + val str = """ + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + compile(str) + } + + "Register update logic" should "not duplicate common subtrees" in { + val result = compileBody(s""" + |input clock : Clock + |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} + |reg r : UInt<8>, clock + |when io.a : + | r <= io.in + |when io.b : + | when io.c : + | r <= UInt(2) + |io.out <= r""".stripMargin + ) + // Checking intermediate state between FlattenRegUpdate and Verilog emission + val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get + fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") + // Checking the Verilog + val verilog = result.getEmittedCircuit.value + result shouldNot containLine ("r <= io_in;") + verilog shouldNot include ("if (io_a) begin") + result should containLine ("r <= _GEN_0;") + } + + it should "not let duplicate subtrees on one register affect another" in { + + val result = compileBody(s""" + |input clock : Clock + |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} + + |reg r : UInt<8>, clock + |reg r2 : UInt<8>, clock + |when io.a : + | r <= io.in + | r2 <= io.in + |when io.b : + | r2 <= UInt(3) + | when io.c : + | r <= UInt(2) + |io.out <= and(r, r2)""".stripMargin + ) + // Checking intermediate state between FlattenRegUpdate and Verilog emission + val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get + fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") + fstate should containLine ("""r2 <= mux(io_b, UInt<8>("h3"), mux(io_a, io_in, r2))""") + // Checking the Verilog + val verilog = result.getEmittedCircuit.value + result shouldNot containLine ("r <= io_in;") + result should containLine ("r <= _GEN_0;") + result should containLine ("r2 <= io_in;") + verilog should include ("if (io_a) begin") // For r2 + // 1 time for r2, old versions would have 3 occurences + Regex.quote("if (io_a) begin").r.findAllMatchIn(verilog).size should be (1) + } + +} + |
