diff options
| author | azidar | 2016-08-01 16:24:49 -0700 |
|---|---|---|
| committer | azidar | 2016-08-01 16:24:49 -0700 |
| commit | 59aff494dd9946c0f521705cfc93cc8687c83ec3 (patch) | |
| tree | 6b325d48270480da2921bf329fb3eb2e4c94bb70 /src | |
| parent | 81f631bc87aa22fff8569e96ae5c4e429df9e1d4 (diff) | |
Refactor RemoveAccesses and fix bug #210.
Added corresponding unit test.
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 145 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 166 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/UnitTests.scala | 38 |
3 files changed, 204 insertions, 145 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 6216d2aa..1a40b7c5 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -747,151 +747,6 @@ object ExpandConnects extends Pass { } } -case class Location(base:Expression,guard:Expression) -object RemoveAccesses extends Pass { - private var mname = "" - def name = "Remove Accesses" - def get_locations (e:Expression) : Seq[Location] = { - e match { - case (e:WRef) => create_exps(e).map(Location(_,one)) - case (e:WSubIndex) => { - val ls = get_locations(e.exp) - val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = ArrayBuffer[Location]() - var c = 0 - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { - lsx += ls(i) - } - } - lsx - } - case (e:WSubField) => { - val ls = get_locations(e.exp) - val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = ArrayBuffer[Location]() - var c = 0 - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) } - } - lsx - } - case (e:WSubAccess) => { - val ls = get_locations(e.exp) - val stride = get_size(tpe(e)) - val wrap = tpe(e.exp).asInstanceOf[VectorType].size - val lsx = ArrayBuffer[Location]() - var c = 0 - for (i <- 0 until ls.size) { - if ((c % wrap) == 0) { c = 0 } - val basex = ls(i).base - val guardx = AND(ls(i).guard,EQV(uint(c),e.index)) - lsx += Location(basex,guardx) - if ((i + 1) % stride == 0) { - c = c + 1 - } - } - lsx - } - } - } - def has_access (e:Expression) : Boolean = { - var ret:Boolean = false - def rec_has_access (e:Expression) : Expression = { - e match { - case (e:WSubAccess) => { ret = true; e } - case (e) => e map (rec_has_access) - } - } - rec_has_access(e) - ret - } - def run (c:Circuit): Circuit = { - def remove_m (m:Module) : Module = { - val namespace = Namespace(m) - mname = m.name - def remove_s (s:Statement) : Statement = { - val stmts = ArrayBuffer[Statement]() - def create_temp (e:Expression) : Expression = { - val n = namespace.newTemp - stmts += DefWire(info(s),n,tpe(e)) - WRef(n,tpe(e),kind(e),gender(e)) - } - def remove_e (e:Expression) : Expression = { //NOT RECURSIVE (except primops) INTENTIONALLY! - e match { - case (e:DoPrim) => e map (remove_e) - case (e:Mux) => e map (remove_e) - case (e:ValidIf) => e map (remove_e) - case (e:SIntLiteral) => e - case (e:UIntLiteral) => e - case x => { - val e = x match { - case (w:WSubAccess) => WSubAccess(w.exp,remove_e(w.index),w.tpe,w.gender) - case _ => x - } - if (has_access(e)) { - val rs = get_locations(e) - val foo = rs.find(x => {x.guard != one}) - foo match { - case None => error("Shouldn't be here") - case foo:Some[Location] => { - val temp = create_temp(e) - val temps = create_exps(temp) - def get_temp (i:Int) = temps(i % temps.size) - (rs,0 until rs.size).zipped.foreach { - (x,i) => { - if (i < temps.size) { - stmts += Connect(info(s),get_temp(i),x.base) - } else { - stmts += Conditionally(info(s),x.guard,Connect(info(s),get_temp(i),x.base),EmptyStmt) - } - } - } - temp - } - } - } else { e} - } - } - } - - val sx = s match { - case (s:Connect) => { - if (has_access(s.loc)) { - val ls = get_locations(s.loc) - val locx = - if (ls.size == 1 & weq(ls(0).guard,one)) s.loc - else { - val temp = create_temp(s.loc) - for (x <- ls) { stmts += Conditionally(s.info,x.guard,Connect(s.info,x.base,temp),EmptyStmt) } - temp - } - Connect(s.info,locx,remove_e(s.expr)) - } else { Connect(s.info,s.loc,remove_e(s.expr)) } - } - case (s) => s map (remove_e) map (remove_s) - } - stmts += sx - if (stmts.size != 1) Block(stmts) else stmts(0) - } - Module(m.info,m.name,m.ports,remove_s(m.body)) - } - - val modulesx = c.modules.map{ - m => { - m match { - case (m:ExtModule) => m - case (m:Module) => remove_m(m) - } - } - } - Circuit(c.info,modulesx,c.main) - } -} // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts // TODO replace UInt with zero-width wire instead diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala new file mode 100644 index 00000000..0309e7a7 --- /dev/null +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -0,0 +1,166 @@ +package firrtl.passes + +import firrtl.ir._ +import firrtl.{WRef, WSubAccess, WSubIndex, WSubField} +import firrtl.Mappers._ +import firrtl.Utils._ +import firrtl.WrappedExpression._ +import firrtl.Namespace +import scala.collection.mutable + + +/** Removes all [[firrtl.WSubAccess]] from circuit + */ +object RemoveAccesses extends Pass { + def name = "Remove Accesses" + + /** Container for a base expression and its corresponding guard + */ + case class Location(base: Expression, guard: Expression) + + /** Walks a referencing expression and returns a list of valid references + * (base) and the corresponding guard which, if true, returns that base. + * E.g. if called on a[i] where a: UInt[2], we would return: + * Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1))) + */ + def getLocations(e: Expression): Seq[Location] = e match { + case e: WRef => create_exps(e).map(Location(_,one)) + case e: WSubIndex => + val ls = getLocations(e.exp) + val start = get_point(e) + val end = start + get_size(tpe(e)) + val stride = get_size(tpe(e.exp)) + val lsx = mutable.ArrayBuffer[Location]() + var c = 0 + for (i <- 0 until ls.size) { + if (((i % stride) >= start) & ((i % stride) < end)) { + lsx += ls(i) + } + } + lsx + case e: WSubField => + val ls = getLocations(e.exp) + val start = get_point(e) + val end = start + get_size(tpe(e)) + val stride = get_size(tpe(e.exp)) + val lsx = mutable.ArrayBuffer[Location]() + var c = 0 + for (i <- 0 until ls.size) { + if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) } + } + lsx + case e: WSubAccess => + val ls = getLocations(e.exp) + val stride = get_size(tpe(e)) + val wrap = tpe(e.exp).asInstanceOf[VectorType].size + val lsx = mutable.ArrayBuffer[Location]() + var c = 0 + for (i <- 0 until ls.size) { + if ((c % wrap) == 0) { c = 0 } + val basex = ls(i).base + val guardx = AND(ls(i).guard,EQV(uint(c),e.index)) + lsx += Location(basex,guardx) + if ((i + 1) % stride == 0) { + c = c + 1 + } + } + lsx + } + /** Returns true if e contains a [[firrtl.WSubAccess]] + */ + def hasAccess(e: Expression): Boolean = { + var ret: Boolean = false + def rec_has_access(e: Expression): Expression = e match { + case (e:WSubAccess) => { ret = true; e } + case (e) => e map (rec_has_access) + } + rec_has_access(e) + ret + } + def run(c: Circuit): Circuit = { + def remove_m(m: Module): Module = { + val namespace = Namespace(m) + def onStmt(s: Statement): Statement = { + val stmts = mutable.ArrayBuffer[Statement]() + def create_temp(e: Expression): Expression = { + val n = namespace.newTemp + stmts += DefWire(info(s), n, tpe(e)) + WRef(n, tpe(e), kind(e), gender(e)) + } + + /** Replaces a subaccess in a given male expression + */ + def removeMale(e: Expression): Expression = e match { + case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(e)) => + val rs = getLocations(e) + val foo = rs.find(x => {x.guard != one}) + foo match { + case None => error("Shouldn't be here") + case foo: Some[Location] => + val temp = create_temp(e) + val temps = create_exps(temp) + def getTemp(i: Int) = temps(i % temps.size) + (rs,0 until rs.size).zipped.foreach { (x,i) => + if (i < temps.size) { + stmts += Connect(info(s),getTemp(i),x.base) + } else { + stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt) + } + } + temp + } + case _ => e + } + + /** Replaces a subaccess in a given female expression + */ + def removeFemale(info: Info, loc: Expression): Expression = loc match { + case (_: WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(loc)) => + val ls = getLocations(loc) + if (ls.size == 1 & weq(ls(0).guard,one)) loc + else { + val temp = create_temp(loc) + for (x <- ls) { stmts += Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt) } + temp + } + case _ => loc + } + + /** Recursively walks a male expression and fixes all subaccesses + * If we see a sub-access, replace it. + * Otherwise, map to children. + */ + def fixMale(e: Expression): Expression = e match { + case w: WSubAccess => removeMale(WSubAccess(w.exp, fixMale(w.index), w.tpe, w.gender)) + //case w: WSubIndex => removeMale(w) + //case w: WSubField => removeMale(w) + case x => x map fixMale + } + + /** Recursively walks a female expression and fixes all subaccesses + * If we see a sub-access, its index is a male expression, and we must replace it. + * Otherwise, map to children. + */ + def fixFemale(e: Expression): Expression = e match { + case w: WSubAccess => WSubAccess(fixFemale(w.exp), fixMale(w.index), w.tpe, w.gender) + case x => x map fixFemale + } + + val sx = s match { + case Connect(info, loc, exp) => + Connect(info, removeFemale(info, fixFemale(loc)), fixMale(exp)) + case (s) => s map (fixMale) map (onStmt) + } + stmts += sx + if (stmts.size != 1) Block(stmts) else stmts(0) + } + Module(m.info, m.name, m.ports, onStmt(m.body)) + } + + val newModules = c.modules.map( _ match { + case m: ExtModule => m + case m: Module => remove_m(m) + }) + Circuit(c.info, newModules, c.main) + } +} diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index ead55755..bc8db897 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -194,4 +194,42 @@ class UnitTests extends FirrtlFlatSpec { val check = Seq("c <= mux(pred, a, pad(b, 32))") executeTest(input, check, passes) } + "Indexes into sub-accesses" should "be dealt with" in { + val passes = Seq( + ToWorkingIR, + ResolveKinds, + InferTypes, + ResolveGenders, + InferWidths, + PullMuxes, + ExpandConnects, + RemoveAccesses + ) + val input = + """circuit AssignViaDeref : + | module AssignViaDeref : + | input clk : Clock + | input reset : UInt<1> + | output io : {a : UInt<8>, sel : UInt<1>} + | + | io is invalid + | reg table : {a : UInt<8>}[2], clk + | reg otherTable : {a : UInt<8>}[2], clk + | otherTable[table[UInt<1>("h01")].a].a <= UInt<1>("h00")""".stripMargin + //TODO(azidar): I realize this is brittle, but unfortunately there + // isn't a better way to test this pass + val check = Seq( + """wire GEN_0 : { a : UInt<8>}""", + """GEN_0.a <= table[0].a""", + """when eq(UInt<1>("h1"), UInt<1>("h1")) :""", + """GEN_0.a <= table[1].a""", + """wire GEN_1 : UInt<8>""", + """when eq(UInt<1>("h0"), GEN_0.a) :""", + """otherTable[0].a <= GEN_1""", + """when eq(UInt<1>("h1"), GEN_0.a) :""", + """otherTable[1].a <= GEN_1""", + """GEN_1 <= UInt<1>("h0")""" + ) + executeTest(input, check, passes) + } } |
