diff options
| author | Albert Magyar | 2019-10-07 11:56:30 -0700 |
|---|---|---|
| committer | mergify[bot] | 2019-10-07 18:56:30 +0000 |
| commit | 357eba4c2b1549de70843899b4dae7d657757d50 (patch) | |
| tree | fcfb740f9dfda8e4e7bdd24984ae027e871f6e32 /src/main/scala/firrtl/transforms/ConstantPropagation.scala | |
| parent | 621c5689ff9b441465a9e6a1f4d92af739603293 (diff) | |
Absorb some instance analysis into InstanceGraph, use safer boxed Strings (#1186)
* Replace instance analysis code with InstanceGraph API calls
* Add convenience implicits for using TargetTokens as safe boxed strings
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 59 |
1 files changed, 26 insertions, 33 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index b183e059..f224546b 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -5,6 +5,7 @@ package transforms import firrtl._ import firrtl.annotations._ +import firrtl.annotations.TargetToken._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ @@ -327,10 +328,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // Is "a" a "better name" than "b"? private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') - def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) - def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) - private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = { + private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = { val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) val propagated = old match { case p: DoPrim => constPropPrim(p) @@ -338,7 +339,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, nodeMap(rname)) case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) => - val module = instMap(inst) + val module = instMap(inst.Instance) // Check constSubOutputs to see if the submodule is driving a constant constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) case x => x @@ -370,10 +371,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { private def constPropModule( m: Module, dontTouches: Set[String], - instMap: Map[String, String], + instMap: collection.Map[Instance, OfModule], constInputs: Map[String, Literal], - constSubOutputs: Map[String, Map[String, Literal]] - ): (Module, Map[String, Literal], Map[String, Map[String, Seq[Literal]]]) = { + constSubOutputs: Map[OfModule, Map[String, Literal]] + ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { var nPropagated = 0L val nodeMap = new NodeMap() @@ -384,7 +385,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { val constOutputs = mutable.HashMap.empty[String, Literal] // Keep track of any submodule inputs we drive with a constant // (can have more than 1 of the same submodule) - val constSubInputs = mutable.HashMap.empty[String, mutable.HashMap[String, Seq[Literal]]] + val constSubInputs = mutable.HashMap.empty[OfModule, mutable.HashMap[String, Seq[Literal]]] // AsyncReset registers don't have reset turned into a mux so we must be careful val asyncResetRegs = mutable.HashSet.empty[String] @@ -494,7 +495,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // Mark instance inputs connected to a constant case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) => val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] - val module = instMap(inst) + val module = instMap(inst.Instance) val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty) portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty) case _ => @@ -525,22 +526,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) } - private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = { - val iGraph = (new InstanceGraph(c)).graph - val moduleDeps = iGraph.getEdgeMap.map({ case (mod, children) => - mod.module -> children.map(i => i.name -> i.module).toMap - }) - - // This is a *relative* instance count, ie. how many there are when you visit each Module once - // (even if it is instantiated multiple times) - val instCount: Map[String, Int] = iGraph.getEdgeMap.foldLeft(Map(c.main -> 1)) { - case (cs, (_, values)) => values.foldLeft(cs) { - case (counts, value) => counts.updated(value.module, counts.getOrElse(value.module, 0) + 1) - } - } + + private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = { + val iGraph = new InstanceGraph(c) + val moduleDeps = iGraph.getChildrenInstanceMap + val instCount = iGraph.staticInstanceCount // DiGraph using Module names as nodes, destination of edge is a parent Module - val parentGraph: DiGraph[String] = iGraph.reverse.transformNodes(_.module) + val parentGraph: DiGraph[OfModule] = iGraph.graph.reverse.transformNodes(_.OfModule) // This outer loop works by applying constant propagation to the modules in a topologically // sorted order from leaf to root @@ -550,9 +543,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // are driven with the same constant value. Then, if we find a Module input where each instance // is driven with the same constant (and not seen in a previous iteration), we iterate again @tailrec - def iterate(toVisit: Set[String], - modules: Map[String, Module], - constInputs: Map[String, Map[String, Literal]]): Map[String, DefModule] = { + def iterate(toVisit: Set[OfModule], + modules: Map[OfModule, Module], + constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = { if (toVisit.isEmpty) modules else { // Order from leaf modules to root so that any module driving an output @@ -564,8 +557,8 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // Aggregate submodule inputs driven constant for checking later val (modulesx, _, constInputsx) = order.foldLeft((modules, - Map[String, Map[String, Literal]](), - Map[String, Map[String, Seq[Literal]]]())) { + Map[OfModule, Map[String, Literal]](), + Map[OfModule, Map[String, Seq[Literal]]]())) { case ((mmap, constOutputs, constInputsAcc), mname) => val dontTouches = dontTouchMap.getOrElse(mname, Set.empty) val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname), @@ -595,10 +588,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } val modulesx = { - val nameMap = c.modules.collect { case m: Module => m.name -> m }.toMap + val nameMap = c.modules.collect { case m: Module => m.OfModule -> m }.toMap // We only pass names of Modules, we can't apply const prop to ExtModules val mmap = iterate(nameMap.keySet, nameMap, Map.empty) - c.modules.map(m => mmap.getOrElse(m.name, m)) + c.modules.map(m => mmap.getOrElse(m.OfModule, m)) } @@ -606,11 +599,11 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } def execute(state: CircuitState): CircuitState = { - val dontTouches: Seq[(String, String)] = state.annotations.collect { - case DontTouchAnnotation(Target(_, Some(m), Seq(Ref(c)))) => m -> c + val dontTouches: Seq[(OfModule, String)] = state.annotations.collect { + case DontTouchAnnotation(Target(_, Some(m), Seq(Ref(c)))) => m.OfModule -> c } // Map from module name to component names - val dontTouchMap: Map[String, Set[String]] = + val dontTouchMap: Map[OfModule, Set[String]] = dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet) state.copy(circuit = run(state.circuit, dontTouchMap)) |
