aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-08-19 17:38:14 -0700
committerDonggyu Kim2016-09-07 10:40:35 -0700
commitabedfbd9dde6e2985f9bc93b53f53853a5ac82d6 (patch)
tree7a4ae6a9001bae05205b42d11a49be2cd1e483bd /src
parentd97fb73a6ea96e32689814326d47b39f55eff773 (diff)
clean up RemoveAccesses
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala73
1 files changed, 33 insertions, 40 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index d3340f2d..4b37f37a 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -16,57 +16,50 @@ object RemoveAccesses extends Pass {
/** Container for a base expression and its corresponding guard
*/
- case class Location(base: Expression, guard: Expression)
+ private case class Location(base: Expression, guard: Expression)
/** Walks a referencing expression and returns a list of valid references
* (base) and the corresponding guard which, if true, returns that base.
* E.g. if called on a[i] where a: UInt[2], we would return:
* Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
*/
- def getLocations(e: Expression): Seq[Location] = e match {
+ private def getLocations(e: Expression): Seq[Location] = e match {
case e: WRef => create_exps(e).map(Location(_,one))
case e: WSubIndex =>
val ls = getLocations(e.exp)
val start = get_point(e)
val end = start + get_size(tpe(e))
val stride = get_size(tpe(e.exp))
- val lsx = mutable.ArrayBuffer[Location]()
- for (i <- 0 until ls.size) {
- if (((i % stride) >= start) & ((i % stride) < end)) {
- lsx += ls(i)
- }
- }
- lsx
+ for ((l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)) yield l
case e: WSubField =>
val ls = getLocations(e.exp)
val start = get_point(e)
val end = start + get_size(tpe(e))
val stride = get_size(tpe(e.exp))
- val lsx = mutable.ArrayBuffer[Location]()
- for (i <- 0 until ls.size) {
- if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) }
- }
- lsx
+ for ((l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)) yield l
case e: WSubAccess =>
val ls = getLocations(e.exp)
val stride = get_size(tpe(e))
val wrap = tpe(e.exp).asInstanceOf[VectorType].size
- val lsx = mutable.ArrayBuffer[Location]()
- for (i <- 0 until ls.size) {
+ ls.zipWithIndex map {case (l, i) =>
val c = (i / stride) % wrap
- val basex = ls(i).base
- val guardx = AND(ls(i).guard,EQV(uint(c),e.index))
- lsx += Location(basex,guardx)
+ val basex = l.base
+ val guardx = AND(l.guard,EQV(uint(c),e.index))
+ Location(basex,guardx)
}
- lsx
}
/** Returns true if e contains a [[firrtl.WSubAccess]]
*/
- def hasAccess(e: Expression): Boolean = {
- var ret: Boolean = false
- def rec_has_access(e: Expression): Expression = e match {
- case (e:WSubAccess) => { ret = true; e }
- case (e) => e map (rec_has_access)
+ private def hasAccess(e: Expression): Boolean = {
+ var ret: Boolean = false
+ def rec_has_access(e: Expression): Expression = {
+ e match {
+ case e : WSubAccess => ret = true
+ case e =>
+ }
+ e map rec_has_access
}
rec_has_access(e)
ret
@@ -75,31 +68,29 @@ object RemoveAccesses extends Pass {
def remove_m(m: Module): Module = {
val namespace = Namespace(m)
def onStmt(s: Statement): Statement = {
- val stmts = mutable.ArrayBuffer[Statement]()
- def create_temp(e: Expression): Expression = {
+ def create_temp(e: Expression): (Statement, Expression) = {
val n = namespace.newTemp
- stmts += DefWire(info(s), n, tpe(e))
- WRef(n, tpe(e), kind(e), gender(e))
+ (DefWire(info(s), n, tpe(e)), WRef(n, tpe(e), kind(e), gender(e)))
}
/** Replaces a subaccess in a given male expression
*/
+ val stmts = mutable.ArrayBuffer[Statement]()
def removeMale(e: Expression): Expression = e match {
case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(e)) =>
val rs = getLocations(e)
- val foo = rs.find(x => {x.guard != one})
- foo match {
+ rs find (x => x.guard != one) match {
case None => error("Shouldn't be here")
- case foo: Some[Location] =>
- val temp = create_temp(e)
+ case Some(_) =>
+ val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
- for((x, i) <- rs.zipWithIndex) {
- if (i < temps.size) {
+ stmts += wire
+ rs.zipWithIndex foreach {
+ case (x, i) if i < temps.size =>
stmts += Connect(info(s),getTemp(i),x.base)
- } else {
+ case (x, i) =>
stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt)
- }
}
temp
}
@@ -113,8 +104,10 @@ object RemoveAccesses extends Pass {
val ls = getLocations(loc)
if (ls.size == 1 & weq(ls(0).guard,one)) loc
else {
- val temp = create_temp(loc)
- for (x <- ls) { stmts += Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt) }
+ val (wire, temp) = create_temp(loc)
+ stmts += wire
+ ls foreach (x => stmts +=
+ Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt))
temp
}
case _ => loc
@@ -148,7 +141,7 @@ object RemoveAccesses extends Pass {
stmts += sx
if (stmts.size != 1) Block(stmts) else stmts(0)
}
- Module(m.info, m.name, m.ports, onStmt(m.body))
+ Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body)))
}
val newModules = c.modules.map( _ match {