aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/RemoveAccesses.scala
blob: 0309e7a73aaab70f0832a9484e6b7fbfc4493903 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package firrtl.passes

import firrtl.ir._
import firrtl.{WRef, WSubAccess, WSubIndex, WSubField}
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.WrappedExpression._
import firrtl.Namespace
import scala.collection.mutable


/** Removes all [[firrtl.WSubAccess]] from circuit
  */
object RemoveAccesses extends Pass {
  def name = "Remove Accesses"

  /** Container for a base expression and its corresponding guard
    */
  case class Location(base: Expression, guard: Expression)

  /** Walks a referencing expression and returns a list of valid references
    * (base) and the corresponding guard which, if true, returns that base.
    * E.g. if called on a[i] where a: UInt[2], we would return:
    *   Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
    */
  def getLocations(e: Expression): Seq[Location] = e match {
    case e: WRef => create_exps(e).map(Location(_,one))
    case e: WSubIndex =>
      val ls = getLocations(e.exp)
      val start = get_point(e)
      val end = start + get_size(tpe(e))
      val stride = get_size(tpe(e.exp))
      val lsx = mutable.ArrayBuffer[Location]()
      var c = 0
      for (i <- 0 until ls.size) {
        if (((i % stride) >= start) & ((i % stride) < end)) {
          lsx += ls(i)
        }
      }
      lsx
    case e: WSubField =>
      val ls = getLocations(e.exp)
      val start = get_point(e)
      val end = start + get_size(tpe(e))
      val stride = get_size(tpe(e.exp))
      val lsx = mutable.ArrayBuffer[Location]()
      var c = 0
      for (i <- 0 until ls.size) {
        if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) }
      }
      lsx
    case e: WSubAccess =>
      val ls = getLocations(e.exp)
      val stride = get_size(tpe(e))
      val wrap = tpe(e.exp).asInstanceOf[VectorType].size
      val lsx = mutable.ArrayBuffer[Location]()
      var c = 0
      for (i <- 0 until ls.size) {
        if ((c % wrap) == 0) { c = 0 }
        val basex = ls(i).base
        val guardx = AND(ls(i).guard,EQV(uint(c),e.index))
        lsx += Location(basex,guardx)
        if ((i + 1) % stride == 0) {
          c = c + 1
        }
      }
      lsx
  }
  /** Returns true if e contains a [[firrtl.WSubAccess]]
    */
  def hasAccess(e: Expression): Boolean = {
    var ret: Boolean = false
    def rec_has_access(e: Expression): Expression = e match {
      case (e:WSubAccess) => { ret = true; e }
      case (e) => e map (rec_has_access)
    }
    rec_has_access(e)
    ret
  }
  def run(c: Circuit): Circuit = {
    def remove_m(m: Module): Module = {
      val namespace = Namespace(m)
      def onStmt(s: Statement): Statement = {
        val stmts = mutable.ArrayBuffer[Statement]()
        def create_temp(e: Expression): Expression = {
          val n = namespace.newTemp
          stmts += DefWire(info(s), n, tpe(e))
          WRef(n, tpe(e), kind(e), gender(e))
        }

        /** Replaces a subaccess in a given male expression
          */
        def removeMale(e: Expression): Expression = e match {
          case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(e)) => 
            val rs = getLocations(e)
            val foo = rs.find(x => {x.guard != one})
            foo match {
              case None => error("Shouldn't be here")
              case foo: Some[Location] =>
                val temp = create_temp(e)
                val temps = create_exps(temp)
                def getTemp(i: Int) = temps(i % temps.size)
                (rs,0 until rs.size).zipped.foreach { (x,i) => 
                  if (i < temps.size) {
                    stmts += Connect(info(s),getTemp(i),x.base)
                  } else {
                    stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt)
                  }
                }
                temp
            }
          case _ => e
        }

        /** Replaces a subaccess in a given female expression
          */
        def removeFemale(info: Info, loc: Expression): Expression = loc match {
          case (_: WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(loc)) => 
            val ls = getLocations(loc)
            if (ls.size == 1 & weq(ls(0).guard,one)) loc
            else {
              val temp = create_temp(loc)
              for (x <- ls) { stmts += Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt) }
              temp
            }
          case _ => loc
        }

        /** Recursively walks a male expression and fixes all subaccesses
          * If we see a sub-access, replace it.
          * Otherwise, map to children.
          */
        def fixMale(e: Expression): Expression = e match {
          case w: WSubAccess => removeMale(WSubAccess(w.exp, fixMale(w.index), w.tpe, w.gender))
          //case w: WSubIndex => removeMale(w)
          //case w: WSubField => removeMale(w)
          case x => x map fixMale
        }

        /** Recursively walks a female expression and fixes all subaccesses
          * If we see a sub-access, its index is a male expression, and we must replace it.
          * Otherwise, map to children.
          */
        def fixFemale(e: Expression): Expression = e match {
          case w: WSubAccess => WSubAccess(fixFemale(w.exp), fixMale(w.index), w.tpe, w.gender)
          case x => x map fixFemale
        }

        val sx = s match {
          case Connect(info, loc, exp) =>
            Connect(info, removeFemale(info, fixFemale(loc)), fixMale(exp))
          case (s) => s map (fixMale) map (onStmt)
        }
        stmts += sx
        if (stmts.size != 1) Block(stmts) else stmts(0)
      }
      Module(m.info, m.name, m.ports, onStmt(m.body))
    }
  
    val newModules = c.modules.map( _ match {
      case m: ExtModule => m
      case m: Module => remove_m(m)
    })
    Circuit(c.info, newModules, c.main)
  }
}