aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala
blob: b9d92a6aff6cdfd50b9f762db47bc54f125e11b3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// See LICENSE for license details.

package firrtlTests.transforms

import org.scalatest.GivenWhenThen

import firrtlTests.FirrtlFlatSpec
import firrtlTests.FirrtlCheckers._

import firrtl.{CircuitState, WRef}
import firrtl.ir.{Connect, Mux}
import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage}

class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen {

  private def toLowFirrtl(string: String): CircuitState = {
    When("the circuit is compiled to low FIRRTL")
    (new FirrtlStage)
      .execute(Array("-X", "low"), Seq(FirrtlSourceAnnotation(string)))
      .collectFirst{ case FirrtlCircuitAnnotation(a) => a }
      .map(a => firrtl.CircuitState(a, firrtl.UnknownForm))
      .get
  }

  behavior of "RemoveReset"

  it should "not generate a reset mux for an invalid init" in {
    Given("a 1-bit register 'foo' initialized to invalid, 1-bit wire 'bar'")
    val input =
      """|circuit Example :
         |  module Example :
         |    input clock : Clock
         |    input rst : UInt<1>
         |    input in : UInt<1>
         |    output out : UInt<1>
         |
         |    wire bar : UInt<1>
         |    bar is invalid
         |
         |    reg foo : UInt<1>, clock with : (reset => (rst, bar))
         |    foo <= in
         |    out <= foo""".stripMargin

    val outputState = toLowFirrtl(input)

    Then("'foo' is NOT connected to a reset mux")
    outputState shouldNot containTree { case Connect(_, WRef("foo",_,_,_), Mux(_,_,_,_)) => true }
  }

  it should "generate a reset mux for only the portion of an invalid aggregate that is reset" in {
    Given("aggregate register 'foo' with 2-bit field 'a' and 1-bit field 'b'")
    And("aggregate, invalid wire 'bar' with the same fields")
    And("'foo' is initialized to 'bar'")
    And("'bar.a[1]' connected to zero")
    val input =
      """|circuit Example :
         |  module Example :
         |    input clock : Clock
         |    input rst : UInt<1>
         |    input in :  {a : UInt<1>[2], b : UInt<1>}
         |    output out :  {a : UInt<1>[2], b : UInt<1>}
         |
         |    wire bar : {a : UInt<1>[2], b : UInt<1>}
         |    bar is invalid
         |    bar.a[1] <= UInt<1>(0)
         |
         |    reg foo :  {a : UInt<1>[2], b : UInt<1>}, clock with : (reset => (rst, bar))
         |    foo <= in
         |    out <= foo""".stripMargin

    val outputState = toLowFirrtl(input)

    Then("foo.a[0] is NOT connected to a reset mux")
    outputState shouldNot containTree { case Connect(_, WRef("foo_a_0",_,_,_), Mux(_,_,_,_)) => true }
    And("foo.a[1] is connected to a reset mux")
    outputState should    containTree { case Connect(_, WRef("foo_a_1",_,_,_), Mux(_,_,_,_)) => true }
    And("foo.b is NOT connected to a reset mux")
    outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_),   Mux(_,_,_,_)) => true }
  }

  it should "propagate invalidations across connects" in {
    Given("aggregate register 'foo' with 1-bit field 'a' and 1-bit field 'b'")
    And("aggregate, invalid wires 'bar' and 'baz' with the same fields")
    And("'foo' is initialized to 'baz'")
    And("'bar.a' is connected to zero")
    And("'baz' is connected to 'bar'")
    val input =
      """|circuit Example :
         |  module Example :
         |    input clock : Clock
         |    input rst : UInt<1>
         |    input in : { a : UInt<1>, b : UInt<1> }
         |    output out : { a : UInt<1>, b : UInt<1> }
         |
         |    wire bar : { a : UInt<1>, b : UInt<1> }
         |    bar is invalid
         |    bar.a <= UInt<1>(0)
         |
         |    wire baz : { a : UInt<1>, b : UInt<1> }
         |    baz is invalid
         |    baz <= bar
         |
         |    reg foo : { a : UInt<1>, b : UInt<1> }, clock with : (reset => (rst, baz))
         |    foo <= in
         |    out <= foo""".stripMargin

    val outputState = toLowFirrtl(input)

    Then("'foo.a' is connected to a reset mux")
    outputState should    containTree { case Connect(_, WRef("foo_a",_,_,_), Mux(_,_,_,_)) => true }
    And("'foo.b' is NOT connected to a reset mux")
    outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_), Mux(_,_,_,_)) => true }
  }

}