1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
// SPDX-License-Identifier: Apache-2.0
package firrtl.passes
package memlib
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import WrappedExpression.weq
import AnalysisUtils._
import MemTransformUtils._
object AnalysisUtils {
type Connects = collection.mutable.HashMap[String, Expression]
/** Builds a map from named component to assigned value
* Named components are serialized LHS of connections, nodes, invalids
*/
def getConnects(m: DefModule): Connects = {
def getConnects(connects: Connects)(s: Statement): Statement = {
s match {
case Connect(_, loc, expr) =>
connects(loc.serialize) = expr
case DefNode(_, name, value) =>
connects(name) = value
case IsInvalid(_, value) =>
connects(value.serialize) = WInvalid
case _ => // do nothing
}
s.map(getConnects(connects))
}
val connects = new Connects
m.map(getConnects(connects))
connects
}
/** Find a connection LHS's origin from a module's list of node-to-node connections
* regardless of whether constant propagation has been run.
* Will search past trivial primop/mux's which do not affect its origin.
* Limitations:
* - Only works in a module (stops @ module inputs)
* - Only does trivial primop/mux's (is not complete)
* TODO(shunshou): implement more equivalence cases (i.e. a + 0 = a)
*/
def getOrigin(connects: Connects, s: String): Expression =
getOrigin(connects)(WRef(s, UnknownType, ExpKind, UnknownFlow))
def getOrigin(connects: Connects)(e: Expression): Expression = e match {
case Mux(cond, tv, fv, _) =>
val fvOrigin = getOrigin(connects)(fv)
val tvOrigin = getOrigin(connects)(tv)
val condOrigin = getOrigin(connects)(cond)
if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
else if (weq(condOrigin, one)) tvOrigin
else if (weq(condOrigin, zero)) fvOrigin
else if (weq(tvOrigin, fvOrigin)) tvOrigin
else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
else e
case DoPrim(PrimOps.Or, args, consts, tpe) if args.exists(weq(_, one)) => one
case DoPrim(PrimOps.And, args, consts, tpe) if args.exists(weq(_, zero)) => zero
case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
val extractionWidth = (msb - lsb) + 1
val nodeWidth = bitWidth(args.head.tpe)
// if you're extracting the full bitwidth, then keep searching for origin
if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e
case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
getOrigin(connects)(args.head)
// It is a correct optimization to treat ValidIf as a connection
case ValidIf(cond, value, _) => getOrigin(connects)(value)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
connects.get(e.serialize) match {
case Some(ex) => getOrigin(connects)(ex)
case None => e
}
case _ => e
}
}
/** Determines if a write mask is needed (wmode/en and wmask are equivalent).
* Populates the maskGran field of DefAnnotatedMemory
* Annotations:
* - maskGran = (dataType size) / (number of mask bits)
* - i.e. 1 if bitmask, 8 if bytemask, absent for no mask
* TODO(shunshou): Add floorplan info?
*/
object ResolveMaskGranularity extends Pass {
/** Returns the number of mask bits, if used
*/
def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
val wenOrigin = getOrigin(connects)(wen)
val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { (s: String) => getOrigin(connects, s) }
// all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one))
if (redundantMask) None else Some(wmaskOrigin.size)
}
/** Only annotate memories that are candidates for memory macro replacements
* i.e. rw, w + r (read, write 1 cycle delay)
*/
def updateStmts(connects: Connects)(s: Statement): Statement = s match {
case m: DefAnnotatedMemory =>
val dataBits = bitWidth(m.dataType)
val rwMasks =
m.readwriters.map(rw => getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
val wMasks = m.writers.map(w => getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
val maskGran = (rwMasks ++ wMasks).head match {
case None => None
case Some(maskBits) => Some(dataBits / maskBits)
}
m.copy(maskGran = maskGran)
case sx => sx.map(updateStmts(connects))
}
def annotateModMems(m: DefModule): DefModule = m.map(updateStmts(getConnects(m)))
def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(annotateModMems))
}
|