aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAdam Izraelevitz2018-02-05 15:08:05 -0800
committerJack Koenig2018-02-05 15:08:05 -0800
commit1fe1b6671a02de613f3cab87dd81526ac1417d39 (patch)
treeea5254f3116dbc360d977de8e38e100b505ad914 /src
parent57025111d3bc872da726e31e3e9a1e4895593266 (diff)
Added comments to ExpandWhens (#716)
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala272
1 files changed, 159 insertions, 113 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index 6ff0debe..959e824a 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -25,13 +25,163 @@ import collection.immutable.ListSet
* @note Assumes all references are declared
*/
object ExpandWhens extends Pass {
+ /** Returns circuit with when and last connection semantics resolved */
+ def run(c: Circuit): Circuit = {
+ val modulesx = c.modules map {
+ case m: ExtModule => m
+ case m: Module =>
+ val (netlist, simlist, attaches, bodyx) = expandWhens(m)
+ val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet
+ val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++
+ combineAttaches(attaches) ++ simlist)
+ Module(m.info, m.name, m.ports, newBody)
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
+
+ /** Maps an expression to a declared node name. Used to memoize predicates */
type NodeMap = mutable.HashMap[MemoizedHash[Expression], String]
+
+ /** Maps a reference to whatever connects to it. Used to resolve last connect semantics */
type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression]
+
+ /** Contains all simulation constructs */
type Simlist = mutable.ArrayBuffer[Statement]
- // Defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow
+
+ /** List of all netlists of each declared scope, ordered from closest to farthest
+ * @note Note immutable.Map because conversion from mutable.LinkedHashMap to mutable.Map is VERY slow
+ */
type Defaults = Seq[mutable.Map[WrappedExpression, Expression]]
- // ========== Expand When Utilz ==========
+
+ /** Expands a module's when statements
+ * @param m Module to expand
+ * @note Netlist maps a reference to whatever connects to it
+ * @note Simlist contains all simulation constructs in m
+ * @note Seq[Attach] contains all Attach statements (unsimplified)
+ * @note Statement contains all declarations in the module (including DefNode's)
+ */
+ def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement) = {
+ val namespace = Namespace(m)
+ val simlist = new Simlist
+ val nodes = new NodeMap
+ // Seq of attaches in order
+ lazy val attaches = mutable.ArrayBuffer.empty[Attach]
+
+ /** Removes connections/attaches from the statement
+ * Mutates namespace, simlist, nodes, attaches
+ * Mutates input netlist
+ * @param netlist maps references to their values for a given immediate scope
+ * @param defaults sequence of netlists of surrouding scopes, ordered closest to farthest
+ * @param p predicate so far, used to update simulation constructs
+ * @param s statement to expand
+ */
+ def expandWhens(netlist: Netlist,
+ defaults: Defaults,
+ p: Expression)
+ (s: Statement): Statement = s match {
+ // For each non-register declaration, update netlist with value WVoid for each female reference
+ // Return self, unchanged
+ case stmt @ (_: DefNode | EmptyStmt) => stmt
+ case w: DefWire =>
+ netlist ++= (getFemaleRefs(w.name, w.tpe, BIGENDER) map (ref => we(ref) -> WVoid))
+ w
+ case w: DefMemory =>
+ netlist ++= (getFemaleRefs(w.name, MemPortUtils.memType(w), MALE) map (ref => we(ref) -> WVoid))
+ w
+ case w: WDefInstance =>
+ netlist ++= (getFemaleRefs(w.name, w.tpe, MALE).map(ref => we(ref) -> WVoid))
+ w
+ // Update netlist with self reference for each female reference
+ // Return self, unchanged
+ case r: DefRegister =>
+ netlist ++= (getFemaleRefs(r.name, r.tpe, BIGENDER) map (ref => we(ref) -> ref))
+ r
+ // For value assignments, update netlist/attaches and return EmptyStmt
+ case c: Connect =>
+ netlist(c.loc) = c.expr
+ EmptyStmt
+ case c: IsInvalid =>
+ netlist(c.expr) = WInvalid
+ EmptyStmt
+ case a: Attach =>
+ attaches += a
+ EmptyStmt
+ // For simulation constructs, update simlist with predicated statement and return EmptyStmt
+ case sx: Print =>
+ simlist += (if (weq(p, one)) sx else Print(sx.info, sx.string, sx.args, sx.clk, AND(p, sx.en)))
+ EmptyStmt
+ case sx: Stop =>
+ simlist += (if (weq(p, one)) sx else Stop(sx.info, sx.ret, sx.clk, AND(p, sx.en)))
+ EmptyStmt
+ // Expand conditionally, see comments below
+ case sx: Conditionally =>
+ /** 1) Recurse into conseq and alt with empty netlist, updated defaults, updated predicate
+ * 2) For each assigned reference (lvalue) in either conseq or alt, get merged value
+ * a) Find default value from defaults
+ * b) Create Mux, ValidIf or WInvalid, depending which (or both) conseq/alt assigned lvalue
+ * 3) If a merged value has been memoized, update netlist. Otherwise, memoize then update netlist.
+ * 4) Return conseq and alt declarations, followed by memoized nodes
+ */
+ val conseqNetlist = new Netlist
+ val altNetlist = new Netlist
+ val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, sx.pred))(sx.conseq)
+ val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(sx.pred)))(sx.alt)
+
+ // Process combined maps because we only want to create 1 mux for each node
+ // present in the conseq and/or alt
+ val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) =>
+ // Defaults in netlist get priority over those in defaults
+ val default = netlist get lvalue match {
+ case Some(v) => Some(v)
+ case None => getDefault(lvalue, defaults)
+ }
+ val res = default match {
+ case Some(defaultValue) =>
+ val trueValue = conseqNetlist getOrElse (lvalue, defaultValue)
+ val falseValue = altNetlist getOrElse (lvalue, defaultValue)
+ (trueValue, falseValue) match {
+ case (WInvalid, WInvalid) => WInvalid
+ case (WInvalid, fv) => ValidIf(NOT(sx.pred), fv, fv.tpe)
+ case (tv, WInvalid) => ValidIf(sx.pred, tv, tv.tpe)
+ case (tv, fv) => Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)) //Muxing clocks will be checked during type checking
+ }
+ case None =>
+ // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
+ conseqNetlist getOrElse (lvalue, altNetlist(lvalue))
+ }
+
+ res match {
+ case _: ValidIf | _: Mux | _: DoPrim => nodes get res match {
+ case Some(name) =>
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind, MALE)
+ EmptyStmt
+ case None =>
+ val name = namespace.newTemp
+ nodes(res) = name
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind, MALE)
+ DefNode(sx.info, name, res)
+ }
+ case _ =>
+ netlist(lvalue) = res
+ EmptyStmt
+ }
+ }
+ Block(Seq(conseqStmt, altStmt) ++ memos)
+ case block: Block => block map expandWhens(netlist, defaults, p)
+ case _ => throwInternalError
+ }
+ val netlist = new Netlist
+ // Add ports to netlist
+ netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) =>
+ getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid)
+ })
+ val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
+ (netlist, simlist, attaches, bodyx)
+ }
+
+
+ /** Returns all references to all Female leaf subcomponents of a reference */
private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = {
def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, Default))
val exps = create_exps(WRef(n, t, ExpKind, g))
@@ -45,13 +195,19 @@ object ExpandWhens extends Pass {
}
}
}
+
+ /** Returns all connections/invalidations in the circuit
+ * @todo Preserve Info
+ * @note Remove IsInvalids on attached Analog-typed components
+ */
private def expandNetlist(netlist: Netlist, attached: Set[WrappedExpression]) =
netlist map {
case (k, WInvalid) => // Remove IsInvalids on attached Analog types
if (attached.contains(k)) EmptyStmt else IsInvalid(NoInfo, k.e1)
case (k, v) => Connect(NoInfo, k.e1, v)
}
- /** Combines Attaches
+
+ /** Returns new sequence of combined Attaches
* @todo Preserve Info
*/
private def combineAttaches(attaches: Seq[Attach]): Seq[Attach] = {
@@ -89,115 +245,5 @@ object ExpandWhens extends Pass {
DoPrim(And, Seq(e1, e2), Nil, BoolType)
private def NOT(e: Expression) =
DoPrim(Eq, Seq(e, zero), Nil, BoolType)
-
- // ------------ Pass -------------------
- def run(c: Circuit): Circuit = {
- def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement) = {
- val namespace = Namespace(m)
- val simlist = new Simlist
- val nodes = new NodeMap
- // Seq of attaches in order
- lazy val attaches = mutable.ArrayBuffer.empty[Attach]
-
- def expandWhens(netlist: Netlist,
- defaults: Defaults,
- p: Expression)
- (s: Statement): Statement = s match {
- case stmt @ (_: DefNode | EmptyStmt) => stmt
- case w: DefWire =>
- netlist ++= (getFemaleRefs(w.name, w.tpe, BIGENDER) map (ref => we(ref) -> WVoid))
- w
- case w: DefMemory =>
- netlist ++= (getFemaleRefs(w.name, MemPortUtils.memType(w), MALE) map (ref => we(ref) -> WVoid))
- w
- case w: WDefInstance =>
- netlist ++= (getFemaleRefs(w.name, w.tpe, MALE).map(ref => we(ref) -> WVoid))
- w
- case r: DefRegister =>
- netlist ++= (getFemaleRefs(r.name, r.tpe, BIGENDER) map (ref => we(ref) -> ref))
- r
- case c: Connect =>
- netlist(c.loc) = c.expr
- EmptyStmt
- case c: IsInvalid =>
- netlist(c.expr) = WInvalid
- EmptyStmt
- case a: Attach =>
- attaches += a
- EmptyStmt
- case sx: Conditionally =>
- val conseqNetlist = new Netlist
- val altNetlist = new Netlist
- val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, sx.pred))(sx.conseq)
- val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(sx.pred)))(sx.alt)
-
- // Process combined maps because we only want to create 1 mux for each node
- // present in the conseq and/or alt
- val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) =>
- // Defaults in netlist get priority over those in defaults
- val default = netlist get lvalue match {
- case Some(v) => Some(v)
- case None => getDefault(lvalue, defaults)
- }
- val res = default match {
- case Some(defaultValue) =>
- val trueValue = conseqNetlist getOrElse (lvalue, defaultValue)
- val falseValue = altNetlist getOrElse (lvalue, defaultValue)
- (trueValue, falseValue) match {
- case (WInvalid, WInvalid) => WInvalid
- case (WInvalid, fv) => ValidIf(NOT(sx.pred), fv, fv.tpe)
- case (tv, WInvalid) => ValidIf(sx.pred, tv, tv.tpe)
- case (tv, fv) => Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv))
- }
- case None =>
- // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
- conseqNetlist getOrElse (lvalue, altNetlist(lvalue))
- }
-
- res match {
- case _: ValidIf | _: Mux | _: DoPrim => nodes get res match {
- case Some(name) =>
- netlist(lvalue) = WRef(name, res.tpe, NodeKind, MALE)
- EmptyStmt
- case None =>
- val name = namespace.newTemp
- nodes(res) = name
- netlist(lvalue) = WRef(name, res.tpe, NodeKind, MALE)
- DefNode(sx.info, name, res)
- }
- case _ =>
- netlist(lvalue) = res
- EmptyStmt
- }
- }
- Block(Seq(conseqStmt, altStmt) ++ memos)
- case sx: Print =>
- simlist += (if (weq(p, one)) sx else Print(sx.info, sx.string, sx.args, sx.clk, AND(p, sx.en)))
- EmptyStmt
- case sx: Stop =>
- simlist += (if (weq(p, one)) sx else Stop(sx.info, sx.ret, sx.clk, AND(p, sx.en)))
- EmptyStmt
- case block: Block => block map expandWhens(netlist, defaults, p)
- case _ => throwInternalError
- }
- val netlist = new Netlist
- // Add ports to netlist
- netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) =>
- getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid)
- })
- val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
- (netlist, simlist, attaches, bodyx)
- }
- val modulesx = c.modules map {
- case m: ExtModule => m
- case m: Module =>
- val (netlist, simlist, attaches, bodyx) = expandWhens(m)
- val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet
- val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++
- combineAttaches(attaches) ++ simlist)
- Module(m.info, m.name, m.ports, newBody)
- }
- Circuit(c.info, modulesx, c.main)
- }
}