aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack Koenig2020-07-16 17:27:52 -0700
committerGitHub2020-07-17 00:27:52 +0000
commitb25cd542192132161f3c162f7e782a9cbb2d09ae (patch)
tree9f30acdc1cbaf112c944169cac812be441a896bd
parentc4cc6bc5b614bd7f5383f8a85c7fc81facdc4b20 (diff)
Propagate source locators to register update always blocks (#1743)
* [WIP] Propagate source locators to Verilog if-else emission * Add and fix tests for reg update info propagation * Add limited source locator propagation in ConstProp Support propagating source locators on connections or nodes where the right-hand side is simply a reference. This case comes up a lot for registers without a synchronous reset. node _T_1 = x @[MyFile.scala 12:10] node _T_2 = _T_1 z <= x Previousy the source locator would be lost, now the result is: z <= x @[MyFile.scala 12:10] * Address review comments Co-authored-by: Schuyler Eldridge <schuyler.eldridge@ibm.com> Co-authored-by: Schuyler Eldridge <schuyler.eldridge@ibm.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
-rw-r--r--src/main/scala/firrtl/Emitter.scala48
-rw-r--r--src/main/scala/firrtl/WIR.scala38
-rw-r--r--src/main/scala/firrtl/ir/IR.scala12
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala108
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala56
-rw-r--r--src/main/scala/firrtl/transforms/FlattenRegUpdate.scala39
-rw-r--r--src/main/scala/firrtl/transforms/RemoveReset.scala10
-rw-r--r--src/test/scala/firrtlTests/InfoSpec.scala71
8 files changed, 263 insertions, 119 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index b5474769..f9787a48 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -550,17 +550,19 @@ class VerilogEmitter extends SeqTransform with Emitter {
this(Seq(), Map.empty, m, moduleMap, "", new EmissionOptions(Seq.empty))(writer)
}
- val netlist = mutable.LinkedHashMap[WrappedExpression, Expression]()
+ val netlist = mutable.LinkedHashMap[WrappedExpression, InfoExpr]()
val namespace = Namespace(m)
namespace.newName("_RAND") // Start rand names at _RAND_0
def build_netlist(s: Statement): Unit = {
s.foreach(build_netlist)
s match {
- case sx: Connect => netlist(sx.loc) = sx.expr
+ case sx: Connect => netlist(sx.loc) = InfoExpr(sx.info, sx.expr)
case sx: IsInvalid => error("Should have removed these!")
+ // TODO Since only register update and memories use the netlist anymore, I think nodes are
+ // unnecessary
case sx: DefNode =>
val e = WRef(sx.name, sx.value.tpe, NodeKind, SourceFlow)
- netlist(e) = sx.value
+ netlist(e) = InfoExpr(sx.info, sx.value)
case _ =>
}
}
@@ -663,6 +665,9 @@ class VerilogEmitter extends SeqTransform with Emitter {
def declare(b: String, n: String, t: Type, info: Info): Unit =
declare(b, n, t, info, None)
+ def assign(e: Expression, infoExpr: InfoExpr): Unit =
+ assign(e, infoExpr.expr, infoExpr.info)
+
def assign(e: Expression, value: Expression, info: Info): Unit = {
assigns += Seq("assign ", e, " = ", value, ";", info)
}
@@ -684,19 +689,20 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
def regUpdate(r: Expression, clk: Expression, reset: Expression, init: Expression) = {
- def addUpdate(expr: Expression, tabs: String): Seq[Seq[Any]] = expr match {
+ def addUpdate(info: Info, expr: Expression, tabs: String): Seq[Seq[Any]] = expr match {
case m: Mux =>
if (m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly")
if (m.tpe == AsyncResetType) throw EmitterException("Cannot emit async reset muxes directly")
- lazy val _if = Seq(tabs, "if (", m.cond, ") begin")
+ val (eninfo, tinfo, finfo) = MultiInfo.demux(info)
+ lazy val _if = Seq(tabs, "if (", m.cond, ") begin", eninfo)
lazy val _else = Seq(tabs, "end else begin")
- lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin")
+ lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin", eninfo)
lazy val _end = Seq(tabs, "end")
- lazy val _true = addUpdate(m.tval, tabs + tab)
- lazy val _false = addUpdate(m.fval, tabs + tab)
+ lazy val _true = addUpdate(tinfo, m.tval, tabs + tab)
+ lazy val _false = addUpdate(finfo, m.fval, tabs + tab)
lazy val _elseIfFalse = {
- val _falsex = addUpdate(m.fval, tabs) // _false, but without an additional tab
+ val _falsex = addUpdate(finfo, m.fval, tabs) // _false, but without an additional tab
Seq(tabs, "end else ", _falsex.head.tail) +: _falsex.tail
}
@@ -719,15 +725,17 @@ class VerilogEmitter extends SeqTransform with Emitter {
case (_, _: Mux) => (_if +: _true) ++ _elseIfFalse
case _ => (_if +: _true :+ _else) ++ _false :+ _end
}
- case e => Seq(Seq(tabs, r, " <= ", e, ";"))
+ case e => Seq(Seq(tabs, r, " <= ", e, ";", info))
}
if (weq(init, r)) { // Synchronous Reset
- noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(netlist(r), "")
+ val InfoExpr(info, e) = netlist(r)
+ noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(info, e, "")
} else { // Asynchronous Reset
assert(reset.tpe == AsyncResetType, "Error! Synchronous reset should have been removed!")
val tv = init
- val fv = netlist(r)
- asyncResetAlwaysBlocks += ((clk, reset, addUpdate(Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), "")))
+ val InfoExpr(finfo, fv) = netlist(r)
+ // TODO add register info argument and build a MultiInfo to pass
+ asyncResetAlwaysBlocks += ((clk, reset, addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), "")))
}
}
@@ -996,7 +1004,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
// declare("wire", LowerTypes.loweredName(en), en.tpe)
//; Read port
- assign(addr, netlist(addr), NoInfo) // Info should come from addr connection
+ assign(addr, netlist(addr))
// assign(en, netlist(en)) //;Connects value to m.r.en
val mem = WRef(sx.name, memType(sx), MemKind, UnknownFlow)
val memPort = WSubAccess(mem, addr, sx.dataType, UnknownFlow)
@@ -1015,7 +1023,8 @@ class VerilogEmitter extends SeqTransform with Emitter {
val mask = memPortField(sx, w, "mask")
val en = memPortField(sx, w, "en")
//Ports should share an always@posedge, so can't have intermediary wire
- val clk = netlist(memPortField(sx, w, "clk"))
+ // TODO should we use the info here for anything?
+ val InfoExpr(_, clk) = netlist(memPortField(sx, w, "clk"))
declare("wire", LowerTypes.loweredName(data), data.tpe, sx.info)
declare("wire", LowerTypes.loweredName(addr), addr.tpe, sx.info)
@@ -1023,11 +1032,10 @@ class VerilogEmitter extends SeqTransform with Emitter {
declare("wire", LowerTypes.loweredName(en), en.tpe, sx.info)
// Write port
- // Info should come from netlist
- assign(data, netlist(data), NoInfo)
- assign(addr, netlist(addr), NoInfo)
- assign(mask, netlist(mask), NoInfo)
- assign(en, netlist(en), NoInfo)
+ assign(data, netlist(data))
+ assign(addr, netlist(addr))
+ assign(mask, netlist(mask))
+ assign(en, netlist(en))
val mem = WRef(sx.name, memType(sx), MemKind, UnknownFlow)
val memPort = WSubAccess(mem, addr, sx.dataType, UnknownFlow)
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 2f1daadd..cda22d27 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -169,6 +169,44 @@ case object Dshlw extends PrimOp {
}
}
+/** Internal class used for propagating [[Info]] across [[Expression]]s
+ *
+ * In particular, this is useful in "Netlist" datastructures mapping node or other [[Statement]]s
+ * to [[Expression]]s
+ *
+ * @note This is not allowed to leak from any transform
+ */
+private[firrtl] case class InfoExpr(info: Info, expr: Expression) extends Expression {
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachWidth(f: Width => Unit): Unit = ()
+ def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(this.expr))
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = this
+ def tpe: Type = expr.tpe
+
+ // Members declared in firrtl.ir.FirrtlNode
+ def serialize: String = s"(${expr.serialize}: ${info.serialize})"
+}
+
+private[firrtl] object InfoExpr {
+ def wrap(info: Info, expr: Expression): Expression =
+ if (info == NoInfo) expr else InfoExpr(info, expr)
+
+ def unwrap(expr: Expression): (Info, Expression) = expr match {
+ case InfoExpr(i, e) => (i, e)
+ case other => (NoInfo, other)
+ }
+
+ def orElse(info: Info, alt: => Info): Info = if (info == NoInfo) alt else info
+
+ // TODO this the right name?
+ def map(expr: Expression)(f: Expression => Expression): Expression = expr match {
+ case ie: InfoExpr => ie.mapExpr(f)
+ case e => f(e)
+ }
+}
+
object WrappedExpression {
def apply(e: Expression) = new WrappedExpression(e)
def we(e: Expression) = new WrappedExpression(e)
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index 734b475d..275cbe51 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -95,10 +95,18 @@ object MultiInfo {
val infosx = infos.filterNot(_ == NoInfo)
infosx.size match {
case 0 => NoInfo
- case 1 => infosx.head
- case _ => new MultiInfo(infosx)
+ case 1 => infos.head
+ case _ => new MultiInfo(infos)
}
}
+
+ // Internal utility for unpacking implicit MultiInfo structure for muxes
+ // TODO should this be made into an API?
+ private[firrtl] def demux(info: Info): (Info, Info, Info) = info match {
+ case MultiInfo(infos) if infos.lengthCompare(3) == 0 => (infos(0), infos(1), infos(2))
+ case other => (other, NoInfo, NoInfo) // if not exactly 3, we don't know what to do
+ }
+
private def flattenInfo(infos: Seq[Info]): Seq[FileInfo] = infos.flatMap {
case NoInfo => Seq()
case f : FileInfo => Seq(f)
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index e7eebb57..ab4c9bfa 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -9,6 +9,7 @@ import firrtl.Mappers._
import firrtl.PrimOps._
import firrtl.WrappedExpression._
import firrtl.options.Dependency
+import firrtl.InfoExpr.unwrap
import annotation.tailrec
import collection.mutable
@@ -42,12 +43,7 @@ object ExpandWhens extends Pass {
def run(c: Circuit): Circuit = {
val modulesx = c.modules map {
case m: ExtModule => m
- case m: Module =>
- val (netlist, simlist, attaches, bodyx, sourceInfoMap) = expandWhens(m)
- val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet
- val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs, sourceInfoMap) ++
- combineAttaches(attaches) ++ simlist)
- Module(m.info, m.name, m.ports, newBody)
+ case m: Module => onModule(m)
}
Circuit(c.info, modulesx, c.main)
}
@@ -61,15 +57,6 @@ object ExpandWhens extends Pass {
/** Maps a reference to whatever connects to it. Used to resolve last connect semantics */
type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression]
- /** Collects Info data serialized names for nodes, aggregating into MultiInfo when necessary */
- class InfoMap extends mutable.HashMap[String, Info] {
- override def default(key: String): Info = {
- val x = NoInfo
- this(key) = x
- x
- }
- }
-
/** Contains all simulation constructs */
type Simlist = mutable.ArrayBuffer[Statement]
@@ -78,37 +65,29 @@ object ExpandWhens extends Pass {
*/
type Defaults = Seq[mutable.Map[WrappedExpression, Expression]]
-
- /** Expands a module's when statements
- * @param m Module to expand
- * @note Netlist maps a reference to whatever connects to it
- * @note Simlist contains all simulation constructs in m
- * @note Seq[Attach] contains all Attach statements (unsimplified)
- * @note Statement contains all declarations in the module (including DefNode's)
- */
- def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement, InfoMap) = {
+ /** Expands a module's when statements */
+ private def onModule(m: Module): Module = {
val namespace = Namespace(m)
val simlist = new Simlist
// Memoizes if an expression contains any WVoids inserted in this pass
val memoizedVoid = new mutable.HashSet[WrappedExpression] += WVoid
+ // Does an expression contain WVoid inserted in this pass?
+ def containsVoid(e: Expression): Boolean = e match {
+ case WVoid => true
+ case ValidIf(_, value, _) => memoizedVoid(value)
+ case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv)
+ case _ => false
+ }
+
+
// Memoizes the node that holds a particular expression, if any
val nodes = new NodeLookup
// Seq of attaches in order
lazy val attaches = mutable.ArrayBuffer.empty[Attach]
- val infoMap: InfoMap = new InfoMap
-
- /* Adds into into map, aggregates info into MultiInfo where necessary
- * @param key serialized name of node
- * @param info info being recorded
- */
- def saveInfo(key: String, info: Info): Unit = {
- infoMap(key) = infoMap(key) ++ info
- }
-
/* Removes connections/attaches from the statement
* Mutates namespace, simlist, nodes, attaches
* Mutates input netlist
@@ -133,15 +112,13 @@ object ExpandWhens extends Pass {
case w: WDefInstance =>
netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> WVoid))
w
- // Update netlist with self reference for each sink reference
- // Return self, unchanged
case r: DefRegister =>
- netlist ++= (getSinkRefs(r.name, r.tpe, DuplexFlow) map (ref => we(ref) -> ref))
+ // Update netlist with self reference for each sink reference
+ netlist ++= getSinkRefs(r.name, r.tpe, DuplexFlow).map(ref => we(ref) -> InfoExpr(r.info, ref))
r
// For value assignments, update netlist/attaches and return EmptyStmt
case c: Connect =>
- saveInfo(c.loc.serialize, c.info)
- netlist(c.loc) = c.expr
+ netlist(c.loc) = InfoExpr(c.info, c.expr)
EmptyStmt
case c: IsInvalid =>
netlist(c.expr) = WInvalid
@@ -179,27 +156,20 @@ object ExpandWhens extends Pass {
case Some(v) => Some(v)
case None => getDefault(lvalue, defaults)
}
- val res = default match {
+ // info0 and info1 correspond to Mux infos, use info0 only if ValidIf
+ val (res, info0, info1) = default match {
case Some(defaultValue) =>
- val trueValue = conseqNetlist getOrElse (lvalue, defaultValue)
- val falseValue = altNetlist getOrElse (lvalue, defaultValue)
+ val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue))
+ val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue))
(trueValue, falseValue) match {
- case (WInvalid, WInvalid) => WInvalid
- case (WInvalid, fv) => ValidIf(NOT(sx.pred), fv, fv.tpe)
- case (tv, WInvalid) => ValidIf(sx.pred, tv, tv.tpe)
- case (tv, fv) => Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)) //Muxing clocks will be checked during type checking
+ case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo)
+ case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo)
+ case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo)
+ case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo)
}
case None =>
// Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
- conseqNetlist getOrElse (lvalue, altNetlist(lvalue))
- }
-
- // Does an expression contain WVoid inserted in this pass?
- def containsVoid(e: Expression): Boolean = e match {
- case WVoid => true
- case ValidIf(_, value, _) => memoizedVoid(value)
- case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv)
- case _ => false
+ (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo)
}
res match {
@@ -217,7 +187,9 @@ object ExpandWhens extends Pass {
val name = namespace.newTemp
nodes(res) = name
netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
- DefNode(sx.info, name, res)
+ // Use MultiInfo constructor to preserve NoInfos
+ val info = new MultiInfo(List(sx.info, info0, info1))
+ DefNode(info, name, res)
}
case _ =>
netlist(lvalue) = res
@@ -233,8 +205,13 @@ object ExpandWhens extends Pass {
netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) =>
getSinkRefs(name, tpe, to_flow(dir)) map (ref => we(ref) -> WVoid)
})
+ // Do traversal and construct mutable datastructures
val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
- (netlist, simlist, attaches, bodyx, infoMap)
+
+ val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet
+ val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++
+ combineAttaches(attaches) ++ simlist)
+ Module(m.info, m.name, m.ports, newBody)
}
@@ -253,17 +230,20 @@ object ExpandWhens extends Pass {
}
/** Returns all connections/invalidations in the circuit
- * @todo Preserve Info
* @note Remove IsInvalids on attached Analog-typed components
*/
- private def expandNetlist(netlist: Netlist, attached: Set[WrappedExpression], sourceInfoMap: InfoMap) =
- netlist map {
- case (k, WInvalid) => // Remove IsInvalids on attached Analog types
- if (attached.contains(k)) EmptyStmt else IsInvalid(NoInfo, k.e1)
+ private def expandNetlist(netlist: Netlist, attached: Set[WrappedExpression]) = {
+ // Remove IsInvalids on attached Analog types
+ def handleInvalid(k: WrappedExpression, info: Info): Statement =
+ if (attached.contains(k)) EmptyStmt else IsInvalid(info, k.e1)
+ netlist.map {
+ case (k, WInvalid) => handleInvalid(k, NoInfo)
+ case (k, InfoExpr(info, WInvalid)) => handleInvalid(k, info)
case (k, v) =>
- val info = sourceInfoMap(k.e1.serialize)
- Connect(info, k.e1, v)
+ val (info, expr) = unwrap(v)
+ Connect(info, k.e1, expr)
}
+ }
/** Returns new sequence of combined Attaches
* @todo Preserve Info
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 29410c7f..000adc15 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -458,9 +458,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case _ => constPropMuxCond(m)
}
- private def constPropNodeRef(r: WRef, e: Expression) = e match {
- case _: UIntLiteral | _: SIntLiteral | _: WRef => e
- case _ => r
+ private def constPropNodeRef(r: WRef, e: Expression): Expression = {
+ def doit(ex: Expression) = ex match {
+ case _: UIntLiteral | _: SIntLiteral | _: WRef => ex
+ case _ => r
+ }
+ InfoExpr.map(e)(doit)
}
// Is "a" a "better name" than "b"?
@@ -475,7 +478,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case p: DoPrim => constPropPrim(p)
case m: Mux => constPropMux(m)
case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
- constPropNodeRef(ref, nodeMap(rname))
+ constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) =>
val module = instMap(inst.Instance)
// Check constSubOutputs to see if the submodule is driving a constant
@@ -487,6 +490,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated)
}
+ /** Hacky way of propagating source locators across nodes and connections that have just a
+ * reference on the right-hand side
+ *
+ * @todo generalize source locator propagation across Expressions and delete this method
+ * @todo is the `orElse` the way we want to do propagation here?
+ */
+ private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])
+ (stmt: Statement): Statement = stmt match {
+ // We check rname because inlining it would cause the original declaration to go away
+ case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
+ val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
+ node.copy(info = InfoExpr.orElse(info1, info0))
+ case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
+ val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
+ con.copy(info = InfoExpr.orElse(info1, info0))
+ case other => other
+ }
+
/* Constant propagate a Module
*
* Two pass process
@@ -547,7 +568,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
ref.copy(name = swapMap(rname))
// Only const prop on the rhs
case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
- constPropNodeRef(ref, nodeMap(rname))
+ constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
case x => x
}
if (old ne propagated) {
@@ -577,7 +598,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
// When propagating a reference, check if we want to keep the name that would be deleted
- def propagateRef(lname: String, value: Expression): Unit = {
+ def propagateRef(lname: String, value: Expression, info: Info): Unit = {
value match {
case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind =>
assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a
@@ -585,19 +606,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
swapMap += (lname -> rname, rname -> lname)
case _ =>
}
- nodeMap(lname) = value
+ nodeMap(lname) = InfoExpr.wrap(info, value)
}
def constPropStmt(s: Statement): Statement = {
- val stmtx = s map constPropStmt map constPropExpression(nodeMap, instMap, constSubOutputs)
+ val s0 = s.map(constPropStmt) // Statement recurse
+ val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation
+ val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions
// Record things that should be propagated
stmtx match {
- case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value)
+ case DefNode(info, name, value) if !dontTouches.contains(name) =>
+ propagateRef(name, value, info)
case reg: DefRegister if reg.reset.tpe == AsyncResetType =>
asyncResetRegs(reg.name) = reg
- case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) =>
+ case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) =>
val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe))
- propagateRef(wname, exprx)
+ propagateRef(wname, exprx, info)
// Record constants driving outputs
case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) =>
val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
@@ -633,7 +657,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case lit: Literal => baseCase.resolve(RegCPEntry(UnboundConstant, BoundConstant(lit)))
case WRef(regName, _, RegKind, _) => baseCase.resolve(RegCPEntry(BoundConstant(regName), UnboundConstant))
case WRef(nodeName, _, NodeKind, _) if nodeMap.contains(nodeName) =>
- val cached = nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(nodeMap(nodeName), unbound) })
+ val (_, expr) = InfoExpr.unwrap(nodeMap(nodeName))
+ val cached = nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(expr, unbound) })
baseCase.resolve(cached)
case Mux(_, tval, fval, _) =>
regConstant(tval, baseCase).resolve(regConstant(fval, baseCase))
@@ -676,8 +701,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Actually transform some statements
stmtx match {
// Propagate connections to references
- case Connect(info, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) =>
- Connect(info, lhs, nodeMap(rname))
+ case Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) =>
+ val (info1, value) = InfoExpr.unwrap(nodeMap(rname))
+ // Is this the right info combination/propagation function?
+ // See propagateDirectConnectionInfoOnly
+ Connect(InfoExpr.orElse(info1, info0), lhs, value)
// If an Attach has at least 1 port, any wires are redundant and can be removed
case Attach(info, exprs) if exprs.exists(kind(_) == PortKind) =>
Attach(info, exprs.filterNot(kind(_) == WireKind))
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
index ea694719..4bda25ce 100644
--- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
+++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
@@ -7,11 +7,19 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.options.Dependency
+import firrtl.InfoExpr.orElse
import scala.collection.mutable
object FlattenRegUpdate {
+ // Combination function for dealing with inlining of muxes and the handling of Triples of infos
+ private def combineInfos(muxInfo: Info, tinfo: Info, finfo: Info): Info = {
+ val (eninfo, tinfoAlt, finfoAlt) = MultiInfo.demux(muxInfo)
+ // Use MultiInfo constructor to preserve NoInfos
+ new MultiInfo(List(eninfo, orElse(tinfo, tinfoAlt), orElse(finfo, finfoAlt)))
+ }
+
/** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
type Netlist = mutable.HashMap[WrappedExpression, Expression]
@@ -26,10 +34,12 @@ object FlattenRegUpdate {
val netlist = new Netlist()
def onStmt(stmt: Statement): Statement = {
stmt.map(onStmt) match {
- case Connect(_, lhs, rhs) =>
- netlist(lhs) = rhs
- case DefNode(_, nname, rhs) =>
- netlist(WRef(nname)) = rhs
+ case Connect(info, lhs, rhs) =>
+ val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs)
+ netlist(lhs) = expr
+ case DefNode(info, nname, rhs) =>
+ val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs)
+ netlist(WRef(nname)) = expr
case _: IsInvalid => throwInternalError("Unexpected IsInvalid, should have been removed by now")
case _ => // Do nothing
}
@@ -64,19 +74,21 @@ object FlattenRegUpdate {
val regUpdates = mutable.ArrayBuffer.empty[Connect]
val netlist = buildNetlist(mod)
- def constructRegUpdate(e: Expression): Expression = {
+ def constructRegUpdate(e: Expression): (Info, Expression) = {
+ import InfoExpr.unwrap
// Only walk netlist for nodes and wires, NOT registers or other state
- val expr = kind(e) match {
- case NodeKind | WireKind => netlist.getOrElse(e, e)
- case _ => e
+ val (info, expr) = kind(e) match {
+ case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e))
+ case _ => unwrap(e)
}
expr match {
case mux: Mux if canFlatten(mux) =>
- val tvalx = constructRegUpdate(mux.tval)
- val fvalx = constructRegUpdate(mux.fval)
- mux.copy(tval = tvalx, fval = fvalx)
+ val (tinfo, tvalx) = constructRegUpdate(mux.tval)
+ val (finfo, fvalx) = constructRegUpdate(mux.fval)
+ val infox = combineInfos(info, tinfo, finfo)
+ (infox, mux.copy(tval = tvalx, fval = fvalx))
// Return the original expression to end flattening
- case _ => e
+ case _ => unwrap(e)
}
}
@@ -85,7 +97,8 @@ object FlattenRegUpdate {
assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
"Synchronous reset should have already been made explicit!")
val ref = WRef(reg)
- val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref)))
+ val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref))
+ val update = Connect(info, ref, rhs)
regUpdates += update
reg
// Remove connections to Registers so we preserve LowFirrtl single-connection semantics
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala
index 2db93626..6b3a9d07 100644
--- a/src/main/scala/firrtl/transforms/RemoveReset.scala
+++ b/src/main/scala/firrtl/transforms/RemoveReset.scala
@@ -30,7 +30,7 @@ object RemoveReset extends Transform with DependencyAPIMigration {
case _ => false
}
- private case class Reset(cond: Expression, value: Expression)
+ private case class Reset(cond: Expression, value: Expression, info: Info)
/** Return an immutable set of all invalid expressions in a module
* @param m a module
@@ -58,14 +58,16 @@ object RemoveReset extends Transform with DependencyAPIMigration {
reg.copy(reset = Utils.zero, init = WRef(reg))
case reg @ DefRegister(_, rname, _, _, Utils.zero, _) =>
reg.copy(init = WRef(reg)) // canonicalize
- case reg @ DefRegister(_, rname, _, _, reset, init) if reset.tpe != AsyncResetType =>
+ case reg @ DefRegister(info , rname, _, _, reset, init) if reset.tpe != AsyncResetType =>
// Add register reset to map
- resets(rname) = Reset(reset, init)
+ resets(rname) = Reset(reset, init, info)
reg.copy(reset = Utils.zero, init = WRef(reg))
case Connect(info, ref @ WRef(rname, _, RegKind, _), expr) if resets.contains(rname) =>
val reset = resets(rname)
val muxType = Utils.mux_type_and_widths(reset.value, expr)
- Connect(info, ref, Mux(reset.cond, reset.value, expr, muxType))
+ // Use reg source locator for mux enable and true value since that's where they're defined
+ val infox = MultiInfo(reset.info, reset.info, info)
+ Connect(infox, ref, Mux(reset.cond, reset.value, expr, muxType))
case other => other map onStmt
}
}
diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala
index 01e0a0ac..a2410f9d 100644
--- a/src/test/scala/firrtlTests/InfoSpec.scala
+++ b/src/test/scala/firrtlTests/InfoSpec.scala
@@ -23,6 +23,7 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers {
val Info1 = FileInfo(StringLit("Source.scala 1:4"))
val Info2 = FileInfo(StringLit("Source.scala 2:4"))
val Info3 = FileInfo(StringLit("Source.scala 3:4"))
+ val Info4 = FileInfo(StringLit("Source.scala 4:4"))
"Source locators on module ports" should "be propagated to Verilog" in {
val result = compileBody(s"""
@@ -119,6 +120,21 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers {
result should containLine (s"Child c ( //$Info1")
}
+ it should "be propagated across direct node assignments and connections" in {
+ val result = compile(s"""
+ |circuit Test :
+ | module Test :
+ | input in : UInt<8>
+ | output out : UInt<8>
+ | node a = in $Info1
+ | node b = a
+ | out <= b
+ |""".stripMargin
+ )
+ result should containTree { case Connect(Info1, Reference("out", _,_,_), Reference("in", _,_,_)) => true }
+ result should containLine (s"assign out = in; //$Info1")
+ }
+
"source locators" should "be propagated through ExpandWhens" in {
val input = """
|;buildInfoPackage: chisel3, version: 3.1-SNAPSHOT, scalaVersion: 2.11.7, sbtVersion: 0.13.11, builtAtString: 2016-11-26 18:48:38.030, builtAtMillis: 1480186118030
@@ -155,8 +171,12 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers {
""".stripMargin
val result = (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty)
- result should containLine ("x <= _GEN_2 @[GCD.scala 17:22 GCD.scala 19:19]")
- result should containLine ("y <= _GEN_3 @[GCD.scala 18:22 GCD.scala 19:30]")
+ result should containLine ("node _GEN_0 = mux(_T_14, _T_16, x) @[GCD.scala 17:18 GCD.scala 17:22 GCD.scala 15:14]")
+ result should containLine ("node _GEN_2 = mux(io_e, io_a, _GEN_0) @[GCD.scala 19:15 GCD.scala 19:19]")
+ result should containLine ("x <= _GEN_2")
+ result should containLine ("node _GEN_1 = mux(_T_18, _T_20, y) @[GCD.scala 18:18 GCD.scala 18:22 GCD.scala 16:14]")
+ result should containLine ("node _GEN_3 = mux(io_e, io_b, _GEN_1) @[GCD.scala 19:15 GCD.scala 19:30]")
+ result should containLine ("y <= _GEN_3")
}
"source locators for append option" should "use multiinfo" in {
@@ -173,6 +193,53 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers {
circuitState should containTree { case MultiInfo(`expectedInfos`) => true }
}
+ "source locators for basic register updates" should "be propagated to Verilog" in {
+ val result = compileBody(s"""
+ |input clock : Clock
+ |input reset : UInt<1>
+ |output io : { flip in : UInt<8>, out : UInt<8>}
+ |reg r : UInt<8>, clock
+ |r <= io.in $Info1
+ |io.out <= r
+ |""".stripMargin
+ )
+ result should containLine (s"r <= io_in; //$Info1")
+ }
+
+ "source locators for register reset" should "be propagated to Verilog" in {
+ val result = compileBody(s"""
+ |input clock : Clock
+ |input reset : UInt<1>
+ |output io : { flip in : UInt<8>, out : UInt<8>}
+ |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info3
+ |r <= io.in $Info1
+ |io.out <= r
+ |""".stripMargin
+ )
+ result should containLine (s"if (reset) begin //$Info3")
+ result should containLine (s"r <= 8'h0; //$Info3")
+ result should containLine (s"r <= io_in; //$Info1")
+ }
+
+ "source locators for complex register updates" should "be propagated to Verilog" in {
+ val result = compileBody(s"""
+ |input clock : Clock
+ |input reset : UInt<1>
+ |output io : { flip in : UInt<8>, flip a : UInt<1>, out : UInt<8>}
+ |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info1
+ |r <= UInt<2>(2) $Info2
+ |when io.a : $Info3
+ | r <= io.in $Info4
+ |io.out <= r
+ |""".stripMargin
+ )
+ result should containLine (s"if (reset) begin //$Info1")
+ result should containLine (s"r <= 8'h0; //$Info1")
+ result should containLine (s"end else if (io_a) begin //$Info3")
+ result should containLine (s"r <= io_in; //$Info4")
+ result should containLine (s"r <= 8'h2; //$Info2")
+ }
+
"FileInfo" should "be able to contain a escaped characters" in {
def input(info: String): String =
s"""circuit m: @[$info]