diff options
| author | Kevin Laeufer | 2020-08-14 18:39:42 -0700 |
|---|---|---|
| committer | GitHub | 2020-08-15 01:39:42 +0000 |
| commit | 2e5f942d25d7afab79ee1263c5d6833cad9d743d (patch) | |
| tree | add86d0b4b090807b48bb2307d10f2b7b38e0bce | |
| parent | 1b48fe5f5e94bdfdef700956e45d478b5706f25e (diff) | |
experimental SMTLib and btor2 emitter (#1826)
This adds an experimental new SMTLib and Btor2 emitter
that converts a firrtl module into a format
suitable for open source model checkers.
The format generally follows the behavior of yosys'
write_smt2 and write_btor commands.
To generate btor2 for the module in m.fir run
> ./utils/bin/firrtl -i m.fir -E experimental-btor2
for SMT:
> ./utils/bin/firrtl -i m.fir -E experimental-smt2
If you have a design with multiple clocks
or an asynchronous reset, try out the new StutteringClockTransform.
You can designate any input of type Clock to be your
global simulation clock using the new GlobalClockAnnotation.
If your toplevel module instantiates submodules,
you need to inline them if you want the submodule
logic to be included in the formal model.
23 files changed, 3062 insertions, 2 deletions
diff --git a/.install_z3.sh b/.install_z3.sh new file mode 100644 index 00000000..13ffb90e --- /dev/null +++ b/.install_z3.sh @@ -0,0 +1,9 @@ +set -e +# Install Z3 (https://github.com/Z3Prover/z3) +if [ ! -f $INSTALL_DIR/bin/z3 ]; then + mkdir -p $INSTALL_DIR + # download prebuilt binary + wget https://github.com/Z3Prover/z3/releases/download/z3-4.8.8/z3-4.8.8-x64-ubuntu-16.04.zip + unzip z3-4.8.8-x64-ubuntu-16.04.zip + mv ./z3-4.8.8-x64-ubuntu-16.04/bin/z3 $INSTALL_DIR/bin/z3 +fi diff --git a/.travis.yml b/.travis.yml index 73987565..9163efa4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,12 +38,14 @@ jobs: # Because these write to the same install directory, they must run in the # same script - stage: prepare - name: "Install: [Verilator, Yosys]" + name: "Install: [Verilator, Yosys, Z3]" script: - bash .install_verilator.sh - verilator --version - bash .install_yosys.sh - yosys -V + - bash .install_z3.sh + - z3 -version - stage: prepare name: "Compile FIRRTL to share with subsequent stages" script: diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 432fa59e..ae9a7dad 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -14,7 +14,8 @@ import firrtl.PrimOps._ import firrtl.WrappedExpression._ import Utils._ import MemPortUtils.{memPortField, memType} -import firrtl.options.{HasShellOptions, CustomFileEmission, ShellOption, PhaseException} +import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} +import firrtl.options.{CustomFileEmission, HasShellOptions, PhaseException, ShellOption} import firrtl.options.Viewer.view import firrtl.stage.{FirrtlFileAnnotation, FirrtlOptions, RunFirrtlTransformAnnotation, TransformManager} // Datastructures @@ -49,9 +50,14 @@ object EmitCircuitAnnotation extends HasShellOptions { EmitCircuitAnnotation(classOf[VerilogEmitter])) case "sverilog" => Seq(RunFirrtlTransformAnnotation(new SystemVerilogEmitter), EmitCircuitAnnotation(classOf[SystemVerilogEmitter])) + case "experimental-btor2" => Seq(RunFirrtlTransformAnnotation(new Btor2Emitter), + EmitCircuitAnnotation(classOf[Btor2Emitter])) + case "experimental-smt2" => Seq(RunFirrtlTransformAnnotation(new SMTLibEmitter), + EmitCircuitAnnotation(classOf[SMTLibEmitter])) case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") }, helpText = "Run the specified circuit emitter (all modules in one file)", shortOption = Some("E"), + // the experimental options are intentionally excluded from the help message helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>") ) ) } diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala new file mode 100644 index 00000000..f7ab9927 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala @@ -0,0 +1,189 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import scala.collection.mutable + +private object Btor2Serializer { + def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = { + new Btor2Serializer().run(sys, skipOutput) + } +} + +private class Btor2Serializer private () { + private val symbols = mutable.HashMap[String, Int]() + private val lines = mutable.ArrayBuffer[String]() + private var index = 1 + + private def line(l: String): Int = { + val ii = index + lines += s"$ii $l" + index += 1 + ii + } + + private def comment(c: String): Unit = { lines += s"; $c" } + private def trailingComment(c: String): Unit = { + val lastLine = lines.last + val newLine = if(lastLine.contains(';')) { lastLine + " " + c} else { lastLine + " ; " + c } + lines(lines.size - 1) = newLine + } + + // bit vector type serialization + private val bitVecTypeCache = mutable.HashMap[Int, Int]() + + private def t(width: Int): Int = bitVecTypeCache.getOrElseUpdate(width, line(s"sort bitvec $width")) + + // bit vector expression serialization + private def s(expr: BVExpr): Int = expr match { + case BVLiteral(value, width) => lit(value, width) + case BVSymbol(name, _) => symbols(name) + case BVExtend(e, 0, _) => s(e) + case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by") + case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by") + case BVSlice(e, hi, lo) => + if (lo == 0 && hi == e.width - 1) { s(e) } else { + line(s"slice ${t(expr.width)} ${s(e)} $hi $lo") + } + case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b) + case BVNot(BVNot(e)) => s(e) + case BVNot(e) => unary("not", expr.width, e) + case BVNegate(e) => unary("neg", expr.width, e) + case BVReduceAnd(e) => unary("redand", expr.width, e) + case BVReduceOr(e) => unary("redor", expr.width, e) + case BVReduceXor(e) => unary("redxor", expr.width, e) + case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b) + case BVImplies(a, b) => binary("implies", expr.width, a, b) + case BVEqual(a, b) => binary("eq", expr.width, a, b) + case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}") + case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b) + case BVComparison(Compare.GreaterEqual, a, b, false) => binary("ugte", expr.width, a, b) + case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b) + case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b) + case BVOp(op, a, b) => binary(s(op), expr.width, a, b) + case BVConcat(a, b) => binary("concat", expr.width, a, b) + case ArrayRead(array, index) => + line(s"read ${t(expr.width)} ${s(array)} ${s(index)}") + case BVIte(cond, tru, fals) => + line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}") + case r : BVRawExpr => + throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}") + } + + private def s(op: Op.Value): String = op match { + case Op.And => "and" + case Op.Or => "or" + case Op.Xor => "xor" + case Op.ArithmeticShiftRight => "sra" + case Op.ShiftRight => "srl" + case Op.ShiftLeft => "sll" + case Op.Add => "add" + case Op.Mul => "mul" + case Op.Sub => "sub" + case Op.SignedDiv => "sdiv" + case Op.UnsignedDiv => "udiv" + case Op.SignedMod => "smod" + case Op.SignedRem => "srem" + case Op.UnsignedRem => "urem" + } + + private def unary(op: String, width: Int, e: BVExpr): Int = line(s"$op ${t(width)} ${s(e)}") + + private def binary(op: String, width: Int, a: BVExpr, b: BVExpr): Int = + line(s"$op ${t(width)} ${s(a)} ${s(b)}") + + private def lit(value: BigInt, w: Int): Int = { + val typ = t(w) + lazy val mask = (BigInt(1) << w) - 1 + if (value == 0) line(s"zero $typ") + else if (value == 1) line(s"one $typ") + else if (value == mask) line(s"ones $typ") + else { + val digits = value.toString(2) + val padded = digits.reverse.padTo(w, '0').reverse + line(s"const $typ $padded") + } + } + + // array type serialization + private val arrayTypeCache = mutable.HashMap[(Int, Int), Int]() + + private def t(indexWidth: Int, dataWidth: Int): Int = + arrayTypeCache.getOrElseUpdate((indexWidth, dataWidth), line(s"sort array ${t(indexWidth)} ${t(dataWidth)}")) + + // array expression serialization + private def s(expr: ArrayExpr): Int = expr match { + case ArraySymbol(name, _, _) => symbols(name) + case ArrayStore(array, index, data) => + line(s"write ${t(expr.indexWidth, expr.dataWidth)} ${s(array)} ${s(index)} ${s(data)}") + case ArrayIte(cond, tru, fals) => + // println("WARN: ITE on array is probably not supported by btor2") + // While the spec does not seem to allow array ite, it seems to be supported in practice. + // It is essential to model memories, so any support in the wild should be fairly well tested. + line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}") + case ArrayConstant(e, _) => s(e) + case r : ArrayRawExpr => + throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}") + } + + private def s(expr: SMTExpr): Int = expr match { + case b: BVExpr => s(b) + case a: ArrayExpr => s(a) + } + + // serialize the type of the expression + private def t(expr: SMTExpr): Int = expr match { + case b: BVExpr => t(b.width) + case a: ArrayExpr => t(a.indexWidth, a.dataWidth) + } + + def run(sys: TransitionSystem, skipOutput: Boolean): Iterable[String] = { + def declare(name: String, expr: => Int): Unit = { + assert(!symbols.contains(name), s"Trying to redeclare `$name`") + val id = expr + symbols(name) = id + if (!skipOutput && sys.outputs.contains(name)) line(s"output $id ; $name") + if (sys.assumes.contains(name)) line(s"constraint $id ; $name") + if (sys.asserts.contains(name)){ + val invertedId = line(s"not ${t(1)} $id") + line(s"bad $invertedId ; $name") + } + if (sys.fair.contains(name)) line(s"fair $id ; $name") + // add trailing comment + sys.comments.get(name).foreach(trailingComment) + } + + // header + sys.header.foreach(comment) + + // declare inputs + sys.inputs.foreach { ii => + declare(ii.name, line(s"input ${t(ii.width)} ${ii.name}")) + } + + // define state init + sys.states.foreach { st => + // calculate init expression before declaring the state + // this is required by btormc (presumably to avoid cycles in the init expression) + val initId = st.init.map { init => comment(s"${st.sym}.init"); s(init) } + declare(st.sym.name, line(s"state ${t(st.sym)} ${st.sym.name}")) + st.init.foreach { init => line(s"init ${t(init)} ${s(st.sym)} ${initId.get}") } + } + + // define all other signals + sys.signals.foreach { signal => + declare(signal.name, s(signal.e)) + } + + // define state next + sys.states.foreach { st => + st.next.foreach { next => + comment(s"${st.sym}.next") + line(s"next ${t(next)} ${s(st.sym)} ${s(next)}") + } + } + + lines + } +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala new file mode 100644 index 00000000..0a223840 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala @@ -0,0 +1,182 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import firrtl.ir +import firrtl.PrimOps +import firrtl.passes.CheckWidths.WidthTooBig + +private trait TranslationContext { + def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe)) + def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe)) + def getRandom(width: Int): BVExpr +} + +private object FirrtlExpressionSemantics { + def getWidth(tpe: ir.Type): Int = tpe match { + case ir.UIntType(ir.IntWidth(w)) => w.toInt + case ir.SIntType(ir.IntWidth(w)) => w.toInt + case ir.ClockType => 1 + case ir.ResetType => 1 + case ir.AnalogType(ir.IntWidth(w)) => w.toInt + case other => throw new RuntimeException(s"Cannot handle type $other") + } + + def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = { + val eSMT = e match { + case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts) + case r : ir.Reference => ctx.getReference(r.serialize, r.tpe) + case r : ir.SubField => ctx.getReference(r.serialize, r.tpe) + case r : ir.SubIndex => ctx.getReference(r.serialize, r.tpe) + case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt) + case ir.SIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt) + case ir.Mux(cond, tval, fval, _) => + val width = List(tval, fval).map(getWidth).max + BVIte(toSMT(cond), toSMT(tval, width), toSMT(fval, width)) + case ir.ValidIf(cond, value, tpe) => + val tru = toSMT(value) + BVIte(toSMT(cond), tru, ctx.getRandom(tpe)) + } + assert(eSMT.width == getWidth(e), "We aim to always produce a SMT expression of the same width as the firrtl expression.") + eSMT + } + + /** Ensures that the result has the desired width by appropriately extending it. */ + def toSMT(e: ir.Expression, width: Int, allowNarrow: Boolean = false)(implicit ctx: TranslationContext): BVExpr = + forceWidth(toSMT(e), isSigned(e), width, allowNarrow) + + private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = { + if(eSMT.width == width) { eSMT } + else if(width < eSMT.width) { + assert(allowNarrow, s"Narrowing from ${eSMT.width} bits to $width bits is not allowed!") + BVSlice(eSMT, width - 1, 0) + } else { + BVExtend(eSMT, width - eSMT.width, eSigned) + } + } + + // see "Primitive Operations" section in the Firrtl Specification + private def onPrim(op: ir.PrimOp, args: Seq[ir.Expression], consts: Seq[BigInt])(implicit ctx: TranslationContext): + BVExpr = { + (op, args, consts) match { + case (PrimOps.Add, Seq(e1, e2), _) => + val width = args.map(getWidth).max + 1 + BVOp(Op.Add, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Sub, Seq(e1, e2), _) => + val width = args.map(getWidth).max + 1 + BVOp(Op.Sub, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Mul, Seq(e1, e2), _) => + val width = args.map(getWidth).sum + BVOp(Op.Mul, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Div, Seq(num, den), _) => + val (width, op) = if(isSigned(num)) { + (getWidth(num) + 1, Op.SignedDiv) + } else { (getWidth(num), Op.UnsignedDiv) } + // "The result of a division where den is zero is undefined." + val undef = ctx.getRandom(width) + val denSMT = toSMT(den) + val denIsZero = BVEqual(denSMT, BVLiteral(0, denSMT.width)) + val numByDen = BVOp(op, toSMT(num, width), forceWidth(denSMT, isSigned(den), width)) + BVIte(denIsZero, undef, numByDen) + case (PrimOps.Div, Seq(num, den), _) if isSigned(num) => + val width = getWidth(num) + 1 + BVOp(Op.SignedDiv, toSMT(num, width), toSMT(den, width)) + case (PrimOps.Rem, Seq(num, den), _) => + val op = if(isSigned(num)) Op.SignedRem else Op.UnsignedRem + val width = args.map(getWidth).max + val resWidth = args.map(getWidth).min + val res = BVOp(op, toSMT(num, width), toSMT(den, width)) + if(res.width > resWidth) { BVSlice(res, resWidth - 1, 0) } else { res } + case (PrimOps.Lt, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVNot(BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1))) + case (PrimOps.Leq, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVNot(BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1))) + case (PrimOps.Gt, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1)) + case (PrimOps.Geq, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1)) + case (PrimOps.Eq, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVEqual(toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Neq, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVNot(BVEqual(toSMT(e1, width), toSMT(e2, width))) + case (PrimOps.Pad, Seq(e), Seq(n)) => + val width = getWidth(e) + if(n <= width) { toSMT(e) } else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) } + case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e) ; toSMT(e) + case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e) ; toSMT(e) + case (PrimOps.AsFixedPoint, Seq(e), _) => throw new AssertionError("Fixed-Point numbers need to be lowered!") + case (PrimOps.AsClock, Seq(e), _) => toSMT(e) + case (PrimOps.AsAsyncReset, Seq(e), _) => + checkForClockInCast(PrimOps.AsAsyncReset, e) + throw new AssertionError(s"Asynchronous resets are not supported! Cannot cast ${e.serialize}.") + case (PrimOps.Shl, Seq(e), Seq(n)) => if(n == 0) { toSMT(e) } else { + val zeros = BVLiteral(0, n.toInt) + BVConcat(toSMT(e), zeros) + } + case (PrimOps.Shr, Seq(e), Seq(n)) => + val width = getWidth(e) + // "If n is greater than or equal to the bit-width of e, + // the resulting value will be zero for unsigned types + // and the sign bit for signed types" + if(n >= width) { + if(isSigned(e)) { BV1BitZero } else { BVSlice(toSMT(e), width - 1, width - 1) } + } else { + BVSlice(toSMT(e), width - 1, n.toInt) + } + case (PrimOps.Dshl, Seq(e1, e2), _) => + val width = getWidth(e1) + (1 << getWidth(e2)) - 1 + BVOp(Op.ShiftLeft, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Dshr, Seq(e1, e2), _) => + val width = getWidth(e1) + val o = if(isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight + BVOp(o, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Cvt, Seq(e), _) => if(isSigned(e)) { toSMT(e) } else { BVConcat(BV1BitZero, toSMT(e)) } + case (PrimOps.Neg, Seq(e), _) => BVNegate(BVExtend(toSMT(e), 1, isSigned(e))) + case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e)) + case (PrimOps.And, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVOp(Op.And, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Or, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVOp(Op.Or, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Xor, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e)) + case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e)) + case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e)) + case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2)) + case (PrimOps.Bits, Seq(e), Seq(hi, lo)) => BVSlice(toSMT(e), hi.toInt, lo.toInt) + case (PrimOps.Head, Seq(e), Seq(n)) => + val width = getWidth(e) + assert(n >= 0 && n <= width) + BVSlice(toSMT(e), width - 1, width - n.toInt) + case (PrimOps.Tail, Seq(e), Seq(n)) => + val width = getWidth(e) + assert(n >= 0 && n <= width) + assert(n < width, "While allowed by the firrtl standard, we do not support 0-bit values in this backend!") + BVSlice(toSMT(e), width - n.toInt - 1, 0) + } + } + + /** For now we strictly forbid casting clocks to anything else. + * Eventually this should be replaced by a more sophisticated clock analysis pass. */ + private def checkForClockInCast(cast: ir.PrimOp, signal: ir.Expression): Unit = { + assert(signal.tpe != ir.ClockType, s"Cannot cast (${cast.serialize}) clock expression ${signal.serialize}!") + } + + private val BV1BitZero = BVLiteral(0, 1) + + def isSigned(e: ir.Expression): Boolean = e.tpe match { + case _: ir.SIntType => true + case _ => false + } + private def getWidth(e: ir.Expression): Int = getWidth(e.tpe) +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala new file mode 100644 index 00000000..b3a2ff17 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -0,0 +1,605 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation} +import FirrtlExpressionSemantics.getWidth +import firrtl.graph.MutableDiGraph +import firrtl.options.Dependency +import firrtl.passes.PassException +import firrtl.stage.Forms +import firrtl.stage.TransformManager.TransformDependency +import firrtl.transforms.PropagatePresetAnnotations +import firrtl.{CircuitState, DependencyAPIMigration, MemoryArrayInit, MemoryInitValue, MemoryScalarInit, Transform, Utils, ir} +import logger.LazyLogging + +import scala.collection.mutable + +// Contains code to convert a flat firrtl module into a functional transition system which +// can then be exported as SMTLib or Btor2 file. + +private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr]) +private case class Signal(name: String, e: BVExpr) { def toSymbol: BVSymbol = BVSymbol(name, e.width) } +private case class TransitionSystem( + name: String, inputs: Array[BVSymbol], states: Array[State], signals: Array[Signal], + outputs: Set[String], assumes: Set[String], asserts: Set[String], fair: Set[String], + comments: Map[String, String] = Map(), header: Array[String] = Array()) { + def serialize: String = { + (Iterator(name) ++ + inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++ + signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++ + states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}") + ).mkString("\n") + } +} + +private case class TransitionSystemAnnotation(sys: TransitionSystem) extends NoTargetAnnotation + +object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { + // TODO: We only really need [[Forms.MidForm]] + LowerTypes, but we also want to fail if there are CombLoops + // TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare + // precisely and PadWidths emits ill-typed firrtl. + override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm + override def invalidates(a: Transform): Boolean = false + // since this pass only runs on the main module, inlining needs to happen before + override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) + + // We run the propagate preset annotations pass manually since we do not want to remove ValidIfs and other + // Verilog emission passes. + // Ideally we would go in and enable the [[PropagatePresetAnnotations]] to only depend on LowForm. + private val presetPass = new PropagatePresetAnnotations + override protected def execute(state: CircuitState): CircuitState = { + // run the preset pass to extract all preset registers and remove preset reset signals + val afterPreset = presetPass.execute(state) + val circuit = afterPreset.circuit + val presetRegs = afterPreset.annotations + .collect { case PresetRegAnnotation(target) if target.module == circuit.main => target.ref }.toSet + + // collect all non-random memory initialization + val memInit = afterPreset.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a } + .filter(_.target.module == circuit.main).map(a => a.target.ref -> a.initValue).toMap + + // convert the main module + val main = circuit.modules.find(_.name == circuit.main).get + val sys = main match { + case x: ir.ExtModule => + throw new ExtModuleException( + "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog.") + case m: ir.Module => + new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit=memInit) + } + + val sortedSys = TopologicalSort.run(sys) + val anno = TransitionSystemAnnotation(sortedSys) + state.copy(circuit=circuit, annotations = afterPreset.annotations :+ anno ) + } +} + +private object UnsupportedException { + val HowToRunStuttering: String = + """ + |You can run the StutteringClockTransform which + |replaces all clock inputs with a clock enable signal. + |This is required not only for multi-clock designs, but also to + |accurately model asynchronous reset which could happen even if there + |isn't a clock edge. + | If you are using the firrtl CLI, please add: + | -fct firrtl.backends.experimental.smt.StutteringClockTransform + | If you are calling into firrtl programmatically you can use: + | RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]) + | To designate a clock to be the global_clock (i.e. the simulation tick), use: + | GlobalClockAnnotation(CircuitTarget(...).module(...).ref("your_clock"))) + |""".stripMargin +} + +private class ExtModuleException(s: String) extends PassException(s) +private class AsyncResetException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering) +private class MultiClockException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering) +private class MissingFeatureException(s: String) extends PassException("Unfortunately the SMT backend does not yet support: " + s) + +private class ModuleToTransitionSystem extends LazyLogging { + + def run(m: ir.Module, presetRegs: Set[String] = Set(), memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + // first pass over the module to convert expressions; discover state and I/O + val scan = new ModuleScanner(makeRandom) + m.foreachPort(scan.onPort) + // multi-clock support requires the StutteringClock transform to be run + if(scan.clocks.size > 1) { + throw new MultiClockException(s"The module ${m.name} has more than one clock: ${scan.clocks.mkString(", ")}") + } + m.foreachStmt(scan.onStatement) + + // turn wires and nodes into signals + val outputs = scan.outputs.toSet + val constraints = scan.assumes.toSet + val bad = scan.asserts.toSet + val isSignal = (scan.wires ++ scan.nodes ++ scan.memSignals).toSet ++ outputs ++ constraints ++ bad + val signals = scan.connects.filter{ case(name, _) => isSignal.contains(name) } + .map { case (name, expr) => Signal(name, expr) } + + // turn registers and memories into states + val registers = scan.registers.map(r => r._1 -> r).toMap + val regStates = scan.connects.filter(s => registers.contains(s._1)).map { case (name, nextExpr) => + val (_, width, resetExpr, initExpr) = registers(name) + onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs) + } + // turn memories into state + val memoryEncoding = new MemoryEncoding(makeRandom) + val memoryStatesAndOutputs = scan.memories.map(m => memoryEncoding.onMemory(m, scan.connects, memInit.get(m.name))) + // replace pseudo assigns for memory outputs + val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap + val signalsWithMem = signals.map { s => + if (memOutputs.contains(s.name)) { + s.copy(e = memOutputs(s.name)) + } else { s } + } + // filter out any left-over self assignments (this happens when we have a registered read port) + .filter(s => s match { case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false case _ => true }) + val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1) + + // generate comments from infos + val comments = mutable.HashMap[String, String]() + scan.infos.foreach { case (name, info) => + serializeInfo(info).foreach { infoString => + if(comments.contains(name)) { comments(name) += InfoSeparator + infoString } + else { comments(name) = InfoPrefix + infoString } + } + } + + // inputs are original module inputs and any "random" signal we need for modelling + val inputs = scan.inputs ++ randoms.values + + // module info to the comment header + val header = serializeInfo(m.info).map(InfoPrefix + _).toArray + + val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints + TransitionSystem(m.name, inputs.toArray, states, signalsWithMem.toArray, outputs, constraints, bad, fair, comments.toMap, header) + } + + private def onRegister(name: String, width: Int, resetExpr: BVExpr, initExpr: BVExpr, + nextExpr: BVExpr, presetRegs: Set[String]): State = { + assert(initExpr.width == width) + assert(nextExpr.width == width) + assert(resetExpr.width == 1) + val sym = BVSymbol(name, width) + val hasReset = initExpr != sym + val isPreset = presetRegs.contains(name) + assert(!isPreset || hasReset, s"Expected preset register $name to have a reset value, not just $initExpr!") + if(hasReset) { + val init = if(isPreset) Some(initExpr) else None + val next = if(isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr) + State(sym, next = Some(next), init = init) + } else { + State(sym, next = Some(nextExpr), init = None) + } + } + + private val InfoSeparator = ", " + private val InfoPrefix = "@ " + private def serializeInfo(info: ir.Info): Option[String] = info match { + case ir.NoInfo => None + case f : ir.FileInfo => Some(f.escaped) + case m : ir.MultiInfo => + val infos = m.flatten + if(infos.isEmpty) { None } else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } + } + + private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]() + private def makeRandom(baseName: String, width: Int): BVExpr = { + // TODO: actually ensure that there cannot be any name clashes with other identifiers + val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii) + val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get + val sym = BVSymbol(name, width) + randoms(name) = sym + sym + } +} + +private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLogging { + type Connects = Iterable[(String, BVExpr)] + def onMemory(defMem: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (Iterable[State], Connects) = { + // we can only work on appropriately lowered memories + assert(defMem.dataType.isInstanceOf[ir.GroundType], + s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!") + assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") + + // collect all memory meta-data in a custom class + val m = new MemInfo(defMem) + + // find all connections related to this memory + val inputs = connects.filter(_._1.startsWith(m.prefix)).toMap + + // there could be a constant init + val init = initValue.map(getInit(m, _)) + + // parse and check read and write ports + val writers = defMem.writers.map( w => new WritePort(m, w, inputs)) + val readers = defMem.readers.map( r => new ReadPort(m, r, inputs)) + + // derive next state from all write ports + assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.") + val next: ArrayExpr = if(writers.isEmpty) { m.sym } else { + if(writers.length > 2) { + throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})") + } + val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) } + if(writers.length == 1) { validData } else { + assert(writers.length == 2) + val conflict = writers.head.doesConflict(writers.last) + val conflictData = writers.head.makeRandomData("_write_write_collision") + val conflictStore = ArrayStore(m.sym, writers.head.addr, conflictData) + ArrayIte(conflict, conflictStore, validData) + } + } + val state = State(m.sym, init, Some(next)) + + // derive data signals from all read ports + assert(defMem.readLatency >= 0) + if(defMem.readLatency > 1) { + throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})") + } + val readPortSignals = if(defMem.readLatency == 0) { + readers.map { r => + // combinatorial read + if(defMem.readUnderWrite != ir.ReadUnderWrite.New) { + //logger.warn(s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." + + // s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored.") + } + // since we do a combinatorial read, the "old" data is the current data + val data = r.readOld() + r.data.name -> data + } + } else { Seq() } + val readPortStates = if(defMem.readLatency == 1) { + readers.map { r => + // we create a register for the read port data + val next = defMem.readUnderWrite match { + case ir.ReadUnderWrite.New => + throw new UnsupportedFeatureException(s"registered read ports that return the new value (${m.name}.${r.name})") + // the thing that makes this hard is to properly handle write conflicts + case ir.ReadUnderWrite.Undefined => + val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r))) + if(anyWriteToTheSameAddress == False) { r.readOld() } else { + val readUnderWriteData = r.makeRandomData("_read_under_write_undefined") + BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.readOld()) + } + case ir.ReadUnderWrite.Old => r.readOld() + } + State(r.data, init=None, next=Some(next)) + } + } else { Seq() } + + (state +: readPortStates, readPortSignals) + } + + private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match { + case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth) + case MemoryArrayInit(values) => + assert(values.length == m.depth, + s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!") + // in order to get a more compact encoding try to find the most common values + val histogram = mutable.LinkedHashMap[BigInt, Int]() + values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) + val baseValue = histogram.maxBy(_._2)._1 + val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth) + values.zipWithIndex.filterNot(_._1 == baseValue) + .foldLeft[ArrayExpr](base) { case (array, (value, index)) => + ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth)) + } + case other => throw new RuntimeException(s"Unsupported memory init option: $other") + } + + private class MemInfo(m: ir.DefMemory) { + val name = m.name + val depth = m.depth + // derrive the type of the memory from the dataType and depth + val dataWidth = getWidth(m.dataType) + val indexWidth = Utils.getUIntWidth(m.depth - 1) max 1 + val sym = ArraySymbol(m.name, indexWidth, dataWidth) + val prefix = m.name + "." + val fullAddressRange = (BigInt(1) << indexWidth) == m.depth + lazy val depthBV = BVLiteral(m.depth, indexWidth) + def isValidAddress(addr: BVExpr): BVExpr = { + if(fullAddressRange) { True } else { + BVComparison(Compare.Greater, depthBV, addr, signed = false) + } + } + } + private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) { + val en: BVSymbol = makeField("en", 1) + val data: BVSymbol = makeField("data", memory.dataWidth) + val addr: BVSymbol = makeField("addr", memory.indexWidth) + protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width) + // make sure that all widths are correct + assert(inputs(en.name).width == en.width) + assert(inputs(addr.name).width == addr.width) + val enIsTrue: Boolean = inputs(en.name) == True + def makeRandomData(suffix: String): BVExpr = + makeRandom(memory.name + "_" + name + suffix, memory.dataWidth) + def readOld(): BVExpr = { + val canBeOutOfRange = !memory.fullAddressRange + val canBeDisabled = !enIsTrue + val data = ArrayRead(memory.sym, addr) + val dataWithRangeCheck = if(canBeOutOfRange) { + val outOfRangeData = makeRandomData("_addr_out_of_range") + BVIte(memory.isValidAddress(addr), data, outOfRangeData) + } else { data } + val dataWithEnabledCheck = if(canBeDisabled) { + val disabledData = makeRandomData("_not_enabled") + BVIte(en, dataWithRangeCheck, disabledData) + } else { dataWithRangeCheck } + dataWithEnabledCheck + } + } + private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr) + extends MemPort(memory, name, inputs) { + assert(inputs(data.name).width == data.width) + val mask: BVSymbol = makeField("mask", 1) + assert(inputs(mask.name).width == mask.width) + val maskIsTrue: Boolean = inputs(mask.name) == True + val doWrite: BVExpr = (enIsTrue, maskIsTrue) match { + case (true, true) => True + case (true, false) => mask + case (false, true) => en + case (false, false) => and(en, mask) + } + def doesConflict(r: ReadPort): BVExpr = { + val sameAddress = BVEqual(r.addr, addr) + if(doWrite == True) { sameAddress } else { and(doWrite, sameAddress) } + } + def doesConflict(w: WritePort): BVExpr = { + val bothWrite = and(doWrite, w.doWrite) + val sameAddress = BVEqual(addr, w.addr) + if(bothWrite == True) { sameAddress } else { and(doWrite, sameAddress) } + } + def writeTo(array: ArrayExpr): ArrayExpr = { + val doUpdate = if(memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr)) + val update = ArrayStore(array, index=addr, data=data) + if(doUpdate == True) update else ArrayIte(doUpdate, update, array) + } + + } + private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr) + extends MemPort(memory, name, inputs) { + } + + private def and(a: BVExpr, b: BVExpr): BVExpr = (a,b) match { + case (True, True) => True + case (True, x) => x + case (x, True) => x + case _ => BVOp(Op.And, a, b) + } + private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b) + private val True = BVLiteral(1, 1) + private val False = BVLiteral(0, 1) + private def all(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) False else b.reduce((a,b) => and(a,b)) + private def any(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) True else b.reduce((a,b) => or(a,b)) +} + +// performas a first pass over the module collecting all connections, wires, registers, input and outputs +private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLogging { + import FirrtlExpressionSemantics.getWidth + + private[firrtl] val inputs = mutable.ArrayBuffer[BVSymbol]() + private[firrtl] val outputs = mutable.ArrayBuffer[String]() + private[firrtl] val clocks = mutable.LinkedHashSet[String]() + private[firrtl] val wires = mutable.ArrayBuffer[String]() + private[firrtl] val nodes = mutable.ArrayBuffer[String]() + private[firrtl] val memSignals = mutable.ArrayBuffer[String]() + private[firrtl] val registers = mutable.ArrayBuffer[(String, Int, BVExpr, BVExpr)]() + private[firrtl] val memories = mutable.ArrayBuffer[ir.DefMemory]() + // DefNode, Connect, IsInvalid and VerificationStatement connections + private[firrtl] val connects = mutable.ArrayBuffer[(String, BVExpr)]() + private[firrtl] val asserts = mutable.ArrayBuffer[String]() + private[firrtl] val assumes = mutable.ArrayBuffer[String]() + // maps identifiers to their info + private[firrtl] val infos = mutable.ArrayBuffer[(String, ir.Info)]() + // keeps track of unused memory (data) outputs so that we can see where they are first used + private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]() + + private[firrtl] def onPort(p: ir.Port): Unit = { + if(isAsyncReset(p.tpe)) { + throw new AsyncResetException(s"Found AsyncReset ${p.name}.") + } + infos.append(p.name -> p.info) + p.direction match { + case ir.Input => + if(isClock(p.tpe)) { + clocks.add(p.name) + } else { + inputs.append(BVSymbol(p.name, getWidth(p.tpe))) + } + case ir.Output => outputs.append(p.name) + } + } + + private[firrtl] def onStatement(s: ir.Statement): Unit = s match { + case ir.DefWire(info, name, tpe) => + if(!isClock(tpe)) { + infos.append(name -> info) + wires.append(name) + } + case ir.DefNode(info, name, expr) => + if(!isClock(expr.tpe)) { + insertDummyAssignsForMemoryOutputs(expr) + infos.append(name -> info) + val e = onExpression(expr, name) + nodes.append(name) + connects.append((name, e)) + } + case ir.DefRegister(info, name, tpe, _, reset, init) => + insertDummyAssignsForMemoryOutputs(reset) + insertDummyAssignsForMemoryOutputs(init) + infos.append(name -> info) + val width = getWidth(tpe) + val resetExpr = onExpression(reset, 1, name + "_reset") + val initExpr = onExpression(init, width, name + "_init") + registers.append((name, width, resetExpr, initExpr)) + case m : ir.DefMemory => + infos.append(m.name -> m.info) + val outputs = getMemOutputs(m) + (getMemInputs(m) ++ outputs).foreach(memSignals.append(_)) + val dataWidth = getWidth(m.dataType) + outputs.foreach(name => unusedMemOutputs(name) = dataWidth) + memories.append(m) + case ir.Connect(info, loc, expr) => + if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") + val name = loc.serialize + insertDummyAssignsForMemoryOutputs(expr) + infos.append(name -> info) + connects.append((name, onExpression(expr, getWidth(loc.tpe), name))) + case ir.IsInvalid(info, loc) => + if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") + val name = loc.serialize + infos.append(name -> info) + connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe)))) + case ir.DefInstance(info, name, module, tpe) => + if(!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") + // we treat all instances as blackboxes + logger.warn(s"WARN: treating instance $name of $module as blackbox. " + + "Please flatten your hierarchy if you want to include submodules in the formal model.") + val ports = tpe.asInstanceOf[ir.BundleType].fields + // skip clock and async reset ports + ports.filterNot(p => isClock(p.tpe) || isAsyncReset(p.tpe) ).foreach { p => + if(!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p") + val isOutput = p.flip == ir.Default + val pName = name + "." + p.name + infos.append(pName -> info) + // outputs of the submodule become inputs to our module + if(isOutput) { + inputs.append(BVSymbol(pName, getWidth(p.tpe))) + } else { + outputs.append(pName) + } + } + case s @ ir.Verification(op, info, _, pred, en, msg) => + if(op == ir.Formal.Cover) { + logger.warn(s"WARN: Cover statement was ignored: ${s.serialize}") + } else { + val name = msgToName(op.toString, msg.string) + val predicate = onExpression(pred, name + "_predicate") + val enabled = onExpression(en, name + "_enabled") + val e = BVImplies(enabled, predicate) + infos.append(name -> info) + connects.append(name -> e) + if(op == ir.Formal.Assert) { + asserts.append(name) + } else { + assumes.append(name) + } + } + case s : ir.Conditionally => + error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}") + case s : ir.PartialConnect => + error(s"PartialConnects are not supported. Please run ExpandConnects: ${s.serialize}") + case s : ir.Attach => + error(s"Analog wires are not supported in the SMT backend: ${s.serialize}") + case s : ir.Stop => + // we could wire up the stop condition as output for debug reasons + logger.warn(s"WARN: Stop statements are currently not supported. Ignoring: ${s.serialize}") + case s : ir.Print => + logger.warn(s"WARN: Print statements are not supported. Ignoring: ${s.serialize}") + case other => other.foreachStmt(onStatement) + } + + private val readInputFields = List("en", "addr") + private val writeInputFields = List("en", "mask", "addr", "data") + private def getMemInputs(m: ir.DefMemory): Iterable[String] = { + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") + val p = m.name + "." + m.writers.flatMap(w => writeInputFields.map(p + w + "." + _)) ++ + m.readers.flatMap(r => readInputFields.map(p + r + "." + _)) + } + private def getMemOutputs(m: ir.DefMemory): Iterable[String] = { + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") + val p = m.name + "." + m.readers.map(r => p + r + ".data") + } + // inserts a dummy assign right before a memory output is used for the first time + // example: + // m.r.data <= m.r.data ; this is the dummy assign + // test <= m.r.data ; this is the first use of m.r.data + private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if(unusedMemOutputs.nonEmpty) { + implicit val uses = mutable.ArrayBuffer[String]() + findUnusedMemoryOutputUse(next) + if(uses.nonEmpty) { + val useSet = uses.toSet + unusedMemOutputs.foreach { case (name, width) => + if(useSet.contains(name)) connects.append(name -> BVSymbol(name, width)) + } + useSet.foreach(name => unusedMemOutputs.remove(name)) + } + } + private def findUnusedMemoryOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match { + case s : ir.SubField => + val name = s.serialize + if(unusedMemOutputs.contains(name)) uses.append(name) + case other => other.foreachExpr(findUnusedMemoryOutputUse) + } + + private case class Context(baseName: String) extends TranslationContext { + override def getRandom(width: Int): BVExpr = makeRandom(baseName, width) + } + + private def onExpression(e: ir.Expression, width: Int, randomPrefix: String): BVExpr = { + implicit val ctx: TranslationContext = Context(randomPrefix) + FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false) + } + private def onExpression(e: ir.Expression, randomPrefix: String): BVExpr = { + implicit val ctx: TranslationContext = Context(randomPrefix) + FirrtlExpressionSemantics.toSMT(e) + } + + private def msgToName(prefix: String, msg: String): String = { + // TODO: ensure that we can generate unique names + prefix + "_" + msg.replace(" ", "_").replace("|", "") + } + private def error(msg: String): Unit = throw new RuntimeException(msg) + private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType] + private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType + private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType +} + +private object TopologicalSort { + /** Ensures that all signals in the resulting system are topologically sorted. + * This is necessary because [[firrtl.transforms.RemoveWires]] does + * not sort assignments to outputs, submodule inputs nor memory ports. + * */ + def run(sys: TransitionSystem): TransitionSystem = { + val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name) + val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates) + // TODO: maybe sort init expressions of states (this should not be needed most of the time) + signalOrder match { + case None => sys + case Some(order) => + val signalMap = sys.signals.map(s => s.name -> s).toMap + // we flatMap over `get` in order to ignore inputs/states in the order + sys.copy(signals = order.flatMap(signalMap.get).toArray) + } + } + + private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = { + val known = new mutable.HashSet[String]() ++ globalSignals + var needsReordering = false + val digraph = new MutableDiGraph[String] + signals.foreach { case (name, expr) => + digraph.addVertex(name) + val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr) + uniqueDependencies.foreach { d => + if(!known.contains(d)) { needsReordering = true } + digraph.addPairWithEdge(name, d) + } + known.add(name) + } + if(needsReordering) { + Some(digraph.linearize.reverse) + } else { None } + } + + private def findDependencies(expr: SMTExpr): List[String] = expr match { + case BVSymbol(name, _) => List(name) + case ArraySymbol(name, _, _) => List(name) + case other => other.children.flatMap(findDependencies) + } +}
\ No newline at end of file diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala new file mode 100644 index 00000000..322b8961 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala @@ -0,0 +1,85 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import java.io.Writer + +import firrtl._ +import firrtl.annotations.{Annotation, NoTargetAnnotation} +import firrtl.options.Viewer.view +import firrtl.options.{CustomFileEmission, Dependency} +import firrtl.stage.FirrtlOptions + + +private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform with Emitter with DependencyAPIMigration { + override def prerequisites: Seq[Dependency[Transform]] = Seq(Dependency(FirrtlToTransitionSystem)) + override def invalidates(a: Transform): Boolean = false + + override def emit(state: CircuitState, writer: Writer): Unit = error("Deprecated since firrtl 1.0!") + + protected def serialize(sys: TransitionSystem): Annotation + + private val BleedingEdgeWarning = + """WARNING: The SMT and BTOR2 emitters are experimental preview features. + |- they might be removed in future versions without deprecation warning + |- their behavior and interfaces might change without notice + |- they are unsupported, we won't be able to fix any issues you find in them + |- however, we do accept pull requests: https://github.com/freechipsproject/firrtl/pulls + |""".stripMargin + + override protected def execute(state: CircuitState): CircuitState = { + val emitCircuit = state.annotations.exists { + case EmitCircuitAnnotation(a) if this.getClass == a => true + case EmitAllModulesAnnotation(a) if this.getClass == a => error("EmitAllModulesAnnotation not supported!") + case _ => false + } + + if(!emitCircuit) { return state } + + logger.warn(BleedingEdgeWarning) + + val sys = state.annotations.collectFirst{ case TransitionSystemAnnotation(sys) => sys }.getOrElse { + error("Could not find the transition system!") + } + state.copy(annotations = state.annotations :+ serialize(sys)) + } + + protected def generatedHeader(format: String, name: String): String = + s"; $format description generated by firrtl ${BuildInfo.version} for module $name.\n" + + protected def error(msg: String): Nothing = throw new RuntimeException(msg) +} + +case class EmittedSMTModelAnnotation(name: String, src: String, outputSuffix: String) + extends NoTargetAnnotation with CustomFileEmission { + override protected def baseFileName(annotations: AnnotationSeq): String = + view[FirrtlOptions](annotations).outputFileName.getOrElse(name) + override protected def suffix: Option[String] = Some(outputSuffix) + override def getBytes: Iterable[Byte] = src.getBytes +} + +private[firrtl] class Btor2Emitter extends SMTEmitter { + override def outputSuffix: String = ".btor2" + override protected def serialize(sys: TransitionSystem): Annotation = { + val btor = generatedHeader("BTOR", sys.name) + Btor2Serializer.serialize(sys).mkString("\n") + "\n" + EmittedSMTModelAnnotation(sys.name, btor, outputSuffix) + } +} + +private[firrtl] class SMTLibEmitter extends SMTEmitter { + override def outputSuffix: String = ".smt2" + override protected def serialize(sys: TransitionSystem): Annotation = { + val hasMemory = sys.states.exists(_.sym.isInstanceOf[ArrayExpr]) + val logic = SMTLibSerializer.setLogic(hasMemory) + "\n" + val header = if(hasMemory) { + "; We have to disable the logic for z3 to accept the non-standard \"as const\"\n" + + "; see https://github.com/Z3Prover/z3/issues/1803\n" + + "; for CVC4 you probably want to include the logic\n" + + ";" + logic + } else { logic } + val smt = generatedHeader("SMT-LIBv2", sys.name) + header + + SMTTransitionSystemEncoder.encode(sys).map(SMTLibSerializer.serialize).mkString("\n") + "\n" + EmittedSMTModelAnnotation(sys.name, smt, outputSuffix) + } +}
\ No newline at end of file diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala new file mode 100644 index 00000000..10a89e8d --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala @@ -0,0 +1,196 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> +// Inspired by the uclid5 SMT library (https://github.com/uclid-org/uclid). +// And the btor2 documentation (BTOR2 , BtorMC and Boolector 3.0 by Niemetz et.al.) + +package firrtl.backends.experimental.smt + +private sealed trait SMTExpr { def children: List[SMTExpr] } +private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { val name: String } +private object SMTSymbol { + def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match { + case b: BVExpr => BVSymbol(name, b.width) + case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth) + } +} +private sealed trait SMTNullaryExpr extends SMTExpr { + override def children: List[SMTExpr] = List() +} + +private sealed trait BVExpr extends SMTExpr { def width: Int } +private case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr { + private def minWidth = value.bitLength + (if(value <= 0) 1 else 0) + assert(width > 0, "Zero or negative width literals are not allowed!") + assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.") + override def toString: String = if(width <= 8) { + width.toString + "'b" + value.toString(2) + } else { width.toString + "'x" + value.toString(16) } +} +private case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol { + assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") + assert(!name.contains("\\"), s"Invalid id $name contains `\\`") + assert(width > 0, "Zero width bit vectors are not supported!") + override def toString: String = name + def toStringWithType: String = name + " : " + SMTExpr.serializeType(this) +} + +private sealed trait BVUnaryExpr extends BVExpr { + def e: BVExpr + override def children: List[BVExpr] = List(e) +} +private case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr { + assert(by >= 0, "Extension must be non-negative!") + override val width: Int = e.width + by + override def toString: String = if(signed) { s"sext($e, $by)" } else { s"zext($e, $by)" } +} +// also known as bit extract operation +private case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr { + assert(lo >= 0, s"lo (lsb) must be non-negative!") + assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo") + assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi") + override def width: Int = hi - lo + 1 + override def toString: String = if(hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]" +} +private case class BVNot(e: BVExpr) extends BVUnaryExpr { + override val width: Int = e.width + override def toString: String = s"not($e)" +} +private case class BVNegate(e: BVExpr) extends BVUnaryExpr { + override val width: Int = e.width + override def toString: String = s"neg($e)" +} +private case class BVReduceOr(e: BVExpr) extends BVUnaryExpr { + override def width: Int = 1 + override def toString: String = s"redor($e)" +} +private case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr { + override def width: Int = 1 + override def toString: String = s"redand($e)" +} +private case class BVReduceXor(e: BVExpr) extends BVUnaryExpr { + override def width: Int = 1 + override def toString: String = s"redxor($e)" +} + +private sealed trait BVBinaryExpr extends BVExpr { + def a: BVExpr + def b: BVExpr + override def children: List[BVExpr] = List(a, b) +} +private case class BVImplies(a: BVExpr, b: BVExpr) extends BVBinaryExpr { + assert(a.width == 1 && b.width == 1, s"Both arguments need to be 1-bit!") + override def width: Int = 1 + override def toString: String = s"impl($a, $b)" +} +private case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr { + assert(a.width == b.width, s"Both argument need to be the same width!") + override def width: Int = 1 + override def toString: String = s"eq($a, $b)" +} +private object Compare extends Enumeration { + val Greater, GreaterEqual = Value +} +private case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed: Boolean) extends BVBinaryExpr { + assert(a.width == b.width, s"Both argument need to be the same width!") + override def width: Int = 1 + override def toString: String = op match { + case Compare.Greater => (if(signed) "sgt" else "ugt") + s"($a, $b)" + case Compare.GreaterEqual => (if(signed) "sgeq" else "ugeq") + s"($a, $b)" + } +} +private object Op extends Enumeration { + val And = Value("and") + val Or = Value("or") + val Xor = Value("xor") + val ShiftLeft = Value("logical_shift_left") + val ArithmeticShiftRight = Value("arithmetic_shift_right") + val ShiftRight = Value("logical_shift_right") + val Add = Value("add") + val Mul = Value("mul") + val SignedDiv = Value("sdiv") + val UnsignedDiv = Value("udiv") + val SignedMod = Value("smod") + val SignedRem = Value("srem") + val UnsignedRem = Value("urem") + val Sub = Value("sub") +} +private case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr { + assert(a.width == b.width, s"Both argument need to be the same width!") + override val width: Int = a.width + override def toString: String = s"$op($a, $b)" +} +private case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr { + override val width: Int = a.width + b.width + override def toString: String = s"concat($a, $b)" +} +private case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr { + assert(array.indexWidth == index.width, "Index with does not match expected array index width!") + override val width: Int = array.dataWidth + override def toString: String = s"$array[$index]" + override def children: List[SMTExpr] = List(array, index) +} +private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr { + assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!") + assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}") + override val width: Int = tru.width + override def toString: String = s"ite($cond, $tru, $fals)" + override def children: List[BVExpr] = List(cond, tru, fals) +} + +private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int } +private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol { + assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") + assert(!name.contains("\\"), s"Invalid id $name contains `\\`") + override def toString: String = name + def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>" +} +private case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr { + assert(array.indexWidth == index.width, "Index with does not match expected array index width!") + assert(array.dataWidth == data.width, "Data with does not match expected array data width!") + override val dataWidth: Int = array.dataWidth + override val indexWidth: Int = array.indexWidth + override def toString: String = s"$array[$index := $data]" + override def children: List[SMTExpr] = List(array, index, data) +} +private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr { + assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!") + assert(tru.indexWidth == fals.indexWidth, + s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}") + assert(tru.dataWidth == fals.dataWidth, + s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}") + override val dataWidth: Int = tru.dataWidth + override val indexWidth: Int = tru.indexWidth + override def toString: String = s"ite($cond, $tru, $fals)" + override def children: List[SMTExpr] = List(cond, tru, fals) +} +private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr { + assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!") + assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!") + override def width: Int = 1 + override def toString: String = s"eq($a, $b)" + override def children: List[SMTExpr] = List(a, b) +} +private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr { + override val dataWidth: Int = e.width + override def toString: String = s"([$e] x ${ (BigInt(1) << indexWidth) })" + override def children: List[SMTExpr] = List(e) +} + +private object SMTEqual { + def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a,b) match { + case (ab : BVExpr, bb : BVExpr) => BVEqual(ab, bb) + case (aa : ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba) + case _ => throw new RuntimeException(s"Cannot compare $a and $b") + } +} + +private object SMTExpr { + def serializeType(e: SMTExpr): String = e match { + case b: BVExpr => s"bv<${b.width}>" + case a: ArrayExpr => s"bv<${a.indexWidth}> -> bv<${a.dataWidth}>" + } +} + +// Raw SMTLib encoded expressions as an escape hatch used in the [[SMTTransitionSystemEncoder]] +private case class BVRawExpr(serialized: String, width: Int) extends BVExpr with SMTNullaryExpr +private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTNullaryExpr
\ No newline at end of file diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala new file mode 100644 index 00000000..14e73253 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala @@ -0,0 +1,73 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +/** Similar to the mapExpr and foreachExpr methods of the firrtl ir nodes, but external to the case classes */ +private object SMTExprVisitor { + type ArrayFun = ArrayExpr => ArrayExpr + type BVFun = BVExpr => BVExpr + + def map[T <: SMTExpr](bv: BVFun, ar: ArrayFun)(e: T): T = e match { + case b: BVExpr => map(b, bv, ar).asInstanceOf[T] + case a: ArrayExpr => map(a, bv, ar).asInstanceOf[T] + } + def map[T <: SMTExpr](f: SMTExpr => SMTExpr)(e: T): T = + map(b => f(b).asInstanceOf[BVExpr], a => f(a).asInstanceOf[ArrayExpr])(e) + + private def map(e: BVExpr, bv: BVFun, ar: ArrayFun): BVExpr = e match { + // nullary + case old : BVLiteral => bv(old) + case old : BVSymbol => bv(old) + case old : BVRawExpr => bv(old) + // unary + case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVExtend(n, by, signed)) + case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVSlice(n, hi, lo)) + case old @ BVNot(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNot(n)) + case old @ BVNegate(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNegate(n)) + case old @ BVReduceAnd(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceAnd(n)) + case old @ BVReduceOr(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceOr(n)) + case old @ BVReduceXor(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceXor(n)) + // binary + case old @ BVImplies(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB)) + case old @ BVEqual(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB)) + case old @ ArrayEqual(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB)) + case old @ BVComparison(op, a, b, signed) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed)) + case old @ BVOp(op, a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB)) + case old @ BVConcat(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB)) + case old @ ArrayRead(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB)) + // ternary + case old @ BVIte(a, b, c) => + val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)) + } + + + private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match { + case old : ArrayRawExpr => ar(old) + case old : ArraySymbol => ar(old) + case old @ ArrayConstant(e, indexWidth) => + val n = map(e, bv, ar) ; ar(if(n.eq(e)) old else ArrayConstant(n, indexWidth)) + case old @ ArrayStore(a, b, c) => + val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) + ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC)) + case old @ ArrayIte(a, b, c) => + val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) + ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC)) + } + +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala new file mode 100644 index 00000000..1993da87 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala @@ -0,0 +1,151 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import scala.util.matching.Regex + +/** Converts STM Expressions to a SMTLib compatible string representation. + * See http://smtlib.cs.uiowa.edu/ + * Assumes well typed expression, so it is advisable to run the TypeChecker + * before serializing! + * Automatically converts 1-bit vectors to bool. + */ +private object SMTLibSerializer { + def setLogic(hasMem: Boolean) = "(set-logic QF_" + (if(hasMem) "A" else "") + "UFBV)" + + def serialize(e: SMTExpr): String = e match { + case b : BVExpr => serialize(b) + case a : ArrayExpr => serialize(a) + } + + def serializeType(e: SMTExpr): String = e match { + case b : BVExpr => serializeBitVectorType(b.width) + case a : ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth) + } + + private def serialize(e: BVExpr): String = e match { + case BVLiteral(value, width) => + val mask = (BigInt(1) << width) - 1 + val twosComplement = if(value < 0) { ((~(-value)) & mask) + 1 } else value + if(width == 1) { + if(twosComplement == 1) "true" else "false" + } else { + s"(_ bv$twosComplement $width)" + } + case BVSymbol(name, _) => escapeIdentifier(name) + case BVExtend(e, 0, _) => serialize(e) + case BVExtend(BVLiteral(value, width), by, false) => serialize(BVLiteral(value, width + by)) + case BVExtend(e, by, signed) => + val foo = if(signed) "sign_extend" else "zero_extend" + s"((_ $foo $by) ${asBitVector(e)})" + case BVSlice(e, hi, lo) => + if(lo == 0 && hi == e.width - 1) { serialize(e) + } else { + val bits = s"((_ extract $hi $lo) ${asBitVector(e)})" + // 1-bit extracts need to be turned into a boolean + if(lo == hi) { toBool(bits) } else { bits } + } + case BVNot(BVEqual(a, b)) if a.width == 1 => s"(distinct ${serialize(a)} ${serialize(b)})" + case BVNot(BVNot(e)) => serialize(e) + case BVNot(e) => if(e.width == 1) { s"(not ${serialize(e)})" } else { s"(bvnot ${serialize(e)})" } + case BVNegate(e) => s"(bvneg ${asBitVector(e)})" + case r: BVReduceAnd => serialize(Expander.expand(r)) + case r: BVReduceOr => serialize(Expander.expand(r)) + case r: BVReduceXor => serialize(Expander.expand(r)) + case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b) + case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})" + case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" + case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" + case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})" + case BVComparison(Compare.GreaterEqual, a, b, false) => s"(bvuge ${asBitVector(a)} ${asBitVector(b)})" + case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})" + case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})" + // boolean operations get a special treatment for 1-bit vectors aka bools + case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})" + case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})" + case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})" + case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})") + case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})" + case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})" + case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})" + case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case BVRawExpr(serialized, _) => serialized + } + + def serialize(e: ArrayExpr): String = e match { + case ArraySymbol(name, _, _) => escapeIdentifier(name) + case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})" + case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})" + case ArrayRawExpr(serialized, _, _) => serialized + } + + def serialize(c: SMTCommand): String = c match { + case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n") + case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)" + case DefineFunction(name, args, e) => + val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ") + s"(define-fun ${escapeIdentifier(name)} ($aa) ${serializeType(e)} ${serialize(e)})" + case DeclareFunction(sym, tpes) => + val aa = tpes.mkString(" ") + s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serializeType(sym)})" + } + + private def serializeArrayType(indexWidth: Int, dataWidth: Int): String = + s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})" + private def serializeBitVectorType(width: Int): String = + if(width == 1) { "Bool" } else { assert(width > 1) ; s"(_ BitVec $width)" } + + private def serialize(op: Op.Value): String = op match { + case Op.And => "bvand" + case Op.Or => "bvor" + case Op.Xor => "bvxor" + case Op.ArithmeticShiftRight => "bvashr" + case Op.ShiftRight => "bvlshr" + case Op.ShiftLeft => "bvshl" + case Op.Add => "bvadd" + case Op.Mul => "bvmul" + case Op.Sub => "bvsub" + case Op.SignedDiv => "bvsdiv" + case Op.UnsignedDiv => "bvudiv" + case Op.SignedMod => "bvsmod" + case Op.SignedRem => "bvsrem" + case Op.UnsignedRem => "bvurem" + } + + private def toBool(e: String): String = s"(= $e (_ bv1 1))" + + private val bvZero = "(_ bv0 1)" + private val bvOne = "(_ bv1 1)" + private def asBitVector(e: BVExpr): String = + if(e.width > 1) { serialize(e) } else { s"(ite ${serialize(e)} $bvOne $bvZero)" } + + // See <simple_symbol> definition in the Concrete Syntax Appendix of the SMTLib Spec + private val simple: Regex = raw"[a-zA-Z\+-/\*\=%\?!\.\$$_~&\^<>@][a-zA-Z0-9\+-/\*\=%\?!\.\$$_~&\^<>@]*".r + def escapeIdentifier(name: String): String = name match { + case simple() => name + case _ => if(name.startsWith("|") && name.endsWith("|")) name else s"|$name|" + } +} + +/** Expands expressions that are not natively supported by SMTLib */ +private object Expander { + def expand(r: BVReduceAnd): BVExpr = { + if(r.e.width == 1) { r.e } else { + val allOnes = (BigInt(1) << r.e.width) - 1 + BVEqual(r.e, BVLiteral(allOnes, r.e.width)) + } + } + def expand(r: BVReduceOr): BVExpr = { + if(r.e.width == 1) { r.e } else { + BVNot(BVEqual(r.e, BVLiteral(0, r.e.width))) + } + } + def expand(r: BVReduceXor): BVExpr = { + if(r.e.width == 1) { r.e } else { + val bits = (0 until r.e.width).map(ii => BVSlice(r.e, ii, ii)) + bits.reduce[BVExpr]((a,b) => BVOp(Op.Xor, a, b)) + } + } +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala new file mode 100644 index 00000000..e9acc05b --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala @@ -0,0 +1,128 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import scala.collection.mutable + +/** This Transition System encoding is directly inspired by yosys' SMT backend: + * https://github.com/YosysHQ/yosys/blob/master/backends/smt2/smt2.cc + * It if fairly compact, but unfortunately, the use of an uninterpreted sort for the state + * prevents this encoding from working with boolector. + * For simplicity reasons, we do not support hierarchical designs (no `_h` function). + * */ +private object SMTTransitionSystemEncoder { + + def encode(sys: TransitionSystem): Iterable[SMTCommand] = { + val cmds = mutable.ArrayBuffer[SMTCommand]() + val name = sys.name + + // emit header as comments + cmds ++= sys.header.map(Comment) + + // declare state type + val stateType = id(name + "_s") + cmds += DeclareUninterpretedSort(stateType) + + // inputs and states are modelled as constants + def declare(sym: SMTSymbol, kind: String): Unit = { + cmds ++= toDescription(sym, kind, sys.comments.get) + val s = SMTSymbol.fromExpr(sym.name + SignalSuffix, sym) + cmds += DeclareFunction(s, List(stateType)) + } + sys.inputs.foreach(i => declare(i, "input")) + sys.states.foreach(s => declare(s.sym, "register")) + + // signals are just functions of other signals, inputs and state + def define(sym: SMTSymbol, e: SMTExpr, suffix: String = SignalSuffix): Unit = { + cmds += DefineFunction(sym.name + suffix, List((State, stateType)), replaceSymbols(e)) + } + sys.signals.foreach { signal => + val kind = if(sys.outputs.contains(signal.name)) { "output" + } else if(sys.assumes.contains(signal.name)) { "assume" + } else if(sys.asserts.contains(signal.name)) { "assert" + } else { "wire" } + val sym = SMTSymbol.fromExpr(signal.name, signal.e) + cmds ++= toDescription(sym, kind, sys.comments.get) + define(sym, signal.e) + } + + // define the next and init functions for all states + sys.states.foreach { state => + assert(state.next.nonEmpty, "Next function required") + define(state.sym, state.next.get, NextSuffix) + // init is optional + state.init.foreach { init => + define(state.sym, init, InitSuffix) + } + } + + def defineConjunction(e: Iterable[BVExpr], suffix: String): Unit = { + define(BVSymbol(name, 1), andReduce(e), suffix) + } + + // the transition relation asserts that the value of the next state is the next value from the previous state + // e.g., (reg state_n) == (reg_next state) + val transitionRelations = sys.states.map { state => + val newState = symbolToFunApp(state.sym, SignalSuffix, StateNext) + val nextOldState = symbolToFunApp(state.sym, NextSuffix, State) + SMTEqual(newState, nextOldState) + } + // the transition relation is over two states + val transitionExpr = replaceSymbols(andReduce(transitionRelations)) + cmds += DefineFunction(name + "_t", List((State, stateType), (StateNext, stateType)), transitionExpr) + + // The init relation just asserts that all init function hold + val initRelations = sys.states.filter(_.init.isDefined).map { state => + val stateSignal = symbolToFunApp(state.sym, SignalSuffix, State) + val initSignal = symbolToFunApp(state.sym, InitSuffix, State) + SMTEqual(stateSignal, initSignal) + } + defineConjunction(initRelations, "_i") + + // assertions and assumptions + val assertions = sys.signals.filter(a => sys.asserts.contains(a.name)).map(a => replaceSymbols(a.toSymbol)) + defineConjunction(assertions, "_a") + val assumptions = sys.signals.filter(a => sys.assumes.contains(a.name)).map(a => replaceSymbols(a.toSymbol)) + defineConjunction(assumptions, "_u") + + cmds + } + + private def id(s: String): String = SMTLibSerializer.escapeIdentifier(s) + private val State = "state" + private val StateNext = "state_n" + private val SignalSuffix = "_f" + private val NextSuffix = "_next" + private val InitSuffix = "_init" + private def toDescription(sym: SMTSymbol, kind: String, comments: String => Option[String]): List[Comment] = { + List(sym match { + case BVSymbol(name, width) => + Comment(s"firrtl-smt2-$kind $name $width") + case ArraySymbol(name, indexWidth, dataWidth) => + Comment(s"firrtl-smt2-$kind $name $indexWidth $dataWidth") + }) ++ comments(sym.name).map(Comment) + } + + private def andReduce(e: Iterable[BVExpr]): BVExpr = + if(e.isEmpty) BVLiteral(1, 1) else e.reduce((a,b) => BVOp(Op.And, a, b)) + + // All signals are modelled with functions that need to be called with the state as argument, + // this replaces all Symbols with function applications to the state. + private def replaceSymbols(e: SMTExpr): SMTExpr = { + SMTExprVisitor.map(symbolToFunApp(_, SignalSuffix, State))(e) + } + private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr] + private def symbolToFunApp(sym: SMTExpr, suffix: String, arg: String): SMTExpr = sym match { + case BVSymbol(name, width) => BVRawExpr(s"(${id(name+suffix)} $arg)", width) + case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name+suffix)} $arg)", indexWidth, dataWidth) + case other => other + } +} + +/** minimal set of pseudo SMT commands needed for our encoding */ +private sealed trait SMTCommand +private case class Comment(msg: String) extends SMTCommand +private case class DefineFunction(name: String, args: Seq[(String, String)], e: SMTExpr) extends SMTCommand +private case class DeclareFunction(sym: SMTSymbol, tpes: Seq[String]) extends SMTCommand +private case class DeclareUninterpretedSort(name: String) extends SMTCommand diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala new file mode 100644 index 00000000..d8e203f8 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala @@ -0,0 +1,253 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import firrtl.{CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils, ir} +import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation, ReferenceTarget, SingleTargetAnnotation} +import firrtl.ir.EmptyStmt +import firrtl.options.Dependency +import firrtl.passes.PassException +import firrtl.stage.Forms +import firrtl.stage.TransformManager.TransformDependency + +import scala.collection.mutable + +case class GlobalClockAnnotation(target: ReferenceTarget) extends SingleTargetAnnotation[ReferenceTarget] { + override def duplicate(n: ReferenceTarget): Annotation = this.copy(n) +} + +/** Converts every input clock into a clock enable input and adds a single global clock. + * - all registers and memory ports will be connected to the new global clock + * - all registers and memory ports will be guarded by the enable signal of their original clock + * - the clock enabled signal can be understood as a clock tick or posedge + * - this transform can be used in order to (formally) verify designs with multiple clocks or asynchronous resets + */ +class StutteringClockTransform extends Transform with DependencyAPIMigration { + override def prerequisites: Seq[TransformDependency] = Forms.LowForm + override def invalidates(a: Transform): Boolean = false + + // this pass needs to run *before* converting to a transition system + override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq(Dependency(FirrtlToTransitionSystem)) + // since this pass only runs on the main module, inlining needs to happen before + override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) + + + override protected def execute(state: CircuitState): CircuitState = { + if(state.circuit.modules.size > 1) { + logger.warn("WARN: StutteringClockTransform currently only supports running on a single module.\n" + + s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want.") + } + + // get main module + val main = state.circuit.modules.find(_.name == state.circuit.main).get match { + case m: ir.Module => m + case e: ir.ExtModule => unsupportedError(s"Cannot run on extmodule $e") + } + mainName = main.name + + val namespace = Namespace(main) + + // create a global clock + val globalClocks = state.annotations.collect { case GlobalClockAnnotation(c) => c } + assert(globalClocks.size < 2, "There can only be a single global clock: " + globalClocks.mkString(", ")) + val (globalClock, portsWithGlobalClock) = globalClocks.headOption match { + case Some(clock) => + assert(clock.module == main.name, "GlobalClock needs to be an input of the main module!") + assert(main.ports.exists(_.name == clock.ref), "GlobalClock needs to be an input port!") + assert(main.ports.find(_.name == clock.ref).get.direction == ir.Input, "GlobalClock needs to be an input port!") + (clock.ref, main.ports) + case None => + val name = namespace.newName("global_clock") + (name, ir.Port(ir.NoInfo, name, ir.Input, ir.ClockType) +: main.ports) + } + + // replace all other clocks with enable signals, unless they are the global clock + val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name) + val clockToEnable = clocks.map{c => + c -> ir.Reference(namespace.newName(c + "_en"), Bool, firrtl.PortKind, firrtl.SourceFlow) + }.toMap + val portsWithEnableSignals = portsWithGlobalClock.map { p => + if(clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) } else { p } + } + // replace async reset with synchronous reset (since everything will we synchronous with the global clock) + // unless it is a preset reset + val asyncResets = portsWithEnableSignals.filter(_.tpe == ir.AsyncResetType).map(_.name) + val isPresetReset = state.annotations.collect{ case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet + val resetsToChange = asyncResets.filterNot(isPresetReset).toSet + val portsWithSyncReset = portsWithEnableSignals.map { p => + if(resetsToChange.contains(p.name)) { p.copy(tpe = Bool) } else { p } + } + + // discover clock and reset connections + val scan = scanClocks(main, clockToEnable, resetsToChange) + + // rename clocks to clock enable signals + val mRef = CircuitTarget(state.circuit.main).module(main.name) + val renameMap = RenameMap() + scan.clockToEnable.foreach { case (clk, en) => + renameMap.record(mRef.ref(clk), mRef.ref(en.name)) + } + + // make changes + implicit val ctx: Context = new Context(globalClock, scan) + val newMain = main.copy(ports = portsWithSyncReset).mapStmt(onStatement) + + val nonMainModules = state.circuit.modules.filterNot(_.name == state.circuit.main) + val newCircuit = state.circuit.copy(modules = nonMainModules :+ newMain) + state.copy(circuit = newCircuit, renames = Some(renameMap)) + } + + private def onStatement(s: ir.Statement)(implicit ctx: Context): ir.Statement = { + s.foreachExpr(checkExpr) + s match { + // memory field connects + case c @ ir.Connect(_, ir.SubField(ir.SubField(ir.Reference(mem, _, _, _), port, _, _), field, _, _), _) + if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) => + // replace clock with the global clock + if(field == "clk") { + c.copy(expr = ctx.globalClock) + } else if(field == "en") { + val m = ctx.memInfo(mem) + val isWritePort = m.writers.contains(port) + assert(isWritePort || m.readers.contains(port)) + + // for write ports we guard the write enable with the clock enable signal, similar to registers + if(isWritePort) { + val clockEn = ctx.memPortToClockEnable(mem + "." + port) + val guardedEnable = and(clockEn, c.expr) + c.copy(expr = guardedEnable) + } else { c } + } else { c} + // register field connects + case c @ ir.Connect(_, r : ir.Reference, next) if ctx.registerToEnable.contains(r.name) => + val clockEnable = ctx.registerToEnable(r.name) + val guardedNext = mux(clockEnable, next, r) + c.copy(expr = guardedNext) + // remove other clock wires and nodes + case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt + case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt + case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt + // change async reset to synchronous reset + case ir.Connect(info, loc: ir.Reference, expr: ir.Reference) if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) => + ir.Connect(info, loc.copy(tpe=Bool), expr.copy(tpe=Bool)) + case d @ ir.DefNode(_, name, value: ir.Reference) if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) => + d.copy(value = value.copy(tpe=Bool)) + case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe=Bool) + // change memory clock and synchronize reset + case ir.DefRegister(info, name, tpe, clock, reset, init) if ctx.registerToEnable.contains(name) => + val clockEnable = ctx.registerToEnable(name) + val newReset = reset match { + case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe=Bool) + case other => other + } + val synchronizedReset = if(reset.tpe == ir.AsyncResetType) { newReset } else { and(newReset, clockEnable) } + ir.DefRegister(info, name, tpe, ctx.globalClock, synchronizedReset, init) + case other => other.mapStmt(onStatement) + } + } + + private def scanClocks(m: ir.Module, initialClockToEnable: Map[String, ir.Reference], resetsToChange: Set[String]): ScanCtx = { + implicit val ctx: ScanCtx = new ScanCtx(initialClockToEnable, resetsToChange) + m.foreachStmt(scanClocksAndResets) + ctx + } + + private def scanClocksAndResets(s: ir.Statement)(implicit ctx: ScanCtx): Unit = { + s.foreachExpr(checkExpr) + s match { + // track clock aliases + case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType => + val locName = loc.serialize + ctx.clockToEnable.get(expr.serialize).foreach { clockEn => + ctx.clockToEnable(locName) = clockEn + // keep track of memory clocks + if(loc.isInstanceOf[ir.SubField]) { + val parts = locName.split('.') + if(ctx.mems.contains(parts.head)) { + assert(parts.length == 3 && parts.last == "clk") + ctx.memPortToClockEnable.append(parts.dropRight(1).mkString(".") -> clockEn) + } + } + } + case ir.DefNode(_, name, value) if value.tpe == ir.ClockType => + ctx.clockToEnable.get(value.serialize).foreach(c => ctx.clockToEnable(name) = c) + // track reset aliases + case ir.Connect(_, loc, expr) if expr.tpe == ir.AsyncResetType && ctx.resetsToChange(expr.serialize) => + ctx.resetsToChange.add(loc.serialize) + case ir.DefNode(_, name, value) if value.tpe == ir.AsyncResetType && ctx.resetsToChange(value.serialize) => + ctx.resetsToChange.add(name) + // modify clocked elements + case ir.DefRegister(_, name, _, clock, _, _) => + ctx.clockToEnable.get(clock.serialize).foreach { clockEnable => + ctx.registerToEnable.append(name -> clockEnable) + } + case m : ir.DefMemory => + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") + assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!") + assert(m.writeLatency == 1, "Only write-latency 1 is supported!") + if(m.readers.nonEmpty && m.readLatency == 1) { + unsupportedError("Registers memory read ports are not properly implemented yet :(") + } + ctx.mems(m.name) = m + case other => other.foreachStmt(scanClocksAndResets) + } + } + + // we rely on people not casting clocks or async resets + private def checkExpr(expr: ir.Expression): Unit = expr match { + case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.ClockType => + unsupportedError(s"Clock casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.ClockType => + unsupportedError(s"Clock casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType => + unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType => + unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsAsyncReset, _, _, _) => + unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsClock, _, _, _) => + unsupportedError(s"Clock casts are not supported: ${expr.serialize}") + case other => other.foreachExpr(checkExpr) + } + + private class ScanCtx(initialClockToEnable: Map[String, ir.Reference], initialResetsToChange: Set[String]) { + // keeps track of which clock signals will be replaced by which clock enable signal + val clockToEnable = mutable.HashMap[String, ir.Reference]() ++ initialClockToEnable + // kepp track of asynchronous resets that need to be changed to bool + val resetsToChange = mutable.HashSet[String]() ++ initialResetsToChange + // registers whose next function needs to be guarded with a clock enable + val registerToEnable = mutable.ArrayBuffer[(String, ir.Reference)]() + // memory enables which need to be guarded with clock enables + val memPortToClockEnable = mutable.ArrayBuffer[(String, ir.Reference)]() + // keep track of memory names + val mems = mutable.HashMap[String, ir.DefMemory]() + } + + private class Context(globalClockName: String, scanResults: ScanCtx) { + val globalClock: ir.Reference = ir.Reference(globalClockName, ir.ClockType, firrtl.PortKind, firrtl.SourceFlow) + // keeps track of which clock signals will be replaced by which clock enable signal + val isRemovedClock: String => Boolean = scanResults.clockToEnable.contains + // registers whose next function needs to be guarded with a clock enable + val registerToEnable: Map[String, ir.Reference] = scanResults.registerToEnable.toMap + // memory enables which need to be guarded with clock enables + val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap + // keep track of memory names + val isMem: String => Boolean = scanResults.mems.contains + val memInfo: String => ir.DefMemory = scanResults.mems + val isResetToChange: String => Boolean = scanResults.resetsToChange.contains + } + + private var mainName: String = "" // for debugging + private def unsupportedError(msg: String): Nothing = + throw new UnsupportedFeatureException(s"StutteringClockTransform: [$mainName] $msg") + + private def mux(cond: ir.Expression, a: ir.Expression, b: ir.Expression): ir.Expression = { + ir.Mux(cond, a, b, Utils.mux_type_and_widths(a, b)) + } + private def and(a: ir.Expression, b: ir.Expression): ir.Expression = + ir.DoPrim(PrimOps.And, List(a, b), List(), Bool) + private val Bool = ir.UIntType(ir.IntWidth(1)) +} + +private class UnsupportedFeatureException(s: String) extends PassException(s)
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala new file mode 100644 index 00000000..56f891e6 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala @@ -0,0 +1,61 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class Btor2Spec extends SMTBackendBaseSpec { + + it should "convert a hello world module" in { + val src = + """circuit m: + | module m: + | input clock: Clock + | input a: UInt<8> + | output b: UInt<16> + | b <= a + | assert(clock, eq(a, b), UInt(1), "") + |""".stripMargin + + val expected = + """1 sort bitvec 8 + |2 input 1 a + |3 sort bitvec 16 + |4 uext 3 2 8 + |5 output 4 ; b + |6 sort bitvec 1 + |7 uext 3 2 8 + |8 eq 6 7 4 + |9 not 6 8 + |10 bad 9 ; assert_ + |""".stripMargin + + assert(toBotr2Str(src) == expected) + } + + it should "include FileInfo in the output" in { + val src = + """circuit m: @[circuit 0:0] + | module m: @[module 0:0] + | input clock: Clock @[clock 0:0] + | input a: UInt<8> @[a 0:0] + | output b: UInt<16> @[b 0:0] + | b <= a @[b_a 0:0] + | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0] + |""".stripMargin + + val expected = + """; @ module 0:0 + |1 sort bitvec 8 + |2 input 1 a ; @ a 0:0 + |3 sort bitvec 16 + |4 uext 3 2 8 + |5 output 4 ; b @ b 0:0, b_a 0:0 + |6 sort bitvec 1 + |7 uext 3 2 8 + |8 eq 6 7 4 + |9 not 6 8 + |10 bad 9 ; assert_ @ assert 0:0 + |""".stripMargin + + assert(toBotr2Str(src) == expected) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala new file mode 100644 index 00000000..015ac4a9 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -0,0 +1,224 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { + + def primop(op: String, resTpe: String, inTpes: Seq[String], consts: Seq[Int]): String = { + val inputs = inTpes.zipWithIndex.map { case (tpe, ii) => s" input i$ii : $tpe" }.mkString("\n") + val args = (inTpes.zipWithIndex.map { case (_, ii) => s"i$ii" } ++ consts.map(_.toString)).mkString(", ") + val src = + s"""circuit m: + | module m: + |$inputs + | output res: $resTpe + | res <= $op($args) + | + |""".stripMargin + val sys = toSys(src) + assert(sys.signals.length == 1) + sys.signals.head.e.toString + } + + def primop(signed: Boolean, op: String, resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(), + resAlwaysUnsigned: Boolean = false): String = { + val tpe = if(signed) "SInt" else "UInt" + val resTpe = if(resAlwaysUnsigned) "UInt" else tpe + val inTpes = inWidth.map(w => s"$tpe<$w>") + primop(op, s"$resTpe<$resWidth>", inTpes, consts) + } + + it should "correctly translate the add primitive operation with different operand sizes" in { + assert(primop(false, "add", 5, List(3, 5)) == "add(zext(i0, 3), zext(i1, 1))[4:0]") + assert(primop(false, "add", 5, List(3, 4)) == "add(zext(i0, 2), zext(i1, 1))") + assert(primop(true, "add", 5, List(3, 5)) == "add(sext(i0, 3), sext(i1, 1))[4:0]") + assert(primop(true, "add", 5, List(3, 4)) == "add(sext(i0, 2), sext(i1, 1))") + + // could be simplified to just `add(i0, i1)` + assert(primop(false, "add", 8, List(8, 8)) == "add(zext(i0, 1), zext(i1, 1))[7:0]") + } + + it should "correctly translate the `add` primitive operation" in { + assert(primop(false, "add", 8, List(7, 7)) == "add(zext(i0, 1), zext(i1, 1))") + } + + it should "correctly translate the `sub` primitive operation" in { + assert(primop(false, "sub", 8, List(7, 7)) == "sub(zext(i0, 1), zext(i1, 1))") + } + + it should "correctly translate the `mul` primitive operation" in { + assert(primop(false, "mul", 8, List(4, 4)) == "mul(zext(i0, 4), zext(i1, 4))") + } + + it should "correctly translate the `div` primitive operation" in { + // division is a little bit more complicated because the result of division by zero is undefined + assert(primop(false, "div", 8, List(8, 8)) == + "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))") + assert(primop(false, "div", 8, List(8, 4)) == + "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))") + + // signed division increases result width by 1 + assert(primop(true, "div", 8, List(7, 7)) == + "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))") + assert(primop(true, "div", 8, List(7, 4)) + == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))") + } + + it should "correctly translate the `rem` primitive operation" in { + // rem can decrease the size of operands, but we should only do that decrease on the result + assert(primop(false, "rem", 4, List(4, 8)) == "urem(zext(i0, 4), i1)[3:0]") + assert(primop(false, "rem", 4, List(8, 4)) == "urem(i0, zext(i1, 4))[3:0]") + assert(primop(true, "rem", 4, List(4, 8)) == "srem(sext(i0, 4), i1)[3:0]") + assert(primop(true, "rem", 4, List(8, 4)) == "srem(i0, sext(i1, 4))[3:0]") + // TODO: add test to make sure we are using the correct mod/rem operation for signed and unsigned + // https://groups.google.com/g/stp-users/c/od43h8q5RSI has some tests that we could copy and + // use with a SMT solver + } + + it should "correctly translate the comparison primitive operations" in { + // some comparisons are represented as the negation of others + assert(primop(false, "lt", 1, List(8, 8)) == "not(ugeq(i0, i1))") + assert(primop(false, "leq", 1, List(8, 8)) == "not(ugt(i0, i1))") + assert(primop(false, "gt", 1, List(8, 8)) == "ugt(i0, i1)") + assert(primop(false, "geq", 1, List(8, 8)) == "ugeq(i0, i1)") + assert(primop(false, "eq", 1, List(8, 8)) == "eq(i0, i1)") + assert(primop(false, "neq", 1, List(8, 8)) == "not(eq(i0, i1))") + + assert(primop(true, "lt", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgeq(i0, i1))") + assert(primop(true, "leq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgt(i0, i1))") + assert(primop(true, "gt", 1, List(8, 8), resAlwaysUnsigned = true) == "sgt(i0, i1)") + assert(primop(true, "geq", 1, List(8, 8), resAlwaysUnsigned = true) == "sgeq(i0, i1)") + assert(primop(true, "eq", 1, List(8, 8), resAlwaysUnsigned = true) == "eq(i0, i1)") + assert(primop(true, "neq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(eq(i0, i1))") + + // it should always extend the width to the max of both + assert(primop(false, "gt", 1, List(7, 8)) == "ugt(zext(i0, 1), i1)") + } + + it should "correctly translate the `pad` primitive operation" in { + // firrtl pad takes new width as argument, whereas the smt zext takes the number of bits to extend by + assert(primop(false, "pad", 8, List(3), List(8)) == "zext(i0, 5)") + assert(primop(false, "pad", 8, List(3), List(5)) == "zext(zext(i0, 2), 3)") + + // there is no negative padding, instead the result is just e + assert(primop(false, "pad", 3, List(3), List(2)) == "i0") + + assert(primop(true, "pad", 8, List(3), List(8)) == "sext(i0, 5)") + assert(primop(true, "pad", 8, List(3), List(5)) == "sext(sext(i0, 2), 3)") + } + + it should "correctly translate the asX primitive operations" in { + // these are all essentially no-ops + assert(primop(false, "asUInt", 3, List(3)) == "i0") + assert(primop(true, "asSInt", 3, List(3)) == "i0") + } + + it should "correctly translate the `shl` primitive operation" in { + assert(primop(false, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)") + assert(primop(true, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)") + assert(primop(false, "shl", 3, List(3), List(0)) == "i0") + } + + it should "correctly translate the `shr` primitive operation" in { + assert(primop(false, "shr", 6, List(9), List(3)) == "i0[8:3]") + assert(primop(true, "shr", 6, List(9), List(3)) == "i0[8:3]") + + // "If n is greater than or equal to the bit-width of e, + // the resulting value will be zero for unsigned types and the sign bit for signed types." + assert(primop(false, "shr", 1, List(3), List(3)) == "1'b0") + assert(primop(false, "shr", 1, List(3), List(4)) == "1'b0") + assert(primop(true, "shr", 1, List(3), List(3)) == "i0[2]") + assert(primop(true, "shr", 1, List(3), List(4)) == "i0[2]") + } + + it should "correctly translate the `dshl` primitive operation" in { + assert(primop(false, "dshl", 31, List(16, 4)) == "logical_shift_left(zext(i0, 15), zext(i1, 27))") + assert(primop(false, "dshl", 19, List(16, 2)) == "logical_shift_left(zext(i0, 3), zext(i1, 17))") + assert(primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == + "logical_shift_left(sext(i0, 3), zext(i1, 17))") + } + + it should "correctly translate the `dshr` primitive operation" in { + assert(primop(false, "dshr", 16, List(16, 4)) == "logical_shift_right(i0, zext(i1, 12))") + assert(primop(false, "dshr", 16, List(16, 2)) == "logical_shift_right(i0, zext(i1, 14))") + assert(primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == + "arithmetic_shift_right(i0, zext(i1, 14))") + } + + it should "correctly translate the `cvt` primitive operation" in { + // for signed operands, this is a no-op + assert(primop(true, "cvt", 3, List(3)) == "i0") + + // for unsigned, a zero is prepended + assert(primop("cvt", "SInt<16>", List("UInt<15>"), List()) == "concat(1'b0, i0)") + assert(primop("cvt", "SInt<16>", List("UInt<14>"), List()) == "sext(concat(1'b0, i0), 1)") + } + + it should "correctly translate the `neg` primitive operation" in { + assert(primop(true, "neg", 4, List(3)) == "neg(sext(i0, 1))") + assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "neg(zext(i0, 1))") + } + + it should "correctly translate the `not` primitive operation" in { + assert(primop(false, "not", 4, List(4)) == "not(i0)") + assert(primop("not", "UInt<4>", List("SInt<4>"), List()) == "not(i0)") + } + + it should "correctly translate the binary bitwise primitive operations" in { + assert(primop(false, "and", 4, List(4, 3)) == "and(i0, zext(i1, 1))") + assert(primop("and", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "and(i0, sext(i1, 1))") + + assert(primop(false, "or", 4, List(4, 3)) == "or(i0, zext(i1, 1))") + assert(primop("or", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "or(i0, sext(i1, 1))") + + assert(primop(false, "xor", 4, List(4, 3)) == "xor(i0, zext(i1, 1))") + assert(primop("xor", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "xor(i0, sext(i1, 1))") + } + + it should "correctly translate the bitwise reduction primitive operation" in { + // zero width special cases are removed by the firrtl compiler + assert(primop(false, "andr", 1, List(0)) == "1'b1") + assert(primop(false, "orr", 1, List(0)) == "redor(1'b0)") + assert(primop(false, "xorr", 1, List(0)) == "redxor(1'b0)") + + assert(primop(false, "andr", 1, List(3)) == "redand(i0)") + assert(primop(true, "andr", 1, List(3), resAlwaysUnsigned = true) == "redand(i0)") + + assert(primop(false, "orr", 1, List(3)) == "redor(i0)") + assert(primop(true, "orr", 1, List(3), resAlwaysUnsigned = true) == "redor(i0)") + + assert(primop(false, "xorr", 1, List(3)) == "redxor(i0)") + assert(primop(true, "xorr", 1, List(3), resAlwaysUnsigned = true) == "redxor(i0)") + } + + it should "correctly translate the `cat` primitive operation" in { + assert(primop(false, "cat", 7, List(4, 3)) == "concat(i0, i1)") + assert(primop(true, "cat", 7, List(4, 3), resAlwaysUnsigned = true) == "concat(i0, i1)") + } + + it should "correctly translate the `bits` primitive operation" in { + assert(primop(false, "bits", 1, List(4), List(2,2)) == "i0[2]") + assert(primop(false, "bits", 2, List(4), List(2,1)) == "i0[2:1]") + assert(primop(false, "bits", 1, List(4), List(2,1)) == "i0[2:1][0]") + assert(primop(false, "bits", 3, List(4), List(2,1)) == "zext(i0[2:1], 1)") + + assert(primop(true, "bits", 1, List(4), List(2,2), resAlwaysUnsigned = true) == "i0[2]") + assert(primop(true, "bits", 2, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1]") + assert(primop(true, "bits", 1, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1][0]") + assert(primop(true, "bits", 3, List(4), List(2,1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") + } + + it should "correctly translate the `head` primitive operation" in { + // "The result of the head operation are the n most significant bits of e" + assert(primop(false, "head", 1, List(4), List(1)) == "i0[3]") + assert(primop(false, "head", 1, List(5), List(1)) == "i0[4]") + assert(primop(false, "head", 3, List(5), List(3)) == "i0[4:2]") + } + + it should "correctly translate the `tail` primitive operation" in { + // "The tail operation truncates the n most significant bits from e" + assert(primop(false, "tail", 3, List(4), List(1)) == "i0[2:0]") + assert(primop(false, "tail", 4, List(5), List(1)) == "i0[3:0]") + assert(primop(false, "tail", 2, List(5), List(3)) == "i0[1:0]") + } +}
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala new file mode 100644 index 00000000..ca7974c5 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala @@ -0,0 +1,314 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +import firrtl.{MemoryArrayInit, MemoryScalarInit, Utils} + +private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { + behavior of "ModuleToTransitionSystem.run" + + + it should "model registers as state" in { + // if a signal is invalid, it could take on an arbitrary value in that cycle + val src = + """circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input en : UInt<1> + | input in : UInt<8> + | output out : UInt<8> + | + | reg r : UInt<8>, clock with : (reset => (reset, UInt<8>(0))) + | when en: + | r <= in + | out <= r + | + |""".stripMargin + val sys = toSys(src) + + assert(sys.signals.length == 2) + + // the when is translated as a ITE + val genSignal = sys.signals.filterNot(_.name == "out").head + assert(genSignal.e.toString == "ite(en, in, r)") + + // the reset is synchronous + val r = sys.states.head + assert(r.sym.name == "r") + assert(r.init.isEmpty, "we are not using any preset, so the initial register content is arbitrary") + assert(r.next.get.toString == s"ite(reset, 8'b0, ${genSignal.name})") + } + + private def memCircuit(depth: Int = 32) = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<${Utils.getUIntWidth(depth)}> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => $depth + | reader => r + | writer => w + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + it should "model memories as state" in { + val sys = toSys(memCircuit()) + + assert(sys.signals.length == 9-2+1, "9 connects - 2 clock connects + 1 combinatorial read port") + + val sig = sys.signals.map(s => s.name -> s.e).toMap + + // masks and enables should all be true + val True = BVLiteral(1, 1) + assert(sig("m.w.mask") == True) + assert(sig("m.w.en") == True) + assert(sig("m.r.en") == True) + + // read data should always be enabled + assert(sig("m.r.data").toString == "m[m.r.addr]") + + // the memory is modelled as a state + val m = sys.states.find(_.sym.name == "m").get + assert(m.sym.isInstanceOf[ArraySymbol]) + val sym = m.sym.asInstanceOf[ArraySymbol] + assert(sym.indexWidth == 5) + assert(sym.dataWidth == 8) + assert(m.init.isEmpty) + //assert(m.next.get.toString.contains("m[m.w.addr := m.w.data]")) + assert(m.next.get.toString == "m[m.w.addr := m.w.data]") + } + + it should "support scalar initialization of a memory to 0" in { + val sys = toSys(memCircuit(), memInit = Map("m" -> MemoryScalarInit(0))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 32)") + } + + it should "support scalar initialization of a memory to 127" in { + val sys = toSys(memCircuit(31), memInit = Map("m" -> MemoryScalarInit(127))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b1111111] x 32)") + } + + it should "support array initialization of a memory to Seq(0, 1, 2, 3)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(0, 1, 2, 3)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 4)[2'b1 := 8'b1][2'b10 := 8'b10][2'b11 := 8'b11]") + } + + it should "support array initialization of a memory to Seq(1, 0, 1, 0)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 1, 0)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b1] x 4)[2'b1 := 8'b0][2'b11 := 8'b0]") + } + + it should "support array initialization of a memory to Seq(1, 0, 0, 0)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 0, 0)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 4)[2'b0 := 8'b1]") + } + + it should "support array initialization from a file" ignore { + assert(false, "TODO") + } + + it should "support memories with registered read port" in { + def src(readUnderWrite: String) = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + + val oldValue = toSys(src("old")) + val oldMData = oldValue.states.find(_.sym.name == "m.r.data").get + assert(oldMData.sym.toString == "m.r.data") + assert(oldMData.next.get.toString == "m[m.r.addr]", "we just need to read the current value") + val readDataSignal = oldValue.signals.find(_.name == "m.r.data") + assert(readDataSignal.isEmpty, s"${readDataSignal.map(_.toString)} should not exist") + + val undefinedValue = toSys(src("undefined")) + val undefinedMData = undefinedValue.states.find(_.sym.name == "m.r.data").get + assert(undefinedMData.sym.toString == "m.r.data") + val undefined = "RANDOM.m_r_read_under_write_undefined" + assert(undefinedMData.next.get.toString == + s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", + "randomize result if there is a write") + } + + it should "support memories with potential write-write conflicts" in { + val src = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + + val sys = toSys(src) + val m = sys.states.find(_.sym.name == "m").get + + val regularUpdate = "m[m.w1.addr := m.w1.data][m.w2.addr := m.w2.data]" + val collision = "eq(m.w1.addr, m.w2.addr)" + val collisionUpdate = "m[m.w1.addr := RANDOM.m_w1_write_write_collision]" + + assert(m.next.get.toString == s"ite($collision, $collisionUpdate, $regularUpdate)") + } + + it should "model invalid signals as inputs" in { + // if a signal is invalid, it could take on an arbitrary value in that cycle + val src = + """circuit m: + | module m: + | input en : UInt<1> + | output o : UInt<8> + | o is invalid + | when en: + | o <= UInt<8>(0) + |""".stripMargin + val sys = toSys(src) + assert(sys.inputs.length == 2) + val random = sys.inputs.filter(_.name.contains("RANDOM")) + assert(random.length == 1) + assert(random.head.width == 8) + } + + it should "throw an error on async reset" in { + val err = intercept[AsyncResetException] { + toSys( + """circuit m: + | module m: + | input reset : AsyncReset + |""".stripMargin + ) + } + assert(err.getMessage.contains("reset")) + } + + it should "throw an error on casting to async reset" in { + val err = intercept[AssertionError] { + toSys( + """circuit m: + | module m: + | input reset : UInt<1> + | node async = asAsyncReset(reset) + |""".stripMargin + ) + } + assert(err.getMessage.contains("reset")) + } + + it should "throw an error on multiple clocks" in { + val err = intercept[MultiClockException] { + toSys( + """circuit m: + | module m: + | input clk1 : Clock + | input clk2 : Clock + |""".stripMargin + ) + } + assert(err.getMessage.contains("clk1, clk2")) + } + + it should "throw an error on using a clock as uInt" in { + // While this could potentially be supported in the future, for now we do not allow + // a clock to be used for anything besides updating registers and memories. + val err = intercept[AssertionError] { + toSys( + """circuit m: + | module m: + | input clk : Clock + | output o : UInt<1> + | o <= asUInt(clk) + | + |""".stripMargin + ) + } + assert(err.getMessage.contains("clk")) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala new file mode 100644 index 00000000..6bfb5437 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala @@ -0,0 +1,38 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +import firrtl.annotations.Annotation +import firrtl.{MemoryInitValue, ir} +import firrtl.stage.{Forms, TransformManager} +import org.scalatest.flatspec.AnyFlatSpec + +private abstract class SMTBackendBaseSpec extends AnyFlatSpec { + private val dependencies = Forms.LowForm + private val compiler = new TransformManager(dependencies) + + protected def compile(src: String, annos: Seq[Annotation] = List()): ir.Circuit = { + val c = firrtl.Parser.parse(src) + compiler.runTransform(firrtl.CircuitState(c, annos)).circuit + } + + protected def toSys(src: String, mod: String = "m", presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + val circuit = compile(src) + val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module] + // println(module.serialize) + new ModuleToTransitionSystem().run(module, presetRegs = presetRegs, memInit = memInit) + } + + protected def toBotr2(src: String, mod: String = "m"): Iterable[String] = + Btor2Serializer.serialize(toSys(src, mod)) + + protected def toBotr2Str(src: String, mod: String = "m"): String = + toBotr2(src, mod).mkString("\n") + "\n" + + protected def toSMTLib(src: String, mod: String = "m"): Iterable[String] = + SMTTransitionSystemEncoder.encode(toSys(src, mod)).map(SMTLibSerializer.serialize) + + protected def toSMTLibStr(src: String, mod: String = "m"): String = + toSMTLib(src, mod).mkString("\n") + "\n" +}
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala new file mode 100644 index 00000000..7193474d --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala @@ -0,0 +1,66 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class SMTLibSpec extends SMTBackendBaseSpec { + + it should "convert a hello world module" in { + val src = + """circuit m: + | module m: + | input clock: Clock + | input a: UInt<8> + | output b: UInt<16> + | b <= a + | assert(clock, eq(a, b), UInt(1), "") + |""".stripMargin + + val expected = + """(declare-sort m_s 0) + |; firrtl-smt2-input a 8 + |(declare-fun a_f (m_s) (_ BitVec 8)) + |; firrtl-smt2-output b 16 + |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state))) + |; firrtl-smt2-assert assert_ 1 + |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state))) + |(define-fun m_t ((state m_s) (state_n m_s)) Bool true) + |(define-fun m_i ((state m_s)) Bool true) + |(define-fun m_a ((state m_s)) Bool (assert__f state)) + |(define-fun m_u ((state m_s)) Bool true) + |""".stripMargin + + assert(toSMTLibStr(src) == expected) + } + + it should "include FileInfo in the output" in { + val src = + """circuit m: @[circuit 0:0] + | module m: @[module 0:0] + | input clock: Clock @[clock 0:0] + | input a: UInt<8> @[a 0:0] + | output b: UInt<16> @[b 0:0] + | b <= a @[b_a 0:0] + | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0] + |""".stripMargin + + val expected = + """; @ module 0:0 + |(declare-sort m_s 0) + |; firrtl-smt2-input a 8 + |; @ a 0:0 + |(declare-fun a_f (m_s) (_ BitVec 8)) + |; firrtl-smt2-output b 16 + |; @ b 0:0, b_a 0:0 + |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state))) + |; firrtl-smt2-assert assert_ 1 + |; @ assert 0:0 + |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state))) + |(define-fun m_t ((state m_s) (state_n m_s)) Bool true) + |(define-fun m_i ((state m_s)) Bool true) + |(define-fun m_a ((state m_s)) Bool (assert__f state)) + |(define-fun m_u ((state m_s)) Bool true) + |""".stripMargin + + assert(toSMTLibStr(src) == expected) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala new file mode 100644 index 00000000..4c6901ea --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.CircuitTarget +import firrtl.backends.experimental.smt.{GlobalClockAnnotation, StutteringClockTransform} +import firrtl.options.Dependency +import firrtl.stage.RunFirrtlTransformAnnotation + +class AsyncResetSpec extends EndToEndSMTBaseSpec { + def annos(name: String) = Seq( + RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]), + GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock"))) + + "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs(RequiresZ3) in { + def in(resetType: String) = + s"""circuit AsyncReset00: + | module AsyncReset00: + | input global_clock: Clock + | input c: Clock + | input reset: $resetType + | input preset: AsyncReset + | + | ; a register with async reset + | reg r: UInt<4>, c with: (reset => (reset, UInt(3))) + | + | ; a counter/toggler connected to the clock c + | reg count: UInt<1>, c with: (reset => (preset, UInt(0))) + | count <= add(count, UInt(1)) + | + | ; the past machinery and the assertion uses the global clock + | reg past_valid: UInt<1>, global_clock with: (reset => (preset, UInt(0))) + | past_valid <= UInt(1) + | reg past_r: UInt<4>, global_clock + | past_r <= r + | reg past_count: UInt<1>, global_clock + | past_count <= count + | + | ; can the value of r change without the count changing? + | assert(global_clock, or(not(eq(count, past_count)), eq(r, past_r)), past_valid, "count = past(count) |-> r = past(r)") + |""".stripMargin + test(in("AsyncReset"), MCFail(1), kmax=2, annos=annos("AsyncReset00")) + test(in("UInt<1>"), MCSuccess, kmax=2, annos=annos("AsyncReset00")) + } + +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala new file mode 100644 index 00000000..2227719b --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala @@ -0,0 +1,230 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import java.io.{File, PrintWriter} + +import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation} +import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} +import firrtl.options.TargetDirAnnotation +import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlStage, OutputFileAnnotation, RunFirrtlTransformAnnotation} +import firrtl.util.BackendCompilationUtilities +import logger.{LazyLogging, LogLevel, LogLevelAnnotation} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.must.Matchers + +import scala.sys.process._ + +class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { + "we" should "check if Z3 is available" taggedAs(RequiresZ3) in { + val log = ProcessLogger(_ => (), logger.warn(_)) + val ret = Process(Seq("which", "z3")).run(log).exitValue() + if(ret != 0) { + logger.error( + """The z3 SMT-Solver seems not to be installed. + |You can exclude the end-to-end smt backend tests which rely on z3 like this: + |sbt testOnly -- -l RequiresZ3 + |""".stripMargin) + } + assert(ret == 0) + } + + "Z3" should "be available in version 4" taggedAs(RequiresZ3) in { + assert(Z3ModelChecker.getZ3Version.startsWith("4.")) + } + + "a simple combinatorial check" should "pass" taggedAs(RequiresZ3) in { + val in = + """circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, lt(add(a, b), UInt(3)), UInt(1), "a + b < 3") + |""".stripMargin + test(in, MCSuccess) + } + + "a simple combinatorial check" should "fail immediately" taggedAs(RequiresZ3) in { + val in = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, gt(add(a, b), UInt(3)), UInt(1), "a + b > 3") + |""".stripMargin + test(in, MCFail(0)) + } + + "adding the right assumption" should "make a test pass" taggedAs(RequiresZ3) in { + val in0 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + |""".stripMargin + val in1 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + | assume(c, neq(a, UInt(0)), UInt(1), "a != 0") + |""".stripMargin + test(in0, MCFail(0)) + test(in1, MCSuccess) + + val in2 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | input en: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + | assume(c, neq(a, UInt(0)), en, "a != 0 if en") + |""".stripMargin + test(in2, MCFail(0)) + } + + "a register connected to preset reset" should "be initialized with the reset value" taggedAs(RequiresZ3) in { + def in(rEq: Int) = + s"""circuit Preset00: + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) + | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") + |""".stripMargin + test(in(3), MCSuccess, kmax = 1) + test(in(2), MCFail(0)) + } + + "a register's initial value" should "should not change" taggedAs(RequiresZ3) in { + val in = + """circuit Preset00: + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | + | ; the past value of our register will only be valid in the 1st unrolling + | reg past_valid: UInt<1>, c with: (reset => (preset, UInt(0))) + | past_valid <= UInt(1) + | + | reg r: UInt<4>, c + | reg r_past: UInt<4>, c + | r_past <= r + | assert(c, eq(r, r_past), past_valid, "past_valid => r == r_past") + |""".stripMargin + test(in, MCSuccess, kmax = 2) + } +} + +abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { + def test(src: String, expected: MCResult, kmax: Int = 0, clue: String = "", annos: Seq[Annotation] = Seq()): Unit = { + expected match { + case MCFail(k) => assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") + case _ => + } + val fir = firrtl.Parser.parse(src) + val name = fir.main + val testDir = BackendCompilationUtilities.createTestDirectory("EndToEndSMT." + name) + // we automagically add a preset annotation if an input called preset exists + val presetAnno = if(!src.contains("input preset")) { None } else { + Some(PresetAnnotation(CircuitTarget(name).module(name).ref("preset"))) + } + val res = (new FirrtlStage).execute(Array(), Seq( + LogLevelAnnotation(LogLevel.Error), // silence warnings for tests + RunFirrtlTransformAnnotation(new SMTLibEmitter), + RunFirrtlTransformAnnotation(new Btor2Emitter), + FirrtlCircuitAnnotation(fir), + TargetDirAnnotation(testDir.getAbsolutePath) + ) ++ presetAnno ++ annos) + assert(res.collectFirst{ case _: OutputFileAnnotation => true }.isDefined) + val r = Z3ModelChecker.bmc(testDir, name, kmax) + assert(r == expected, clue + "\n" + s"$testDir") + } +} + +/** Minimal implementation of a Z3 based bounded model checker. + * A more complete version of this with better use feedback should eventually be provided by a + * chisel3 formal verification library. Do not use this implementation outside of the firrtl test suite! + * */ +private object Z3ModelChecker extends LazyLogging { + def getZ3Version: String = { + val (out, ret) = executeCmd("-version") + assert(ret == 0, "failed to call z3") + assert(out.startsWith("Z3 version"), s"$out does not start with 'Z3 version'") + val version = out.split(" ")(2) + version + } + + def bmc(testDir: File, main: String, kmax: Int): MCResult = { + assert(kmax >=0 && kmax < 50, "Trying to keep kmax in a reasonable range.") + val smtFile = new File(testDir, main + ".smt2") + val header = read(smtFile) + val steps = (0 to kmax).map(k => new File(testDir, main + s"_step$k.smt2")).zipWithIndex + steps.foreach { case (f,k) => + writeStep(f, main, header, k) + val success = executeStep(f.getAbsolutePath) + if(!success) return MCFail(k) + } + MCSuccess + } + + private def executeStep(filename: String): Boolean = { + val (out, ret) = executeCmd(filename) + assert(ret == 0, s"expected success (0), not $ret: `$out`\nz3 $filename") + assert(out == "sat" || out == "unsat", s"Unexpected output: $out") + out == "unsat" + } + + private def executeCmd(cmd: String): (String, Int) = { + var out = "" + val log = ProcessLogger(s => out = s, logger.warn(_)) + val ret = Process(Seq("z3", cmd)).run(log).exitValue() + (out, ret) + } + + private def writeStep(f: File, main: String, header: Iterable[String], k: Int): Unit = { + val pw = new PrintWriter(f) + val lines = header ++ step(main, k) ++ List("(check-sat)") + lines.foreach(pw.println) + pw.close() + } + + private def step(main: String, k: Int): Iterable[String] = { + // define all states + (0 to k).map(ii => s"(declare-fun s$ii () $main$StateTpe)") ++ + // assert that init holds in state 0 + List(s"(assert ($main$Init s0))") ++ + // assert transition relation + (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii+1}))") ++ + // assert that assumptions hold in all states + (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ + // assert that assertions hold for all but last state + (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ + // check to see if we can violate the assertions in the last state + List(s"(assert (not ($main$Asserts s$k)))") + } + + private def read(f: File): Iterable[String] = { + val source = scala.io.Source.fromFile(f) + try source.getLines().toVector finally source.close() + } + + // the following suffixes have to match the ones in [[SMTTransitionSystemEncoder]] + private val Transition = "_t" + private val Init = "_i" + private val Asserts = "_a" + private val Assumes = "_u" + private val StateTpe = "_s" +} + +private sealed trait MCResult +private case object MCSuccess extends MCResult +private case class MCFail(k: Int) extends MCResult diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala new file mode 100644 index 00000000..10de9cda --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -0,0 +1,107 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.{CircuitTarget, MemoryArrayInitAnnotation, MemoryScalarInitAnnotation} + +class MemorySpec extends EndToEndSMTBaseSpec { + private def registeredTestMem(name: String, cmds: String, readUnderWrite: String): String = + registeredTestMem(name, cmds.split("\n"), readUnderWrite) + private def registeredTestMem(name: String, cmds: Iterable[String], readUnderWrite: String): String = + s"""circuit $name: + | module $name: + | input reset : UInt<1> + | input clock : Clock + | input preset: AsyncReset + | input write_addr : UInt<5> + | input read_addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= write_addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) + | cycle <= add(cycle, UInt(1)) + | node past_valid = geq(cycle, UInt(1)) + | + | ${cmds.mkString("\n ")} + |""".stripMargin + + "Registered test memory" should "return written data after two cycles" taggedAs(RequiresZ3) in { + val cmds = + """node past_past_valid = geq(cycle, UInt(2)) + |reg past_in: UInt<8>, clock + |past_in <= in + |reg past_past_in: UInt<8>, clock + |past_past_in <= past_in + |reg past_write_addr: UInt<5>, clock + |past_write_addr <= write_addr + | + |assume(clock, eq(read_addr, past_write_addr), past_valid, "read_addr = past(write_addr)") + |assert(clock, eq(out, past_past_in), past_past_valid, "out = past(past(in))") + |""".stripMargin + test(registeredTestMem("Mem00", cmds, "old"), MCSuccess, kmax = 3) + } + + private def readOnlyMem(pred: String, num: Int) = + s"""circuit Mem0$num: + | module Mem0$num: + | input c : Clock + | input read_addr : UInt<2> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 4 + | reader => r + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | assert(c, $pred, UInt(1), "") + |""".stripMargin + private def m(num: Int) = CircuitTarget(s"Mem0$num").module(s"Mem0$num").ref("m") + + "read-only memory" should "always return 0" taggedAs(RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax=2, + annos=Seq(MemoryScalarInitAnnotation(m(1), 0))) + } + + "read-only memory" should "not always return 1" taggedAs(RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax=2, + annos=Seq(MemoryScalarInitAnnotation(m(2), 0))) + } + + "read-only memory" should "always return 1 or 2" taggedAs(RequiresZ3) in { + test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), MCSuccess, kmax=2, + annos=Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1)))) + } + + "read-only memory" should "not always return 1 or 2 or 3" taggedAs(RequiresZ3) in { + test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), MCFail(0), kmax=2, + annos=Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3)))) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala new file mode 100644 index 00000000..d633a1a0 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala @@ -0,0 +1,9 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import org.scalatest.Tag + +// To disable tests that require the Z3 SMT solver to be installed use the following: +// `sbt testOnly -- -l RequiresZ3` +object RequiresZ3 extends Tag("RequiresZ3") diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala new file mode 100644 index 00000000..cbf194dd --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import java.io.File + +import firrtl.stage.{FirrtlStage, OutputFileAnnotation} +import firrtl.util.BackendCompilationUtilities +import logger.LazyLogging +import org.scalatest.flatspec.AnyFlatSpec + +import scala.sys.process.{Process, ProcessLogger} + +/** compiles the regression tests to SMTLib and parses the result with z3 */ +class SMTCompilationTest extends AnyFlatSpec with LazyLogging { + it should "generate valid SMTLib for AddNot" taggedAs(RequiresZ3) in { compileAndParse("AddNot") } + it should "generate valid SMTLib for FPU" taggedAs(RequiresZ3) in { compileAndParse("FPU") } + // we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer + it should "generate valid SMTLib for HwachaSequencer" taggedAs(RequiresZ3) ignore { compileAndParse("HwachaSequencer") } + it should "generate valid SMTLib for ICache" taggedAs(RequiresZ3) in { compileAndParse("ICache") } + it should "generate valid SMTLib for Ops" taggedAs(RequiresZ3) in { compileAndParse("Ops") } + // TODO: enable Rob test once we support more than 2 write ports on a memory + it should "generate valid SMTLib for Rob" taggedAs(RequiresZ3) ignore { compileAndParse("Rob") } + it should "generate valid SMTLib for RocketCore" taggedAs(RequiresZ3) in { compileAndParse("RocketCore") } + + private def compileAndParse(name: String): Unit = { + val testDir = BackendCompilationUtilities.createTestDirectory(name + "-smt") + val inputFile = new File(testDir, s"${name}.fir") + BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", inputFile) + + val args = Array( + "-ll", "error", // surpress warnings to keep test output clean + "--target-dir", testDir.toString, + "-i", inputFile.toString, + "-E", "experimental-smt2" + // "-fct", "firrtl.backends.experimental.smt.StutteringClockTransform" + ) + val res = (new FirrtlStage).execute(args, Seq()) + val fileName = res.collectFirst{ case OutputFileAnnotation(file) => file }.get + + val smtFile = testDir.toString + "/" + fileName + ".smt2" + val log = ProcessLogger(_ => (), logger.error(_)) + val z3Ret = Process(Seq("z3", smtFile)).run(log).exitValue() + assert(z3Ret == 0, s"Failed to parse SMTLib file $smtFile generated for $name") + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala new file mode 100644 index 00000000..8fa80b4c --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala @@ -0,0 +1,40 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +/** undefined values in firrtl are modelled as fresh auxiliary variables (inputs) */ +class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec { + + "division by zero" should "result in an arbitrary value" taggedAs(RequiresZ3) in { + // the SMTLib spec defines the result of division by zero to be all 1s + // https://cs.nyu.edu/pipermail/smt-lib/2015/000977.html + def in(dEq: Int) = + s"""circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<2> + | input b: UInt<2> + | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") + | node d = div(a, b) + | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") + |""".stripMargin + // we try to assert that (d = a / 0) is any fixed value which should be false + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") } + } + + // TODO: rem should probably also be undefined, but the spec isn't 100% clear here + + + "invalid signals" should "have an arbitrary values" taggedAs(RequiresZ3) in { + def in(aEq: Int) = + s"""circuit CC00: + | module CC00: + | input c: Clock + | wire a: UInt<2> + | a is invalid + | assert(c, eq(a, UInt($aEq)), UInt(1), "a = $aEq") + |""".stripMargin + // a should not be equivalent to any fixed value (0, 1, 2 or 3) + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"a = $ii") } + } +} |
