aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
blob: 79ecd9cdbc95d438ed34e32eb252b1b5495406f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// See LICENSE for license details.

package firrtl.passes
package memlib

import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import WrappedExpression.weq
import AnalysisUtils._
import MemTransformUtils._

object AnalysisUtils {
  type Connects = collection.mutable.HashMap[String, Expression]

  /** Builds a map from named component to assigned value
    * Named components are serialized LHS of connections, nodes, invalids
    */
  def getConnects(m: DefModule): Connects = {
    def getConnects(connects: Connects)(s: Statement): Statement = {
      s match {
        case Connect(_, loc, expr) =>
          connects(loc.serialize) = expr
        case DefNode(_, name, value) =>
          connects(name) = value
        case IsInvalid(_, value) =>
          connects(value.serialize) = WInvalid
        case _ => // do nothing
      }
      s map getConnects(connects)
    }
    val connects = new Connects
    m map getConnects(connects)
    connects
  }

  /** Find a connection LHS's origin from a module's list of node-to-node connections
    *   regardless of whether constant propagation has been run.
    * Will search past trivial primop/mux's which do not affect its origin.
    * Limitations: 
    *   - Only works in a module (stops @ module inputs)
    *   - Only does trivial primop/mux's (is not complete)
    * TODO(shunshou): implement more equivalence cases (i.e. a + 0 = a)
    */
  def getOrigin(connects: Connects, s: String): Expression =
    getOrigin(connects)(WRef(s, UnknownType, ExpKind, UNKNOWNGENDER))
  def getOrigin(connects: Connects)(e: Expression): Expression = e match { 
    case Mux(cond, tv, fv, _) =>
      val fvOrigin = getOrigin(connects)(fv)
      val tvOrigin = getOrigin(connects)(tv)
      val condOrigin = getOrigin(connects)(cond)
      if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
      else if (weq(condOrigin, one)) tvOrigin
      else if (weq(condOrigin, zero)) fvOrigin
      else if (weq(tvOrigin, fvOrigin)) tvOrigin
      else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
      else e
    case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
    case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
    case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>  
      val extractionWidth = (msb - lsb) + 1
      val nodeWidth = bitWidth(args.head.tpe)
      // if you're extracting the full bitwidth, then keep searching for origin
      if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e
    case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => 
      getOrigin(connects)(args.head)
    case ValidIf(cond, value, ClockType) => getOrigin(connects)(value)
    // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
    case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
       connects get e.serialize match {
         case Some(ex) => getOrigin(connects)(ex)
         case None => e
       }
    case _ => e
  }

  /** Checks whether the two memories are equivalent in all respects except name
    */
  def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory, noDeDupeMems: Seq[String]) =
    a == b.copy(info = a.info, name = a.name, memRef = a.memRef) &&
    !(noDeDupeMems.contains(a.name) || noDeDupeMems.contains(b.name))
}

/** Determines if a write mask is needed (wmode/en and wmask are equivalent).
  * Populates the maskGran field of DefAnnotatedMemory
  * Annotations:
  *   - maskGran = (dataType size) / (number of mask bits)
  *      - i.e. 1 if bitmask, 8 if bytemask, absent for no mask
  * TODO(shunshou): Add floorplan info?
  */
object ResolveMaskGranularity extends Pass {

  /** Returns the number of mask bits, if used
    */
  def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
    val wenOrigin = getOrigin(connects)(wen)
    val wmaskOrigin = connects.keys filter
      (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)}
    // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
    val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
    if (redundantMask) None else Some(wmaskOrigin.size)
  }

  /** Only annotate memories that are candidates for memory macro replacements
    * i.e. rw, w + r (read, write 1 cycle delay)
    */
  def updateStmts(connects: Connects)(s: Statement): Statement = s match {
    case m: DefAnnotatedMemory =>
      val dataBits = bitWidth(m.dataType)
      val rwMasks = m.readwriters map (rw =>
        getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
      val wMasks = m.writers map (w =>
        getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
      val maskGran = (rwMasks ++ wMasks).head match {
        case None =>  None
        case Some(maskBits) => Some(dataBits / maskBits)
      }
      m.copy(maskGran = maskGran)
    case sx => sx map updateStmts(connects)
  }

  def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m))
  def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
}