From b25cd542192132161f3c162f7e782a9cbb2d09ae Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Thu, 16 Jul 2020 17:27:52 -0700 Subject: Propagate source locators to register update always blocks (#1743) * [WIP] Propagate source locators to Verilog if-else emission * Add and fix tests for reg update info propagation * Add limited source locator propagation in ConstProp Support propagating source locators on connections or nodes where the right-hand side is simply a reference. This case comes up a lot for registers without a synchronous reset. node _T_1 = x @[MyFile.scala 12:10] node _T_2 = _T_1 z <= x Previousy the source locator would be lost, now the result is: z <= x @[MyFile.scala 12:10] * Address review comments Co-authored-by: Schuyler Eldridge Co-authored-by: Schuyler Eldridge Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>--- src/main/scala/firrtl/Emitter.scala | 48 +++++---- src/main/scala/firrtl/WIR.scala | 38 ++++++++ src/main/scala/firrtl/ir/IR.scala | 12 ++- src/main/scala/firrtl/passes/ExpandWhens.scala | 108 +++++++++------------ .../firrtl/transforms/ConstantPropagation.scala | 56 ++++++++--- .../scala/firrtl/transforms/FlattenRegUpdate.scala | 39 +++++--- src/main/scala/firrtl/transforms/RemoveReset.scala | 10 +- src/test/scala/firrtlTests/InfoSpec.scala | 71 +++++++++++++- 8 files changed, 263 insertions(+), 119 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index b5474769..f9787a48 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -550,17 +550,19 @@ class VerilogEmitter extends SeqTransform with Emitter { this(Seq(), Map.empty, m, moduleMap, "", new EmissionOptions(Seq.empty))(writer) } - val netlist = mutable.LinkedHashMap[WrappedExpression, Expression]() + val netlist = mutable.LinkedHashMap[WrappedExpression, InfoExpr]() val namespace = Namespace(m) namespace.newName("_RAND") // Start rand names at _RAND_0 def build_netlist(s: Statement): Unit = { s.foreach(build_netlist) s match { - case sx: Connect => netlist(sx.loc) = sx.expr + case sx: Connect => netlist(sx.loc) = InfoExpr(sx.info, sx.expr) case sx: IsInvalid => error("Should have removed these!") + // TODO Since only register update and memories use the netlist anymore, I think nodes are + // unnecessary case sx: DefNode => val e = WRef(sx.name, sx.value.tpe, NodeKind, SourceFlow) - netlist(e) = sx.value + netlist(e) = InfoExpr(sx.info, sx.value) case _ => } } @@ -663,6 +665,9 @@ class VerilogEmitter extends SeqTransform with Emitter { def declare(b: String, n: String, t: Type, info: Info): Unit = declare(b, n, t, info, None) + def assign(e: Expression, infoExpr: InfoExpr): Unit = + assign(e, infoExpr.expr, infoExpr.info) + def assign(e: Expression, value: Expression, info: Info): Unit = { assigns += Seq("assign ", e, " = ", value, ";", info) } @@ -684,19 +689,20 @@ class VerilogEmitter extends SeqTransform with Emitter { } def regUpdate(r: Expression, clk: Expression, reset: Expression, init: Expression) = { - def addUpdate(expr: Expression, tabs: String): Seq[Seq[Any]] = expr match { + def addUpdate(info: Info, expr: Expression, tabs: String): Seq[Seq[Any]] = expr match { case m: Mux => if (m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly") if (m.tpe == AsyncResetType) throw EmitterException("Cannot emit async reset muxes directly") - lazy val _if = Seq(tabs, "if (", m.cond, ") begin") + val (eninfo, tinfo, finfo) = MultiInfo.demux(info) + lazy val _if = Seq(tabs, "if (", m.cond, ") begin", eninfo) lazy val _else = Seq(tabs, "end else begin") - lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin") + lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin", eninfo) lazy val _end = Seq(tabs, "end") - lazy val _true = addUpdate(m.tval, tabs + tab) - lazy val _false = addUpdate(m.fval, tabs + tab) + lazy val _true = addUpdate(tinfo, m.tval, tabs + tab) + lazy val _false = addUpdate(finfo, m.fval, tabs + tab) lazy val _elseIfFalse = { - val _falsex = addUpdate(m.fval, tabs) // _false, but without an additional tab + val _falsex = addUpdate(finfo, m.fval, tabs) // _false, but without an additional tab Seq(tabs, "end else ", _falsex.head.tail) +: _falsex.tail } @@ -719,15 +725,17 @@ class VerilogEmitter extends SeqTransform with Emitter { case (_, _: Mux) => (_if +: _true) ++ _elseIfFalse case _ => (_if +: _true :+ _else) ++ _false :+ _end } - case e => Seq(Seq(tabs, r, " <= ", e, ";")) + case e => Seq(Seq(tabs, r, " <= ", e, ";", info)) } if (weq(init, r)) { // Synchronous Reset - noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(netlist(r), "") + val InfoExpr(info, e) = netlist(r) + noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(info, e, "") } else { // Asynchronous Reset assert(reset.tpe == AsyncResetType, "Error! Synchronous reset should have been removed!") val tv = init - val fv = netlist(r) - asyncResetAlwaysBlocks += ((clk, reset, addUpdate(Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), ""))) + val InfoExpr(finfo, fv) = netlist(r) + // TODO add register info argument and build a MultiInfo to pass + asyncResetAlwaysBlocks += ((clk, reset, addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), ""))) } } @@ -996,7 +1004,7 @@ class VerilogEmitter extends SeqTransform with Emitter { // declare("wire", LowerTypes.loweredName(en), en.tpe) //; Read port - assign(addr, netlist(addr), NoInfo) // Info should come from addr connection + assign(addr, netlist(addr)) // assign(en, netlist(en)) //;Connects value to m.r.en val mem = WRef(sx.name, memType(sx), MemKind, UnknownFlow) val memPort = WSubAccess(mem, addr, sx.dataType, UnknownFlow) @@ -1015,7 +1023,8 @@ class VerilogEmitter extends SeqTransform with Emitter { val mask = memPortField(sx, w, "mask") val en = memPortField(sx, w, "en") //Ports should share an always@posedge, so can't have intermediary wire - val clk = netlist(memPortField(sx, w, "clk")) + // TODO should we use the info here for anything? + val InfoExpr(_, clk) = netlist(memPortField(sx, w, "clk")) declare("wire", LowerTypes.loweredName(data), data.tpe, sx.info) declare("wire", LowerTypes.loweredName(addr), addr.tpe, sx.info) @@ -1023,11 +1032,10 @@ class VerilogEmitter extends SeqTransform with Emitter { declare("wire", LowerTypes.loweredName(en), en.tpe, sx.info) // Write port - // Info should come from netlist - assign(data, netlist(data), NoInfo) - assign(addr, netlist(addr), NoInfo) - assign(mask, netlist(mask), NoInfo) - assign(en, netlist(en), NoInfo) + assign(data, netlist(data)) + assign(addr, netlist(addr)) + assign(mask, netlist(mask)) + assign(en, netlist(en)) val mem = WRef(sx.name, memType(sx), MemKind, UnknownFlow) val memPort = WSubAccess(mem, addr, sx.dataType, UnknownFlow) diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 2f1daadd..cda22d27 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -169,6 +169,44 @@ case object Dshlw extends PrimOp { } } +/** Internal class used for propagating [[Info]] across [[Expression]]s + * + * In particular, this is useful in "Netlist" datastructures mapping node or other [[Statement]]s + * to [[Expression]]s + * + * @note This is not allowed to leak from any transform + */ +private[firrtl] case class InfoExpr(info: Info, expr: Expression) extends Expression { + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = () + def foreachWidth(f: Width => Unit): Unit = () + def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(this.expr)) + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this + def tpe: Type = expr.tpe + + // Members declared in firrtl.ir.FirrtlNode + def serialize: String = s"(${expr.serialize}: ${info.serialize})" +} + +private[firrtl] object InfoExpr { + def wrap(info: Info, expr: Expression): Expression = + if (info == NoInfo) expr else InfoExpr(info, expr) + + def unwrap(expr: Expression): (Info, Expression) = expr match { + case InfoExpr(i, e) => (i, e) + case other => (NoInfo, other) + } + + def orElse(info: Info, alt: => Info): Info = if (info == NoInfo) alt else info + + // TODO this the right name? + def map(expr: Expression)(f: Expression => Expression): Expression = expr match { + case ie: InfoExpr => ie.mapExpr(f) + case e => f(e) + } +} + object WrappedExpression { def apply(e: Expression) = new WrappedExpression(e) def we(e: Expression) = new WrappedExpression(e) diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 734b475d..275cbe51 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -95,10 +95,18 @@ object MultiInfo { val infosx = infos.filterNot(_ == NoInfo) infosx.size match { case 0 => NoInfo - case 1 => infosx.head - case _ => new MultiInfo(infosx) + case 1 => infos.head + case _ => new MultiInfo(infos) } } + + // Internal utility for unpacking implicit MultiInfo structure for muxes + // TODO should this be made into an API? + private[firrtl] def demux(info: Info): (Info, Info, Info) = info match { + case MultiInfo(infos) if infos.lengthCompare(3) == 0 => (infos(0), infos(1), infos(2)) + case other => (other, NoInfo, NoInfo) // if not exactly 3, we don't know what to do + } + private def flattenInfo(infos: Seq[Info]): Seq[FileInfo] = infos.flatMap { case NoInfo => Seq() case f : FileInfo => Seq(f) diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index e7eebb57..ab4c9bfa 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -9,6 +9,7 @@ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.WrappedExpression._ import firrtl.options.Dependency +import firrtl.InfoExpr.unwrap import annotation.tailrec import collection.mutable @@ -42,12 +43,7 @@ object ExpandWhens extends Pass { def run(c: Circuit): Circuit = { val modulesx = c.modules map { case m: ExtModule => m - case m: Module => - val (netlist, simlist, attaches, bodyx, sourceInfoMap) = expandWhens(m) - val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet - val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs, sourceInfoMap) ++ - combineAttaches(attaches) ++ simlist) - Module(m.info, m.name, m.ports, newBody) + case m: Module => onModule(m) } Circuit(c.info, modulesx, c.main) } @@ -61,15 +57,6 @@ object ExpandWhens extends Pass { /** Maps a reference to whatever connects to it. Used to resolve last connect semantics */ type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression] - /** Collects Info data serialized names for nodes, aggregating into MultiInfo when necessary */ - class InfoMap extends mutable.HashMap[String, Info] { - override def default(key: String): Info = { - val x = NoInfo - this(key) = x - x - } - } - /** Contains all simulation constructs */ type Simlist = mutable.ArrayBuffer[Statement] @@ -78,37 +65,29 @@ object ExpandWhens extends Pass { */ type Defaults = Seq[mutable.Map[WrappedExpression, Expression]] - - /** Expands a module's when statements - * @param m Module to expand - * @note Netlist maps a reference to whatever connects to it - * @note Simlist contains all simulation constructs in m - * @note Seq[Attach] contains all Attach statements (unsimplified) - * @note Statement contains all declarations in the module (including DefNode's) - */ - def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement, InfoMap) = { + /** Expands a module's when statements */ + private def onModule(m: Module): Module = { val namespace = Namespace(m) val simlist = new Simlist // Memoizes if an expression contains any WVoids inserted in this pass val memoizedVoid = new mutable.HashSet[WrappedExpression] += WVoid + // Does an expression contain WVoid inserted in this pass? + def containsVoid(e: Expression): Boolean = e match { + case WVoid => true + case ValidIf(_, value, _) => memoizedVoid(value) + case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) + case _ => false + } + + // Memoizes the node that holds a particular expression, if any val nodes = new NodeLookup // Seq of attaches in order lazy val attaches = mutable.ArrayBuffer.empty[Attach] - val infoMap: InfoMap = new InfoMap - - /* Adds into into map, aggregates info into MultiInfo where necessary - * @param key serialized name of node - * @param info info being recorded - */ - def saveInfo(key: String, info: Info): Unit = { - infoMap(key) = infoMap(key) ++ info - } - /* Removes connections/attaches from the statement * Mutates namespace, simlist, nodes, attaches * Mutates input netlist @@ -133,15 +112,13 @@ object ExpandWhens extends Pass { case w: WDefInstance => netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> WVoid)) w - // Update netlist with self reference for each sink reference - // Return self, unchanged case r: DefRegister => - netlist ++= (getSinkRefs(r.name, r.tpe, DuplexFlow) map (ref => we(ref) -> ref)) + // Update netlist with self reference for each sink reference + netlist ++= getSinkRefs(r.name, r.tpe, DuplexFlow).map(ref => we(ref) -> InfoExpr(r.info, ref)) r // For value assignments, update netlist/attaches and return EmptyStmt case c: Connect => - saveInfo(c.loc.serialize, c.info) - netlist(c.loc) = c.expr + netlist(c.loc) = InfoExpr(c.info, c.expr) EmptyStmt case c: IsInvalid => netlist(c.expr) = WInvalid @@ -179,27 +156,20 @@ object ExpandWhens extends Pass { case Some(v) => Some(v) case None => getDefault(lvalue, defaults) } - val res = default match { + // info0 and info1 correspond to Mux infos, use info0 only if ValidIf + val (res, info0, info1) = default match { case Some(defaultValue) => - val trueValue = conseqNetlist getOrElse (lvalue, defaultValue) - val falseValue = altNetlist getOrElse (lvalue, defaultValue) + val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue)) + val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue)) (trueValue, falseValue) match { - case (WInvalid, WInvalid) => WInvalid - case (WInvalid, fv) => ValidIf(NOT(sx.pred), fv, fv.tpe) - case (tv, WInvalid) => ValidIf(sx.pred, tv, tv.tpe) - case (tv, fv) => Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)) //Muxing clocks will be checked during type checking + case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo) + case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo) + case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo) + case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo) } case None => // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt - conseqNetlist getOrElse (lvalue, altNetlist(lvalue)) - } - - // Does an expression contain WVoid inserted in this pass? - def containsVoid(e: Expression): Boolean = e match { - case WVoid => true - case ValidIf(_, value, _) => memoizedVoid(value) - case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) - case _ => false + (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo) } res match { @@ -217,7 +187,9 @@ object ExpandWhens extends Pass { val name = namespace.newTemp nodes(res) = name netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) - DefNode(sx.info, name, res) + // Use MultiInfo constructor to preserve NoInfos + val info = new MultiInfo(List(sx.info, info0, info1)) + DefNode(info, name, res) } case _ => netlist(lvalue) = res @@ -233,8 +205,13 @@ object ExpandWhens extends Pass { netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) => getSinkRefs(name, tpe, to_flow(dir)) map (ref => we(ref) -> WVoid) }) + // Do traversal and construct mutable datastructures val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) - (netlist, simlist, attaches, bodyx, infoMap) + + val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet + val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ + combineAttaches(attaches) ++ simlist) + Module(m.info, m.name, m.ports, newBody) } @@ -253,17 +230,20 @@ object ExpandWhens extends Pass { } /** Returns all connections/invalidations in the circuit - * @todo Preserve Info * @note Remove IsInvalids on attached Analog-typed components */ - private def expandNetlist(netlist: Netlist, attached: Set[WrappedExpression], sourceInfoMap: InfoMap) = - netlist map { - case (k, WInvalid) => // Remove IsInvalids on attached Analog types - if (attached.contains(k)) EmptyStmt else IsInvalid(NoInfo, k.e1) + private def expandNetlist(netlist: Netlist, attached: Set[WrappedExpression]) = { + // Remove IsInvalids on attached Analog types + def handleInvalid(k: WrappedExpression, info: Info): Statement = + if (attached.contains(k)) EmptyStmt else IsInvalid(info, k.e1) + netlist.map { + case (k, WInvalid) => handleInvalid(k, NoInfo) + case (k, InfoExpr(info, WInvalid)) => handleInvalid(k, info) case (k, v) => - val info = sourceInfoMap(k.e1.serialize) - Connect(info, k.e1, v) + val (info, expr) = unwrap(v) + Connect(info, k.e1, expr) } + } /** Returns new sequence of combined Attaches * @todo Preserve Info diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 29410c7f..000adc15 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -458,9 +458,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case _ => constPropMuxCond(m) } - private def constPropNodeRef(r: WRef, e: Expression) = e match { - case _: UIntLiteral | _: SIntLiteral | _: WRef => e - case _ => r + private def constPropNodeRef(r: WRef, e: Expression): Expression = { + def doit(ex: Expression) = ex match { + case _: UIntLiteral | _: SIntLiteral | _: WRef => ex + case _ => r + } + InfoExpr.map(e)(doit) } // Is "a" a "better name" than "b"? @@ -475,7 +478,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case p: DoPrim => constPropPrim(p) case m: Mux => constPropMux(m) case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) + constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) => val module = instMap(inst.Instance) // Check constSubOutputs to see if the submodule is driving a constant @@ -487,6 +490,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) } + /** Hacky way of propagating source locators across nodes and connections that have just a + * reference on the right-hand side + * + * @todo generalize source locator propagation across Expressions and delete this method + * @todo is the `orElse` the way we want to do propagation here? + */ + private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String]) + (stmt: Statement): Statement = stmt match { + // We check rname because inlining it would cause the original declaration to go away + case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + node.copy(info = InfoExpr.orElse(info1, info0)) + case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + con.copy(info = InfoExpr.orElse(info1, info0)) + case other => other + } + /* Constant propagate a Module * * Two pass process @@ -547,7 +568,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res ref.copy(name = swapMap(rname)) // Only const prop on the rhs case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) + constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case x => x } if (old ne propagated) { @@ -577,7 +598,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } // When propagating a reference, check if we want to keep the name that would be deleted - def propagateRef(lname: String, value: Expression): Unit = { + def propagateRef(lname: String, value: Expression, info: Info): Unit = { value match { case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a @@ -585,19 +606,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res swapMap += (lname -> rname, rname -> lname) case _ => } - nodeMap(lname) = value + nodeMap(lname) = InfoExpr.wrap(info, value) } def constPropStmt(s: Statement): Statement = { - val stmtx = s map constPropStmt map constPropExpression(nodeMap, instMap, constSubOutputs) + val s0 = s.map(constPropStmt) // Statement recurse + val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation + val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions // Record things that should be propagated stmtx match { - case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) + case DefNode(info, name, value) if !dontTouches.contains(name) => + propagateRef(name, value, info) case reg: DefRegister if reg.reset.tpe == AsyncResetType => asyncResetRegs(reg.name) = reg - case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => + case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe)) - propagateRef(wname, exprx) + propagateRef(wname, exprx, info) // Record constants driving outputs case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] @@ -633,7 +657,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case lit: Literal => baseCase.resolve(RegCPEntry(UnboundConstant, BoundConstant(lit))) case WRef(regName, _, RegKind, _) => baseCase.resolve(RegCPEntry(BoundConstant(regName), UnboundConstant)) case WRef(nodeName, _, NodeKind, _) if nodeMap.contains(nodeName) => - val cached = nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(nodeMap(nodeName), unbound) }) + val (_, expr) = InfoExpr.unwrap(nodeMap(nodeName)) + val cached = nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(expr, unbound) }) baseCase.resolve(cached) case Mux(_, tval, fval, _) => regConstant(tval, baseCase).resolve(regConstant(fval, baseCase)) @@ -676,8 +701,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Actually transform some statements stmtx match { // Propagate connections to references - case Connect(info, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) => - Connect(info, lhs, nodeMap(rname)) + case Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) => + val (info1, value) = InfoExpr.unwrap(nodeMap(rname)) + // Is this the right info combination/propagation function? + // See propagateDirectConnectionInfoOnly + Connect(InfoExpr.orElse(info1, info0), lhs, value) // If an Attach has at least 1 port, any wires are redundant and can be removed case Attach(info, exprs) if exprs.exists(kind(_) == PortKind) => Attach(info, exprs.filterNot(kind(_) == WireKind)) diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index ea694719..4bda25ce 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -7,11 +7,19 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.Utils._ import firrtl.options.Dependency +import firrtl.InfoExpr.orElse import scala.collection.mutable object FlattenRegUpdate { + // Combination function for dealing with inlining of muxes and the handling of Triples of infos + private def combineInfos(muxInfo: Info, tinfo: Info, finfo: Info): Info = { + val (eninfo, tinfoAlt, finfoAlt) = MultiInfo.demux(muxInfo) + // Use MultiInfo constructor to preserve NoInfos + new MultiInfo(List(eninfo, orElse(tinfo, tinfoAlt), orElse(finfo, finfoAlt))) + } + /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */ type Netlist = mutable.HashMap[WrappedExpression, Expression] @@ -26,10 +34,12 @@ object FlattenRegUpdate { val netlist = new Netlist() def onStmt(stmt: Statement): Statement = { stmt.map(onStmt) match { - case Connect(_, lhs, rhs) => - netlist(lhs) = rhs - case DefNode(_, nname, rhs) => - netlist(WRef(nname)) = rhs + case Connect(info, lhs, rhs) => + val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs) + netlist(lhs) = expr + case DefNode(info, nname, rhs) => + val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs) + netlist(WRef(nname)) = expr case _: IsInvalid => throwInternalError("Unexpected IsInvalid, should have been removed by now") case _ => // Do nothing } @@ -64,19 +74,21 @@ object FlattenRegUpdate { val regUpdates = mutable.ArrayBuffer.empty[Connect] val netlist = buildNetlist(mod) - def constructRegUpdate(e: Expression): Expression = { + def constructRegUpdate(e: Expression): (Info, Expression) = { + import InfoExpr.unwrap // Only walk netlist for nodes and wires, NOT registers or other state - val expr = kind(e) match { - case NodeKind | WireKind => netlist.getOrElse(e, e) - case _ => e + val (info, expr) = kind(e) match { + case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e)) + case _ => unwrap(e) } expr match { case mux: Mux if canFlatten(mux) => - val tvalx = constructRegUpdate(mux.tval) - val fvalx = constructRegUpdate(mux.fval) - mux.copy(tval = tvalx, fval = fvalx) + val (tinfo, tvalx) = constructRegUpdate(mux.tval) + val (finfo, fvalx) = constructRegUpdate(mux.fval) + val infox = combineInfos(info, tinfo, finfo) + (infox, mux.copy(tval = tvalx, fval = fvalx)) // Return the original expression to end flattening - case _ => e + case _ => unwrap(e) } } @@ -85,7 +97,8 @@ object FlattenRegUpdate { assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero, "Synchronous reset should have already been made explicit!") val ref = WRef(reg) - val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref))) + val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref)) + val update = Connect(info, ref, rhs) regUpdates += update reg // Remove connections to Registers so we preserve LowFirrtl single-connection semantics diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index 2db93626..6b3a9d07 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -30,7 +30,7 @@ object RemoveReset extends Transform with DependencyAPIMigration { case _ => false } - private case class Reset(cond: Expression, value: Expression) + private case class Reset(cond: Expression, value: Expression, info: Info) /** Return an immutable set of all invalid expressions in a module * @param m a module @@ -58,14 +58,16 @@ object RemoveReset extends Transform with DependencyAPIMigration { reg.copy(reset = Utils.zero, init = WRef(reg)) case reg @ DefRegister(_, rname, _, _, Utils.zero, _) => reg.copy(init = WRef(reg)) // canonicalize - case reg @ DefRegister(_, rname, _, _, reset, init) if reset.tpe != AsyncResetType => + case reg @ DefRegister(info , rname, _, _, reset, init) if reset.tpe != AsyncResetType => // Add register reset to map - resets(rname) = Reset(reset, init) + resets(rname) = Reset(reset, init, info) reg.copy(reset = Utils.zero, init = WRef(reg)) case Connect(info, ref @ WRef(rname, _, RegKind, _), expr) if resets.contains(rname) => val reset = resets(rname) val muxType = Utils.mux_type_and_widths(reset.value, expr) - Connect(info, ref, Mux(reset.cond, reset.value, expr, muxType)) + // Use reg source locator for mux enable and true value since that's where they're defined + val infox = MultiInfo(reset.info, reset.info, info) + Connect(infox, ref, Mux(reset.cond, reset.value, expr, muxType)) case other => other map onStmt } } diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala index 01e0a0ac..a2410f9d 100644 --- a/src/test/scala/firrtlTests/InfoSpec.scala +++ b/src/test/scala/firrtlTests/InfoSpec.scala @@ -23,6 +23,7 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { val Info1 = FileInfo(StringLit("Source.scala 1:4")) val Info2 = FileInfo(StringLit("Source.scala 2:4")) val Info3 = FileInfo(StringLit("Source.scala 3:4")) + val Info4 = FileInfo(StringLit("Source.scala 4:4")) "Source locators on module ports" should "be propagated to Verilog" in { val result = compileBody(s""" @@ -119,6 +120,21 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { result should containLine (s"Child c ( //$Info1") } + it should "be propagated across direct node assignments and connections" in { + val result = compile(s""" + |circuit Test : + | module Test : + | input in : UInt<8> + | output out : UInt<8> + | node a = in $Info1 + | node b = a + | out <= b + |""".stripMargin + ) + result should containTree { case Connect(Info1, Reference("out", _,_,_), Reference("in", _,_,_)) => true } + result should containLine (s"assign out = in; //$Info1") + } + "source locators" should "be propagated through ExpandWhens" in { val input = """ |;buildInfoPackage: chisel3, version: 3.1-SNAPSHOT, scalaVersion: 2.11.7, sbtVersion: 0.13.11, builtAtString: 2016-11-26 18:48:38.030, builtAtMillis: 1480186118030 @@ -155,8 +171,12 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { """.stripMargin val result = (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) - result should containLine ("x <= _GEN_2 @[GCD.scala 17:22 GCD.scala 19:19]") - result should containLine ("y <= _GEN_3 @[GCD.scala 18:22 GCD.scala 19:30]") + result should containLine ("node _GEN_0 = mux(_T_14, _T_16, x) @[GCD.scala 17:18 GCD.scala 17:22 GCD.scala 15:14]") + result should containLine ("node _GEN_2 = mux(io_e, io_a, _GEN_0) @[GCD.scala 19:15 GCD.scala 19:19]") + result should containLine ("x <= _GEN_2") + result should containLine ("node _GEN_1 = mux(_T_18, _T_20, y) @[GCD.scala 18:18 GCD.scala 18:22 GCD.scala 16:14]") + result should containLine ("node _GEN_3 = mux(io_e, io_b, _GEN_1) @[GCD.scala 19:15 GCD.scala 19:30]") + result should containLine ("y <= _GEN_3") } "source locators for append option" should "use multiinfo" in { @@ -173,6 +193,53 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { circuitState should containTree { case MultiInfo(`expectedInfos`) => true } } + "source locators for basic register updates" should "be propagated to Verilog" in { + val result = compileBody(s""" + |input clock : Clock + |input reset : UInt<1> + |output io : { flip in : UInt<8>, out : UInt<8>} + |reg r : UInt<8>, clock + |r <= io.in $Info1 + |io.out <= r + |""".stripMargin + ) + result should containLine (s"r <= io_in; //$Info1") + } + + "source locators for register reset" should "be propagated to Verilog" in { + val result = compileBody(s""" + |input clock : Clock + |input reset : UInt<1> + |output io : { flip in : UInt<8>, out : UInt<8>} + |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info3 + |r <= io.in $Info1 + |io.out <= r + |""".stripMargin + ) + result should containLine (s"if (reset) begin //$Info3") + result should containLine (s"r <= 8'h0; //$Info3") + result should containLine (s"r <= io_in; //$Info1") + } + + "source locators for complex register updates" should "be propagated to Verilog" in { + val result = compileBody(s""" + |input clock : Clock + |input reset : UInt<1> + |output io : { flip in : UInt<8>, flip a : UInt<1>, out : UInt<8>} + |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info1 + |r <= UInt<2>(2) $Info2 + |when io.a : $Info3 + | r <= io.in $Info4 + |io.out <= r + |""".stripMargin + ) + result should containLine (s"if (reset) begin //$Info1") + result should containLine (s"r <= 8'h0; //$Info1") + result should containLine (s"end else if (io_a) begin //$Info3") + result should containLine (s"r <= io_in; //$Info4") + result should containLine (s"r <= 8'h2; //$Info2") + } + "FileInfo" should "be able to contain a escaped characters" in { def input(info: String): String = s"""circuit m: @[$info] -- cgit v1.2.3