aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/InferReadWriteSpec.scala
blob: 4268bd2ba69b90c410d2879303447260947af6e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// See LICENSE for license details.

package firrtlTests

import firrtl._
import firrtl.ir._
import firrtl.options.PreservesAll
import firrtl.passes._
import firrtl.stage.Forms
import firrtl.testutils._
import firrtl.testutils.FirrtlCheckers._

class InferReadWriteSpec extends SimpleTransformSpec {
  class InferReadWriteCheckException extends PassException(
    "Readwrite ports are not found!")

  object InferReadWriteCheck extends Pass with PreservesAll[Transform] {
    override def prerequisites = Forms.MidForm
    override def optionalPrerequisites = Seq.empty
    override def dependents = Forms.MidEmitters

    def findReadWrite(s: Statement): Boolean = s match {
      case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 =>
        s.name == "mem" && s.readwriters.head == "rw"
      case s: Block =>
        s.stmts exists findReadWrite
      case _ => false
    }

    def run (c: Circuit) = {
      val errors = new Errors
      val foundReadWrite = c.modules exists {
        case m: Module => findReadWrite(m.body)
        case m: ExtModule => false
      }
      if (!foundReadWrite) {
        errors append new InferReadWriteCheckException
        errors.trigger
      }
      c
    }
  }

  def emitter = new MiddleFirrtlEmitter
  def transforms = Seq(
    new ChirrtlToHighFirrtl,
    new IRToWorkingIR,
    new ResolveAndCheck,
    new HighFirrtlToMiddleFirrtl,
    new memlib.InferReadWrite,
    InferReadWriteCheck
  )

  "Infer ReadWrite Ports" should "infer readwrite ports for the same clock" in {
    val input = """
circuit sram6t :
  module sram6t :
    input clock : Clock
    input reset : UInt<1>
    output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>}

    io is invalid
    smem mem : UInt<32>[128]
    node T_0 = eq(io.wen, UInt<1>("h00"))
    node T_1 = and(io.en, T_0)
    wire T_2 : UInt
    T_2 is invalid
    when T_1 :
      T_2 <= io.raddr
    read mport T_3 = mem[T_2], clock
    io.rdata <= T_3
    node T_4 = and(io.en, io.wen)
    when T_4 :
      write mport T_5 = mem[io.waddr], clock
      T_5 <= io.wdata
""".stripMargin

    val annos = Seq(memlib.InferReadWriteAnnotation)
    val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
    // Check correctness of firrtl
    parse(res.getEmittedCircuit.value)
  }

  "Infer ReadWrite Ports" should "infer readwrite ports from exclusive when statements" in {
    val input = """
circuit sram6t :
  module sram6t :
    input clock : Clock
    input reset : UInt<1>
    output io : { flip addr : UInt<11>, flip ren : UInt<1>, flip wen : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}

    io is invalid
    smem mem : UInt<32> [2048]
    when io.wen :
      write mport _T_14 = mem[io.addr], clock
      _T_14 <= io.dataIn
    node _T_16 = eq(io.wen, UInt<1>("h0"))
    when _T_16 :
      wire _T_18 : UInt
      _T_18 is invalid
      when io.ren :
        _T_18 <= io.addr
        node _T_20 = or(_T_18, UInt<11>("h0"))
        node _T_21 = bits(_T_20, 10, 0)
        read mport _T_22 = mem[_T_21], clock
      io.dataOut <= _T_22
""".stripMargin

    val annos = Seq(memlib.InferReadWriteAnnotation)
    val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
    // Check correctness of firrtl
    parse(res.getEmittedCircuit.value)
  }

  "Infer ReadWrite Ports" should "not infer readwrite ports for the difference clocks" in {
    val input = """
circuit sram6t :
  module sram6t :
    input clk1 : Clock
    input clk2 : Clock
    input reset : UInt<1>
    output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>}

    io is invalid
    smem mem : UInt<32>[128]
    node T_0 = eq(io.wen, UInt<1>("h00"))
    node T_1 = and(io.en, T_0)
    wire T_2 : UInt
    T_2 is invalid
    when T_1 :
      T_2 <= io.raddr
    read mport T_3 = mem[T_2], clk1
    io.rdata <= T_3
    node T_4 = and(io.en, io.wen)
    when T_4 :
      write mport T_5 = mem[io.waddr], clk2
      T_5 <= io.wdata
""".stripMargin

    val annos = Seq(memlib.InferReadWriteAnnotation)
    intercept[Exception] {
      compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
    } match {
      case CustomTransformException(_: InferReadWriteCheckException) => // success
      case _ => fail()
    }
  }

  "wmode" should "be simplified" in {
    val input = """
circuit sram6t :
  module sram6t :
    input clock : Clock
    input reset : UInt<1>
    output io : { flip addr : UInt<11>, flip valid : UInt<1>, flip write : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}

    io is invalid
    smem mem : UInt<32> [2048]
    node wen = and(io.valid, io.write)
    node ren = and(io.valid, not(io.write))
    when wen :
      write mport _T_14 = mem[io.addr], clock
      _T_14 <= io.dataIn
    node _T_16 = eq(wen, UInt<1>("h0"))
    when _T_16 :
      wire _T_18 : UInt
      _T_18 is invalid
      when ren :
        _T_18 <= io.addr
        node _T_20 = or(_T_18, UInt<11>("h0"))
        node _T_21 = bits(_T_20, 10, 0)
        read mport _T_22 = mem[_T_21], clock
      io.dataOut <= _T_22
""".stripMargin

    val annos = Seq(memlib.InferReadWriteAnnotation)
    val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
    // Check correctness of firrtl
    res should containLine (s"mem.rw.wmode <= wen")
  }
}