1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
|
package midas
import Utils._
import firrtl._
import firrtl.Utils._
/** FAME-1 Transformation
*
* This pass takes a lowered-to-ground circuit and performs a FAME-1 (Decoupled) transformation
* to the circuit
* It does this by creating a simulation module wrapper around the circuit, if we can gate the
* clock, then there need be no modification to the target RTL, if we can't, then the target
* RTL will have to be modified (by adding a midasFire input and use this signal to gate
* register enable
*
* ALGORITHM
* 1. Flatten RTL
* a. Create NewTop
* b. Instantiate Top in NewTop
* c. Iteratively pull all sim tagged instances out of hierarchy to NewTop
* i. Move instance declaration from child module to parent module
* ii. Create io in child module corresponding to io of instance
* iii. Connect original instance io in child to new child io
* iv. Connect child io in parent to instance io
* v. Repeate until instance is in SimTop
* (if black box, repeat until completely removed from design)
* * Post-flattening invariants
* - No combinational logic on NewTop
* - All and only instances in NewTop are sim tagged
* - No black boxes in design
* 2. Simulation Transformation
* a. Perform Decoupled Transformation on every RTL Module
* i. Add targetFire signal input
* ii. Find all DefInst nodes and propogate targetFire to the instances
* iii. Find all registers and add when statement connecting regIn to regOut when !targetFire
* b. Iteratively transform each inst in Top (see example firrtl in dram_midas top dir)
* i. Create wrapper class
* ii. Create input and output (ready, valid) pairs for every other sim module this module connects to
* * Note that TopIO counts as a "sim module"
* iii. Create targetFire as AND of inputs.valid and outputs.ready
* iv. Connect targetFire to targetFire input of target rtl inst
* v. Connect target IO to wrapper IO, except connect target clock to simClock
*
* TODO
* - Change from clock gating to reg enable, dont' forget to change sequential memory read enable
* - Implement Flatten RTL
* - Refactor important strings/naming to API (eg. "topIO" needs to be a constant defined somewhere or something)
* - Check that circuit is in LowFIRRTL?
*
* NOTES
* - How do we only transform the necessary modules? Should there be a MIDAS list of modules
* or something?
* * YES, this will be done by requiring the user to instantiate modules that should be split
* with something like: val module = new MIDASModule(class... etc.)
* - There cannot be nested DecoupledIO or ValidIO
* - How do output consumes tie in to MIDAS fire? If all of our outputs are not consumed
* in a given cycle, do we block midas$fire on the next cycle? Perhaps there should be
* a register for not having consumed all outputs last cycle
* - If our outputs are not consumed we also need to be sure not to consume out inputs,
* so the logic for this must depend on the previous cycle being consumed as well
* - We also need a way to determine the difference between the MIDAS modules and their
* connecting Queues, perhaps they should be MIDAS queues, which then perhaps prints
* out a listing of all queues so that they can be properly transformed
* * What do these MIDAS queues look like since we're enforcing true decoupled
* interfaces?
*/
object Fame1 {
// Constants, common nodes, and common types used throughout
private type PortMap = Map[String, Seq[String]]
private val PortMap = Map[String, Seq[String]]()
private type ConnMap = Map[String, PortMap]
private val ConnMap = Map[String, PortMap]()
private type SimQMap = Map[(String, String), Module]
private val SimQMap = Map[(String, String), Module]()
private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1)))
private val hostValid = Field("hostValid", Default, UIntType(IntWidth(1)))
private val hostClock = Port(NoInfo, "hostClock", Input, ClockType)
private val hostReset = Port(NoInfo, "hostReset", Input, UIntType(IntWidth(1)))
private val targetFire = Port(NoInfo, "targetFire", Input, UIntType(IntWidth(1)))
private def wrapName(name: String): String = s"SimWrap_${name}"
private def unwrapName(name: String): String = name.stripPrefix("SimWrap_")
private def queueName(src: String, dst: String): String = s"SimQueue_${src}_${dst}"
private def instName(name: String): String = s"inst_${name}"
private def unInstName(name: String): String = name.stripPrefix("inst_")
private def genHostDecoupled(fields: Seq[Field]): BundleType = {
BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(fields)))
}
// ********** findPortConn **********
// This takes lowFIRRTL top module that follows invariants described above and returns a connection Map
// of instanceName -> (instanctPorts -> portEndpoint)
// It honestly feels kind of brittle given it assumes there will be no intermediate nodes or anything in
// the way of direct connections between IO of module instances
private def processConnectExp(exp: Exp): (String, String) = {
val unsupportedExp = new Exception("Unsupported Exp for finding port connections: " + exp)
exp match {
case ref: Ref => ("topIO", ref.name)
case sub: Subfield =>
sub.exp match {
case ref: Ref => (ref.name, sub.name)
case _ => throw unsupportedExp
}
case exp: Exp => throw unsupportedExp
}
}
private def processConnect(conn: Connect): ConnMap = {
val lhs = processConnectExp(conn.lhs)
val rhs = processConnectExp(conn.rhs)
Map(lhs._1 -> Map(lhs._2 -> Seq(rhs._1)), rhs._1 -> Map(rhs._2 -> Seq(lhs._1))).withDefaultValue(PortMap)
}
private def findPortConn(connMap: ConnMap, stmts: Seq[Stmt]): ConnMap = {
if (stmts.isEmpty) connMap
else {
stmts.head match {
case conn: Connect => {
val newConnMap = processConnect(conn)
findPortConn((connMap map { case (k,v) =>
k -> merge(Seq(v, newConnMap(k))) { (_, v1, v2) => v1 ++ v2 }}), stmts.tail )
}
case _ => findPortConn(connMap, stmts.tail)
}
}
}
private def findPortConn(top: Module, insts: Seq[DefInst]): ConnMap = {
val initConnMap = (insts map ( _.name -> PortMap )).toMap ++ Map("topIO" -> PortMap)
val topStmts = top.stmt match {
case b: Block => b.stmts
case s: Stmt => Seq(s) // This honestly shouldn't happen but let's be safe
}
findPortConn(initConnMap, topStmts)
}
// Removes clocks from a portmap
private def scrubClocks(ports: Seq[Port], portMap: PortMap): PortMap = {
val clocks = ports filter (_.tpe == ClockType) map (_.name)
portMap filter { case (portName, _) => !clocks.contains(portName) }
}
// ********** transformRTL **********
// Takes an RTL module and give it targetFire input, propogates targetFire to all child instances,
// puts targetFire on regEnable for all registers
// TODO
// - Add smem support
private def transformRTL(m: Module): Module = {
val ports = m.ports :+ targetFire
val instProp = getDefInsts(m) map { inst =>
Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name))
}
val regEn = getDefRegs(m) map { reg =>
When(NoInfo, DoPrimop(Not, Seq(buildExp(targetFire.name)), Seq(), UnknownType),
Connect(NoInfo, buildExp(reg.name), buildExp(reg.name)), EmptyStmt)
}
Module(m.info, m.name, ports, Block(m.stmt +: (instProp ++ regEn)))
}
// ********** genWrapperModule **********
// Generates FAME-1 Decoupled wrappers for simulation module instances
private def genWrapperModule(inst: DefInst, portMap: PortMap): Module = {
val instIO = getDefInstType(inst)
val nameToField = (instIO.fields map (f => f.name -> f)).toMap
val connections = (portMap map (_._2)).toSeq.flatten.distinct // modules this inst connects to
// Build simPort for each connecting module
// TODO This whole chunk really ought to be rewritten or made a function
val connPorts = connections map { c =>
// Get ports that connect to this particular module as fields
val fields = (portMap filter (_._2.contains(c))).keySet.toSeq.sorted map (nameToField(_))
val noClock = fields filter (_.tpe != ClockType) // Remove clock
val inputSet = noClock filter (_.dir == Reverse) map (f => Field(f.name, Default, f.tpe))
val outputSet = noClock filter (_.dir == Default)
Port(inst.info, c, Output, BundleType(
(if (inputSet.isEmpty) Seq()
else Seq(Field("hostIn", Reverse, genHostDecoupled(inputSet)))
) ++
(if (outputSet.isEmpty) Seq()
else Seq(Field("hostOut", Default, genHostDecoupled(outputSet)))
)
))
}
val ports = hostClock +: hostReset +: connPorts // Add host and host reset
// targetFire is signal to indicate when a simulation module can execute, this is indicated by all of its inputs
// being valid and all of its outputs being ready
val targetFireInputs = (connPorts map { port =>
getFields(port) map { field =>
field.dir match {
case Reverse => buildExp(Seq(port.name, field.name, hostValid.name))
case Default => buildExp(Seq(port.name, field.name, hostReady.name))
}
}
}).flatten
val defTargetFire = DefNode(inst.info, targetFire.name, genPrimopReduce(And, targetFireInputs))
val connectTargetFire = Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name))
// As a simple RTL module, we're always ready
val inputsReady = (connPorts map { port =>
getFields(port) filter (_.dir == Reverse) map { field => // filter to only take inputs
Connect(inst.info, buildExp(Seq(port.name, field.name, hostReady.name)), UIntValue(1, IntWidth(1)))
}
}).flatten
// Outputs are valid on cycles where we fire
val outputsValid = (connPorts map { port =>
getFields(port) filter (_.dir == Default) map { field => // filter to only take outputs
Connect(inst.info, buildExp(Seq(port.name, field.name, hostValid.name)), buildExp(targetFire.name))
}
}).flatten
// Connect up all of the IO of the RTL module to sim module IO, except clock which should be connected
// This currently assumes naming things that are also done above when generating connPorts
val connectedInstIOFields = instIO.fields filter(field => portMap.contains(field.name)) // skip unconnected IO
val instIOConnect = (connectedInstIOFields map { field =>
field.tpe match {
case ClockType => Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)),
Ref(hostClock.name, ClockType)))
case _ => field.dir match {
case Default => portMap(field.name) map { endpoint =>
Connect(inst.info, buildExp(Seq(endpoint, "hostOut", "hostBits", field.name)),
buildExp(Seq(inst.name, field.name)))
}
case Reverse => {
if (portMap(field.name).length > 1)
throw new Exception("It is illegal to have more than 1 connection to a single input" + field)
Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)),
buildExp(Seq(portMap(field.name).head, "hostIn", "hostBits", field.name))))
}
}
}
}).flatten
val stmts = Block(Seq(defTargetFire) ++ inputsReady ++ outputsValid ++ Seq(inst) ++
Seq(connectTargetFire) ++ instIOConnect)
Module(inst.info, wrapName(inst.name), ports, stmts)
}
// ********** generateSimQueues **********
// Takes Seq of SimWrapper modules
// Returns Map of (src, dest) -> SimQueue
// To prevent duplicates, instead of creating a map with (src, dest) as the key, we could instead
// only one direction of the queue for each simport of each module. The only problem with this is
// it won't create queues for TopIO since that isn't a real module
private def generateSimQueues(wrappers: Seq[Module]): SimQMap = {
def rec(wrappers: Seq[Module], map: SimQMap): SimQMap = {
if (wrappers.isEmpty) map
else {
val w = wrappers.head
val name = unwrapName(w.name)
val newMap = (w.ports filter(isSimPort) map { port =>
(splitSimPort(port) map { field =>
val (src, dst) = if (field.dir == Default) (name, port.name) else (port.name, name)
if (map.contains((src, dst))) SimQMap
else Map((src, dst) -> buildSimQueue(queueName(src, dst), getHostBits(field).tpe))
}).flatten.toMap
}).flatten.toMap
rec(wrappers.tail, map ++ newMap)
}
}
rec(wrappers, SimQMap)
}
// ********** generateSimTop **********
// Creates the Simulation Top module where all sim modules and sim queues are instantiated and connected
private def transformTopIO(ports: Seq[Port]): Seq[Port] = {
val noClock = ports filter (_.tpe != ClockType)
val inputs = noClock filter (_.dir == Input) map (_.toField.flip) // Flip because wrapping port is input
val outputs = noClock filter (_.dir == Output) map (_.toField)
Seq(Port(NoInfo, "io", Output, BundleType(Seq(Field("hostIn", Reverse, genHostDecoupled(inputs)),
Field("hostOut", Default, genHostDecoupled(outputs))))))
}
private def generateSimTop(wrappers: Seq[Module], simQueues: SimQMap, portMap: PortMap, rtlTop: Module): Module = {
val insts = (wrappers map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) ++
(simQueues.values map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) })
val connectClocks = (wrappers ++ simQueues.values) map { m =>
Connect(NoInfo, buildExp(Seq(instName(m.name), hostClock.name)), buildExp(hostClock.name))
}
val connectResets = (wrappers ++ simQueues.values) map { m =>
Connect(NoInfo, buildExp(Seq(instName(m.name), hostReset.name)), buildExp(hostReset.name))
}
// Connect queues to simulation modules (excludes IO)
val connectQueues = (simQueues map { case ((src, dst), queue) =>
(if (src == "topIO") Seq()
else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(queue.name), "io", "enq")),
buildExp(Seq(instName(wrapName(src)), dst, "hostOut"))))
) ++
(if (dst == "topIO") Seq()
else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(wrapName(dst)), src, "hostIn")),
buildExp(Seq(instName(queue.name), "io", "deq"))))
)
}).flatten
// Connect IO queues, Src means input, Dst means output (ie. the outside word is the Src or Dst)
val ioSrcQueues = (simQueues filter {case ((src, dst), _) => src == "topIO"} map {case (_, queue) => queue}).toSeq
val ioDstQueues = (simQueues filter {case ((src, dst), _) => dst == "topIO"} map {case (_, queue) => queue}).toSeq
val ioSrcSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Input) map (_.name)
val ioDstSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Output) map (_.name)
val ioSrcQueueConnect = if (ioSrcQueues.length > 0) {
val readySignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostReady.name)))
val validSignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostValid.name)))
(ioSrcSignals map { sig =>
(portMap(sig) map { dst =>
Connect(NoInfo, buildExp(Seq(instName(queueName("topIO", dst)), "io", "enq", "hostBits", sig)),
buildExp(Seq("io", "hostIn", "hostBits", sig)))
})
}).flatten ++
(validSignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostIn", hostValid.name))))) :+
Connect(NoInfo, buildExp(Seq("io", "hostIn", hostReady.name)), genPrimopReduce(And, readySignals))
} else Seq(EmptyStmt)
val ioDstQueueConnect = if (ioDstQueues.length > 0) {
val readySignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostReady.name)))
val validSignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostValid.name)))
(ioDstSignals map { sig =>
(portMap(sig) map { src =>
Connect(NoInfo, buildExp(Seq("io", "hostOut", "hostBits", sig)),
buildExp(Seq(instName(queueName(src, "topIO")), "io", "deq", "hostBits", sig)))
})
}).flatten ++
(readySignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostOut", hostReady.name))))) :+
Connect(NoInfo, buildExp(Seq("io", "hostOut", hostValid.name)), genPrimopReduce(And, validSignals))
} else Seq(EmptyStmt)
val stmts = Block(insts ++ connectClocks ++ connectResets ++ connectQueues ++ ioSrcQueueConnect ++ ioDstQueueConnect)
val ports = Seq(hostClock, hostReset) ++ transformTopIO(rtlTop.ports)
Module(NoInfo, "SimTop", ports, stmts)
}
// ********** transform **********
// Perform FAME-1 Transformation for MIDAS
def transform(c: Circuit): Circuit = {
// We should check that the invariants mentioned above are true
val nameToModule = (c.modules map (m => m.name -> m))(collection.breakOut): Map[String, Module]
val top = nameToModule(c.name)
val rtlModules = c.modules filter (_.name != top.name) map (transformRTL)
val insts = getDefInsts(top)
val portConn = findPortConn(top, insts)
// Check that port Connections include all ports for each instance?
val simWrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name)))
val simQueues = generateSimQueues(simWrappers)
// Remove duplicate simWrapper and simQueue modules?
val simTop = generateSimTop(simWrappers, simQueues, portConn("topIO"), top)
val modules = rtlModules ++ simWrappers ++ simQueues.values.toSeq ++ Seq(simTop)
Circuit(c.info, simTop.name, modules)
}
}
|