aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/RemoveAccesses.scala
blob: 539971f559e9cc333f1894f14a39af26411a64b2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package firrtl.passes

import firrtl.{WRef, WSubAccess, WSubIndex, WSubField, Namespace}
import firrtl.PrimOps.{And, Eq}
import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.WrappedExpression._
import scala.collection.mutable


/** Removes all [[firrtl.WSubAccess]] from circuit
  */
object RemoveAccesses extends Pass {
  def name = "Remove Accesses"

  private def AND(e1: Expression, e2: Expression) =
    DoPrim(And, Seq(e1, e2), Nil, BoolType)

  private def EQV(e1: Expression, e2: Expression): Expression =
    DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe)

  /** Container for a base expression and its corresponding guard
    */
  private case class Location(base: Expression, guard: Expression)

  /** Walks a referencing expression and returns a list of valid references
    * (base) and the corresponding guard which, if true, returns that base.
    * E.g. if called on a[i] where a: UInt[2], we would return:
    *   Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
    */
  private def getLocations(e: Expression): Seq[Location] = e match {
    case e: WRef => create_exps(e).map(Location(_,one))
    case e: WSubIndex =>
      val ls = getLocations(e.exp)
      val start = get_point(e)
      val end = start + get_size(e.tpe)
      val stride = get_size(e.exp.tpe)
      for ((l, i) <- ls.zipWithIndex
        if ((i % stride) >= start) & ((i % stride) < end)) yield l
    case e: WSubField =>
      val ls = getLocations(e.exp)
      val start = get_point(e)
      val end = start + get_size(e.tpe)
      val stride = get_size(e.exp.tpe)
      for ((l, i) <- ls.zipWithIndex
        if ((i % stride) >= start) & ((i % stride) < end)) yield l
    case e: WSubAccess =>
      val ls = getLocations(e.exp)
      val stride = get_size(e.tpe)
      val wrap = e.exp.tpe.asInstanceOf[VectorType].size
      ls.zipWithIndex map {case (l, i) =>
        val c = (i / stride) % wrap
        val basex = l.base
        val guardx = AND(l.guard,EQV(uint(c),e.index))
        Location(basex,guardx)
      }
  }

  /** Returns true if e contains a [[firrtl.WSubAccess]]
    */
  private def hasAccess(e: Expression): Boolean = {
    var ret: Boolean = false
    def rec_has_access(e: Expression): Expression = {
      e match {
        case _ : WSubAccess => ret = true
        case _ =>
      }
      e map rec_has_access
    }
    rec_has_access(e)
    ret
  }

  // This improves the performance of this pass
  private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]()
  private def create_exps(e: Expression) =
    createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e))

  def run(c: Circuit): Circuit = {
    def remove_m(m: Module): Module = {
      val namespace = Namespace(m)
      def onStmt(s: Statement): Statement = {
        def create_temp(e: Expression): (Statement, Expression) = {
          val n = namespace.newTemp
          (DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e)))
        }

        /** Replaces a subaccess in a given male expression
          */
        val stmts = mutable.ArrayBuffer[Statement]()
        def removeMale(e: Expression): Expression = e match {
          case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(e) =>
            val rs = getLocations(e)
            rs find (x => x.guard != one) match {
              case None => error("Shouldn't be here")
              case Some(_) =>
                val (wire, temp) = create_temp(e)
                val temps = create_exps(temp)
                def getTemp(i: Int) = temps(i % temps.size)
                stmts += wire
                rs.zipWithIndex foreach {
                  case (x, i) if i < temps.size =>
                    stmts += Connect(get_info(s),getTemp(i),x.base)
                  case (x, i) =>
                    stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
                }
                temp
            }
          case _ => e
        }

        /** Replaces a subaccess in a given female expression
          */
        def removeFemale(info: Info, loc: Expression): Expression = loc match {
          case (_: WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(loc) =>
            val ls = getLocations(loc)
            if (ls.size == 1 & weq(ls.head.guard,one)) loc
            else {
              val (wire, temp) = create_temp(loc)
              stmts += wire
              ls foreach (x => stmts +=
                Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt))
              temp
            }
          case _ => loc
        }

        /** Recursively walks a male expression and fixes all subaccesses
          * If we see a sub-access, replace it.
          * Otherwise, map to children.
          */
        def fixMale(e: Expression): Expression = e match {
          case w: WSubAccess => removeMale(WSubAccess(w.exp, fixMale(w.index), w.tpe, w.gender))
          //case w: WSubIndex => removeMale(w)
          //case w: WSubField => removeMale(w)
          case x => x map fixMale
        }

        /** Recursively walks a female expression and fixes all subaccesses
          * If we see a sub-access, its index is a male expression, and we must replace it.
          * Otherwise, map to children.
          */
        def fixFemale(e: Expression): Expression = e match {
          case w: WSubAccess => WSubAccess(fixFemale(w.exp), fixMale(w.index), w.tpe, w.gender)
          case x => x map fixFemale
        }

        val sx = s match {
          case Connect(info, loc, exp) =>
            Connect(info, removeFemale(info, fixFemale(loc)), fixMale(exp))
          case sxx => sxx map fixMale map onStmt
        }
        stmts += sx
        if (stmts.size != 1) Block(stmts) else stmts(0)
      }
      Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body)))
    }
  
    c copy (modules = c.modules map {
      case m: ExtModule => m
      case m: Module => remove_m(m)
    })
  }
}