aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
blob: 43a46bf1abbc5c8cad47144eec158495e80a7a07 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// SPDX-License-Identifier: Apache-2.0

package firrtl.passes
package memlib

import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import MemPortUtils._
import MemTransformUtils._

/** Changes memory port names to standard port names (i.e. RW0 instead T_408)
  */
object RenameAnnotatedMemoryPorts extends Pass {

  /** Renames memory ports to a standard naming scheme:
    *    - R0, R1, ... for each read port
    *    - W0, W1, ... for each write port
    *    - RW0, RW1, ... for each readwrite port
    */
  def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = {
    val rports = m.readers.indices.map(i => s"R$i")
    val wports = m.writers.indices.map(i => s"W$i")
    val rwports = m.readwriters.indices.map(i => s"RW$i")
    m.copy(readers = rports, writers = wports, readwriters = rwports)
  }

  /** Maps the serialized form of all memory port field names to the
    *    corresponding new memory port field Expression.
    *  E.g.:
    *    - ("m.read.addr") becomes (m.R0.addr)
    */
  def getMemPortMap(m: DefAnnotatedMemory, memPortMap: MemPortMap): Unit = {
    val defaultFields = Seq("addr", "en", "clk")
    val rFields = defaultFields :+ "data"
    val wFields = rFields :+ "mask"
    val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask")

    def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit =
      for {
        (p, i) <- ports.zipWithIndex
        f <- fields
      } {
        val newPort = WSubField(WRef(m.name), newPortKind + i)
        val field = WSubField(newPort, f)
        memPortMap(s"${m.name}.$p.$f") = field
      }
    updateMemPortMap(m.readers, rFields, "R")
    updateMemPortMap(m.writers, wFields, "W")
    updateMemPortMap(m.readwriters, rwFields, "RW")
  }

  /** Replaces candidate memories with memories with standard port names
    * Does not update the references (this is done via updateStmtRefs)
    */
  def updateMemStmts(memPortMap: MemPortMap)(s: Statement): Statement = s match {
    case m: DefAnnotatedMemory =>
      val updatedMem = createMemProto(m)
      getMemPortMap(m, memPortMap)
      updatedMem
    case s => s.map(updateMemStmts(memPortMap))
  }

  /** Replaces candidate memories and their references with standard port names
    */
  def updateMemMods(m: DefModule) = {
    val memPortMap = new MemPortMap
    (m.map(updateMemStmts(memPortMap))
      .map(updateStmtRefs(memPortMap)))
  }

  def run(c: Circuit) = c.copy(modules = c.modules.map(updateMemMods))
}