aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
blob: 7a1a57fbef44007b43d9502163357d5694ab14b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// SPDX-License-Identifier: Apache-2.0

package firrtl.passes
package memlib

import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import WrappedExpression.weq
import AnalysisUtils._
import MemTransformUtils._

object AnalysisUtils {
  type Connects = collection.mutable.HashMap[String, Expression]

  /** Builds a map from named component to assigned value
    * Named components are serialized LHS of connections, nodes, invalids
    */
  def getConnects(m: DefModule): Connects = {
    def getConnects(connects: Connects)(s: Statement): Statement = {
      s match {
        case Connect(_, loc, expr) =>
          connects(loc.serialize) = expr
        case DefNode(_, name, value) =>
          connects(name) = value
        case IsInvalid(_, value) =>
          connects(value.serialize) = WInvalid
        case _ => // do nothing
      }
      s.map(getConnects(connects))
    }
    val connects = new Connects
    m.map(getConnects(connects))
    connects
  }

  /** Find a connection LHS's origin from a module's list of node-to-node connections
    *   regardless of whether constant propagation has been run.
    * Will search past trivial primop/mux's which do not affect its origin.
    * Limitations:
    *   - Only works in a module (stops @ module inputs)
    *   - Only does trivial primop/mux's (is not complete)
    * TODO(shunshou): implement more equivalence cases (i.e. a + 0 = a)
    */
  def getOrigin(connects: Connects, s: String): Expression =
    getOrigin(connects)(WRef(s, UnknownType, ExpKind, UnknownFlow))
  def getOrigin(connects: Connects)(e: Expression): Expression = e match {
    case Mux(cond, tv, fv, _) =>
      val fvOrigin = getOrigin(connects)(fv)
      val tvOrigin = getOrigin(connects)(tv)
      val condOrigin = getOrigin(connects)(cond)
      if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
      else if (weq(condOrigin, one)) tvOrigin
      else if (weq(condOrigin, zero)) fvOrigin
      else if (weq(tvOrigin, fvOrigin)) tvOrigin
      else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
      else e
    case DoPrim(PrimOps.Or, args, consts, tpe) if args.exists(weq(_, one))   => one
    case DoPrim(PrimOps.And, args, consts, tpe) if args.exists(weq(_, zero)) => zero
    case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
      val extractionWidth = (msb - lsb) + 1
      val nodeWidth = bitWidth(args.head.tpe)
      // if you're extracting the full bitwidth, then keep searching for origin
      if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e
    case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
      getOrigin(connects)(args.head)
    // It is a correct optimization to treat ValidIf as a connection
    case ValidIf(cond, value, _) => getOrigin(connects)(value)
    // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
    case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
      connects.get(e.serialize) match {
        case Some(ex) => getOrigin(connects)(ex)
        case None     => e
      }
    case _ => e
  }
}

/** Determines if a write mask is needed (wmode/en and wmask are equivalent).
  * Populates the maskGran field of DefAnnotatedMemory
  * Annotations:
  *   - maskGran = (dataType size) / (number of mask bits)
  *      - i.e. 1 if bitmask, 8 if bytemask, absent for no mask
  * TODO(shunshou): Add floorplan info?
  */
object ResolveMaskGranularity extends Pass {

  /** Returns the number of mask bits, if used
    */
  def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
    val wenOrigin = getOrigin(connects)(wen)
    val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { s: String => getOrigin(connects, s) }
    // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
    val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one))
    if (redundantMask) None else Some(wmaskOrigin.size)
  }

  /** Only annotate memories that are candidates for memory macro replacements
    * i.e. rw, w + r (read, write 1 cycle delay)
    */
  def updateStmts(connects: Connects)(s: Statement): Statement = s match {
    case m: DefAnnotatedMemory =>
      val dataBits = bitWidth(m.dataType)
      val rwMasks =
        m.readwriters.map(rw => getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
      val wMasks = m.writers.map(w => getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
      val maskGran = (rwMasks ++ wMasks).head match {
        case None           => None
        case Some(maskBits) => Some(dataBits / maskBits)
      }
      m.copy(maskGran = maskGran)
    case sx => sx.map(updateStmts(connects))
  }

  def annotateModMems(m: DefModule): DefModule = m.map(updateStmts(getConnects(m)))
  def run(c:             Circuit):   Circuit = c.copy(modules = c.modules.map(annotateModMems))
}