diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/ExpandWhens.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 173 |
1 files changed, 90 insertions, 83 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index ab7f02db..14d5d3ef 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -28,21 +28,23 @@ import collection.mutable object ExpandWhens extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Resolved + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses) + ) ++ firrtl.stage.Forms.Resolved override def invalidates(a: Transform): Boolean = a match { case CheckInitialization | ResolveKinds | InferTypes => true - case _ => false + case _ => false } /** Returns circuit with when and last connection semantics resolved */ def run(c: Circuit): Circuit = { - val modulesx = c.modules map { + val modulesx = c.modules.map { case m: ExtModule => m - case m: Module => onModule(m) + case m: Module => onModule(m) } Circuit(c.info, modulesx, c.main) } @@ -74,13 +76,12 @@ object ExpandWhens extends Pass { // Does an expression contain WVoid inserted in this pass? def containsVoid(e: Expression): Boolean = e match { - case WVoid => true + case WVoid => true case ValidIf(_, value, _) => memoizedVoid(value) - case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) - case _ => false + case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) + case _ => false } - // Memoizes the node that holds a particular expression, if any val nodes = new NodeLookup @@ -95,18 +96,15 @@ object ExpandWhens extends Pass { * @param p predicate so far, used to update simulation constructs * @param s statement to expand */ - def expandWhens(netlist: Netlist, - defaults: Defaults, - p: Expression) - (s: Statement): Statement = s match { + def expandWhens(netlist: Netlist, defaults: Defaults, p: Expression)(s: Statement): Statement = s match { // For each non-register declaration, update netlist with value WVoid for each sink reference // Return self, unchanged case stmt @ (_: DefNode | EmptyStmt) => stmt case w: DefWire => - netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow) map (ref => we(ref) -> WVoid)) + netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow).map(ref => we(ref) -> WVoid)) w case w: DefMemory => - netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow) map (ref => we(ref) -> WVoid)) + netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow).map(ref => we(ref) -> WVoid)) w case w: WDefInstance => netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> WVoid)) @@ -151,82 +149,88 @@ object ExpandWhens extends Pass { // Process combined maps because we only want to create 1 mux for each node // present in the conseq and/or alt - val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) => - // Defaults in netlist get priority over those in defaults - val default = netlist get lvalue match { - case Some(v) => Some(v) - case None => getDefault(lvalue, defaults) - } - // info0 and info1 correspond to Mux infos, use info0 only if ValidIf - val (res, info0, info1) = default match { - case Some(defaultValue) => - val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue)) - val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue)) - (trueValue, falseValue) match { - case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo) - case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo) - case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo) - case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo) - } - case None => - // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt - (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo) - } + val memos = (conseqNetlist ++ altNetlist).map { + case (lvalue, _) => + // Defaults in netlist get priority over those in defaults + val default = netlist.get(lvalue) match { + case Some(v) => Some(v) + case None => getDefault(lvalue, defaults) + } + // info0 and info1 correspond to Mux infos, use info0 only if ValidIf + val (res, info0, info1) = default match { + case Some(defaultValue) => + val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue)) + val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue)) + (trueValue, falseValue) match { + case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo) + case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo) + case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo) + case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo) + } + case None => + // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt + (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo) + } - res match { - // Don't create a node to hold mux trees with void values - // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches) - case e if containsVoid(e) => - netlist(lvalue) = e - memoizedVoid += e // remember that this was void - EmptyStmt - case _: ValidIf | _: Mux | _: DoPrim => nodes get res match { - case Some(name) => - netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) + res match { + // Don't create a node to hold mux trees with void values + // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches) + case e if containsVoid(e) => + netlist(lvalue) = e + memoizedVoid += e // remember that this was void + EmptyStmt + case _: ValidIf | _: Mux | _: DoPrim => + nodes.get(res) match { + case Some(name) => + netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) + EmptyStmt + case None => + val name = namespace.newTemp + nodes(res) = name + netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) + // Use MultiInfo constructor to preserve NoInfos + val info = new MultiInfo(List(sx.info, info0, info1)) + DefNode(info, name, res) + } + case _ => + netlist(lvalue) = res EmptyStmt - case None => - val name = namespace.newTemp - nodes(res) = name - netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) - // Use MultiInfo constructor to preserve NoInfos - val info = new MultiInfo(List(sx.info, info0, info1)) - DefNode(info, name, res) } - case _ => - netlist(lvalue) = res - EmptyStmt - } } Block(Seq(conseqStmt, altStmt) ++ memos) - case block: Block => block map expandWhens(netlist, defaults, p) + case block: Block => block.map(expandWhens(netlist, defaults, p)) case _ => throwInternalError() } val netlist = new Netlist // Add ports to netlist - netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) => - getSinkRefs(name, tpe, to_flow(dir)) map (ref => we(ref) -> WVoid) + netlist ++= (m.ports.flatMap { + case Port(_, name, dir, tpe) => + getSinkRefs(name, tpe, to_flow(dir)).map(ref => we(ref) -> WVoid) }) // Do traversal and construct mutable datastructures val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet - val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ - combineAttaches(attaches.toSeq) ++ simlist) + val newBody = Block( + Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ + combineAttaches(attaches.toSeq) ++ simlist + ) Module(m.info, m.name, m.ports, newBody) } - /** Returns all references to all sink leaf subcomponents of a reference */ private def getSinkRefs(n: String, t: Type, g: Flow): Seq[Expression] = { val exps = create_exps(WRef(n, t, ExpKind, g)) - exps.flatMap { case exp => - exp.tpe match { - case AnalogType(w) => None - case _ => flow(exp) match { - case (DuplexFlow | SinkFlow) => Some(exp) - case _ => None + exps.flatMap { + case exp => + exp.tpe match { + case AnalogType(w) => None + case _ => + flow(exp) match { + case (DuplexFlow | SinkFlow) => Some(exp) + case _ => None + } } - } } } @@ -238,7 +242,7 @@ object ExpandWhens extends Pass { def handleInvalid(k: WrappedExpression, info: Info): Statement = if (attached.contains(k)) EmptyStmt else IsInvalid(info, k.e1) netlist.map { - case (k, WInvalid) => handleInvalid(k, NoInfo) + case (k, WInvalid) => handleInvalid(k, NoInfo) case (k, InfoExpr(info, WInvalid)) => handleInvalid(k, info) case (k, v) => val (info, expr) = unwrap(v) @@ -261,7 +265,7 @@ object ExpandWhens extends Pass { case Seq() => // None of these expressions is present in the attachMap AttachAcc(exprs, attachMap.size) case accs => // At least one expression present in the attachMap - val sorted = accs sortBy (_.idx) + val sorted = accs.sortBy(_.idx) AttachAcc((sorted.map(_.exprs) :+ exprs).flatten.distinct, sorted.head.idx) } attachMap ++= acc.exprs.map(_ -> acc) @@ -274,10 +278,11 @@ object ExpandWhens extends Pass { private def getDefault(lvalue: WrappedExpression, defaults: Defaults): Option[Expression] = { defaults match { case Nil => None - case head :: tail => head get lvalue match { - case Some(p) => Some(p) - case None => getDefault(lvalue, tail) - } + case head :: tail => + head.get(lvalue) match { + case Some(p) => Some(p) + case None => getDefault(lvalue, tail) + } } } @@ -290,10 +295,12 @@ object ExpandWhens extends Pass { class ExpandWhensAndCheck extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses) + ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform): Boolean = a match { case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true @@ -301,6 +308,6 @@ class ExpandWhensAndCheck extends Transform with DependencyAPIMigration { } override def execute(a: CircuitState): CircuitState = - Seq(ExpandWhens, CheckInitialization).foldLeft(a){ case (acc, tx) => tx.transform(acc) } + Seq(ExpandWhens, CheckInitialization).foldLeft(a) { case (acc, tx) => tx.transform(acc) } } |
