aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/stage/FirrtlAnnotations.scala26
-rw-r--r--src/main/scala/firrtl/stage/FirrtlCli.scala3
-rw-r--r--src/main/scala/firrtl/stage/package.scala1
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala51
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala12
5 files changed, 75 insertions, 18 deletions
diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala
index 99a6e9c3..2ac74de2 100644
--- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala
+++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala
@@ -280,3 +280,29 @@ case object PrettyNoExprInlining extends NoTargetAnnotation with FirrtlOption wi
)
)
}
+
+/** Turn off folding a specific primitive operand
+ * @param op the op that should never be folded
+ */
+case class DisableFold(op: ir.PrimOp) extends NoTargetAnnotation with FirrtlOption
+
+object DisableFold extends HasShellOptions {
+
+ private val mapping: Map[String, ir.PrimOp] = PrimOps.builtinPrimOps.map { case op => op.toString -> op }.toMap
+
+ override val options = Seq(
+ new ShellOption[String](
+ longOption = "dont-fold",
+ toAnnotationSeq = a => {
+ mapping
+ .get(a)
+ .orElse(throw new OptionsException(s"Unknown primop '$a'. (Did you misspell it?)"))
+ .map(DisableFold(_))
+ .toSeq
+ },
+ helpText = "Disable folding of specific primitive operations",
+ helpValueName = Some("<primop>")
+ )
+ )
+
+}
diff --git a/src/main/scala/firrtl/stage/FirrtlCli.scala b/src/main/scala/firrtl/stage/FirrtlCli.scala
index 18f14107..8be5fb74 100644
--- a/src/main/scala/firrtl/stage/FirrtlCli.scala
+++ b/src/main/scala/firrtl/stage/FirrtlCli.scala
@@ -21,7 +21,8 @@ trait FirrtlCli { this: Shell =>
firrtl.EmitAllModulesAnnotation,
NoCircuitDedupAnnotation,
WarnNoScalaVersionDeprecation,
- PrettyNoExprInlining
+ PrettyNoExprInlining,
+ DisableFold
)
.map(_.addOptions(parser))
diff --git a/src/main/scala/firrtl/stage/package.scala b/src/main/scala/firrtl/stage/package.scala
index c159f852..68e7a9c5 100644
--- a/src/main/scala/firrtl/stage/package.scala
+++ b/src/main/scala/firrtl/stage/package.scala
@@ -34,6 +34,7 @@ package object stage {
case a: CompilerAnnotation => logger.warn(s"Use of CompilerAnnotation is deprecated. Ignoring $a"); c
case WarnNoScalaVersionDeprecation => c
case PrettyNoExprInlining => c
+ case _: DisableFold => c
}
}
}
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 5d57de3a..c89fcff1 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -14,6 +14,7 @@ import firrtl.graph.DiGraph
import firrtl.analyses.InstanceKeyGraph
import firrtl.annotations.TargetToken.Ref
import firrtl.options.Dependency
+import firrtl.stage.DisableFold
import annotation.tailrec
import collection.mutable
@@ -401,7 +402,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
override def reduce = (a: Boolean, b: Boolean) => a ^ b
}
- private def constPropPrim(e: DoPrim): Expression = e.op match {
+ private def constPropPrim(e: DoPrim, disabledOps: Set[PrimOp]): Expression = e.op match {
+ case a if disabledOps(a) => e
case Shl => foldShiftLeft(e)
case Dshl => foldDynamicShiftLeft(e)
case Shr => foldShiftRight(e)
@@ -495,19 +497,25 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
def optimize(e: Expression): Expression =
- constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ constPropExpression(
+ new NodeMap(),
+ Map.empty[Instance, OfModule],
+ Map.empty[OfModule, Map[String, Literal]],
+ Set.empty
+ )(e)
def optimize(e: Expression, nodeMap: NodeMap): Expression =
- constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]], Set.empty)(e)
private def constPropExpression(
nodeMap: NodeMap,
instMap: collection.Map[Instance, OfModule],
- constSubOutputs: Map[OfModule, Map[String, Literal]]
+ constSubOutputs: Map[OfModule, Map[String, Literal]],
+ disabledOps: Set[PrimOp]
)(e: Expression
): Expression = {
- val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs))
+ val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps))
val propagated = old match {
- case p: DoPrim => constPropPrim(p)
+ case p: DoPrim => constPropPrim(p, disabledOps)
case m: Mux => constPropMux(m)
case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
@@ -519,7 +527,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
}
// We're done when the Expression no longer changes
if (propagated eq old) propagated
- else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated)
+ else constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(propagated)
}
/** Hacky way of propagating source locators across nodes and connections that have just a
@@ -555,6 +563,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
* @param instMap map of instance names to Module name
* @param constInputs map of names of m's input ports to literal driving it (if applicable)
* @param constSubOutputs Map of Module name to Map of output port name to literal driving it
+ * @param disabledOps a Set of any PrimOps that should not be folded
* @return (Constpropped Module, Map of output port names to literal value,
* Map of submodule modulenames to Map of input port names to literal values)
*/
@@ -564,7 +573,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
dontTouches: Set[String],
instMap: collection.Map[Instance, OfModule],
constInputs: Map[String, Literal],
- constSubOutputs: Map[OfModule, Map[String, Literal]]
+ constSubOutputs: Map[OfModule, Map[String, Literal]],
+ disabledOps: Set[PrimOp]
): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
var nPropagated = 0L
@@ -646,7 +656,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
def constPropStmt(s: Statement): Statement = {
val s0 = s.map(constPropStmt) // Statement recurse
val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation
- val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions
+ // propagate sub-Expressions
+ val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps))
// Record things that should be propagated
stmtx match {
case DefNode(info, name, value) if !dontTouches.contains(name) =>
@@ -654,11 +665,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
case reg: DefRegister if reg.reset.tpe == AsyncResetType =>
asyncResetRegs(reg.name) = reg
case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) =>
- val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe))
+ val exprx = constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(expr, wtpe))
propagateRef(wname, exprx, info)
// Record constants driving outputs
case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) =>
- val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
+ val paddedLit =
+ constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal]
constOutputs(pname) = paddedLit
// Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns
// This requires that reset has been made explicit
@@ -714,7 +726,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
case _ =>
}
- def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe))
+ def padCPExp(e: Expression) =
+ constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(e, ltpe))
asyncResetRegs.get(lname) match {
// Normal Register
@@ -725,7 +738,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
// Mark instance inputs connected to a constant
case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) =>
- val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal]
+ val paddedLit =
+ constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal]
val module = instMap(inst.Instance)
val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty)
portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty)
@@ -750,7 +764,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
// When we call this function again, constOutputs and constSubInputs are reconstructed and
// strictly a superset of the versions here
- if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs)
+ if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs, disabledOps)
else (modx, constOutputs.toMap, constSubInputs.mapValues(_.toMap).toMap)
}
@@ -761,7 +775,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
}
- private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = {
+ private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]], disabledOps: Set[PrimOp]): Circuit = {
val iGraph = InstanceKeyGraph(c)
val moduleDeps = iGraph.getChildInstanceMap
val instCount = iGraph.staticInstanceCount
@@ -800,7 +814,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
dontTouches,
moduleDeps(mname),
constInputs.getOrElse(mname, Map.empty),
- constOutputs
+ constOutputs,
+ disabledOps
)
// Accumulate all Literals used to drive a particular Module port
val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d))
@@ -852,6 +867,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration {
val dontTouchMap: Map[OfModule, Set[String]] =
dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap
- state.copy(circuit = run(state.circuit, dontTouchMap))
+ val disabledOps = state.annotations.collect { case DisableFold(op) => op }.toSet
+
+ state.copy(circuit = run(state.circuit, dontTouchMap, disabledOps))
}
}
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 94973042..28c1d823 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -7,6 +7,7 @@ import firrtl.passes._
import firrtl.transforms._
import firrtl.testutils._
import firrtl.annotations.Annotation
+import firrtl.stage.DisableFold
class ConstantPropagationSpec extends FirrtlFlatSpec {
val transforms: Seq[Transform] =
@@ -798,6 +799,17 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
castCheck("Clock", "asClock")
castCheck("AsyncReset", "asAsyncReset")
}
+
+ /* */
+ "The rule a / a -> 1" should "be ignored if division folds are disabled" in {
+ val input =
+ """circuit foo:
+ | module foo:
+ | input a: UInt<8>
+ | output b: UInt<8>
+ | b <= div(a, a)""".stripMargin
+ (parse(exec(input, Seq(DisableFold(PrimOps.Div))))) should be(parse(input))
+ }
}
// More sophisticated tests of the full compiler