aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/SimplifyMems.scala
blob: 90c26efc52132550024ea7c78043cc5b661af733 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// SPDX-License-Identifier: Apache-2.0

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import firrtl.options.Dependency
import firrtl.passes._
import firrtl.passes.memlib._
import firrtl.stage.Forms
import firrtl.renamemap.MutableRenameMap
import scala.collection.mutable

import AnalysisUtils._
import MemPortUtils._
import ResolveMaskGranularity._

/**
  * Lowers memories without splitting them, but without the complexity of ReplaceMemMacros
  */
class SimplifyMems extends Transform with DependencyAPIMigration {

  override def prerequisites = Forms.MidForm
  override def optionalPrerequisites = Seq(Dependency[InferReadWrite])
  override def optionalPrerequisiteOf = Forms.MidEmitters
  override def invalidates(a: Transform) = a match {
    case InferTypes => true
    case _          => false
  }

  @deprecated("Use version that accepts renamemap.MutableRenameMap", "FIRRTL 1.5")
  def onModule(c: Circuit, renames: RenameMap)(m: DefModule): DefModule =
    // Cast is safe because RenameMap is sealed trait and MutableRenameMap is only subclass
    onModule(c, renames.asInstanceOf[MutableRenameMap])(m)

  def onModule(c: Circuit, renames: MutableRenameMap)(m: DefModule): DefModule = {
    val moduleNS = Namespace(m)
    val connects = getConnects(m)
    val memAdapters = new mutable.LinkedHashMap[String, DefWire]
    val mTarget = ModuleTarget(c.main, m.name)

    def onExpr(e: Expression): Expression = e.map(onExpr) match {
      case wr @ WRef(name, _, MemKind, _) if memAdapters.contains(name) => wr.copy(kind = WireKind)
      case e                                                            => e
    }

    def simplifyMem(mem: DefMemory): Statement = {
      val adapterDecl = DefWire(mem.info, mem.name, memType(mem))
      val simpleMemDecl =
        mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType))
      val oldRT = mTarget.ref(mem.name)
      val adapterConnects = memType(simpleMemDecl).fields.flatMap {
        case Field(pName, Flip, pType: BundleType) =>
          val memPort = WSubField(WRef(simpleMemDecl), pName)
          val adapterPort = WSubField(WRef(adapterDecl), pName)
          renames.delete(oldRT.field(pName))
          pType.fields.map {
            case Field(name, Flip, _) if name.contains("data") => // read data
              fromBits(WSubField(adapterPort, name), WSubField(memPort, name))
            case Field(name, Default, _) if name.contains("data") => // write data
              Connect(mem.info, WSubField(memPort, name), toBits(WSubField(adapterPort, name)))
            case Field(name, Default, _) if name.contains("mask") => // mask
              Connect(mem.info, WSubField(memPort, name), Utils.one)
            case Field(name, _, _) => // etc
              Connect(mem.info, WSubField(memPort, name), WSubField(adapterPort, name))
          }
      }
      memAdapters(mem.name) = adapterDecl
      renames.record(oldRT, oldRT.copy(ref = simpleMemDecl.name))
      Block(Seq(adapterDecl, simpleMemDecl) ++ adapterConnects)
    }

    def canSimplify(mem: DefMemory) = mem.dataType match {
      case at: AggregateType =>
        val wMasks =
          mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask")))
        val rwMasks =
          mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask")))
        (wMasks ++ rwMasks).flatten.isEmpty
      case _ => false
    }

    def onStmt(s: Statement): Statement = s match {
      case mem: DefMemory if canSimplify(mem) => simplifyMem(mem)
      case s => s.map(onStmt).map(onExpr)
    }

    m.map(onStmt)
  }

  override def execute(state: CircuitState): CircuitState = {
    val c = state.circuit
    val renames = MutableRenameMap()
    state.copy(circuit = c.map(onModule(c, renames)), renames = Some(renames))
  }
}