aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack Koenig2017-08-14 10:59:05 -0700
committerGitHub2017-08-14 10:59:05 -0700
commit672162b4bf6ca4a4a4ed7a4a9ffaadfea428ede0 (patch)
tree01bb0a5ef3ce91af8803fdac6caf2e72cdc6479a
parenta84956afa36dbe29e87dd6c2168848a426ec42d3 (diff)
Constant propagation across module boundaries (#633)
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala162
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala200
2 files changed, 327 insertions, 35 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index d5a4b7e1..69502911 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -9,7 +9,9 @@ import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.PrimOps._
+import firrtl.graph.DiGraph
import firrtl.WrappedExpression.weq
+import firrtl.analyses.InstanceGraph
import annotation.tailrec
import collection.mutable
@@ -244,21 +246,50 @@ class ConstantPropagation extends Transform {
// Is "a" a "better name" than "b"?
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
- // Two pass process
- // 1. Propagate constants in expressions and forward propagate references
- // 2. Propagate references again for backwards reference (Wires)
- // TODO Replacing all wires with nodes makes the second pass unnecessary
- // However, preserving decent names DOES require a second pass
- // Replacing all wires with nodes makes it unnecessary for preserving decent names to trigger an
- // extra iteration though
+ /** Constant propagate a Module
+ *
+ * Two pass process
+ * 1. Propagate constants in expressions and forward propagate references
+ * 2. Propagate references again for backwards reference (Wires)
+ * TODO Replacing all wires with nodes makes the second pass unnecessary
+ * However, preserving decent names DOES require a second pass
+ * Replacing all wires with nodes makes it unnecessary for preserving decent names to trigger an
+ * extra iteration though
+ *
+ * @param m the Module to run constant propagation on
+ * @param dontTouches names of components local to m that should not be propagated across
+ * @param instMap map of instance names to Module name
+ * @param constInputs map of names of m's input ports to literal driving it (if applicable)
+ * @param constSubOutputs Map of Module name to Map of output port name to literal driving it
+ * @return (Constpropped Module, Map of output port names to literal value,
+ * Map of submodule modulenames to Map of input port names to literal values)
+ */
@tailrec
- private def constPropModule(m: Module, dontTouches: Set[String]): Module = {
+ private def constPropModule(
+ m: Module,
+ dontTouches: Set[String],
+ instMap: Map[String, String],
+ constInputs: Map[String, Literal],
+ constSubOutputs: Map[String, Map[String, Literal]]
+ ): (Module, Map[String, Literal], Map[String, Map[String, Seq[Literal]]]) = {
+
var nPropagated = 0L
val nodeMap = mutable.HashMap.empty[String, Expression]
// For cases where we are trying to constprop a bad name over a good one, we swap their names
// during the second pass
val swapMap = mutable.HashMap.empty[String, String]
-
+ // Keep track of any outputs we drive with a constant
+ val constOutputs = mutable.HashMap.empty[String, Literal]
+ // Keep track of any submodule inputs we drive with a constant
+ // (can have more than 1 of the same submodule)
+ val constSubInputs = mutable.HashMap.empty[String, mutable.HashMap[String, Seq[Literal]]]
+
+ // Copy constant mapping for constant inputs (except ones marked dontTouch!)
+ nodeMap ++= constInputs.filterNot { case (pname, _) => dontTouches.contains(pname) }
+
+ // Note that on back propagation we *only* worry about swapping names and propagating references
+ // to constant wires, we don't need to worry about propagating primops or muxes since we'll do
+ // that on the next iteration if necessary
def backPropExpr(expr: Expression): Expression = {
val old = expr map backPropExpr
val propagated = old match {
@@ -296,6 +327,10 @@ class ConstantPropagation extends Transform {
case m: Mux => constPropMux(m)
case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
constPropNodeRef(ref, nodeMap(rname))
+ case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, MALE) =>
+ val module = instMap(inst)
+ // Check constSubOutputs to see if the submodule is driving a constant
+ constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
case x => x
}
propagated
@@ -320,32 +355,119 @@ class ConstantPropagation extends Transform {
case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) =>
val exprx = constPropExpression(pad(expr, wtpe))
propagateRef(wname, exprx)
+ // Record constants driving outputs
+ case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) =>
+ val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal]
+ constOutputs(pname) = paddedLit
// Const prop registers that are fed only a constant or a mux between and constant and the
// register itself
// This requires that reset has been made explicit
- case Connect(_, rref @ WRef(rname, rtpe, RegKind, _), expr) => expr match {
+ case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), expr) => expr match {
case lit: Literal =>
- nodeMap(rname) = constPropExpression(pad(lit, rtpe))
- case Mux(_, tval: WRef, fval: Literal, _) if weq(rref, tval) =>
- nodeMap(rname) = constPropExpression(pad(fval, rtpe))
- case Mux(_, tval: Literal, fval: WRef, _) if weq(rref, fval) =>
- nodeMap(rname) = constPropExpression(pad(tval, rtpe))
+ nodeMap(lname) = constPropExpression(pad(lit, ltpe))
+ case Mux(_, tval: WRef, fval: Literal, _) if weq(lref, tval) =>
+ nodeMap(lname) = constPropExpression(pad(fval, ltpe))
+ case Mux(_, tval: Literal, fval: WRef, _) if weq(lref, fval) =>
+ nodeMap(lname) = constPropExpression(pad(tval, ltpe))
case _ =>
}
+ // Mark instance inputs connected to a constant
+ case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) =>
+ val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal]
+ val module = instMap(inst)
+ val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty)
+ portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty)
case _ =>
}
stmtx
}
- val res = m.copy(body = backPropStmt(constPropStmt(m.body)))
- if (nPropagated > 0) constPropModule(res, dontTouches) else res
+ val modx = m.copy(body = backPropStmt(constPropStmt(m.body)))
+
+ // When we call this function again, constOutputs and constSubInputs are reconstructed and
+ // strictly a superset of the versions here
+ if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs)
+ else (modx, constOutputs.toMap, constSubInputs.mapValues(_.toMap).toMap)
}
+ // Unify two maps using f to combine values of duplicate keys
+ private def unify[K, V](a: Map[K, V], b: Map[K, V])(f: (V, V) => V): Map[K, V] =
+ b.foldLeft(a) { case (acc, (k, v)) =>
+ acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
+ }
+
private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = {
- val modulesx = c.modules.map {
- case m: ExtModule => m
- case m: Module => constPropModule(m, dontTouchMap.getOrElse(m.name, Set.empty))
+ val iGraph = (new InstanceGraph(c)).graph
+ val moduleDeps = iGraph.edges.map { case (mod, children) =>
+ mod.module -> children.map(i => i.name -> i.module).toMap
}
+
+ // Module name to number of instances
+ val instCount: Map[String, Int] = iGraph.getVertices.groupBy(_.module).mapValues(_.size)
+
+ // DiGraph using Module names as nodes, destination of edge is a parent Module
+ val parentGraph: DiGraph[String] = iGraph.reverse.transformNodes(_.module)
+
+ // This outer loop works by applying constant propagation to the modules in a topologically
+ // sorted order from leaf to root
+ // Modules will register any outputs they drive with a constant in constOutputs which is then
+ // checked by later modules in the same iteration (since we iterate from leaf to root)
+ // Since Modules can be instantiated multiple times, for inputs we must check that all instances
+ // are driven with the same constant value. Then, if we find a Module input where each instance
+ // is driven with the same constant (and not seen in a previous iteration), we iterate again
+ @tailrec
+ def iterate(toVisit: Set[String],
+ modules: Map[String, Module],
+ constInputs: Map[String, Map[String, Literal]]): Map[String, DefModule] = {
+ if (toVisit.isEmpty) modules
+ else {
+ // Order from leaf modules to root so that any module driving an output
+ // with a constant will be visible to modules that instantiate it
+ // TODO Generating order as we execute constant propagation on each module would be faster
+ val order = parentGraph.subgraph(toVisit).linearize
+ // Execute constant propagation on each module in order
+ // Aggreagte Module outputs that are driven constant for use by instaniating Modules
+ // Aggregate submodule inputs driven constant for checking later
+ val (modulesx, _, constInputsx) =
+ order.foldLeft((modules,
+ Map[String, Map[String, Literal]](),
+ Map[String, Map[String, Seq[Literal]]]())) {
+ case ((mmap, constOutputs, constInputsAcc), mname) =>
+ val dontTouches = dontTouchMap.getOrElse(mname, Set.empty)
+ val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname),
+ constInputs.getOrElse(mname, Map.empty), constOutputs)
+ // Accumulate all Literals used to drive a particular Module port
+ val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d))
+ (mmap + (mname -> mx), constOutputs + (mname -> mco), constInputsx)
+ }
+ // Determine which module inputs have all of the same, new constants driving them
+ val newProppedInputs = constInputsx.flatMap { case (mname, ports) =>
+ val portsx = ports.flatMap { case (pname, lits) =>
+ val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false)
+ val isModule = modules.contains(mname) // ExtModules are not contained in modules
+ val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1
+ if (isModule && newPort && allSameConst) Some(pname -> lits.head)
+ else None
+ }
+ if (portsx.nonEmpty) Some(mname -> portsx) else None
+ }
+ val modsWithConstInputs = newProppedInputs.keySet
+ val newToVisit = modsWithConstInputs ++
+ modsWithConstInputs.flatMap(parentGraph.reachableFrom)
+ // Combine const inputs (there can't be duplicate values in the inner maps)
+ val nextConstInputs = unify(constInputs, newProppedInputs)((a, b) => a ++ b)
+ iterate(newToVisit.toSet, modulesx, nextConstInputs)
+ }
+ }
+
+ val modulesx = {
+ val nameMap = c.modules.collect { case m: Module => m.name -> m }.toMap
+ // We only pass names of Modules, we can't apply const prop to ExtModules
+ val mmap = iterate(nameMap.keySet, nameMap, Map.empty)
+ c.modules.map(m => mmap.getOrElse(m.name, m))
+ }
+
+
Circuit(c.info, modulesx, c.main)
}
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index e42ecfac..b3a25f67 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -508,6 +508,116 @@ class ConstantPropagationSpec extends FirrtlFlatSpec {
"""
(parse(exec(input))) should be (parse(check))
}
+
+ "ConstProp" should "propagate constant outputs" in {
+ val input =
+"""circuit Top :
+ module Child :
+ output out : UInt<1>
+ out <= UInt<1>(0)
+ module Top :
+ input x : UInt<1>
+ output z : UInt<1>
+ inst c of Child
+ z <= and(x, c.out)
+"""
+ val check =
+"""circuit Top :
+ module Child :
+ output out : UInt<1>
+ out <= UInt<1>(0)
+ module Top :
+ input x : UInt<1>
+ output z : UInt<1>
+ inst c of Child
+ z <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ "ConstProp" should "propagate constant inputs" in {
+ val input =
+"""circuit Top :
+ module Child :
+ input in0 : UInt<1>
+ input in1 : UInt<1>
+ output out : UInt<1>
+ out <= and(in0, in1)
+ module Top :
+ input x : UInt<1>
+ output z : UInt<1>
+ inst c of Child
+ c.in0 <= x
+ c.in1 <= UInt<1>(1)
+ z <= c.out
+"""
+ val check =
+"""circuit Top :
+ module Child :
+ input in0 : UInt<1>
+ input in1 : UInt<1>
+ output out : UInt<1>
+ out <= in0
+ module Top :
+ input x : UInt<1>
+ output z : UInt<1>
+ inst c of Child
+ c.in0 <= x
+ c.in1 <= UInt<1>(1)
+ z <= c.out
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ "ConstProp" should "propagate constant inputs ONLY if ALL instance inputs get the same value" in {
+ def circuit(allSame: Boolean) =
+s"""circuit Top :
+ module Bottom :
+ input in : UInt<1>
+ output out : UInt<1>
+ out <= in
+ module Child :
+ output out : UInt<1>
+ inst b of Bottom
+ b.in <= UInt(1)
+ out <= b.out
+ module Top :
+ input x : UInt<1>
+ output z : UInt<1>
+
+ inst c of Child
+
+ inst b0 of Bottom
+ b0.in <= ${if (allSame) "UInt(1)" else "x"}
+ inst b1 of Bottom
+ b1.in <= UInt(1)
+
+ z <= and(and(b0.out, b1.out), c.out)
+"""
+ val resultFromAllSame =
+"""circuit Top :
+ module Bottom :
+ input in : UInt<1>
+ output out : UInt<1>
+ out <= UInt(1)
+ module Child :
+ output out : UInt<1>
+ inst b of Bottom
+ b.in <= UInt(1)
+ out <= UInt(1)
+ module Top :
+ input x : UInt<1>
+ output z : UInt<1>
+ inst c of Child
+ inst b0 of Bottom
+ b0.in <= UInt(1)
+ inst b1 of Bottom
+ b1.in <= UInt(1)
+ z <= UInt(1)
+"""
+ (parse(exec(circuit(false)))) should be (parse(circuit(false)))
+ (parse(exec(circuit(true)))) should be (parse(resultFromAllSame))
+ }
}
// More sophisticated tests of the full compiler
@@ -522,13 +632,7 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
| output y : UInt<1>
| node z = x
| y <= z""".stripMargin
- val check =
- """circuit Top :
- | module Top :
- | input x : UInt<1>
- | output y : UInt<1>
- | node z = x
- | y <= z""".stripMargin
+ val check = input
execute(input, check, Seq(dontTouch("Top.z")))
}
@@ -541,17 +645,44 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
| wire z : UInt<1>
| y <= z
| z <= x""".stripMargin
- val check =
- """circuit Top :
- | module Top :
- | input x : UInt<1>
- | output y : UInt<1>
- | wire z : UInt<1>
- | y <= z
- | z <= x""".stripMargin
+ val check = input
execute(input, check, Seq(dontTouch("Top.z")))
}
+ it should "NOT optimize across dontTouch on output ports" in {
+ val input =
+ """circuit Top :
+ | module Child :
+ | output out : UInt<1>
+ | out <= UInt<1>(0)
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst c of Child
+ | z <= and(x, c.out)""".stripMargin
+ val check = input
+ execute(input, check, Seq(dontTouch("Child.out")))
+ }
+
+ it should "NOT optimize across dontTouch on input ports" in {
+ val input =
+ """circuit Top :
+ | module Child :
+ | input in0 : UInt<1>
+ | input in1 : UInt<1>
+ | output out : UInt<1>
+ | out <= and(in0, in1)
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst c of Child
+ | z <= c.out
+ | c.in0 <= x
+ | c.in1 <= UInt<1>(1)""".stripMargin
+ val check = input
+ execute(input, check, Seq(dontTouch("Child.in1")))
+ }
+
it should "still propagate constants even when there is name swapping" in {
val input =
"""circuit Top :
@@ -608,6 +739,45 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
execute(input, check, Seq.empty)
}
+ it should "pad constant connections to outputs when propagating" in {
+ val input =
+ """circuit Top :
+ | module Child :
+ | output x : UInt<8>
+ | x <= UInt<2>("h3")
+ | module Top :
+ | output z : UInt<16>
+ | inst c of Child
+ | z <= cat(UInt<2>("h3"), c.x)""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<16>
+ | z <= UInt<16>("h303")""".stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ it should "pad constant connections to submodule inputs when propagating" in {
+ val input =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<8>
+ | output y : UInt<16>
+ | y <= cat(UInt<2>("h3"), x)
+ | module Top :
+ | output z : UInt<16>
+ | inst c of Child
+ | c.x <= UInt<2>("h3")
+ | z <= c.y""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<16>
+ | z <= UInt<16>("h303")""".stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+
"Registers with no reset or connections" should "be replaced with constant zero" in {
val input =
"""circuit Top :