diff options
| author | Jack Koenig | 2016-09-07 11:24:08 -0700 |
|---|---|---|
| committer | GitHub | 2016-09-07 11:24:08 -0700 |
| commit | 0c6db9ef0669e3fb92fcc0bda2085f934d065f0b (patch) | |
| tree | cfff6e46fad44cc0c20eb079863b2a0d6d4aa993 /src | |
| parent | 6a05468ed0ece1ace3019666b16f2ae83ef76ef9 (diff) | |
| parent | 6255d5e398ae21dbc75db907bb9a9b24bc09d2b3 (diff) | |
Merge pull request #256 from ucb-bar/fix_boom_errors
Fix performance bug with remove accesses
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Namespace.scala | 22 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 33 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 27 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 32 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 97 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplaceSubAccess.scala | 32 |
7 files changed, 125 insertions, 119 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index f9a5864c..c8430d2b 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -104,6 +104,7 @@ class ResolveAndCheck () extends Transform with SimpleRun { class HighFirrtlToMiddleFirrtl () extends Transform with SimpleRun { val passSeq = Seq( passes.PullMuxes, + passes.ReplaceAccesses, passes.ExpandConnects, passes.RemoveAccesses, passes.ExpandWhens, diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala index e7a1cd10..952670cf 100644 --- a/src/main/scala/firrtl/Namespace.scala +++ b/src/main/scala/firrtl/Namespace.scala @@ -63,22 +63,16 @@ object Namespace { def apply(m: DefModule): Namespace = { val namespace = new Namespace - def buildNamespaceStmt(s: Statement): Statement = - s map buildNamespaceStmt match { - case dec: IsDeclaration => - namespace.namespace += dec.name - dec - case x => x - } - def buildNamespacePort(p: Port): Port = p match { - case dec: IsDeclaration => - namespace.namespace += dec.name - dec - case x => x + def buildNamespaceStmt(s: Statement): Seq[String] = s match { + case s: IsDeclaration => Seq(s.name) + case s: Conditionally => buildNamespaceStmt(s.conseq) ++ buildNamespaceStmt(s.alt) + case s: Block => s.stmts flatMap buildNamespaceStmt + case _ => Nil } - m.ports map buildNamespacePort + namespace.namespace ++= (m.ports collect { case dec: IsDeclaration => dec.name }) m match { - case in: Module => buildNamespaceStmt(in.body) + case in: Module => + namespace.namespace ++= buildNamespaceStmt(in.body) case _ => // Do nothing } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index e6db4b2d..1db8ce78 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -151,27 +151,18 @@ object Utils extends LazyLogging { } def create_exps (n:String, t:Type) : Seq[Expression] = create_exps(WRef(n,t,ExpKind(),UNKNOWNGENDER)) - def create_exps (e:Expression) : Seq[Expression] = { - e match { - case (e:Mux) => { - val e1s = create_exps(e.tval) - val e2s = create_exps(e.fval) - (e1s, e2s).zipped.map { (e1,e2) => Mux(e.cond,e1,e2,mux_type_and_widths(e1,e2)) } - } - case (e:ValidIf) => create_exps(e.value).map { e1 => ValidIf(e.cond,e1,tpe(e1)) } - case (e) => { - tpe(e) match { - case (t:UIntType) => Seq(e) - case (t:SIntType) => Seq(e) - case ClockType => Seq(e) - case (t:BundleType) => { - t.fields.flatMap { f => create_exps(WSubField(e,f.name,f.tpe,times(gender(e), f.flip))) } - } - case (t:VectorType) => { - (0 until t.size).flatMap { i => create_exps(WSubIndex(e,i,t.tpe,gender(e))) } - } - } - } + def create_exps (e:Expression) : Seq[Expression] = e match { + case (e:Mux) => + val e1s = create_exps(e.tval) + val e2s = create_exps(e.fval) + (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type_and_widths(e1,e2))) + case (e:ValidIf) => create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1))) + case (e) => tpe(e) match { + case (_:GroundType) => Seq(e) + case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => + exps ++ create_exps(WSubField(e,f.name,f.tpe,times(gender(e), f.flip)))) + case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + exps ++ create_exps(WSubIndex(e,i,t.tpe,gender(e)))) } } def get_flip (t:Type, i:Int, f:Orientation) : Orientation = { diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 4ed639da..eddd723b 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -105,12 +105,9 @@ class WrappedExpression (val e1:Expression) { case (e1:WSubAccess,e2:WSubAccess) => weq(e1.index,e2.index) && weq(e1.exp,e2.exp) case (e1:WVoid,e2:WVoid) => true case (e1:WInvalid,e2:WInvalid) => true - case (e1:DoPrim,e2:DoPrim) => { - var are_equal = e1.op == e2.op - (e1.args,e2.args).zipped.foreach{ (x,y) => { if (!weq(x,y)) are_equal = false }} - (e1.consts,e2.consts).zipped.foreach{ (x,y) => { if (x != y) are_equal = false }} - are_equal - } + case (e1:DoPrim,e2:DoPrim) => e1.op == e2.op && + ((e1.consts zip e2.consts) forall {case (x, y) => x == y}) && + ((e1.args zip e2.args) forall {case (x, y) => weq(x, y)}) case (e1:Mux,e2:Mux) => weq(e1.cond,e2.cond) && weq(e1.tval,e2.tval) && weq(e1.fval,e2.fval) case (e1:ValidIf,e2:ValidIf) => weq(e1.cond,e2.cond) && weq(e1.value,e2.value) case (e1,e2) => false @@ -156,17 +153,13 @@ class WrappedType (val t:Type) { case (t1:UIntType,t2:UIntType) => true case (t1:SIntType,t2:SIntType) => true case (ClockType, ClockType) => true - case (t1:VectorType,t2:VectorType) => (wt(t1.tpe) == wt(t2.tpe) && t1.size == t2.size) - case (t1:BundleType,t2:BundleType) => { - var ret = true - (t1.fields,t2.fields).zipped.foreach{ (f1,f2) => { - if (f1.flip != f2.flip) ret = false - if (f1.name != f2.name) ret = false - if (wt(f1.tpe) != wt(f2.tpe)) ret = false - }} - if (t1.fields.size != t2.fields.size) ret = false - ret - } + case (t1:VectorType,t2:VectorType) => + t1.size == t2.size && wt(t1.tpe) == wt(t2.tpe) + case (t1:BundleType,t2:BundleType) => + t1.fields.size == t2.fields.size && ( + (t1.fields zip t2.fields) forall {case (f1, f2) => + f1.flip == f2.flip && f1.name == f2.name && wt(f1.tpe) == wt(f2.tpe) + }) case (t1,t2) => false } } diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 1b6c76f4..7b4f9aa2 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -1061,24 +1061,20 @@ case class DataRef( val exp : Expression, val male : String, val female : String object RemoveCHIRRTL extends Pass { def name = "Remove CHIRRTL" var mname = "" - def create_exps (e:Expression) : Seq[Expression] = { - (e) match { - case (e:Mux)=> - (create_exps(e.tval),create_exps(e.fval)).zipped.map((e1,e2) => { - Mux(e.cond,e1,e2,mux_type(e1,e2)) - }) - case (e:ValidIf) => - create_exps(e.value).map(e1 => { - ValidIf(e.cond,e1,tpe(e1)) - }) - case (e) => (tpe(e)) match { - case (_:UIntType|_:SIntType|ClockType) => Seq(e) - case (t:BundleType) => - t.fields.flatMap(f => create_exps(SubField(e,f.name,f.tpe))) - case (t:VectorType)=> - (0 until t.size).flatMap(i => create_exps(SubIndex(e,i,t.tpe))) - case UnknownType => Seq(e) - } + def create_exps (e:Expression) : Seq[Expression] = e match { + case (e:Mux) => + val e1s = create_exps(e.tval) + val e2s = create_exps(e.fval) + (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2))) + case (e:ValidIf) => + create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1))) + case (e) => (tpe(e)) match { + case (_:GroundType) => Seq(e) + case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => + exps ++ create_exps(SubField(e,f.name,f.tpe))) + case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + exps ++ create_exps(SubIndex(e,i,t.tpe))) + case UnknownType => Seq(e) } } def run (c:Circuit) : Circuit = { diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index d3340f2d..a3ce49f7 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -13,93 +13,90 @@ import scala.collection.mutable */ object RemoveAccesses extends Pass { def name = "Remove Accesses" - /** Container for a base expression and its corresponding guard */ - case class Location(base: Expression, guard: Expression) + private case class Location(base: Expression, guard: Expression) /** Walks a referencing expression and returns a list of valid references * (base) and the corresponding guard which, if true, returns that base. * E.g. if called on a[i] where a: UInt[2], we would return: * Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1))) */ - def getLocations(e: Expression): Seq[Location] = e match { + private def getLocations(e: Expression): Seq[Location] = e match { case e: WRef => create_exps(e).map(Location(_,one)) case e: WSubIndex => val ls = getLocations(e.exp) val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { - lsx += ls(i) - } - } - lsx + val end = start + get_size(e.tpe) + val stride = get_size(e.exp.tpe) + for ((l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end)) yield l case e: WSubField => val ls = getLocations(e.exp) val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) } - } - lsx + val end = start + get_size(e.tpe) + val stride = get_size(e.exp.tpe) + for ((l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end)) yield l case e: WSubAccess => val ls = getLocations(e.exp) - val stride = get_size(tpe(e)) - val wrap = tpe(e.exp).asInstanceOf[VectorType].size - val lsx = mutable.ArrayBuffer[Location]() - for (i <- 0 until ls.size) { + val stride = get_size(e.tpe) + val wrap = e.exp.tpe.asInstanceOf[VectorType].size + ls.zipWithIndex map {case (l, i) => val c = (i / stride) % wrap - val basex = ls(i).base - val guardx = AND(ls(i).guard,EQV(uint(c),e.index)) - lsx += Location(basex,guardx) + val basex = l.base + val guardx = AND(l.guard,EQV(uint(c),e.index)) + Location(basex,guardx) } - lsx } + /** Returns true if e contains a [[firrtl.WSubAccess]] */ - def hasAccess(e: Expression): Boolean = { - var ret: Boolean = false - def rec_has_access(e: Expression): Expression = e match { - case (e:WSubAccess) => { ret = true; e } - case (e) => e map (rec_has_access) + private def hasAccess(e: Expression): Boolean = { + var ret: Boolean = false + def rec_has_access(e: Expression): Expression = { + e match { + case e : WSubAccess => ret = true + case e => + } + e map rec_has_access } rec_has_access(e) ret } + + // This improves the performance of this pass + private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]() + private def create_exps(e: Expression) = + createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e)) + def run(c: Circuit): Circuit = { def remove_m(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { - val stmts = mutable.ArrayBuffer[Statement]() - def create_temp(e: Expression): Expression = { + def create_temp(e: Expression): (Statement, Expression) = { val n = namespace.newTemp - stmts += DefWire(info(s), n, tpe(e)) - WRef(n, tpe(e), kind(e), gender(e)) + (DefWire(info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) } /** Replaces a subaccess in a given male expression */ + val stmts = mutable.ArrayBuffer[Statement]() def removeMale(e: Expression): Expression = e match { case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(e)) => val rs = getLocations(e) - val foo = rs.find(x => {x.guard != one}) - foo match { + rs find (x => x.guard != one) match { case None => error("Shouldn't be here") - case foo: Some[Location] => - val temp = create_temp(e) + case Some(_) => + val (wire, temp) = create_temp(e) val temps = create_exps(temp) def getTemp(i: Int) = temps(i % temps.size) - for((x, i) <- rs.zipWithIndex) { - if (i < temps.size) { + stmts += wire + rs.zipWithIndex foreach { + case (x, i) if i < temps.size => stmts += Connect(info(s),getTemp(i),x.base) - } else { + case (x, i) => stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt) - } } temp } @@ -113,8 +110,10 @@ object RemoveAccesses extends Pass { val ls = getLocations(loc) if (ls.size == 1 & weq(ls(0).guard,one)) loc else { - val temp = create_temp(loc) - for (x <- ls) { stmts += Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt) } + val (wire, temp) = create_temp(loc) + stmts += wire + ls foreach (x => stmts += + Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt)) temp } case _ => loc @@ -148,13 +147,13 @@ object RemoveAccesses extends Pass { stmts += sx if (stmts.size != 1) Block(stmts) else stmts(0) } - Module(m.info, m.name, m.ports, onStmt(m.body)) + Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body))) } - val newModules = c.modules.map( _ match { + val newModules = c.modules.map { case m: ExtModule => m case m: Module => remove_m(m) - }) + } Circuit(c.info, newModules, c.main) } } diff --git a/src/main/scala/firrtl/passes/ReplaceSubAccess.scala b/src/main/scala/firrtl/passes/ReplaceSubAccess.scala new file mode 100644 index 00000000..8e911a96 --- /dev/null +++ b/src/main/scala/firrtl/passes/ReplaceSubAccess.scala @@ -0,0 +1,32 @@ +package firrtl.passes + +import firrtl.ir._ +import firrtl.{WRef, WSubAccess, WSubIndex, WSubField} +import firrtl.Mappers._ +import firrtl.Utils._ +import firrtl.WrappedExpression._ +import firrtl.Namespace +import scala.collection.mutable + + +/** Replaces constant [[firrtl.WSubAccess]] with [[firrtl.WSubIndex]] + * TODO Fold in to High Firrtl Const Prop + */ +object ReplaceAccesses extends Pass { + def name = "Replace Accesses" + + def run(c: Circuit): Circuit = { + def onStmt(s: Statement): Statement = s map onStmt map onExp + def onExp(e: Expression): Expression = e match { + case WSubAccess(e, UIntLiteral(value, width), t, g) => WSubIndex(e, value.toInt, t, g) + case e => e map onExp + } + + val newModules = c.modules map { + case m: ExtModule => m + case Module(i, n, ps, b) => Module(i, n, ps, onStmt(b)) + } + + Circuit(c.info, newModules, c.main) + } +} |
