aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ReplSeqMem.scala
blob: 3457febb8e346b8f0b7b9c610e861c4cdeef59b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
package firrtl.passes

import com.typesafe.scalalogging.LazyLogging
import scala.collection.mutable.{ArrayBuffer,HashMap}

import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import Annotations._
import firrtl.PrimOps._
import firrtl.WrappedExpression._

import java.io.Writer

import scala.util.matching.Regex

case class ReplSeqMemAnnotation(t: String, tID: TransID)
    extends Annotation with Loose with Unstable {

  val usage = """
[Optional] ReplSeqMem
  Pass to replace sequential memories with blackboxes + configuration file

Usage: 
  --replSeqMem -c:<circuit>:-i<filename>:-o<filename>
  *** Note: sub-arguments to --replSeqMem should be delimited by : and not white space!

Required Arguments:
  -o<filename>         Specify the output configuration file
  -c<compiler>         Specify the target circuit

Optional Arguments:
  -i<filename>         Specify the input configuration file
"""    

  sealed trait PassOption
  case object InputConfigFileName extends PassOption
  case object OutputConfigFileName extends PassOption
  case object PassCircuitName extends PassOption
  
  type PassOptionMap = Map[PassOption, String] 

  // can't use space to delimit sub arguments (otherwise, Driver.scala will throw error)
  val passArgList = t.split(":").toList
  
  def nextPassOption(map: PassOptionMap, list: List[String]): PassOptionMap = {
    list match {
      case Nil => map
      case "-i" :: value :: tail =>
        nextPassOption(map + (InputConfigFileName -> value), tail)
      case "-o" :: value :: tail =>
        nextPassOption(map + (OutputConfigFileName -> value), tail)
      case "-c" :: value :: tail =>
        nextPassOption(map + (PassCircuitName -> value), tail)
      case option :: tail =>
        throw new Exception("Unknown option " + option + usage)
    }
  }

  val passOptions = nextPassOption(Map[PassOption, String](), passArgList)
  val inputConfig = passOptions.getOrElse(InputConfigFileName, throw new Exception("No input config file provided for ReplSeqMem!" + usage))
  val outputConfig = passOptions.getOrElse(OutputConfigFileName, throw new Exception("No output config file provided for ReplSeqMem!" + usage))
  val passCircuit = passOptions.getOrElse(PassCircuitName, throw new Exception("No circuit name specified for ReplSeqMem!" + usage))

  val target = CircuitName(passCircuit)
  def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit,"-c:"+n.name))
  
}

object ReplSeqMem extends Pass {

  def name = "Replace Sequential Memories with Blackboxes + Configuration File"

  trait WritePortChar {
    def name: String
    def useMask: Boolean
    def maskGran: Option[BigInt]
    require( (useMask && (maskGran != None)) || (!useMask), "Must specify a mask granularity if write mask is desired" )
  }

  case class PortForWrite(
    name: String,
    useMask: Boolean = false,
    maskGran: Option[BigInt] = None
  ) extends WritePortChar 

  case class PortForReadWrite(
    name: String,
    useMask: Boolean = false,
    maskGran: Option[BigInt] = None
  ) extends WritePortChar 

  case class PortForRead(
    name: String
  )

  // vendor agnostic configuration
  case class SMem(
    m: DefMemory,
    // names of read ports
    readPorts: Seq[PortForRead],
    // write ports
    writePorts: Seq[PortForWrite],
    // read/write ports
    readWritePorts: Seq[PortForReadWrite]
  ){
    require ( 
      if (readWritePorts.isEmpty) writePorts.nonEmpty && readPorts.nonEmpty else writePorts.isEmpty && readPorts.isEmpty,
      "Need at least one set of read, write ports if no RW port is specified. A RW port must be standalone"
    )  
    require (readWritePorts.length < 2, "Cannot have more than 1 readwrite port")
    def name = m.name
    def dataType = m.dataType
    def depth = m.depth
    def writeLatency = m.writeLatency
    def readLatency = m.readLatency
    def numReaders = readPorts.length
    def numWriters = writePorts.length
    def numRWriters = readWritePorts.length  
    def rPortMap = readPorts.zipWithIndex map { case (p,i) => p -> s"R$i" }
    def wPortMap = writePorts.zipWithIndex map { case (p,i) => p.name -> s"W$i" }
    def rwPortMap = readWritePorts.zipWithIndex map { case (p,i) => p.name -> s"RW$i" }
    def width = bitWidth(dataType)
    def serialize = {
      // for backwards compatibility with old conf format
      val writers = writePorts map (x => if (x.useMask) "mwrite" else "write")
      val readers = List.fill(numReaders)("read")
      val readwriters = readWritePorts map (x => if(x.useMask) "mrw" else "rw")
      val ports = (writers ++ readers ++ readwriters).mkString(",")
      // old conf file only supported 1 mask_gran
      val maskGran = (writePorts ++ readWritePorts) map (_.maskGran.getOrElse(0))
      val maskGranConf = if (maskGran.head == 0) "" else s"mask_gran ${maskGran.head}"
      s"name ${name} depth ${depth} width ${width} ports ${ports} ${maskGranConf} \n"
    }
    def eq(m: SMem) = {
      // TODO: Condition on read under write
      val wpIndivEq = writePorts zip m.writePorts map {case(a,b) => a.maskGran == b.maskGran} 
      val wpEq = wpIndivEq.foldLeft(true)(_ && _)
      val rwpIndivEq = readWritePorts zip m.readWritePorts map {case(a,b) => a.maskGran == b.maskGran} 
      val rwpEq = rwpIndivEq.foldLeft(true)(_ && _)
      (dataType == m.dataType) && 
      (depth == m.depth) && 
      (writeLatency == m.writeLatency) && 
      (readLatency == m.readLatency) && 
      (numReaders == m.numReaders) && 
      (wpEq && rwpEq)
    }
  }

