diff options
| author | Jack Koenig | 2016-09-07 11:24:08 -0700 |
|---|---|---|
| committer | GitHub | 2016-09-07 11:24:08 -0700 |
| commit | 0c6db9ef0669e3fb92fcc0bda2085f934d065f0b (patch) | |
| tree | cfff6e46fad44cc0c20eb079863b2a0d6d4aa993 /src/main/scala/firrtl/passes | |
| parent | 6a05468ed0ece1ace3019666b16f2ae83ef76ef9 (diff) | |
| parent | 6255d5e398ae21dbc75db907bb9a9b24bc09d2b3 (diff) | |
Merge pull request #256 from ucb-bar/fix_boom_errors
Fix performance bug with remove accesses
Diffstat (limited to 'src/main/scala/firrtl/passes')
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 32 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 97 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplaceSubAccess.scala | 32 |
3 files changed, 94 insertions, 67 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 1b6c76f4..7b4f9aa2 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -1061,24 +1061,20 @@ case class DataRef( val exp : Expression, val male : String, val female : String object RemoveCHIRRTL extends Pass { def name = "Remove CHIRRTL" var mname = "" - def create_exps (e:Expression) : Seq[Expression] = { - (e) match { - case (e:Mux)=> - (create_exps(e.tval),create_exps(e.fval)).zipped.map((e1,e2) => { - Mux(e.cond,e1,e2,mux_type(e1,e2)) - }) - case (e:ValidIf) => - create_exps(e.value).map(e1 => { - ValidIf(e.cond,e1,tpe(e1)) - }) - case (e) => (tpe(e)) match { - case (_:UIntType|_:SIntType|ClockType) => Seq(e) - case (t:BundleType) => - t.fields.flatMap(f => create_exps(SubField(e,f.name,f.tpe))) - case (t:VectorType)=> - (0 until t.size).flatMap(i => create_exps(SubIndex(e,i,t.tpe))) - case UnknownType => Seq(e) - } + def create_exps (e:Expression) : Seq[Expression] = e match { + case (e:Mux) => + val e1s = create_exps(e.tval) + val e2s = create_exps(e.fval) + (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2))) + case (e:ValidIf) => + create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1))) + case (e) => (tpe(e)) match { + case (_:GroundType) => Seq(e) + case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => + exps ++ create_exps(SubField(e,f.name,f.tpe))) + case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + exps ++ create_exps(SubIndex(e,i,t.tpe))) + case UnknownType => Seq(e) } } def run (c:Circuit) : Circuit = { diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index d3340f2d..a3ce49f7 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -13,93 +13,90 @@ import scala.collection.mutable */ object RemoveAccesses extends Pass { def name = "Remove Accesses" - /** Container for a base expression and its corresponding guard */ - case class Location(base: Expression, guard: Expression) + private case class Location(base: Expression, guard: Expression) /** Walks a referencing expression and returns a list of valid references * (base) and the corresponding guard which, if true, returns that base. * E.g. if called on a[i] where a: UInt[2], we would return: * Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1))) */ - def getLocations(e: Expression): Seq[Location] = e match { + private def getLocations(e: Expression): Seq[Location] = e match { case e: WRef => create_exps(e).map(Location(_,one)) case e: WSubIndex => val ls = getLocations(e.exp) val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { - lsx += ls(i) - } - } - lsx + val end = start + get_size(e.tpe) + val stride = get_size(e.exp.tpe) + for ((l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end)) yield l case e: WSubField => val ls = getLocations(e.exp) val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) } - } - lsx + val end = start + get_size(e.tpe) + val stride = get_size(e.exp.tpe) + for ((l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end)) yield l case e: WSubAccess => val ls = getLocations(e.exp) - val stride = get_size(tpe(e)) - val wrap = tpe(e.exp).asInstanceOf[VectorType].size - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { + val stride = get_size(e.tpe) + val wrap = e.exp.tpe.asInstanceOf[VectorType].size + ls.zipWithIndex map {case (l, i) => val c = (i / stride) % wrap - val basex = ls(i).base - val guardx = AND(ls(i).guard,EQV(uint(c),e.index)) - lsx += Location(basex,guardx) + val basex = l.base + val guardx = AND(l.guard,EQV(uint(c),e.index)) + Location(basex,guardx) } - lsx } + /** Returns true if e contains a [[firrtl.WSubAccess]] */ - def hasAccess(e: Expression): Boolean = { - var ret: Boolean = false - def rec_has_access(e: Expression): Expression = e match { - case (e:WSubAccess) => { ret = true; e } - case (e) => e map (rec_has_access) + private def hasAccess(e: Expression): Boolean = { + var ret: Boolean = false + def rec_has_access(e: Expression): Expression = { + e match { + case e : WSubAccess => ret = true + case e => + } + e map rec_has_access } rec_has_access(e) ret } + + // This improves the performance of this pass + private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]() + private def create_exps(e: Expression) = + createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e)) + def run(c: Circuit): Circuit = { def remove_m(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { - val stmts = mutable.ArrayBuffer[Statement]() - def create_temp(e: Expression): Expression = { + def create_temp(e: Expression): (Statement, Expression) = { val n = namespace.newTemp - stmts += DefWire(info(s), n, tpe(e)) - WRef(n, tpe(e), kind(e), gender(e)) + (DefWire(info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) } /** Replaces a subaccess in a given male expression */ + val stmts = mutable.ArrayBuffer[Statement]() def removeMale(e: Expression): Expression = e match { case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(e)) => val rs = getLocations(e) - val foo = rs.find(x => {x.guard != one}) - foo match { + rs find (x => x.guard != one) match { case None => error("Shouldn't be here") - case foo: Some[Location] => - val temp = create_temp(e) + case Some(_) => + val (wire, temp) = create_temp(e) val temps = create_exps(temp) def getTemp(i: Int) = temps(i % temps.size) - for((x, i) <- rs.zipWithIndex) { - if (i < temps.size) { + stmts += wire + rs.zipWithIndex foreach { + case (x, i) if i < temps.size => stmts += Connect(info(s),getTemp(i),x.base) - } else { + case (x, i) => stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt) - } } temp } @@ -113,8 +110,10 @@ object RemoveAccesses extends Pass { val ls = getLocations(loc) if (ls.size == 1 & weq(ls(0).guard,one)) loc else { - val temp = create_temp(loc) - for (x <- ls) { stmts += Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt) } + val (wire, temp) = create_temp(loc) + stmts += wire + ls foreach (x => stmts += + Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt)) temp } case _ => loc @@ -148,13 +147,13 @@ object RemoveAccesses extends Pass { stmts += sx if (stmts.size != 1) Block(stmts) else stmts(0) } - Module(m.info, m.name, m.ports, onStmt(m.body)) + Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body))) } - val newModules = c.modules.map( _ match { + val newModules = c.modules.map { case m: ExtModule => m case m: Module => remove_m(m) - }) + } Circuit(c.info, newModules, c.main) } } diff --git a/src/main/scala/firrtl/passes/ReplaceSubAccess.scala b/src/main/scala/firrtl/passes/ReplaceSubAccess.scala new file mode 100644 index 00000000..8e911a96 --- /dev/null +++ b/src/main/scala/firrtl/passes/ReplaceSubAccess.scala @@ -0,0 +1,32 @@ +package firrtl.passes + +import firrtl.ir._ +import firrtl.{WRef, WSubAccess, WSubIndex, WSubField} +import firrtl.Mappers._ +import firrtl.Utils._ +import firrtl.WrappedExpression._ +import firrtl.Namespace +import scala.collection.mutable + + +/** Replaces constant [[firrtl.WSubAccess]] with [[firrtl.WSubIndex]] + * TODO Fold in to High Firrtl Const Prop + */ +object ReplaceAccesses extends Pass { + def name = "Replace Accesses" + + def run(c: Circuit): Circuit = { + def onStmt(s: Statement): Statement = s map onStmt map onExp + def onExp(e: Expression): Expression = e match { + case WSubAccess(e, UIntLiteral(value, width), t, g) => WSubIndex(e, value.toInt, t, g) + case e => e map onExp + } + + val newModules = c.modules map { + case m: ExtModule => m + case Module(i, n, ps, b) => Module(i, n, ps, onStmt(b)) + } + + Circuit(c.info, newModules, c.main) + } +} |
