aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/transforms/FlattenRegUpdate.scala')
-rw-r--r--src/main/scala/firrtl/transforms/FlattenRegUpdate.scala39
1 files changed, 26 insertions, 13 deletions
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
index ea694719..4bda25ce 100644
--- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
+++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
@@ -7,11 +7,19 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.options.Dependency
+import firrtl.InfoExpr.orElse
import scala.collection.mutable
object FlattenRegUpdate {
+ // Combination function for dealing with inlining of muxes and the handling of Triples of infos
+ private def combineInfos(muxInfo: Info, tinfo: Info, finfo: Info): Info = {
+ val (eninfo, tinfoAlt, finfoAlt) = MultiInfo.demux(muxInfo)
+ // Use MultiInfo constructor to preserve NoInfos
+ new MultiInfo(List(eninfo, orElse(tinfo, tinfoAlt), orElse(finfo, finfoAlt)))
+ }
+
/** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
type Netlist = mutable.HashMap[WrappedExpression, Expression]
@@ -26,10 +34,12 @@ object FlattenRegUpdate {
val netlist = new Netlist()
def onStmt(stmt: Statement): Statement = {
stmt.map(onStmt) match {
- case Connect(_, lhs, rhs) =>
- netlist(lhs) = rhs
- case DefNode(_, nname, rhs) =>
- netlist(WRef(nname)) = rhs
+ case Connect(info, lhs, rhs) =>
+ val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs)
+ netlist(lhs) = expr
+ case DefNode(info, nname, rhs) =>
+ val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs)
+ netlist(WRef(nname)) = expr
case _: IsInvalid => throwInternalError("Unexpected IsInvalid, should have been removed by now")
case _ => // Do nothing
}
@@ -64,19 +74,21 @@ object FlattenRegUpdate {
val regUpdates = mutable.ArrayBuffer.empty[Connect]
val netlist = buildNetlist(mod)
- def constructRegUpdate(e: Expression): Expression = {
+ def constructRegUpdate(e: Expression): (Info, Expression) = {
+ import InfoExpr.unwrap
// Only walk netlist for nodes and wires, NOT registers or other state
- val expr = kind(e) match {
- case NodeKind | WireKind => netlist.getOrElse(e, e)
- case _ => e
+ val (info, expr) = kind(e) match {
+ case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e))
+ case _ => unwrap(e)
}
expr match {
case mux: Mux if canFlatten(mux) =>
- val tvalx = constructRegUpdate(mux.tval)
- val fvalx = constructRegUpdate(mux.fval)
- mux.copy(tval = tvalx, fval = fvalx)
+ val (tinfo, tvalx) = constructRegUpdate(mux.tval)
+ val (finfo, fvalx) = constructRegUpdate(mux.fval)
+ val infox = combineInfos(info, tinfo, finfo)
+ (infox, mux.copy(tval = tvalx, fval = fvalx))
// Return the original expression to end flattening
- case _ => e
+ case _ => unwrap(e)
}
}
@@ -85,7 +97,8 @@ object FlattenRegUpdate {
assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
"Synchronous reset should have already been made explicit!")
val ref = WRef(reg)
- val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref)))
+ val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref))
+ val update = Connect(info, ref, rhs)
regUpdates += update
reg
// Remove connections to Registers so we preserve LowFirrtl single-connection semantics