aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorjackkoenig2016-04-26 16:26:31 -0700
committerjackkoenig2016-05-03 16:56:52 -0700
commit41e0f6da3d60528241a46520b949c15bcbc29957 (patch)
tree9149fa413b50935f0a2574f1a0fb75b5387b905a /src
parenta5526c177563b2c4de2a9c2b39a5b51a05697292 (diff)
Rewrite ExpandWhens to memoize complex default values
Fixes #113 and Fixes #150
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala225
1 files changed, 105 insertions, 120 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index 1b6030e2..540aab9f 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -35,178 +35,163 @@ import firrtl.WrappedExpression._
// Datastructures
import scala.collection.mutable.HashMap
+import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ArrayBuffer
+import annotation.tailrec
+
/** Expand Whens
*
* @note This pass does three things: remove last connect semantics,
* remove conditional blocks, and eliminate concept of scoping.
+* @note Assumes bulk connects and isInvalids have been expanded
+* @note Assumes all references are declared
*/
object ExpandWhens extends Pass {
def name = "Expand Whens"
- var mname = ""
+
// ========== Expand When Utilz ==========
- def getEntries(
- hash: HashMap[WrappedExpression, Expression],
- exps: Seq[Expression]): HashMap[WrappedExpression, Expression] = {
- val hashx = HashMap[WrappedExpression, Expression]()
+ private def getEntries(
+ hash: LinkedHashMap[WrappedExpression, Expression],
+ exps: Seq[Expression]): LinkedHashMap[WrappedExpression, Expression] = {
+ val hashx = LinkedHashMap[WrappedExpression, Expression]()
exps foreach (e => if (hash.contains(e)) hashx(e) = hash(e))
hashx
}
- def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = {
+ private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = {
def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, DEFAULT))
val exps = create_exps(WRef(n, t, ExpKind(), g))
val expsx = ArrayBuffer[Expression]()
- for (i <- 0 until exps.size) {
- getGender(t, i, g) match {
- case (BIGENDER | FEMALE) => expsx += exps(i)
+ for (j <- 0 until exps.size) {
+ getGender(t, j, g) match {
+ case (BIGENDER | FEMALE) => expsx += exps(j)
case _ =>
}
}
expsx
}
-
- // ------------ Pass -------------------
- def run(c: Circuit): Circuit = {
- def voidAll(m: InModule): InModule = {
- mname = m.name
- def voidAllStmt(s: Stmt): Stmt = s match {
- case (_: DefWire | _: DefRegister | _: WDefInstance |_: DefMemory) =>
- val voids = ArrayBuffer[Stmt]()
- for (e <- getFemaleRefs(get_name(s),get_type(s),get_gender(s))) {
- voids += Connect(get_info(s),e,WVoid())
- }
- Begin(Seq(s,Begin(voids)))
- case s => s map voidAllStmt
- }
- val voids = ArrayBuffer[Stmt]()
- for (p <- m.ports) {
- for (e <- getFemaleRefs(p.name,p.tpe,get_gender(p))) {
- voids += Connect(p.info,e,WVoid())
+ private def squashEmpty(s: Stmt): Stmt = {
+ s map squashEmpty match {
+ case Begin(stmts) =>
+ val newStmts = stmts filter (_ != Empty())
+ newStmts.size match {
+ case 0 => Empty()
+ case 1 => newStmts.head
+ case _ => Begin(newStmts)
}
+ case s => s
+ }
+ }
+ private def expandNetlist(netlist: LinkedHashMap[WrappedExpression, Expression]) =
+ netlist map { case (k, v) =>
+ v match {
+ case WInvalid() => IsInvalid(NoInfo, k.e1)
+ case _ => Connect(NoInfo, k.e1, v)
}
- val bodyx = voidAllStmt(m.body)
- InModule(m.info, m.name, m.ports, Begin(Seq(Begin(voids),bodyx)))
}
- def expandWhens(m: InModule): (HashMap[WrappedExpression, Expression], ArrayBuffer[Stmt]) = {
+ // Searches nested scopes of defaults for lvalue
+ // defaults uses mutable Map because we are searching LinkedHashMaps and conversion to immutable is VERY slow
+ @tailrec
+ private def getDefault(
+ lvalue: WrappedExpression,
+ defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]]): Option[Expression] = {
+ if (defaults.isEmpty) None
+ else if (defaults.head.contains(lvalue)) defaults.head.get(lvalue)
+ else getDefault(lvalue, defaults.tail)
+ }
+
+ // ------------ Pass -------------------
+ def run(c: Circuit): Circuit = {
+ def expandWhens(m: InModule): (LinkedHashMap[WrappedExpression, Expression], ArrayBuffer[Stmt], Stmt) = {
+ val namespace = Namespace(m)
val simlist = ArrayBuffer[Stmt]()
- mname = m.name
- def expandWhens(netlist: HashMap[WrappedExpression, Expression], p: Expression)(s: Stmt): Stmt = {
+
+ // defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow
+ def expandWhens(
+ netlist: LinkedHashMap[WrappedExpression, Expression],
+ defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]],
+ p: Expression)
+ (s: Stmt): Stmt = {
s match {
- case s: Connect => netlist(s.loc) = s.exp
- case s: IsInvalid => netlist(s.exp) = WInvalid()
+ case w: DefWire =>
+ getFemaleRefs(w.name, w.tpe, BIGENDER) foreach (ref => netlist(ref) = WVoid())
+ w
+ case r: DefRegister =>
+ getFemaleRefs(r.name, r.tpe, BIGENDER) foreach (ref => netlist(ref) = ref)
+ r
+ case c: Connect =>
+ netlist(c.loc) = c.exp
+ Empty()
+ case c: IsInvalid =>
+ netlist(c.exp) = WInvalid()
+ Empty()
case s: Conditionally =>
- val exps = ArrayBuffer[Expression]()
- def prefetch(s: Stmt): Stmt = s match {
- case s: Connect => exps += s.loc; s
- case s => s map prefetch
- }
- prefetch(s.conseq)
- val c_netlist = getEntries(netlist,exps)
- expandWhens(c_netlist, AND(p, s.pred))(s.conseq)
- expandWhens(netlist, AND(p, NOT(s.pred)))(s.alt)
- for (lvalue <- c_netlist.keys) {
- val value = netlist.get(lvalue)
- value match {
- case value: Some[Expression] =>
- val tv = c_netlist(lvalue)
- val fv = value.get
- val res = (tv, fv) match {
- case (tv:WInvalid, fv:WInvalid) => WInvalid()
- case (tv:WInvalid, fv) => ValidIf(NOT(s.pred), fv,tpe(fv))
- case (tv, fv:WInvalid) => ValidIf(s.pred, tv, tpe(tv))
+ val memos = ArrayBuffer[Stmt]()
+
+ val conseqNetlist = LinkedHashMap[WrappedExpression, Expression]()
+ val altNetlist = LinkedHashMap[WrappedExpression, Expression]()
+ val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq)
+ val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt)
+
+ (conseqNetlist.keySet ++ altNetlist.keySet) foreach { lvalue =>
+ // Defaults in netlist get priority over those in defaults
+ val default = if (netlist.contains(lvalue)) netlist.get(lvalue) else getDefault(lvalue, defaults)
+ val res = default match {
+ case Some(defaultValue) =>
+ val trueValue = conseqNetlist.getOrElse(lvalue, defaultValue)
+ val falseValue = altNetlist.getOrElse(lvalue, defaultValue)
+ (trueValue, falseValue) match {
+ case (WInvalid(), WInvalid()) => WInvalid()
+ case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, tpe(fv))
+ case (tv, WInvalid()) => ValidIf(s.pred, tv, tpe(tv))
case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv))
}
- netlist(lvalue) = res
- case None => netlist(lvalue) = c_netlist(lvalue)
+ case None =>
+ // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
+ conseqNetlist.getOrElse(lvalue, altNetlist(lvalue))
}
+
+ val memoNode = DefNode(s.info, namespace.newTemp, res)
+ val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE)
+ memos += memoNode
+ netlist(lvalue) = memoExpr
}
+ Begin(Seq(conseqStmt, altStmt) ++ memos)
+
case s: Print =>
if(weq(p, one)) {
simlist += s
} else {
simlist += Print(s.info, s.string, s.args, s.clk, AND(p, s.en))
}
+ Empty()
case s: Stop =>
if (weq(p, one)) {
simlist += s
} else {
simlist += Stop(s.info, s.ret, s.clk, AND(p, s.en))
}
- case s => s map expandWhens(netlist, p)
+ Empty()
+ case s => s map expandWhens(netlist, defaults, p)
}
- s
}
- val netlist = HashMap[WrappedExpression, Expression]()
- expandWhens(netlist, one)(m.body)
+ val netlist = LinkedHashMap[WrappedExpression, Expression]()
- (netlist, simlist)
- }
-
- def createModule(netlist: HashMap[WrappedExpression,Expression], simlist: ArrayBuffer[Stmt], m: InModule): InModule = {
- mname = m.name
- val stmts = ArrayBuffer[Stmt]()
- val connections = ArrayBuffer[Stmt]()
- def replace_void(e: Expression)(rvalue: Expression): Expression = rvalue match {
- case rv: WVoid => e
- case rv => rv map replace_void(e)
- }
- def create(s: Stmt): Stmt = {
- s match {
- case (_: DefWire | _: WDefInstance | _: DefMemory) =>
- stmts += s
- for (e <- getFemaleRefs(get_name(s), get_type(s), get_gender(s))) {
- val rvalue = netlist(e)
- val con = rvalue match {
- case rvalue: WInvalid => IsInvalid(get_info(s), e)
- case rvalue => Connect(get_info(s), e, rvalue)
- }
- connections += con
- }
- case s: DefRegister =>
- stmts += s
- for (e <- getFemaleRefs(get_name(s), get_type(s), get_gender(s))) {
- val rvalue = replace_void(e)(netlist(e))
- val con = rvalue match {
- case rvalue: WInvalid => IsInvalid(get_info(s), e)
- case rvalue => Connect(get_info(s), e, rvalue)
- }
- connections += con
- }
- case (_: DefPoison | _: DefNode) => stmts += s
- case s => s map create
- }
- s
+ // Add ports to netlist
+ m.ports foreach { port =>
+ getFemaleRefs(port.name, port.tpe, to_gender(port.direction)) foreach (ref => netlist(ref) = WVoid())
}
- create(m.body)
- for (p <- m.ports) {
- for (e <- getFemaleRefs(p.name, p.tpe, get_gender(p))) {
- val rvalue = netlist(e)
- val con = rvalue match {
- case rvalue: WInvalid => IsInvalid(p.info, e)
- case rvalue => Connect(p.info, e, rvalue)
- }
- connections += con
- }
- }
- for (x <- simlist) { stmts += x }
- InModule(m.info, m.name, m.ports, Begin(Seq(Begin(stmts), Begin(connections))))
- }
+ val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
- val voided_modules = c.modules map { m =>
- m match {
- case m: ExModule => m
- case m: InModule => voidAll(m)
- }
+ (netlist, simlist, bodyx)
}
-
- val modulesx = voided_modules map { m =>
+ val modulesx = c.modules map { m =>
m match {
case m: ExModule => m
case m: InModule =>
- val (netlist, simlist) = expandWhens(m)
- createModule(netlist, simlist, m)
-
+ val (netlist, simlist, bodyx) = expandWhens(m)
+ val newBody = Begin(Seq(bodyx map squashEmpty) ++ expandNetlist(netlist) ++ simlist)
+ InModule(m.info, m.name, m.ports, newBody)
}
}
Circuit(c.info, modulesx, c.main)