diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 49 |
1 files changed, 26 insertions, 23 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 086f1cee..04ad2cb2 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -251,6 +251,24 @@ class ConstantPropagation extends Transform { // Is "a" a "better name" than "b"? private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') + def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = { + val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) + val propagated = old match { + case p: DoPrim => constPropPrim(p) + case m: Mux => constPropMux(m) + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, MALE) => + val module = instMap(inst) + // Check constSubOutputs to see if the submodule is driving a constant + constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) + case x => x + } + propagated + } + /** Constant propagate a Module * * Two pass process @@ -279,7 +297,7 @@ class ConstantPropagation extends Transform { ): (Module, Map[String, Literal], Map[String, Map[String, Seq[Literal]]]) = { var nPropagated = 0L - val nodeMap = mutable.HashMap.empty[String, Expression] + val nodeMap = new NodeMap() // For cases where we are trying to constprop a bad name over a good one, we swap their names // during the second pass val swapMap = mutable.HashMap.empty[String, String] @@ -325,21 +343,6 @@ class ConstantPropagation extends Transform { case other => other map backPropStmt } - def constPropExpression(e: Expression): Expression = { - val old = e map constPropExpression - val propagated = old match { - case p: DoPrim => constPropPrim(p) - case m: Mux => constPropMux(m) - case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) - case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, MALE) => - val module = instMap(inst) - // Check constSubOutputs to see if the submodule is driving a constant - constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) - case x => x - } - propagated - } // When propagating a reference, check if we want to keep the name that would be deleted def propagateRef(lname: String, value: Expression): Unit = { @@ -354,31 +357,31 @@ class ConstantPropagation extends Transform { } def constPropStmt(s: Statement): Statement = { - val stmtx = s map constPropStmt map constPropExpression + val stmtx = s map constPropStmt map constPropExpression(nodeMap, instMap, constSubOutputs) stmtx match { case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => - val exprx = constPropExpression(pad(expr, wtpe)) + val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe)) propagateRef(wname, exprx) // Record constants driving outputs case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => - val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] constOutputs(pname) = paddedLit // Const prop registers that are fed only a constant or a mux between and constant and the // register itself // This requires that reset has been made explicit case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), expr) if !dontTouches.contains(lname) => expr match { case lit: Literal => - nodeMap(lname) = constPropExpression(pad(lit, ltpe)) + nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ltpe)) case Mux(_, tval: WRef, fval: Literal, _) if weq(lref, tval) => - nodeMap(lname) = constPropExpression(pad(fval, ltpe)) + nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(fval, ltpe)) case Mux(_, tval: Literal, fval: WRef, _) if weq(lref, fval) => - nodeMap(lname) = constPropExpression(pad(tval, ltpe)) + nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(tval, ltpe)) case _ => } // Mark instance inputs connected to a constant case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) => - val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] val module = instMap(inst) val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty) portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty) |
