aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
blob: 33a371a03b7b7dd3315955dd4daa8f5e11f923f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// See LICENSE for license details.

package firrtl.passes

import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import MemPortUtils._
import MemTransformUtils._
import AnalysisUtils._

class ReplaceMemMacros(writer: ConfWriter) extends Pass {
  def name = "Replace memories with black box wrappers" +
             " (optimizes when write mask isn't needed) + configuration file"

  // from Albert
  def createMemModule(m: DefMemory, wrapperName: String): Seq[DefModule] = {
    assert(m.dataType != UnknownType)
    val wrapperIoType = memToBundle(m)
    val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
    val bbIoType = memToFlattenBundle(m)
    val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
    val bbRef = createRef(m.name, bbIoType)
    val hasMask = containsInfo(m.info, "maskGran")
    val fillMask = getFillWMask(m)
    def portRef(p: String) = createRef(p, field_type(wrapperIoType, p))
    val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++
      (m.readers flatMap (r => adaptReader(portRef(r), createSubField(bbRef, r)))) ++
      (m.writers flatMap (w => adaptWriter(portRef(w), createSubField(bbRef, w), hasMask, fillMask))) ++
      (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), createSubField(bbRef, rw), hasMask, fillMask)))
    val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts))
    val bb = ExtModule(NoInfo, m.name, bbIoPorts)
    // TODO: Annotate? -- use actual annotation map

    // add to conf file
    writer.append(m)
    Seq(bb, wrapper)
  }

  // TODO: get rid of copy pasta
  def defaultConnects(wrapperPort: WRef, bbPort: WSubField) =
    Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f))

  def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean) =
    if (fillMask) toBitMask(mask, dataType) else toBits(mask)

  def adaptReader(wrapperPort: WRef, bbPort: WSubField) =
    defaultConnects(wrapperPort, bbPort) :+
    fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data"))

  def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = {
    val wrapperData = createSubField(wrapperPort, "data")
    val defaultSeq = defaultConnects(wrapperPort, bbPort) :+
      Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData))
    hasMask match {
      case false => defaultSeq
      case true => defaultSeq :+ Connect(
        NoInfo,
        createSubField(bbPort, "mask"),
        maskBits(createSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
      )
    }
  }

  def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = {
    val wrapperWData = createSubField(wrapperPort, "wdata")
    val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq(
      fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")),
      connectFields(bbPort, "wmode", wrapperPort, "wmode"), 
      Connect(NoInfo, createSubField(bbPort, "wdata"), toBits(wrapperWData)))
    hasMask match {
      case false => defaultSeq
      case true => defaultSeq :+ Connect(
        NoInfo,
        createSubField(bbPort, "wmask"),
        maskBits(createSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
      )
    }
  }

  def updateMemStmts(namespace: Namespace,
                     memPortMap: MemPortMap,
                     memMods: Modules)
                     (s: Statement): Statement = s match {
    case m: DefMemory if containsInfo(m.info, "useMacro") => 
      if (!containsInfo(m.info, "maskGran")) {
        m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression }
        m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression }
      }
      val info = getInfo(m.info, "info") match {
        case None => NoInfo
        case Some(p: Info) => p
      }
      getInfo(m.info, "ref") match {
        case None =>
          // prototype mem
          val newWrapperName = namespace newName m.name
          val newMemBBName = namespace newName s"${m.name}_ext"
          val newMem = m copy (name = newMemBBName)
          memMods ++= createMemModule(newMem, newWrapperName)
          WDefInstance(info, m.name, newWrapperName, UnknownType) 
        case Some(ref: String) =>
          WDefInstance(info, m.name, ref, UnknownType) 
      }
    case s => s map updateMemStmts(namespace, memPortMap, memMods)
  }

  def updateMemMods(namespace: Namespace, memMods: Modules)(m: DefModule) = {
    val memPortMap = new MemPortMap

    (m map updateMemStmts(namespace, memPortMap, memMods)
       map updateStmtRefs(memPortMap))
  }

  def run(c: Circuit) = {
    val namespace = Namespace(c)
    val memMods = new Modules
    val modules = c.modules map updateMemMods(namespace, memMods)
    // print conf
    writer.serialize
    c copy (modules = modules ++ memMods)
  }  
}