  def analyzeMemsInModule(m: Module): Seq[SMem] = {

    val connects = HashMap[String, Expression]()
    val mems = ArrayBuffer[SMem]()

    // swiped from InferRW 
    def findConnects(s: Statement): Unit = s match {
      case s: Connect  =>
        connects(s.loc.serialize) = s.expr
      case s: PartialConnect =>
        connects(s.loc.serialize) = s.expr
      case s: DefNode =>
        connects(s.name) = s.value
      case s: Block =>
        s.stmts foreach findConnects
      case _ =>
    }

    def findConnectOriginFromExp(e: Expression): Seq[Expression] = e match {
      // matches how wmode, wmask, write_en are assigned (from Chirrtl) 
      // in case no ConstProp is performed before this pass
      case Mux(cond, tv, fv, _) if we(tv) == we(one) && we(fv) == we(zero) =>
        cond +: findConnectOrigin(cond.serialize)
      // visit connected nodes to references
      case _: WRef | _: SubField | _: SubIndex | _: SubAccess =>
        e +: findConnectOrigin(e.serialize) 
      // backward searches until a PrimOp or Literal appears -->
      // Literal: you've reached origin
      // PrimOp: you're not simply doing propagation anymore
      // NOTE: not a catch-all!!! 
      case _ => List(e)  
    }

    // only capable of searching for origin in the same module
    def findConnectOrigin(node: String): Seq[Expression] = {
      if (connects contains node) findConnectOriginFromExp(connects(node))
      else Nil
    }

    // returns None if wen = wmask bits or wmask bits all = 1; otherwise returns # of mask bits
    def getMaskBits(wen: String, wmask: String): Option[Int] = {
      val wenOrigin = findConnectOrigin(wen)
      // find all mask bits
      val wmaskOrigin = connects.keys.toSeq filter (_.startsWith(wmask)) map findConnectOrigin
      val bitEq = wmaskOrigin map (wenOrigin intersect _) map (_.length > 0) 
      // when all wmask bits are equal to wmode, wmask is redundant
      val eq = bitEq.foldLeft(true)(_ && _)
      val wmaskBitOne = wmaskOrigin map(_ contains one)
      // if all wmask bits = 1, then wmask is redundant
      val wmaskOne = wmaskBitOne.foldLeft(true)(_ && _)
      if (eq || wmaskOne) None else Some(wmaskOrigin.length)
    }

    def findMemInsts(s: Statement): Unit = s match {
      // only find smems
      case m: DefMemory if m.readLatency > 0 =>
        val dataBits = bitWidth(m.dataType)
        val rwPorts = m.readwriters map (w => {
          val maskBits = getMaskBits(s"${m.name}.$w.wmode",s"${m.name}.$w.wmask")
          if (maskBits == None) PortForReadWrite(name = w)
          else PortForReadWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get))
        })
        val wPorts = m.writers map (w => {
          val maskBits = getMaskBits(s"${m.name}.$w.en",s"${m.name}.$w.mask")
          if (maskBits == None) PortForWrite(name = w)
          else PortForWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get))  
        })
        val smemInfo = SMem(
          m = m,
          readPorts = m.readers map(r => PortForRead(name = r)),
          writePorts = wPorts,
          readWritePorts = rwPorts
        )
        mems += smemInfo
      case b: Block => b.stmts foreach findMemInsts
      case _ => 
    }
    findConnects(m.body)
    findMemInsts(m.body)
    mems.toSeq
  }

  def run(c: Circuit) = {
    val uniqueMems = ArrayBuffer[SMem]()
    def analyzeMemsInCircuit(c: Circuit) = {
      val mems = ArrayBuffer[SMem]()
      c.modules foreach { _ match {
        case m: Module => mems ++= analyzeMemsInModule(m)
        case m: ExtModule =>
      }}
      mems map {m =>
        val memProto = uniqueMems.find(_.eq(m))
        if (memProto == None) {
          uniqueMems += m
          m.name -> m
        }
        else m.name -> memProto.get.copy(m=m.m)
      }
    }
    val memMap = analyzeMemsInCircuit(c)
    println(memMap)
    c
  }

}

class ReplSeqMem(transID: TransID) extends Transform with LazyLogging {
  def execute(circuit:Circuit, map: AnnotationMap) = 
    map get transID match {
      case Some(p) => p get CircuitName(circuit.main) match {
        case Some(ReplSeqMemAnnotation(_, _)) => TransformResult((Seq(
          Legalize,
          ReplSeqMem,
          CheckInitialization,
          ResolveKinds,
          InferTypes,
          ResolveGenders) foldLeft circuit){ (c, pass) =>
            val x = Utils.time(pass.name)(pass run c)
            logger debug x.serialize
            x
          }, None, Some(map))
        case _ => TransformResult(circuit, None, Some(map))
      }
      case _ => TransformResult(circuit, None, Some(map))
    }
}