aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
blob: 0a685c3caf5854f2382bd8ba37f1502011fd58e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// See LICENSE for license details.

package firrtl.passes

import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import AnalysisUtils._
import MemTransformUtils._

object MemTransformUtils {

  type MemPortMap = collection.mutable.HashMap[String, Expression]
  type Memories = collection.mutable.ArrayBuffer[DefMemory]
  type Modules = collection.mutable.ArrayBuffer[DefModule]

  def createRef(n: String, t: Type = UnknownType, k: Kind = ExpKind) =
    WRef(n, t, k, UNKNOWNGENDER)
  def createSubField(exp: Expression, n: String) =
    WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER)
  def connectFields(lref: Expression, lname: String, rref: Expression, rname: String) = 
    Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname))

  def getMemPortMap(m: DefMemory) = {
    val memPortMap = new MemPortMap
    val defaultFields = Seq("addr", "en", "clk")
    val rFields = defaultFields :+ "data"
    val wFields = rFields :+ "mask"
    val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask")

    def updateMemPortMap(ports: Seq[String], fields: Seq[String], portType: String) = 
      for ((p, i) <- ports.zipWithIndex; f <- fields) {
        val newPort = createSubField(createRef(m.name), portType+i)        
        val field = createSubField(newPort, f)
        memPortMap(s"${m.name}.${p}.${f}") = field
      }
    updateMemPortMap(m.readers, rFields, "R")
    updateMemPortMap(m.writers, wFields, "W")
    updateMemPortMap(m.readwriters, rwFields, "RW")
    memPortMap
  }

  def createMemProto(m: DefMemory) = {
    val rports = (0 until m.readers.length) map (i => s"R$i")
    val wports = (0 until m.writers.length) map (i => s"W$i")
    val rwports = (0 until m.readwriters.length) map (i => s"RW$i")
    m copy (readers = rports, writers = wports, readwriters = rwports)
  }

  def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = {
    def updateRef(e: Expression): Expression = {
      val ex = e map updateRef
      repl getOrElse (ex.serialize, ex)
    }

    def hasEmptyExpr(stmt: Statement): Boolean = {
      var foundEmpty = false
      def testEmptyExpr(e: Expression): Expression = {
        e match {
          case EmptyExpression => foundEmpty = true
          case _ =>
        }
        e map testEmptyExpr // map must return; no foreach
      }
      stmt map testEmptyExpr
      foundEmpty
    }

    def updateStmtRefs(s: Statement): Statement =
      s map updateStmtRefs map updateRef match {
        case c: Connect if hasEmptyExpr(c) => EmptyStmt
        case s => s
      }

    updateStmtRefs(s)
  }

}

object UpdateDuplicateMemMacros extends Pass {

  def name = "Convert memory port names to be more meaningful and tag duplicate memories"

  def updateMemStmts(uniqueMems: Memories,
                     memPortMap: MemPortMap)
                     (s: Statement): Statement = s match {
    case m: DefMemory if containsInfo(m.info, "useMacro") => 
      val updatedMem = createMemProto(m)
      memPortMap ++= getMemPortMap(m)
      uniqueMems find (x => eqMems(x, updatedMem)) match {
        case None =>
          uniqueMems += updatedMem
          updatedMem
        case Some(proto) =>
          updatedMem copy (info = appendInfo(updatedMem.info, "ref" -> proto.name))
      }
    case s => s map updateMemStmts(uniqueMems, memPortMap)
  }

  def updateMemMods(m: DefModule) = {
    val uniqueMems = new Memories
    val memPortMap = new MemPortMap
    (m map updateMemStmts(uniqueMems, memPortMap)
       map updateStmtRefs(memPortMap))
  }

  def run(c: Circuit) = c copy (modules = (c.modules map updateMemMods)) 
}
// TODO: Module namespace?