aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/formal/AssertSubmoduleAssumptionsSpec.scala
blob: 9d7bff3da606e8382feb489d33880af55a31e56c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// SPDX-License-Identifier: Apache-2.0

package firrtlTests.formal

import firrtl.{CircuitState, Parser, Transform, UnknownForm}
import firrtl.testutils.FirrtlFlatSpec
import firrtl.transforms.formal.AssertSubmoduleAssumptions
import firrtl.stage.{Forms, TransformManager}

class AssertSubmoduleAssumptionsSpec extends FirrtlFlatSpec {
  behavior.of("AssertSubmoduleAssumptions")

  val transforms = new TransformManager(Forms.HighForm, Forms.MinimalHighForm).flattenedTransformOrder ++ Seq(
    new AssertSubmoduleAssumptions
  )

  def run(input: String, check: Seq[String], debug: Boolean = false): Unit = {
    val circuit = Parser.parse(input.split("\n").toIterator)
    val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) =>
      p.runTransform(c)
    }
    val lines = result.circuit.serialize.split("\n").map(normalized)

    if (debug) {
      println(lines.mkString("\n"))
    }

    for (ch <- check) {
      lines should contain(ch)
    }
  }

  it should "convert `assume` to `assert` in a submodule" in {
    val input =
      """circuit Test :
        |  module Test :
        |    input clock : Clock
        |    input reset : UInt<1>
        |    input in : UInt<8>
        |    output out : UInt<8>
        |    inst sub of Sub
        |    sub.clock <= clock
        |    sub.reset <= reset
        |    sub.in <= in
        |    out <= sub.out
        |    assume(clock, eq(in, UInt(0)), UInt(1), "assume0")
        |    assert(clock, eq(out, UInt(0)), UInt(1), "assert0")
        |
        |  module Sub :
        |    input clock : Clock
        |    input reset : UInt<1>
        |    input in : UInt<8>
        |    output out : UInt<8>
        |    out <= in
        |    assume(clock, eq(in, UInt(1)), UInt(1), "assume1")
        |    assert(clock, eq(out, UInt(1)), UInt(1), "assert1")
        |""".stripMargin

    val check = Seq(
      "assert(clock, eq(in, UInt<1>(\"h1\")), UInt<1>(\"h1\"), \"assume1\")"
    )
    run(input, check)
  }

  it should "convert `assume` to `assert` in a nested submodule" in {
    val input =
      """circuit Test :
        |  module Test :
        |    input clock : Clock
        |    input reset : UInt<1>
        |    input in : UInt<8>
        |    output out : UInt<8>
        |    inst sub of Sub
        |    sub.clock <= clock
        |    sub.reset <= reset
        |    sub.in <= in
        |    out <= sub.out
        |    assume(clock, eq(in, UInt(0)), UInt(1), "assume0")
        |    assert(clock, eq(out, UInt(0)), UInt(1), "assert0")
        |
        |  module Sub :
        |    input clock : Clock
        |    input reset : UInt<1>
        |    input in : UInt<8>
        |    output out : UInt<8>
        |    inst nestedSub of NestedSub
        |    nestedSub.clock <= clock
        |    nestedSub.reset <= reset
        |    nestedSub.in <= in
        |    out <= nestedSub.out
        |    assume(clock, eq(in, UInt(1)), UInt(1), "assume1")
        |    assert(clock, eq(out, UInt(1)), UInt(1), "assert1")
        |
        |  module NestedSub :
        |    input clock : Clock
        |    input reset : UInt<1>
        |    input in : UInt<8>
        |    output out : UInt<8>
        |    out <= in
        |    assume(clock, eq(in, UInt(2)), UInt(1), "assume2")
        |    assert(clock, eq(out, UInt(2)), UInt(1), "assert2")
        |""".stripMargin

    val check = Seq(
      "assert(clock, eq(in, UInt<1>(\"h1\")), UInt<1>(\"h1\"), \"assume1\")",
      "assert(clock, eq(in, UInt<2>(\"h2\")), UInt<1>(\"h1\"), \"assume2\")"
    )
    run(input, check)
  }
}