diff options
| author | Schuyler Eldridge | 2021-01-19 22:49:31 -0500 |
|---|---|---|
| committer | GitHub | 2021-01-20 03:49:31 +0000 |
| commit | 698a9dca52f819aca6309e3b03f2420a71bc89a6 (patch) | |
| tree | 4e7afa7beec5a176bd65922a5d29a334486b774e /src | |
| parent | 6d8e9041e000f9ea5fb3d069d1f9ec06d2158575 (diff) | |
Add --dont-fold option to disable folding prim ops (#2040)
This adds a --dont-fold options (backed by a DisableFold annotation)
that lets a user specify primitive operations which should never be
folded. This feature lets a user disable certain folds which may be
allowable in FIRRTL (or by any sane synthesis tool), but due to inane
Verilog language design causes formal equivalence tools to fail due to
the fold.
Add a test that a user can disable `a / a -> 1` with a
DisableFold(PrimOps.Div) annotation.
Signed-off-by: Schuyler Eldridge <schuyler.eldridge@sifive.com>
Diffstat (limited to 'src')
5 files changed, 75 insertions, 18 deletions
diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala index 99a6e9c3..2ac74de2 100644 --- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala +++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala @@ -280,3 +280,29 @@ case object PrettyNoExprInlining extends NoTargetAnnotation with FirrtlOption wi ) ) } + +/** Turn off folding a specific primitive operand + * @param op the op that should never be folded + */ +case class DisableFold(op: ir.PrimOp) extends NoTargetAnnotation with FirrtlOption + +object DisableFold extends HasShellOptions { + + private val mapping: Map[String, ir.PrimOp] = PrimOps.builtinPrimOps.map { case op => op.toString -> op }.toMap + + override val options = Seq( + new ShellOption[String]( + longOption = "dont-fold", + toAnnotationSeq = a => { + mapping + .get(a) + .orElse(throw new OptionsException(s"Unknown primop '$a'. (Did you misspell it?)")) + .map(DisableFold(_)) + .toSeq + }, + helpText = "Disable folding of specific primitive operations", + helpValueName = Some("<primop>") + ) + ) + +} diff --git a/src/main/scala/firrtl/stage/FirrtlCli.scala b/src/main/scala/firrtl/stage/FirrtlCli.scala index 18f14107..8be5fb74 100644 --- a/src/main/scala/firrtl/stage/FirrtlCli.scala +++ b/src/main/scala/firrtl/stage/FirrtlCli.scala @@ -21,7 +21,8 @@ trait FirrtlCli { this: Shell => firrtl.EmitAllModulesAnnotation, NoCircuitDedupAnnotation, WarnNoScalaVersionDeprecation, - PrettyNoExprInlining + PrettyNoExprInlining, + DisableFold ) .map(_.addOptions(parser)) diff --git a/src/main/scala/firrtl/stage/package.scala b/src/main/scala/firrtl/stage/package.scala index c159f852..68e7a9c5 100644 --- a/src/main/scala/firrtl/stage/package.scala +++ b/src/main/scala/firrtl/stage/package.scala @@ -34,6 +34,7 @@ package object stage { case a: CompilerAnnotation => logger.warn(s"Use of CompilerAnnotation is deprecated. Ignoring $a"); c case WarnNoScalaVersionDeprecation => c case PrettyNoExprInlining => c + case _: DisableFold => c } } } diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 5d57de3a..c89fcff1 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -14,6 +14,7 @@ import firrtl.graph.DiGraph import firrtl.analyses.InstanceKeyGraph import firrtl.annotations.TargetToken.Ref import firrtl.options.Dependency +import firrtl.stage.DisableFold import annotation.tailrec import collection.mutable @@ -401,7 +402,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { override def reduce = (a: Boolean, b: Boolean) => a ^ b } - private def constPropPrim(e: DoPrim): Expression = e.op match { + private def constPropPrim(e: DoPrim, disabledOps: Set[PrimOp]): Expression = e.op match { + case a if disabledOps(a) => e case Shl => foldShiftLeft(e) case Dshl => foldDynamicShiftLeft(e) case Shr => foldShiftRight(e) @@ -495,19 +497,25 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') def optimize(e: Expression): Expression = - constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + constPropExpression( + new NodeMap(), + Map.empty[Instance, OfModule], + Map.empty[OfModule, Map[String, Literal]], + Set.empty + )(e) def optimize(e: Expression, nodeMap: NodeMap): Expression = - constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]], Set.empty)(e) private def constPropExpression( nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], - constSubOutputs: Map[OfModule, Map[String, Literal]] + constSubOutputs: Map[OfModule, Map[String, Literal]], + disabledOps: Set[PrimOp] )(e: Expression ): Expression = { - val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs)) + val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)) val propagated = old match { - case p: DoPrim => constPropPrim(p) + case p: DoPrim => constPropPrim(p, disabledOps) case m: Mux => constPropMux(m) case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) @@ -519,7 +527,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { } // We're done when the Expression no longer changes if (propagated eq old) propagated - else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) + else constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(propagated) } /** Hacky way of propagating source locators across nodes and connections that have just a @@ -555,6 +563,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { * @param instMap map of instance names to Module name * @param constInputs map of names of m's input ports to literal driving it (if applicable) * @param constSubOutputs Map of Module name to Map of output port name to literal driving it + * @param disabledOps a Set of any PrimOps that should not be folded * @return (Constpropped Module, Map of output port names to literal value, * Map of submodule modulenames to Map of input port names to literal values) */ @@ -564,7 +573,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { dontTouches: Set[String], instMap: collection.Map[Instance, OfModule], constInputs: Map[String, Literal], - constSubOutputs: Map[OfModule, Map[String, Literal]] + constSubOutputs: Map[OfModule, Map[String, Literal]], + disabledOps: Set[PrimOp] ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { var nPropagated = 0L @@ -646,7 +656,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { def constPropStmt(s: Statement): Statement = { val s0 = s.map(constPropStmt) // Statement recurse val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation - val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions + // propagate sub-Expressions + val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)) // Record things that should be propagated stmtx match { case DefNode(info, name, value) if !dontTouches.contains(name) => @@ -654,11 +665,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { case reg: DefRegister if reg.reset.tpe == AsyncResetType => asyncResetRegs(reg.name) = reg case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => - val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe)) + val exprx = constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(expr, wtpe)) propagateRef(wname, exprx, info) // Record constants driving outputs case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => - val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = + constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal] constOutputs(pname) = paddedLit // Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns // This requires that reset has been made explicit @@ -714,7 +726,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { case _ => } - def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe)) + def padCPExp(e: Expression) = + constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(e, ltpe)) asyncResetRegs.get(lname) match { // Normal Register @@ -725,7 +738,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { // Mark instance inputs connected to a constant case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) => - val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = + constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal] val module = instMap(inst.Instance) val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty) portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty) @@ -750,7 +764,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { // When we call this function again, constOutputs and constSubInputs are reconstructed and // strictly a superset of the versions here - if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs) + if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs, disabledOps) else (modx, constOutputs.toMap, constSubInputs.mapValues(_.toMap).toMap) } @@ -761,7 +775,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) } - private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = { + private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]], disabledOps: Set[PrimOp]): Circuit = { val iGraph = InstanceKeyGraph(c) val moduleDeps = iGraph.getChildInstanceMap val instCount = iGraph.staticInstanceCount @@ -800,7 +814,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { dontTouches, moduleDeps(mname), constInputs.getOrElse(mname, Map.empty), - constOutputs + constOutputs, + disabledOps ) // Accumulate all Literals used to drive a particular Module port val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d)) @@ -852,6 +867,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { val dontTouchMap: Map[OfModule, Set[String]] = dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap - state.copy(circuit = run(state.circuit, dontTouchMap)) + val disabledOps = state.annotations.collect { case DisableFold(op) => op }.toSet + + state.copy(circuit = run(state.circuit, dontTouchMap, disabledOps)) } } diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 94973042..28c1d823 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -7,6 +7,7 @@ import firrtl.passes._ import firrtl.transforms._ import firrtl.testutils._ import firrtl.annotations.Annotation +import firrtl.stage.DisableFold class ConstantPropagationSpec extends FirrtlFlatSpec { val transforms: Seq[Transform] = @@ -798,6 +799,17 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { castCheck("Clock", "asClock") castCheck("AsyncReset", "asAsyncReset") } + + /* */ + "The rule a / a -> 1" should "be ignored if division folds are disabled" in { + val input = + """circuit foo: + | module foo: + | input a: UInt<8> + | output b: UInt<8> + | b <= div(a, a)""".stripMargin + (parse(exec(input, Seq(DisableFold(PrimOps.Div))))) should be(parse(input)) + } } // More sophisticated tests of the full compiler |
