From 430a8afb2cf42e9eef438c7ea38934113de0bbcf Mon Sep 17 00:00:00 2001 From: chick Date: Thu, 24 May 2018 13:34:25 -0700 Subject: Makes ExpandWhens preserve connect Infos * Collects Infos found for symbols * Merges multiple sources for symbol into MultiInfo * Restores these Infos on connect statements. * Add test showing preserved Infos * Changed ++ methods on the Info sub-classes * Ignore NoInfo being added * Fixed way adding was implemented in MultiInfo * Made InfoMap a class which defines the default value function --- src/main/scala/firrtl/ir/IR.scala | 6 ++-- src/main/scala/firrtl/passes/ExpandWhens.scala | 35 ++++++++++++++++++---- src/test/scala/firrtlTests/InfoSpec.scala | 40 ++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 53fbb765..3887f17d 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -21,7 +21,8 @@ case object NoInfo extends Info { } case class FileInfo(info: StringLit) extends Info { override def toString: String = " @[" + info.serialize + "]" - def ++(that: Info): Info = MultiInfo(Seq(this, that)) + //scalastyle:off method.name + def ++(that: Info): Info = if (that == NoInfo) this else MultiInfo(Seq(this, that)) } case class MultiInfo(infos: Seq[Info]) extends Info { private def collectStringLits(info: Info): Seq[StringLit] = info match { @@ -34,7 +35,8 @@ case class MultiInfo(infos: Seq[Info]) extends Info { if (parts.nonEmpty) parts.map(_.serialize).mkString(" @[", " ", "]") else "" } - def ++(that: Info): Info = MultiInfo(Seq(this, that)) + //scalastyle:off method.name + def ++(that: Info): Info = if (that == NoInfo) this else MultiInfo(infos :+ that) } object MultiInfo { def apply(infos: Info*) = { diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 519a1e1a..4abae636 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -30,9 +30,9 @@ object ExpandWhens extends Pass { val modulesx = c.modules map { case m: ExtModule => m case m: Module => - val (netlist, simlist, attaches, bodyx) = expandWhens(m) + val (netlist, simlist, attaches, bodyx, sourceInfoMap) = expandWhens(m) val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet - val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ + val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs, sourceInfoMap) ++ combineAttaches(attaches) ++ simlist) Module(m.info, m.name, m.ports, newBody) } @@ -45,6 +45,15 @@ object ExpandWhens extends Pass { /** Maps a reference to whatever connects to it. Used to resolve last connect semantics */ type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression] + /** Collects Info data serialized names for nodes, aggregating into MultiInfo when necessary */ + class InfoMap extends mutable.HashMap[String, Info] { + override def default(key: String): Info = { + val x = NoInfo + this(key) = x + x + } + } + /** Contains all simulation constructs */ type Simlist = mutable.ArrayBuffer[Statement] @@ -61,13 +70,24 @@ object ExpandWhens extends Pass { * @note Seq[Attach] contains all Attach statements (unsimplified) * @note Statement contains all declarations in the module (including DefNode's) */ - def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement) = { + def expandWhens(m: Module): (Netlist, Simlist, Seq[Attach], Statement, InfoMap) = { val namespace = Namespace(m) val simlist = new Simlist val nodes = new NodeMap // Seq of attaches in order lazy val attaches = mutable.ArrayBuffer.empty[Attach] + val infoMap: InfoMap = new InfoMap + + /** + * Adds into into map, aggregates info into MultiInfo where necessary + * @param key serialized name of node + * @param info info being recorded + */ + def saveInfo(key: String, info: Info): Unit = { + infoMap(key) = infoMap(key) ++ info + } + /** Removes connections/attaches from the statement * Mutates namespace, simlist, nodes, attaches * Mutates input netlist @@ -99,6 +119,7 @@ object ExpandWhens extends Pass { r // For value assignments, update netlist/attaches and return EmptyStmt case c: Connect => + saveInfo(c.loc.serialize, c.info) netlist(c.loc) = c.expr EmptyStmt case c: IsInvalid => @@ -177,7 +198,7 @@ object ExpandWhens extends Pass { getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid) }) val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) - (netlist, simlist, attaches, bodyx) + (netlist, simlist, attaches, bodyx, infoMap) } @@ -200,11 +221,13 @@ object ExpandWhens extends Pass { * @todo Preserve Info * @note Remove IsInvalids on attached Analog-typed components */ - private def expandNetlist(netlist: Netlist, attached: Set[WrappedExpression]) = + private def expandNetlist(netlist: Netlist, attached: Set[WrappedExpression], sourceInfoMap: InfoMap) = netlist map { case (k, WInvalid) => // Remove IsInvalids on attached Analog types if (attached.contains(k)) EmptyStmt else IsInvalid(NoInfo, k.e1) - case (k, v) => Connect(NoInfo, k.e1, v) + case (k, v) => + val info = sourceInfoMap(k.e1.serialize) + Connect(info, k.e1, v) } /** Returns new sequence of combined Attaches diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala index 8d49d753..dbc997cd 100644 --- a/src/test/scala/firrtlTests/InfoSpec.scala +++ b/src/test/scala/firrtlTests/InfoSpec.scala @@ -117,4 +117,44 @@ class InfoSpec extends FirrtlFlatSpec { result should containTree { case WDefInstance(Info1, "c", "Child", _) => true } result should containLine (s"Child c ( //$Info1") } + + "source locators" should "be propagated through ExpandWhens" in { + val input = """ + |;buildInfoPackage: chisel3, version: 3.1-SNAPSHOT, scalaVersion: 2.11.7, sbtVersion: 0.13.11, builtAtString: 2016-11-26 18:48:38.030, builtAtMillis: 1480186118030 + |circuit GCD : + | module GCD : + | input clock : Clock + | input reset : UInt<1> + | output io : {flip a : UInt<32>, flip b : UInt<32>, flip e : UInt<1>, z : UInt<32>, v : UInt<1>} + | + | io is invalid + | io is invalid + | reg x : UInt<32>, clock @[GCD.scala 15:14] + | reg y : UInt<32>, clock @[GCD.scala 16:14] + | node _T_14 = gt(x, y) @[GCD.scala 17:11] + | when _T_14 : @[GCD.scala 17:18] + | node _T_15 = sub(x, y) @[GCD.scala 17:27] + | node _T_16 = tail(_T_15, 1) @[GCD.scala 17:27] + | x <= _T_16 @[GCD.scala 17:22] + | skip @[GCD.scala 17:18] + | node _T_18 = eq(_T_14, UInt<1>("h00")) @[GCD.scala 17:18] + | when _T_18 : @[GCD.scala 18:18] + | node _T_19 = sub(y, x) @[GCD.scala 18:27] + | node _T_20 = tail(_T_19, 1) @[GCD.scala 18:27] + | y <= _T_20 @[GCD.scala 18:22] + | skip @[GCD.scala 18:18] + | when io.e : @[GCD.scala 19:15] + | x <= io.a @[GCD.scala 19:19] + | y <= io.b @[GCD.scala 19:30] + | skip @[GCD.scala 19:15] + | io.z <= x @[GCD.scala 20:8] + | node _T_22 = eq(y, UInt<1>("h00")) @[GCD.scala 21:13] + | io.v <= _T_22 @[GCD.scala 21:8] + | + """.stripMargin + + val result = (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) + result should containLine ("x <= _GEN_2 @[GCD.scala 17:22 GCD.scala 19:19]") + result should containLine ("y <= _GEN_3 @[GCD.scala 18:22 GCD.scala 19:30]") + } } -- cgit v1.2.3