diff options
| author | Kevin Laeufer | 2021-03-08 17:37:35 -0800 |
|---|---|---|
| committer | GitHub | 2021-03-09 01:37:35 +0000 |
| commit | 8a4c156f401c8bfab5f2d595c32c20534f0722d7 (patch) | |
| tree | 06c8ea2221e94d2bc2e281ffbc79a8aee177cf3f | |
| parent | 29d57a612df69ae4a6db4b3755fc292e5a539e11 (diff) | |
SMT Backend: model Invalid and Division by Zero with DefRandom nodes (#2104)
This finally removes all randomization code from the transition
system conversion and into a separate pass using DefRandom nodes.
4 files changed, 201 insertions, 46 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala index d85fbfe5..13e0c312 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala @@ -9,8 +9,6 @@ import firrtl.passes.CheckWidths.WidthTooBig private trait TranslationContext { def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe)) - def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe)) - def getRandom(width: Int): BVExpr } private object FirrtlExpressionSemantics { @@ -34,9 +32,8 @@ private object FirrtlExpressionSemantics { case ir.Mux(cond, tval, fval, _) => val width = List(tval, fval).map(getWidth).max BVIte(toSMT(cond), toSMT(tval, width), toSMT(fval, width)) - case ir.ValidIf(cond, value, tpe) => - val tru = toSMT(value) - BVIte(toSMT(cond), tru, ctx.getRandom(tpe)) + case v: ir.ValidIf => + throw new RuntimeException(s"Unsupported expression: ValidIf ${v.serialize}") } assert( eSMT.width == getWidth(e), @@ -81,15 +78,7 @@ private object FirrtlExpressionSemantics { val (width, op) = if (isSigned(num)) { (getWidth(num) + 1, Op.SignedDiv) } else { (getWidth(num), Op.UnsignedDiv) } - // "The result of a division where den is zero is undefined." - val undef = ctx.getRandom(width) - val denSMT = toSMT(den) - val denIsZero = BVEqual(denSMT, BVLiteral(0, denSMT.width)) - val numByDen = BVOp(op, toSMT(num, width), forceWidth(denSMT, isSigned(den), width)) - BVIte(denIsZero, undef, numByDen) - case (PrimOps.Div, Seq(num, den), _) if isSigned(num) => - val width = getWidth(num) + 1 - BVOp(Op.SignedDiv, toSMT(num, width), toSMT(den, width)) + BVOp(op, toSMT(num, width), forceWidth(toSMT(den), isSigned(den), width)) case (PrimOps.Rem, Seq(num, den), _) => val op = if (isSigned(num)) Op.SignedRem else Op.UnsignedRem val width = args.map(getWidth).max diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index d3a1ed68..fea92c75 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -62,7 +62,7 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { // TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare // precisely and PadWidths emits ill-typed firrtl. override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++ - Seq(Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays)) + Seq(Dependency(InvalidToRandomPass), Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays)) override def invalidates(a: Transform): Boolean = false // since this pass only runs on the main module, inlining needs to happen before override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) @@ -147,7 +147,7 @@ private class ModuleToTransitionSystem extends LazyLogging { uninterpreted: Map[String, UninterpretedModuleAnnotation] = Map() ): TransitionSystem = { // first pass over the module to convert expressions; discover state and I/O - val scan = new ModuleScanner(makeRandom, uninterpreted) + val scan = new ModuleScanner(uninterpreted) m.foreachPort(scan.onPort) m.foreachStmt(scan.onStatement) @@ -200,8 +200,8 @@ private class ModuleToTransitionSystem extends LazyLogging { } } - // inputs are original module inputs and any "random" signal we need for modelling - val inputs = scan.inputs ++ randoms.values + // inputs are original module inputs and any DefRandom signal + val inputs = scan.inputs // module info to the comment header val header = serializeInfo(m.info).map(InfoPrefix + _).toArray @@ -350,21 +350,10 @@ private class ModuleToTransitionSystem extends LazyLogging { if (infos.isEmpty) { None } else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } } - - private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]() - private def makeRandom(baseName: String, width: Int): BVExpr = { - // TODO: actually ensure that there cannot be any name clashes with other identifiers - val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii) - val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get - val sym = BVSymbol(name, width) - randoms(name) = sym - sym - } } // performas a first pass over the module collecting all connections, wires, registers, input and outputs private class ModuleScanner( - makeRandom: (String, Int) => BVExpr, uninterpreted: Map[String, UninterpretedModuleAnnotation]) extends LazyLogging { import FirrtlExpressionSemantics.getWidth @@ -429,7 +418,7 @@ private class ModuleScanner( if (!isClock(expr.tpe)) { insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) - val e = onExpression(expr, name) + val e = onExpression(expr) nodes.append(name) connects.append((name, e)) } @@ -439,8 +428,8 @@ private class ModuleScanner( insertDummyAssignsForUnusedOutputs(init) infos.append(name -> info) val width = getWidth(tpe) - val resetExpr = onExpression(reset, 1, name + "_reset") - val initExpr = onExpression(init, width, name + "_init") + val resetExpr = onExpression(reset, 1) + val initExpr = onExpression(init, width) registers.append((name, width, resetExpr, initExpr)) case m: ir.DefMemory => namespace.newName(m.name) @@ -456,13 +445,11 @@ private class ModuleScanner( val name = loc.serialize insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) - connects.append((name, onExpression(expr, getWidth(loc.tpe), name))) + connects.append((name, onExpression(expr, getWidth(loc.tpe)))) } - case ir.IsInvalid(info, loc) => + case i @ ir.IsInvalid(info, loc) => if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") - val name = loc.serialize - infos.append(name -> info) - connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe)))) + throw new UnsupportedFeatureException(s"IsInvalid statements are not supported: ${i.serialize}") case ir.DefInstance(info, name, module, tpe) => onInstance(info, name, module, tpe) case s @ ir.Verification(op, info, _, pred, en, msg) => if (op == ir.Formal.Cover) { @@ -471,8 +458,8 @@ private class ModuleScanner( insertDummyAssignsForUnusedOutputs(pred) insertDummyAssignsForUnusedOutputs(en) val name = namespace.newName(msgToName(op.toString, msg.string)) - val predicate = onExpression(pred, name + "_predicate") - val enabled = onExpression(en, name + "_enabled") + val predicate = onExpression(pred) + val enabled = onExpression(en) val e = BVImplies(enabled, predicate) infos.append(name -> info) connects.append(name -> e) @@ -596,16 +583,14 @@ private class ModuleScanner( case other => other.foreachExpr(findUnusedOutputUse) } - private case class Context(baseName: String) extends TranslationContext { - override def getRandom(width: Int): BVExpr = makeRandom(baseName, width) - } + private case class Context() extends TranslationContext {} - private def onExpression(e: ir.Expression, width: Int, randomPrefix: String): BVExpr = { - implicit val ctx: TranslationContext = Context(randomPrefix) + private def onExpression(e: ir.Expression, width: Int): BVExpr = { + implicit val ctx: TranslationContext = Context() FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false) } - private def onExpression(e: ir.Expression, randomPrefix: String): BVExpr = { - implicit val ctx: TranslationContext = Context(randomPrefix) + private def onExpression(e: ir.Expression): BVExpr = { + implicit val ctx: TranslationContext = Context() FirrtlExpressionSemantics.toSMT(e) } diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala new file mode 100644 index 00000000..c7eaad74 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.experimental.smt.random + +import firrtl._ +import firrtl.annotations.NoTargetAnnotation +import firrtl.ir._ +import firrtl.passes._ +import firrtl.options.Dependency +import firrtl.stage.Forms +import firrtl.transforms.RemoveWires + +import scala.collection.mutable + +/** Chooses how to model explicit and implicit invalid values in the circuit */ +case class InvalidToRandomOptions( + randomizeInvalidSignals: Boolean = true, + randomizeDivisionByZero: Boolean = true) + extends NoTargetAnnotation + +/** Replaces all explicit and implicit "invalid" values with random values. + * Explicit invalids are: + * - signal is invalid + * - signal <= valid(..., expr) + * Implicit invalids are: + * - a / b when eq(b, 0) + */ +object InvalidToRandomPass extends Transform with DependencyAPIMigration { + override def prerequisites = Forms.LowForm + // once ValidIf has been removed, we can no longer detect and randomize them + override def optionalPrerequisiteOf = Seq(Dependency(RemoveValidIf)) + override def invalidates(a: Transform) = a match { + // this pass might destroy SSA form, as we add a wire for the data field of every read port + case _: RemoveWires => true + // TODO: should we add some optimization passes here? we could be generating some dead code. + case _ => false + } + + override protected def execute(state: CircuitState): CircuitState = { + val opts = state.annotations.collect { case o: InvalidToRandomOptions => o } + require(opts.size < 2, s"Multiple options: $opts") + val opt = opts.headOption.getOrElse(InvalidToRandomOptions()) + + // quick exit if we just want to skip this pass + if (!opt.randomizeDivisionByZero && !opt.randomizeInvalidSignals) { + state + } else { + val c = state.circuit.mapModule(onModule(_, opt)) + state.copy(circuit = c) + } + } + + private def onModule(m: DefModule, opt: InvalidToRandomOptions): DefModule = m match { + case d: DescribedMod => + throw new RuntimeException(s"CompilerError: Unexpected internal node: ${d.serialize}") + case e: ExtModule => e + case mod: Module => + val namespace = Namespace(mod) + mod.mapStmt(onStmt(namespace, opt, _)) + } + + private def onStmt(namespace: Namespace, opt: InvalidToRandomOptions, s: Statement): Statement = s match { + case IsInvalid(info, loc: RefLikeExpression) if opt.randomizeInvalidSignals => + val name = namespace.newName(loc.serialize.replace('.', '_') + "_invalid") + val rand = DefRandom(info, name, loc.tpe, None) + Block(List(rand, Connect(info, loc, Reference(rand)))) + case other => + val info = other match { + case h: HasInfo => h.info + case _ => NoInfo + } + val prefix = other match { + case c: Connect => c.loc.serialize.replace('.', '_') + case h: HasName => h.name + case _ => "" + } + val ctx = ExprCtx(namespace, opt, prefix, info, mutable.ListBuffer[Statement]()) + val stmt = other.mapExpr(onExpr(ctx, _)).mapStmt(onStmt(namespace, opt, _)) + if (ctx.rands.isEmpty) { stmt } + else { Block(Block(ctx.rands.toList), stmt) } + } + + private case class ExprCtx( + namespace: Namespace, + opt: InvalidToRandomOptions, + prefix: String, + info: Info, + rands: mutable.ListBuffer[Statement]) + + private def onExpr(ctx: ExprCtx, e: Expression): Expression = + e.mapExpr(onExpr(ctx, _)) match { + case ValidIf(_, value, tpe) if tpe == ClockType => + // we currently assume that clocks are always valid + // TODO: is that a good assumption? + value + case ValidIf(cond, value, tpe) if ctx.opt.randomizeInvalidSignals => + makeRand(ctx, cond, tpe, value, invert = true) + case d @ DoPrim(PrimOps.Div, Seq(_, den), _, tpe) if ctx.opt.randomizeDivisionByZero => + val denIsZero = Utils.eq(den, Utils.getGroundZero(den.tpe.asInstanceOf[GroundType])) + makeRand(ctx, denIsZero, tpe, d, invert = false) + case other => other + } + + private def makeRand( + ctx: ExprCtx, + cond: Expression, + tpe: Type, + value: Expression, + invert: Boolean + ): Expression = { + val name = ctx.namespace.newName(if (ctx.prefix.isEmpty) "invalid" else ctx.prefix + "_invalid") + // create a condition node if the condition isn't a reference already + val condRef = cond match { + case r: RefLikeExpression => if (invert) Utils.not(r) else r + case other => + val cond = if (invert) Utils.not(other) else other + val condNode = DefNode(ctx.info, ctx.namespace.newName(name + "_cond"), cond) + ctx.rands.append(condNode) + Reference(condNode) + } + val rand = DefRandom(ctx.info, name, tpe, None, condRef) + ctx.rands.append(rand) + Utils.mux(condRef, Reference(rand), value) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala new file mode 100644 index 00000000..8f17a847 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala @@ -0,0 +1,56 @@ +package firrtl.backends.experimental.smt.random + +import firrtl.options.Dependency +import firrtl.testutils.LeanTransformSpec + +class InvalidToRandomSpec extends LeanTransformSpec(Seq(Dependency(InvalidToRandomPass))) { + behavior.of("InvalidToRandomPass") + + val src1 = + s""" + |circuit Test: + | module Test: + | input a : UInt<2> + | output o : UInt<8> + | output o2 : UInt<8> + | output o3 : UInt<8> + | + | o is invalid + | + | when eq(a, UInt(3)): + | o <= UInt(5) + | + | o2 is invalid + | node o2_valid = eq(a, UInt(2)) + | when o2_valid: + | o2 <= UInt(7) + | + | o3 is invalid + | o3 <= UInt(3) + |""".stripMargin + + it should "model invalid signals as random" in { + + val circuit = compile(src1, List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // the condition should end up as a new node if it wasn't a reference already + assert(result.contains("node _GEN_0_invalid_cond = not(eq(a, UInt<2>(\"h3\")))")) + assert(result.contains("node o2_valid = eq(a, UInt<2>(\"h2\"))")) + + // every invalid results in a random statement + assert(result.contains("rand _GEN_0_invalid : UInt<3> when _GEN_0_invalid_cond")) + assert(result.contains("rand _GEN_1_invalid : UInt<3> when not(o2_valid)")) + + // the random value is conditionally assigned + assert(result.contains("node _GEN_0 = mux(_GEN_0_invalid_cond, _GEN_0_invalid, UInt<3>(\"h5\"))")) + assert(result.contains("node _GEN_1 = mux(not(o2_valid), _GEN_1_invalid, UInt<3>(\"h7\"))")) + + // expressions that are trivially valid do not get randomized + assert(result.contains("o3 <= UInt<2>(\"h3\")")) + val defRandCount = result.count(_.contains("rand ")) + assert(defRandCount == 2) + } + +} |
