diff options
| author | Jack Koenig | 2017-08-14 10:59:05 -0700 |
|---|---|---|
| committer | GitHub | 2017-08-14 10:59:05 -0700 |
| commit | 672162b4bf6ca4a4a4ed7a4a9ffaadfea428ede0 (patch) | |
| tree | 01bb0a5ef3ce91af8803fdac6caf2e72cdc6479a | |
| parent | a84956afa36dbe29e87dd6c2168848a426ec42d3 (diff) | |
Constant propagation across module boundaries (#633)
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 162 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 200 |
2 files changed, 327 insertions, 35 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index d5a4b7e1..69502911 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -9,7 +9,9 @@ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ +import firrtl.graph.DiGraph import firrtl.WrappedExpression.weq +import firrtl.analyses.InstanceGraph import annotation.tailrec import collection.mutable @@ -244,21 +246,50 @@ class ConstantPropagation extends Transform { // Is "a" a "better name" than "b"? private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') - // Two pass process - // 1. Propagate constants in expressions and forward propagate references - // 2. Propagate references again for backwards reference (Wires) - // TODO Replacing all wires with nodes makes the second pass unnecessary - // However, preserving decent names DOES require a second pass - // Replacing all wires with nodes makes it unnecessary for preserving decent names to trigger an - // extra iteration though + /** Constant propagate a Module + * + * Two pass process + * 1. Propagate constants in expressions and forward propagate references + * 2. Propagate references again for backwards reference (Wires) + * TODO Replacing all wires with nodes makes the second pass unnecessary + * However, preserving decent names DOES require a second pass + * Replacing all wires with nodes makes it unnecessary for preserving decent names to trigger an + * extra iteration though + * + * @param m the Module to run constant propagation on + * @param dontTouches names of components local to m that should not be propagated across + * @param instMap map of instance names to Module name + * @param constInputs map of names of m's input ports to literal driving it (if applicable) + * @param constSubOutputs Map of Module name to Map of output port name to literal driving it + * @return (Constpropped Module, Map of output port names to literal value, + * Map of submodule modulenames to Map of input port names to literal values) + */ @tailrec - private def constPropModule(m: Module, dontTouches: Set[String]): Module = { + private def constPropModule( + m: Module, + dontTouches: Set[String], + instMap: Map[String, String], + constInputs: Map[String, Literal], + constSubOutputs: Map[String, Map[String, Literal]] + ): (Module, Map[String, Literal], Map[String, Map[String, Seq[Literal]]]) = { + var nPropagated = 0L val nodeMap = mutable.HashMap.empty[String, Expression] // For cases where we are trying to constprop a bad name over a good one, we swap their names // during the second pass val swapMap = mutable.HashMap.empty[String, String] - + // Keep track of any outputs we drive with a constant + val constOutputs = mutable.HashMap.empty[String, Literal] + // Keep track of any submodule inputs we drive with a constant + // (can have more than 1 of the same submodule) + val constSubInputs = mutable.HashMap.empty[String, mutable.HashMap[String, Seq[Literal]]] + + // Copy constant mapping for constant inputs (except ones marked dontTouch!) + nodeMap ++= constInputs.filterNot { case (pname, _) => dontTouches.contains(pname) } + + // Note that on back propagation we *only* worry about swapping names and propagating references + // to constant wires, we don't need to worry about propagating primops or muxes since we'll do + // that on the next iteration if necessary def backPropExpr(expr: Expression): Expression = { val old = expr map backPropExpr val propagated = old match { @@ -296,6 +327,10 @@ class ConstantPropagation extends Transform { case m: Mux => constPropMux(m) case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => constPropNodeRef(ref, nodeMap(rname)) + case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, MALE) => + val module = instMap(inst) + // Check constSubOutputs to see if the submodule is driving a constant + constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) case x => x } propagated @@ -320,32 +355,119 @@ class ConstantPropagation extends Transform { case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => val exprx = constPropExpression(pad(expr, wtpe)) propagateRef(wname, exprx) + // Record constants driving outputs + case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => + val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal] + constOutputs(pname) = paddedLit // Const prop registers that are fed only a constant or a mux between and constant and the // register itself // This requires that reset has been made explicit - case Connect(_, rref @ WRef(rname, rtpe, RegKind, _), expr) => expr match { + case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), expr) => expr match { case lit: Literal => - nodeMap(rname) = constPropExpression(pad(lit, rtpe)) - case Mux(_, tval: WRef, fval: Literal, _) if weq(rref, tval) => - nodeMap(rname) = constPropExpression(pad(fval, rtpe)) - case Mux(_, tval: Literal, fval: WRef, _) if weq(rref, fval) => - nodeMap(rname) = constPropExpression(pad(tval, rtpe)) + nodeMap(lname) = constPropExpression(pad(lit, ltpe)) + case Mux(_, tval: WRef, fval: Literal, _) if weq(lref, tval) => + nodeMap(lname) = constPropExpression(pad(fval, ltpe)) + case Mux(_, tval: Literal, fval: WRef, _) if weq(lref, fval) => + nodeMap(lname) = constPropExpression(pad(tval, ltpe)) case _ => } + // Mark instance inputs connected to a constant + case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) => + val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal] + val module = instMap(inst) + val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty) + portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty) case _ => } stmtx } - val res = m.copy(body = backPropStmt(constPropStmt(m.body))) - if (nPropagated > 0) constPropModule(res, dontTouches) else res + val modx = m.copy(body = backPropStmt(constPropStmt(m.body))) + + // When we call this function again, constOutputs and constSubInputs are reconstructed and + // strictly a superset of the versions here + if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs) + else (modx, constOutputs.toMap, constSubInputs.mapValues(_.toMap).toMap) } + // Unify two maps using f to combine values of duplicate keys + private def unify[K, V](a: Map[K, V], b: Map[K, V])(f: (V, V) => V): Map[K, V] = + b.foldLeft(a) { case (acc, (k, v)) => + acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) + } + private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = { - val modulesx = c.modules.map { - case m: ExtModule => m - case m: Module => constPropModule(m, dontTouchMap.getOrElse(m.name, Set.empty)) + val iGraph = (new InstanceGraph(c)).graph + val moduleDeps = iGraph.edges.map { case (mod, children) => + mod.module -> children.map(i => i.name -> i.module).toMap } + + // Module name to number of instances + val instCount: Map[String, Int] = iGraph.getVertices.groupBy(_.module).mapValues(_.size) + + // DiGraph using Module names as nodes, destination of edge is a parent Module + val parentGraph: DiGraph[String] = iGraph.reverse.transformNodes(_.module) + + // This outer loop works by applying constant propagation to the modules in a topologically + // sorted order from leaf to root + // Modules will register any outputs they drive with a constant in constOutputs which is then + // checked by later modules in the same iteration (since we iterate from leaf to root) + // Since Modules can be instantiated multiple times, for inputs we must check that all instances + // are driven with the same constant value. Then, if we find a Module input where each instance + // is driven with the same constant (and not seen in a previous iteration), we iterate again + @tailrec + def iterate(toVisit: Set[String], + modules: Map[String, Module], + constInputs: Map[String, Map[String, Literal]]): Map[String, DefModule] = { + if (toVisit.isEmpty) modules + else { + // Order from leaf modules to root so that any module driving an output + // with a constant will be visible to modules that instantiate it + // TODO Generating order as we execute constant propagation on each module would be faster + val order = parentGraph.subgraph(toVisit).linearize + // Execute constant propagation on each module in order + // Aggreagte Module outputs that are driven constant for use by instaniating Modules + // Aggregate submodule inputs driven constant for checking later + val (modulesx, _, constInputsx) = + order.foldLeft((modules, + Map[String, Map[String, Literal]](), + Map[String, Map[String, Seq[Literal]]]())) { + case ((mmap, constOutputs, constInputsAcc), mname) => + val dontTouches = dontTouchMap.getOrElse(mname, Set.empty) + val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname), + constInputs.getOrElse(mname, Map.empty), constOutputs) + // Accumulate all Literals used to drive a particular Module port + val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d)) + (mmap + (mname -> mx), constOutputs + (mname -> mco), constInputsx) + } + // Determine which module inputs have all of the same, new constants driving them + val newProppedInputs = constInputsx.flatMap { case (mname, ports) => + val portsx = ports.flatMap { case (pname, lits) => + val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false) + val isModule = modules.contains(mname) // ExtModules are not contained in modules + val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1 + if (isModule && newPort && allSameConst) Some(pname -> lits.head) + else None + } + if (portsx.nonEmpty) Some(mname -> portsx) else None + } + val modsWithConstInputs = newProppedInputs.keySet + val newToVisit = modsWithConstInputs ++ + modsWithConstInputs.flatMap(parentGraph.reachableFrom) + // Combine const inputs (there can't be duplicate values in the inner maps) + val nextConstInputs = unify(constInputs, newProppedInputs)((a, b) => a ++ b) + iterate(newToVisit.toSet, modulesx, nextConstInputs) + } + } + + val modulesx = { + val nameMap = c.modules.collect { case m: Module => m.name -> m }.toMap + // We only pass names of Modules, we can't apply const prop to ExtModules + val mmap = iterate(nameMap.keySet, nameMap, Map.empty) + c.modules.map(m => mmap.getOrElse(m.name, m)) + } + + Circuit(c.info, modulesx, c.main) } diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index e42ecfac..b3a25f67 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -508,6 +508,116 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { """ (parse(exec(input))) should be (parse(check)) } + + "ConstProp" should "propagate constant outputs" in { + val input = +"""circuit Top : + module Child : + output out : UInt<1> + out <= UInt<1>(0) + module Top : + input x : UInt<1> + output z : UInt<1> + inst c of Child + z <= and(x, c.out) +""" + val check = +"""circuit Top : + module Child : + output out : UInt<1> + out <= UInt<1>(0) + module Top : + input x : UInt<1> + output z : UInt<1> + inst c of Child + z <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } + + "ConstProp" should "propagate constant inputs" in { + val input = +"""circuit Top : + module Child : + input in0 : UInt<1> + input in1 : UInt<1> + output out : UInt<1> + out <= and(in0, in1) + module Top : + input x : UInt<1> + output z : UInt<1> + inst c of Child + c.in0 <= x + c.in1 <= UInt<1>(1) + z <= c.out +""" + val check = +"""circuit Top : + module Child : + input in0 : UInt<1> + input in1 : UInt<1> + output out : UInt<1> + out <= in0 + module Top : + input x : UInt<1> + output z : UInt<1> + inst c of Child + c.in0 <= x + c.in1 <= UInt<1>(1) + z <= c.out +""" + (parse(exec(input))) should be (parse(check)) + } + + "ConstProp" should "propagate constant inputs ONLY if ALL instance inputs get the same value" in { + def circuit(allSame: Boolean) = +s"""circuit Top : + module Bottom : + input in : UInt<1> + output out : UInt<1> + out <= in + module Child : + output out : UInt<1> + inst b of Bottom + b.in <= UInt(1) + out <= b.out + module Top : + input x : UInt<1> + output z : UInt<1> + + inst c of Child + + inst b0 of Bottom + b0.in <= ${if (allSame) "UInt(1)" else "x"} + inst b1 of Bottom + b1.in <= UInt(1) + + z <= and(and(b0.out, b1.out), c.out) +""" + val resultFromAllSame = +"""circuit Top : + module Bottom : + input in : UInt<1> + output out : UInt<1> + out <= UInt(1) + module Child : + output out : UInt<1> + inst b of Bottom + b.in <= UInt(1) + out <= UInt(1) + module Top : + input x : UInt<1> + output z : UInt<1> + inst c of Child + inst b0 of Bottom + b0.in <= UInt(1) + inst b1 of Bottom + b1.in <= UInt(1) + z <= UInt(1) +""" + (parse(exec(circuit(false)))) should be (parse(circuit(false))) + (parse(exec(circuit(true)))) should be (parse(resultFromAllSame)) + } } // More sophisticated tests of the full compiler @@ -522,13 +632,7 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { | output y : UInt<1> | node z = x | y <= z""".stripMargin - val check = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Top.z"))) } @@ -541,17 +645,44 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { | wire z : UInt<1> | y <= z | z <= x""".stripMargin - val check = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | wire z : UInt<1> - | y <= z - | z <= x""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Top.z"))) } + it should "NOT optimize across dontTouch on output ports" in { + val input = + """circuit Top : + | module Child : + | output out : UInt<1> + | out <= UInt<1>(0) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of Child + | z <= and(x, c.out)""".stripMargin + val check = input + execute(input, check, Seq(dontTouch("Child.out"))) + } + + it should "NOT optimize across dontTouch on input ports" in { + val input = + """circuit Top : + | module Child : + | input in0 : UInt<1> + | input in1 : UInt<1> + | output out : UInt<1> + | out <= and(in0, in1) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of Child + | z <= c.out + | c.in0 <= x + | c.in1 <= UInt<1>(1)""".stripMargin + val check = input + execute(input, check, Seq(dontTouch("Child.in1"))) + } + it should "still propagate constants even when there is name swapping" in { val input = """circuit Top : @@ -608,6 +739,45 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input, check, Seq.empty) } + it should "pad constant connections to outputs when propagating" in { + val input = + """circuit Top : + | module Child : + | output x : UInt<8> + | x <= UInt<2>("h3") + | module Top : + | output z : UInt<16> + | inst c of Child + | z <= cat(UInt<2>("h3"), c.x)""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin + execute(input, check, Seq.empty) + } + + it should "pad constant connections to submodule inputs when propagating" in { + val input = + """circuit Top : + | module Child : + | input x : UInt<8> + | output y : UInt<16> + | y <= cat(UInt<2>("h3"), x) + | module Top : + | output z : UInt<16> + | inst c of Child + | c.x <= UInt<2>("h3") + | z <= c.y""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin + execute(input, check, Seq.empty) + } + + "Registers with no reset or connections" should "be replaced with constant zero" in { val input = """circuit Top : |
