aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2016-09-07 11:24:08 -0700
committerGitHub2016-09-07 11:24:08 -0700
commit0c6db9ef0669e3fb92fcc0bda2085f934d065f0b (patch)
treecfff6e46fad44cc0c20eb079863b2a0d6d4aa993 /src
parent6a05468ed0ece1ace3019666b16f2ae83ef76ef9 (diff)
parent6255d5e398ae21dbc75db907bb9a9b24bc09d2b3 (diff)
Merge pull request #256 from ucb-bar/fix_boom_errors
Fix performance bug with remove accesses
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala1
-rw-r--r--src/main/scala/firrtl/Namespace.scala22
-rw-r--r--src/main/scala/firrtl/Utils.scala33
-rw-r--r--src/main/scala/firrtl/WIR.scala27
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala32
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala97
-rw-r--r--src/main/scala/firrtl/passes/ReplaceSubAccess.scala32
7 files changed, 125 insertions, 119 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index f9a5864c..c8430d2b 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -104,6 +104,7 @@ class ResolveAndCheck () extends Transform with SimpleRun {
class HighFirrtlToMiddleFirrtl () extends Transform with SimpleRun {
val passSeq = Seq(
passes.PullMuxes,
+ passes.ReplaceAccesses,
passes.ExpandConnects,
passes.RemoveAccesses,
passes.ExpandWhens,
diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala
index e7a1cd10..952670cf 100644
--- a/src/main/scala/firrtl/Namespace.scala
+++ b/src/main/scala/firrtl/Namespace.scala
@@ -63,22 +63,16 @@ object Namespace {
def apply(m: DefModule): Namespace = {
val namespace = new Namespace
- def buildNamespaceStmt(s: Statement): Statement =
- s map buildNamespaceStmt match {
- case dec: IsDeclaration =>
- namespace.namespace += dec.name
- dec
- case x => x
- }
- def buildNamespacePort(p: Port): Port = p match {
- case dec: IsDeclaration =>
- namespace.namespace += dec.name
- dec
- case x => x
+ def buildNamespaceStmt(s: Statement): Seq[String] = s match {
+ case s: IsDeclaration => Seq(s.name)
+ case s: Conditionally => buildNamespaceStmt(s.conseq) ++ buildNamespaceStmt(s.alt)
+ case s: Block => s.stmts flatMap buildNamespaceStmt
+ case _ => Nil
}
- m.ports map buildNamespacePort
+ namespace.namespace ++= (m.ports collect { case dec: IsDeclaration => dec.name })
m match {
- case in: Module => buildNamespaceStmt(in.body)
+ case in: Module =>
+ namespace.namespace ++= buildNamespaceStmt(in.body)
case _ => // Do nothing
}
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index e6db4b2d..1db8ce78 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -151,27 +151,18 @@ object Utils extends LazyLogging {
}
def create_exps (n:String, t:Type) : Seq[Expression] =
create_exps(WRef(n,t,ExpKind(),UNKNOWNGENDER))
- def create_exps (e:Expression) : Seq[Expression] = {
- e match {
- case (e:Mux) => {
- val e1s = create_exps(e.tval)
- val e2s = create_exps(e.fval)
- (e1s, e2s).zipped.map { (e1,e2) => Mux(e.cond,e1,e2,mux_type_and_widths(e1,e2)) }
- }
- case (e:ValidIf) => create_exps(e.value).map { e1 => ValidIf(e.cond,e1,tpe(e1)) }
- case (e) => {
- tpe(e) match {
- case (t:UIntType) => Seq(e)
- case (t:SIntType) => Seq(e)
- case ClockType => Seq(e)
- case (t:BundleType) => {
- t.fields.flatMap { f => create_exps(WSubField(e,f.name,f.tpe,times(gender(e), f.flip))) }
- }
- case (t:VectorType) => {
- (0 until t.size).flatMap { i => create_exps(WSubIndex(e,i,t.tpe,gender(e))) }
- }
- }
- }
+ def create_exps (e:Expression) : Seq[Expression] = e match {
+ case (e:Mux) =>
+ val e1s = create_exps(e.tval)
+ val e2s = create_exps(e.fval)
+ (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type_and_widths(e1,e2)))
+ case (e:ValidIf) => create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1)))
+ case (e) => tpe(e) match {
+ case (_:GroundType) => Seq(e)
+ case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
+ exps ++ create_exps(WSubField(e,f.name,f.tpe,times(gender(e), f.flip))))
+ case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
+ exps ++ create_exps(WSubIndex(e,i,t.tpe,gender(e))))
}
}
def get_flip (t:Type, i:Int, f:Orientation) : Orientation = {
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 4ed639da..eddd723b 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -105,12 +105,9 @@ class WrappedExpression (val e1:Expression) {
case (e1:WSubAccess,e2:WSubAccess) => weq(e1.index,e2.index) && weq(e1.exp,e2.exp)
case (e1:WVoid,e2:WVoid) => true
case (e1:WInvalid,e2:WInvalid) => true
- case (e1:DoPrim,e2:DoPrim) => {
- var are_equal = e1.op == e2.op
- (e1.args,e2.args).zipped.foreach{ (x,y) => { if (!weq(x,y)) are_equal = false }}
- (e1.consts,e2.consts).zipped.foreach{ (x,y) => { if (x != y) are_equal = false }}
- are_equal
- }
+ case (e1:DoPrim,e2:DoPrim) => e1.op == e2.op &&
+ ((e1.consts zip e2.consts) forall {case (x, y) => x == y}) &&
+ ((e1.args zip e2.args) forall {case (x, y) => weq(x, y)})
case (e1:Mux,e2:Mux) => weq(e1.cond,e2.cond) && weq(e1.tval,e2.tval) && weq(e1.fval,e2.fval)
case (e1:ValidIf,e2:ValidIf) => weq(e1.cond,e2.cond) && weq(e1.value,e2.value)
case (e1,e2) => false
@@ -156,17 +153,13 @@ class WrappedType (val t:Type) {
case (t1:UIntType,t2:UIntType) => true
case (t1:SIntType,t2:SIntType) => true
case (ClockType, ClockType) => true
- case (t1:VectorType,t2:VectorType) => (wt(t1.tpe) == wt(t2.tpe) && t1.size == t2.size)
- case (t1:BundleType,t2:BundleType) => {
- var ret = true
- (t1.fields,t2.fields).zipped.foreach{ (f1,f2) => {
- if (f1.flip != f2.flip) ret = false
- if (f1.name != f2.name) ret = false
- if (wt(f1.tpe) != wt(f2.tpe)) ret = false
- }}
- if (t1.fields.size != t2.fields.size) ret = false
- ret
- }
+ case (t1:VectorType,t2:VectorType) =>
+ t1.size == t2.size && wt(t1.tpe) == wt(t2.tpe)
+ case (t1:BundleType,t2:BundleType) =>
+ t1.fields.size == t2.fields.size && (
+ (t1.fields zip t2.fields) forall {case (f1, f2) =>
+ f1.flip == f2.flip && f1.name == f2.name && wt(f1.tpe) == wt(f2.tpe)
+ })
case (t1,t2) => false
}
}
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 1b6c76f4..7b4f9aa2 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -1061,24 +1061,20 @@ case class DataRef( val exp : Expression, val male : String, val female : String
object RemoveCHIRRTL extends Pass {
def name = "Remove CHIRRTL"
var mname = ""
- def create_exps (e:Expression) : Seq[Expression] = {
- (e) match {
- case (e:Mux)=>
- (create_exps(e.tval),create_exps(e.fval)).zipped.map((e1,e2) => {
- Mux(e.cond,e1,e2,mux_type(e1,e2))
- })
- case (e:ValidIf) =>
- create_exps(e.value).map(e1 => {
- ValidIf(e.cond,e1,tpe(e1))
- })
- case (e) => (tpe(e)) match {
- case (_:UIntType|_:SIntType|ClockType) => Seq(e)
- case (t:BundleType) =>
- t.fields.flatMap(f => create_exps(SubField(e,f.name,f.tpe)))
- case (t:VectorType)=>
- (0 until t.size).flatMap(i => create_exps(SubIndex(e,i,t.tpe)))
- case UnknownType => Seq(e)
- }
+ def create_exps (e:Expression) : Seq[Expression] = e match {
+ case (e:Mux) =>
+ val e1s = create_exps(e.tval)
+ val e2s = create_exps(e.fval)
+ (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2)))
+ case (e:ValidIf) =>
+ create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1)))
+ case (e) => (tpe(e)) match {
+ case (_:GroundType) => Seq(e)
+ case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
+ exps ++ create_exps(SubField(e,f.name,f.tpe)))
+ case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
+ exps ++ create_exps(SubIndex(e,i,t.tpe)))
+ case UnknownType => Seq(e)
}
}
def run (c:Circuit) : Circuit = {
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index d3340f2d..a3ce49f7 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -13,93 +13,90 @@ import scala.collection.mutable
*/
object RemoveAccesses extends Pass {
def name = "Remove Accesses"
-
/** Container for a base expression and its corresponding guard
*/
- case class Location(base: Expression, guard: Expression)
+ private case class Location(base: Expression, guard: Expression)
/** Walks a referencing expression and returns a list of valid references
* (base) and the corresponding guard which, if true, returns that base.
* E.g. if called on a[i] where a: UInt[2], we would return:
* Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
*/
- def getLocations(e: Expression): Seq[Location] = e match {
+ private def getLocations(e: Expression): Seq[Location] = e match {
case e: WRef => create_exps(e).map(Location(_,one))
case e: WSubIndex =>
val ls = getLocations(e.exp)
val start = get_point(e)
- val end = start + get_size(tpe(e))
- val stride = get_size(tpe(e.exp))
- val lsx = mutable.ArrayBuffer[Location]()
- for (i <- 0 until ls.size) {
- if (((i % stride) >= start) & ((i % stride) < end)) {
- lsx += ls(i)
- }
- }
- lsx
+ val end = start + get_size(e.tpe)
+ val stride = get_size(e.exp.tpe)
+ for ((l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)) yield l
case e: WSubField =>
val ls = getLocations(e.exp)
val start = get_point(e)
- val end = start + get_size(tpe(e))
- val stride = get_size(tpe(e.exp))
- val lsx = mutable.ArrayBuffer[Location]()
- for (i <- 0 until ls.size) {
- if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) }
- }
- lsx
+ val end = start + get_size(e.tpe)
+ val stride = get_size(e.exp.tpe)
+ for ((l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)) yield l
case e: WSubAccess =>
val ls = getLocations(e.exp)
- val stride = get_size(tpe(e))
- val wrap = tpe(e.exp).asInstanceOf[VectorType].size
- val lsx = mutable.ArrayBuffer[Location]()
- for (i <- 0 until ls.size) {
+ val stride = get_size(e.tpe)
+ val wrap = e.exp.tpe.asInstanceOf[VectorType].size
+ ls.zipWithIndex map {case (l, i) =>
val c = (i / stride) % wrap
- val basex = ls(i).base
- val guardx = AND(ls(i).guard,EQV(uint(c),e.index))
- lsx += Location(basex,guardx)
+ val basex = l.base
+ val guardx = AND(l.guard,EQV(uint(c),e.index))
+ Location(basex,guardx)
}
- lsx
}
+
/** Returns true if e contains a [[firrtl.WSubAccess]]
*/
- def hasAccess(e: Expression): Boolean = {
- var ret: Boolean = false
- def rec_has_access(e: Expression): Expression = e match {
- case (e:WSubAccess) => { ret = true; e }
- case (e) => e map (rec_has_access)
+ private def hasAccess(e: Expression): Boolean = {
+ var ret: Boolean = false
+ def rec_has_access(e: Expression): Expression = {
+ e match {
+ case e : WSubAccess => ret = true
+ case e =>
+ }
+ e map rec_has_access
}
rec_has_access(e)
ret
}
+
+ // This improves the performance of this pass
+ private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]()
+ private def create_exps(e: Expression) =
+ createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e))
+
def run(c: Circuit): Circuit = {
def remove_m(m: Module): Module = {
val namespace = Namespace(m)
def onStmt(s: Statement): Statement = {
- val stmts = mutable.ArrayBuffer[Statement]()
- def create_temp(e: Expression): Expression = {
+ def create_temp(e: Expression): (Statement, Expression) = {
val n = namespace.newTemp
- stmts += DefWire(info(s), n, tpe(e))
- WRef(n, tpe(e), kind(e), gender(e))
+ (DefWire(info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e)))
}
/** Replaces a subaccess in a given male expression
*/
+ val stmts = mutable.ArrayBuffer[Statement]()
def removeMale(e: Expression): Expression = e match {
case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(e)) =>
val rs = getLocations(e)
- val foo = rs.find(x => {x.guard != one})
- foo match {
+ rs find (x => x.guard != one) match {
case None => error("Shouldn't be here")
- case foo: Some[Location] =>
- val temp = create_temp(e)
+ case Some(_) =>
+ val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
- for((x, i) <- rs.zipWithIndex) {
- if (i < temps.size) {
+ stmts += wire
+ rs.zipWithIndex foreach {
+ case (x, i) if i < temps.size =>
stmts += Connect(info(s),getTemp(i),x.base)
- } else {
+ case (x, i) =>
stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt)
- }
}
temp
}
@@ -113,8 +110,10 @@ object RemoveAccesses extends Pass {
val ls = getLocations(loc)
if (ls.size == 1 & weq(ls(0).guard,one)) loc
else {
- val temp = create_temp(loc)
- for (x <- ls) { stmts += Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt) }
+ val (wire, temp) = create_temp(loc)
+ stmts += wire
+ ls foreach (x => stmts +=
+ Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt))
temp
}
case _ => loc
@@ -148,13 +147,13 @@ object RemoveAccesses extends Pass {
stmts += sx
if (stmts.size != 1) Block(stmts) else stmts(0)
}
- Module(m.info, m.name, m.ports, onStmt(m.body))
+ Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body)))
}
- val newModules = c.modules.map( _ match {
+ val newModules = c.modules.map {
case m: ExtModule => m
case m: Module => remove_m(m)
- })
+ }
Circuit(c.info, newModules, c.main)
}
}
diff --git a/src/main/scala/firrtl/passes/ReplaceSubAccess.scala b/src/main/scala/firrtl/passes/ReplaceSubAccess.scala
new file mode 100644
index 00000000..8e911a96
--- /dev/null
+++ b/src/main/scala/firrtl/passes/ReplaceSubAccess.scala
@@ -0,0 +1,32 @@
+package firrtl.passes
+
+import firrtl.ir._
+import firrtl.{WRef, WSubAccess, WSubIndex, WSubField}
+import firrtl.Mappers._
+import firrtl.Utils._
+import firrtl.WrappedExpression._
+import firrtl.Namespace
+import scala.collection.mutable
+
+
+/** Replaces constant [[firrtl.WSubAccess]] with [[firrtl.WSubIndex]]
+ * TODO Fold in to High Firrtl Const Prop
+ */
+object ReplaceAccesses extends Pass {
+ def name = "Replace Accesses"
+
+ def run(c: Circuit): Circuit = {
+ def onStmt(s: Statement): Statement = s map onStmt map onExp
+ def onExp(e: Expression): Expression = e match {
+ case WSubAccess(e, UIntLiteral(value, width), t, g) => WSubIndex(e, value.toInt, t, g)
+ case e => e map onExp
+ }
+
+ val newModules = c.modules map {
+ case m: ExtModule => m
+ case Module(i, n, ps, b) => Module(i, n, ps, onStmt(b))
+ }
+
+ Circuit(c.info, newModules, c.main)
+ }
+}