diff options
| author | jackkoenig | 2016-04-26 16:26:31 -0700 |
|---|---|---|
| committer | jackkoenig | 2016-05-03 16:56:52 -0700 |
| commit | 41e0f6da3d60528241a46520b949c15bcbc29957 (patch) | |
| tree | 9149fa413b50935f0a2574f1a0fb75b5387b905a /src | |
| parent | a5526c177563b2c4de2a9c2b39a5b51a05697292 (diff) | |
Rewrite ExpandWhens to memoize complex default values
Fixes #113 and Fixes #150
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 225 |
1 files changed, 105 insertions, 120 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 1b6030e2..540aab9f 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -35,178 +35,163 @@ import firrtl.WrappedExpression._ // Datastructures import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ArrayBuffer +import annotation.tailrec + /** Expand Whens * * @note This pass does three things: remove last connect semantics, * remove conditional blocks, and eliminate concept of scoping. +* @note Assumes bulk connects and isInvalids have been expanded +* @note Assumes all references are declared */ object ExpandWhens extends Pass { def name = "Expand Whens" - var mname = "" + // ========== Expand When Utilz ========== - def getEntries( - hash: HashMap[WrappedExpression, Expression], - exps: Seq[Expression]): HashMap[WrappedExpression, Expression] = { - val hashx = HashMap[WrappedExpression, Expression]() + private def getEntries( + hash: LinkedHashMap[WrappedExpression, Expression], + exps: Seq[Expression]): LinkedHashMap[WrappedExpression, Expression] = { + val hashx = LinkedHashMap[WrappedExpression, Expression]() exps foreach (e => if (hash.contains(e)) hashx(e) = hash(e)) hashx } - def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = { + private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = { def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, DEFAULT)) val exps = create_exps(WRef(n, t, ExpKind(), g)) val expsx = ArrayBuffer[Expression]() - for (i <- 0 until exps.size) { - getGender(t, i, g) match { - case (BIGENDER | FEMALE) => expsx += exps(i) + for (j <- 0 until exps.size) { + getGender(t, j, g) match { + case (BIGENDER | FEMALE) => expsx += exps(j) case _ => } } expsx } - - // ------------ Pass ------------------- - def run(c: Circuit): Circuit = { - def voidAll(m: InModule): InModule = { - mname = m.name - def voidAllStmt(s: Stmt): Stmt = s match { - case (_: DefWire | _: DefRegister | _: WDefInstance |_: DefMemory) => - val voids = ArrayBuffer[Stmt]() - for (e <- getFemaleRefs(get_name(s),get_type(s),get_gender(s))) { - voids += Connect(get_info(s),e,WVoid()) - } - Begin(Seq(s,Begin(voids))) - case s => s map voidAllStmt - } - val voids = ArrayBuffer[Stmt]() - for (p <- m.ports) { - for (e <- getFemaleRefs(p.name,p.tpe,get_gender(p))) { - voids += Connect(p.info,e,WVoid()) + private def squashEmpty(s: Stmt): Stmt = { + s map squashEmpty match { + case Begin(stmts) => + val newStmts = stmts filter (_ != Empty()) + newStmts.size match { + case 0 => Empty() + case 1 => newStmts.head + case _ => Begin(newStmts) } + case s => s + } + } + private def expandNetlist(netlist: LinkedHashMap[WrappedExpression, Expression]) = + netlist map { case (k, v) => + v match { + case WInvalid() => IsInvalid(NoInfo, k.e1) + case _ => Connect(NoInfo, k.e1, v) } - val bodyx = voidAllStmt(m.body) - InModule(m.info, m.name, m.ports, Begin(Seq(Begin(voids),bodyx))) } - def expandWhens(m: InModule): (HashMap[WrappedExpression, Expression], ArrayBuffer[Stmt]) = { + // Searches nested scopes of defaults for lvalue + // defaults uses mutable Map because we are searching LinkedHashMaps and conversion to immutable is VERY slow + @tailrec + private def getDefault( + lvalue: WrappedExpression, + defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]]): Option[Expression] = { + if (defaults.isEmpty) None + else if (defaults.head.contains(lvalue)) defaults.head.get(lvalue) + else getDefault(lvalue, defaults.tail) + } + + // ------------ Pass ------------------- + def run(c: Circuit): Circuit = { + def expandWhens(m: InModule): (LinkedHashMap[WrappedExpression, Expression], ArrayBuffer[Stmt], Stmt) = { + val namespace = Namespace(m) val simlist = ArrayBuffer[Stmt]() - mname = m.name - def expandWhens(netlist: HashMap[WrappedExpression, Expression], p: Expression)(s: Stmt): Stmt = { + + // defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow + def expandWhens( + netlist: LinkedHashMap[WrappedExpression, Expression], + defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]], + p: Expression) + (s: Stmt): Stmt = { s match { - case s: Connect => netlist(s.loc) = s.exp - case s: IsInvalid => netlist(s.exp) = WInvalid() + case w: DefWire => + getFemaleRefs(w.name, w.tpe, BIGENDER) foreach (ref => netlist(ref) = WVoid()) + w + case r: DefRegister => + getFemaleRefs(r.name, r.tpe, BIGENDER) foreach (ref => netlist(ref) = ref) + r + case c: Connect => + netlist(c.loc) = c.exp + Empty() + case c: IsInvalid => + netlist(c.exp) = WInvalid() + Empty() case s: Conditionally => - val exps = ArrayBuffer[Expression]() - def prefetch(s: Stmt): Stmt = s match { - case s: Connect => exps += s.loc; s - case s => s map prefetch - } - prefetch(s.conseq) - val c_netlist = getEntries(netlist,exps) - expandWhens(c_netlist, AND(p, s.pred))(s.conseq) - expandWhens(netlist, AND(p, NOT(s.pred)))(s.alt) - for (lvalue <- c_netlist.keys) { - val value = netlist.get(lvalue) - value match { - case value: Some[Expression] => - val tv = c_netlist(lvalue) - val fv = value.get - val res = (tv, fv) match { - case (tv:WInvalid, fv:WInvalid) => WInvalid() - case (tv:WInvalid, fv) => ValidIf(NOT(s.pred), fv,tpe(fv)) - case (tv, fv:WInvalid) => ValidIf(s.pred, tv, tpe(tv)) + val memos = ArrayBuffer[Stmt]() + + val conseqNetlist = LinkedHashMap[WrappedExpression, Expression]() + val altNetlist = LinkedHashMap[WrappedExpression, Expression]() + val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq) + val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt) + + (conseqNetlist.keySet ++ altNetlist.keySet) foreach { lvalue => + // Defaults in netlist get priority over those in defaults + val default = if (netlist.contains(lvalue)) netlist.get(lvalue) else getDefault(lvalue, defaults) + val res = default match { + case Some(defaultValue) => + val trueValue = conseqNetlist.getOrElse(lvalue, defaultValue) + val falseValue = altNetlist.getOrElse(lvalue, defaultValue) + (trueValue, falseValue) match { + case (WInvalid(), WInvalid()) => WInvalid() + case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, tpe(fv)) + case (tv, WInvalid()) => ValidIf(s.pred, tv, tpe(tv)) case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv)) } - netlist(lvalue) = res - case None => netlist(lvalue) = c_netlist(lvalue) + case None => + // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt + conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)) } + + val memoNode = DefNode(s.info, namespace.newTemp, res) + val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE) + memos += memoNode + netlist(lvalue) = memoExpr } + Begin(Seq(conseqStmt, altStmt) ++ memos) + case s: Print => if(weq(p, one)) { simlist += s } else { simlist += Print(s.info, s.string, s.args, s.clk, AND(p, s.en)) } + Empty() case s: Stop => if (weq(p, one)) { simlist += s } else { simlist += Stop(s.info, s.ret, s.clk, AND(p, s.en)) } - case s => s map expandWhens(netlist, p) + Empty() + case s => s map expandWhens(netlist, defaults, p) } - s } - val netlist = HashMap[WrappedExpression, Expression]() - expandWhens(netlist, one)(m.body) + val netlist = LinkedHashMap[WrappedExpression, Expression]() - (netlist, simlist) - } - - def createModule(netlist: HashMap[WrappedExpression,Expression], simlist: ArrayBuffer[Stmt], m: InModule): InModule = { - mname = m.name - val stmts = ArrayBuffer[Stmt]() - val connections = ArrayBuffer[Stmt]() - def replace_void(e: Expression)(rvalue: Expression): Expression = rvalue match { - case rv: WVoid => e - case rv => rv map replace_void(e) - } - def create(s: Stmt): Stmt = { - s match { - case (_: DefWire | _: WDefInstance | _: DefMemory) => - stmts += s - for (e <- getFemaleRefs(get_name(s), get_type(s), get_gender(s))) { - val rvalue = netlist(e) - val con = rvalue match { - case rvalue: WInvalid => IsInvalid(get_info(s), e) - case rvalue => Connect(get_info(s), e, rvalue) - } - connections += con - } - case s: DefRegister => - stmts += s - for (e <- getFemaleRefs(get_name(s), get_type(s), get_gender(s))) { - val rvalue = replace_void(e)(netlist(e)) - val con = rvalue match { - case rvalue: WInvalid => IsInvalid(get_info(s), e) - case rvalue => Connect(get_info(s), e, rvalue) - } - connections += con - } - case (_: DefPoison | _: DefNode) => stmts += s - case s => s map create - } - s + // Add ports to netlist + m.ports foreach { port => + getFemaleRefs(port.name, port.tpe, to_gender(port.direction)) foreach (ref => netlist(ref) = WVoid()) } - create(m.body) - for (p <- m.ports) { - for (e <- getFemaleRefs(p.name, p.tpe, get_gender(p))) { - val rvalue = netlist(e) - val con = rvalue match { - case rvalue: WInvalid => IsInvalid(p.info, e) - case rvalue => Connect(p.info, e, rvalue) - } - connections += con - } - } - for (x <- simlist) { stmts += x } - InModule(m.info, m.name, m.ports, Begin(Seq(Begin(stmts), Begin(connections)))) - } + val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) - val voided_modules = c.modules map { m => - m match { - case m: ExModule => m - case m: InModule => voidAll(m) - } + (netlist, simlist, bodyx) } - - val modulesx = voided_modules map { m => + val modulesx = c.modules map { m => m match { case m: ExModule => m case m: InModule => - val (netlist, simlist) = expandWhens(m) - createModule(netlist, simlist, m) - + val (netlist, simlist, bodyx) = expandWhens(m) + val newBody = Begin(Seq(bodyx map squashEmpty) ++ expandNetlist(netlist) ++ simlist) + InModule(m.info, m.name, m.ports, newBody) } } Circuit(c.info, modulesx, c.main) |
