aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms
diff options
context:
space:
mode:
authorSchuyler Eldridge2021-01-19 22:49:31 -0500
committerGitHub2021-01-20 03:49:31 +0000
commit698a9dca52f819aca6309e3b03f2420a71bc89a6 (patch)
tree4e7afa7beec5a176bd65922a5d29a334486b774e /src/main/scala/firrtl/transforms
parent6d8e9041e000f9ea5fb3d069d1f9ec06d2158575 (diff)
Add --dont-fold option to disable folding prim ops (#2040)
This adds a --dont-fold options (backed by a DisableFold annotation) that lets a user specify primitive operations which should never be folded. This feature lets a user disable certain folds which may be allowable in FIRRTL (or by any sane synthesis tool), but due to inane Verilog language design causes formal equivalence tools to fail due to the fold. Add a test that a user can disable `a / a -> 1` with a DisableFold(PrimOps.Div) annotation. Signed-off-by: Schuyler Eldridge <schuyler.eldridge@sifive.com>
Diffstat (limited to 'src/main/scala/firrtl/transforms')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala51
1 files changed, 34 insertions, 17 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 5d57de3a..c89fcff1 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -14,6 +14,7 @@ import firrtl.graph.DiGraph
import firrtl.analyses.InstanceKeyGraph
import firrtl.annotations.TargetToken.Ref
import firrtl.options.Dependency
+import firrtl.stage.DisableFold
import annotation.tailrec
import collection.mutable
@@ -401,7 +402,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
override def reduce = (a: Boolean, b: Boolean) => a ^ b
}
- private def constPropPrim(e: DoPrim): Expression = e.op match {
+ private def constPropPrim(e: DoPrim, disabledOps: Set[PrimOp]): Expression = e.op match {
+ case a if disabledOps(a) => e
case Shl => foldShiftLeft(e)
case Dshl => foldDynamicShiftLeft(e)
case Shr => foldShiftRight(e)
@@ -495,19 +497,25 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
def optimize(e: Expression): Expression =
- constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ constPropExpression(
+ new NodeMap(),
+ Map.empty[Instance, OfModule],
+ Map.empty[OfModule, Map[String, Literal]],
+ Set.empty
+ )(e)
def optimize(e: Expression, nodeMap: NodeMap): Expression =
- constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]], Set.empty)(e)
private def constPropExpression(
nodeMap: NodeMap,
instMap: collection.Map[Instance, OfModule],
- constSubOutputs: Map[OfModule, Map[String, Literal]]
+ constSubOutputs: Map[OfModule, Map[String, Literal]],
+ disabledOps: Set[PrimOp]
)(e: Expression
): Expression = {
- val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs))
+ val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps))
val propagated = old match {
- case p: DoPrim => constPropPrim(p)
+ case p: DoPrim => constPropPrim(p, disabledOps)
case m: Mux => constPropMux(m)
case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
@@ -519,7 +527,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
}
// We're done when the Expression no longer changes
if (propagated eq old) propagated
- else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated)
+ else constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(propagated)
}
/** Hacky way of propagating source locators across nodes and connections that have just a
@@ -555,6 +563,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
* @param instMap map of instance names to Module name
* @param constInputs map of names of m's input ports to literal driving it (if applicable)
* @param constSubOutputs Map of Module name to Map of output port name to literal driving it
+ * @param disabledOps a Set of any PrimOps that should not be folded
* @return (Constpropped Module, Map of output port names to literal value,
* Map of submodule modulenames to Map of input port names to literal values)
*/
@@ -564,7 +573,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
dontTouches: Set[String],
instMap: collection.Map[Instance, OfModule],
constInputs: Map[String, Literal],
- constSubOutputs: Map[OfModule, Map[String, Literal]]
+ constSubOutputs: Map[OfModule, Map[String, Literal]],
+ disabledOps: Set[PrimOp]
): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
var nPropagated = 0L
@@ -646,7 +656,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
def constPropStmt(s: Statement): Statement = {
val s0 = s.map(constPropStmt) // Statement recurse
val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation
- val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions
+ // propagate sub-Expressions
+ val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps))
// Record things that should be propagated
stmtx match {
case DefNode(info, name, value) if !dontTouches.contains(name) =>
@@ -654,11 +665,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
case reg: DefRegister if reg.reset.tpe == AsyncResetType =>
asyncResetRegs(reg.name) = reg
case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) =>
- val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe))
+ val exprx = constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(expr, wtpe))
propagateRef(wname, exprx, info)
// Record constants driving outputs
case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) =>
- val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
+ val paddedLit =
+ constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal]
constOutputs(pname) = paddedLit
// Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns
// This requires that reset has been made explicit
@@ -714,7 +726,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
case _ =>
}
- def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe))
+ def padCPExp(e: Expression) =
+ constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(e, ltpe))
asyncResetRegs.get(lname) match {
// Normal Register
@@ -725,7 +738,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
// Mark instance inputs connected to a constant
case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) =>
- val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
+ val paddedLit =
+ constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal]
val module = instMap(inst.Instance)
val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty)
portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty)
@@ -750,7 +764,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
// When we call this function again, constOutputs and constSubInputs are reconstructed and
// strictly a superset of the versions here
- if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs)
+ if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs, disabledOps)
else (modx, constOutputs.toMap, constSubInputs.mapValues(_.toMap).toMap)
}
@@ -761,7 +775,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
}
- private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = {
+ private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]], disabledOps: Set[PrimOp]): Circuit = {
val iGraph = InstanceKeyGraph(c)
val moduleDeps = iGraph.getChildInstanceMap
val instCount = iGraph.staticInstanceCount
@@ -800,7 +814,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
dontTouches,
moduleDeps(mname),
constInputs.getOrElse(mname, Map.empty),
- constOutputs
+ constOutputs,
+ disabledOps
)
// Accumulate all Literals used to drive a particular Module port
val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d))
@@ -852,6 +867,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
val dontTouchMap: Map[OfModule, Set[String]] =
dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap
- state.copy(circuit = run(state.circuit, dontTouchMap))
+ val disabledOps = state.annotations.collect { case DisableFold(op) => op }.toSet
+
+ state.copy(circuit = run(state.circuit, dontTouchMap, disabledOps))
}
}