diff options
| author | Adam Izraelevitz | 2018-02-05 15:08:05 -0800 |
|---|---|---|
| committer | Jack Koenig | 2018-02-05 15:08:05 -0800 |
| commit | 1fe1b6671a02de613f3cab87dd81526ac1417d39 (patch) | |
| tree | ea5254f3116dbc360d977de8e38e100b505ad914 /src | |
| parent | 57025111d3bc872da726e31e3e9a1e4895593266 (diff) | |
Added comments to ExpandWhens (#716)
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 272 |
1 files changed, 159 insertions, 113 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 6ff0debe..959e824a 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -25,13 +25,163 @@ import collection.immutable.ListSet * @note Assumes all references are declared */ object ExpandWhens extends Pass { + /** Returns circuit with when and last connection semantics resolved */ + def run(c: Circuit): Circuit = { + val modulesx = c.modules map { + case m: ExtModule => m + case m: Module => + val (netlist, simlist, attaches, bodyx) = expandWhens(m) + val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet + val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ + combineAttaches(attaches) ++ simlist) + Module(m.info, m.name, m.ports, newBody) + } + Circuit(c.info, modulesx, c.main) + } + + /** Maps an expression to a declared node name. Used to memoize predicates */ type NodeMap = mutable.HashMap[MemoizedHash[Expression], String] + + /** Maps a reference to whatever connects to it. Used to resolve last connect semantics */ type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression] + + /** Contains all simulation constructs */ type Simlist = mutable.ArrayBuffer[Statement] - // Defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow + + /** List of all netlists of each declared scope, ordered from closest to farthest + * @note Note immutable.Map because conversion from mutable.LinkedHashMap to mutable.Map is VERY slow + */ type Defaults = Seq[mutable.Map[WrappedExpression, Expression]] - // ========== Expand When Utilz ========== + + /** Expands a module's when statements + * @param m Module to expand + * @note Netlist maps a reference to whatever connects to it + * @note Simlist contains all simulation constructs in m + * @note Seq[Attach] contains all Attach statements (unsimplified) + * @note Statement contains all declarations in the module (including DefNode's) + */ + def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement) = { + val namespace = Namespace(m) + val simlist = new Simlist + val nodes = new NodeMap + // Seq of attaches in order + lazy val attaches = mutable.ArrayBuffer.empty[Attach] + + /** Removes connections/attaches from the statement + * Mutates namespace, simlist, nodes, attaches + * Mutates input netlist + * @param netlist maps references to their values for a given immediate scope + * @param defaults sequence of netlists of surrouding scopes, ordered closest to farthest + * @param p predicate so far, used to update simulation constructs + * @param s statement to expand + */ + def expandWhens(netlist: Netlist, + defaults: Defaults, + p: Expression) + (s: Statement): Statement = s match { + // For each non-register declaration, update netlist with value WVoid for each female reference + // Return self, unchanged + case stmt @ (_: DefNode | EmptyStmt) => stmt + case w: DefWire => + netlist ++= (getFemaleRefs(w.name, w.tpe, BIGENDER) map (ref => we(ref) -> WVoid)) + w + case w: DefMemory => + netlist ++= (getFemaleRefs(w.name, MemPortUtils.memType(w), MALE) map (ref => we(ref) -> WVoid)) + w + case w: WDefInstance => + netlist ++= (getFemaleRefs(w.name, w.tpe, MALE).map(ref => we(ref) -> WVoid)) + w + // Update netlist with self reference for each female reference + // Return self, unchanged + case r: DefRegister => + netlist ++= (getFemaleRefs(r.name, r.tpe, BIGENDER) map (ref => we(ref) -> ref)) + r + // For value assignments, update netlist/attaches and return EmptyStmt + case c: Connect => + netlist(c.loc) = c.expr + EmptyStmt + case c: IsInvalid => + netlist(c.expr) = WInvalid + EmptyStmt + case a: Attach => + attaches += a + EmptyStmt + // For simulation constructs, update simlist with predicated statement and return EmptyStmt + case sx: Print => + simlist += (if (weq(p, one)) sx else Print(sx.info, sx.string, sx.args, sx.clk, AND(p, sx.en))) + EmptyStmt + case sx: Stop => + simlist += (if (weq(p, one)) sx else Stop(sx.info, sx.ret, sx.clk, AND(p, sx.en))) + EmptyStmt + // Expand conditionally, see comments below + case sx: Conditionally => + /** 1) Recurse into conseq and alt with empty netlist, updated defaults, updated predicate + * 2) For each assigned reference (lvalue) in either conseq or alt, get merged value + * a) Find default value from defaults + * b) Create Mux, ValidIf or WInvalid, depending which (or both) conseq/alt assigned lvalue + * 3) If a merged value has been memoized, update netlist. Otherwise, memoize then update netlist. + * 4) Return conseq and alt declarations, followed by memoized nodes + */ + val conseqNetlist = new Netlist + val altNetlist = new Netlist + val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, sx.pred))(sx.conseq) + val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(sx.pred)))(sx.alt) + + // Process combined maps because we only want to create 1 mux for each node + // present in the conseq and/or alt + val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) => + // Defaults in netlist get priority over those in defaults + val default = netlist get lvalue match { + case Some(v) => Some(v) + case None => getDefault(lvalue, defaults) + } + val res = default match { + case Some(defaultValue) => + val trueValue = conseqNetlist getOrElse (lvalue, defaultValue) + val falseValue = altNetlist getOrElse (lvalue, defaultValue) + (trueValue, falseValue) match { + case (WInvalid, WInvalid) => WInvalid + case (WInvalid, fv) => ValidIf(NOT(sx.pred), fv, fv.tpe) + case (tv, WInvalid) => ValidIf(sx.pred, tv, tv.tpe) + case (tv, fv) => Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)) //Muxing clocks will be checked during type checking + } + case None => + // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt + conseqNetlist getOrElse (lvalue, altNetlist(lvalue)) + } + + res match { + case _: ValidIf | _: Mux | _: DoPrim => nodes get res match { + case Some(name) => + netlist(lvalue) = WRef(name, res.tpe, NodeKind, MALE) + EmptyStmt + case None => + val name = namespace.newTemp + nodes(res) = name + netlist(lvalue) = WRef(name, res.tpe, NodeKind, MALE) + DefNode(sx.info, name, res) + } + case _ => + netlist(lvalue) = res + EmptyStmt + } + } + Block(Seq(conseqStmt, altStmt) ++ memos) + case block: Block => block map expandWhens(netlist, defaults, p) + case _ => throwInternalError + } + val netlist = new Netlist + // Add ports to netlist + netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) => + getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid) + }) + val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) + (netlist, simlist, attaches, bodyx) + } + + + /** Returns all references to all Female leaf subcomponents of a reference */ private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = { def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, Default)) val exps = create_exps(WRef(n, t, ExpKind, g)) @@ -45,13 +195,19 @@ object ExpandWhens extends Pass { } } } + + /** Returns all connections/invalidations in the circuit + * @todo Preserve Info + * @note Remove IsInvalids on attached Analog-typed components + */ private def expandNetlist(netlist: Netlist, attached: Set[WrappedExpression]) = netlist map { case (k, WInvalid) => // Remove IsInvalids on attached Analog types if (attached.contains(k)) EmptyStmt else IsInvalid(NoInfo, k.e1) case (k, v) => Connect(NoInfo, k.e1, v) } - /** Combines Attaches + + /** Returns new sequence of combined Attaches * @todo Preserve Info */ private def combineAttaches(attaches: Seq[Attach]): Seq[Attach] = { @@ -89,115 +245,5 @@ object ExpandWhens extends Pass { DoPrim(And, Seq(e1, e2), Nil, BoolType) private def NOT(e: Expression) = DoPrim(Eq, Seq(e, zero), Nil, BoolType) - - // ------------ Pass ------------------- - def run(c: Circuit): Circuit = { - def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement) = { - val namespace = Namespace(m) - val simlist = new Simlist - val nodes = new NodeMap - // Seq of attaches in order - lazy val attaches = mutable.ArrayBuffer.empty[Attach] - - def expandWhens(netlist: Netlist, - defaults: Defaults, - p: Expression) - (s: Statement): Statement = s match { - case stmt @ (_: DefNode | EmptyStmt) => stmt - case w: DefWire => - netlist ++= (getFemaleRefs(w.name, w.tpe, BIGENDER) map (ref => we(ref) -> WVoid)) - w - case w: DefMemory => - netlist ++= (getFemaleRefs(w.name, MemPortUtils.memType(w), MALE) map (ref => we(ref) -> WVoid)) - w - case w: WDefInstance => - netlist ++= (getFemaleRefs(w.name, w.tpe, MALE).map(ref => we(ref) -> WVoid)) - w - case r: DefRegister => - netlist ++= (getFemaleRefs(r.name, r.tpe, BIGENDER) map (ref => we(ref) -> ref)) - r - case c: Connect => - netlist(c.loc) = c.expr - EmptyStmt - case c: IsInvalid => - netlist(c.expr) = WInvalid - EmptyStmt - case a: Attach => - attaches += a - EmptyStmt - case sx: Conditionally => - val conseqNetlist = new Netlist - val altNetlist = new Netlist - val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, sx.pred))(sx.conseq) - val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(sx.pred)))(sx.alt) - - // Process combined maps because we only want to create 1 mux for each node - // present in the conseq and/or alt - val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) => - // Defaults in netlist get priority over those in defaults - val default = netlist get lvalue match { - case Some(v) => Some(v) - case None => getDefault(lvalue, defaults) - } - val res = default match { - case Some(defaultValue) => - val trueValue = conseqNetlist getOrElse (lvalue, defaultValue) - val falseValue = altNetlist getOrElse (lvalue, defaultValue) - (trueValue, falseValue) match { - case (WInvalid, WInvalid) => WInvalid - case (WInvalid, fv) => ValidIf(NOT(sx.pred), fv, fv.tpe) - case (tv, WInvalid) => ValidIf(sx.pred, tv, tv.tpe) - case (tv, fv) => Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)) - } - case None => - // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt - conseqNetlist getOrElse (lvalue, altNetlist(lvalue)) - } - - res match { - case _: ValidIf | _: Mux | _: DoPrim => nodes get res match { - case Some(name) => - netlist(lvalue) = WRef(name, res.tpe, NodeKind, MALE) - EmptyStmt - case None => - val name = namespace.newTemp - nodes(res) = name - netlist(lvalue) = WRef(name, res.tpe, NodeKind, MALE) - DefNode(sx.info, name, res) - } - case _ => - netlist(lvalue) = res - EmptyStmt - } - } - Block(Seq(conseqStmt, altStmt) ++ memos) - case sx: Print => - simlist += (if (weq(p, one)) sx else Print(sx.info, sx.string, sx.args, sx.clk, AND(p, sx.en))) - EmptyStmt - case sx: Stop => - simlist += (if (weq(p, one)) sx else Stop(sx.info, sx.ret, sx.clk, AND(p, sx.en))) - EmptyStmt - case block: Block => block map expandWhens(netlist, defaults, p) - case _ => throwInternalError - } - val netlist = new Netlist - // Add ports to netlist - netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) => - getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid) - }) - val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) - (netlist, simlist, attaches, bodyx) - } - val modulesx = c.modules map { - case m: ExtModule => m - case m: Module => - val (netlist, simlist, attaches, bodyx) = expandWhens(m) - val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet - val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ - combineAttaches(attaches) ++ simlist) - Module(m.info, m.name, m.ports, newBody) - } - Circuit(c.info, modulesx, c.main) - } } |
