diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms/FlattenRegUpdate.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/FlattenRegUpdate.scala | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index ea694719..4bda25ce 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -7,11 +7,19 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.Utils._ import firrtl.options.Dependency +import firrtl.InfoExpr.orElse import scala.collection.mutable object FlattenRegUpdate { + // Combination function for dealing with inlining of muxes and the handling of Triples of infos + private def combineInfos(muxInfo: Info, tinfo: Info, finfo: Info): Info = { + val (eninfo, tinfoAlt, finfoAlt) = MultiInfo.demux(muxInfo) + // Use MultiInfo constructor to preserve NoInfos + new MultiInfo(List(eninfo, orElse(tinfo, tinfoAlt), orElse(finfo, finfoAlt))) + } + /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */ type Netlist = mutable.HashMap[WrappedExpression, Expression] @@ -26,10 +34,12 @@ object FlattenRegUpdate { val netlist = new Netlist() def onStmt(stmt: Statement): Statement = { stmt.map(onStmt) match { - case Connect(_, lhs, rhs) => - netlist(lhs) = rhs - case DefNode(_, nname, rhs) => - netlist(WRef(nname)) = rhs + case Connect(info, lhs, rhs) => + val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs) + netlist(lhs) = expr + case DefNode(info, nname, rhs) => + val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs) + netlist(WRef(nname)) = expr case _: IsInvalid => throwInternalError("Unexpected IsInvalid, should have been removed by now") case _ => // Do nothing } @@ -64,19 +74,21 @@ object FlattenRegUpdate { val regUpdates = mutable.ArrayBuffer.empty[Connect] val netlist = buildNetlist(mod) - def constructRegUpdate(e: Expression): Expression = { + def constructRegUpdate(e: Expression): (Info, Expression) = { + import InfoExpr.unwrap // Only walk netlist for nodes and wires, NOT registers or other state - val expr = kind(e) match { - case NodeKind | WireKind => netlist.getOrElse(e, e) - case _ => e + val (info, expr) = kind(e) match { + case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e)) + case _ => unwrap(e) } expr match { case mux: Mux if canFlatten(mux) => - val tvalx = constructRegUpdate(mux.tval) - val fvalx = constructRegUpdate(mux.fval) - mux.copy(tval = tvalx, fval = fvalx) + val (tinfo, tvalx) = constructRegUpdate(mux.tval) + val (finfo, fvalx) = constructRegUpdate(mux.fval) + val infox = combineInfos(info, tinfo, finfo) + (infox, mux.copy(tval = tvalx, fval = fvalx)) // Return the original expression to end flattening - case _ => e + case _ => unwrap(e) } } @@ -85,7 +97,8 @@ object FlattenRegUpdate { assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero, "Synchronous reset should have already been made explicit!") val ref = WRef(reg) - val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref))) + val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref)) + val update = Connect(info, ref, rhs) regUpdates += update reg // Remove connections to Registers so we preserve LowFirrtl single-connection semantics |
