1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
|
// See LICENSE for license details.
package firrtlTests
import firrtl._
import firrtl.ir._
import firrtl.options.PreservesAll
import firrtl.passes._
import firrtl.stage.Forms
import firrtl.testutils._
import firrtl.testutils.FirrtlCheckers._
class InferReadWriteSpec extends SimpleTransformSpec {
class InferReadWriteCheckException extends PassException(
"Readwrite ports are not found!")
object InferReadWriteCheck extends Pass with PreservesAll[Transform] {
override def prerequisites = Forms.MidForm
override def optionalPrerequisites = Seq.empty
override def dependents = Forms.MidEmitters
def findReadWrite(s: Statement): Boolean = s match {
case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 =>
s.name == "mem" && s.readwriters.head == "rw"
case s: Block =>
s.stmts exists findReadWrite
case _ => false
}
def run (c: Circuit) = {
val errors = new Errors
val foundReadWrite = c.modules exists {
case m: Module => findReadWrite(m.body)
case m: ExtModule => false
}
if (!foundReadWrite) {
errors append new InferReadWriteCheckException
errors.trigger
}
c
}
}
def emitter = new MiddleFirrtlEmitter
def transforms = Seq(
new ChirrtlToHighFirrtl,
new IRToWorkingIR,
new ResolveAndCheck,
new HighFirrtlToMiddleFirrtl,
new memlib.InferReadWrite,
InferReadWriteCheck
)
"Infer ReadWrite Ports" should "infer readwrite ports for the same clock" in {
val input = """
circuit sram6t :
module sram6t :
input clock : Clock
input reset : UInt<1>
output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>}
io is invalid
smem mem : UInt<32>[128]
node T_0 = eq(io.wen, UInt<1>("h00"))
node T_1 = and(io.en, T_0)
wire T_2 : UInt
T_2 is invalid
when T_1 :
T_2 <= io.raddr
read mport T_3 = mem[T_2], clock
io.rdata <= T_3
node T_4 = and(io.en, io.wen)
when T_4 :
write mport T_5 = mem[io.waddr], clock
T_5 <= io.wdata
""".stripMargin
val annos = Seq(memlib.InferReadWriteAnnotation)
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
}
"Infer ReadWrite Ports" should "infer readwrite ports from exclusive when statements" in {
val input = """
circuit sram6t :
module sram6t :
input clock : Clock
input reset : UInt<1>
output io : { flip addr : UInt<11>, flip ren : UInt<1>, flip wen : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}
io is invalid
smem mem : UInt<32> [2048]
when io.wen :
write mport _T_14 = mem[io.addr], clock
_T_14 <= io.dataIn
node _T_16 = eq(io.wen, UInt<1>("h0"))
when _T_16 :
wire _T_18 : UInt
_T_18 is invalid
when io.ren :
_T_18 <= io.addr
node _T_20 = or(_T_18, UInt<11>("h0"))
node _T_21 = bits(_T_20, 10, 0)
read mport _T_22 = mem[_T_21], clock
io.dataOut <= _T_22
""".stripMargin
val annos = Seq(memlib.InferReadWriteAnnotation)
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
}
"Infer ReadWrite Ports" should "not infer readwrite ports for the difference clocks" in {
val input = """
circuit sram6t :
module sram6t :
input clk1 : Clock
input clk2 : Clock
input reset : UInt<1>
output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>}
io is invalid
smem mem : UInt<32>[128]
node T_0 = eq(io.wen, UInt<1>("h00"))
node T_1 = and(io.en, T_0)
wire T_2 : UInt
T_2 is invalid
when T_1 :
T_2 <= io.raddr
read mport T_3 = mem[T_2], clk1
io.rdata <= T_3
node T_4 = and(io.en, io.wen)
when T_4 :
write mport T_5 = mem[io.waddr], clk2
T_5 <= io.wdata
""".stripMargin
val annos = Seq(memlib.InferReadWriteAnnotation)
intercept[Exception] {
compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
} match {
case CustomTransformException(_: InferReadWriteCheckException) => // success
case _ => fail()
}
}
"wmode" should "be simplified" in {
val input = """
circuit sram6t :
module sram6t :
input clock : Clock
input reset : UInt<1>
output io : { flip addr : UInt<11>, flip valid : UInt<1>, flip write : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}
io is invalid
smem mem : UInt<32> [2048]
node wen = and(io.valid, io.write)
node ren = and(io.valid, not(io.write))
when wen :
write mport _T_14 = mem[io.addr], clock
_T_14 <= io.dataIn
node _T_16 = eq(wen, UInt<1>("h0"))
when _T_16 :
wire _T_18 : UInt
_T_18 is invalid
when ren :
_T_18 <= io.addr
node _T_20 = or(_T_18, UInt<11>("h0"))
node _T_21 = bits(_T_20, 10, 0)
read mport _T_22 = mem[_T_21], clock
io.dataOut <= _T_22
""".stripMargin
val annos = Seq(memlib.InferReadWriteAnnotation)
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
res should containLine (s"mem.rw.wmode <= wen")
}
}
|