1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
|
package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
import scala.collection.mutable.{ArrayBuffer,HashMap}
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import Annotations._
import firrtl.PrimOps._
import firrtl.WrappedExpression._
import java.io.Writer
import scala.util.matching.Regex
case class ReplSeqMemAnnotation(t: String, tID: TransID)
extends Annotation with Loose with Unstable {
val usage = """
[Optional] ReplSeqMem
Pass to replace sequential memories with blackboxes + configuration file
Usage:
--replSeqMem -c:<circuit>:-i<filename>:-o<filename>
*** Note: sub-arguments to --replSeqMem should be delimited by : and not white space!
Required Arguments:
-o<filename> Specify the output configuration file
-c<compiler> Specify the target circuit
Optional Arguments:
-i<filename> Specify the input configuration file
"""
sealed trait PassOption
case object InputConfigFileName extends PassOption
case object OutputConfigFileName extends PassOption
case object PassCircuitName extends PassOption
type PassOptionMap = Map[PassOption, String]
// can't use space to delimit sub arguments (otherwise, Driver.scala will throw error)
val passArgList = t.split(":").toList
def nextPassOption(map: PassOptionMap, list: List[String]): PassOptionMap = {
list match {
case Nil => map
case "-i" :: value :: tail =>
nextPassOption(map + (InputConfigFileName -> value), tail)
case "-o" :: value :: tail =>
nextPassOption(map + (OutputConfigFileName -> value), tail)
case "-c" :: value :: tail =>
nextPassOption(map + (PassCircuitName -> value), tail)
case option :: tail =>
throw new Exception("Unknown option " + option + usage)
}
}
val passOptions = nextPassOption(Map[PassOption, String](), passArgList)
val inputConfig = passOptions.getOrElse(InputConfigFileName, throw new Exception("No input config file provided for ReplSeqMem!" + usage))
val outputConfig = passOptions.getOrElse(OutputConfigFileName, throw new Exception("No output config file provided for ReplSeqMem!" + usage))
val passCircuit = passOptions.getOrElse(PassCircuitName, throw new Exception("No circuit name specified for ReplSeqMem!" + usage))
val target = CircuitName(passCircuit)
def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit,"-c:"+n.name))
}
object ReplSeqMem extends Pass {
def name = "Replace Sequential Memories with Blackboxes + Configuration File"
trait WritePortChar {
def name: String
def useMask: Boolean
def maskGran: Option[BigInt]
require( (useMask && (maskGran != None)) || (!useMask), "Must specify a mask granularity if write mask is desired" )
}
case class PortForWrite(
name: String,
useMask: Boolean = false,
maskGran: Option[BigInt] = None
) extends WritePortChar
case class PortForReadWrite(
name: String,
useMask: Boolean = false,
maskGran: Option[BigInt] = None
) extends WritePortChar
case class PortForRead(
name: String
)
// vendor agnostic configuration
case class SMem(
m: DefMemory,
// names of read ports
readPorts: Seq[PortForRead],
// write ports
writePorts: Seq[PortForWrite],
// read/write ports
readWritePorts: Seq[PortForReadWrite]
){
require (
if (readWritePorts.isEmpty) writePorts.nonEmpty && readPorts.nonEmpty else writePorts.isEmpty && readPorts.isEmpty,
"Need at least one set of read, write ports if no RW port is specified. A RW port must be standalone"
)
require (readWritePorts.length < 2, "Cannot have more than 1 readwrite port")
def name = m.name
def dataType = m.dataType
def depth = m.depth
def writeLatency = m.writeLatency
def readLatency = m.readLatency
def numReaders = readPorts.length
def numWriters = writePorts.length
def numRWriters = readWritePorts.length
def rPortMap = readPorts.zipWithIndex map { case (p,i) => p -> s"R$i" }
def wPortMap = writePorts.zipWithIndex map { case (p,i) => p.name -> s"W$i" }
def rwPortMap = readWritePorts.zipWithIndex map { case (p,i) => p.name -> s"RW$i" }
def width = bitWidth(dataType)
def serialize = {
// for backwards compatibility with old conf format
val writers = writePorts map (x => if (x.useMask) "mwrite" else "write")
val readers = List.fill(numReaders)("read")
val readwriters = readWritePorts map (x => if(x.useMask) "mrw" else "rw")
val ports = (writers ++ readers ++ readwriters).mkString(",")
// old conf file only supported 1 mask_gran
val maskGran = (writePorts ++ readWritePorts) map (_.maskGran.getOrElse(0))
val maskGranConf = if (maskGran.head == 0) "" else s"mask_gran ${maskGran.head}"
s"name ${name} depth ${depth} width ${width} ports ${ports} ${maskGranConf} \n"
}
def eq(m: SMem) = {
// TODO: Condition on read under write
val wpIndivEq = writePorts zip m.writePorts map {case(a,b) => a.maskGran == b.maskGran}
val wpEq = wpIndivEq.foldLeft(true)(_ && _)
val rwpIndivEq = readWritePorts zip m.readWritePorts map {case(a,b) => a.maskGran == b.maskGran}
val rwpEq = rwpIndivEq.foldLeft(true)(_ && _)
(dataType == m.dataType) &&
(depth == m.depth) &&
(writeLatency == m.writeLatency) &&
(readLatency == m.readLatency) &&
(numReaders == m.numReaders) &&
(wpEq && rwpEq)
}
}
def analyzeMemsInModule(m: Module): Seq[SMem] = {
val connects = HashMap[String, Expression]()
val mems = ArrayBuffer[SMem]()
// swiped from InferRW
def findConnects(s: Statement): Unit = s match {
case s: Connect =>
connects(s.loc.serialize) = s.expr
case s: PartialConnect =>
connects(s.loc.serialize) = s.expr
case s: DefNode =>
connects(s.name) = s.value
case s: Block =>
s.stmts foreach findConnects
case _ =>
}
def findConnectOriginFromExp(e: Expression): Seq[Expression] = e match {
// matches how wmode, wmask, write_en are assigned (from Chirrtl)
// in case no ConstProp is performed before this pass
case Mux(cond, tv, fv, _) if we(tv) == we(one) && we(fv) == we(zero) =>
cond +: findConnectOrigin(cond.serialize)
// visit connected nodes to references
case _: WRef | _: SubField | _: SubIndex | _: SubAccess =>
e +: findConnectOrigin(e.serialize)
// backward searches until a PrimOp or Literal appears -->
// Literal: you've reached origin
// PrimOp: you're not simply doing propagation anymore
// NOTE: not a catch-all!!!
case _ => List(e)
}
// only capable of searching for origin in the same module
def findConnectOrigin(node: String): Seq[Expression] = {
if (connects contains node) findConnectOriginFromExp(connects(node))
else Nil
}
// returns None if wen = wmask bits or wmask bits all = 1; otherwise returns # of mask bits
def getMaskBits(wen: String, wmask: String): Option[Int] = {
val wenOrigin = findConnectOrigin(wen)
// find all mask bits
val wmaskOrigin = connects.keys.toSeq filter (_.startsWith(wmask)) map findConnectOrigin
val bitEq = wmaskOrigin map (wenOrigin intersect _) map (_.length > 0)
// when all wmask bits are equal to wmode, wmask is redundant
val eq = bitEq.foldLeft(true)(_ && _)
val wmaskBitOne = wmaskOrigin map(_ contains one)
// if all wmask bits = 1, then wmask is redundant
val wmaskOne = wmaskBitOne.foldLeft(true)(_ && _)
if (eq || wmaskOne) None else Some(wmaskOrigin.length)
}
def findMemInsts(s: Statement): Unit = s match {
// only find smems
case m: DefMemory if m.readLatency > 0 =>
val dataBits = bitWidth(m.dataType)
val rwPorts = m.readwriters map (w => {
val maskBits = getMaskBits(s"${m.name}.$w.wmode",s"${m.name}.$w.wmask")
if (maskBits == None) PortForReadWrite(name = w)
else PortForReadWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get))
})
val wPorts = m.writers map (w => {
val maskBits = getMaskBits(s"${m.name}.$w.en",s"${m.name}.$w.mask")
if (maskBits == None) PortForWrite(name = w)
else PortForWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get))
})
val smemInfo = SMem(
m = m,
readPorts = m.readers map(r => PortForRead(name = r)),
writePorts = wPorts,
readWritePorts = rwPorts
)
mems += smemInfo
case b: Block => b.stmts foreach findMemInsts
case _ =>
}
findConnects(m.body)
findMemInsts(m.body)
mems.toSeq
}
def run(c: Circuit) = {
val uniqueMems = ArrayBuffer[SMem]()
def analyzeMemsInCircuit(c: Circuit) = {
val mems = ArrayBuffer[SMem]()
c.modules foreach { _ match {
case m: Module => mems ++= analyzeMemsInModule(m)
case m: ExtModule =>
}}
mems map {m =>
val memProto = uniqueMems.find(_.eq(m))
if (memProto == None) {
uniqueMems += m
m.name -> m
}
else m.name -> memProto.get.copy(m=m.m)
}
}
val memMap = analyzeMemsInCircuit(c)
println(memMap)
c
}
}
class ReplSeqMem(transID: TransID) extends Transform with LazyLogging {
def execute(circuit:Circuit, map: AnnotationMap) =
map get transID match {
case Some(p) => p get CircuitName(circuit.main) match {
case Some(ReplSeqMemAnnotation(_, _)) => TransformResult((Seq(
Legalize,
ReplSeqMem,
CheckInitialization,
ResolveKinds,
InferTypes,
ResolveGenders) foldLeft circuit){ (c, pass) =>
val x = Utils.time(pass.name)(pass run c)
logger debug x.serialize
x
}, None, Some(map))
case _ => TransformResult(circuit, None, Some(map))
}
case _ => TransformResult(circuit, None, Some(map))
}
}
|