aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
blob: c1d54964b1f17423cf8b530aa2b84a853e4f5e1d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package firrtl.passes

import scala.collection.mutable.{HashMap,ArrayBuffer}
import firrtl.ir._
import AnalysisUtils._
import MemTransformUtils._
import firrtl._
import firrtl.Utils._

object ReplaceMemMacros extends Pass {

  def name = "Replace memories with black box wrappers (optimizes when write mask isn't needed)"

  def run(c: Circuit) = {

    lazy val moduleNamespace = Namespace(c)
    val memMods = ArrayBuffer[DefModule]()
    val uniqueMems = ArrayBuffer[DefMemory]()

    def updateMemMods(m: Module) = {
      val memPortMap = HashMap[String,Expression]()

      def updateMemStmts(s: Statement): Statement = s match {
        case m: DefMemory if containsInfo(m.info,"useMacro") => 
          if(!containsInfo(m.info,"maskGran")) {
            m.writers foreach {w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression}
            m.readwriters foreach {w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression}
          }
          val infoT = getInfo(m.info,"info")
          val info = if (infoT == None) NoInfo else infoT.get match {case i: Info => i}
          val ref = getInfo(m.info,"ref")
          
          // prototype mem
          if (ref == None) {
            val newName = moduleNamespace.newName(m.name)
            val newMem = m.copy(name = newName)
            memMods ++= createMemModule(newMem)
            uniqueMems += newMem
            WDefInstance(info, m.name, newMem.name, UnknownType) 
          }
          else {
            val r = ref.get match {case s: String => s}
            WDefInstance(info, m.name, r, UnknownType) 
          }
        case b: Block => Block(b.stmts map updateMemStmts)
        case s => s
      }

      val updatedMems = updateMemStmts(m.body)
      val updatedConns = updateStmtRefs(updatedMems,memPortMap.toMap)
      m.copy(body = updatedConns)
    }

    val updatedMods = c.modules map {
      case m: Module => updateMemMods(m)
      case m: ExtModule => m
    }

    c.copy(modules = updatedMods ++ memMods.toSeq) 
  }  

  // from Albert
  def createMemModule(m: DefMemory): Seq[DefModule] = {
    assert(m.dataType != UnknownType)
    val bbName = m.name + "_ext"
    val stmts = ArrayBuffer[Statement]()
    val wrapperioPorts = MemPortUtils.memToBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) 
    val bbProto = m.copy(dataType = UIntType(IntWidth(bitWidth(m.dataType))))
    val bbioPorts = MemPortUtils.memToBundle(bbProto).fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) 
    stmts += WDefInstance(m.info,bbName,bbName,UnknownType)
    val bbRef = createRef(bbName)
    stmts ++= (m.readers zip bbProto.readers).map{ 
      case (x,y) => adaptReader(createRef(x),m,createSubField(bbRef,y),bbProto)
    }.flatten 
    stmts ++= (m.writers zip bbProto.writers).map{ 
      case (x,y) => adaptWriter(createRef(x),m,createSubField(bbRef,y),bbProto)
    }.flatten 
    stmts ++= (m.readwriters zip bbProto.readwriters).map{ 
      case (x,y) => adaptReadWriter(createRef(x),m,createSubField(bbRef,y),bbProto)
    }.flatten  
    val wrapper = Module(m.info,m.name,wrapperioPorts,Block(stmts))   

    println(wrapper.body.serialize)

    val bb = ExtModule(m.info,bbName,bbioPorts) 
    Seq(bb,wrapper)
  }

  def adaptReader(aggPort: Expression, aggMem: DefMemory, groundPort: Expression, groundMem: DefMemory) = Seq(
    connectFields(groundPort,"addr",aggPort,"addr"),
    connectFields(groundPort,"en",aggPort,"en"),
    connectFields(groundPort,"clk",aggPort,"clk"),
    fromBits(
      WSubField(aggPort,"data",aggMem.dataType,UNKNOWNGENDER),
      WSubField(groundPort,"data",groundMem.dataType,UNKNOWNGENDER)
    )
  )

  def adaptWriter(aggPort: Expression, aggMem: DefMemory, groundPort: Expression, groundMem: DefMemory) = {
    val defaultSeq = Seq(
      connectFields(groundPort,"addr",aggPort,"addr"),
      connectFields(groundPort,"en",aggPort,"en"),
      connectFields(groundPort,"clk",aggPort,"clk"),
      Connect(
        NoInfo,
        WSubField(groundPort,"data",groundMem.dataType,UNKNOWNGENDER),
        toBits(WSubField(aggPort,"data",aggMem.dataType,UNKNOWNGENDER))
      )
    )
    if (containsInfo(aggMem.info,"maskGran"))
      defaultSeq :+ Connect(
        NoInfo,
        WSubField(groundPort,"mask",create_mask(groundMem.dataType),UNKNOWNGENDER),
        toBitMask(WSubField(aggPort,"mask",create_mask(aggMem.dataType),UNKNOWNGENDER),aggMem.dataType)
      )
    else defaultSeq
  }

  def adaptReadWriter(aggPort: Expression, aggMem: DefMemory, groundPort: Expression, groundMem: DefMemory) = {
    val defaultSeq = Seq(
      connectFields(groundPort,"addr",aggPort,"addr"),
      connectFields(groundPort,"en",aggPort,"en"),
      connectFields(groundPort,"clk",aggPort,"clk"),
      connectFields(groundPort,"wmode",aggPort,"wmode"),
      Connect(
        NoInfo,
        WSubField(groundPort,"wdata",groundMem.dataType,UNKNOWNGENDER),
        toBits(WSubField(aggPort,"wdata",aggMem.dataType,UNKNOWNGENDER))
      ),
      fromBits(
        WSubField(aggPort,"rdata",aggMem.dataType,UNKNOWNGENDER),
        WSubField(groundPort,"rdata",groundMem.dataType,UNKNOWNGENDER)
      )
    )
    if (containsInfo(aggMem.info,"maskGran"))
      defaultSeq :+ Connect(
        NoInfo,
        WSubField(groundPort,"wmask",create_mask(groundMem.dataType),UNKNOWNGENDER),
        toBitMask(WSubField(aggPort,"wmask",create_mask(aggMem.dataType),UNKNOWNGENDER),aggMem.dataType)
      )
    else defaultSeq
  }

}