diff options
Diffstat (limited to 'src')
21 files changed, 3050 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 432fa59e..ae9a7dad 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -14,7 +14,8 @@ import firrtl.PrimOps._ import firrtl.WrappedExpression._ import Utils._ import MemPortUtils.{memPortField, memType} -import firrtl.options.{HasShellOptions, CustomFileEmission, ShellOption, PhaseException} +import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} +import firrtl.options.{CustomFileEmission, HasShellOptions, PhaseException, ShellOption} import firrtl.options.Viewer.view import firrtl.stage.{FirrtlFileAnnotation, FirrtlOptions, RunFirrtlTransformAnnotation, TransformManager} // Datastructures @@ -49,9 +50,14 @@ object EmitCircuitAnnotation extends HasShellOptions { EmitCircuitAnnotation(classOf[VerilogEmitter])) case "sverilog" => Seq(RunFirrtlTransformAnnotation(new SystemVerilogEmitter), EmitCircuitAnnotation(classOf[SystemVerilogEmitter])) + case "experimental-btor2" => Seq(RunFirrtlTransformAnnotation(new Btor2Emitter), + EmitCircuitAnnotation(classOf[Btor2Emitter])) + case "experimental-smt2" => Seq(RunFirrtlTransformAnnotation(new SMTLibEmitter), + EmitCircuitAnnotation(classOf[SMTLibEmitter])) case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") }, helpText = "Run the specified circuit emitter (all modules in one file)", shortOption = Some("E"), + // the experimental options are intentionally excluded from the help message helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>") ) ) } diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala new file mode 100644 index 00000000..f7ab9927 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala @@ -0,0 +1,189 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import scala.collection.mutable + +private object Btor2Serializer { + def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = { + new Btor2Serializer().run(sys, skipOutput) + } +} + +private class Btor2Serializer private () { + private val symbols = mutable.HashMap[String, Int]() + private val lines = mutable.ArrayBuffer[String]() + private var index = 1 + + private def line(l: String): Int = { + val ii = index + lines += s"$ii $l" + index += 1 + ii + } + + private def comment(c: String): Unit = { lines += s"; $c" } + private def trailingComment(c: String): Unit = { + val lastLine = lines.last + val newLine = if(lastLine.contains(';')) { lastLine + " " + c} else { lastLine + " ; " + c } + lines(lines.size - 1) = newLine + } + + // bit vector type serialization + private val bitVecTypeCache = mutable.HashMap[Int, Int]() + + private def t(width: Int): Int = bitVecTypeCache.getOrElseUpdate(width, line(s"sort bitvec $width")) + + // bit vector expression serialization + private def s(expr: BVExpr): Int = expr match { + case BVLiteral(value, width) => lit(value, width) + case BVSymbol(name, _) => symbols(name) + case BVExtend(e, 0, _) => s(e) + case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by") + case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by") + case BVSlice(e, hi, lo) => + if (lo == 0 && hi == e.width - 1) { s(e) } else { + line(s"slice ${t(expr.width)} ${s(e)} $hi $lo") + } + case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b) + case BVNot(BVNot(e)) => s(e) + case BVNot(e) => unary("not", expr.width, e) + case BVNegate(e) => unary("neg", expr.width, e) + case BVReduceAnd(e) => unary("redand", expr.width, e) + case BVReduceOr(e) => unary("redor", expr.width, e) + case BVReduceXor(e) => unary("redxor", expr.width, e) + case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b) + case BVImplies(a, b) => binary("implies", expr.width, a, b) + case BVEqual(a, b) => binary("eq", expr.width, a, b) + case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}") + case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b) + case BVComparison(Compare.GreaterEqual, a, b, false) => binary("ugte", expr.width, a, b) + case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b) + case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b) + case BVOp(op, a, b) => binary(s(op), expr.width, a, b) + case BVConcat(a, b) => binary("concat", expr.width, a, b) + case ArrayRead(array, index) => + line(s"read ${t(expr.width)} ${s(array)} ${s(index)}") + case BVIte(cond, tru, fals) => + line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}") + case r : BVRawExpr => + throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}") + } + + private def s(op: Op.Value): String = op match { + case Op.And => "and" + case Op.Or => "or" + case Op.Xor => "xor" + case Op.ArithmeticShiftRight => "sra" + case Op.ShiftRight => "srl" + case Op.ShiftLeft => "sll" + case Op.Add => "add" + case Op.Mul => "mul" + case Op.Sub => "sub" + case Op.SignedDiv => "sdiv" + case Op.UnsignedDiv => "udiv" + case Op.SignedMod => "smod" + case Op.SignedRem => "srem" + case Op.UnsignedRem => "urem" + } + + private def unary(op: String, width: Int, e: BVExpr): Int = line(s"$op ${t(width)} ${s(e)}") + + private def binary(op: String, width: Int, a: BVExpr, b: BVExpr): Int = + line(s"$op ${t(width)} ${s(a)} ${s(b)}") + + private def lit(value: BigInt, w: Int): Int = { + val typ = t(w) + lazy val mask = (BigInt(1) << w) - 1 + if (value == 0) line(s"zero $typ") + else if (value == 1) line(s"one $typ") + else if (value == mask) line(s"ones $typ") + else { + val digits = value.toString(2) + val padded = digits.reverse.padTo(w, '0').reverse + line(s"const $typ $padded") + } + } + + // array type serialization + private val arrayTypeCache = mutable.HashMap[(Int, Int), Int]() + + private def t(indexWidth: Int, dataWidth: Int): Int = + arrayTypeCache.getOrElseUpdate((indexWidth, dataWidth), line(s"sort array ${t(indexWidth)} ${t(dataWidth)}")) + + // array expression serialization + private def s(expr: ArrayExpr): Int = expr match { + case ArraySymbol(name, _, _) => symbols(name) + case ArrayStore(array, index, data) => + line(s"write ${t(expr.indexWidth, expr.dataWidth)} ${s(array)} ${s(index)} ${s(data)}") + case ArrayIte(cond, tru, fals) => + // println("WARN: ITE on array is probably not supported by btor2") + // While the spec does not seem to allow array ite, it seems to be supported in practice. + // It is essential to model memories, so any support in the wild should be fairly well tested. + line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}") + case ArrayConstant(e, _) => s(e) + case r : ArrayRawExpr => + throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}") + } + + private def s(expr: SMTExpr): Int = expr match { + case b: BVExpr => s(b) + case a: ArrayExpr => s(a) + } + + // serialize the type of the expression + private def t(expr: SMTExpr): Int = expr match { + case b: BVExpr => t(b.width) + case a: ArrayExpr => t(a.indexWidth, a.dataWidth) + } + + def run(sys: TransitionSystem, skipOutput: Boolean): Iterable[String] = { + def declare(name: String, expr: => Int): Unit = { + assert(!symbols.contains(name), s"Trying to redeclare `$name`") + val id = expr + symbols(name) = id + if (!skipOutput && sys.outputs.contains(name)) line(s"output $id ; $name") + if (sys.assumes.contains(name)) line(s"constraint $id ; $name") + if (sys.asserts.contains(name)){ + val invertedId = line(s"not ${t(1)} $id") + line(s"bad $invertedId ; $name") + } + if (sys.fair.contains(name)) line(s"fair $id ; $name") + // add trailing comment + sys.comments.get(name).foreach(trailingComment) + } + + // header + sys.header.foreach(comment) + + // declare inputs + sys.inputs.foreach { ii => + declare(ii.name, line(s"input ${t(ii.width)} ${ii.name}")) + } + + // define state init + sys.states.foreach { st => + // calculate init expression before declaring the state + // this is required by btormc (presumably to avoid cycles in the init expression) + val initId = st.init.map { init => comment(s"${st.sym}.init"); s(init) } + declare(st.sym.name, line(s"state ${t(st.sym)} ${st.sym.name}")) + st.init.foreach { init => line(s"init ${t(init)} ${s(st.sym)} ${initId.get}") } + } + + // define all other signals + sys.signals.foreach { signal => + declare(signal.name, s(signal.e)) + } + + // define state next + sys.states.foreach { st => + st.next.foreach { next => + comment(s"${st.sym}.next") + line(s"next ${t(next)} ${s(st.sym)} ${s(next)}") + } + } + + lines + } +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala new file mode 100644 index 00000000..0a223840 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala @@ -0,0 +1,182 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import firrtl.ir +import firrtl.PrimOps +import firrtl.passes.CheckWidths.WidthTooBig + +private trait TranslationContext { + def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe)) + def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe)) + def getRandom(width: Int): BVExpr +} + +private object FirrtlExpressionSemantics { + def getWidth(tpe: ir.Type): Int = tpe match { + case ir.UIntType(ir.IntWidth(w)) => w.toInt + case ir.SIntType(ir.IntWidth(w)) => w.toInt + case ir.ClockType => 1 + case ir.ResetType => 1 + case ir.AnalogType(ir.IntWidth(w)) => w.toInt + case other => throw new RuntimeException(s"Cannot handle type $other") + } + + def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = { + val eSMT = e match { + case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts) + case r : ir.Reference => ctx.getReference(r.serialize, r.tpe) + case r : ir.SubField => ctx.getReference(r.serialize, r.tpe) + case r : ir.SubIndex => ctx.getReference(r.serialize, r.tpe) + case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt) + case ir.SIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt) + case ir.Mux(cond, tval, fval, _) => + val width = List(tval, fval).map(getWidth).max + BVIte(toSMT(cond), toSMT(tval, width), toSMT(fval, width)) + case ir.ValidIf(cond, value, tpe) => + val tru = toSMT(value) + BVIte(toSMT(cond), tru, ctx.getRandom(tpe)) + } + assert(eSMT.width == getWidth(e), "We aim to always produce a SMT expression of the same width as the firrtl expression.") + eSMT + } + + /** Ensures that the result has the desired width by appropriately extending it. */ + def toSMT(e: ir.Expression, width: Int, allowNarrow: Boolean = false)(implicit ctx: TranslationContext): BVExpr = + forceWidth(toSMT(e), isSigned(e), width, allowNarrow) + + private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = { + if(eSMT.width == width) { eSMT } + else if(width < eSMT.width) { + assert(allowNarrow, s"Narrowing from ${eSMT.width} bits to $width bits is not allowed!") + BVSlice(eSMT, width - 1, 0) + } else { + BVExtend(eSMT, width - eSMT.width, eSigned) + } + } + + // see "Primitive Operations" section in the Firrtl Specification + private def onPrim(op: ir.PrimOp, args: Seq[ir.Expression], consts: Seq[BigInt])(implicit ctx: TranslationContext): + BVExpr = { + (op, args, consts) match { + case (PrimOps.Add, Seq(e1, e2), _) => + val width = args.map(getWidth).max + 1 + BVOp(Op.Add, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Sub, Seq(e1, e2), _) => + val width = args.map(getWidth).max + 1 + BVOp(Op.Sub, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Mul, Seq(e1, e2), _) => + val width = args.map(getWidth).sum + BVOp(Op.Mul, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Div, Seq(num, den), _) => + val (width, op) = if(isSigned(num)) { + (getWidth(num) + 1, Op.SignedDiv) + } else { (getWidth(num), Op.UnsignedDiv) } + // "The result of a division where den is zero is undefined." + val undef = ctx.getRandom(width) + val denSMT = toSMT(den) + val denIsZero = BVEqual(denSMT, BVLiteral(0, denSMT.width)) + val numByDen = BVOp(op, toSMT(num, width), forceWidth(denSMT, isSigned(den), width)) + BVIte(denIsZero, undef, numByDen) + case (PrimOps.Div, Seq(num, den), _) if isSigned(num) => + val width = getWidth(num) + 1 + BVOp(Op.SignedDiv, toSMT(num, width), toSMT(den, width)) + case (PrimOps.Rem, Seq(num, den), _) => + val op = if(isSigned(num)) Op.SignedRem else Op.UnsignedRem + val width = args.map(getWidth).max + val resWidth = args.map(getWidth).min + val res = BVOp(op, toSMT(num, width), toSMT(den, width)) + if(res.width > resWidth) { BVSlice(res, resWidth - 1, 0) } else { res } + case (PrimOps.Lt, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVNot(BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1))) + case (PrimOps.Leq, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVNot(BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1))) + case (PrimOps.Gt, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1)) + case (PrimOps.Geq, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1)) + case (PrimOps.Eq, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVEqual(toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Neq, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVNot(BVEqual(toSMT(e1, width), toSMT(e2, width))) + case (PrimOps.Pad, Seq(e), Seq(n)) => + val width = getWidth(e) + if(n <= width) { toSMT(e) } else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) } + case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e) ; toSMT(e) + case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e) ; toSMT(e) + case (PrimOps.AsFixedPoint, Seq(e), _) => throw new AssertionError("Fixed-Point numbers need to be lowered!") + case (PrimOps.AsClock, Seq(e), _) => toSMT(e) + case (PrimOps.AsAsyncReset, Seq(e), _) => + checkForClockInCast(PrimOps.AsAsyncReset, e) + throw new AssertionError(s"Asynchronous resets are not supported! Cannot cast ${e.serialize}.") + case (PrimOps.Shl, Seq(e), Seq(n)) => if(n == 0) { toSMT(e) } else { + val zeros = BVLiteral(0, n.toInt) + BVConcat(toSMT(e), zeros) + } + case (PrimOps.Shr, Seq(e), Seq(n)) => + val width = getWidth(e) + // "If n is greater than or equal to the bit-width of e, + // the resulting value will be zero for unsigned types + // and the sign bit for signed types" + if(n >= width) { + if(isSigned(e)) { BV1BitZero } else { BVSlice(toSMT(e), width - 1, width - 1) } + } else { + BVSlice(toSMT(e), width - 1, n.toInt) + } + case (PrimOps.Dshl, Seq(e1, e2), _) => + val width = getWidth(e1) + (1 << getWidth(e2)) - 1 + BVOp(Op.ShiftLeft, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Dshr, Seq(e1, e2), _) => + val width = getWidth(e1) + val o = if(isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight + BVOp(o, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Cvt, Seq(e), _) => if(isSigned(e)) { toSMT(e) } else { BVConcat(BV1BitZero, toSMT(e)) } + case (PrimOps.Neg, Seq(e), _) => BVNegate(BVExtend(toSMT(e), 1, isSigned(e))) + case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e)) + case (PrimOps.And, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVOp(Op.And, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Or, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVOp(Op.Or, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Xor, Seq(e1, e2), _) => + val width = args.map(getWidth).max + BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width)) + case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e)) + case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e)) + case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e)) + case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2)) + case (PrimOps.Bits, Seq(e), Seq(hi, lo)) => BVSlice(toSMT(e), hi.toInt, lo.toInt) + case (PrimOps.Head, Seq(e), Seq(n)) => + val width = getWidth(e) + assert(n >= 0 && n <= width) + BVSlice(toSMT(e), width - 1, width - n.toInt) + case (PrimOps.Tail, Seq(e), Seq(n)) => + val width = getWidth(e) + assert(n >= 0 && n <= width) + assert(n < width, "While allowed by the firrtl standard, we do not support 0-bit values in this backend!") + BVSlice(toSMT(e), width - n.toInt - 1, 0) + } + } + + /** For now we strictly forbid casting clocks to anything else. + * Eventually this should be replaced by a more sophisticated clock analysis pass. */ + private def checkForClockInCast(cast: ir.PrimOp, signal: ir.Expression): Unit = { + assert(signal.tpe != ir.ClockType, s"Cannot cast (${cast.serialize}) clock expression ${signal.serialize}!") + } + + private val BV1BitZero = BVLiteral(0, 1) + + def isSigned(e: ir.Expression): Boolean = e.tpe match { + case _: ir.SIntType => true + case _ => false + } + private def getWidth(e: ir.Expression): Int = getWidth(e.tpe) +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala new file mode 100644 index 00000000..b3a2ff17 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -0,0 +1,605 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation} +import FirrtlExpressionSemantics.getWidth +import firrtl.graph.MutableDiGraph +import firrtl.options.Dependency +import firrtl.passes.PassException +import firrtl.stage.Forms +import firrtl.stage.TransformManager.TransformDependency +import firrtl.transforms.PropagatePresetAnnotations +import firrtl.{CircuitState, DependencyAPIMigration, MemoryArrayInit, MemoryInitValue, MemoryScalarInit, Transform, Utils, ir} +import logger.LazyLogging + +import scala.collection.mutable + +// Contains code to convert a flat firrtl module into a functional transition system which +// can then be exported as SMTLib or Btor2 file. + +private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr]) +private case class Signal(name: String, e: BVExpr) { def toSymbol: BVSymbol = BVSymbol(name, e.width) } +private case class TransitionSystem( + name: String, inputs: Array[BVSymbol], states: Array[State], signals: Array[Signal], + outputs: Set[String], assumes: Set[String], asserts: Set[String], fair: Set[String], + comments: Map[String, String] = Map(), header: Array[String] = Array()) { + def serialize: String = { + (Iterator(name) ++ + inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++ + signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++ + states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}") + ).mkString("\n") + } +} + +private case class TransitionSystemAnnotation(sys: TransitionSystem) extends NoTargetAnnotation + +object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { + // TODO: We only really need [[Forms.MidForm]] + LowerTypes, but we also want to fail if there are CombLoops + // TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare + // precisely and PadWidths emits ill-typed firrtl. + override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm + override def invalidates(a: Transform): Boolean = false + // since this pass only runs on the main module, inlining needs to happen before + override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) + + // We run the propagate preset annotations pass manually since we do not want to remove ValidIfs and other + // Verilog emission passes. + // Ideally we would go in and enable the [[PropagatePresetAnnotations]] to only depend on LowForm. + private val presetPass = new PropagatePresetAnnotations + override protected def execute(state: CircuitState): CircuitState = { + // run the preset pass to extract all preset registers and remove preset reset signals + val afterPreset = presetPass.execute(state) + val circuit = afterPreset.circuit + val presetRegs = afterPreset.annotations + .collect { case PresetRegAnnotation(target) if target.module == circuit.main => target.ref }.toSet + + // collect all non-random memory initialization + val memInit = afterPreset.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a } + .filter(_.target.module == circuit.main).map(a => a.target.ref -> a.initValue).toMap + + // convert the main module + val main = circuit.modules.find(_.name == circuit.main).get + val sys = main match { + case x: ir.ExtModule => + throw new ExtModuleException( + "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog.") + case m: ir.Module => + new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit=memInit) + } + + val sortedSys = TopologicalSort.run(sys) + val anno = TransitionSystemAnnotation(sortedSys) + state.copy(circuit=circuit, annotations = afterPreset.annotations :+ anno ) + } +} + +private object UnsupportedException { + val HowToRunStuttering: String = + """ + |You can run the StutteringClockTransform which + |replaces all clock inputs with a clock enable signal. + |This is required not only for multi-clock designs, but also to + |accurately model asynchronous reset which could happen even if there + |isn't a clock edge. + | If you are using the firrtl CLI, please add: + | -fct firrtl.backends.experimental.smt.StutteringClockTransform + | If you are calling into firrtl programmatically you can use: + | RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]) + | To designate a clock to be the global_clock (i.e. the simulation tick), use: + | GlobalClockAnnotation(CircuitTarget(...).module(...).ref("your_clock"))) + |""".stripMargin +} + +private class ExtModuleException(s: String) extends PassException(s) +private class AsyncResetException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering) +private class MultiClockException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering) +private class MissingFeatureException(s: String) extends PassException("Unfortunately the SMT backend does not yet support: " + s) + +private class ModuleToTransitionSystem extends LazyLogging { + + def run(m: ir.Module, presetRegs: Set[String] = Set(), memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + // first pass over the module to convert expressions; discover state and I/O + val scan = new ModuleScanner(makeRandom) + m.foreachPort(scan.onPort) + // multi-clock support requires the StutteringClock transform to be run + if(scan.clocks.size > 1) { + throw new MultiClockException(s"The module ${m.name} has more than one clock: ${scan.clocks.mkString(", ")}") + } + m.foreachStmt(scan.onStatement) + + // turn wires and nodes into signals + val outputs = scan.outputs.toSet + val constraints = scan.assumes.toSet + val bad = scan.asserts.toSet + val isSignal = (scan.wires ++ scan.nodes ++ scan.memSignals).toSet ++ outputs ++ constraints ++ bad + val signals = scan.connects.filter{ case(name, _) => isSignal.contains(name) } + .map { case (name, expr) => Signal(name, expr) } + + // turn registers and memories into states + val registers = scan.registers.map(r => r._1 -> r).toMap + val regStates = scan.connects.filter(s => registers.contains(s._1)).map { case (name, nextExpr) => + val (_, width, resetExpr, initExpr) = registers(name) + onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs) + } + // turn memories into state + val memoryEncoding = new MemoryEncoding(makeRandom) + val memoryStatesAndOutputs = scan.memories.map(m => memoryEncoding.onMemory(m, scan.connects, memInit.get(m.name))) + // replace pseudo assigns for memory outputs + val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap + val signalsWithMem = signals.map { s => + if (memOutputs.contains(s.name)) { + s.copy(e = memOutputs(s.name)) + } else { s } + } + // filter out any left-over self assignments (this happens when we have a registered read port) + .filter(s => s match { case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false case _ => true }) + val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1) + + // generate comments from infos + val comments = mutable.HashMap[String, String]() + scan.infos.foreach { case (name, info) => + serializeInfo(info).foreach { infoString => + if(comments.contains(name)) { comments(name) += InfoSeparator + infoString } + else { comments(name) = InfoPrefix + infoString } + } + } + + // inputs are original module inputs and any "random" signal we need for modelling + val inputs = scan.inputs ++ randoms.values + + // module info to the comment header + val header = serializeInfo(m.info).map(InfoPrefix + _).toArray + + val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints + TransitionSystem(m.name, inputs.toArray, states, signalsWithMem.toArray, outputs, constraints, bad, fair, comments.toMap, header) + } + + private def onRegister(name: String, width: Int, resetExpr: BVExpr, initExpr: BVExpr, + nextExpr: BVExpr, presetRegs: Set[String]): State = { + assert(initExpr.width == width) + assert(nextExpr.width == width) + assert(resetExpr.width == 1) + val sym = BVSymbol(name, width) + val hasReset = initExpr != sym + val isPreset = presetRegs.contains(name) + assert(!isPreset || hasReset, s"Expected preset register $name to have a reset value, not just $initExpr!") + if(hasReset) { + val init = if(isPreset) Some(initExpr) else None + val next = if(isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr) + State(sym, next = Some(next), init = init) + } else { + State(sym, next = Some(nextExpr), init = None) + } + } + + private val InfoSeparator = ", " + private val InfoPrefix = "@ " + private def serializeInfo(info: ir.Info): Option[String] = info match { + case ir.NoInfo => None + case f : ir.FileInfo => Some(f.escaped) + case m : ir.MultiInfo => + val infos = m.flatten + if(infos.isEmpty) { None } else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } + } + + private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]() + private def makeRandom(baseName: String, width: Int): BVExpr = { + // TODO: actually ensure that there cannot be any name clashes with other identifiers + val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii) + val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get + val sym = BVSymbol(name, width) + randoms(name) = sym + sym + } +} + +private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLogging { + type Connects = Iterable[(String, BVExpr)] + def onMemory(defMem: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (Iterable[State], Connects) = { + // we can only work on appropriately lowered memories + assert(defMem.dataType.isInstanceOf[ir.GroundType], + s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!") + assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") + + // collect all memory meta-data in a custom class + val m = new MemInfo(defMem) + + // find all connections related to this memory + val inputs = connects.filter(_._1.startsWith(m.prefix)).toMap + + // there could be a constant init + val init = initValue.map(getInit(m, _)) + + // parse and check read and write ports + val writers = defMem.writers.map( w => new WritePort(m, w, inputs)) + val readers = defMem.readers.map( r => new ReadPort(m, r, inputs)) + + // derive next state from all write ports + assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.") + val next: ArrayExpr = if(writers.isEmpty) { m.sym } else { + if(writers.length > 2) { + throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})") + } + val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) } + if(writers.length == 1) { validData } else { + assert(writers.length == 2) + val conflict = writers.head.doesConflict(writers.last) + val conflictData = writers.head.makeRandomData("_write_write_collision") + val conflictStore = ArrayStore(m.sym, writers.head.addr, conflictData) + ArrayIte(conflict, conflictStore, validData) + } + } + val state = State(m.sym, init, Some(next)) + + // derive data signals from all read ports + assert(defMem.readLatency >= 0) + if(defMem.readLatency > 1) { + throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})") + } + val readPortSignals = if(defMem.readLatency == 0) { + readers.map { r => + // combinatorial read + if(defMem.readUnderWrite != ir.ReadUnderWrite.New) { + //logger.warn(s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." + + // s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored.") + } + // since we do a combinatorial read, the "old" data is the current data + val data = r.readOld() + r.data.name -> data + } + } else { Seq() } + val readPortStates = if(defMem.readLatency == 1) { + readers.map { r => + // we create a register for the read port data + val next = defMem.readUnderWrite match { + case ir.ReadUnderWrite.New => + throw new UnsupportedFeatureException(s"registered read ports that return the new value (${m.name}.${r.name})") + // the thing that makes this hard is to properly handle write conflicts + case ir.ReadUnderWrite.Undefined => + val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r))) + if(anyWriteToTheSameAddress == False) { r.readOld() } else { + val readUnderWriteData = r.makeRandomData("_read_under_write_undefined") + BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.readOld()) + } + case ir.ReadUnderWrite.Old => r.readOld() + } + State(r.data, init=None, next=Some(next)) + } + } else { Seq() } + + (state +: readPortStates, readPortSignals) + } + + private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match { + case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth) + case MemoryArrayInit(values) => + assert(values.length == m.depth, + s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!") + // in order to get a more compact encoding try to find the most common values + val histogram = mutable.LinkedHashMap[BigInt, Int]() + values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) + val baseValue = histogram.maxBy(_._2)._1 + val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth) + values.zipWithIndex.filterNot(_._1 == baseValue) + .foldLeft[ArrayExpr](base) { case (array, (value, index)) => + ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth)) + } + case other => throw new RuntimeException(s"Unsupported memory init option: $other") + } + + private class MemInfo(m: ir.DefMemory) { + val name = m.name + val depth = m.depth + // derrive the type of the memory from the dataType and depth + val dataWidth = getWidth(m.dataType) + val indexWidth = Utils.getUIntWidth(m.depth - 1) max 1 + val sym = ArraySymbol(m.name, indexWidth, dataWidth) + val prefix = m.name + "." + val fullAddressRange = (BigInt(1) << indexWidth) == m.depth + lazy val depthBV = BVLiteral(m.depth, indexWidth) + def isValidAddress(addr: BVExpr): BVExpr = { + if(fullAddressRange) { True } else { + BVComparison(Compare.Greater, depthBV, addr, signed = false) + } + } + } + private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) { + val en: BVSymbol = makeField("en", 1) + val data: BVSymbol = makeField("data", memory.dataWidth) + val addr: BVSymbol = makeField("addr", memory.indexWidth) + protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width) + // make sure that all widths are correct + assert(inputs(en.name).width == en.width) + assert(inputs(addr.name).width == addr.width) + val enIsTrue: Boolean = inputs(en.name) == True + def makeRandomData(suffix: String): BVExpr = + makeRandom(memory.name + "_" + name + suffix, memory.dataWidth) + def readOld(): BVExpr = { + val canBeOutOfRange = !memory.fullAddressRange + val canBeDisabled = !enIsTrue + val data = ArrayRead(memory.sym, addr) + val dataWithRangeCheck = if(canBeOutOfRange) { + val outOfRangeData = makeRandomData("_addr_out_of_range") + BVIte(memory.isValidAddress(addr), data, outOfRangeData) + } else { data } + val dataWithEnabledCheck = if(canBeDisabled) { + val disabledData = makeRandomData("_not_enabled") + BVIte(en, dataWithRangeCheck, disabledData) + } else { dataWithRangeCheck } + dataWithEnabledCheck + } + } + private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr) + extends MemPort(memory, name, inputs) { + assert(inputs(data.name).width == data.width) + val mask: BVSymbol = makeField("mask", 1) + assert(inputs(mask.name).width == mask.width) + val maskIsTrue: Boolean = inputs(mask.name) == True + val doWrite: BVExpr = (enIsTrue, maskIsTrue) match { + case (true, true) => True + case (true, false) => mask + case (false, true) => en + case (false, false) => and(en, mask) + } + def doesConflict(r: ReadPort): BVExpr = { + val sameAddress = BVEqual(r.addr, addr) + if(doWrite == True) { sameAddress } else { and(doWrite, sameAddress) } + } + def doesConflict(w: WritePort): BVExpr = { + val bothWrite = and(doWrite, w.doWrite) + val sameAddress = BVEqual(addr, w.addr) + if(bothWrite == True) { sameAddress } else { and(doWrite, sameAddress) } + } + def writeTo(array: ArrayExpr): ArrayExpr = { + val doUpdate = if(memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr)) + val update = ArrayStore(array, index=addr, data=data) + if(doUpdate == True) update else ArrayIte(doUpdate, update, array) + } + + } + private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr) + extends MemPort(memory, name, inputs) { + } + + private def and(a: BVExpr, b: BVExpr): BVExpr = (a,b) match { + case (True, True) => True + case (True, x) => x + case (x, True) => x + case _ => BVOp(Op.And, a, b) + } + private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b) + private val True = BVLiteral(1, 1) + private val False = BVLiteral(0, 1) + private def all(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) False else b.reduce((a,b) => and(a,b)) + private def any(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) True else b.reduce((a,b) => or(a,b)) +} + +// performas a first pass over the module collecting all connections, wires, registers, input and outputs +private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLogging { + import FirrtlExpressionSemantics.getWidth + + private[firrtl] val inputs = mutable.ArrayBuffer[BVSymbol]() + private[firrtl] val outputs = mutable.ArrayBuffer[String]() + private[firrtl] val clocks = mutable.LinkedHashSet[String]() + private[firrtl] val wires = mutable.ArrayBuffer[String]() + private[firrtl] val nodes = mutable.ArrayBuffer[String]() + private[firrtl] val memSignals = mutable.ArrayBuffer[String]() + private[firrtl] val registers = mutable.ArrayBuffer[(String, Int, BVExpr, BVExpr)]() + private[firrtl] val memories = mutable.ArrayBuffer[ir.DefMemory]() + // DefNode, Connect, IsInvalid and VerificationStatement connections + private[firrtl] val connects = mutable.ArrayBuffer[(String, BVExpr)]() + private[firrtl] val asserts = mutable.ArrayBuffer[String]() + private[firrtl] val assumes = mutable.ArrayBuffer[String]() + // maps identifiers to their info + private[firrtl] val infos = mutable.ArrayBuffer[(String, ir.Info)]() + // keeps track of unused memory (data) outputs so that we can see where they are first used + private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]() + + private[firrtl] def onPort(p: ir.Port): Unit = { + if(isAsyncReset(p.tpe)) { + throw new AsyncResetException(s"Found AsyncReset ${p.name}.") + } + infos.append(p.name -> p.info) + p.direction match { + case ir.Input => + if(isClock(p.tpe)) { + clocks.add(p.name) + } else { + inputs.append(BVSymbol(p.name, getWidth(p.tpe))) + } + case ir.Output => outputs.append(p.name) + } + } + + private[firrtl] def onStatement(s: ir.Statement): Unit = s match { + case ir.DefWire(info, name, tpe) => + if(!isClock(tpe)) { + infos.append(name -> info) + wires.append(name) + } + case ir.DefNode(info, name, expr) => + if(!isClock(expr.tpe)) { + insertDummyAssignsForMemoryOutputs(expr) + infos.append(name -> info) + val e = onExpression(expr, name) + nodes.append(name) + connects.append((name, e)) + } + case ir.DefRegister(info, name, tpe, _, reset, init) => + insertDummyAssignsForMemoryOutputs(reset) + insertDummyAssignsForMemoryOutputs(init) + infos.append(name -> info) + val width = getWidth(tpe) + val resetExpr = onExpression(reset, 1, name + "_reset") + val initExpr = onExpression(init, width, name + "_init") + registers.append((name, width, resetExpr, initExpr)) + case m : ir.DefMemory => + infos.append(m.name -> m.info) + val outputs = getMemOutputs(m) + (getMemInputs(m) ++ outputs).foreach(memSignals.append(_)) + val dataWidth = getWidth(m.dataType) + outputs.foreach(name => unusedMemOutputs(name) = dataWidth) + memories.append(m) + case ir.Connect(info, loc, expr) => + if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") + val name = loc.serialize + insertDummyAssignsForMemoryOutputs(expr) + infos.append(name -> info) + connects.append((name, onExpression(expr, getWidth(loc.tpe), name))) + case ir.IsInvalid(info, loc) => + if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") + val name = loc.serialize + infos.append(name -> info) + connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe)))) + case ir.DefInstance(info, name, module, tpe) => + if(!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") + // we treat all instances as blackboxes + logger.warn(s"WARN: treating instance $name of $module as blackbox. " + + "Please flatten your hierarchy if you want to include submodules in the formal model.") + val ports = tpe.asInstanceOf[ir.BundleType].fields + // skip clock and async reset ports + ports.filterNot(p => isClock(p.tpe) || isAsyncReset(p.tpe) ).foreach { p => + if(!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p") + val isOutput = p.flip == ir.Default + val pName = name + "." + p.name + infos.append(pName -> info) + // outputs of the submodule become inputs to our module + if(isOutput) { + inputs.append(BVSymbol(pName, getWidth(p.tpe))) + } else { + outputs.append(pName) + } + } + case s @ ir.Verification(op, info, _, pred, en, msg) => + if(op == ir.Formal.Cover) { + logger.warn(s"WARN: Cover statement was ignored: ${s.serialize}") + } else { + val name = msgToName(op.toString, msg.string) + val predicate = onExpression(pred, name + "_predicate") + val enabled = onExpression(en, name + "_enabled") + val e = BVImplies(enabled, predicate) + infos.append(name -> info) + connects.append(name -> e) + if(op == ir.Formal.Assert) { + asserts.append(name) + } else { + assumes.append(name) + } + } + case s : ir.Conditionally => + error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}") + case s : ir.PartialConnect => + error(s"PartialConnects are not supported. Please run ExpandConnects: ${s.serialize}") + case s : ir.Attach => + error(s"Analog wires are not supported in the SMT backend: ${s.serialize}") + case s : ir.Stop => + // we could wire up the stop condition as output for debug reasons + logger.warn(s"WARN: Stop statements are currently not supported. Ignoring: ${s.serialize}") + case s : ir.Print => + logger.warn(s"WARN: Print statements are not supported. Ignoring: ${s.serialize}") + case other => other.foreachStmt(onStatement) + } + + private val readInputFields = List("en", "addr") + private val writeInputFields = List("en", "mask", "addr", "data") + private def getMemInputs(m: ir.DefMemory): Iterable[String] = { + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") + val p = m.name + "." + m.writers.flatMap(w => writeInputFields.map(p + w + "." + _)) ++ + m.readers.flatMap(r => readInputFields.map(p + r + "." + _)) + } + private def getMemOutputs(m: ir.DefMemory): Iterable[String] = { + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") + val p = m.name + "." + m.readers.map(r => p + r + ".data") + } + // inserts a dummy assign right before a memory output is used for the first time + // example: + // m.r.data <= m.r.data ; this is the dummy assign + // test <= m.r.data ; this is the first use of m.r.data + private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if(unusedMemOutputs.nonEmpty) { + implicit val uses = mutable.ArrayBuffer[String]() + findUnusedMemoryOutputUse(next) + if(uses.nonEmpty) { + val useSet = uses.toSet + unusedMemOutputs.foreach { case (name, width) => + if(useSet.contains(name)) connects.append(name -> BVSymbol(name, width)) + } + useSet.foreach(name => unusedMemOutputs.remove(name)) + } + } + private def findUnusedMemoryOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match { + case s : ir.SubField => + val name = s.serialize + if(unusedMemOutputs.contains(name)) uses.append(name) + case other => other.foreachExpr(findUnusedMemoryOutputUse) + } + + private case class Context(baseName: String) extends TranslationContext { + override def getRandom(width: Int): BVExpr = makeRandom(baseName, width) + } + + private def onExpression(e: ir.Expression, width: Int, randomPrefix: String): BVExpr = { + implicit val ctx: TranslationContext = Context(randomPrefix) + FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false) + } + private def onExpression(e: ir.Expression, randomPrefix: String): BVExpr = { + implicit val ctx: TranslationContext = Context(randomPrefix) + FirrtlExpressionSemantics.toSMT(e) + } + + private def msgToName(prefix: String, msg: String): String = { + // TODO: ensure that we can generate unique names + prefix + "_" + msg.replace(" ", "_").replace("|", "") + } + private def error(msg: String): Unit = throw new RuntimeException(msg) + private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType] + private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType + private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType +} + +private object TopologicalSort { + /** Ensures that all signals in the resulting system are topologically sorted. + * This is necessary because [[firrtl.transforms.RemoveWires]] does + * not sort assignments to outputs, submodule inputs nor memory ports. + * */ + def run(sys: TransitionSystem): TransitionSystem = { + val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name) + val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates) + // TODO: maybe sort init expressions of states (this should not be needed most of the time) + signalOrder match { + case None => sys + case Some(order) => + val signalMap = sys.signals.map(s => s.name -> s).toMap + // we flatMap over `get` in order to ignore inputs/states in the order + sys.copy(signals = order.flatMap(signalMap.get).toArray) + } + } + + private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = { + val known = new mutable.HashSet[String]() ++ globalSignals + var needsReordering = false + val digraph = new MutableDiGraph[String] + signals.foreach { case (name, expr) => + digraph.addVertex(name) + val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr) + uniqueDependencies.foreach { d => + if(!known.contains(d)) { needsReordering = true } + digraph.addPairWithEdge(name, d) + } + known.add(name) + } + if(needsReordering) { + Some(digraph.linearize.reverse) + } else { None } + } + + private def findDependencies(expr: SMTExpr): List[String] = expr match { + case BVSymbol(name, _) => List(name) + case ArraySymbol(name, _, _) => List(name) + case other => other.children.flatMap(findDependencies) + } +}
\ No newline at end of file diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala new file mode 100644 index 00000000..322b8961 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala @@ -0,0 +1,85 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import java.io.Writer + +import firrtl._ +import firrtl.annotations.{Annotation, NoTargetAnnotation} +import firrtl.options.Viewer.view +import firrtl.options.{CustomFileEmission, Dependency} +import firrtl.stage.FirrtlOptions + + +private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform with Emitter with DependencyAPIMigration { + override def prerequisites: Seq[Dependency[Transform]] = Seq(Dependency(FirrtlToTransitionSystem)) + override def invalidates(a: Transform): Boolean = false + + override def emit(state: CircuitState, writer: Writer): Unit = error("Deprecated since firrtl 1.0!") + + protected def serialize(sys: TransitionSystem): Annotation + + private val BleedingEdgeWarning = + """WARNING: The SMT and BTOR2 emitters are experimental preview features. + |- they might be removed in future versions without deprecation warning + |- their behavior and interfaces might change without notice + |- they are unsupported, we won't be able to fix any issues you find in them + |- however, we do accept pull requests: https://github.com/freechipsproject/firrtl/pulls + |""".stripMargin + + override protected def execute(state: CircuitState): CircuitState = { + val emitCircuit = state.annotations.exists { + case EmitCircuitAnnotation(a) if this.getClass == a => true + case EmitAllModulesAnnotation(a) if this.getClass == a => error("EmitAllModulesAnnotation not supported!") + case _ => false + } + + if(!emitCircuit) { return state } + + logger.warn(BleedingEdgeWarning) + + val sys = state.annotations.collectFirst{ case TransitionSystemAnnotation(sys) => sys }.getOrElse { + error("Could not find the transition system!") + } + state.copy(annotations = state.annotations :+ serialize(sys)) + } + + protected def generatedHeader(format: String, name: String): String = + s"; $format description generated by firrtl ${BuildInfo.version} for module $name.\n" + + protected def error(msg: String): Nothing = throw new RuntimeException(msg) +} + +case class EmittedSMTModelAnnotation(name: String, src: String, outputSuffix: String) + extends NoTargetAnnotation with CustomFileEmission { + override protected def baseFileName(annotations: AnnotationSeq): String = + view[FirrtlOptions](annotations).outputFileName.getOrElse(name) + override protected def suffix: Option[String] = Some(outputSuffix) + override def getBytes: Iterable[Byte] = src.getBytes +} + +private[firrtl] class Btor2Emitter extends SMTEmitter { + override def outputSuffix: String = ".btor2" + override protected def serialize(sys: TransitionSystem): Annotation = { + val btor = generatedHeader("BTOR", sys.name) + Btor2Serializer.serialize(sys).mkString("\n") + "\n" + EmittedSMTModelAnnotation(sys.name, btor, outputSuffix) + } +} + +private[firrtl] class SMTLibEmitter extends SMTEmitter { + override def outputSuffix: String = ".smt2" + override protected def serialize(sys: TransitionSystem): Annotation = { + val hasMemory = sys.states.exists(_.sym.isInstanceOf[ArrayExpr]) + val logic = SMTLibSerializer.setLogic(hasMemory) + "\n" + val header = if(hasMemory) { + "; We have to disable the logic for z3 to accept the non-standard \"as const\"\n" + + "; see https://github.com/Z3Prover/z3/issues/1803\n" + + "; for CVC4 you probably want to include the logic\n" + + ";" + logic + } else { logic } + val smt = generatedHeader("SMT-LIBv2", sys.name) + header + + SMTTransitionSystemEncoder.encode(sys).map(SMTLibSerializer.serialize).mkString("\n") + "\n" + EmittedSMTModelAnnotation(sys.name, smt, outputSuffix) + } +}
\ No newline at end of file diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala new file mode 100644 index 00000000..10a89e8d --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala @@ -0,0 +1,196 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> +// Inspired by the uclid5 SMT library (https://github.com/uclid-org/uclid). +// And the btor2 documentation (BTOR2 , BtorMC and Boolector 3.0 by Niemetz et.al.) + +package firrtl.backends.experimental.smt + +private sealed trait SMTExpr { def children: List[SMTExpr] } +private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { val name: String } +private object SMTSymbol { + def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match { + case b: BVExpr => BVSymbol(name, b.width) + case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth) + } +} +private sealed trait SMTNullaryExpr extends SMTExpr { + override def children: List[SMTExpr] = List() +} + +private sealed trait BVExpr extends SMTExpr { def width: Int } +private case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr { + private def minWidth = value.bitLength + (if(value <= 0) 1 else 0) + assert(width > 0, "Zero or negative width literals are not allowed!") + assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.") + override def toString: String = if(width <= 8) { + width.toString + "'b" + value.toString(2) + } else { width.toString + "'x" + value.toString(16) } +} +private case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol { + assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") + assert(!name.contains("\\"), s"Invalid id $name contains `\\`") + assert(width > 0, "Zero width bit vectors are not supported!") + override def toString: String = name + def toStringWithType: String = name + " : " + SMTExpr.serializeType(this) +} + +private sealed trait BVUnaryExpr extends BVExpr { + def e: BVExpr + override def children: List[BVExpr] = List(e) +} +private case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr { + assert(by >= 0, "Extension must be non-negative!") + override val width: Int = e.width + by + override def toString: String = if(signed) { s"sext($e, $by)" } else { s"zext($e, $by)" } +} +// also known as bit extract operation +private case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr { + assert(lo >= 0, s"lo (lsb) must be non-negative!") + assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo") + assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi") + override def width: Int = hi - lo + 1 + override def toString: String = if(hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]" +} +private case class BVNot(e: BVExpr) extends BVUnaryExpr { + override val width: Int = e.width + override def toString: String = s"not($e)" +} +private case class BVNegate(e: BVExpr) extends BVUnaryExpr { + override val width: Int = e.width + override def toString: String = s"neg($e)" +} +private case class BVReduceOr(e: BVExpr) extends BVUnaryExpr { + override def width: Int = 1 + override def toString: String = s"redor($e)" +} +private case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr { + override def width: Int = 1 + override def toString: String = s"redand($e)" +} +private case class BVReduceXor(e: BVExpr) extends BVUnaryExpr { + override def width: Int = 1 + override def toString: String = s"redxor($e)" +} + +private sealed trait BVBinaryExpr extends BVExpr { + def a: BVExpr + def b: BVExpr + override def children: List[BVExpr] = List(a, b) +} +private case class BVImplies(a: BVExpr, b: BVExpr) extends BVBinaryExpr { + assert(a.width == 1 && b.width == 1, s"Both arguments need to be 1-bit!") + override def width: Int = 1 + override def toString: String = s"impl($a, $b)" +} +private case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr { + assert(a.width == b.width, s"Both argument need to be the same width!") + override def width: Int = 1 + override def toString: String = s"eq($a, $b)" +} +private object Compare extends Enumeration { + val Greater, GreaterEqual = Value +} +private case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed: Boolean) extends BVBinaryExpr { + assert(a.width == b.width, s"Both argument need to be the same width!") + override def width: Int = 1 + override def toString: String = op match { + case Compare.Greater => (if(signed) "sgt" else "ugt") + s"($a, $b)" + case Compare.GreaterEqual => (if(signed) "sgeq" else "ugeq") + s"($a, $b)" + } +} +private object Op extends Enumeration { + val And = Value("and") + val Or = Value("or") + val Xor = Value("xor") + val ShiftLeft = Value("logical_shift_left") + val ArithmeticShiftRight = Value("arithmetic_shift_right") + val ShiftRight = Value("logical_shift_right") + val Add = Value("add") + val Mul = Value("mul") + val SignedDiv = Value("sdiv") + val UnsignedDiv = Value("udiv") + val SignedMod = Value("smod") + val SignedRem = Value("srem") + val UnsignedRem = Value("urem") + val Sub = Value("sub") +} +private case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr { + assert(a.width == b.width, s"Both argument need to be the same width!") + override val width: Int = a.width + override def toString: String = s"$op($a, $b)" +} +private case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr { + override val width: Int = a.width + b.width + override def toString: String = s"concat($a, $b)" +} +private case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr { + assert(array.indexWidth == index.width, "Index with does not match expected array index width!") + override val width: Int = array.dataWidth + override def toString: String = s"$array[$index]" + override def children: List[SMTExpr] = List(array, index) +} +private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr { + assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!") + assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}") + override val width: Int = tru.width + override def toString: String = s"ite($cond, $tru, $fals)" + override def children: List[BVExpr] = List(cond, tru, fals) +} + +private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int } +private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol { + assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") + assert(!name.contains("\\"), s"Invalid id $name contains `\\`") + override def toString: String = name + def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>" +} +private case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr { + assert(array.indexWidth == index.width, "Index with does not match expected array index width!") + assert(array.dataWidth == data.width, "Data with does not match expected array data width!") + override val dataWidth: Int = array.dataWidth + override val indexWidth: Int = array.indexWidth + override def toString: String = s"$array[$index := $data]" + override def children: List[SMTExpr] = List(array, index, data) +} +private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr { + assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!") + assert(tru.indexWidth == fals.indexWidth, + s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}") + assert(tru.dataWidth == fals.dataWidth, + s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}") + override val dataWidth: Int = tru.dataWidth + override val indexWidth: Int = tru.indexWidth + override def toString: String = s"ite($cond, $tru, $fals)" + override def children: List[SMTExpr] = List(cond, tru, fals) +} +private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr { + assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!") + assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!") + override def width: Int = 1 + override def toString: String = s"eq($a, $b)" + override def children: List[SMTExpr] = List(a, b) +} +private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr { + override val dataWidth: Int = e.width + override def toString: String = s"([$e] x ${ (BigInt(1) << indexWidth) })" + override def children: List[SMTExpr] = List(e) +} + +private object SMTEqual { + def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a,b) match { + case (ab : BVExpr, bb : BVExpr) => BVEqual(ab, bb) + case (aa : ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba) + case _ => throw new RuntimeException(s"Cannot compare $a and $b") + } +} + +private object SMTExpr { + def serializeType(e: SMTExpr): String = e match { + case b: BVExpr => s"bv<${b.width}>" + case a: ArrayExpr => s"bv<${a.indexWidth}> -> bv<${a.dataWidth}>" + } +} + +// Raw SMTLib encoded expressions as an escape hatch used in the [[SMTTransitionSystemEncoder]] +private case class BVRawExpr(serialized: String, width: Int) extends BVExpr with SMTNullaryExpr +private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTNullaryExpr
\ No newline at end of file diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala new file mode 100644 index 00000000..14e73253 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala @@ -0,0 +1,73 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +/** Similar to the mapExpr and foreachExpr methods of the firrtl ir nodes, but external to the case classes */ +private object SMTExprVisitor { + type ArrayFun = ArrayExpr => ArrayExpr + type BVFun = BVExpr => BVExpr + + def map[T <: SMTExpr](bv: BVFun, ar: ArrayFun)(e: T): T = e match { + case b: BVExpr => map(b, bv, ar).asInstanceOf[T] + case a: ArrayExpr => map(a, bv, ar).asInstanceOf[T] + } + def map[T <: SMTExpr](f: SMTExpr => SMTExpr)(e: T): T = + map(b => f(b).asInstanceOf[BVExpr], a => f(a).asInstanceOf[ArrayExpr])(e) + + private def map(e: BVExpr, bv: BVFun, ar: ArrayFun): BVExpr = e match { + // nullary + case old : BVLiteral => bv(old) + case old : BVSymbol => bv(old) + case old : BVRawExpr => bv(old) + // unary + case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVExtend(n, by, signed)) + case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVSlice(n, hi, lo)) + case old @ BVNot(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNot(n)) + case old @ BVNegate(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNegate(n)) + case old @ BVReduceAnd(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceAnd(n)) + case old @ BVReduceOr(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceOr(n)) + case old @ BVReduceXor(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceXor(n)) + // binary + case old @ BVImplies(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB)) + case old @ BVEqual(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB)) + case old @ ArrayEqual(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB)) + case old @ BVComparison(op, a, b, signed) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed)) + case old @ BVOp(op, a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB)) + case old @ BVConcat(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB)) + case old @ ArrayRead(a, b) => + val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB)) + // ternary + case old @ BVIte(a, b, c) => + val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) + bv(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)) + } + + + private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match { + case old : ArrayRawExpr => ar(old) + case old : ArraySymbol => ar(old) + case old @ ArrayConstant(e, indexWidth) => + val n = map(e, bv, ar) ; ar(if(n.eq(e)) old else ArrayConstant(n, indexWidth)) + case old @ ArrayStore(a, b, c) => + val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) + ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC)) + case old @ ArrayIte(a, b, c) => + val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) + ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC)) + } + +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala new file mode 100644 index 00000000..1993da87 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala @@ -0,0 +1,151 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import scala.util.matching.Regex + +/** Converts STM Expressions to a SMTLib compatible string representation. + * See http://smtlib.cs.uiowa.edu/ + * Assumes well typed expression, so it is advisable to run the TypeChecker + * before serializing! + * Automatically converts 1-bit vectors to bool. + */ +private object SMTLibSerializer { + def setLogic(hasMem: Boolean) = "(set-logic QF_" + (if(hasMem) "A" else "") + "UFBV)" + + def serialize(e: SMTExpr): String = e match { + case b : BVExpr => serialize(b) + case a : ArrayExpr => serialize(a) + } + + def serializeType(e: SMTExpr): String = e match { + case b : BVExpr => serializeBitVectorType(b.width) + case a : ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth) + } + + private def serialize(e: BVExpr): String = e match { + case BVLiteral(value, width) => + val mask = (BigInt(1) << width) - 1 + val twosComplement = if(value < 0) { ((~(-value)) & mask) + 1 } else value + if(width == 1) { + if(twosComplement == 1) "true" else "false" + } else { + s"(_ bv$twosComplement $width)" + } + case BVSymbol(name, _) => escapeIdentifier(name) + case BVExtend(e, 0, _) => serialize(e) + case BVExtend(BVLiteral(value, width), by, false) => serialize(BVLiteral(value, width + by)) + case BVExtend(e, by, signed) => + val foo = if(signed) "sign_extend" else "zero_extend" + s"((_ $foo $by) ${asBitVector(e)})" + case BVSlice(e, hi, lo) => + if(lo == 0 && hi == e.width - 1) { serialize(e) + } else { + val bits = s"((_ extract $hi $lo) ${asBitVector(e)})" + // 1-bit extracts need to be turned into a boolean + if(lo == hi) { toBool(bits) } else { bits } + } + case BVNot(BVEqual(a, b)) if a.width == 1 => s"(distinct ${serialize(a)} ${serialize(b)})" + case BVNot(BVNot(e)) => serialize(e) + case BVNot(e) => if(e.width == 1) { s"(not ${serialize(e)})" } else { s"(bvnot ${serialize(e)})" } + case BVNegate(e) => s"(bvneg ${asBitVector(e)})" + case r: BVReduceAnd => serialize(Expander.expand(r)) + case r: BVReduceOr => serialize(Expander.expand(r)) + case r: BVReduceXor => serialize(Expander.expand(r)) + case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b) + case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})" + case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" + case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" + case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})" + case BVComparison(Compare.GreaterEqual, a, b, false) => s"(bvuge ${asBitVector(a)} ${asBitVector(b)})" + case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})" + case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})" + // boolean operations get a special treatment for 1-bit vectors aka bools + case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})" + case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})" + case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})" + case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})") + case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})" + case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})" + case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})" + case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case BVRawExpr(serialized, _) => serialized + } + + def serialize(e: ArrayExpr): String = e match { + case ArraySymbol(name, _, _) => escapeIdentifier(name) + case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})" + case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})" + case ArrayRawExpr(serialized, _, _) => serialized + } + + def serialize(c: SMTCommand): String = c match { + case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n") + case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)" + case DefineFunction(name, args, e) => + val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ") + s"(define-fun ${escapeIdentifier(name)} ($aa) ${serializeType(e)} ${serialize(e)})" + case DeclareFunction(sym, tpes) => + val aa = tpes.mkString(" ") + s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serializeType(sym)})" + } + + private def serializeArrayType(indexWidth: Int, dataWidth: Int): String = + s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})" + private def serializeBitVectorType(width: Int): String = + if(width == 1) { "Bool" } else { assert(width > 1) ; s"(_ BitVec $width)" } + + private def serialize(op: Op.Value): String = op match { + case Op.And => "bvand" + case Op.Or => "bvor" + case Op.Xor => "bvxor" + case Op.ArithmeticShiftRight => "bvashr" + case Op.ShiftRight => "bvlshr" + case Op.ShiftLeft => "bvshl" + case Op.Add => "bvadd" + case Op.Mul => "bvmul" + case Op.Sub => "bvsub" + case Op.SignedDiv => "bvsdiv" + case Op.UnsignedDiv => "bvudiv" + case Op.SignedMod => "bvsmod" + case Op.SignedRem => "bvsrem" + case Op.UnsignedRem => "bvurem" + } + + private def toBool(e: String): String = s"(= $e (_ bv1 1))" + + private val bvZero = "(_ bv0 1)" + private val bvOne = "(_ bv1 1)" + private def asBitVector(e: BVExpr): String = + if(e.width > 1) { serialize(e) } else { s"(ite ${serialize(e)} $bvOne $bvZero)" } + + // See <simple_symbol> definition in the Concrete Syntax Appendix of the SMTLib Spec + private val simple: Regex = raw"[a-zA-Z\+-/\*\=%\?!\.\$$_~&\^<>@][a-zA-Z0-9\+-/\*\=%\?!\.\$$_~&\^<>@]*".r + def escapeIdentifier(name: String): String = name match { + case simple() => name + case _ => if(name.startsWith("|") && name.endsWith("|")) name else s"|$name|" + } +} + +/** Expands expressions that are not natively supported by SMTLib */ +private object Expander { + def expand(r: BVReduceAnd): BVExpr = { + if(r.e.width == 1) { r.e } else { + val allOnes = (BigInt(1) << r.e.width) - 1 + BVEqual(r.e, BVLiteral(allOnes, r.e.width)) + } + } + def expand(r: BVReduceOr): BVExpr = { + if(r.e.width == 1) { r.e } else { + BVNot(BVEqual(r.e, BVLiteral(0, r.e.width))) + } + } + def expand(r: BVReduceXor): BVExpr = { + if(r.e.width == 1) { r.e } else { + val bits = (0 until r.e.width).map(ii => BVSlice(r.e, ii, ii)) + bits.reduce[BVExpr]((a,b) => BVOp(Op.Xor, a, b)) + } + } +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala new file mode 100644 index 00000000..e9acc05b --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala @@ -0,0 +1,128 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import scala.collection.mutable + +/** This Transition System encoding is directly inspired by yosys' SMT backend: + * https://github.com/YosysHQ/yosys/blob/master/backends/smt2/smt2.cc + * It if fairly compact, but unfortunately, the use of an uninterpreted sort for the state + * prevents this encoding from working with boolector. + * For simplicity reasons, we do not support hierarchical designs (no `_h` function). + * */ +private object SMTTransitionSystemEncoder { + + def encode(sys: TransitionSystem): Iterable[SMTCommand] = { + val cmds = mutable.ArrayBuffer[SMTCommand]() + val name = sys.name + + // emit header as comments + cmds ++= sys.header.map(Comment) + + // declare state type + val stateType = id(name + "_s") + cmds += DeclareUninterpretedSort(stateType) + + // inputs and states are modelled as constants + def declare(sym: SMTSymbol, kind: String): Unit = { + cmds ++= toDescription(sym, kind, sys.comments.get) + val s = SMTSymbol.fromExpr(sym.name + SignalSuffix, sym) + cmds += DeclareFunction(s, List(stateType)) + } + sys.inputs.foreach(i => declare(i, "input")) + sys.states.foreach(s => declare(s.sym, "register")) + + // signals are just functions of other signals, inputs and state + def define(sym: SMTSymbol, e: SMTExpr, suffix: String = SignalSuffix): Unit = { + cmds += DefineFunction(sym.name + suffix, List((State, stateType)), replaceSymbols(e)) + } + sys.signals.foreach { signal => + val kind = if(sys.outputs.contains(signal.name)) { "output" + } else if(sys.assumes.contains(signal.name)) { "assume" + } else if(sys.asserts.contains(signal.name)) { "assert" + } else { "wire" } + val sym = SMTSymbol.fromExpr(signal.name, signal.e) + cmds ++= toDescription(sym, kind, sys.comments.get) + define(sym, signal.e) + } + + // define the next and init functions for all states + sys.states.foreach { state => + assert(state.next.nonEmpty, "Next function required") + define(state.sym, state.next.get, NextSuffix) + // init is optional + state.init.foreach { init => + define(state.sym, init, InitSuffix) + } + } + + def defineConjunction(e: Iterable[BVExpr], suffix: String): Unit = { + define(BVSymbol(name, 1), andReduce(e), suffix) + } + + // the transition relation asserts that the value of the next state is the next value from the previous state + // e.g., (reg state_n) == (reg_next state) + val transitionRelations = sys.states.map { state => + val newState = symbolToFunApp(state.sym, SignalSuffix, StateNext) + val nextOldState = symbolToFunApp(state.sym, NextSuffix, State) + SMTEqual(newState, nextOldState) + } + // the transition relation is over two states + val transitionExpr = replaceSymbols(andReduce(transitionRelations)) + cmds += DefineFunction(name + "_t", List((State, stateType), (StateNext, stateType)), transitionExpr) + + // The init relation just asserts that all init function hold + val initRelations = sys.states.filter(_.init.isDefined).map { state => + val stateSignal = symbolToFunApp(state.sym, SignalSuffix, State) + val initSignal = symbolToFunApp(state.sym, InitSuffix, State) + SMTEqual(stateSignal, initSignal) + } + defineConjunction(initRelations, "_i") + + // assertions and assumptions + val assertions = sys.signals.filter(a => sys.asserts.contains(a.name)).map(a => replaceSymbols(a.toSymbol)) + defineConjunction(assertions, "_a") + val assumptions = sys.signals.filter(a => sys.assumes.contains(a.name)).map(a => replaceSymbols(a.toSymbol)) + defineConjunction(assumptions, "_u") + + cmds + } + + private def id(s: String): String = SMTLibSerializer.escapeIdentifier(s) + private val State = "state" + private val StateNext = "state_n" + private val SignalSuffix = "_f" + private val NextSuffix = "_next" + private val InitSuffix = "_init" + private def toDescription(sym: SMTSymbol, kind: String, comments: String => Option[String]): List[Comment] = { + List(sym match { + case BVSymbol(name, width) => + Comment(s"firrtl-smt2-$kind $name $width") + case ArraySymbol(name, indexWidth, dataWidth) => + Comment(s"firrtl-smt2-$kind $name $indexWidth $dataWidth") + }) ++ comments(sym.name).map(Comment) + } + + private def andReduce(e: Iterable[BVExpr]): BVExpr = + if(e.isEmpty) BVLiteral(1, 1) else e.reduce((a,b) => BVOp(Op.And, a, b)) + + // All signals are modelled with functions that need to be called with the state as argument, + // this replaces all Symbols with function applications to the state. + private def replaceSymbols(e: SMTExpr): SMTExpr = { + SMTExprVisitor.map(symbolToFunApp(_, SignalSuffix, State))(e) + } + private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr] + private def symbolToFunApp(sym: SMTExpr, suffix: String, arg: String): SMTExpr = sym match { + case BVSymbol(name, width) => BVRawExpr(s"(${id(name+suffix)} $arg)", width) + case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name+suffix)} $arg)", indexWidth, dataWidth) + case other => other + } +} + +/** minimal set of pseudo SMT commands needed for our encoding */ +private sealed trait SMTCommand +private case class Comment(msg: String) extends SMTCommand +private case class DefineFunction(name: String, args: Seq[(String, String)], e: SMTExpr) extends SMTCommand +private case class DeclareFunction(sym: SMTSymbol, tpes: Seq[String]) extends SMTCommand +private case class DeclareUninterpretedSort(name: String) extends SMTCommand diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala new file mode 100644 index 00000000..d8e203f8 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala @@ -0,0 +1,253 @@ +// See LICENSE for license details. +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import firrtl.{CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils, ir} +import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation, ReferenceTarget, SingleTargetAnnotation} +import firrtl.ir.EmptyStmt +import firrtl.options.Dependency +import firrtl.passes.PassException +import firrtl.stage.Forms +import firrtl.stage.TransformManager.TransformDependency + +import scala.collection.mutable + +case class GlobalClockAnnotation(target: ReferenceTarget) extends SingleTargetAnnotation[ReferenceTarget] { + override def duplicate(n: ReferenceTarget): Annotation = this.copy(n) +} + +/** Converts every input clock into a clock enable input and adds a single global clock. + * - all registers and memory ports will be connected to the new global clock + * - all registers and memory ports will be guarded by the enable signal of their original clock + * - the clock enabled signal can be understood as a clock tick or posedge + * - this transform can be used in order to (formally) verify designs with multiple clocks or asynchronous resets + */ +class StutteringClockTransform extends Transform with DependencyAPIMigration { + override def prerequisites: Seq[TransformDependency] = Forms.LowForm + override def invalidates(a: Transform): Boolean = false + + // this pass needs to run *before* converting to a transition system + override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq(Dependency(FirrtlToTransitionSystem)) + // since this pass only runs on the main module, inlining needs to happen before + override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) + + + override protected def execute(state: CircuitState): CircuitState = { + if(state.circuit.modules.size > 1) { + logger.warn("WARN: StutteringClockTransform currently only supports running on a single module.\n" + + s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want.") + } + + // get main module + val main = state.circuit.modules.find(_.name == state.circuit.main).get match { + case m: ir.Module => m + case e: ir.ExtModule => unsupportedError(s"Cannot run on extmodule $e") + } + mainName = main.name + + val namespace = Namespace(main) + + // create a global clock + val globalClocks = state.annotations.collect { case GlobalClockAnnotation(c) => c } + assert(globalClocks.size < 2, "There can only be a single global clock: " + globalClocks.mkString(", ")) + val (globalClock, portsWithGlobalClock) = globalClocks.headOption match { + case Some(clock) => + assert(clock.module == main.name, "GlobalClock needs to be an input of the main module!") + assert(main.ports.exists(_.name == clock.ref), "GlobalClock needs to be an input port!") + assert(main.ports.find(_.name == clock.ref).get.direction == ir.Input, "GlobalClock needs to be an input port!") + (clock.ref, main.ports) + case None => + val name = namespace.newName("global_clock") + (name, ir.Port(ir.NoInfo, name, ir.Input, ir.ClockType) +: main.ports) + } + + // replace all other clocks with enable signals, unless they are the global clock + val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name) + val clockToEnable = clocks.map{c => + c -> ir.Reference(namespace.newName(c + "_en"), Bool, firrtl.PortKind, firrtl.SourceFlow) + }.toMap + val portsWithEnableSignals = portsWithGlobalClock.map { p => + if(clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) } else { p } + } + // replace async reset with synchronous reset (since everything will we synchronous with the global clock) + // unless it is a preset reset + val asyncResets = portsWithEnableSignals.filter(_.tpe == ir.AsyncResetType).map(_.name) + val isPresetReset = state.annotations.collect{ case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet + val resetsToChange = asyncResets.filterNot(isPresetReset).toSet + val portsWithSyncReset = portsWithEnableSignals.map { p => + if(resetsToChange.contains(p.name)) { p.copy(tpe = Bool) } else { p } + } + + // discover clock and reset connections + val scan = scanClocks(main, clockToEnable, resetsToChange) + + // rename clocks to clock enable signals + val mRef = CircuitTarget(state.circuit.main).module(main.name) + val renameMap = RenameMap() + scan.clockToEnable.foreach { case (clk, en) => + renameMap.record(mRef.ref(clk), mRef.ref(en.name)) + } + + // make changes + implicit val ctx: Context = new Context(globalClock, scan) + val newMain = main.copy(ports = portsWithSyncReset).mapStmt(onStatement) + + val nonMainModules = state.circuit.modules.filterNot(_.name == state.circuit.main) + val newCircuit = state.circuit.copy(modules = nonMainModules :+ newMain) + state.copy(circuit = newCircuit, renames = Some(renameMap)) + } + + private def onStatement(s: ir.Statement)(implicit ctx: Context): ir.Statement = { + s.foreachExpr(checkExpr) + s match { + // memory field connects + case c @ ir.Connect(_, ir.SubField(ir.SubField(ir.Reference(mem, _, _, _), port, _, _), field, _, _), _) + if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) => + // replace clock with the global clock + if(field == "clk") { + c.copy(expr = ctx.globalClock) + } else if(field == "en") { + val m = ctx.memInfo(mem) + val isWritePort = m.writers.contains(port) + assert(isWritePort || m.readers.contains(port)) + + // for write ports we guard the write enable with the clock enable signal, similar to registers + if(isWritePort) { + val clockEn = ctx.memPortToClockEnable(mem + "." + port) + val guardedEnable = and(clockEn, c.expr) + c.copy(expr = guardedEnable) + } else { c } + } else { c} + // register field connects + case c @ ir.Connect(_, r : ir.Reference, next) if ctx.registerToEnable.contains(r.name) => + val clockEnable = ctx.registerToEnable(r.name) + val guardedNext = mux(clockEnable, next, r) + c.copy(expr = guardedNext) + // remove other clock wires and nodes + case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt + case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt + case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt + // change async reset to synchronous reset + case ir.Connect(info, loc: ir.Reference, expr: ir.Reference) if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) => + ir.Connect(info, loc.copy(tpe=Bool), expr.copy(tpe=Bool)) + case d @ ir.DefNode(_, name, value: ir.Reference) if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) => + d.copy(value = value.copy(tpe=Bool)) + case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe=Bool) + // change memory clock and synchronize reset + case ir.DefRegister(info, name, tpe, clock, reset, init) if ctx.registerToEnable.contains(name) => + val clockEnable = ctx.registerToEnable(name) + val newReset = reset match { + case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe=Bool) + case other => other + } + val synchronizedReset = if(reset.tpe == ir.AsyncResetType) { newReset } else { and(newReset, clockEnable) } + ir.DefRegister(info, name, tpe, ctx.globalClock, synchronizedReset, init) + case other => other.mapStmt(onStatement) + } + } + + private def scanClocks(m: ir.Module, initialClockToEnable: Map[String, ir.Reference], resetsToChange: Set[String]): ScanCtx = { + implicit val ctx: ScanCtx = new ScanCtx(initialClockToEnable, resetsToChange) + m.foreachStmt(scanClocksAndResets) + ctx + } + + private def scanClocksAndResets(s: ir.Statement)(implicit ctx: ScanCtx): Unit = { + s.foreachExpr(checkExpr) + s match { + // track clock aliases + case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType => + val locName = loc.serialize + ctx.clockToEnable.get(expr.serialize).foreach { clockEn => + ctx.clockToEnable(locName) = clockEn + // keep track of memory clocks + if(loc.isInstanceOf[ir.SubField]) { + val parts = locName.split('.') + if(ctx.mems.contains(parts.head)) { + assert(parts.length == 3 && parts.last == "clk") + ctx.memPortToClockEnable.append(parts.dropRight(1).mkString(".") -> clockEn) + } + } + } + case ir.DefNode(_, name, value) if value.tpe == ir.ClockType => + ctx.clockToEnable.get(value.serialize).foreach(c => ctx.clockToEnable(name) = c) + // track reset aliases + case ir.Connect(_, loc, expr) if expr.tpe == ir.AsyncResetType && ctx.resetsToChange(expr.serialize) => + ctx.resetsToChange.add(loc.serialize) + case ir.DefNode(_, name, value) if value.tpe == ir.AsyncResetType && ctx.resetsToChange(value.serialize) => + ctx.resetsToChange.add(name) + // modify clocked elements + case ir.DefRegister(_, name, _, clock, _, _) => + ctx.clockToEnable.get(clock.serialize).foreach { clockEnable => + ctx.registerToEnable.append(name -> clockEnable) + } + case m : ir.DefMemory => + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") + assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!") + assert(m.writeLatency == 1, "Only write-latency 1 is supported!") + if(m.readers.nonEmpty && m.readLatency == 1) { + unsupportedError("Registers memory read ports are not properly implemented yet :(") + } + ctx.mems(m.name) = m + case other => other.foreachStmt(scanClocksAndResets) + } + } + + // we rely on people not casting clocks or async resets + private def checkExpr(expr: ir.Expression): Unit = expr match { + case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.ClockType => + unsupportedError(s"Clock casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.ClockType => + unsupportedError(s"Clock casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType => + unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType => + unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsAsyncReset, _, _, _) => + unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}") + case ir.DoPrim(PrimOps.AsClock, _, _, _) => + unsupportedError(s"Clock casts are not supported: ${expr.serialize}") + case other => other.foreachExpr(checkExpr) + } + + private class ScanCtx(initialClockToEnable: Map[String, ir.Reference], initialResetsToChange: Set[String]) { + // keeps track of which clock signals will be replaced by which clock enable signal + val clockToEnable = mutable.HashMap[String, ir.Reference]() ++ initialClockToEnable + // kepp track of asynchronous resets that need to be changed to bool + val resetsToChange = mutable.HashSet[String]() ++ initialResetsToChange + // registers whose next function needs to be guarded with a clock enable + val registerToEnable = mutable.ArrayBuffer[(String, ir.Reference)]() + // memory enables which need to be guarded with clock enables + val memPortToClockEnable = mutable.ArrayBuffer[(String, ir.Reference)]() + // keep track of memory names + val mems = mutable.HashMap[String, ir.DefMemory]() + } + + private class Context(globalClockName: String, scanResults: ScanCtx) { + val globalClock: ir.Reference = ir.Reference(globalClockName, ir.ClockType, firrtl.PortKind, firrtl.SourceFlow) + // keeps track of which clock signals will be replaced by which clock enable signal + val isRemovedClock: String => Boolean = scanResults.clockToEnable.contains + // registers whose next function needs to be guarded with a clock enable + val registerToEnable: Map[String, ir.Reference] = scanResults.registerToEnable.toMap + // memory enables which need to be guarded with clock enables + val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap + // keep track of memory names + val isMem: String => Boolean = scanResults.mems.contains + val memInfo: String => ir.DefMemory = scanResults.mems + val isResetToChange: String => Boolean = scanResults.resetsToChange.contains + } + + private var mainName: String = "" // for debugging + private def unsupportedError(msg: String): Nothing = + throw new UnsupportedFeatureException(s"StutteringClockTransform: [$mainName] $msg") + + private def mux(cond: ir.Expression, a: ir.Expression, b: ir.Expression): ir.Expression = { + ir.Mux(cond, a, b, Utils.mux_type_and_widths(a, b)) + } + private def and(a: ir.Expression, b: ir.Expression): ir.Expression = + ir.DoPrim(PrimOps.And, List(a, b), List(), Bool) + private val Bool = ir.UIntType(ir.IntWidth(1)) +} + +private class UnsupportedFeatureException(s: String) extends PassException(s)
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala new file mode 100644 index 00000000..56f891e6 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala @@ -0,0 +1,61 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class Btor2Spec extends SMTBackendBaseSpec { + + it should "convert a hello world module" in { + val src = + """circuit m: + | module m: + | input clock: Clock + | input a: UInt<8> + | output b: UInt<16> + | b <= a + | assert(clock, eq(a, b), UInt(1), "") + |""".stripMargin + + val expected = + """1 sort bitvec 8 + |2 input 1 a + |3 sort bitvec 16 + |4 uext 3 2 8 + |5 output 4 ; b + |6 sort bitvec 1 + |7 uext 3 2 8 + |8 eq 6 7 4 + |9 not 6 8 + |10 bad 9 ; assert_ + |""".stripMargin + + assert(toBotr2Str(src) == expected) + } + + it should "include FileInfo in the output" in { + val src = + """circuit m: @[circuit 0:0] + | module m: @[module 0:0] + | input clock: Clock @[clock 0:0] + | input a: UInt<8> @[a 0:0] + | output b: UInt<16> @[b 0:0] + | b <= a @[b_a 0:0] + | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0] + |""".stripMargin + + val expected = + """; @ module 0:0 + |1 sort bitvec 8 + |2 input 1 a ; @ a 0:0 + |3 sort bitvec 16 + |4 uext 3 2 8 + |5 output 4 ; b @ b 0:0, b_a 0:0 + |6 sort bitvec 1 + |7 uext 3 2 8 + |8 eq 6 7 4 + |9 not 6 8 + |10 bad 9 ; assert_ @ assert 0:0 + |""".stripMargin + + assert(toBotr2Str(src) == expected) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala new file mode 100644 index 00000000..015ac4a9 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -0,0 +1,224 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { + + def primop(op: String, resTpe: String, inTpes: Seq[String], consts: Seq[Int]): String = { + val inputs = inTpes.zipWithIndex.map { case (tpe, ii) => s" input i$ii : $tpe" }.mkString("\n") + val args = (inTpes.zipWithIndex.map { case (_, ii) => s"i$ii" } ++ consts.map(_.toString)).mkString(", ") + val src = + s"""circuit m: + | module m: + |$inputs + | output res: $resTpe + | res <= $op($args) + | + |""".stripMargin + val sys = toSys(src) + assert(sys.signals.length == 1) + sys.signals.head.e.toString + } + + def primop(signed: Boolean, op: String, resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(), + resAlwaysUnsigned: Boolean = false): String = { + val tpe = if(signed) "SInt" else "UInt" + val resTpe = if(resAlwaysUnsigned) "UInt" else tpe + val inTpes = inWidth.map(w => s"$tpe<$w>") + primop(op, s"$resTpe<$resWidth>", inTpes, consts) + } + + it should "correctly translate the add primitive operation with different operand sizes" in { + assert(primop(false, "add", 5, List(3, 5)) == "add(zext(i0, 3), zext(i1, 1))[4:0]") + assert(primop(false, "add", 5, List(3, 4)) == "add(zext(i0, 2), zext(i1, 1))") + assert(primop(true, "add", 5, List(3, 5)) == "add(sext(i0, 3), sext(i1, 1))[4:0]") + assert(primop(true, "add", 5, List(3, 4)) == "add(sext(i0, 2), sext(i1, 1))") + + // could be simplified to just `add(i0, i1)` + assert(primop(false, "add", 8, List(8, 8)) == "add(zext(i0, 1), zext(i1, 1))[7:0]") + } + + it should "correctly translate the `add` primitive operation" in { + assert(primop(false, "add", 8, List(7, 7)) == "add(zext(i0, 1), zext(i1, 1))") + } + + it should "correctly translate the `sub` primitive operation" in { + assert(primop(false, "sub", 8, List(7, 7)) == "sub(zext(i0, 1), zext(i1, 1))") + } + + it should "correctly translate the `mul` primitive operation" in { + assert(primop(false, "mul", 8, List(4, 4)) == "mul(zext(i0, 4), zext(i1, 4))") + } + + it should "correctly translate the `div` primitive operation" in { + // division is a little bit more complicated because the result of division by zero is undefined + assert(primop(false, "div", 8, List(8, 8)) == + "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))") + assert(primop(false, "div", 8, List(8, 4)) == + "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))") + + // signed division increases result width by 1 + assert(primop(true, "div", 8, List(7, 7)) == + "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))") + assert(primop(true, "div", 8, List(7, 4)) + == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))") + } + + it should "correctly translate the `rem` primitive operation" in { + // rem can decrease the size of operands, but we should only do that decrease on the result + assert(primop(false, "rem", 4, List(4, 8)) == "urem(zext(i0, 4), i1)[3:0]") + assert(primop(false, "rem", 4, List(8, 4)) == "urem(i0, zext(i1, 4))[3:0]") + assert(primop(true, "rem", 4, List(4, 8)) == "srem(sext(i0, 4), i1)[3:0]") + assert(primop(true, "rem", 4, List(8, 4)) == "srem(i0, sext(i1, 4))[3:0]") + // TODO: add test to make sure we are using the correct mod/rem operation for signed and unsigned + // https://groups.google.com/g/stp-users/c/od43h8q5RSI has some tests that we could copy and + // use with a SMT solver + } + + it should "correctly translate the comparison primitive operations" in { + // some comparisons are represented as the negation of others + assert(primop(false, "lt", 1, List(8, 8)) == "not(ugeq(i0, i1))") + assert(primop(false, "leq", 1, List(8, 8)) == "not(ugt(i0, i1))") + assert(primop(false, "gt", 1, List(8, 8)) == "ugt(i0, i1)") + assert(primop(false, "geq", 1, List(8, 8)) == "ugeq(i0, i1)") + assert(primop(false, "eq", 1, List(8, 8)) == "eq(i0, i1)") + assert(primop(false, "neq", 1, List(8, 8)) == "not(eq(i0, i1))") + + assert(primop(true, "lt", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgeq(i0, i1))") + assert(primop(true, "leq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgt(i0, i1))") + assert(primop(true, "gt", 1, List(8, 8), resAlwaysUnsigned = true) == "sgt(i0, i1)") + assert(primop(true, "geq", 1, List(8, 8), resAlwaysUnsigned = true) == "sgeq(i0, i1)") + assert(primop(true, "eq", 1, List(8, 8), resAlwaysUnsigned = true) == "eq(i0, i1)") + assert(primop(true, "neq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(eq(i0, i1))") + + // it should always extend the width to the max of both + assert(primop(false, "gt", 1, List(7, 8)) == "ugt(zext(i0, 1), i1)") + } + + it should "correctly translate the `pad` primitive operation" in { + // firrtl pad takes new width as argument, whereas the smt zext takes the number of bits to extend by + assert(primop(false, "pad", 8, List(3), List(8)) == "zext(i0, 5)") + assert(primop(false, "pad", 8, List(3), List(5)) == "zext(zext(i0, 2), 3)") + + // there is no negative padding, instead the result is just e + assert(primop(false, "pad", 3, List(3), List(2)) == "i0") + + assert(primop(true, "pad", 8, List(3), List(8)) == "sext(i0, 5)") + assert(primop(true, "pad", 8, List(3), List(5)) == "sext(sext(i0, 2), 3)") + } + + it should "correctly translate the asX primitive operations" in { + // these are all essentially no-ops + assert(primop(false, "asUInt", 3, List(3)) == "i0") + assert(primop(true, "asSInt", 3, List(3)) == "i0") + } + + it should "correctly translate the `shl` primitive operation" in { + assert(primop(false, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)") + assert(primop(true, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)") + assert(primop(false, "shl", 3, List(3), List(0)) == "i0") + } + + it should "correctly translate the `shr` primitive operation" in { + assert(primop(false, "shr", 6, List(9), List(3)) == "i0[8:3]") + assert(primop(true, "shr", 6, List(9), List(3)) == "i0[8:3]") + + // "If n is greater than or equal to the bit-width of e, + // the resulting value will be zero for unsigned types and the sign bit for signed types." + assert(primop(false, "shr", 1, List(3), List(3)) == "1'b0") + assert(primop(false, "shr", 1, List(3), List(4)) == "1'b0") + assert(primop(true, "shr", 1, List(3), List(3)) == "i0[2]") + assert(primop(true, "shr", 1, List(3), List(4)) == "i0[2]") + } + + it should "correctly translate the `dshl` primitive operation" in { + assert(primop(false, "dshl", 31, List(16, 4)) == "logical_shift_left(zext(i0, 15), zext(i1, 27))") + assert(primop(false, "dshl", 19, List(16, 2)) == "logical_shift_left(zext(i0, 3), zext(i1, 17))") + assert(primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == + "logical_shift_left(sext(i0, 3), zext(i1, 17))") + } + + it should "correctly translate the `dshr` primitive operation" in { + assert(primop(false, "dshr", 16, List(16, 4)) == "logical_shift_right(i0, zext(i1, 12))") + assert(primop(false, "dshr", 16, List(16, 2)) == "logical_shift_right(i0, zext(i1, 14))") + assert(primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == + "arithmetic_shift_right(i0, zext(i1, 14))") + } + + it should "correctly translate the `cvt` primitive operation" in { + // for signed operands, this is a no-op + assert(primop(true, "cvt", 3, List(3)) == "i0") + + // for unsigned, a zero is prepended + assert(primop("cvt", "SInt<16>", List("UInt<15>"), List()) == "concat(1'b0, i0)") + assert(primop("cvt", "SInt<16>", List("UInt<14>"), List()) == "sext(concat(1'b0, i0), 1)") + } + + it should "correctly translate the `neg` primitive operation" in { + assert(primop(true, "neg", 4, List(3)) == "neg(sext(i0, 1))") + assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "neg(zext(i0, 1))") + } + + it should "correctly translate the `not` primitive operation" in { + assert(primop(false, "not", 4, List(4)) == "not(i0)") + assert(primop("not", "UInt<4>", List("SInt<4>"), List()) == "not(i0)") + } + + it should "correctly translate the binary bitwise primitive operations" in { + assert(primop(false, "and", 4, List(4, 3)) == "and(i0, zext(i1, 1))") + assert(primop("and", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "and(i0, sext(i1, 1))") + + assert(primop(false, "or", 4, List(4, 3)) == "or(i0, zext(i1, 1))") + assert(primop("or", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "or(i0, sext(i1, 1))") + + assert(primop(false, "xor", 4, List(4, 3)) == "xor(i0, zext(i1, 1))") + assert(primop("xor", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "xor(i0, sext(i1, 1))") + } + + it should "correctly translate the bitwise reduction primitive operation" in { + // zero width special cases are removed by the firrtl compiler + assert(primop(false, "andr", 1, List(0)) == "1'b1") + assert(primop(false, "orr", 1, List(0)) == "redor(1'b0)") + assert(primop(false, "xorr", 1, List(0)) == "redxor(1'b0)") + + assert(primop(false, "andr", 1, List(3)) == "redand(i0)") + assert(primop(true, "andr", 1, List(3), resAlwaysUnsigned = true) == "redand(i0)") + + assert(primop(false, "orr", 1, List(3)) == "redor(i0)") + assert(primop(true, "orr", 1, List(3), resAlwaysUnsigned = true) == "redor(i0)") + + assert(primop(false, "xorr", 1, List(3)) == "redxor(i0)") + assert(primop(true, "xorr", 1, List(3), resAlwaysUnsigned = true) == "redxor(i0)") + } + + it should "correctly translate the `cat` primitive operation" in { + assert(primop(false, "cat", 7, List(4, 3)) == "concat(i0, i1)") + assert(primop(true, "cat", 7, List(4, 3), resAlwaysUnsigned = true) == "concat(i0, i1)") + } + + it should "correctly translate the `bits` primitive operation" in { + assert(primop(false, "bits", 1, List(4), List(2,2)) == "i0[2]") + assert(primop(false, "bits", 2, List(4), List(2,1)) == "i0[2:1]") + assert(primop(false, "bits", 1, List(4), List(2,1)) == "i0[2:1][0]") + assert(primop(false, "bits", 3, List(4), List(2,1)) == "zext(i0[2:1], 1)") + + assert(primop(true, "bits", 1, List(4), List(2,2), resAlwaysUnsigned = true) == "i0[2]") + assert(primop(true, "bits", 2, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1]") + assert(primop(true, "bits", 1, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1][0]") + assert(primop(true, "bits", 3, List(4), List(2,1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") + } + + it should "correctly translate the `head` primitive operation" in { + // "The result of the head operation are the n most significant bits of e" + assert(primop(false, "head", 1, List(4), List(1)) == "i0[3]") + assert(primop(false, "head", 1, List(5), List(1)) == "i0[4]") + assert(primop(false, "head", 3, List(5), List(3)) == "i0[4:2]") + } + + it should "correctly translate the `tail` primitive operation" in { + // "The tail operation truncates the n most significant bits from e" + assert(primop(false, "tail", 3, List(4), List(1)) == "i0[2:0]") + assert(primop(false, "tail", 4, List(5), List(1)) == "i0[3:0]") + assert(primop(false, "tail", 2, List(5), List(3)) == "i0[1:0]") + } +}
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala new file mode 100644 index 00000000..ca7974c5 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala @@ -0,0 +1,314 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +import firrtl.{MemoryArrayInit, MemoryScalarInit, Utils} + +private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { + behavior of "ModuleToTransitionSystem.run" + + + it should "model registers as state" in { + // if a signal is invalid, it could take on an arbitrary value in that cycle + val src = + """circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input en : UInt<1> + | input in : UInt<8> + | output out : UInt<8> + | + | reg r : UInt<8>, clock with : (reset => (reset, UInt<8>(0))) + | when en: + | r <= in + | out <= r + | + |""".stripMargin + val sys = toSys(src) + + assert(sys.signals.length == 2) + + // the when is translated as a ITE + val genSignal = sys.signals.filterNot(_.name == "out").head + assert(genSignal.e.toString == "ite(en, in, r)") + + // the reset is synchronous + val r = sys.states.head + assert(r.sym.name == "r") + assert(r.init.isEmpty, "we are not using any preset, so the initial register content is arbitrary") + assert(r.next.get.toString == s"ite(reset, 8'b0, ${genSignal.name})") + } + + private def memCircuit(depth: Int = 32) = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<${Utils.getUIntWidth(depth)}> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => $depth + | reader => r + | writer => w + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + it should "model memories as state" in { + val sys = toSys(memCircuit()) + + assert(sys.signals.length == 9-2+1, "9 connects - 2 clock connects + 1 combinatorial read port") + + val sig = sys.signals.map(s => s.name -> s.e).toMap + + // masks and enables should all be true + val True = BVLiteral(1, 1) + assert(sig("m.w.mask") == True) + assert(sig("m.w.en") == True) + assert(sig("m.r.en") == True) + + // read data should always be enabled + assert(sig("m.r.data").toString == "m[m.r.addr]") + + // the memory is modelled as a state + val m = sys.states.find(_.sym.name == "m").get + assert(m.sym.isInstanceOf[ArraySymbol]) + val sym = m.sym.asInstanceOf[ArraySymbol] + assert(sym.indexWidth == 5) + assert(sym.dataWidth == 8) + assert(m.init.isEmpty) + //assert(m.next.get.toString.contains("m[m.w.addr := m.w.data]")) + assert(m.next.get.toString == "m[m.w.addr := m.w.data]") + } + + it should "support scalar initialization of a memory to 0" in { + val sys = toSys(memCircuit(), memInit = Map("m" -> MemoryScalarInit(0))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 32)") + } + + it should "support scalar initialization of a memory to 127" in { + val sys = toSys(memCircuit(31), memInit = Map("m" -> MemoryScalarInit(127))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b1111111] x 32)") + } + + it should "support array initialization of a memory to Seq(0, 1, 2, 3)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(0, 1, 2, 3)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 4)[2'b1 := 8'b1][2'b10 := 8'b10][2'b11 := 8'b11]") + } + + it should "support array initialization of a memory to Seq(1, 0, 1, 0)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 1, 0)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b1] x 4)[2'b1 := 8'b0][2'b11 := 8'b0]") + } + + it should "support array initialization of a memory to Seq(1, 0, 0, 0)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 0, 0)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 4)[2'b0 := 8'b1]") + } + + it should "support array initialization from a file" ignore { + assert(false, "TODO") + } + + it should "support memories with registered read port" in { + def src(readUnderWrite: String) = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + + val oldValue = toSys(src("old")) + val oldMData = oldValue.states.find(_.sym.name == "m.r.data").get + assert(oldMData.sym.toString == "m.r.data") + assert(oldMData.next.get.toString == "m[m.r.addr]", "we just need to read the current value") + val readDataSignal = oldValue.signals.find(_.name == "m.r.data") + assert(readDataSignal.isEmpty, s"${readDataSignal.map(_.toString)} should not exist") + + val undefinedValue = toSys(src("undefined")) + val undefinedMData = undefinedValue.states.find(_.sym.name == "m.r.data").get + assert(undefinedMData.sym.toString == "m.r.data") + val undefined = "RANDOM.m_r_read_under_write_undefined" + assert(undefinedMData.next.get.toString == + s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", + "randomize result if there is a write") + } + + it should "support memories with potential write-write conflicts" in { + val src = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + + val sys = toSys(src) + val m = sys.states.find(_.sym.name == "m").get + + val regularUpdate = "m[m.w1.addr := m.w1.data][m.w2.addr := m.w2.data]" + val collision = "eq(m.w1.addr, m.w2.addr)" + val collisionUpdate = "m[m.w1.addr := RANDOM.m_w1_write_write_collision]" + + assert(m.next.get.toString == s"ite($collision, $collisionUpdate, $regularUpdate)") + } + + it should "model invalid signals as inputs" in { + // if a signal is invalid, it could take on an arbitrary value in that cycle + val src = + """circuit m: + | module m: + | input en : UInt<1> + | output o : UInt<8> + | o is invalid + | when en: + | o <= UInt<8>(0) + |""".stripMargin + val sys = toSys(src) + assert(sys.inputs.length == 2) + val random = sys.inputs.filter(_.name.contains("RANDOM")) + assert(random.length == 1) + assert(random.head.width == 8) + } + + it should "throw an error on async reset" in { + val err = intercept[AsyncResetException] { + toSys( + """circuit m: + | module m: + | input reset : AsyncReset + |""".stripMargin + ) + } + assert(err.getMessage.contains("reset")) + } + + it should "throw an error on casting to async reset" in { + val err = intercept[AssertionError] { + toSys( + """circuit m: + | module m: + | input reset : UInt<1> + | node async = asAsyncReset(reset) + |""".stripMargin + ) + } + assert(err.getMessage.contains("reset")) + } + + it should "throw an error on multiple clocks" in { + val err = intercept[MultiClockException] { + toSys( + """circuit m: + | module m: + | input clk1 : Clock + | input clk2 : Clock + |""".stripMargin + ) + } + assert(err.getMessage.contains("clk1, clk2")) + } + + it should "throw an error on using a clock as uInt" in { + // While this could potentially be supported in the future, for now we do not allow + // a clock to be used for anything besides updating registers and memories. + val err = intercept[AssertionError] { + toSys( + """circuit m: + | module m: + | input clk : Clock + | output o : UInt<1> + | o <= asUInt(clk) + | + |""".stripMargin + ) + } + assert(err.getMessage.contains("clk")) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala new file mode 100644 index 00000000..6bfb5437 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala @@ -0,0 +1,38 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +import firrtl.annotations.Annotation +import firrtl.{MemoryInitValue, ir} +import firrtl.stage.{Forms, TransformManager} +import org.scalatest.flatspec.AnyFlatSpec + +private abstract class SMTBackendBaseSpec extends AnyFlatSpec { + private val dependencies = Forms.LowForm + private val compiler = new TransformManager(dependencies) + + protected def compile(src: String, annos: Seq[Annotation] = List()): ir.Circuit = { + val c = firrtl.Parser.parse(src) + compiler.runTransform(firrtl.CircuitState(c, annos)).circuit + } + + protected def toSys(src: String, mod: String = "m", presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + val circuit = compile(src) + val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module] + // println(module.serialize) + new ModuleToTransitionSystem().run(module, presetRegs = presetRegs, memInit = memInit) + } + + protected def toBotr2(src: String, mod: String = "m"): Iterable[String] = + Btor2Serializer.serialize(toSys(src, mod)) + + protected def toBotr2Str(src: String, mod: String = "m"): String = + toBotr2(src, mod).mkString("\n") + "\n" + + protected def toSMTLib(src: String, mod: String = "m"): Iterable[String] = + SMTTransitionSystemEncoder.encode(toSys(src, mod)).map(SMTLibSerializer.serialize) + + protected def toSMTLibStr(src: String, mod: String = "m"): String = + toSMTLib(src, mod).mkString("\n") + "\n" +}
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala new file mode 100644 index 00000000..7193474d --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala @@ -0,0 +1,66 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class SMTLibSpec extends SMTBackendBaseSpec { + + it should "convert a hello world module" in { + val src = + """circuit m: + | module m: + | input clock: Clock + | input a: UInt<8> + | output b: UInt<16> + | b <= a + | assert(clock, eq(a, b), UInt(1), "") + |""".stripMargin + + val expected = + """(declare-sort m_s 0) + |; firrtl-smt2-input a 8 + |(declare-fun a_f (m_s) (_ BitVec 8)) + |; firrtl-smt2-output b 16 + |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state))) + |; firrtl-smt2-assert assert_ 1 + |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state))) + |(define-fun m_t ((state m_s) (state_n m_s)) Bool true) + |(define-fun m_i ((state m_s)) Bool true) + |(define-fun m_a ((state m_s)) Bool (assert__f state)) + |(define-fun m_u ((state m_s)) Bool true) + |""".stripMargin + + assert(toSMTLibStr(src) == expected) + } + + it should "include FileInfo in the output" in { + val src = + """circuit m: @[circuit 0:0] + | module m: @[module 0:0] + | input clock: Clock @[clock 0:0] + | input a: UInt<8> @[a 0:0] + | output b: UInt<16> @[b 0:0] + | b <= a @[b_a 0:0] + | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0] + |""".stripMargin + + val expected = + """; @ module 0:0 + |(declare-sort m_s 0) + |; firrtl-smt2-input a 8 + |; @ a 0:0 + |(declare-fun a_f (m_s) (_ BitVec 8)) + |; firrtl-smt2-output b 16 + |; @ b 0:0, b_a 0:0 + |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state))) + |; firrtl-smt2-assert assert_ 1 + |; @ assert 0:0 + |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state))) + |(define-fun m_t ((state m_s) (state_n m_s)) Bool true) + |(define-fun m_i ((state m_s)) Bool true) + |(define-fun m_a ((state m_s)) Bool (assert__f state)) + |(define-fun m_u ((state m_s)) Bool true) + |""".stripMargin + + assert(toSMTLibStr(src) == expected) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala new file mode 100644 index 00000000..4c6901ea --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.CircuitTarget +import firrtl.backends.experimental.smt.{GlobalClockAnnotation, StutteringClockTransform} +import firrtl.options.Dependency +import firrtl.stage.RunFirrtlTransformAnnotation + +class AsyncResetSpec extends EndToEndSMTBaseSpec { + def annos(name: String) = Seq( + RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]), + GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock"))) + + "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs(RequiresZ3) in { + def in(resetType: String) = + s"""circuit AsyncReset00: + | module AsyncReset00: + | input global_clock: Clock + | input c: Clock + | input reset: $resetType + | input preset: AsyncReset + | + | ; a register with async reset + | reg r: UInt<4>, c with: (reset => (reset, UInt(3))) + | + | ; a counter/toggler connected to the clock c + | reg count: UInt<1>, c with: (reset => (preset, UInt(0))) + | count <= add(count, UInt(1)) + | + | ; the past machinery and the assertion uses the global clock + | reg past_valid: UInt<1>, global_clock with: (reset => (preset, UInt(0))) + | past_valid <= UInt(1) + | reg past_r: UInt<4>, global_clock + | past_r <= r + | reg past_count: UInt<1>, global_clock + | past_count <= count + | + | ; can the value of r change without the count changing? + | assert(global_clock, or(not(eq(count, past_count)), eq(r, past_r)), past_valid, "count = past(count) |-> r = past(r)") + |""".stripMargin + test(in("AsyncReset"), MCFail(1), kmax=2, annos=annos("AsyncReset00")) + test(in("UInt<1>"), MCSuccess, kmax=2, annos=annos("AsyncReset00")) + } + +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala new file mode 100644 index 00000000..2227719b --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala @@ -0,0 +1,230 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import java.io.{File, PrintWriter} + +import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation} +import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} +import firrtl.options.TargetDirAnnotation +import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlStage, OutputFileAnnotation, RunFirrtlTransformAnnotation} +import firrtl.util.BackendCompilationUtilities +import logger.{LazyLogging, LogLevel, LogLevelAnnotation} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.must.Matchers + +import scala.sys.process._ + +class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { + "we" should "check if Z3 is available" taggedAs(RequiresZ3) in { + val log = ProcessLogger(_ => (), logger.warn(_)) + val ret = Process(Seq("which", "z3")).run(log).exitValue() + if(ret != 0) { + logger.error( + """The z3 SMT-Solver seems not to be installed. + |You can exclude the end-to-end smt backend tests which rely on z3 like this: + |sbt testOnly -- -l RequiresZ3 + |""".stripMargin) + } + assert(ret == 0) + } + + "Z3" should "be available in version 4" taggedAs(RequiresZ3) in { + assert(Z3ModelChecker.getZ3Version.startsWith("4.")) + } + + "a simple combinatorial check" should "pass" taggedAs(RequiresZ3) in { + val in = + """circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, lt(add(a, b), UInt(3)), UInt(1), "a + b < 3") + |""".stripMargin + test(in, MCSuccess) + } + + "a simple combinatorial check" should "fail immediately" taggedAs(RequiresZ3) in { + val in = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, gt(add(a, b), UInt(3)), UInt(1), "a + b > 3") + |""".stripMargin + test(in, MCFail(0)) + } + + "adding the right assumption" should "make a test pass" taggedAs(RequiresZ3) in { + val in0 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + |""".stripMargin + val in1 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + | assume(c, neq(a, UInt(0)), UInt(1), "a != 0") + |""".stripMargin + test(in0, MCFail(0)) + test(in1, MCSuccess) + + val in2 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | input en: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + | assume(c, neq(a, UInt(0)), en, "a != 0 if en") + |""".stripMargin + test(in2, MCFail(0)) + } + + "a register connected to preset reset" should "be initialized with the reset value" taggedAs(RequiresZ3) in { + def in(rEq: Int) = + s"""circuit Preset00: + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) + | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") + |""".stripMargin + test(in(3), MCSuccess, kmax = 1) + test(in(2), MCFail(0)) + } + + "a register's initial value" should "should not change" taggedAs(RequiresZ3) in { + val in = + """circuit Preset00: + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | + | ; the past value of our register will only be valid in the 1st unrolling + | reg past_valid: UInt<1>, c with: (reset => (preset, UInt(0))) + | past_valid <= UInt(1) + | + | reg r: UInt<4>, c + | reg r_past: UInt<4>, c + | r_past <= r + | assert(c, eq(r, r_past), past_valid, "past_valid => r == r_past") + |""".stripMargin + test(in, MCSuccess, kmax = 2) + } +} + +abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { + def test(src: String, expected: MCResult, kmax: Int = 0, clue: String = "", annos: Seq[Annotation] = Seq()): Unit = { + expected match { + case MCFail(k) => assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") + case _ => + } + val fir = firrtl.Parser.parse(src) + val name = fir.main + val testDir = BackendCompilationUtilities.createTestDirectory("EndToEndSMT." + name) + // we automagically add a preset annotation if an input called preset exists + val presetAnno = if(!src.contains("input preset")) { None } else { + Some(PresetAnnotation(CircuitTarget(name).module(name).ref("preset"))) + } + val res = (new FirrtlStage).execute(Array(), Seq( + LogLevelAnnotation(LogLevel.Error), // silence warnings for tests + RunFirrtlTransformAnnotation(new SMTLibEmitter), + RunFirrtlTransformAnnotation(new Btor2Emitter), + FirrtlCircuitAnnotation(fir), + TargetDirAnnotation(testDir.getAbsolutePath) + ) ++ presetAnno ++ annos) + assert(res.collectFirst{ case _: OutputFileAnnotation => true }.isDefined) + val r = Z3ModelChecker.bmc(testDir, name, kmax) + assert(r == expected, clue + "\n" + s"$testDir") + } +} + +/** Minimal implementation of a Z3 based bounded model checker. + * A more complete version of this with better use feedback should eventually be provided by a + * chisel3 formal verification library. Do not use this implementation outside of the firrtl test suite! + * */ +private object Z3ModelChecker extends LazyLogging { + def getZ3Version: String = { + val (out, ret) = executeCmd("-version") + assert(ret == 0, "failed to call z3") + assert(out.startsWith("Z3 version"), s"$out does not start with 'Z3 version'") + val version = out.split(" ")(2) + version + } + + def bmc(testDir: File, main: String, kmax: Int): MCResult = { + assert(kmax >=0 && kmax < 50, "Trying to keep kmax in a reasonable range.") + val smtFile = new File(testDir, main + ".smt2") + val header = read(smtFile) + val steps = (0 to kmax).map(k => new File(testDir, main + s"_step$k.smt2")).zipWithIndex + steps.foreach { case (f,k) => + writeStep(f, main, header, k) + val success = executeStep(f.getAbsolutePath) + if(!success) return MCFail(k) + } + MCSuccess + } + + private def executeStep(filename: String): Boolean = { + val (out, ret) = executeCmd(filename) + assert(ret == 0, s"expected success (0), not $ret: `$out`\nz3 $filename") + assert(out == "sat" || out == "unsat", s"Unexpected output: $out") + out == "unsat" + } + + private def executeCmd(cmd: String): (String, Int) = { + var out = "" + val log = ProcessLogger(s => out = s, logger.warn(_)) + val ret = Process(Seq("z3", cmd)).run(log).exitValue() + (out, ret) + } + + private def writeStep(f: File, main: String, header: Iterable[String], k: Int): Unit = { + val pw = new PrintWriter(f) + val lines = header ++ step(main, k) ++ List("(check-sat)") + lines.foreach(pw.println) + pw.close() + } + + private def step(main: String, k: Int): Iterable[String] = { + // define all states + (0 to k).map(ii => s"(declare-fun s$ii () $main$StateTpe)") ++ + // assert that init holds in state 0 + List(s"(assert ($main$Init s0))") ++ + // assert transition relation + (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii+1}))") ++ + // assert that assumptions hold in all states + (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ + // assert that assertions hold for all but last state + (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ + // check to see if we can violate the assertions in the last state + List(s"(assert (not ($main$Asserts s$k)))") + } + + private def read(f: File): Iterable[String] = { + val source = scala.io.Source.fromFile(f) + try source.getLines().toVector finally source.close() + } + + // the following suffixes have to match the ones in [[SMTTransitionSystemEncoder]] + private val Transition = "_t" + private val Init = "_i" + private val Asserts = "_a" + private val Assumes = "_u" + private val StateTpe = "_s" +} + +private sealed trait MCResult +private case object MCSuccess extends MCResult +private case class MCFail(k: Int) extends MCResult diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala new file mode 100644 index 00000000..10de9cda --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -0,0 +1,107 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.{CircuitTarget, MemoryArrayInitAnnotation, MemoryScalarInitAnnotation} + +class MemorySpec extends EndToEndSMTBaseSpec { + private def registeredTestMem(name: String, cmds: String, readUnderWrite: String): String = + registeredTestMem(name, cmds.split("\n"), readUnderWrite) + private def registeredTestMem(name: String, cmds: Iterable[String], readUnderWrite: String): String = + s"""circuit $name: + | module $name: + | input reset : UInt<1> + | input clock : Clock + | input preset: AsyncReset + | input write_addr : UInt<5> + | input read_addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= write_addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) + | cycle <= add(cycle, UInt(1)) + | node past_valid = geq(cycle, UInt(1)) + | + | ${cmds.mkString("\n ")} + |""".stripMargin + + "Registered test memory" should "return written data after two cycles" taggedAs(RequiresZ3) in { + val cmds = + """node past_past_valid = geq(cycle, UInt(2)) + |reg past_in: UInt<8>, clock + |past_in <= in + |reg past_past_in: UInt<8>, clock + |past_past_in <= past_in + |reg past_write_addr: UInt<5>, clock + |past_write_addr <= write_addr + | + |assume(clock, eq(read_addr, past_write_addr), past_valid, "read_addr = past(write_addr)") + |assert(clock, eq(out, past_past_in), past_past_valid, "out = past(past(in))") + |""".stripMargin + test(registeredTestMem("Mem00", cmds, "old"), MCSuccess, kmax = 3) + } + + private def readOnlyMem(pred: String, num: Int) = + s"""circuit Mem0$num: + | module Mem0$num: + | input c : Clock + | input read_addr : UInt<2> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 4 + | reader => r + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | assert(c, $pred, UInt(1), "") + |""".stripMargin + private def m(num: Int) = CircuitTarget(s"Mem0$num").module(s"Mem0$num").ref("m") + + "read-only memory" should "always return 0" taggedAs(RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax=2, + annos=Seq(MemoryScalarInitAnnotation(m(1), 0))) + } + + "read-only memory" should "not always return 1" taggedAs(RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax=2, + annos=Seq(MemoryScalarInitAnnotation(m(2), 0))) + } + + "read-only memory" should "always return 1 or 2" taggedAs(RequiresZ3) in { + test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), MCSuccess, kmax=2, + annos=Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1)))) + } + + "read-only memory" should "not always return 1 or 2 or 3" taggedAs(RequiresZ3) in { + test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), MCFail(0), kmax=2, + annos=Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3)))) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala new file mode 100644 index 00000000..d633a1a0 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala @@ -0,0 +1,9 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import org.scalatest.Tag + +// To disable tests that require the Z3 SMT solver to be installed use the following: +// `sbt testOnly -- -l RequiresZ3` +object RequiresZ3 extends Tag("RequiresZ3") diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala new file mode 100644 index 00000000..cbf194dd --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import java.io.File + +import firrtl.stage.{FirrtlStage, OutputFileAnnotation} +import firrtl.util.BackendCompilationUtilities +import logger.LazyLogging +import org.scalatest.flatspec.AnyFlatSpec + +import scala.sys.process.{Process, ProcessLogger} + +/** compiles the regression tests to SMTLib and parses the result with z3 */ +class SMTCompilationTest extends AnyFlatSpec with LazyLogging { + it should "generate valid SMTLib for AddNot" taggedAs(RequiresZ3) in { compileAndParse("AddNot") } + it should "generate valid SMTLib for FPU" taggedAs(RequiresZ3) in { compileAndParse("FPU") } + // we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer + it should "generate valid SMTLib for HwachaSequencer" taggedAs(RequiresZ3) ignore { compileAndParse("HwachaSequencer") } + it should "generate valid SMTLib for ICache" taggedAs(RequiresZ3) in { compileAndParse("ICache") } + it should "generate valid SMTLib for Ops" taggedAs(RequiresZ3) in { compileAndParse("Ops") } + // TODO: enable Rob test once we support more than 2 write ports on a memory + it should "generate valid SMTLib for Rob" taggedAs(RequiresZ3) ignore { compileAndParse("Rob") } + it should "generate valid SMTLib for RocketCore" taggedAs(RequiresZ3) in { compileAndParse("RocketCore") } + + private def compileAndParse(name: String): Unit = { + val testDir = BackendCompilationUtilities.createTestDirectory(name + "-smt") + val inputFile = new File(testDir, s"${name}.fir") + BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", inputFile) + + val args = Array( + "-ll", "error", // surpress warnings to keep test output clean + "--target-dir", testDir.toString, + "-i", inputFile.toString, + "-E", "experimental-smt2" + // "-fct", "firrtl.backends.experimental.smt.StutteringClockTransform" + ) + val res = (new FirrtlStage).execute(args, Seq()) + val fileName = res.collectFirst{ case OutputFileAnnotation(file) => file }.get + + val smtFile = testDir.toString + "/" + fileName + ".smt2" + val log = ProcessLogger(_ => (), logger.error(_)) + val z3Ret = Process(Seq("z3", smtFile)).run(log).exitValue() + assert(z3Ret == 0, s"Failed to parse SMTLib file $smtFile generated for $name") + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala new file mode 100644 index 00000000..8fa80b4c --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala @@ -0,0 +1,40 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +/** undefined values in firrtl are modelled as fresh auxiliary variables (inputs) */ +class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec { + + "division by zero" should "result in an arbitrary value" taggedAs(RequiresZ3) in { + // the SMTLib spec defines the result of division by zero to be all 1s + // https://cs.nyu.edu/pipermail/smt-lib/2015/000977.html + def in(dEq: Int) = + s"""circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<2> + | input b: UInt<2> + | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") + | node d = div(a, b) + | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") + |""".stripMargin + // we try to assert that (d = a / 0) is any fixed value which should be false + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") } + } + + // TODO: rem should probably also be undefined, but the spec isn't 100% clear here + + + "invalid signals" should "have an arbitrary values" taggedAs(RequiresZ3) in { + def in(aEq: Int) = + s"""circuit CC00: + | module CC00: + | input c: Clock + | wire a: UInt<2> + | a is invalid + | assert(c, eq(a, UInt($aEq)), UInt(1), "a = $aEq") + |""".stripMargin + // a should not be equivalent to any fixed value (0, 1, 2 or 3) + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"a = $ii") } + } +} |
