diff options
| author | Jack Koenig | 2020-07-16 17:27:52 -0700 |
|---|---|---|
| committer | GitHub | 2020-07-17 00:27:52 +0000 |
| commit | b25cd542192132161f3c162f7e782a9cbb2d09ae (patch) | |
| tree | 9f30acdc1cbaf112c944169cac812be441a896bd /src/main/scala/firrtl/transforms | |
| parent | c4cc6bc5b614bd7f5383f8a85c7fc81facdc4b20 (diff) | |
Propagate source locators to register update always blocks (#1743)
* [WIP] Propagate source locators to Verilog if-else emission
* Add and fix tests for reg update info propagation
* Add limited source locator propagation in ConstProp
Support propagating source locators on connections or nodes where the
right-hand side is simply a reference. This case comes up a lot for
registers without a synchronous reset.
node _T_1 = x @[MyFile.scala 12:10]
node _T_2 = _T_1
z <= x
Previousy the source locator would be lost, now the result is:
z <= x @[MyFile.scala 12:10]
* Address review comments
Co-authored-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
Co-authored-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src/main/scala/firrtl/transforms')
3 files changed, 74 insertions, 31 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 29410c7f..000adc15 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -458,9 +458,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case _ => constPropMuxCond(m) } - private def constPropNodeRef(r: WRef, e: Expression) = e match { - case _: UIntLiteral | _: SIntLiteral | _: WRef => e - case _ => r + private def constPropNodeRef(r: WRef, e: Expression): Expression = { + def doit(ex: Expression) = ex match { + case _: UIntLiteral | _: SIntLiteral | _: WRef => ex + case _ => r + } + InfoExpr.map(e)(doit) } // Is "a" a "better name" than "b"? @@ -475,7 +478,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case p: DoPrim => constPropPrim(p) case m: Mux => constPropMux(m) case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) + constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) => val module = instMap(inst.Instance) // Check constSubOutputs to see if the submodule is driving a constant @@ -487,6 +490,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) } + /** Hacky way of propagating source locators across nodes and connections that have just a + * reference on the right-hand side + * + * @todo generalize source locator propagation across Expressions and delete this method + * @todo is the `orElse` the way we want to do propagation here? + */ + private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String]) + (stmt: Statement): Statement = stmt match { + // We check rname because inlining it would cause the original declaration to go away + case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + node.copy(info = InfoExpr.orElse(info1, info0)) + case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + con.copy(info = InfoExpr.orElse(info1, info0)) + case other => other + } + /* Constant propagate a Module * * Two pass process @@ -547,7 +568,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res ref.copy(name = swapMap(rname)) // Only const prop on the rhs case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) + constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case x => x } if (old ne propagated) { @@ -577,7 +598,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } // When propagating a reference, check if we want to keep the name that would be deleted - def propagateRef(lname: String, value: Expression): Unit = { + def propagateRef(lname: String, value: Expression, info: Info): Unit = { value match { case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a @@ -585,19 +606,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res swapMap += (lname -> rname, rname -> lname) case _ => } - nodeMap(lname) = value + nodeMap(lname) = InfoExpr.wrap(info, value) } def constPropStmt(s: Statement): Statement = { - val stmtx = s map constPropStmt map constPropExpression(nodeMap, instMap, constSubOutputs) + val s0 = s.map(constPropStmt) // Statement recurse + val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation + val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions // Record things that should be propagated stmtx match { - case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) + case DefNode(info, name, value) if !dontTouches.contains(name) => + propagateRef(name, value, info) case reg: DefRegister if reg.reset.tpe == AsyncResetType => asyncResetRegs(reg.name) = reg - case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => + case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe)) - propagateRef(wname, exprx) + propagateRef(wname, exprx, info) // Record constants driving outputs case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] @@ -633,7 +657,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case lit: Literal => baseCase.resolve(RegCPEntry(UnboundConstant, BoundConstant(lit))) case WRef(regName, _, RegKind, _) => baseCase.resolve(RegCPEntry(BoundConstant(regName), UnboundConstant)) case WRef(nodeName, _, NodeKind, _) if nodeMap.contains(nodeName) => - val cached = nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(nodeMap(nodeName), unbound) }) + val (_, expr) = InfoExpr.unwrap(nodeMap(nodeName)) + val cached = nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(expr, unbound) }) baseCase.resolve(cached) case Mux(_, tval, fval, _) => regConstant(tval, baseCase).resolve(regConstant(fval, baseCase)) @@ -676,8 +701,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Actually transform some statements stmtx match { // Propagate connections to references - case Connect(info, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) => - Connect(info, lhs, nodeMap(rname)) + case Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) => + val (info1, value) = InfoExpr.unwrap(nodeMap(rname)) + // Is this the right info combination/propagation function? + // See propagateDirectConnectionInfoOnly + Connect(InfoExpr.orElse(info1, info0), lhs, value) // If an Attach has at least 1 port, any wires are redundant and can be removed case Attach(info, exprs) if exprs.exists(kind(_) == PortKind) => Attach(info, exprs.filterNot(kind(_) == WireKind)) diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index ea694719..4bda25ce 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -7,11 +7,19 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.Utils._ import firrtl.options.Dependency +import firrtl.InfoExpr.orElse import scala.collection.mutable object FlattenRegUpdate { + // Combination function for dealing with inlining of muxes and the handling of Triples of infos + private def combineInfos(muxInfo: Info, tinfo: Info, finfo: Info): Info = { + val (eninfo, tinfoAlt, finfoAlt) = MultiInfo.demux(muxInfo) + // Use MultiInfo constructor to preserve NoInfos + new MultiInfo(List(eninfo, orElse(tinfo, tinfoAlt), orElse(finfo, finfoAlt))) + } + /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */ type Netlist = mutable.HashMap[WrappedExpression, Expression] @@ -26,10 +34,12 @@ object FlattenRegUpdate { val netlist = new Netlist() def onStmt(stmt: Statement): Statement = { stmt.map(onStmt) match { - case Connect(_, lhs, rhs) => - netlist(lhs) = rhs - case DefNode(_, nname, rhs) => - netlist(WRef(nname)) = rhs + case Connect(info, lhs, rhs) => + val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs) + netlist(lhs) = expr + case DefNode(info, nname, rhs) => + val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs) + netlist(WRef(nname)) = expr case _: IsInvalid => throwInternalError("Unexpected IsInvalid, should have been removed by now") case _ => // Do nothing } @@ -64,19 +74,21 @@ object FlattenRegUpdate { val regUpdates = mutable.ArrayBuffer.empty[Connect] val netlist = buildNetlist(mod) - def constructRegUpdate(e: Expression): Expression = { + def constructRegUpdate(e: Expression): (Info, Expression) = { + import InfoExpr.unwrap // Only walk netlist for nodes and wires, NOT registers or other state - val expr = kind(e) match { - case NodeKind | WireKind => netlist.getOrElse(e, e) - case _ => e + val (info, expr) = kind(e) match { + case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e)) + case _ => unwrap(e) } expr match { case mux: Mux if canFlatten(mux) => - val tvalx = constructRegUpdate(mux.tval) - val fvalx = constructRegUpdate(mux.fval) - mux.copy(tval = tvalx, fval = fvalx) + val (tinfo, tvalx) = constructRegUpdate(mux.tval) + val (finfo, fvalx) = constructRegUpdate(mux.fval) + val infox = combineInfos(info, tinfo, finfo) + (infox, mux.copy(tval = tvalx, fval = fvalx)) // Return the original expression to end flattening - case _ => e + case _ => unwrap(e) } } @@ -85,7 +97,8 @@ object FlattenRegUpdate { assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero, "Synchronous reset should have already been made explicit!") val ref = WRef(reg) - val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref))) + val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref)) + val update = Connect(info, ref, rhs) regUpdates += update reg // Remove connections to Registers so we preserve LowFirrtl single-connection semantics diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index 2db93626..6b3a9d07 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -30,7 +30,7 @@ object RemoveReset extends Transform with DependencyAPIMigration { case _ => false } - private case class Reset(cond: Expression, value: Expression) + private case class Reset(cond: Expression, value: Expression, info: Info) /** Return an immutable set of all invalid expressions in a module * @param m a module @@ -58,14 +58,16 @@ object RemoveReset extends Transform with DependencyAPIMigration { reg.copy(reset = Utils.zero, init = WRef(reg)) case reg @ DefRegister(_, rname, _, _, Utils.zero, _) => reg.copy(init = WRef(reg)) // canonicalize - case reg @ DefRegister(_, rname, _, _, reset, init) if reset.tpe != AsyncResetType => + case reg @ DefRegister(info , rname, _, _, reset, init) if reset.tpe != AsyncResetType => // Add register reset to map - resets(rname) = Reset(reset, init) + resets(rname) = Reset(reset, init, info) reg.copy(reset = Utils.zero, init = WRef(reg)) case Connect(info, ref @ WRef(rname, _, RegKind, _), expr) if resets.contains(rname) => val reset = resets(rname) val muxType = Utils.mux_type_and_widths(reset.value, expr) - Connect(info, ref, Mux(reset.cond, reset.value, expr, muxType)) + // Use reg source locator for mux enable and true value since that's where they're defined + val infox = MultiInfo(reset.info, reset.info, info) + Connect(infox, ref, Mux(reset.cond, reset.value, expr, muxType)) case other => other map onStmt } } |
