diff options
| author | Albert Magyar | 2019-10-07 11:56:30 -0700 |
|---|---|---|
| committer | mergify[bot] | 2019-10-07 18:56:30 +0000 |
| commit | 357eba4c2b1549de70843899b4dae7d657757d50 (patch) | |
| tree | fcfb740f9dfda8e4e7bdd24984ae027e871f6e32 /src | |
| parent | 621c5689ff9b441465a9e6a1f4d92af739603293 (diff) | |
Absorb some instance analysis into InstanceGraph, use safer boxed Strings (#1186)
* Replace instance analysis code with InstanceGraph API calls
* Add convenience implicits for using TargetTokens as safe boxed strings
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/analyses/InstanceGraph.scala | 31 | ||||
| -rw-r--r-- | src/main/scala/firrtl/annotations/TargetToken.scala | 30 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 59 |
4 files changed, 87 insertions, 41 deletions
diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala index 22f40359..59eae09b 100644 --- a/src/main/scala/firrtl/analyses/InstanceGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala @@ -8,7 +8,7 @@ import firrtl.ir._ import firrtl.graph._ import firrtl.Utils._ import firrtl.traversals.Foreachers._ -import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.annotations.TargetToken._ /** A class representing the instance hierarchy of a working IR Circuit @@ -62,6 +62,19 @@ class InstanceGraph(c: Circuit) { */ lazy val fullHierarchy: mutable.LinkedHashMap[WDefInstance,Seq[Seq[WDefInstance]]] = graph.pathsInDAG(trueTopInstance) + /** A count of the *static* number of instances of each module. For + * any module other than the top module, this is equivalent to the + * number of inst statements in the circuit instantiating each + * module, irrespective of the number of times (if any) the + * enclosing module appears in the hierarchy. Note that top module + * of the circuit has an associated count of 1, even though it is + * never directly instantiated. + */ + lazy val staticInstanceCount: Map[OfModule, Int] = { + val instModules = childInstances.flatMap(_._2.view.map(_.OfModule).toSeq) + instModules.foldLeft(Map(c.main.OfModule -> 1)) { case (counts, mod) => counts.updated(mod, counts.getOrElse(mod, 0) + 1) } + } + /** Finds the absolute paths (each represented by a Seq of instances * representing the chain of hierarchy) of all instances of a * particular module. @@ -107,8 +120,22 @@ class InstanceGraph(c: Circuit) { * instance/module [[firrtl.annotations.TargetToken]]s */ def getChildrenInstanceOfModule: mutable.LinkedHashMap[String, mutable.LinkedHashSet[(Instance, OfModule)]] = - childInstances.map(kv => kv._1 -> kv._2.map(i => (Instance(i.name), OfModule(i.module)))) + childInstances.map(kv => kv._1 -> kv._2.map(_.toTokens)) + + // Transforms a TraversableOnce input into an order-preserving map + // Iterates only once, no intermediate collections + // Can possibly be replaced using LinkedHashMap.from(..) or better immutable map in Scala 2.13 + private def asOrderedMap[K1, K2, V](it: TraversableOnce[K1], f: (K1) => (K2, V)): collection.Map[K2, V] = { + val lhmap = new mutable.LinkedHashMap[K2, V] + it.foreach { lhmap += f(_) } + lhmap + } + /** Given a circuit, returns a map from module name to a map + * in turn mapping instances names to corresponding module names + */ + def getChildrenInstanceMap: collection.Map[OfModule, collection.Map[Instance, OfModule]] = + childInstances.map(kv => kv._1.OfModule -> asOrderedMap(kv._2, (i: WDefInstance) => i.toTokens)) } diff --git a/src/main/scala/firrtl/annotations/TargetToken.scala b/src/main/scala/firrtl/annotations/TargetToken.scala index 587f30eb..70b64271 100644 --- a/src/main/scala/firrtl/annotations/TargetToken.scala +++ b/src/main/scala/firrtl/annotations/TargetToken.scala @@ -2,6 +2,9 @@ package firrtl.annotations +import firrtl._ +import ir.{DefModule, DefInstance} + /** Building block to represent a [[Target]] of a FIRRTL component */ sealed trait TargetToken { def keyword: String @@ -32,6 +35,33 @@ case object TargetToken { case object Init extends TargetToken { override def keyword: String = "init"; val value = "" } case object Reset extends TargetToken { override def keyword: String = "reset"; val value = "" } + implicit class fromStringToTargetToken(s: String) { + def Instance: Instance = new TargetToken.Instance(s) + def OfModule: OfModule = new TargetToken.OfModule(s) + def Ref: Ref = new TargetToken.Ref(s) + def Field: Field = new TargetToken.Field(s) + } + + implicit class fromIntToTargetToken(i: Int) { + def Index: Index = new TargetToken.Index(i) + } + + implicit class fromDefModuleToTargetToken(m: DefModule) { + def OfModule: OfModule = new TargetToken.OfModule(m.name) + } + + implicit class fromDefInstanceToTargetToken(i: DefInstance) { + def Instance: Instance = new TargetToken.Instance(i.name) + def OfModule: OfModule = new TargetToken.OfModule(i.module) + def toTokens: (Instance, OfModule) = (new TargetToken.Instance(i.name), new TargetToken.OfModule(i.module)) + } + + implicit class fromWDefInstanceToTargetToken(wi: WDefInstance) { + def Instance: Instance = new TargetToken.Instance(wi.name) + def OfModule: OfModule = new TargetToken.OfModule(wi.module) + def toTokens: (Instance, OfModule) = (new TargetToken.Instance(wi.name), new TargetToken.OfModule(wi.module)) + } + val keyword2targettoken = Map( "inst" -> ((value: String) => Instance(value)), "of" -> ((value: String) => OfModule(value)), diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 0ca98ac5..fd001827 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -134,11 +134,7 @@ class InlineInstances extends Transform with RegisteredTransform { val iGraph = new InstanceGraph(c) val namespaceMap = collection.mutable.Map[String, Namespace]() // Map of Module name to Map of instance name to Module name - val instMaps: Map[OfModule, Map[Instance, OfModule]] = { - iGraph.graph.getEdgeMap.view.map { case (mod, children) => - OfModule(mod.module) -> children.view.map(i => Instance(i.name) -> OfModule(i.module)).toMap - }.toMap - } + val instMaps = iGraph.getChildrenInstanceMap /** Add a prefix to all declarations updating a [[Namespace]] and appending to a [[RenameMap]] */ def appendNamePrefix( @@ -225,7 +221,7 @@ class InlineInstances extends Transform with RegisteredTransform { } def fixupRefs( - instMap: Map[Instance, OfModule], + instMap: collection.Map[Instance, OfModule], currentModule: IsModule)(e: Expression): Expression = { e match { case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) => diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index b183e059..f224546b 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -5,6 +5,7 @@ package transforms import firrtl._ import firrtl.annotations._ +import firrtl.annotations.TargetToken._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ @@ -327,10 +328,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // Is "a" a "better name" than "b"? private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') - def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) - def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) - private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = { + private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = { val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) val propagated = old match { case p: DoPrim => constPropPrim(p) @@ -338,7 +339,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, nodeMap(rname)) case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) => - val module = instMap(inst) + val module = instMap(inst.Instance) // Check constSubOutputs to see if the submodule is driving a constant constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) case x => x @@ -370,10 +371,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { private def constPropModule( m: Module, dontTouches: Set[String], - instMap: Map[String, String], + instMap: collection.Map[Instance, OfModule], constInputs: Map[String, Literal], - constSubOutputs: Map[String, Map[String, Literal]] - ): (Module, Map[String, Literal], Map[String, Map[String, Seq[Literal]]]) = { + constSubOutputs: Map[OfModule, Map[String, Literal]] + ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { var nPropagated = 0L val nodeMap = new NodeMap() @@ -384,7 +385,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { val constOutputs = mutable.HashMap.empty[String, Literal] // Keep track of any submodule inputs we drive with a constant // (can have more than 1 of the same submodule) - val constSubInputs = mutable.HashMap.empty[String, mutable.HashMap[String, Seq[Literal]]] + val constSubInputs = mutable.HashMap.empty[OfModule, mutable.HashMap[String, Seq[Literal]]] // AsyncReset registers don't have reset turned into a mux so we must be careful val asyncResetRegs = mutable.HashSet.empty[String] @@ -494,7 +495,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // Mark instance inputs connected to a constant case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) => val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] - val module = instMap(inst) + val module = instMap(inst.Instance) val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty) portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty) case _ => @@ -525,22 +526,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) } - private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = { - val iGraph = (new InstanceGraph(c)).graph - val moduleDeps = iGraph.getEdgeMap.map({ case (mod, children) => - mod.module -> children.map(i => i.name -> i.module).toMap - }) - - // This is a *relative* instance count, ie. how many there are when you visit each Module once - // (even if it is instantiated multiple times) - val instCount: Map[String, Int] = iGraph.getEdgeMap.foldLeft(Map(c.main -> 1)) { - case (cs, (_, values)) => values.foldLeft(cs) { - case (counts, value) => counts.updated(value.module, counts.getOrElse(value.module, 0) + 1) - } - } + + private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = { + val iGraph = new InstanceGraph(c) + val moduleDeps = iGraph.getChildrenInstanceMap + val instCount = iGraph.staticInstanceCount // DiGraph using Module names as nodes, destination of edge is a parent Module - val parentGraph: DiGraph[String] = iGraph.reverse.transformNodes(_.module) + val parentGraph: DiGraph[OfModule] = iGraph.graph.reverse.transformNodes(_.OfModule) // This outer loop works by applying constant propagation to the modules in a topologically // sorted order from leaf to root @@ -550,9 +543,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // are driven with the same constant value. Then, if we find a Module input where each instance // is driven with the same constant (and not seen in a previous iteration), we iterate again @tailrec - def iterate(toVisit: Set[String], - modules: Map[String, Module], - constInputs: Map[String, Map[String, Literal]]): Map[String, DefModule] = { + def iterate(toVisit: Set[OfModule], + modules: Map[OfModule, Module], + constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = { if (toVisit.isEmpty) modules else { // Order from leaf modules to root so that any module driving an output @@ -564,8 +557,8 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // Aggregate submodule inputs driven constant for checking later val (modulesx, _, constInputsx) = order.foldLeft((modules, - Map[String, Map[String, Literal]](), - Map[String, Map[String, Seq[Literal]]]())) { + Map[OfModule, Map[String, Literal]](), + Map[OfModule, Map[String, Seq[Literal]]]())) { case ((mmap, constOutputs, constInputsAcc), mname) => val dontTouches = dontTouchMap.getOrElse(mname, Set.empty) val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname), @@ -595,10 +588,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } val modulesx = { - val nameMap = c.modules.collect { case m: Module => m.name -> m }.toMap + val nameMap = c.modules.collect { case m: Module => m.OfModule -> m }.toMap // We only pass names of Modules, we can't apply const prop to ExtModules val mmap = iterate(nameMap.keySet, nameMap, Map.empty) - c.modules.map(m => mmap.getOrElse(m.name, m)) + c.modules.map(m => mmap.getOrElse(m.OfModule, m)) } @@ -606,11 +599,11 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } def execute(state: CircuitState): CircuitState = { - val dontTouches: Seq[(String, String)] = state.annotations.collect { - case DontTouchAnnotation(Target(_, Some(m), Seq(Ref(c)))) => m -> c + val dontTouches: Seq[(OfModule, String)] = state.annotations.collect { + case DontTouchAnnotation(Target(_, Some(m), Seq(Ref(c)))) => m.OfModule -> c } // Map from module name to component names - val dontTouchMap: Map[String, Set[String]] = + val dontTouchMap: Map[OfModule, Set[String]] = dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet) state.copy(circuit = run(state.circuit, dontTouchMap)) |
