aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Magyar2019-10-07 11:56:30 -0700
committermergify[bot]2019-10-07 18:56:30 +0000
commit357eba4c2b1549de70843899b4dae7d657757d50 (patch)
treefcfb740f9dfda8e4e7bdd24984ae027e871f6e32 /src
parent621c5689ff9b441465a9e6a1f4d92af739603293 (diff)
Absorb some instance analysis into InstanceGraph, use safer boxed Strings (#1186)
* Replace instance analysis code with InstanceGraph API calls * Add convenience implicits for using TargetTokens as safe boxed strings
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/analyses/InstanceGraph.scala31
-rw-r--r--src/main/scala/firrtl/annotations/TargetToken.scala30
-rw-r--r--src/main/scala/firrtl/passes/Inline.scala8
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala59
4 files changed, 87 insertions, 41 deletions
diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala
index 22f40359..59eae09b 100644
--- a/src/main/scala/firrtl/analyses/InstanceGraph.scala
+++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala
@@ -8,7 +8,7 @@ import firrtl.ir._
import firrtl.graph._
import firrtl.Utils._
import firrtl.traversals.Foreachers._
-import firrtl.annotations.TargetToken.{Instance, OfModule}
+import firrtl.annotations.TargetToken._
/** A class representing the instance hierarchy of a working IR Circuit
@@ -62,6 +62,19 @@ class InstanceGraph(c: Circuit) {
*/
lazy val fullHierarchy: mutable.LinkedHashMap[WDefInstance,Seq[Seq[WDefInstance]]] = graph.pathsInDAG(trueTopInstance)
+ /** A count of the *static* number of instances of each module. For
+ * any module other than the top module, this is equivalent to the
+ * number of inst statements in the circuit instantiating each
+ * module, irrespective of the number of times (if any) the
+ * enclosing module appears in the hierarchy. Note that top module
+ * of the circuit has an associated count of 1, even though it is
+ * never directly instantiated.
+ */
+ lazy val staticInstanceCount: Map[OfModule, Int] = {
+ val instModules = childInstances.flatMap(_._2.view.map(_.OfModule).toSeq)
+ instModules.foldLeft(Map(c.main.OfModule -> 1)) { case (counts, mod) => counts.updated(mod, counts.getOrElse(mod, 0) + 1) }
+ }
+
/** Finds the absolute paths (each represented by a Seq of instances
* representing the chain of hierarchy) of all instances of a
* particular module.
@@ -107,8 +120,22 @@ class InstanceGraph(c: Circuit) {
* instance/module [[firrtl.annotations.TargetToken]]s
*/
def getChildrenInstanceOfModule: mutable.LinkedHashMap[String, mutable.LinkedHashSet[(Instance, OfModule)]] =
- childInstances.map(kv => kv._1 -> kv._2.map(i => (Instance(i.name), OfModule(i.module))))
+ childInstances.map(kv => kv._1 -> kv._2.map(_.toTokens))
+
+ // Transforms a TraversableOnce input into an order-preserving map
+ // Iterates only once, no intermediate collections
+ // Can possibly be replaced using LinkedHashMap.from(..) or better immutable map in Scala 2.13
+ private def asOrderedMap[K1, K2, V](it: TraversableOnce[K1], f: (K1) => (K2, V)): collection.Map[K2, V] = {
+ val lhmap = new mutable.LinkedHashMap[K2, V]
+ it.foreach { lhmap += f(_) }
+ lhmap
+ }
+ /** Given a circuit, returns a map from module name to a map
+ * in turn mapping instances names to corresponding module names
+ */
+ def getChildrenInstanceMap: collection.Map[OfModule, collection.Map[Instance, OfModule]] =
+ childInstances.map(kv => kv._1.OfModule -> asOrderedMap(kv._2, (i: WDefInstance) => i.toTokens))
}
diff --git a/src/main/scala/firrtl/annotations/TargetToken.scala b/src/main/scala/firrtl/annotations/TargetToken.scala
index 587f30eb..70b64271 100644
--- a/src/main/scala/firrtl/annotations/TargetToken.scala
+++ b/src/main/scala/firrtl/annotations/TargetToken.scala
@@ -2,6 +2,9 @@
package firrtl.annotations
+import firrtl._
+import ir.{DefModule, DefInstance}
+
/** Building block to represent a [[Target]] of a FIRRTL component */
sealed trait TargetToken {
def keyword: String
@@ -32,6 +35,33 @@ case object TargetToken {
case object Init extends TargetToken { override def keyword: String = "init"; val value = "" }
case object Reset extends TargetToken { override def keyword: String = "reset"; val value = "" }
+ implicit class fromStringToTargetToken(s: String) {
+ def Instance: Instance = new TargetToken.Instance(s)
+ def OfModule: OfModule = new TargetToken.OfModule(s)
+ def Ref: Ref = new TargetToken.Ref(s)
+ def Field: Field = new TargetToken.Field(s)
+ }
+
+ implicit class fromIntToTargetToken(i: Int) {
+ def Index: Index = new TargetToken.Index(i)
+ }
+
+ implicit class fromDefModuleToTargetToken(m: DefModule) {
+ def OfModule: OfModule = new TargetToken.OfModule(m.name)
+ }
+
+ implicit class fromDefInstanceToTargetToken(i: DefInstance) {
+ def Instance: Instance = new TargetToken.Instance(i.name)
+ def OfModule: OfModule = new TargetToken.OfModule(i.module)
+ def toTokens: (Instance, OfModule) = (new TargetToken.Instance(i.name), new TargetToken.OfModule(i.module))
+ }
+
+ implicit class fromWDefInstanceToTargetToken(wi: WDefInstance) {
+ def Instance: Instance = new TargetToken.Instance(wi.name)
+ def OfModule: OfModule = new TargetToken.OfModule(wi.module)
+ def toTokens: (Instance, OfModule) = (new TargetToken.Instance(wi.name), new TargetToken.OfModule(wi.module))
+ }
+
val keyword2targettoken = Map(
"inst" -> ((value: String) => Instance(value)),
"of" -> ((value: String) => OfModule(value)),
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala
index 0ca98ac5..fd001827 100644
--- a/src/main/scala/firrtl/passes/Inline.scala
+++ b/src/main/scala/firrtl/passes/Inline.scala
@@ -134,11 +134,7 @@ class InlineInstances extends Transform with RegisteredTransform {
val iGraph = new InstanceGraph(c)
val namespaceMap = collection.mutable.Map[String, Namespace]()
// Map of Module name to Map of instance name to Module name
- val instMaps: Map[OfModule, Map[Instance, OfModule]] = {
- iGraph.graph.getEdgeMap.view.map { case (mod, children) =>
- OfModule(mod.module) -> children.view.map(i => Instance(i.name) -> OfModule(i.module)).toMap
- }.toMap
- }
+ val instMaps = iGraph.getChildrenInstanceMap
/** Add a prefix to all declarations updating a [[Namespace]] and appending to a [[RenameMap]] */
def appendNamePrefix(
@@ -225,7 +221,7 @@ class InlineInstances extends Transform with RegisteredTransform {
}
def fixupRefs(
- instMap: Map[Instance, OfModule],
+ instMap: collection.Map[Instance, OfModule],
currentModule: IsModule)(e: Expression): Expression = {
e match {
case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) =>
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index b183e059..f224546b 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -5,6 +5,7 @@ package transforms
import firrtl._
import firrtl.annotations._
+import firrtl.annotations.TargetToken._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
@@ -327,10 +328,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// Is "a" a "better name" than "b"?
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
- def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
- def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
+ def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
- private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = {
+ private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = {
val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
val propagated = old match {
case p: DoPrim => constPropPrim(p)
@@ -338,7 +339,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, nodeMap(rname))
case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) =>
- val module = instMap(inst)
+ val module = instMap(inst.Instance)
// Check constSubOutputs to see if the submodule is driving a constant
constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
case x => x
@@ -370,10 +371,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
private def constPropModule(
m: Module,
dontTouches: Set[String],
- instMap: Map[String, String],
+ instMap: collection.Map[Instance, OfModule],
constInputs: Map[String, Literal],
- constSubOutputs: Map[String, Map[String, Literal]]
- ): (Module, Map[String, Literal], Map[String, Map[String, Seq[Literal]]]) = {
+ constSubOutputs: Map[OfModule, Map[String, Literal]]
+ ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
var nPropagated = 0L
val nodeMap = new NodeMap()
@@ -384,7 +385,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
val constOutputs = mutable.HashMap.empty[String, Literal]
// Keep track of any submodule inputs we drive with a constant
// (can have more than 1 of the same submodule)
- val constSubInputs = mutable.HashMap.empty[String, mutable.HashMap[String, Seq[Literal]]]
+ val constSubInputs = mutable.HashMap.empty[OfModule, mutable.HashMap[String, Seq[Literal]]]
// AsyncReset registers don't have reset turned into a mux so we must be careful
val asyncResetRegs = mutable.HashSet.empty[String]
@@ -494,7 +495,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// Mark instance inputs connected to a constant
case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) =>
val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
- val module = instMap(inst)
+ val module = instMap(inst.Instance)
val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty)
portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty)
case _ =>
@@ -525,22 +526,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
}
- private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = {
- val iGraph = (new InstanceGraph(c)).graph
- val moduleDeps = iGraph.getEdgeMap.map({ case (mod, children) =>
- mod.module -> children.map(i => i.name -> i.module).toMap
- })
-
- // This is a *relative* instance count, ie. how many there are when you visit each Module once
- // (even if it is instantiated multiple times)
- val instCount: Map[String, Int] = iGraph.getEdgeMap.foldLeft(Map(c.main -> 1)) {
- case (cs, (_, values)) => values.foldLeft(cs) {
- case (counts, value) => counts.updated(value.module, counts.getOrElse(value.module, 0) + 1)
- }
- }
+
+ private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = {
+ val iGraph = new InstanceGraph(c)
+ val moduleDeps = iGraph.getChildrenInstanceMap
+ val instCount = iGraph.staticInstanceCount
// DiGraph using Module names as nodes, destination of edge is a parent Module
- val parentGraph: DiGraph[String] = iGraph.reverse.transformNodes(_.module)
+ val parentGraph: DiGraph[OfModule] = iGraph.graph.reverse.transformNodes(_.OfModule)
// This outer loop works by applying constant propagation to the modules in a topologically
// sorted order from leaf to root
@@ -550,9 +543,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// are driven with the same constant value. Then, if we find a Module input where each instance
// is driven with the same constant (and not seen in a previous iteration), we iterate again
@tailrec
- def iterate(toVisit: Set[String],
- modules: Map[String, Module],
- constInputs: Map[String, Map[String, Literal]]): Map[String, DefModule] = {
+ def iterate(toVisit: Set[OfModule],
+ modules: Map[OfModule, Module],
+ constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = {
if (toVisit.isEmpty) modules
else {
// Order from leaf modules to root so that any module driving an output
@@ -564,8 +557,8 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// Aggregate submodule inputs driven constant for checking later
val (modulesx, _, constInputsx) =
order.foldLeft((modules,
- Map[String, Map[String, Literal]](),
- Map[String, Map[String, Seq[Literal]]]())) {
+ Map[OfModule, Map[String, Literal]](),
+ Map[OfModule, Map[String, Seq[Literal]]]())) {
case ((mmap, constOutputs, constInputsAcc), mname) =>
val dontTouches = dontTouchMap.getOrElse(mname, Set.empty)
val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname),
@@ -595,10 +588,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
val modulesx = {
- val nameMap = c.modules.collect { case m: Module => m.name -> m }.toMap
+ val nameMap = c.modules.collect { case m: Module => m.OfModule -> m }.toMap
// We only pass names of Modules, we can't apply const prop to ExtModules
val mmap = iterate(nameMap.keySet, nameMap, Map.empty)
- c.modules.map(m => mmap.getOrElse(m.name, m))
+ c.modules.map(m => mmap.getOrElse(m.OfModule, m))
}
@@ -606,11 +599,11 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
def execute(state: CircuitState): CircuitState = {
- val dontTouches: Seq[(String, String)] = state.annotations.collect {
- case DontTouchAnnotation(Target(_, Some(m), Seq(Ref(c)))) => m -> c
+ val dontTouches: Seq[(OfModule, String)] = state.annotations.collect {
+ case DontTouchAnnotation(Target(_, Some(m), Seq(Ref(c)))) => m.OfModule -> c
}
// Map from module name to component names
- val dontTouchMap: Map[String, Set[String]] =
+ val dontTouchMap: Map[OfModule, Set[String]] =
dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet)
state.copy(circuit = run(state.circuit, dontTouchMap))