aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/RemoveAccesses.scala
blob: 073bf49d9343c5523d3742b0060256f023bf8b42 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
// SPDX-License-Identifier: Apache-2.0

package firrtl.passes

import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubField, WSubIndex}
import firrtl.PrimOps.{And, Eq}
import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.WrappedExpression._
import firrtl.options.Dependency

import scala.collection.mutable

/** Removes all [[firrtl.WSubAccess]] from circuit
  */
object RemoveAccesses extends Pass {

  override def prerequisites =
    Seq(
      Dependency(PullMuxes),
      Dependency(ZeroLengthVecs),
      Dependency(ReplaceAccesses),
      Dependency(ExpandConnects)
    ) ++ firrtl.stage.Forms.Deduped

  override def invalidates(a: Transform): Boolean = a match {
    case ResolveKinds | ResolveFlows => true
    case _                           => false
  }

  private def AND(e1: Expression, e2: Expression) =
    if (e1 == one) e2
    else if (e2 == one) e1
    else DoPrim(And, Seq(e1, e2), Nil, BoolType)

  private def EQV(e1: Expression, e2: Expression): Expression =
    DoPrim(Eq, Seq(e1, e2), Nil, BoolType)

  /** Container for a base expression and its corresponding guard
    */
  private case class Location(base: Expression, guard: Expression)

  /** Walks a referencing expression and returns a list of valid references
    * (base) and the corresponding guard which, if true, returns that base.
    * E.g. if called on a[i] where a: UInt[2], we would return:
    *   Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
    */
  private def getLocations(e: Expression): Seq[Location] = e match {
    case e: SubIndex =>
      val ls = getLocations(e.expr)
      val start = get_point(e)
      val end = start + get_size(e.tpe)
      val stride = get_size(e.expr.tpe)
      for (
        (l, i) <- ls.zipWithIndex
        if ((i % stride) >= start) & ((i % stride) < end)
      ) yield l
    case e: SubField =>
      val ls = getLocations(e.expr)
      val start = get_point(e)
      val end = start + get_size(e.tpe)
      val stride = get_size(e.expr.tpe)
      for (
        (l, i) <- ls.zipWithIndex
        if ((i % stride) >= start) & ((i % stride) < end)
      ) yield l
    case SubAccess(expr, index, tpe, _) =>
      getLocations(expr).zipWithIndex.flatMap {
        case (Location(exprBase, exprGuard), exprIndex) =>
          getLocations(index).map {
            case Location(indexBase, indexGuard) =>
              Location(
                exprBase,
                AND(
                  AND(
                    indexGuard,
                    exprGuard
                  ),
                  EQV(
                    UIntLiteral((exprIndex / get_size(tpe)) % expr.tpe.asInstanceOf[VectorType].size),
                    indexBase
                  )
                )
              )
          }
      }
    case e => create_exps(e).map(Location(_, one))
  }

  /** Returns true if e contains a [[firrtl.WSubAccess]]
    */
  private def hasAccess(e: Expression): Boolean = {
    var ret: Boolean = false
    def rec_has_access(e: Expression): Expression = {
      e match {
        case _: WSubAccess => ret = true
        case _ =>
      }
      e.map(rec_has_access)
    }
    rec_has_access(e)
    ret
  }

  // This improves the performance of this pass
  private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]()
  private def create_exps(e: Expression) =
    createExpsCache.getOrElseUpdate(e, firrtl.Utils.create_exps(e))

  def run(c: Circuit): Circuit = {
    def remove_m(m: Module): Module = {
      val namespace = Namespace(m)
      def onStmt(s: Statement): Statement = {
        def create_temp(e: Expression): (Statement, Expression) = {
          val n = namespace.newName(niceName(e))
          (DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), flow(e)))
        }

        /** Replaces a subaccess in a given source expression
          */
        val stmts = mutable.ArrayBuffer[Statement]()
        // Only called on RefLikes that definitely have a SubAccess
        // Must accept Expression because that's the output type of fixIndices
        def removeSource(e: Expression): Expression = {
          val rs = getLocations(e)
          rs.find(x => x.guard != one) match {
            case None => throwInternalError(s"removeSource: shouldn't be here - $e")
            case Some(_) =>
              val (wire, temp) = create_temp(e)
              val temps = create_exps(temp)
              def getTemp(i: Int) = temps(i % temps.size)
              stmts += wire
              rs.zipWithIndex.foreach {
                case (x, i) if i < temps.size =>
                  stmts += IsInvalid(get_info(s), getTemp(i))
                  stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
                case (x, i) =>
                  stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
              }
              temp
          }
        }

        /** Replaces a subaccess in a given sink expression
          */
        def removeSink(info: Info, loc: Expression): Expression = loc match {
          case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(loc) =>
            val ls = getLocations(loc)
            if (ls.size == 1 & weq(ls.head.guard, one)) loc
            else {
              val (wire, temp) = create_temp(loc)
              stmts += wire
              ls.foreach(x =>
                stmts +=
                  Conditionally(info, x.guard, Connect(info, x.base, temp), EmptyStmt)
              )
              temp
            }
          case _ => loc
        }

        /** Recurse until find SubAccess and call fixSource on its index
          * @note this only accepts [[RefLikeExpression]]s but we can't enforce it because map
          *       requires Expression => Expression
          */
        def fixIndices(e: Expression): Expression = e match {
          case e: SubAccess => e.copy(index = fixSource(e.index))
          case other => other.map(fixIndices)
        }

        /** Recursively walks a source expression and fixes all subaccesses
          *
          * If we see a RefLikeExpression that contains a SubAccess, we recursively remove
          * subaccesses from the indices of any SubAccesses, then process modified RefLikeExpression
          */
        def fixSource(e: Expression): Expression = e match {
          case ref: RefLikeExpression =>
            if (hasAccess(ref)) removeSource(fixIndices(ref)) else ref
          case x => x.map(fixSource)
        }

        /** Recursively walks a sink expression and fixes all subaccesses
          * If we see a sub-access, its index is a source expression, and we must replace it.
          * Otherwise, map to children.
          */
        def fixSink(e: Expression): Expression = e match {
          case w: WSubAccess => WSubAccess(fixSink(w.expr), fixSource(w.index), w.tpe, w.flow)
          case x => x.map(fixSink)
        }

        val sx = s match {
          case Connect(info, loc, exp) =>
            Connect(info, removeSink(info, fixSink(loc)), fixSource(exp))
          case sxx => sxx.map(fixSource).map(onStmt)
        }
        stmts += sx
        if (stmts.size != 1) Block(stmts.toSeq) else stmts(0)
      }
      Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body)))
    }

    c.copy(modules = c.modules.map {
      case m: ExtModule => m
      case m: Module    => remove_m(m)
    })
  }
}