aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSchuyler Eldridge2021-01-19 22:49:31 -0500
committerGitHub2021-01-20 03:49:31 +0000
commit698a9dca52f819aca6309e3b03f2420a71bc89a6 (patch)
tree4e7afa7beec5a176bd65922a5d29a334486b774e /src
parent6d8e9041e000f9ea5fb3d069d1f9ec06d2158575 (diff)
Add --dont-fold option to disable folding prim ops (#2040)
This adds a --dont-fold options (backed by a DisableFold annotation) that lets a user specify primitive operations which should never be folded. This feature lets a user disable certain folds which may be allowable in FIRRTL (or by any sane synthesis tool), but due to inane Verilog language design causes formal equivalence tools to fail due to the fold. Add a test that a user can disable `a / a -> 1` with a DisableFold(PrimOps.Div) annotation. Signed-off-by: Schuyler Eldridge <schuyler.eldridge@sifive.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/stage/FirrtlAnnotations.scala26
-rw-r--r--src/main/scala/firrtl/stage/FirrtlCli.scala3
-rw-r--r--src/main/scala/firrtl/stage/package.scala1
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala51
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala12
5 files changed, 75 insertions, 18 deletions
diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala
index 99a6e9c3..2ac74de2 100644
--- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala
+++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala
@@ -280,3 +280,29 @@ case object PrettyNoExprInlining extends NoTargetAnnotation with FirrtlOption wi
)
)
}
+
+/** Turn off folding a specific primitive operand
+ * @param op the op that should never be folded
+ */
+case class DisableFold(op: ir.PrimOp) extends NoTargetAnnotation with FirrtlOption
+
+object DisableFold extends HasShellOptions {
+
+ private val mapping: Map[String, ir.PrimOp] = PrimOps.builtinPrimOps.map { case op => op.toString -> op }.toMap
+
+ override val options = Seq(
+ new ShellOption[String](
+ longOption = "dont-fold",
+ toAnnotationSeq = a => {
+ mapping
+ .get(a)
+ .orElse(throw new OptionsException(s"Unknown primop '$a'. (Did you misspell it?)"))
+ .map(DisableFold(_))
+ .toSeq
+ },
+ helpText = "Disable folding of specific primitive operations",
+ helpValueName = Some("<primop>")
+ )
+ )
+
+}
diff --git a/src/main/scala/firrtl/stage/FirrtlCli.scala b/src/main/scala/firrtl/stage/FirrtlCli.scala
index 18f14107..8be5fb74 100644
--- a/src/main/scala/firrtl/stage/FirrtlCli.scala
+++ b/src/main/scala/firrtl/stage/FirrtlCli.scala
@@ -21,7 +21,8 @@ trait FirrtlCli { this: Shell =>
firrtl.EmitAllModulesAnnotation,
NoCircuitDedupAnnotation,
WarnNoScalaVersionDeprecation,
- PrettyNoExprInlining
+ PrettyNoExprInlining,
+ DisableFold
)
.map(_.addOptions(parser))
diff --git a/src/main/scala/firrtl/stage/package.scala b/src/main/scala/firrtl/stage/package.scala
index c159f852..68e7a9c5 100644
--- a/src/main/scala/firrtl/stage/package.scala
+++ b/src/main/scala/firrtl/stage/package.scala
@@ -34,6 +34,7 @@ package object stage {
case a: CompilerAnnotation => logger.warn(s"Use of CompilerAnnotation is deprecated. Ignoring $a"); c
case WarnNoScalaVersionDeprecation => c
case PrettyNoExprInlining => c
+ case _: DisableFold => c
}
}
}
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 5d57de3a..c89fcff1 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -14,6 +14,7 @@ import firrtl.graph.DiGraph
import firrtl.analyses.InstanceKeyGraph
import firrtl.annotations.TargetToken.Ref
import firrtl.options.Dependency
+import firrtl.stage.DisableFold
import annotation.tailrec
import collection.mutable
@@ -401,7 +402,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
override def reduce = (a: Boolean, b: Boolean) => a ^ b
}
- private def constPropPrim(e: DoPrim): Expression = e.op match {
+ private def constPropPrim(e: DoPrim, disabledOps: Set[PrimOp]): Expression = e.op match {
+ case a if disabledOps(a) => e
case Shl => foldShiftLeft(e)
case Dshl => foldDynamicShiftLeft(e)
case Shr => foldShiftRight(e)
@@ -495,19 +497,25 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
def optimize(e: Expression): Expression =
- constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ constPropExpression(
+ new NodeMap(),
+ Map.empty[Instance, OfModule],
+ Map.empty[OfModule, Map[String, Literal]],
+ Set.empty
+ )(e)
def optimize(e: Expression, nodeMap: NodeMap): Expression =
- constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]], Set.empty)(e)
private def constPropExpression(
nodeMap: NodeMap,
instMap: collection.Map[Instance, OfModule],
- constSubOutputs: Map[OfModule, Map[String, Literal]]
+ constSubOutputs: Map[OfModule, Map[String, Literal]],
+ disabledOps: Set[PrimOp]
)(e: Expression
): Expression = {
- val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs))
+ val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps))
val propagated = old match {
- case p: DoPrim => constPropPrim(p)
+ case p: DoPrim => constPropPrim(p, disabledOps)
case m: Mux => constPropMux(m)
case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
@@ -519,7 +527,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
}
// We're done when the Expression no longer changes
if (propagated eq old) propagated
- else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated)
+ else constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(propagated)
}
/** Hacky way of propagating source locators across nodes and connections that have just a
@@ -555,6 +563,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
* @param instMap map of instance names to Module name
* @param constInputs map of names of m's input ports to literal driving it (if applicable)
* @param constSubOutputs Map of Module name to Map of output port name to literal driving it
+ * @param disabledOps a Set of any PrimOps that should not be folded
* @return (Constpropped Module, Map of output port names to literal value,
* Map of submodule modulenames to Map of input port names to literal values)
*/
@@ -564,7 +573,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
dontTouches: Set[String],
instMap: collection.Map[Instance, OfModule],
constInputs: Map[String, Literal],
- constSubOutputs: Map[OfModule, Map[String, Literal]]
+ constSubOutputs: Map[OfModule, Map[String, Literal]],
+ disabledOps: Set[PrimOp]
): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
var nPropagated = 0L
@@ -646,7 +656,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
def constPropStmt(s: Statement): Statement = {
val s0 = s.map(constPropStmt) // Statement recurse
val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation
- val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions
+ // propagate sub-Expressions
+ val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps))
// Record things that should be propagated
stmtx match {
case DefNode(info, name, value) if !dontTouches.contains(name) =>
@@ -654,11 +665,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
case reg: DefRegister if reg.reset.tpe == AsyncResetType =>
asyncResetRegs(reg.name) = reg
case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) =>
- val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe))
+ val exprx = constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(expr, wtpe))
propagateRef(wname, exprx, info)
// Record constants driving outputs
case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) =>
- val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
+ val paddedLit =
+ constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal]
constOutputs(pname) = paddedLit
// Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns
// This requires that reset has been made explicit
@@ -714,7 +726,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
case _ =>
}
- def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe))
+ def padCPExp(e: Expression) =
+ constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(e, ltpe))
asyncResetRegs.get(lname) match {
// Normal Register
@@ -725,7 +738,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
// Mark instance inputs connected to a constant
case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) =>
- val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
+ val paddedLit =
+ constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal]
val module = instMap(inst.Instance)
val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty)
portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty)
@@ -750,7 +764,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
// When we call this function again, constOutputs and constSubInputs are reconstructed and
// strictly a superset of the versions here
- if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs)
+ if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs, disabledOps)
else (modx, constOutputs.toMap, constSubInputs.mapValues(_.toMap).toMap)
}
@@ -761,7 +775,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
}
- private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = {
+ private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]], disabledOps: Set[PrimOp]): Circuit = {
val iGraph = InstanceKeyGraph(c)
val moduleDeps = iGraph.getChildInstanceMap
val instCount = iGraph.staticInstanceCount
@@ -800,7 +814,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
dontTouches,
moduleDeps(mname),
constInputs.getOrElse(mname, Map.empty),
- constOutputs
+ constOutputs,
+ disabledOps
)
// Accumulate all Literals used to drive a particular Module port
val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d))
@@ -852,6 +867,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
val dontTouchMap: Map[OfModule, Set[String]] =
dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap
- state.copy(circuit = run(state.circuit, dontTouchMap))
+ val disabledOps = state.annotations.collect { case DisableFold(op) => op }.toSet
+
+ state.copy(circuit = run(state.circuit, dontTouchMap, disabledOps))
}
}
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 94973042..28c1d823 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -7,6 +7,7 @@ import firrtl.passes._
import firrtl.transforms._
import firrtl.testutils._
import firrtl.annotations.Annotation
+import firrtl.stage.DisableFold
class ConstantPropagationSpec extends FirrtlFlatSpec {
val transforms: Seq[Transform] =
@@ -798,6 +799,17 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
castCheck("Clock", "asClock")
castCheck("AsyncReset", "asAsyncReset")
}
+
+ /* */
+ "The rule a / a -> 1" should "be ignored if division folds are disabled" in {
+ val input =
+ """circuit foo:
+ | module foo:
+ | input a: UInt<8>
+ | output b: UInt<8>
+ | b <= div(a, a)""".stripMargin
+ (parse(exec(input, Seq(DisableFold(PrimOps.Div))))) should be(parse(input))
+ }
}
// More sophisticated tests of the full compiler