diff options
Diffstat (limited to 'src')
5 files changed, 75 insertions, 18 deletions
diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala index 99a6e9c3..2ac74de2 100644 --- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala +++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala @@ -280,3 +280,29 @@ case object PrettyNoExprInlining extends NoTargetAnnotation with FirrtlOption wi ) ) } + +/** Turn off folding a specific primitive operand + * @param op the op that should never be folded + */ +case class DisableFold(op: ir.PrimOp) extends NoTargetAnnotation with FirrtlOption + +object DisableFold extends HasShellOptions { + + private val mapping: Map[String, ir.PrimOp] = PrimOps.builtinPrimOps.map { case op => op.toString -> op }.toMap + + override val options = Seq( + new ShellOption[String]( + longOption = "dont-fold", + toAnnotationSeq = a => { + mapping + .get(a) + .orElse(throw new OptionsException(s"Unknown primop '$a'. (Did you misspell it?)")) + .map(DisableFold(_)) + .toSeq + }, + helpText = "Disable folding of specific primitive operations", + helpValueName = Some("<primop>") + ) + ) + +} diff --git a/src/main/scala/firrtl/stage/FirrtlCli.scala b/src/main/scala/firrtl/stage/FirrtlCli.scala index 18f14107..8be5fb74 100644 --- a/src/main/scala/firrtl/stage/FirrtlCli.scala +++ b/src/main/scala/firrtl/stage/FirrtlCli.scala @@ -21,7 +21,8 @@ trait FirrtlCli { this: Shell => firrtl.EmitAllModulesAnnotation, NoCircuitDedupAnnotation, WarnNoScalaVersionDeprecation, - PrettyNoExprInlining + PrettyNoExprInlining, + DisableFold ) .map(_.addOptions(parser)) diff --git a/src/main/scala/firrtl/stage/package.scala b/src/main/scala/firrtl/stage/package.scala index c159f852..68e7a9c5 100644 --- a/src/main/scala/firrtl/stage/package.scala +++ b/src/main/scala/firrtl/stage/package.scala @@ -34,6 +34,7 @@ package object stage { case a: CompilerAnnotation => logger.warn(s"Use of CompilerAnnotation is deprecated. Ignoring $a"); c case WarnNoScalaVersionDeprecation => c case PrettyNoExprInlining => c + case _: DisableFold => c } } } diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 5d57de3a..c89fcff1 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -14,6 +14,7 @@ import firrtl.graph.DiGraph import firrtl.analyses.InstanceKeyGraph import firrtl.annotations.TargetToken.Ref import firrtl.options.Dependency +import firrtl.stage.DisableFold import annotation.tailrec import collection.mutable @@ -401,7 +402,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { override def reduce = (a: Boolean, b: Boolean) => a ^ b } - private def constPropPrim(e: DoPrim): Expression = e.op match { + private def constPropPrim(e: DoPrim, disabledOps: Set[PrimOp]): Expression = e.op match { + case a if disabledOps(a) => e case Shl => foldShiftLeft(e) case Dshl => foldDynamicShiftLeft(e) case Shr => foldShiftRight(e) @@ -495,19 +497,25 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') def optimize(e: Expression): Expression = - constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + constPropExpression( + new NodeMap(), + Map.empty[Instance, OfModule], + Map.empty[OfModule, Map[String, Literal]], + Set.empty + )(e) def optimize(e: Expression, nodeMap: NodeMap): Expression = - constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]], Set.empty)(e) private def constPropExpression( nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], - constSubOutputs: Map[OfModule, Map[String, Literal]] + constSubOutputs: Map[OfModule, Map[String, Literal]], + disabledOps: Set[PrimOp] )(e: Expression ): Expression = { - val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs)) + val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)) val propagated = old match { - case p: DoPrim => constPropPrim(p) + case p: DoPrim => constPropPrim(p, disabledOps) case m: Mux => constPropMux(m) case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) @@ -519,7 +527,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { } // We're done when the Expression no longer changes if (propagated eq old) propagated - else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) + else constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(propagated) } /** Hacky way of propagating source locators across nodes and connections that have just a @@ -555,6 +563,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { * @param instMap map of instance names to Module name * @param constInputs map of names of m's input ports to literal driving it (if applicable) * @param constSubOutputs Map of Module name to Map of output port name to literal driving it + * @param disabledOps a Set of any PrimOps that should not be folded * @return (Constpropped Module, Map of output port names to literal value, * Map of submodule modulenames to Map of input port names to literal values) */ @@ -564,7 +573,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { dontTouches: Set[String], instMap: collection.Map[Instance, OfModule], constInputs: Map[String, Literal], - constSubOutputs: Map[OfModule, Map[String, Literal]] + constSubOutputs: Map[OfModule, Map[String, Literal]], + disabledOps: Set[PrimOp] ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { var nPropagated = 0L @@ -646,7 +656,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { def constPropStmt(s: Statement): Statement = { val s0 = s.map(constPropStmt) // Statement recurse val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation - val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions + // propagate sub-Expressions + val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)) // Record things that should be propagated stmtx match { case DefNode(info, name, value) if !dontTouches.contains(name) => @@ -654,11 +665,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { case reg: DefRegister if reg.reset.tpe == AsyncResetType => asyncResetRegs(reg.name) = reg case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => - val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe)) + val exprx = constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(expr, wtpe)) propagateRef(wname, exprx, info) // Record constants driving outputs case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => - val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = + constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal] constOutputs(pname) = paddedLit // Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns // This requires that reset has been made explicit @@ -714,7 +726,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { case _ => } - def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe)) + def padCPExp(e: Expression) = + constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(e, ltpe)) asyncResetRegs.get(lname) match { // Normal Register @@ -725,7 +738,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { // Mark instance inputs connected to a constant case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) => - val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = + constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal] val module = instMap(inst.Instance) val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty) portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty) @@ -750,7 +764,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { // When we call this function again, constOutputs and constSubInputs are reconstructed and // strictly a superset of the versions here - if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs) + if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs, disabledOps) else (modx, constOutputs.toMap, constSubInputs.mapValues(_.toMap).toMap) } @@ -761,7 +775,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) } - private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = { + private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]], disabledOps: Set[PrimOp]): Circuit = { val iGraph = InstanceKeyGraph(c) val moduleDeps = iGraph.getChildInstanceMap val instCount = iGraph.staticInstanceCount @@ -800,7 +814,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { dontTouches, moduleDeps(mname), constInputs.getOrElse(mname, Map.empty), - constOutputs + constOutputs, + disabledOps ) // Accumulate all Literals used to drive a particular Module port val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d)) @@ -852,6 +867,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { val dontTouchMap: Map[OfModule, Set[String]] = dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap - state.copy(circuit = run(state.circuit, dontTouchMap)) + val disabledOps = state.annotations.collect { case DisableFold(op) => op }.toSet + + state.copy(circuit = run(state.circuit, dontTouchMap, disabledOps)) } } diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 94973042..28c1d823 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -7,6 +7,7 @@ import firrtl.passes._ import firrtl.transforms._ import firrtl.testutils._ import firrtl.annotations.Annotation +import firrtl.stage.DisableFold class ConstantPropagationSpec extends FirrtlFlatSpec { val transforms: Seq[Transform] = @@ -798,6 +799,17 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { castCheck("Clock", "asClock") castCheck("AsyncReset", "asAsyncReset") } + + /* */ + "The rule a / a -> 1" should "be ignored if division folds are disabled" in { + val input = + """circuit foo: + | module foo: + | input a: UInt<8> + | output b: UInt<8> + | b <= div(a, a)""".stripMargin + (parse(exec(input, Seq(DisableFold(PrimOps.Div))))) should be(parse(input)) + } } // More sophisticated tests of the full compiler |
