1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
|
// SPDX-License-Identifier: Apache-2.0
package firrtl.passes
import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubField, WSubIndex}
import firrtl.PrimOps.{And, Eq}
import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.WrappedExpression._
import firrtl.options.Dependency
import scala.collection.mutable
/** Removes all [[firrtl.WSubAccess]] from circuit
*/
object RemoveAccesses extends Pass {
override def prerequisites =
Seq(
Dependency(PullMuxes),
Dependency(ZeroLengthVecs),
Dependency(ReplaceAccesses),
Dependency(ExpandConnects)
) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case ResolveKinds | ResolveFlows => true
case _ => false
}
private def AND(e1: Expression, e2: Expression) =
if (e1 == one) e2
else if (e2 == one) e1
else DoPrim(And, Seq(e1, e2), Nil, BoolType)
private def EQV(e1: Expression, e2: Expression): Expression =
DoPrim(Eq, Seq(e1, e2), Nil, BoolType)
/** Container for a base expression and its corresponding guard
*/
private case class Location(base: Expression, guard: Expression)
/** Walks a referencing expression and returns a list of valid references
* (base) and the corresponding guard which, if true, returns that base.
* E.g. if called on a[i] where a: UInt[2], we would return:
* Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
*/
private def getLocations(e: Expression): Seq[Location] = e match {
case e: SubIndex =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
val stride = get_size(e.expr.tpe)
for (
(l, i) <- ls.zipWithIndex
if ((i % stride) >= start) & ((i % stride) < end)
) yield l
case e: SubField =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
val stride = get_size(e.expr.tpe)
for (
(l, i) <- ls.zipWithIndex
if ((i % stride) >= start) & ((i % stride) < end)
) yield l
case SubAccess(expr, index, tpe, _) =>
getLocations(expr).zipWithIndex.flatMap {
case (Location(exprBase, exprGuard), exprIndex) =>
getLocations(index).map {
case Location(indexBase, indexGuard) =>
Location(
exprBase,
AND(
AND(
indexGuard,
exprGuard
),
EQV(
UIntLiteral((exprIndex / get_size(tpe)) % expr.tpe.asInstanceOf[VectorType].size),
indexBase
)
)
)
}
}
case e => create_exps(e).map(Location(_, one))
}
/** Returns true if e contains a [[firrtl.WSubAccess]]
*/
private def hasAccess(e: Expression): Boolean = {
var ret: Boolean = false
def rec_has_access(e: Expression): Expression = {
e match {
case _: WSubAccess => ret = true
case _ =>
}
e.map(rec_has_access)
}
rec_has_access(e)
ret
}
// This improves the performance of this pass
private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]()
private def create_exps(e: Expression) =
createExpsCache.getOrElseUpdate(e, firrtl.Utils.create_exps(e))
def run(c: Circuit): Circuit = {
def remove_m(m: Module): Module = {
val namespace = Namespace(m)
def onStmt(s: Statement): Statement = {
def create_temp(e: Expression): (Statement, Expression) = {
val n = namespace.newName(niceName(e))
(DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), flow(e)))
}
/** Replaces a subaccess in a given source expression
*/
val stmts = mutable.ArrayBuffer[Statement]()
// Only called on RefLikes that definitely have a SubAccess
// Must accept Expression because that's the output type of fixIndices
def removeSource(e: Expression): Expression = {
val rs = getLocations(e)
rs.find(x => x.guard != one) match {
case None => throwInternalError(s"removeSource: shouldn't be here - $e")
case Some(_) =>
val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
stmts += wire
rs.zipWithIndex.foreach {
case (x, i) if i < temps.size =>
stmts += IsInvalid(get_info(s), getTemp(i))
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
case (x, i) =>
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
}
temp
}
}
/** Replaces a subaccess in a given sink expression
*/
def removeSink(info: Info, loc: Expression): Expression = loc match {
case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(loc) =>
val ls = getLocations(loc)
if (ls.size == 1 & weq(ls.head.guard, one)) loc
else {
val (wire, temp) = create_temp(loc)
stmts += wire
ls.foreach(x =>
stmts +=
Conditionally(info, x.guard, Connect(info, x.base, temp), EmptyStmt)
)
temp
}
case _ => loc
}
/** Recurse until find SubAccess and call fixSource on its index
* @note this only accepts [[RefLikeExpression]]s but we can't enforce it because map
* requires Expression => Expression
*/
def fixIndices(e: Expression): Expression = e match {
case e: SubAccess => e.copy(index = fixSource(e.index))
case other => other.map(fixIndices)
}
/** Recursively walks a source expression and fixes all subaccesses
*
* If we see a RefLikeExpression that contains a SubAccess, we recursively remove
* subaccesses from the indices of any SubAccesses, then process modified RefLikeExpression
*/
def fixSource(e: Expression): Expression = e match {
case ref: RefLikeExpression =>
if (hasAccess(ref)) removeSource(fixIndices(ref)) else ref
case x => x.map(fixSource)
}
/** Recursively walks a sink expression and fixes all subaccesses
* If we see a sub-access, its index is a source expression, and we must replace it.
* Otherwise, map to children.
*/
def fixSink(e: Expression): Expression = e match {
case w: WSubAccess => WSubAccess(fixSink(w.expr), fixSource(w.index), w.tpe, w.flow)
case x => x.map(fixSink)
}
val sx = s match {
case Connect(info, loc, exp) =>
Connect(info, removeSink(info, fixSink(loc)), fixSource(exp))
case sxx => sxx.map(fixSource).map(onStmt)
}
stmts += sx
if (stmts.size != 1) Block(stmts.toSeq) else stmts(0)
}
Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body)))
}
c.copy(modules = c.modules.map {
case m: ExtModule => m
case m: Module => remove_m(m)
})
}
}
|