aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/ConstantPropagation.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala49
1 files changed, 26 insertions, 23 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 086f1cee..04ad2cb2 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -251,6 +251,24 @@ class ConstantPropagation extends Transform {
// Is "a" a "better name" than "b"?
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
+ def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
+ def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
+ private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = {
+ val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
+ val propagated = old match {
+ case p: DoPrim => constPropPrim(p)
+ case m: Mux => constPropMux(m)
+ case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
+ constPropNodeRef(ref, nodeMap(rname))
+ case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, MALE) =>
+ val module = instMap(inst)
+ // Check constSubOutputs to see if the submodule is driving a constant
+ constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
+ case x => x
+ }
+ propagated
+ }
+
/** Constant propagate a Module
*
* Two pass process
@@ -279,7 +297,7 @@ class ConstantPropagation extends Transform {
): (Module, Map[String, Literal], Map[String, Map[String, Seq[Literal]]]) = {
var nPropagated = 0L
- val nodeMap = mutable.HashMap.empty[String, Expression]
+ val nodeMap = new NodeMap()
// For cases where we are trying to constprop a bad name over a good one, we swap their names
// during the second pass
val swapMap = mutable.HashMap.empty[String, String]
@@ -325,21 +343,6 @@ class ConstantPropagation extends Transform {
case other => other map backPropStmt
}
- def constPropExpression(e: Expression): Expression = {
- val old = e map constPropExpression
- val propagated = old match {
- case p: DoPrim => constPropPrim(p)
- case m: Mux => constPropMux(m)
- case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
- constPropNodeRef(ref, nodeMap(rname))
- case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, MALE) =>
- val module = instMap(inst)
- // Check constSubOutputs to see if the submodule is driving a constant
- constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
- case x => x
- }
- propagated
- }
// When propagating a reference, check if we want to keep the name that would be deleted
def propagateRef(lname: String, value: Expression): Unit = {
@@ -354,31 +357,31 @@ class ConstantPropagation extends Transform {
}
def constPropStmt(s: Statement): Statement = {
- val stmtx = s map constPropStmt map constPropExpression
+ val stmtx = s map constPropStmt map constPropExpression(nodeMap, instMap, constSubOutputs)
stmtx match {
case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value)
case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) =>
- val exprx = constPropExpression(pad(expr, wtpe))
+ val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe))
propagateRef(wname, exprx)
// Record constants driving outputs
case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) =>
- val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal]
+ val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
constOutputs(pname) = paddedLit
// Const prop registers that are fed only a constant or a mux between and constant and the
// register itself
// This requires that reset has been made explicit
case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), expr) if !dontTouches.contains(lname) => expr match {
case lit: Literal =>
- nodeMap(lname) = constPropExpression(pad(lit, ltpe))
+ nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ltpe))
case Mux(_, tval: WRef, fval: Literal, _) if weq(lref, tval) =>
- nodeMap(lname) = constPropExpression(pad(fval, ltpe))
+ nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(fval, ltpe))
case Mux(_, tval: Literal, fval: WRef, _) if weq(lref, fval) =>
- nodeMap(lname) = constPropExpression(pad(tval, ltpe))
+ nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(tval, ltpe))
case _ =>
}
// Mark instance inputs connected to a constant
case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) =>
- val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal]
+ val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
val module = instMap(inst)
val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty)
portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty)