aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ExpandWhens.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/ExpandWhens.scala')
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala173
1 files changed, 90 insertions, 83 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index ab7f02db..14d5d3ef 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -28,21 +28,23 @@ import collection.mutable
object ExpandWhens extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Resolved
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses)
+ ) ++ firrtl.stage.Forms.Resolved
override def invalidates(a: Transform): Boolean = a match {
case CheckInitialization | ResolveKinds | InferTypes => true
- case _ => false
+ case _ => false
}
/** Returns circuit with when and last connection semantics resolved */
def run(c: Circuit): Circuit = {
- val modulesx = c.modules map {
+ val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => onModule(m)
+ case m: Module => onModule(m)
}
Circuit(c.info, modulesx, c.main)
}
@@ -74,13 +76,12 @@ object ExpandWhens extends Pass {
// Does an expression contain WVoid inserted in this pass?
def containsVoid(e: Expression): Boolean = e match {
- case WVoid => true
+ case WVoid => true
case ValidIf(_, value, _) => memoizedVoid(value)
- case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv)
- case _ => false
+ case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv)
+ case _ => false
}
-
// Memoizes the node that holds a particular expression, if any
val nodes = new NodeLookup
@@ -95,18 +96,15 @@ object ExpandWhens extends Pass {
* @param p predicate so far, used to update simulation constructs
* @param s statement to expand
*/
- def expandWhens(netlist: Netlist,
- defaults: Defaults,
- p: Expression)
- (s: Statement): Statement = s match {
+ def expandWhens(netlist: Netlist, defaults: Defaults, p: Expression)(s: Statement): Statement = s match {
// For each non-register declaration, update netlist with value WVoid for each sink reference
// Return self, unchanged
case stmt @ (_: DefNode | EmptyStmt) => stmt
case w: DefWire =>
- netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow) map (ref => we(ref) -> WVoid))
+ netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow).map(ref => we(ref) -> WVoid))
w
case w: DefMemory =>
- netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow) map (ref => we(ref) -> WVoid))
+ netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow).map(ref => we(ref) -> WVoid))
w
case w: WDefInstance =>
netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> WVoid))
@@ -151,82 +149,88 @@ object ExpandWhens extends Pass {
// Process combined maps because we only want to create 1 mux for each node
// present in the conseq and/or alt
- val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) =>
- // Defaults in netlist get priority over those in defaults
- val default = netlist get lvalue match {
- case Some(v) => Some(v)
- case None => getDefault(lvalue, defaults)
- }
- // info0 and info1 correspond to Mux infos, use info0 only if ValidIf
- val (res, info0, info1) = default match {
- case Some(defaultValue) =>
- val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue))
- val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue))
- (trueValue, falseValue) match {
- case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo)
- case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo)
- case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo)
- case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo)
- }
- case None =>
- // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
- (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo)
- }
+ val memos = (conseqNetlist ++ altNetlist).map {
+ case (lvalue, _) =>
+ // Defaults in netlist get priority over those in defaults
+ val default = netlist.get(lvalue) match {
+ case Some(v) => Some(v)
+ case None => getDefault(lvalue, defaults)
+ }
+ // info0 and info1 correspond to Mux infos, use info0 only if ValidIf
+ val (res, info0, info1) = default match {
+ case Some(defaultValue) =>
+ val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue))
+ val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue))
+ (trueValue, falseValue) match {
+ case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo)
+ case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo)
+ case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo)
+ case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo)
+ }
+ case None =>
+ // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
+ (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo)
+ }
- res match {
- // Don't create a node to hold mux trees with void values
- // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches)
- case e if containsVoid(e) =>
- netlist(lvalue) = e
- memoizedVoid += e // remember that this was void
- EmptyStmt
- case _: ValidIf | _: Mux | _: DoPrim => nodes get res match {
- case Some(name) =>
- netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
+ res match {
+ // Don't create a node to hold mux trees with void values
+ // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches)
+ case e if containsVoid(e) =>
+ netlist(lvalue) = e
+ memoizedVoid += e // remember that this was void
+ EmptyStmt
+ case _: ValidIf | _: Mux | _: DoPrim =>
+ nodes.get(res) match {
+ case Some(name) =>
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
+ EmptyStmt
+ case None =>
+ val name = namespace.newTemp
+ nodes(res) = name
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
+ // Use MultiInfo constructor to preserve NoInfos
+ val info = new MultiInfo(List(sx.info, info0, info1))
+ DefNode(info, name, res)
+ }
+ case _ =>
+ netlist(lvalue) = res
EmptyStmt
- case None =>
- val name = namespace.newTemp
- nodes(res) = name
- netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
- // Use MultiInfo constructor to preserve NoInfos
- val info = new MultiInfo(List(sx.info, info0, info1))
- DefNode(info, name, res)
}
- case _ =>
- netlist(lvalue) = res
- EmptyStmt
- }
}
Block(Seq(conseqStmt, altStmt) ++ memos)
- case block: Block => block map expandWhens(netlist, defaults, p)
+ case block: Block => block.map(expandWhens(netlist, defaults, p))
case _ => throwInternalError()
}
val netlist = new Netlist
// Add ports to netlist
- netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) =>
- getSinkRefs(name, tpe, to_flow(dir)) map (ref => we(ref) -> WVoid)
+ netlist ++= (m.ports.flatMap {
+ case Port(_, name, dir, tpe) =>
+ getSinkRefs(name, tpe, to_flow(dir)).map(ref => we(ref) -> WVoid)
})
// Do traversal and construct mutable datastructures
val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet
- val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++
- combineAttaches(attaches.toSeq) ++ simlist)
+ val newBody = Block(
+ Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++
+ combineAttaches(attaches.toSeq) ++ simlist
+ )
Module(m.info, m.name, m.ports, newBody)
}
-
/** Returns all references to all sink leaf subcomponents of a reference */
private def getSinkRefs(n: String, t: Type, g: Flow): Seq[Expression] = {
val exps = create_exps(WRef(n, t, ExpKind, g))
- exps.flatMap { case exp =>
- exp.tpe match {
- case AnalogType(w) => None
- case _ => flow(exp) match {
- case (DuplexFlow | SinkFlow) => Some(exp)
- case _ => None
+ exps.flatMap {
+ case exp =>
+ exp.tpe match {
+ case AnalogType(w) => None
+ case _ =>
+ flow(exp) match {
+ case (DuplexFlow | SinkFlow) => Some(exp)
+ case _ => None
+ }
}
- }
}
}
@@ -238,7 +242,7 @@ object ExpandWhens extends Pass {
def handleInvalid(k: WrappedExpression, info: Info): Statement =
if (attached.contains(k)) EmptyStmt else IsInvalid(info, k.e1)
netlist.map {
- case (k, WInvalid) => handleInvalid(k, NoInfo)
+ case (k, WInvalid) => handleInvalid(k, NoInfo)
case (k, InfoExpr(info, WInvalid)) => handleInvalid(k, info)
case (k, v) =>
val (info, expr) = unwrap(v)
@@ -261,7 +265,7 @@ object ExpandWhens extends Pass {
case Seq() => // None of these expressions is present in the attachMap
AttachAcc(exprs, attachMap.size)
case accs => // At least one expression present in the attachMap
- val sorted = accs sortBy (_.idx)
+ val sorted = accs.sortBy(_.idx)
AttachAcc((sorted.map(_.exprs) :+ exprs).flatten.distinct, sorted.head.idx)
}
attachMap ++= acc.exprs.map(_ -> acc)
@@ -274,10 +278,11 @@ object ExpandWhens extends Pass {
private def getDefault(lvalue: WrappedExpression, defaults: Defaults): Option[Expression] = {
defaults match {
case Nil => None
- case head :: tail => head get lvalue match {
- case Some(p) => Some(p)
- case None => getDefault(lvalue, tail)
- }
+ case head :: tail =>
+ head.get(lvalue) match {
+ case Some(p) => Some(p)
+ case None => getDefault(lvalue, tail)
+ }
}
}
@@ -290,10 +295,12 @@ object ExpandWhens extends Pass {
class ExpandWhensAndCheck extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses)
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true
@@ -301,6 +308,6 @@ class ExpandWhensAndCheck extends Transform with DependencyAPIMigration {
}
override def execute(a: CircuitState): CircuitState =
- Seq(ExpandWhens, CheckInitialization).foldLeft(a){ case (acc, tx) => tx.transform(acc) }
+ Seq(ExpandWhens, CheckInitialization).foldLeft(a) { case (acc, tx) => tx.transform(acc) }
}