aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala
blob: e5226226d7f0df3bce22a67f127c5776d4ecd62f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
package firrtl.backends.experimental.smt.random

import firrtl.options.Dependency
import firrtl.testutils.LeanTransformSpec

class InvalidToRandomSpec extends LeanTransformSpec(Seq(Dependency(InvalidToRandomPass))) {
  behavior.of("InvalidToRandomPass")

  val src1 =
    s"""
       |circuit Test:
       |  module Test:
       |    input a : UInt<2>
       |    output o : UInt<8>
       |    output o2 : UInt<8>
       |    output o3 : UInt<8>
       |
       |    o is invalid
       |
       |    when eq(a, UInt(3)):
       |      o <= UInt(5)
       |
       |    o2 is invalid
       |    node o2_valid = eq(a, UInt(2))
       |    when o2_valid:
       |      o2 <= UInt(7)
       |
       |    o3 is invalid
       |    o3 <= UInt(3)
       |""".stripMargin

  it should "model invalid signals as random" in {

    val circuit = compile(src1, List()).circuit
    //println(circuit.serialize)
    val result = circuit.serialize.split('\n').map(_.trim)

    // the condition should end up as a new node if it wasn't a reference already
    assert(result.contains("node _GEN_0_invalid_cond = not(eq(a, UInt<2>(\"h3\")))"))
    assert(result.contains("node o2_valid = eq(a, UInt<2>(\"h2\"))"))

    // every invalid results in a random statement
    assert(result.contains("rand _GEN_0_invalid : UInt<3> when _GEN_0_invalid_cond"))
    assert(result.contains("rand _GEN_1_invalid : UInt<3> when not(o2_valid)"))

    // the random value is conditionally assigned
    assert(result.contains("node _GEN_0 = mux(_GEN_0_invalid_cond, _GEN_0_invalid, UInt<3>(\"h5\"))"))
    assert(result.contains("node _GEN_1 = mux(not(o2_valid), _GEN_1_invalid, UInt<3>(\"h7\"))"))

    // expressions that are trivially valid do not get randomized
    assert(result.contains("o3 <= UInt<8>(\"h3\")"))
    val defRandCount = result.count(_.contains("rand "))
    assert(defRandCount == 2)
  }

}