aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlbert Magyar2020-02-07 00:26:33 -0800
committerAlbert Magyar2020-02-07 00:26:33 -0800
commit71240a3c832160c66483e91c85223db5b74cea2b (patch)
tree71e481209c35d2a5fd251db2df8d5eb49f35cd36
parent98a23c0c1fe018c0899d26da5bfd18fb1e0bcab5 (diff)
Refactor handling of reg const prop entries to cover more cases
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala43
1 files changed, 27 insertions, 16 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index ad784559..f450f6a6 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -87,6 +87,7 @@ object ConstantPropagation {
// fact that a constant propagation loop can include both self-assignments and consistent literals.
private case class RegCPEntry(r: ConstPropBinding[String], l: ConstPropBinding[Literal]) {
def resolve(that: RegCPEntry) = RegCPEntry(r.resolve(that.r), l.resolve(that.l))
+ def nonConstant: Boolean = r == NonConstant || l == NonConstant
}
}
@@ -498,28 +499,38 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
*
* @return a RegCPEntry describing the constant prop-compatible sources driving this expression
*/
- def regConstant(e: Expression): RegCPEntry = e match {
- case lit: Literal => RegCPEntry(UnboundConstant, BoundConstant(lit))
- case WRef(regName, _, RegKind, _) => RegCPEntry(BoundConstant(regName), UnboundConstant)
+
+ val unbound = RegCPEntry(UnboundConstant, UnboundConstant)
+ val selfBound = RegCPEntry(BoundConstant(lname), UnboundConstant)
+
+ def zero = passes.RemoveValidIf.getGroundZero(ltpe)
+ def regConstant(e: Expression, baseCase: RegCPEntry): RegCPEntry = e match {
+ case lit: Literal => baseCase.resolve(RegCPEntry(UnboundConstant, BoundConstant(lit)))
+ case WRef(regName, _, RegKind, _) => baseCase.resolve(RegCPEntry(BoundConstant(regName), UnboundConstant))
case WRef(nodeName, _, NodeKind, _) if nodeMap.contains(nodeName) =>
- nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(nodeMap(nodeName)) })
- case Mux(_, tval, fval, _) => regConstant(tval).resolve(regConstant(fval))
- case DoPrim(Or, Seq(a, b), Nil, BoolType) => regConstant(Mux(a, one, b, BoolType))
- case DoPrim(And, Seq(a, b), Nil, BoolType) => regConstant(Mux(a, b, zero, BoolType))
- case _ => RegCPEntry(NonConstant, NonConstant)
+ val cached = nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(nodeMap(nodeName), unbound) })
+ baseCase.resolve(cached)
+ case Mux(_, tval, fval, _) =>
+ regConstant(tval, baseCase).resolve(regConstant(fval, baseCase))
+ case DoPrim(Or, Seq(a, b), _, BoolType) =>
+ val aSel = regConstant(Mux(a, one, b, BoolType), baseCase)
+ if (!aSel.nonConstant) aSel else regConstant(Mux(b, one, a, BoolType), baseCase)
+ case DoPrim(And, Seq(a, b), _, BoolType) =>
+ val aSel = regConstant(Mux(a, b, zero, BoolType), baseCase)
+ if (!aSel.nonConstant) aSel else regConstant(Mux(b, a, zero, BoolType), baseCase)
+ case _ =>
+ RegCPEntry(NonConstant, NonConstant)
}
// Updates nodeMap after analyzing the returned value from regConstant
- def updateNodeMapIfConstant(e: Expression): Unit = regConstant(e) match {
- case RegCPEntry(BoundConstant(`lname`), litBinding) => litBinding match {
- case UnboundConstant => nodeMap(lname) = padCPExp(zero) // only self-assigns -> replace with zero
- case BoundConstant(lit) => nodeMap(lname) = padCPExp(lit) // self + lit assigns -> replace with lit
- case _ =>
- }
- case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) // only lit assigns
+ def updateNodeMapIfConstant(e: Expression): Unit = regConstant(e, selfBound) match {
+ case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero)
+ case RegCPEntry(BoundConstant(_), UnboundConstant) => nodeMap(lname) = padCPExp(zero)
+ case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
+ case RegCPEntry(BoundConstant(_), BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
case _ =>
}
- def zero = passes.RemoveValidIf.getGroundZero(ltpe)
+
def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe))
asyncResetRegs.get(lname) match {