aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Laeufer2021-03-08 17:37:35 -0800
committerGitHub2021-03-09 01:37:35 +0000
commit8a4c156f401c8bfab5f2d595c32c20534f0722d7 (patch)
tree06c8ea2221e94d2bc2e281ffbc79a8aee177cf3f
parent29d57a612df69ae4a6db4b3755fc292e5a539e11 (diff)
SMT Backend: model Invalid and Division by Zero with DefRandom nodes (#2104)
This finally removes all randomization code from the transition system conversion and into a separate pass using DefRandom nodes.
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala17
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala49
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala125
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala56
4 files changed, 201 insertions, 46 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
index d85fbfe5..13e0c312 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
@@ -9,8 +9,6 @@ import firrtl.passes.CheckWidths.WidthTooBig
private trait TranslationContext {
def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe))
- def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe))
- def getRandom(width: Int): BVExpr
}
private object FirrtlExpressionSemantics {
@@ -34,9 +32,8 @@ private object FirrtlExpressionSemantics {
case ir.Mux(cond, tval, fval, _) =>
val width = List(tval, fval).map(getWidth).max
BVIte(toSMT(cond), toSMT(tval, width), toSMT(fval, width))
- case ir.ValidIf(cond, value, tpe) =>
- val tru = toSMT(value)
- BVIte(toSMT(cond), tru, ctx.getRandom(tpe))
+ case v: ir.ValidIf =>
+ throw new RuntimeException(s"Unsupported expression: ValidIf ${v.serialize}")
}
assert(
eSMT.width == getWidth(e),
@@ -81,15 +78,7 @@ private object FirrtlExpressionSemantics {
val (width, op) = if (isSigned(num)) {
(getWidth(num) + 1, Op.SignedDiv)
} else { (getWidth(num), Op.UnsignedDiv) }
- // "The result of a division where den is zero is undefined."
- val undef = ctx.getRandom(width)
- val denSMT = toSMT(den)
- val denIsZero = BVEqual(denSMT, BVLiteral(0, denSMT.width))
- val numByDen = BVOp(op, toSMT(num, width), forceWidth(denSMT, isSigned(den), width))
- BVIte(denIsZero, undef, numByDen)
- case (PrimOps.Div, Seq(num, den), _) if isSigned(num) =>
- val width = getWidth(num) + 1
- BVOp(Op.SignedDiv, toSMT(num, width), toSMT(den, width))
+ BVOp(op, toSMT(num, width), forceWidth(toSMT(den), isSigned(den), width))
case (PrimOps.Rem, Seq(num, den), _) =>
val op = if (isSigned(num)) Op.SignedRem else Op.UnsignedRem
val width = args.map(getWidth).max
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index d3a1ed68..fea92c75 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -62,7 +62,7 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
// TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare
// precisely and PadWidths emits ill-typed firrtl.
override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++
- Seq(Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays))
+ Seq(Dependency(InvalidToRandomPass), Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays))
override def invalidates(a: Transform): Boolean = false
// since this pass only runs on the main module, inlining needs to happen before
override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances])
@@ -147,7 +147,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
uninterpreted: Map[String, UninterpretedModuleAnnotation] = Map()
): TransitionSystem = {
// first pass over the module to convert expressions; discover state and I/O
- val scan = new ModuleScanner(makeRandom, uninterpreted)
+ val scan = new ModuleScanner(uninterpreted)
m.foreachPort(scan.onPort)
m.foreachStmt(scan.onStatement)
@@ -200,8 +200,8 @@ private class ModuleToTransitionSystem extends LazyLogging {
}
}
- // inputs are original module inputs and any "random" signal we need for modelling
- val inputs = scan.inputs ++ randoms.values
+ // inputs are original module inputs and any DefRandom signal
+ val inputs = scan.inputs
// module info to the comment header
val header = serializeInfo(m.info).map(InfoPrefix + _).toArray
@@ -350,21 +350,10 @@ private class ModuleToTransitionSystem extends LazyLogging {
if (infos.isEmpty) { None }
else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
}
-
- private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]()
- private def makeRandom(baseName: String, width: Int): BVExpr = {
- // TODO: actually ensure that there cannot be any name clashes with other identifiers
- val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii)
- val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get
- val sym = BVSymbol(name, width)
- randoms(name) = sym
- sym
- }
}
// performas a first pass over the module collecting all connections, wires, registers, input and outputs
private class ModuleScanner(
- makeRandom: (String, Int) => BVExpr,
uninterpreted: Map[String, UninterpretedModuleAnnotation])
extends LazyLogging {
import FirrtlExpressionSemantics.getWidth
@@ -429,7 +418,7 @@ private class ModuleScanner(
if (!isClock(expr.tpe)) {
insertDummyAssignsForUnusedOutputs(expr)
infos.append(name -> info)
- val e = onExpression(expr, name)
+ val e = onExpression(expr)
nodes.append(name)
connects.append((name, e))
}
@@ -439,8 +428,8 @@ private class ModuleScanner(
insertDummyAssignsForUnusedOutputs(init)
infos.append(name -> info)
val width = getWidth(tpe)
- val resetExpr = onExpression(reset, 1, name + "_reset")
- val initExpr = onExpression(init, width, name + "_init")
+ val resetExpr = onExpression(reset, 1)
+ val initExpr = onExpression(init, width)
registers.append((name, width, resetExpr, initExpr))
case m: ir.DefMemory =>
namespace.newName(m.name)
@@ -456,13 +445,11 @@ private class ModuleScanner(
val name = loc.serialize
insertDummyAssignsForUnusedOutputs(expr)
infos.append(name -> info)
- connects.append((name, onExpression(expr, getWidth(loc.tpe), name)))
+ connects.append((name, onExpression(expr, getWidth(loc.tpe))))
}
- case ir.IsInvalid(info, loc) =>
+ case i @ ir.IsInvalid(info, loc) =>
if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
- val name = loc.serialize
- infos.append(name -> info)
- connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe))))
+ throw new UnsupportedFeatureException(s"IsInvalid statements are not supported: ${i.serialize}")
case ir.DefInstance(info, name, module, tpe) => onInstance(info, name, module, tpe)
case s @ ir.Verification(op, info, _, pred, en, msg) =>
if (op == ir.Formal.Cover) {
@@ -471,8 +458,8 @@ private class ModuleScanner(
insertDummyAssignsForUnusedOutputs(pred)
insertDummyAssignsForUnusedOutputs(en)
val name = namespace.newName(msgToName(op.toString, msg.string))
- val predicate = onExpression(pred, name + "_predicate")
- val enabled = onExpression(en, name + "_enabled")
+ val predicate = onExpression(pred)
+ val enabled = onExpression(en)
val e = BVImplies(enabled, predicate)
infos.append(name -> info)
connects.append(name -> e)
@@ -596,16 +583,14 @@ private class ModuleScanner(
case other => other.foreachExpr(findUnusedOutputUse)
}
- private case class Context(baseName: String) extends TranslationContext {
- override def getRandom(width: Int): BVExpr = makeRandom(baseName, width)
- }
+ private case class Context() extends TranslationContext {}
- private def onExpression(e: ir.Expression, width: Int, randomPrefix: String): BVExpr = {
- implicit val ctx: TranslationContext = Context(randomPrefix)
+ private def onExpression(e: ir.Expression, width: Int): BVExpr = {
+ implicit val ctx: TranslationContext = Context()
FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false)
}
- private def onExpression(e: ir.Expression, randomPrefix: String): BVExpr = {
- implicit val ctx: TranslationContext = Context(randomPrefix)
+ private def onExpression(e: ir.Expression): BVExpr = {
+ implicit val ctx: TranslationContext = Context()
FirrtlExpressionSemantics.toSMT(e)
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala
new file mode 100644
index 00000000..c7eaad74
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala
@@ -0,0 +1,125 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl.backends.experimental.smt.random
+
+import firrtl._
+import firrtl.annotations.NoTargetAnnotation
+import firrtl.ir._
+import firrtl.passes._
+import firrtl.options.Dependency
+import firrtl.stage.Forms
+import firrtl.transforms.RemoveWires
+
+import scala.collection.mutable
+
+/** Chooses how to model explicit and implicit invalid values in the circuit */
+case class InvalidToRandomOptions(
+ randomizeInvalidSignals: Boolean = true,
+ randomizeDivisionByZero: Boolean = true)
+ extends NoTargetAnnotation
+
+/** Replaces all explicit and implicit "invalid" values with random values.
+ * Explicit invalids are:
+ * - signal is invalid
+ * - signal <= valid(..., expr)
+ * Implicit invalids are:
+ * - a / b when eq(b, 0)
+ */
+object InvalidToRandomPass extends Transform with DependencyAPIMigration {
+ override def prerequisites = Forms.LowForm
+ // once ValidIf has been removed, we can no longer detect and randomize them
+ override def optionalPrerequisiteOf = Seq(Dependency(RemoveValidIf))
+ override def invalidates(a: Transform) = a match {
+ // this pass might destroy SSA form, as we add a wire for the data field of every read port
+ case _: RemoveWires => true
+ // TODO: should we add some optimization passes here? we could be generating some dead code.
+ case _ => false
+ }
+
+ override protected def execute(state: CircuitState): CircuitState = {
+ val opts = state.annotations.collect { case o: InvalidToRandomOptions => o }
+ require(opts.size < 2, s"Multiple options: $opts")
+ val opt = opts.headOption.getOrElse(InvalidToRandomOptions())
+
+ // quick exit if we just want to skip this pass
+ if (!opt.randomizeDivisionByZero && !opt.randomizeInvalidSignals) {
+ state
+ } else {
+ val c = state.circuit.mapModule(onModule(_, opt))
+ state.copy(circuit = c)
+ }
+ }
+
+ private def onModule(m: DefModule, opt: InvalidToRandomOptions): DefModule = m match {
+ case d: DescribedMod =>
+ throw new RuntimeException(s"CompilerError: Unexpected internal node: ${d.serialize}")
+ case e: ExtModule => e
+ case mod: Module =>
+ val namespace = Namespace(mod)
+ mod.mapStmt(onStmt(namespace, opt, _))
+ }
+
+ private def onStmt(namespace: Namespace, opt: InvalidToRandomOptions, s: Statement): Statement = s match {
+ case IsInvalid(info, loc: RefLikeExpression) if opt.randomizeInvalidSignals =>
+ val name = namespace.newName(loc.serialize.replace('.', '_') + "_invalid")
+ val rand = DefRandom(info, name, loc.tpe, None)
+ Block(List(rand, Connect(info, loc, Reference(rand))))
+ case other =>
+ val info = other match {
+ case h: HasInfo => h.info
+ case _ => NoInfo
+ }
+ val prefix = other match {
+ case c: Connect => c.loc.serialize.replace('.', '_')
+ case h: HasName => h.name
+ case _ => ""
+ }
+ val ctx = ExprCtx(namespace, opt, prefix, info, mutable.ListBuffer[Statement]())
+ val stmt = other.mapExpr(onExpr(ctx, _)).mapStmt(onStmt(namespace, opt, _))
+ if (ctx.rands.isEmpty) { stmt }
+ else { Block(Block(ctx.rands.toList), stmt) }
+ }
+
+ private case class ExprCtx(
+ namespace: Namespace,
+ opt: InvalidToRandomOptions,
+ prefix: String,
+ info: Info,
+ rands: mutable.ListBuffer[Statement])
+
+ private def onExpr(ctx: ExprCtx, e: Expression): Expression =
+ e.mapExpr(onExpr(ctx, _)) match {
+ case ValidIf(_, value, tpe) if tpe == ClockType =>
+ // we currently assume that clocks are always valid
+ // TODO: is that a good assumption?
+ value
+ case ValidIf(cond, value, tpe) if ctx.opt.randomizeInvalidSignals =>
+ makeRand(ctx, cond, tpe, value, invert = true)
+ case d @ DoPrim(PrimOps.Div, Seq(_, den), _, tpe) if ctx.opt.randomizeDivisionByZero =>
+ val denIsZero = Utils.eq(den, Utils.getGroundZero(den.tpe.asInstanceOf[GroundType]))
+ makeRand(ctx, denIsZero, tpe, d, invert = false)
+ case other => other
+ }
+
+ private def makeRand(
+ ctx: ExprCtx,
+ cond: Expression,
+ tpe: Type,
+ value: Expression,
+ invert: Boolean
+ ): Expression = {
+ val name = ctx.namespace.newName(if (ctx.prefix.isEmpty) "invalid" else ctx.prefix + "_invalid")
+ // create a condition node if the condition isn't a reference already
+ val condRef = cond match {
+ case r: RefLikeExpression => if (invert) Utils.not(r) else r
+ case other =>
+ val cond = if (invert) Utils.not(other) else other
+ val condNode = DefNode(ctx.info, ctx.namespace.newName(name + "_cond"), cond)
+ ctx.rands.append(condNode)
+ Reference(condNode)
+ }
+ val rand = DefRandom(ctx.info, name, tpe, None, condRef)
+ ctx.rands.append(rand)
+ Utils.mux(condRef, Reference(rand), value)
+ }
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala
new file mode 100644
index 00000000..8f17a847
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala
@@ -0,0 +1,56 @@
+package firrtl.backends.experimental.smt.random
+
+import firrtl.options.Dependency
+import firrtl.testutils.LeanTransformSpec
+
+class InvalidToRandomSpec extends LeanTransformSpec(Seq(Dependency(InvalidToRandomPass))) {
+ behavior.of("InvalidToRandomPass")
+
+ val src1 =
+ s"""
+ |circuit Test:
+ | module Test:
+ | input a : UInt<2>
+ | output o : UInt<8>
+ | output o2 : UInt<8>
+ | output o3 : UInt<8>
+ |
+ | o is invalid
+ |
+ | when eq(a, UInt(3)):
+ | o <= UInt(5)
+ |
+ | o2 is invalid
+ | node o2_valid = eq(a, UInt(2))
+ | when o2_valid:
+ | o2 <= UInt(7)
+ |
+ | o3 is invalid
+ | o3 <= UInt(3)
+ |""".stripMargin
+
+ it should "model invalid signals as random" in {
+
+ val circuit = compile(src1, List()).circuit
+ //println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ // the condition should end up as a new node if it wasn't a reference already
+ assert(result.contains("node _GEN_0_invalid_cond = not(eq(a, UInt<2>(\"h3\")))"))
+ assert(result.contains("node o2_valid = eq(a, UInt<2>(\"h2\"))"))
+
+ // every invalid results in a random statement
+ assert(result.contains("rand _GEN_0_invalid : UInt<3> when _GEN_0_invalid_cond"))
+ assert(result.contains("rand _GEN_1_invalid : UInt<3> when not(o2_valid)"))
+
+ // the random value is conditionally assigned
+ assert(result.contains("node _GEN_0 = mux(_GEN_0_invalid_cond, _GEN_0_invalid, UInt<3>(\"h5\"))"))
+ assert(result.contains("node _GEN_1 = mux(not(o2_valid), _GEN_1_invalid, UInt<3>(\"h7\"))"))
+
+ // expressions that are trivially valid do not get randomized
+ assert(result.contains("o3 <= UInt<2>(\"h3\")"))
+ val defRandCount = result.count(_.contains("rand "))
+ assert(defRandCount == 2)
+ }
+
+}