diff options
| author | Donggyu Kim | 2016-08-19 17:38:14 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-07 10:40:35 -0700 |
| commit | abedfbd9dde6e2985f9bc93b53f53853a5ac82d6 (patch) | |
| tree | 7a4ae6a9001bae05205b42d11a49be2cd1e483bd /src | |
| parent | d97fb73a6ea96e32689814326d47b39f55eff773 (diff) | |
clean up RemoveAccesses
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 73 |
1 files changed, 33 insertions, 40 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index d3340f2d..4b37f37a 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -16,57 +16,50 @@ object RemoveAccesses extends Pass { /** Container for a base expression and its corresponding guard */ - case class Location(base: Expression, guard: Expression) + private case class Location(base: Expression, guard: Expression) /** Walks a referencing expression and returns a list of valid references * (base) and the corresponding guard which, if true, returns that base. * E.g. if called on a[i] where a: UInt[2], we would return: * Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1))) */ - def getLocations(e: Expression): Seq[Location] = e match { + private def getLocations(e: Expression): Seq[Location] = e match { case e: WRef => create_exps(e).map(Location(_,one)) case e: WSubIndex => val ls = getLocations(e.exp) val start = get_point(e) val end = start + get_size(tpe(e)) val stride = get_size(tpe(e.exp)) - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { - lsx += ls(i) - } - } - lsx + for ((l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end)) yield l case e: WSubField => val ls = getLocations(e.exp) val start = get_point(e) val end = start + get_size(tpe(e)) val stride = get_size(tpe(e.exp)) - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) } - } - lsx + for ((l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end)) yield l case e: WSubAccess => val ls = getLocations(e.exp) val stride = get_size(tpe(e)) val wrap = tpe(e.exp).asInstanceOf[VectorType].size - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { + ls.zipWithIndex map {case (l, i) => val c = (i / stride) % wrap - val basex = ls(i).base - val guardx = AND(ls(i).guard,EQV(uint(c),e.index)) - lsx += Location(basex,guardx) + val basex = l.base + val guardx = AND(l.guard,EQV(uint(c),e.index)) + Location(basex,guardx) } - lsx } /** Returns true if e contains a [[firrtl.WSubAccess]] */ - def hasAccess(e: Expression): Boolean = { - var ret: Boolean = false - def rec_has_access(e: Expression): Expression = e match { - case (e:WSubAccess) => { ret = true; e } - case (e) => e map (rec_has_access) + private def hasAccess(e: Expression): Boolean = { + var ret: Boolean = false + def rec_has_access(e: Expression): Expression = { + e match { + case e : WSubAccess => ret = true + case e => + } + e map rec_has_access } rec_has_access(e) ret @@ -75,31 +68,29 @@ object RemoveAccesses extends Pass { def remove_m(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { - val stmts = mutable.ArrayBuffer[Statement]() - def create_temp(e: Expression): Expression = { + def create_temp(e: Expression): (Statement, Expression) = { val n = namespace.newTemp - stmts += DefWire(info(s), n, tpe(e)) - WRef(n, tpe(e), kind(e), gender(e)) + (DefWire(info(s), n, tpe(e)), WRef(n, tpe(e), kind(e), gender(e))) } /** Replaces a subaccess in a given male expression */ + val stmts = mutable.ArrayBuffer[Statement]() def removeMale(e: Expression): Expression = e match { case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(e)) => val rs = getLocations(e) - val foo = rs.find(x => {x.guard != one}) - foo match { + rs find (x => x.guard != one) match { case None => error("Shouldn't be here") - case foo: Some[Location] => - val temp = create_temp(e) + case Some(_) => + val (wire, temp) = create_temp(e) val temps = create_exps(temp) def getTemp(i: Int) = temps(i % temps.size) - for((x, i) <- rs.zipWithIndex) { - if (i < temps.size) { + stmts += wire + rs.zipWithIndex foreach { + case (x, i) if i < temps.size => stmts += Connect(info(s),getTemp(i),x.base) - } else { + case (x, i) => stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt) - } } temp } @@ -113,8 +104,10 @@ object RemoveAccesses extends Pass { val ls = getLocations(loc) if (ls.size == 1 & weq(ls(0).guard,one)) loc else { - val temp = create_temp(loc) - for (x <- ls) { stmts += Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt) } + val (wire, temp) = create_temp(loc) + stmts += wire + ls foreach (x => stmts += + Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt)) temp } case _ => loc @@ -148,7 +141,7 @@ object RemoveAccesses extends Pass { stmts += sx if (stmts.size != 1) Block(stmts) else stmts(0) } - Module(m.info, m.name, m.ports, onStmt(m.body)) + Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body))) } val newModules = c.modules.map( _ match { |
