aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/InferReadWriteSpec.scala
blob: 62969df5f2579af855d9fcb3504fb010faf79ead (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
// SPDX-License-Identifier: Apache-2.0

package firrtlTests

import firrtl._
import firrtl.ir._
import firrtl.passes._
import firrtl.stage.Forms
import firrtl.testutils._
import firrtl.testutils.FirrtlCheckers._

class InferReadWriteSpec extends SimpleTransformSpec {
  class InferReadWriteCheckException extends PassException("Readwrite ports are not found!")

  object InferReadWriteCheck extends Pass {
    override def prerequisites = Forms.MidForm
    override def optionalPrerequisites = Seq.empty
    override def optionalPrerequisiteOf = Forms.MidEmitters
    override def invalidates(a: Transform) = false

    def findReadWrite(s: Statement): Boolean = s match {
      case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 =>
        s.name == "mem" && s.readwriters.head == "rw"
      case s: Block =>
        s.stmts.exists(findReadWrite)
      case _ => false
    }

    def run(c: Circuit) = {
      val errors = new Errors
      val foundReadWrite = c.modules.exists {
        case m: Module    => findReadWrite(m.body)
        case m: ExtModule => false
      }
      if (!foundReadWrite) {
        errors.append(new InferReadWriteCheckException)
        errors.trigger
      }
      c
    }
  }

  def emitter = new MiddleFirrtlEmitter
  def transforms = Seq(
    new ChirrtlToHighFirrtl,
    new IRToWorkingIR,
    new ResolveAndCheck,
    new HighFirrtlToMiddleFirrtl,
    new memlib.InferReadWrite,
    InferReadWriteCheck
  )

  "Infer ReadWrite Ports" should "infer readwrite ports for the same clock" in {
    val input = """
circuit sram6t :
  module sram6t :
    input clock : Clock
    input reset : UInt<1>
    output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>}

    io is invalid
    smem mem : UInt<32>[128]
    node T_0 = eq(io.wen, UInt<1>("h00"))
    node T_1 = and(io.en, T_0)
    wire T_2 : UInt
    T_2 is invalid
    when T_1 :
      T_2 <= io.raddr
    read mport T_3 = mem[T_2], clock
    io.rdata <= T_3
    node T_4 = and(io.en, io.wen)
    when T_4 :
      write mport T_5 = mem[io.waddr], clock
      T_5 <= io.wdata
""".stripMargin

    val annos = Seq(memlib.InferReadWriteAnnotation)
    val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
    // Check correctness of firrtl
    parse(res.getEmittedCircuit.value)
  }

  "Infer ReadWrite Ports" should "infer readwrite ports from exclusive when statements" in {
    val input = """
circuit sram6t :
  module sram6t :
    input clock : Clock
    input reset : UInt<1>
    output io : { flip addr : UInt<11>, flip ren : UInt<1>, flip wen : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}

    io is invalid
    smem mem : UInt<32> [2048]
    when io.wen :
      write mport _T_14 = mem[io.addr], clock
      _T_14 <= io.dataIn
    node _T_16 = eq(io.wen, UInt<1>("h0"))
    when _T_16 :
      wire _T_18 : UInt
      _T_18 is invalid
      when io.ren :
        _T_18 <= io.addr
        node _T_20 = or(_T_18, UInt<11>("h0"))
        node _T_21 = bits(_T_20, 10, 0)
        read mport _T_22 = mem[_T_21], clock
      io.dataOut <= _T_22
""".stripMargin

    val annos = Seq(memlib.InferReadWriteAnnotation)
    val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
    // Check correctness of firrtl
    parse(res.getEmittedCircuit.value)
  }

  "Infer ReadWrite Ports" should "not infer readwrite ports for the difference clocks" in {
    val input = """
circuit sram6t :
  module sram6t :
    input clk1 : Clock
    input clk2 : Clock
    input reset : UInt<1>
    output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>}

    io is invalid
    smem mem : UInt<32>[128]
    node T_0 = eq(io.wen, UInt<1>("h00"))
    node T_1 = and(io.en, T_0)
    wire T_2 : UInt
    T_2 is invalid
    when T_1 :
      T_2 <= io.raddr
    read mport T_3 = mem[T_2], clk1
    io.rdata <= T_3
    node T_4 = and(io.en, io.wen)
    when T_4 :
      write mport T_5 = mem[io.waddr], clk2
      T_5 <= io.wdata
""".stripMargin

    val annos = Seq(memlib.InferReadWriteAnnotation)
    intercept[Exception] {
      compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
    } match {
      case CustomTransformException(_: InferReadWriteCheckException) => // success
      case _ => fail()
    }
  }

  "wmode" should "be simplified" in {
    val input = """
circuit sram6t :
  module sram6t :
    input clock : Clock
    input reset : UInt<1>
    output io : { flip addr : UInt<11>, flip valid : UInt<1>, flip write : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}

    io is invalid
    smem mem : UInt<32> [2048]
    node wen = and(io.valid, io.write)
    node ren = and(io.valid, not(io.write))
    when wen :
      write mport _T_14 = mem[io.addr], clock
      _T_14 <= io.dataIn
    node _T_16 = eq(wen, UInt<1>("h0"))
    when _T_16 :
      wire _T_18 : UInt
      _T_18 is invalid
      when ren :
        _T_18 <= io.addr
        node _T_20 = or(_T_18, UInt<11>("h0"))
        node _T_21 = bits(_T_20, 10, 0)
        read mport _T_22 = mem[_T_21], clock
      io.dataOut <= _T_22
""".stripMargin

    val annos = Seq(memlib.InferReadWriteAnnotation)
    val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
    // Check correctness of firrtl
    res should containLine(s"mem.rw.wmode <= wen")
  }

  def sameAddr(ruw: String): String = {
    s"""
       |circuit sram6t :
       |  module sram6t :
       |    input clock : Clock
       |    output io : { flip addr : UInt<11>, flip valid : UInt<1>, flip write : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}
       |
       |    mem mem:
       |      data-type => UInt<4>
       |      depth => 64
       |      reader => r
       |      writer => w
       |      read-latency => 1
       |      write-latency => 1
       |      read-under-write => ${ruw}
       |
       |    mem.r.clk <= clock
       |    mem.r.addr <= io.addr
       |    mem.r.en <= io.valid
       |    io.dataOut <= mem.r.data
       |
       |    node wen = and(io.valid, io.write)
       |    mem.w.clk <= clock
       |    mem.w.addr <= io.addr
       |    mem.w.en <= wen
       |    mem.w.mask <= UInt(1)
       |    mem.w.data <= io.dataIn""".stripMargin
  }

  "Infer ReadWrite Ports" should "infer readwrite ports from shared addresses with undefined readUnderWrite" in {
    val input = sameAddr("undefined")
    val annos = Seq(memlib.InferReadWriteAnnotation)
    val res = compileAndEmit(CircuitState(parse(input), HighForm, annos))
    // Check correctness of firrtl
    res should containLine(s"mem.rw.wmode <= wen")
  }

  Seq("old", "new").foreach { ruw =>
    "Infer ReadWrite Ports" should s"not infer readwrite ports from shared addresses with '${ruw}' readUnderWrite" in {
      val input = sameAddr(ruw)
      val annos = Seq(memlib.InferReadWriteAnnotation)
      intercept[Exception] {
        compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
      } match {
        case CustomTransformException(_: InferReadWriteCheckException) => // success
        case _ => fail()
      }
    }
  }

}