aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/ConstantPropagation.scala
diff options
context:
space:
mode:
authorAlbert Magyar2019-10-07 11:56:30 -0700
committermergify[bot]2019-10-07 18:56:30 +0000
commit357eba4c2b1549de70843899b4dae7d657757d50 (patch)
treefcfb740f9dfda8e4e7bdd24984ae027e871f6e32 /src/main/scala/firrtl/transforms/ConstantPropagation.scala
parent621c5689ff9b441465a9e6a1f4d92af739603293 (diff)
Absorb some instance analysis into InstanceGraph, use safer boxed Strings (#1186)
* Replace instance analysis code with InstanceGraph API calls * Add convenience implicits for using TargetTokens as safe boxed strings
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala59
1 files changed, 26 insertions, 33 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index b183e059..f224546b 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -5,6 +5,7 @@ package transforms
import firrtl._
import firrtl.annotations._
+import firrtl.annotations.TargetToken._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
@@ -327,10 +328,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// Is "a" a "better name" than "b"?
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
- def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
- def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
+ def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
- private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = {
+ private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = {
val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
val propagated = old match {
case p: DoPrim => constPropPrim(p)
@@ -338,7 +339,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, nodeMap(rname))
case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) =>
- val module = instMap(inst)
+ val module = instMap(inst.Instance)
// Check constSubOutputs to see if the submodule is driving a constant
constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
case x => x
@@ -370,10 +371,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
private def constPropModule(
m: Module,
dontTouches: Set[String],
- instMap: Map[String, String],
+ instMap: collection.Map[Instance, OfModule],
constInputs: Map[String, Literal],
- constSubOutputs: Map[String, Map[String, Literal]]
- ): (Module, Map[String, Literal], Map[String, Map[String, Seq[Literal]]]) = {
+ constSubOutputs: Map[OfModule, Map[String, Literal]]
+ ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
var nPropagated = 0L
val nodeMap = new NodeMap()
@@ -384,7 +385,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
val constOutputs = mutable.HashMap.empty[String, Literal]
// Keep track of any submodule inputs we drive with a constant
// (can have more than 1 of the same submodule)
- val constSubInputs = mutable.HashMap.empty[String, mutable.HashMap[String, Seq[Literal]]]
+ val constSubInputs = mutable.HashMap.empty[OfModule, mutable.HashMap[String, Seq[Literal]]]
// AsyncReset registers don't have reset turned into a mux so we must be careful
val asyncResetRegs = mutable.HashSet.empty[String]
@@ -494,7 +495,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// Mark instance inputs connected to a constant
case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) =>
val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
- val module = instMap(inst)
+ val module = instMap(inst.Instance)
val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty)
portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty)
case _ =>
@@ -525,22 +526,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
}
- private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = {
- val iGraph = (new InstanceGraph(c)).graph
- val moduleDeps = iGraph.getEdgeMap.map({ case (mod, children) =>
- mod.module -> children.map(i => i.name -> i.module).toMap
- })
-
- // This is a *relative* instance count, ie. how many there are when you visit each Module once
- // (even if it is instantiated multiple times)
- val instCount: Map[String, Int] = iGraph.getEdgeMap.foldLeft(Map(c.main -> 1)) {
- case (cs, (_, values)) => values.foldLeft(cs) {
- case (counts, value) => counts.updated(value.module, counts.getOrElse(value.module, 0) + 1)
- }
- }
+
+ private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = {
+ val iGraph = new InstanceGraph(c)
+ val moduleDeps = iGraph.getChildrenInstanceMap
+ val instCount = iGraph.staticInstanceCount
// DiGraph using Module names as nodes, destination of edge is a parent Module
- val parentGraph: DiGraph[String] = iGraph.reverse.transformNodes(_.module)
+ val parentGraph: DiGraph[OfModule] = iGraph.graph.reverse.transformNodes(_.OfModule)
// This outer loop works by applying constant propagation to the modules in a topologically
// sorted order from leaf to root
@@ -550,9 +543,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// are driven with the same constant value. Then, if we find a Module input where each instance
// is driven with the same constant (and not seen in a previous iteration), we iterate again
@tailrec
- def iterate(toVisit: Set[String],
- modules: Map[String, Module],
- constInputs: Map[String, Map[String, Literal]]): Map[String, DefModule] = {
+ def iterate(toVisit: Set[OfModule],
+ modules: Map[OfModule, Module],
+ constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = {
if (toVisit.isEmpty) modules
else {
// Order from leaf modules to root so that any module driving an output
@@ -564,8 +557,8 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// Aggregate submodule inputs driven constant for checking later
val (modulesx, _, constInputsx) =
order.foldLeft((modules,
- Map[String, Map[String, Literal]](),
- Map[String, Map[String, Seq[Literal]]]())) {
+ Map[OfModule, Map[String, Literal]](),
+ Map[OfModule, Map[String, Seq[Literal]]]())) {
case ((mmap, constOutputs, constInputsAcc), mname) =>
val dontTouches = dontTouchMap.getOrElse(mname, Set.empty)
val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname),
@@ -595,10 +588,10 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
val modulesx = {
- val nameMap = c.modules.collect { case m: Module => m.name -> m }.toMap
+ val nameMap = c.modules.collect { case m: Module => m.OfModule -> m }.toMap
// We only pass names of Modules, we can't apply const prop to ExtModules
val mmap = iterate(nameMap.keySet, nameMap, Map.empty)
- c.modules.map(m => mmap.getOrElse(m.name, m))
+ c.modules.map(m => mmap.getOrElse(m.OfModule, m))
}
@@ -606,11 +599,11 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
def execute(state: CircuitState): CircuitState = {
- val dontTouches: Seq[(String, String)] = state.annotations.collect {
- case DontTouchAnnotation(Target(_, Some(m), Seq(Ref(c)))) => m -> c
+ val dontTouches: Seq[(OfModule, String)] = state.annotations.collect {
+ case DontTouchAnnotation(Target(_, Some(m), Seq(Ref(c)))) => m.OfModule -> c
}
// Map from module name to component names
- val dontTouchMap: Map[String, Set[String]] =
+ val dontTouchMap: Map[OfModule, Set[String]] =
dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet)
state.copy(circuit = run(state.circuit, dontTouchMap))