aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/Emitter.scala8
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala189
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala182
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala605
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala85
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala196
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala73
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala151
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala128
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala253
10 files changed, 1869 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 432fa59e..ae9a7dad 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -14,7 +14,8 @@ import firrtl.PrimOps._
import firrtl.WrappedExpression._
import Utils._
import MemPortUtils.{memPortField, memType}
-import firrtl.options.{HasShellOptions, CustomFileEmission, ShellOption, PhaseException}
+import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter}
+import firrtl.options.{CustomFileEmission, HasShellOptions, PhaseException, ShellOption}
import firrtl.options.Viewer.view
import firrtl.stage.{FirrtlFileAnnotation, FirrtlOptions, RunFirrtlTransformAnnotation, TransformManager}
// Datastructures
@@ -49,9 +50,14 @@ object EmitCircuitAnnotation extends HasShellOptions {
EmitCircuitAnnotation(classOf[VerilogEmitter]))
case "sverilog" => Seq(RunFirrtlTransformAnnotation(new SystemVerilogEmitter),
EmitCircuitAnnotation(classOf[SystemVerilogEmitter]))
+ case "experimental-btor2" => Seq(RunFirrtlTransformAnnotation(new Btor2Emitter),
+ EmitCircuitAnnotation(classOf[Btor2Emitter]))
+ case "experimental-smt2" => Seq(RunFirrtlTransformAnnotation(new SMTLibEmitter),
+ EmitCircuitAnnotation(classOf[SMTLibEmitter]))
case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") },
helpText = "Run the specified circuit emitter (all modules in one file)",
shortOption = Some("E"),
+ // the experimental options are intentionally excluded from the help message
helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>") ) )
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
new file mode 100644
index 00000000..f7ab9927
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
@@ -0,0 +1,189 @@
+// See LICENSE for license details.
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import scala.collection.mutable
+
+private object Btor2Serializer {
+ def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = {
+ new Btor2Serializer().run(sys, skipOutput)
+ }
+}
+
+private class Btor2Serializer private () {
+ private val symbols = mutable.HashMap[String, Int]()
+ private val lines = mutable.ArrayBuffer[String]()
+ private var index = 1
+
+ private def line(l: String): Int = {
+ val ii = index
+ lines += s"$ii $l"
+ index += 1
+ ii
+ }
+
+ private def comment(c: String): Unit = { lines += s"; $c" }
+ private def trailingComment(c: String): Unit = {
+ val lastLine = lines.last
+ val newLine = if(lastLine.contains(';')) { lastLine + " " + c} else { lastLine + " ; " + c }
+ lines(lines.size - 1) = newLine
+ }
+
+ // bit vector type serialization
+ private val bitVecTypeCache = mutable.HashMap[Int, Int]()
+
+ private def t(width: Int): Int = bitVecTypeCache.getOrElseUpdate(width, line(s"sort bitvec $width"))
+
+ // bit vector expression serialization
+ private def s(expr: BVExpr): Int = expr match {
+ case BVLiteral(value, width) => lit(value, width)
+ case BVSymbol(name, _) => symbols(name)
+ case BVExtend(e, 0, _) => s(e)
+ case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by")
+ case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by")
+ case BVSlice(e, hi, lo) =>
+ if (lo == 0 && hi == e.width - 1) { s(e) } else {
+ line(s"slice ${t(expr.width)} ${s(e)} $hi $lo")
+ }
+ case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b)
+ case BVNot(BVNot(e)) => s(e)
+ case BVNot(e) => unary("not", expr.width, e)
+ case BVNegate(e) => unary("neg", expr.width, e)
+ case BVReduceAnd(e) => unary("redand", expr.width, e)
+ case BVReduceOr(e) => unary("redor", expr.width, e)
+ case BVReduceXor(e) => unary("redxor", expr.width, e)
+ case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b)
+ case BVImplies(a, b) => binary("implies", expr.width, a, b)
+ case BVEqual(a, b) => binary("eq", expr.width, a, b)
+ case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}")
+ case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b)
+ case BVComparison(Compare.GreaterEqual, a, b, false) => binary("ugte", expr.width, a, b)
+ case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b)
+ case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b)
+ case BVOp(op, a, b) => binary(s(op), expr.width, a, b)
+ case BVConcat(a, b) => binary("concat", expr.width, a, b)
+ case ArrayRead(array, index) =>
+ line(s"read ${t(expr.width)} ${s(array)} ${s(index)}")
+ case BVIte(cond, tru, fals) =>
+ line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}")
+ case r : BVRawExpr =>
+ throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}")
+ }
+
+ private def s(op: Op.Value): String = op match {
+ case Op.And => "and"
+ case Op.Or => "or"
+ case Op.Xor => "xor"
+ case Op.ArithmeticShiftRight => "sra"
+ case Op.ShiftRight => "srl"
+ case Op.ShiftLeft => "sll"
+ case Op.Add => "add"
+ case Op.Mul => "mul"
+ case Op.Sub => "sub"
+ case Op.SignedDiv => "sdiv"
+ case Op.UnsignedDiv => "udiv"
+ case Op.SignedMod => "smod"
+ case Op.SignedRem => "srem"
+ case Op.UnsignedRem => "urem"
+ }
+
+ private def unary(op: String, width: Int, e: BVExpr): Int = line(s"$op ${t(width)} ${s(e)}")
+
+ private def binary(op: String, width: Int, a: BVExpr, b: BVExpr): Int =
+ line(s"$op ${t(width)} ${s(a)} ${s(b)}")
+
+ private def lit(value: BigInt, w: Int): Int = {
+ val typ = t(w)
+ lazy val mask = (BigInt(1) << w) - 1
+ if (value == 0) line(s"zero $typ")
+ else if (value == 1) line(s"one $typ")
+ else if (value == mask) line(s"ones $typ")
+ else {
+ val digits = value.toString(2)
+ val padded = digits.reverse.padTo(w, '0').reverse
+ line(s"const $typ $padded")
+ }
+ }
+
+ // array type serialization
+ private val arrayTypeCache = mutable.HashMap[(Int, Int), Int]()
+
+ private def t(indexWidth: Int, dataWidth: Int): Int =
+ arrayTypeCache.getOrElseUpdate((indexWidth, dataWidth), line(s"sort array ${t(indexWidth)} ${t(dataWidth)}"))
+
+ // array expression serialization
+ private def s(expr: ArrayExpr): Int = expr match {
+ case ArraySymbol(name, _, _) => symbols(name)
+ case ArrayStore(array, index, data) =>
+ line(s"write ${t(expr.indexWidth, expr.dataWidth)} ${s(array)} ${s(index)} ${s(data)}")
+ case ArrayIte(cond, tru, fals) =>
+ // println("WARN: ITE on array is probably not supported by btor2")
+ // While the spec does not seem to allow array ite, it seems to be supported in practice.
+ // It is essential to model memories, so any support in the wild should be fairly well tested.
+ line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}")
+ case ArrayConstant(e, _) => s(e)
+ case r : ArrayRawExpr =>
+ throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}")
+ }
+
+ private def s(expr: SMTExpr): Int = expr match {
+ case b: BVExpr => s(b)
+ case a: ArrayExpr => s(a)
+ }
+
+ // serialize the type of the expression
+ private def t(expr: SMTExpr): Int = expr match {
+ case b: BVExpr => t(b.width)
+ case a: ArrayExpr => t(a.indexWidth, a.dataWidth)
+ }
+
+ def run(sys: TransitionSystem, skipOutput: Boolean): Iterable[String] = {
+ def declare(name: String, expr: => Int): Unit = {
+ assert(!symbols.contains(name), s"Trying to redeclare `$name`")
+ val id = expr
+ symbols(name) = id
+ if (!skipOutput && sys.outputs.contains(name)) line(s"output $id ; $name")
+ if (sys.assumes.contains(name)) line(s"constraint $id ; $name")
+ if (sys.asserts.contains(name)){
+ val invertedId = line(s"not ${t(1)} $id")
+ line(s"bad $invertedId ; $name")
+ }
+ if (sys.fair.contains(name)) line(s"fair $id ; $name")
+ // add trailing comment
+ sys.comments.get(name).foreach(trailingComment)
+ }
+
+ // header
+ sys.header.foreach(comment)
+
+ // declare inputs
+ sys.inputs.foreach { ii =>
+ declare(ii.name, line(s"input ${t(ii.width)} ${ii.name}"))
+ }
+
+ // define state init
+ sys.states.foreach { st =>
+ // calculate init expression before declaring the state
+ // this is required by btormc (presumably to avoid cycles in the init expression)
+ val initId = st.init.map { init => comment(s"${st.sym}.init"); s(init) }
+ declare(st.sym.name, line(s"state ${t(st.sym)} ${st.sym.name}"))
+ st.init.foreach { init => line(s"init ${t(init)} ${s(st.sym)} ${initId.get}") }
+ }
+
+ // define all other signals
+ sys.signals.foreach { signal =>
+ declare(signal.name, s(signal.e))
+ }
+
+ // define state next
+ sys.states.foreach { st =>
+ st.next.foreach { next =>
+ comment(s"${st.sym}.next")
+ line(s"next ${t(next)} ${s(st.sym)} ${s(next)}")
+ }
+ }
+
+ lines
+ }
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
new file mode 100644
index 00000000..0a223840
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
@@ -0,0 +1,182 @@
+// See LICENSE for license details.
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import firrtl.ir
+import firrtl.PrimOps
+import firrtl.passes.CheckWidths.WidthTooBig
+
+private trait TranslationContext {
+ def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe))
+ def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe))
+ def getRandom(width: Int): BVExpr
+}
+
+private object FirrtlExpressionSemantics {
+ def getWidth(tpe: ir.Type): Int = tpe match {
+ case ir.UIntType(ir.IntWidth(w)) => w.toInt
+ case ir.SIntType(ir.IntWidth(w)) => w.toInt
+ case ir.ClockType => 1
+ case ir.ResetType => 1
+ case ir.AnalogType(ir.IntWidth(w)) => w.toInt
+ case other => throw new RuntimeException(s"Cannot handle type $other")
+ }
+
+ def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = {
+ val eSMT = e match {
+ case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts)
+ case r : ir.Reference => ctx.getReference(r.serialize, r.tpe)
+ case r : ir.SubField => ctx.getReference(r.serialize, r.tpe)
+ case r : ir.SubIndex => ctx.getReference(r.serialize, r.tpe)
+ case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt)
+ case ir.SIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt)
+ case ir.Mux(cond, tval, fval, _) =>
+ val width = List(tval, fval).map(getWidth).max
+ BVIte(toSMT(cond), toSMT(tval, width), toSMT(fval, width))
+ case ir.ValidIf(cond, value, tpe) =>
+ val tru = toSMT(value)
+ BVIte(toSMT(cond), tru, ctx.getRandom(tpe))
+ }
+ assert(eSMT.width == getWidth(e), "We aim to always produce a SMT expression of the same width as the firrtl expression.")
+ eSMT
+ }
+
+ /** Ensures that the result has the desired width by appropriately extending it. */
+ def toSMT(e: ir.Expression, width: Int, allowNarrow: Boolean = false)(implicit ctx: TranslationContext): BVExpr =
+ forceWidth(toSMT(e), isSigned(e), width, allowNarrow)
+
+ private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = {
+ if(eSMT.width == width) { eSMT }
+ else if(width < eSMT.width) {
+ assert(allowNarrow, s"Narrowing from ${eSMT.width} bits to $width bits is not allowed!")
+ BVSlice(eSMT, width - 1, 0)
+ } else {
+ BVExtend(eSMT, width - eSMT.width, eSigned)
+ }
+ }
+
+ // see "Primitive Operations" section in the Firrtl Specification
+ private def onPrim(op: ir.PrimOp, args: Seq[ir.Expression], consts: Seq[BigInt])(implicit ctx: TranslationContext):
+ BVExpr = {
+ (op, args, consts) match {
+ case (PrimOps.Add, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max + 1
+ BVOp(Op.Add, toSMT(e1, width), toSMT(e2, width))
+ case (PrimOps.Sub, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max + 1
+ BVOp(Op.Sub, toSMT(e1, width), toSMT(e2, width))
+ case (PrimOps.Mul, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).sum
+ BVOp(Op.Mul, toSMT(e1, width), toSMT(e2, width))
+ case (PrimOps.Div, Seq(num, den), _) =>
+ val (width, op) = if(isSigned(num)) {
+ (getWidth(num) + 1, Op.SignedDiv)
+ } else { (getWidth(num), Op.UnsignedDiv) }
+ // "The result of a division where den is zero is undefined."
+ val undef = ctx.getRandom(width)
+ val denSMT = toSMT(den)
+ val denIsZero = BVEqual(denSMT, BVLiteral(0, denSMT.width))
+ val numByDen = BVOp(op, toSMT(num, width), forceWidth(denSMT, isSigned(den), width))
+ BVIte(denIsZero, undef, numByDen)
+ case (PrimOps.Div, Seq(num, den), _) if isSigned(num) =>
+ val width = getWidth(num) + 1
+ BVOp(Op.SignedDiv, toSMT(num, width), toSMT(den, width))
+ case (PrimOps.Rem, Seq(num, den), _) =>
+ val op = if(isSigned(num)) Op.SignedRem else Op.UnsignedRem
+ val width = args.map(getWidth).max
+ val resWidth = args.map(getWidth).min
+ val res = BVOp(op, toSMT(num, width), toSMT(den, width))
+ if(res.width > resWidth) { BVSlice(res, resWidth - 1, 0) } else { res }
+ case (PrimOps.Lt, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max
+ BVNot(BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1)))
+ case (PrimOps.Leq, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max
+ BVNot(BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1)))
+ case (PrimOps.Gt, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max
+ BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1))
+ case (PrimOps.Geq, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max
+ BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1))
+ case (PrimOps.Eq, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max
+ BVEqual(toSMT(e1, width), toSMT(e2, width))
+ case (PrimOps.Neq, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max
+ BVNot(BVEqual(toSMT(e1, width), toSMT(e2, width)))
+ case (PrimOps.Pad, Seq(e), Seq(n)) =>
+ val width = getWidth(e)
+ if(n <= width) { toSMT(e) } else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) }
+ case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e) ; toSMT(e)
+ case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e) ; toSMT(e)
+ case (PrimOps.AsFixedPoint, Seq(e), _) => throw new AssertionError("Fixed-Point numbers need to be lowered!")
+ case (PrimOps.AsClock, Seq(e), _) => toSMT(e)
+ case (PrimOps.AsAsyncReset, Seq(e), _) =>
+ checkForClockInCast(PrimOps.AsAsyncReset, e)
+ throw new AssertionError(s"Asynchronous resets are not supported! Cannot cast ${e.serialize}.")
+ case (PrimOps.Shl, Seq(e), Seq(n)) => if(n == 0) { toSMT(e) } else {
+ val zeros = BVLiteral(0, n.toInt)
+ BVConcat(toSMT(e), zeros)
+ }
+ case (PrimOps.Shr, Seq(e), Seq(n)) =>
+ val width = getWidth(e)
+ // "If n is greater than or equal to the bit-width of e,
+ // the resulting value will be zero for unsigned types
+ // and the sign bit for signed types"
+ if(n >= width) {
+ if(isSigned(e)) { BV1BitZero } else { BVSlice(toSMT(e), width - 1, width - 1) }
+ } else {
+ BVSlice(toSMT(e), width - 1, n.toInt)
+ }
+ case (PrimOps.Dshl, Seq(e1, e2), _) =>
+ val width = getWidth(e1) + (1 << getWidth(e2)) - 1
+ BVOp(Op.ShiftLeft, toSMT(e1, width), toSMT(e2, width))
+ case (PrimOps.Dshr, Seq(e1, e2), _) =>
+ val width = getWidth(e1)
+ val o = if(isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight
+ BVOp(o, toSMT(e1, width), toSMT(e2, width))
+ case (PrimOps.Cvt, Seq(e), _) => if(isSigned(e)) { toSMT(e) } else { BVConcat(BV1BitZero, toSMT(e)) }
+ case (PrimOps.Neg, Seq(e), _) => BVNegate(BVExtend(toSMT(e), 1, isSigned(e)))
+ case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e))
+ case (PrimOps.And, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max
+ BVOp(Op.And, toSMT(e1, width), toSMT(e2, width))
+ case (PrimOps.Or, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max
+ BVOp(Op.Or, toSMT(e1, width), toSMT(e2, width))
+ case (PrimOps.Xor, Seq(e1, e2), _) =>
+ val width = args.map(getWidth).max
+ BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width))
+ case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e))
+ case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e))
+ case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e))
+ case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2))
+ case (PrimOps.Bits, Seq(e), Seq(hi, lo)) => BVSlice(toSMT(e), hi.toInt, lo.toInt)
+ case (PrimOps.Head, Seq(e), Seq(n)) =>
+ val width = getWidth(e)
+ assert(n >= 0 && n <= width)
+ BVSlice(toSMT(e), width - 1, width - n.toInt)
+ case (PrimOps.Tail, Seq(e), Seq(n)) =>
+ val width = getWidth(e)
+ assert(n >= 0 && n <= width)
+ assert(n < width, "While allowed by the firrtl standard, we do not support 0-bit values in this backend!")
+ BVSlice(toSMT(e), width - n.toInt - 1, 0)
+ }
+ }
+
+ /** For now we strictly forbid casting clocks to anything else.
+ * Eventually this should be replaced by a more sophisticated clock analysis pass. */
+ private def checkForClockInCast(cast: ir.PrimOp, signal: ir.Expression): Unit = {
+ assert(signal.tpe != ir.ClockType, s"Cannot cast (${cast.serialize}) clock expression ${signal.serialize}!")
+ }
+
+ private val BV1BitZero = BVLiteral(0, 1)
+
+ def isSigned(e: ir.Expression): Boolean = e.tpe match {
+ case _: ir.SIntType => true
+ case _ => false
+ }
+ private def getWidth(e: ir.Expression): Int = getWidth(e.tpe)
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
new file mode 100644
index 00000000..b3a2ff17
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -0,0 +1,605 @@
+// See LICENSE for license details.
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation}
+import FirrtlExpressionSemantics.getWidth
+import firrtl.graph.MutableDiGraph
+import firrtl.options.Dependency
+import firrtl.passes.PassException
+import firrtl.stage.Forms
+import firrtl.stage.TransformManager.TransformDependency
+import firrtl.transforms.PropagatePresetAnnotations
+import firrtl.{CircuitState, DependencyAPIMigration, MemoryArrayInit, MemoryInitValue, MemoryScalarInit, Transform, Utils, ir}
+import logger.LazyLogging
+
+import scala.collection.mutable
+
+// Contains code to convert a flat firrtl module into a functional transition system which
+// can then be exported as SMTLib or Btor2 file.
+
+private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr])
+private case class Signal(name: String, e: BVExpr) { def toSymbol: BVSymbol = BVSymbol(name, e.width) }
+private case class TransitionSystem(
+ name: String, inputs: Array[BVSymbol], states: Array[State], signals: Array[Signal],
+ outputs: Set[String], assumes: Set[String], asserts: Set[String], fair: Set[String],
+ comments: Map[String, String] = Map(), header: Array[String] = Array()) {
+ def serialize: String = {
+ (Iterator(name) ++
+ inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++
+ signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++
+ states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")
+ ).mkString("\n")
+ }
+}
+
+private case class TransitionSystemAnnotation(sys: TransitionSystem) extends NoTargetAnnotation
+
+object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
+ // TODO: We only really need [[Forms.MidForm]] + LowerTypes, but we also want to fail if there are CombLoops
+ // TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare
+ // precisely and PadWidths emits ill-typed firrtl.
+ override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm
+ override def invalidates(a: Transform): Boolean = false
+ // since this pass only runs on the main module, inlining needs to happen before
+ override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances])
+
+ // We run the propagate preset annotations pass manually since we do not want to remove ValidIfs and other
+ // Verilog emission passes.
+ // Ideally we would go in and enable the [[PropagatePresetAnnotations]] to only depend on LowForm.
+ private val presetPass = new PropagatePresetAnnotations
+ override protected def execute(state: CircuitState): CircuitState = {
+ // run the preset pass to extract all preset registers and remove preset reset signals
+ val afterPreset = presetPass.execute(state)
+ val circuit = afterPreset.circuit
+ val presetRegs = afterPreset.annotations
+ .collect { case PresetRegAnnotation(target) if target.module == circuit.main => target.ref }.toSet
+
+ // collect all non-random memory initialization
+ val memInit = afterPreset.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a }
+ .filter(_.target.module == circuit.main).map(a => a.target.ref -> a.initValue).toMap
+
+ // convert the main module
+ val main = circuit.modules.find(_.name == circuit.main).get
+ val sys = main match {
+ case x: ir.ExtModule =>
+ throw new ExtModuleException(
+ "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog.")
+ case m: ir.Module =>
+ new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit=memInit)
+ }
+
+ val sortedSys = TopologicalSort.run(sys)
+ val anno = TransitionSystemAnnotation(sortedSys)
+ state.copy(circuit=circuit, annotations = afterPreset.annotations :+ anno )
+ }
+}
+
+private object UnsupportedException {
+ val HowToRunStuttering: String =
+ """
+ |You can run the StutteringClockTransform which
+ |replaces all clock inputs with a clock enable signal.
+ |This is required not only for multi-clock designs, but also to
+ |accurately model asynchronous reset which could happen even if there
+ |isn't a clock edge.
+ | If you are using the firrtl CLI, please add:
+ | -fct firrtl.backends.experimental.smt.StutteringClockTransform
+ | If you are calling into firrtl programmatically you can use:
+ | RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform])
+ | To designate a clock to be the global_clock (i.e. the simulation tick), use:
+ | GlobalClockAnnotation(CircuitTarget(...).module(...).ref("your_clock")))
+ |""".stripMargin
+}
+
+private class ExtModuleException(s: String) extends PassException(s)
+private class AsyncResetException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering)
+private class MultiClockException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering)
+private class MissingFeatureException(s: String) extends PassException("Unfortunately the SMT backend does not yet support: " + s)
+
+private class ModuleToTransitionSystem extends LazyLogging {
+
+ def run(m: ir.Module, presetRegs: Set[String] = Set(), memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = {
+ // first pass over the module to convert expressions; discover state and I/O
+ val scan = new ModuleScanner(makeRandom)
+ m.foreachPort(scan.onPort)
+ // multi-clock support requires the StutteringClock transform to be run
+ if(scan.clocks.size > 1) {
+ throw new MultiClockException(s"The module ${m.name} has more than one clock: ${scan.clocks.mkString(", ")}")
+ }
+ m.foreachStmt(scan.onStatement)
+
+ // turn wires and nodes into signals
+ val outputs = scan.outputs.toSet
+ val constraints = scan.assumes.toSet
+ val bad = scan.asserts.toSet
+ val isSignal = (scan.wires ++ scan.nodes ++ scan.memSignals).toSet ++ outputs ++ constraints ++ bad
+ val signals = scan.connects.filter{ case(name, _) => isSignal.contains(name) }
+ .map { case (name, expr) => Signal(name, expr) }
+
+ // turn registers and memories into states
+ val registers = scan.registers.map(r => r._1 -> r).toMap
+ val regStates = scan.connects.filter(s => registers.contains(s._1)).map { case (name, nextExpr) =>
+ val (_, width, resetExpr, initExpr) = registers(name)
+ onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs)
+ }
+ // turn memories into state
+ val memoryEncoding = new MemoryEncoding(makeRandom)
+ val memoryStatesAndOutputs = scan.memories.map(m => memoryEncoding.onMemory(m, scan.connects, memInit.get(m.name)))
+ // replace pseudo assigns for memory outputs
+ val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap
+ val signalsWithMem = signals.map { s =>
+ if (memOutputs.contains(s.name)) {
+ s.copy(e = memOutputs(s.name))
+ } else { s }
+ }
+ // filter out any left-over self assignments (this happens when we have a registered read port)
+ .filter(s => s match { case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false case _ => true })
+ val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1)
+
+ // generate comments from infos
+ val comments = mutable.HashMap[String, String]()
+ scan.infos.foreach { case (name, info) =>
+ serializeInfo(info).foreach { infoString =>
+ if(comments.contains(name)) { comments(name) += InfoSeparator + infoString }
+ else { comments(name) = InfoPrefix + infoString }
+ }
+ }
+
+ // inputs are original module inputs and any "random" signal we need for modelling
+ val inputs = scan.inputs ++ randoms.values
+
+ // module info to the comment header
+ val header = serializeInfo(m.info).map(InfoPrefix + _).toArray
+
+ val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints
+ TransitionSystem(m.name, inputs.toArray, states, signalsWithMem.toArray, outputs, constraints, bad, fair, comments.toMap, header)
+ }
+
+ private def onRegister(name: String, width: Int, resetExpr: BVExpr, initExpr: BVExpr,
+ nextExpr: BVExpr, presetRegs: Set[String]): State = {
+ assert(initExpr.width == width)
+ assert(nextExpr.width == width)
+ assert(resetExpr.width == 1)
+ val sym = BVSymbol(name, width)
+ val hasReset = initExpr != sym
+ val isPreset = presetRegs.contains(name)
+ assert(!isPreset || hasReset, s"Expected preset register $name to have a reset value, not just $initExpr!")
+ if(hasReset) {
+ val init = if(isPreset) Some(initExpr) else None
+ val next = if(isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr)
+ State(sym, next = Some(next), init = init)
+ } else {
+ State(sym, next = Some(nextExpr), init = None)
+ }
+ }
+
+ private val InfoSeparator = ", "
+ private val InfoPrefix = "@ "
+ private def serializeInfo(info: ir.Info): Option[String] = info match {
+ case ir.NoInfo => None
+ case f : ir.FileInfo => Some(f.escaped)
+ case m : ir.MultiInfo =>
+ val infos = m.flatten
+ if(infos.isEmpty) { None } else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
+ }
+
+ private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]()
+ private def makeRandom(baseName: String, width: Int): BVExpr = {
+ // TODO: actually ensure that there cannot be any name clashes with other identifiers
+ val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii)
+ val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get
+ val sym = BVSymbol(name, width)
+ randoms(name) = sym
+ sym
+ }
+}
+
+private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLogging {
+ type Connects = Iterable[(String, BVExpr)]
+ def onMemory(defMem: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (Iterable[State], Connects) = {
+ // we can only work on appropriately lowered memories
+ assert(defMem.dataType.isInstanceOf[ir.GroundType],
+ s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!")
+ assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
+
+ // collect all memory meta-data in a custom class
+ val m = new MemInfo(defMem)
+
+ // find all connections related to this memory
+ val inputs = connects.filter(_._1.startsWith(m.prefix)).toMap
+
+ // there could be a constant init
+ val init = initValue.map(getInit(m, _))
+
+ // parse and check read and write ports
+ val writers = defMem.writers.map( w => new WritePort(m, w, inputs))
+ val readers = defMem.readers.map( r => new ReadPort(m, r, inputs))
+
+ // derive next state from all write ports
+ assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.")
+ val next: ArrayExpr = if(writers.isEmpty) { m.sym } else {
+ if(writers.length > 2) {
+ throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})")
+ }
+ val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) }
+ if(writers.length == 1) { validData } else {
+ assert(writers.length == 2)
+ val conflict = writers.head.doesConflict(writers.last)
+ val conflictData = writers.head.makeRandomData("_write_write_collision")
+ val conflictStore = ArrayStore(m.sym, writers.head.addr, conflictData)
+ ArrayIte(conflict, conflictStore, validData)
+ }
+ }
+ val state = State(m.sym, init, Some(next))
+
+ // derive data signals from all read ports
+ assert(defMem.readLatency >= 0)
+ if(defMem.readLatency > 1) {
+ throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})")
+ }
+ val readPortSignals = if(defMem.readLatency == 0) {
+ readers.map { r =>
+ // combinatorial read
+ if(defMem.readUnderWrite != ir.ReadUnderWrite.New) {
+ //logger.warn(s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." +
+ // s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored.")
+ }
+ // since we do a combinatorial read, the "old" data is the current data
+ val data = r.readOld()
+ r.data.name -> data
+ }
+ } else { Seq() }
+ val readPortStates = if(defMem.readLatency == 1) {
+ readers.map { r =>
+ // we create a register for the read port data
+ val next = defMem.readUnderWrite match {
+ case ir.ReadUnderWrite.New =>
+ throw new UnsupportedFeatureException(s"registered read ports that return the new value (${m.name}.${r.name})")
+ // the thing that makes this hard is to properly handle write conflicts
+ case ir.ReadUnderWrite.Undefined =>
+ val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r)))
+ if(anyWriteToTheSameAddress == False) { r.readOld() } else {
+ val readUnderWriteData = r.makeRandomData("_read_under_write_undefined")
+ BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.readOld())
+ }
+ case ir.ReadUnderWrite.Old => r.readOld()
+ }
+ State(r.data, init=None, next=Some(next))
+ }
+ } else { Seq() }
+
+ (state +: readPortStates, readPortSignals)
+ }
+
+ private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match {
+ case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth)
+ case MemoryArrayInit(values) =>
+ assert(values.length == m.depth,
+ s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!")
+ // in order to get a more compact encoding try to find the most common values
+ val histogram = mutable.LinkedHashMap[BigInt, Int]()
+ values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
+ val baseValue = histogram.maxBy(_._2)._1
+ val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth)
+ values.zipWithIndex.filterNot(_._1 == baseValue)
+ .foldLeft[ArrayExpr](base) { case (array, (value, index)) =>
+ ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth))
+ }
+ case other => throw new RuntimeException(s"Unsupported memory init option: $other")
+ }
+
+ private class MemInfo(m: ir.DefMemory) {
+ val name = m.name
+ val depth = m.depth
+ // derrive the type of the memory from the dataType and depth
+ val dataWidth = getWidth(m.dataType)
+ val indexWidth = Utils.getUIntWidth(m.depth - 1) max 1
+ val sym = ArraySymbol(m.name, indexWidth, dataWidth)
+ val prefix = m.name + "."
+ val fullAddressRange = (BigInt(1) << indexWidth) == m.depth
+ lazy val depthBV = BVLiteral(m.depth, indexWidth)
+ def isValidAddress(addr: BVExpr): BVExpr = {
+ if(fullAddressRange) { True } else {
+ BVComparison(Compare.Greater, depthBV, addr, signed = false)
+ }
+ }
+ }
+ private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) {
+ val en: BVSymbol = makeField("en", 1)
+ val data: BVSymbol = makeField("data", memory.dataWidth)
+ val addr: BVSymbol = makeField("addr", memory.indexWidth)
+ protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width)
+ // make sure that all widths are correct
+ assert(inputs(en.name).width == en.width)
+ assert(inputs(addr.name).width == addr.width)
+ val enIsTrue: Boolean = inputs(en.name) == True
+ def makeRandomData(suffix: String): BVExpr =
+ makeRandom(memory.name + "_" + name + suffix, memory.dataWidth)
+ def readOld(): BVExpr = {
+ val canBeOutOfRange = !memory.fullAddressRange
+ val canBeDisabled = !enIsTrue
+ val data = ArrayRead(memory.sym, addr)
+ val dataWithRangeCheck = if(canBeOutOfRange) {
+ val outOfRangeData = makeRandomData("_addr_out_of_range")
+ BVIte(memory.isValidAddress(addr), data, outOfRangeData)
+ } else { data }
+ val dataWithEnabledCheck = if(canBeDisabled) {
+ val disabledData = makeRandomData("_not_enabled")
+ BVIte(en, dataWithRangeCheck, disabledData)
+ } else { dataWithRangeCheck }
+ dataWithEnabledCheck
+ }
+ }
+ private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr)
+ extends MemPort(memory, name, inputs) {
+ assert(inputs(data.name).width == data.width)
+ val mask: BVSymbol = makeField("mask", 1)
+ assert(inputs(mask.name).width == mask.width)
+ val maskIsTrue: Boolean = inputs(mask.name) == True
+ val doWrite: BVExpr = (enIsTrue, maskIsTrue) match {
+ case (true, true) => True
+ case (true, false) => mask
+ case (false, true) => en
+ case (false, false) => and(en, mask)
+ }
+ def doesConflict(r: ReadPort): BVExpr = {
+ val sameAddress = BVEqual(r.addr, addr)
+ if(doWrite == True) { sameAddress } else { and(doWrite, sameAddress) }
+ }
+ def doesConflict(w: WritePort): BVExpr = {
+ val bothWrite = and(doWrite, w.doWrite)
+ val sameAddress = BVEqual(addr, w.addr)
+ if(bothWrite == True) { sameAddress } else { and(doWrite, sameAddress) }
+ }
+ def writeTo(array: ArrayExpr): ArrayExpr = {
+ val doUpdate = if(memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr))
+ val update = ArrayStore(array, index=addr, data=data)
+ if(doUpdate == True) update else ArrayIte(doUpdate, update, array)
+ }
+
+ }
+ private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr)
+ extends MemPort(memory, name, inputs) {
+ }
+
+ private def and(a: BVExpr, b: BVExpr): BVExpr = (a,b) match {
+ case (True, True) => True
+ case (True, x) => x
+ case (x, True) => x
+ case _ => BVOp(Op.And, a, b)
+ }
+ private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b)
+ private val True = BVLiteral(1, 1)
+ private val False = BVLiteral(0, 1)
+ private def all(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) False else b.reduce((a,b) => and(a,b))
+ private def any(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) True else b.reduce((a,b) => or(a,b))
+}
+
+// performas a first pass over the module collecting all connections, wires, registers, input and outputs
+private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLogging {
+ import FirrtlExpressionSemantics.getWidth
+
+ private[firrtl] val inputs = mutable.ArrayBuffer[BVSymbol]()
+ private[firrtl] val outputs = mutable.ArrayBuffer[String]()
+ private[firrtl] val clocks = mutable.LinkedHashSet[String]()
+ private[firrtl] val wires = mutable.ArrayBuffer[String]()
+ private[firrtl] val nodes = mutable.ArrayBuffer[String]()
+ private[firrtl] val memSignals = mutable.ArrayBuffer[String]()
+ private[firrtl] val registers = mutable.ArrayBuffer[(String, Int, BVExpr, BVExpr)]()
+ private[firrtl] val memories = mutable.ArrayBuffer[ir.DefMemory]()
+ // DefNode, Connect, IsInvalid and VerificationStatement connections
+ private[firrtl] val connects = mutable.ArrayBuffer[(String, BVExpr)]()
+ private[firrtl] val asserts = mutable.ArrayBuffer[String]()
+ private[firrtl] val assumes = mutable.ArrayBuffer[String]()
+ // maps identifiers to their info
+ private[firrtl] val infos = mutable.ArrayBuffer[(String, ir.Info)]()
+ // keeps track of unused memory (data) outputs so that we can see where they are first used
+ private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]()
+
+ private[firrtl] def onPort(p: ir.Port): Unit = {
+ if(isAsyncReset(p.tpe)) {
+ throw new AsyncResetException(s"Found AsyncReset ${p.name}.")
+ }
+ infos.append(p.name -> p.info)
+ p.direction match {
+ case ir.Input =>
+ if(isClock(p.tpe)) {
+ clocks.add(p.name)
+ } else {
+ inputs.append(BVSymbol(p.name, getWidth(p.tpe)))
+ }
+ case ir.Output => outputs.append(p.name)
+ }
+ }
+
+ private[firrtl] def onStatement(s: ir.Statement): Unit = s match {
+ case ir.DefWire(info, name, tpe) =>
+ if(!isClock(tpe)) {
+ infos.append(name -> info)
+ wires.append(name)
+ }
+ case ir.DefNode(info, name, expr) =>
+ if(!isClock(expr.tpe)) {
+ insertDummyAssignsForMemoryOutputs(expr)
+ infos.append(name -> info)
+ val e = onExpression(expr, name)
+ nodes.append(name)
+ connects.append((name, e))
+ }
+ case ir.DefRegister(info, name, tpe, _, reset, init) =>
+ insertDummyAssignsForMemoryOutputs(reset)
+ insertDummyAssignsForMemoryOutputs(init)
+ infos.append(name -> info)
+ val width = getWidth(tpe)
+ val resetExpr = onExpression(reset, 1, name + "_reset")
+ val initExpr = onExpression(init, width, name + "_init")
+ registers.append((name, width, resetExpr, initExpr))
+ case m : ir.DefMemory =>
+ infos.append(m.name -> m.info)
+ val outputs = getMemOutputs(m)
+ (getMemInputs(m) ++ outputs).foreach(memSignals.append(_))
+ val dataWidth = getWidth(m.dataType)
+ outputs.foreach(name => unusedMemOutputs(name) = dataWidth)
+ memories.append(m)
+ case ir.Connect(info, loc, expr) =>
+ if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
+ val name = loc.serialize
+ insertDummyAssignsForMemoryOutputs(expr)
+ infos.append(name -> info)
+ connects.append((name, onExpression(expr, getWidth(loc.tpe), name)))
+ case ir.IsInvalid(info, loc) =>
+ if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
+ val name = loc.serialize
+ infos.append(name -> info)
+ connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe))))
+ case ir.DefInstance(info, name, module, tpe) =>
+ if(!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}")
+ // we treat all instances as blackboxes
+ logger.warn(s"WARN: treating instance $name of $module as blackbox. " +
+ "Please flatten your hierarchy if you want to include submodules in the formal model.")
+ val ports = tpe.asInstanceOf[ir.BundleType].fields
+ // skip clock and async reset ports
+ ports.filterNot(p => isClock(p.tpe) || isAsyncReset(p.tpe) ).foreach { p =>
+ if(!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p")
+ val isOutput = p.flip == ir.Default
+ val pName = name + "." + p.name
+ infos.append(pName -> info)
+ // outputs of the submodule become inputs to our module
+ if(isOutput) {
+ inputs.append(BVSymbol(pName, getWidth(p.tpe)))
+ } else {
+ outputs.append(pName)
+ }
+ }
+ case s @ ir.Verification(op, info, _, pred, en, msg) =>
+ if(op == ir.Formal.Cover) {
+ logger.warn(s"WARN: Cover statement was ignored: ${s.serialize}")
+ } else {
+ val name = msgToName(op.toString, msg.string)
+ val predicate = onExpression(pred, name + "_predicate")
+ val enabled = onExpression(en, name + "_enabled")
+ val e = BVImplies(enabled, predicate)
+ infos.append(name -> info)
+ connects.append(name -> e)
+ if(op == ir.Formal.Assert) {
+ asserts.append(name)
+ } else {
+ assumes.append(name)
+ }
+ }
+ case s : ir.Conditionally =>
+ error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}")
+ case s : ir.PartialConnect =>
+ error(s"PartialConnects are not supported. Please run ExpandConnects: ${s.serialize}")
+ case s : ir.Attach =>
+ error(s"Analog wires are not supported in the SMT backend: ${s.serialize}")
+ case s : ir.Stop =>
+ // we could wire up the stop condition as output for debug reasons
+ logger.warn(s"WARN: Stop statements are currently not supported. Ignoring: ${s.serialize}")
+ case s : ir.Print =>
+ logger.warn(s"WARN: Print statements are not supported. Ignoring: ${s.serialize}")
+ case other => other.foreachStmt(onStatement)
+ }
+
+ private val readInputFields = List("en", "addr")
+ private val writeInputFields = List("en", "mask", "addr", "data")
+ private def getMemInputs(m: ir.DefMemory): Iterable[String] = {
+ assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!")
+ val p = m.name + "."
+ m.writers.flatMap(w => writeInputFields.map(p + w + "." + _)) ++
+ m.readers.flatMap(r => readInputFields.map(p + r + "." + _))
+ }
+ private def getMemOutputs(m: ir.DefMemory): Iterable[String] = {
+ assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!")
+ val p = m.name + "."
+ m.readers.map(r => p + r + ".data")
+ }
+ // inserts a dummy assign right before a memory output is used for the first time
+ // example:
+ // m.r.data <= m.r.data ; this is the dummy assign
+ // test <= m.r.data ; this is the first use of m.r.data
+ private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if(unusedMemOutputs.nonEmpty) {
+ implicit val uses = mutable.ArrayBuffer[String]()
+ findUnusedMemoryOutputUse(next)
+ if(uses.nonEmpty) {
+ val useSet = uses.toSet
+ unusedMemOutputs.foreach { case (name, width) =>
+ if(useSet.contains(name)) connects.append(name -> BVSymbol(name, width))
+ }
+ useSet.foreach(name => unusedMemOutputs.remove(name))
+ }
+ }
+ private def findUnusedMemoryOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match {
+ case s : ir.SubField =>
+ val name = s.serialize
+ if(unusedMemOutputs.contains(name)) uses.append(name)
+ case other => other.foreachExpr(findUnusedMemoryOutputUse)
+ }
+
+ private case class Context(baseName: String) extends TranslationContext {
+ override def getRandom(width: Int): BVExpr = makeRandom(baseName, width)
+ }
+
+ private def onExpression(e: ir.Expression, width: Int, randomPrefix: String): BVExpr = {
+ implicit val ctx: TranslationContext = Context(randomPrefix)
+ FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false)
+ }
+ private def onExpression(e: ir.Expression, randomPrefix: String): BVExpr = {
+ implicit val ctx: TranslationContext = Context(randomPrefix)
+ FirrtlExpressionSemantics.toSMT(e)
+ }
+
+ private def msgToName(prefix: String, msg: String): String = {
+ // TODO: ensure that we can generate unique names
+ prefix + "_" + msg.replace(" ", "_").replace("|", "")
+ }
+ private def error(msg: String): Unit = throw new RuntimeException(msg)
+ private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType]
+ private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType
+ private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType
+}
+
+private object TopologicalSort {
+ /** Ensures that all signals in the resulting system are topologically sorted.
+ * This is necessary because [[firrtl.transforms.RemoveWires]] does
+ * not sort assignments to outputs, submodule inputs nor memory ports.
+ * */
+ def run(sys: TransitionSystem): TransitionSystem = {
+ val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name)
+ val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates)
+ // TODO: maybe sort init expressions of states (this should not be needed most of the time)
+ signalOrder match {
+ case None => sys
+ case Some(order) =>
+ val signalMap = sys.signals.map(s => s.name -> s).toMap
+ // we flatMap over `get` in order to ignore inputs/states in the order
+ sys.copy(signals = order.flatMap(signalMap.get).toArray)
+ }
+ }
+
+ private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = {
+ val known = new mutable.HashSet[String]() ++ globalSignals
+ var needsReordering = false
+ val digraph = new MutableDiGraph[String]
+ signals.foreach { case (name, expr) =>
+ digraph.addVertex(name)
+ val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr)
+ uniqueDependencies.foreach { d =>
+ if(!known.contains(d)) { needsReordering = true }
+ digraph.addPairWithEdge(name, d)
+ }
+ known.add(name)
+ }
+ if(needsReordering) {
+ Some(digraph.linearize.reverse)
+ } else { None }
+ }
+
+ private def findDependencies(expr: SMTExpr): List[String] = expr match {
+ case BVSymbol(name, _) => List(name)
+ case ArraySymbol(name, _, _) => List(name)
+ case other => other.children.flatMap(findDependencies)
+ }
+} \ No newline at end of file
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
new file mode 100644
index 00000000..322b8961
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
@@ -0,0 +1,85 @@
+// See LICENSE for license details.
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import java.io.Writer
+
+import firrtl._
+import firrtl.annotations.{Annotation, NoTargetAnnotation}
+import firrtl.options.Viewer.view
+import firrtl.options.{CustomFileEmission, Dependency}
+import firrtl.stage.FirrtlOptions
+
+
+private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform with Emitter with DependencyAPIMigration {
+ override def prerequisites: Seq[Dependency[Transform]] = Seq(Dependency(FirrtlToTransitionSystem))
+ override def invalidates(a: Transform): Boolean = false
+
+ override def emit(state: CircuitState, writer: Writer): Unit = error("Deprecated since firrtl 1.0!")
+
+ protected def serialize(sys: TransitionSystem): Annotation
+
+ private val BleedingEdgeWarning =
+ """WARNING: The SMT and BTOR2 emitters are experimental preview features.
+ |- they might be removed in future versions without deprecation warning
+ |- their behavior and interfaces might change without notice
+ |- they are unsupported, we won't be able to fix any issues you find in them
+ |- however, we do accept pull requests: https://github.com/freechipsproject/firrtl/pulls
+ |""".stripMargin
+
+ override protected def execute(state: CircuitState): CircuitState = {
+ val emitCircuit = state.annotations.exists {
+ case EmitCircuitAnnotation(a) if this.getClass == a => true
+ case EmitAllModulesAnnotation(a) if this.getClass == a => error("EmitAllModulesAnnotation not supported!")
+ case _ => false
+ }
+
+ if(!emitCircuit) { return state }
+
+ logger.warn(BleedingEdgeWarning)
+
+ val sys = state.annotations.collectFirst{ case TransitionSystemAnnotation(sys) => sys }.getOrElse {
+ error("Could not find the transition system!")
+ }
+ state.copy(annotations = state.annotations :+ serialize(sys))
+ }
+
+ protected def generatedHeader(format: String, name: String): String =
+ s"; $format description generated by firrtl ${BuildInfo.version} for module $name.\n"
+
+ protected def error(msg: String): Nothing = throw new RuntimeException(msg)
+}
+
+case class EmittedSMTModelAnnotation(name: String, src: String, outputSuffix: String)
+ extends NoTargetAnnotation with CustomFileEmission {
+ override protected def baseFileName(annotations: AnnotationSeq): String =
+ view[FirrtlOptions](annotations).outputFileName.getOrElse(name)
+ override protected def suffix: Option[String] = Some(outputSuffix)
+ override def getBytes: Iterable[Byte] = src.getBytes
+}
+
+private[firrtl] class Btor2Emitter extends SMTEmitter {
+ override def outputSuffix: String = ".btor2"
+ override protected def serialize(sys: TransitionSystem): Annotation = {
+ val btor = generatedHeader("BTOR", sys.name) + Btor2Serializer.serialize(sys).mkString("\n") + "\n"
+ EmittedSMTModelAnnotation(sys.name, btor, outputSuffix)
+ }
+}
+
+private[firrtl] class SMTLibEmitter extends SMTEmitter {
+ override def outputSuffix: String = ".smt2"
+ override protected def serialize(sys: TransitionSystem): Annotation = {
+ val hasMemory = sys.states.exists(_.sym.isInstanceOf[ArrayExpr])
+ val logic = SMTLibSerializer.setLogic(hasMemory) + "\n"
+ val header = if(hasMemory) {
+ "; We have to disable the logic for z3 to accept the non-standard \"as const\"\n" +
+ "; see https://github.com/Z3Prover/z3/issues/1803\n" +
+ "; for CVC4 you probably want to include the logic\n" +
+ ";" + logic
+ } else { logic }
+ val smt = generatedHeader("SMT-LIBv2", sys.name) + header +
+ SMTTransitionSystemEncoder.encode(sys).map(SMTLibSerializer.serialize).mkString("\n") + "\n"
+ EmittedSMTModelAnnotation(sys.name, smt, outputSuffix)
+ }
+} \ No newline at end of file
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
new file mode 100644
index 00000000..10a89e8d
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
@@ -0,0 +1,196 @@
+// See LICENSE for license details.
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+// Inspired by the uclid5 SMT library (https://github.com/uclid-org/uclid).
+// And the btor2 documentation (BTOR2 , BtorMC and Boolector 3.0 by Niemetz et.al.)
+
+package firrtl.backends.experimental.smt
+
+private sealed trait SMTExpr { def children: List[SMTExpr] }
+private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { val name: String }
+private object SMTSymbol {
+ def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match {
+ case b: BVExpr => BVSymbol(name, b.width)
+ case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth)
+ }
+}
+private sealed trait SMTNullaryExpr extends SMTExpr {
+ override def children: List[SMTExpr] = List()
+}
+
+private sealed trait BVExpr extends SMTExpr { def width: Int }
+private case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr {
+ private def minWidth = value.bitLength + (if(value <= 0) 1 else 0)
+ assert(width > 0, "Zero or negative width literals are not allowed!")
+ assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.")
+ override def toString: String = if(width <= 8) {
+ width.toString + "'b" + value.toString(2)
+ } else { width.toString + "'x" + value.toString(16) }
+}
+private case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol {
+ assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
+ assert(!name.contains("\\"), s"Invalid id $name contains `\\`")
+ assert(width > 0, "Zero width bit vectors are not supported!")
+ override def toString: String = name
+ def toStringWithType: String = name + " : " + SMTExpr.serializeType(this)
+}
+
+private sealed trait BVUnaryExpr extends BVExpr {
+ def e: BVExpr
+ override def children: List[BVExpr] = List(e)
+}
+private case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr {
+ assert(by >= 0, "Extension must be non-negative!")
+ override val width: Int = e.width + by
+ override def toString: String = if(signed) { s"sext($e, $by)" } else { s"zext($e, $by)" }
+}
+// also known as bit extract operation
+private case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr {
+ assert(lo >= 0, s"lo (lsb) must be non-negative!")
+ assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo")
+ assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi")
+ override def width: Int = hi - lo + 1
+ override def toString: String = if(hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]"
+}
+private case class BVNot(e: BVExpr) extends BVUnaryExpr {
+ override val width: Int = e.width
+ override def toString: String = s"not($e)"
+}
+private case class BVNegate(e: BVExpr) extends BVUnaryExpr {
+ override val width: Int = e.width
+ override def toString: String = s"neg($e)"
+}
+private case class BVReduceOr(e: BVExpr) extends BVUnaryExpr {
+ override def width: Int = 1
+ override def toString: String = s"redor($e)"
+}
+private case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr {
+ override def width: Int = 1
+ override def toString: String = s"redand($e)"
+}
+private case class BVReduceXor(e: BVExpr) extends BVUnaryExpr {
+ override def width: Int = 1
+ override def toString: String = s"redxor($e)"
+}
+
+private sealed trait BVBinaryExpr extends BVExpr {
+ def a: BVExpr
+ def b: BVExpr
+ override def children: List[BVExpr] = List(a, b)
+}
+private case class BVImplies(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
+ assert(a.width == 1 && b.width == 1, s"Both arguments need to be 1-bit!")
+ override def width: Int = 1
+ override def toString: String = s"impl($a, $b)"
+}
+private case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
+ assert(a.width == b.width, s"Both argument need to be the same width!")
+ override def width: Int = 1
+ override def toString: String = s"eq($a, $b)"
+}
+private object Compare extends Enumeration {
+ val Greater, GreaterEqual = Value
+}
+private case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed: Boolean) extends BVBinaryExpr {
+ assert(a.width == b.width, s"Both argument need to be the same width!")
+ override def width: Int = 1
+ override def toString: String = op match {
+ case Compare.Greater => (if(signed) "sgt" else "ugt") + s"($a, $b)"
+ case Compare.GreaterEqual => (if(signed) "sgeq" else "ugeq") + s"($a, $b)"
+ }
+}
+private object Op extends Enumeration {
+ val And = Value("and")
+ val Or = Value("or")
+ val Xor = Value("xor")
+ val ShiftLeft = Value("logical_shift_left")
+ val ArithmeticShiftRight = Value("arithmetic_shift_right")
+ val ShiftRight = Value("logical_shift_right")
+ val Add = Value("add")
+ val Mul = Value("mul")
+ val SignedDiv = Value("sdiv")
+ val UnsignedDiv = Value("udiv")
+ val SignedMod = Value("smod")
+ val SignedRem = Value("srem")
+ val UnsignedRem = Value("urem")
+ val Sub = Value("sub")
+}
+private case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr {
+ assert(a.width == b.width, s"Both argument need to be the same width!")
+ override val width: Int = a.width
+ override def toString: String = s"$op($a, $b)"
+}
+private case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
+ override val width: Int = a.width + b.width
+ override def toString: String = s"concat($a, $b)"
+}
+private case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr {
+ assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
+ override val width: Int = array.dataWidth
+ override def toString: String = s"$array[$index]"
+ override def children: List[SMTExpr] = List(array, index)
+}
+private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr {
+ assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!")
+ assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}")
+ override val width: Int = tru.width
+ override def toString: String = s"ite($cond, $tru, $fals)"
+ override def children: List[BVExpr] = List(cond, tru, fals)
+}
+
+private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int }
+private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol {
+ assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
+ assert(!name.contains("\\"), s"Invalid id $name contains `\\`")
+ override def toString: String = name
+ def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>"
+}
+private case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr {
+ assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
+ assert(array.dataWidth == data.width, "Data with does not match expected array data width!")
+ override val dataWidth: Int = array.dataWidth
+ override val indexWidth: Int = array.indexWidth
+ override def toString: String = s"$array[$index := $data]"
+ override def children: List[SMTExpr] = List(array, index, data)
+}
+private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr {
+ assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!")
+ assert(tru.indexWidth == fals.indexWidth,
+ s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}")
+ assert(tru.dataWidth == fals.dataWidth,
+ s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}")
+ override val dataWidth: Int = tru.dataWidth
+ override val indexWidth: Int = tru.indexWidth
+ override def toString: String = s"ite($cond, $tru, $fals)"
+ override def children: List[SMTExpr] = List(cond, tru, fals)
+}
+private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr {
+ assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!")
+ assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!")
+ override def width: Int = 1
+ override def toString: String = s"eq($a, $b)"
+ override def children: List[SMTExpr] = List(a, b)
+}
+private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr {
+ override val dataWidth: Int = e.width
+ override def toString: String = s"([$e] x ${ (BigInt(1) << indexWidth) })"
+ override def children: List[SMTExpr] = List(e)
+}
+
+private object SMTEqual {
+ def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a,b) match {
+ case (ab : BVExpr, bb : BVExpr) => BVEqual(ab, bb)
+ case (aa : ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba)
+ case _ => throw new RuntimeException(s"Cannot compare $a and $b")
+ }
+}
+
+private object SMTExpr {
+ def serializeType(e: SMTExpr): String = e match {
+ case b: BVExpr => s"bv<${b.width}>"
+ case a: ArrayExpr => s"bv<${a.indexWidth}> -> bv<${a.dataWidth}>"
+ }
+}
+
+// Raw SMTLib encoded expressions as an escape hatch used in the [[SMTTransitionSystemEncoder]]
+private case class BVRawExpr(serialized: String, width: Int) extends BVExpr with SMTNullaryExpr
+private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTNullaryExpr \ No newline at end of file
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
new file mode 100644
index 00000000..14e73253
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
@@ -0,0 +1,73 @@
+// See LICENSE for license details.
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+/** Similar to the mapExpr and foreachExpr methods of the firrtl ir nodes, but external to the case classes */
+private object SMTExprVisitor {
+ type ArrayFun = ArrayExpr => ArrayExpr
+ type BVFun = BVExpr => BVExpr
+
+ def map[T <: SMTExpr](bv: BVFun, ar: ArrayFun)(e: T): T = e match {
+ case b: BVExpr => map(b, bv, ar).asInstanceOf[T]
+ case a: ArrayExpr => map(a, bv, ar).asInstanceOf[T]
+ }
+ def map[T <: SMTExpr](f: SMTExpr => SMTExpr)(e: T): T =
+ map(b => f(b).asInstanceOf[BVExpr], a => f(a).asInstanceOf[ArrayExpr])(e)
+
+ private def map(e: BVExpr, bv: BVFun, ar: ArrayFun): BVExpr = e match {
+ // nullary
+ case old : BVLiteral => bv(old)
+ case old : BVSymbol => bv(old)
+ case old : BVRawExpr => bv(old)
+ // unary
+ case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVExtend(n, by, signed))
+ case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVSlice(n, hi, lo))
+ case old @ BVNot(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNot(n))
+ case old @ BVNegate(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNegate(n))
+ case old @ BVReduceAnd(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceAnd(n))
+ case old @ BVReduceOr(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceOr(n))
+ case old @ BVReduceXor(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceXor(n))
+ // binary
+ case old @ BVImplies(a, b) =>
+ val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
+ bv(if(nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB))
+ case old @ BVEqual(a, b) =>
+ val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
+ bv(if(nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB))
+ case old @ ArrayEqual(a, b) =>
+ val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
+ bv(if(nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB))
+ case old @ BVComparison(op, a, b, signed) =>
+ val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
+ bv(if(nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed))
+ case old @ BVOp(op, a, b) =>
+ val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
+ bv(if(nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB))
+ case old @ BVConcat(a, b) =>
+ val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
+ bv(if(nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB))
+ case old @ ArrayRead(a, b) =>
+ val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
+ bv(if(nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB))
+ // ternary
+ case old @ BVIte(a, b, c) =>
+ val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
+ bv(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC))
+ }
+
+
+ private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match {
+ case old : ArrayRawExpr => ar(old)
+ case old : ArraySymbol => ar(old)
+ case old @ ArrayConstant(e, indexWidth) =>
+ val n = map(e, bv, ar) ; ar(if(n.eq(e)) old else ArrayConstant(n, indexWidth))
+ case old @ ArrayStore(a, b, c) =>
+ val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
+ ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC))
+ case old @ ArrayIte(a, b, c) =>
+ val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
+ ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC))
+ }
+
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
new file mode 100644
index 00000000..1993da87
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
@@ -0,0 +1,151 @@
+// See LICENSE for license details.
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import scala.util.matching.Regex
+
+/** Converts STM Expressions to a SMTLib compatible string representation.
+ * See http://smtlib.cs.uiowa.edu/
+ * Assumes well typed expression, so it is advisable to run the TypeChecker
+ * before serializing!
+ * Automatically converts 1-bit vectors to bool.
+ */
+private object SMTLibSerializer {
+ def setLogic(hasMem: Boolean) = "(set-logic QF_" + (if(hasMem) "A" else "") + "UFBV)"
+
+ def serialize(e: SMTExpr): String = e match {
+ case b : BVExpr => serialize(b)
+ case a : ArrayExpr => serialize(a)
+ }
+
+ def serializeType(e: SMTExpr): String = e match {
+ case b : BVExpr => serializeBitVectorType(b.width)
+ case a : ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth)
+ }
+
+ private def serialize(e: BVExpr): String = e match {
+ case BVLiteral(value, width) =>
+ val mask = (BigInt(1) << width) - 1
+ val twosComplement = if(value < 0) { ((~(-value)) & mask) + 1 } else value
+ if(width == 1) {
+ if(twosComplement == 1) "true" else "false"
+ } else {
+ s"(_ bv$twosComplement $width)"
+ }
+ case BVSymbol(name, _) => escapeIdentifier(name)
+ case BVExtend(e, 0, _) => serialize(e)
+ case BVExtend(BVLiteral(value, width), by, false) => serialize(BVLiteral(value, width + by))
+ case BVExtend(e, by, signed) =>
+ val foo = if(signed) "sign_extend" else "zero_extend"
+ s"((_ $foo $by) ${asBitVector(e)})"
+ case BVSlice(e, hi, lo) =>
+ if(lo == 0 && hi == e.width - 1) { serialize(e)
+ } else {
+ val bits = s"((_ extract $hi $lo) ${asBitVector(e)})"
+ // 1-bit extracts need to be turned into a boolean
+ if(lo == hi) { toBool(bits) } else { bits }
+ }
+ case BVNot(BVEqual(a, b)) if a.width == 1 => s"(distinct ${serialize(a)} ${serialize(b)})"
+ case BVNot(BVNot(e)) => serialize(e)
+ case BVNot(e) => if(e.width == 1) { s"(not ${serialize(e)})" } else { s"(bvnot ${serialize(e)})" }
+ case BVNegate(e) => s"(bvneg ${asBitVector(e)})"
+ case r: BVReduceAnd => serialize(Expander.expand(r))
+ case r: BVReduceOr => serialize(Expander.expand(r))
+ case r: BVReduceXor => serialize(Expander.expand(r))
+ case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b)
+ case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})"
+ case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
+ case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
+ case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})"
+ case BVComparison(Compare.GreaterEqual, a, b, false) => s"(bvuge ${asBitVector(a)} ${asBitVector(b)})"
+ case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})"
+ case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})"
+ // boolean operations get a special treatment for 1-bit vectors aka bools
+ case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})"
+ case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})"
+ case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})"
+ case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})")
+ case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})"
+ case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})"
+ case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})"
+ case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
+ case BVRawExpr(serialized, _) => serialized
+ }
+
+ def serialize(e: ArrayExpr): String = e match {
+ case ArraySymbol(name, _, _) => escapeIdentifier(name)
+ case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})"
+ case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
+ case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
+ case ArrayRawExpr(serialized, _, _) => serialized
+ }
+
+ def serialize(c: SMTCommand): String = c match {
+ case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n")
+ case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)"
+ case DefineFunction(name, args, e) =>
+ val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ")
+ s"(define-fun ${escapeIdentifier(name)} ($aa) ${serializeType(e)} ${serialize(e)})"
+ case DeclareFunction(sym, tpes) =>
+ val aa = tpes.mkString(" ")
+ s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serializeType(sym)})"
+ }
+
+ private def serializeArrayType(indexWidth: Int, dataWidth: Int): String =
+ s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})"
+ private def serializeBitVectorType(width: Int): String =
+ if(width == 1) { "Bool" } else { assert(width > 1) ; s"(_ BitVec $width)" }
+
+ private def serialize(op: Op.Value): String = op match {
+ case Op.And => "bvand"
+ case Op.Or => "bvor"
+ case Op.Xor => "bvxor"
+ case Op.ArithmeticShiftRight => "bvashr"
+ case Op.ShiftRight => "bvlshr"
+ case Op.ShiftLeft => "bvshl"
+ case Op.Add => "bvadd"
+ case Op.Mul => "bvmul"
+ case Op.Sub => "bvsub"
+ case Op.SignedDiv => "bvsdiv"
+ case Op.UnsignedDiv => "bvudiv"
+ case Op.SignedMod => "bvsmod"
+ case Op.SignedRem => "bvsrem"
+ case Op.UnsignedRem => "bvurem"
+ }
+
+ private def toBool(e: String): String = s"(= $e (_ bv1 1))"
+
+ private val bvZero = "(_ bv0 1)"
+ private val bvOne = "(_ bv1 1)"
+ private def asBitVector(e: BVExpr): String =
+ if(e.width > 1) { serialize(e) } else { s"(ite ${serialize(e)} $bvOne $bvZero)" }
+
+ // See <simple_symbol> definition in the Concrete Syntax Appendix of the SMTLib Spec
+ private val simple: Regex = raw"[a-zA-Z\+-/\*\=%\?!\.\$$_~&\^<>@][a-zA-Z0-9\+-/\*\=%\?!\.\$$_~&\^<>@]*".r
+ def escapeIdentifier(name: String): String = name match {
+ case simple() => name
+ case _ => if(name.startsWith("|") && name.endsWith("|")) name else s"|$name|"
+ }
+}
+
+/** Expands expressions that are not natively supported by SMTLib */
+private object Expander {
+ def expand(r: BVReduceAnd): BVExpr = {
+ if(r.e.width == 1) { r.e } else {
+ val allOnes = (BigInt(1) << r.e.width) - 1
+ BVEqual(r.e, BVLiteral(allOnes, r.e.width))
+ }
+ }
+ def expand(r: BVReduceOr): BVExpr = {
+ if(r.e.width == 1) { r.e } else {
+ BVNot(BVEqual(r.e, BVLiteral(0, r.e.width)))
+ }
+ }
+ def expand(r: BVReduceXor): BVExpr = {
+ if(r.e.width == 1) { r.e } else {
+ val bits = (0 until r.e.width).map(ii => BVSlice(r.e, ii, ii))
+ bits.reduce[BVExpr]((a,b) => BVOp(Op.Xor, a, b))
+ }
+ }
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
new file mode 100644
index 00000000..e9acc05b
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
@@ -0,0 +1,128 @@
+// See LICENSE for license details.
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import scala.collection.mutable
+
+/** This Transition System encoding is directly inspired by yosys' SMT backend:
+ * https://github.com/YosysHQ/yosys/blob/master/backends/smt2/smt2.cc
+ * It if fairly compact, but unfortunately, the use of an uninterpreted sort for the state
+ * prevents this encoding from working with boolector.
+ * For simplicity reasons, we do not support hierarchical designs (no `_h` function).
+ * */
+private object SMTTransitionSystemEncoder {
+
+ def encode(sys: TransitionSystem): Iterable[SMTCommand] = {
+ val cmds = mutable.ArrayBuffer[SMTCommand]()
+ val name = sys.name
+
+ // emit header as comments
+ cmds ++= sys.header.map(Comment)
+
+ // declare state type
+ val stateType = id(name + "_s")
+ cmds += DeclareUninterpretedSort(stateType)
+
+ // inputs and states are modelled as constants
+ def declare(sym: SMTSymbol, kind: String): Unit = {
+ cmds ++= toDescription(sym, kind, sys.comments.get)
+ val s = SMTSymbol.fromExpr(sym.name + SignalSuffix, sym)
+ cmds += DeclareFunction(s, List(stateType))
+ }
+ sys.inputs.foreach(i => declare(i, "input"))
+ sys.states.foreach(s => declare(s.sym, "register"))
+
+ // signals are just functions of other signals, inputs and state
+ def define(sym: SMTSymbol, e: SMTExpr, suffix: String = SignalSuffix): Unit = {
+ cmds += DefineFunction(sym.name + suffix, List((State, stateType)), replaceSymbols(e))
+ }
+ sys.signals.foreach { signal =>
+ val kind = if(sys.outputs.contains(signal.name)) { "output"
+ } else if(sys.assumes.contains(signal.name)) { "assume"
+ } else if(sys.asserts.contains(signal.name)) { "assert"
+ } else { "wire" }
+ val sym = SMTSymbol.fromExpr(signal.name, signal.e)
+ cmds ++= toDescription(sym, kind, sys.comments.get)
+ define(sym, signal.e)
+ }
+
+ // define the next and init functions for all states
+ sys.states.foreach { state =>
+ assert(state.next.nonEmpty, "Next function required")
+ define(state.sym, state.next.get, NextSuffix)
+ // init is optional
+ state.init.foreach { init =>
+ define(state.sym, init, InitSuffix)
+ }
+ }
+
+ def defineConjunction(e: Iterable[BVExpr], suffix: String): Unit = {
+ define(BVSymbol(name, 1), andReduce(e), suffix)
+ }
+
+ // the transition relation asserts that the value of the next state is the next value from the previous state
+ // e.g., (reg state_n) == (reg_next state)
+ val transitionRelations = sys.states.map { state =>
+ val newState = symbolToFunApp(state.sym, SignalSuffix, StateNext)
+ val nextOldState = symbolToFunApp(state.sym, NextSuffix, State)
+ SMTEqual(newState, nextOldState)
+ }
+ // the transition relation is over two states
+ val transitionExpr = replaceSymbols(andReduce(transitionRelations))
+ cmds += DefineFunction(name + "_t", List((State, stateType), (StateNext, stateType)), transitionExpr)
+
+ // The init relation just asserts that all init function hold
+ val initRelations = sys.states.filter(_.init.isDefined).map { state =>
+ val stateSignal = symbolToFunApp(state.sym, SignalSuffix, State)
+ val initSignal = symbolToFunApp(state.sym, InitSuffix, State)
+ SMTEqual(stateSignal, initSignal)
+ }
+ defineConjunction(initRelations, "_i")
+
+ // assertions and assumptions
+ val assertions = sys.signals.filter(a => sys.asserts.contains(a.name)).map(a => replaceSymbols(a.toSymbol))
+ defineConjunction(assertions, "_a")
+ val assumptions = sys.signals.filter(a => sys.assumes.contains(a.name)).map(a => replaceSymbols(a.toSymbol))
+ defineConjunction(assumptions, "_u")
+
+ cmds
+ }
+
+ private def id(s: String): String = SMTLibSerializer.escapeIdentifier(s)
+ private val State = "state"
+ private val StateNext = "state_n"
+ private val SignalSuffix = "_f"
+ private val NextSuffix = "_next"
+ private val InitSuffix = "_init"
+ private def toDescription(sym: SMTSymbol, kind: String, comments: String => Option[String]): List[Comment] = {
+ List(sym match {
+ case BVSymbol(name, width) =>
+ Comment(s"firrtl-smt2-$kind $name $width")
+ case ArraySymbol(name, indexWidth, dataWidth) =>
+ Comment(s"firrtl-smt2-$kind $name $indexWidth $dataWidth")
+ }) ++ comments(sym.name).map(Comment)
+ }
+
+ private def andReduce(e: Iterable[BVExpr]): BVExpr =
+ if(e.isEmpty) BVLiteral(1, 1) else e.reduce((a,b) => BVOp(Op.And, a, b))
+
+ // All signals are modelled with functions that need to be called with the state as argument,
+ // this replaces all Symbols with function applications to the state.
+ private def replaceSymbols(e: SMTExpr): SMTExpr = {
+ SMTExprVisitor.map(symbolToFunApp(_, SignalSuffix, State))(e)
+ }
+ private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr]
+ private def symbolToFunApp(sym: SMTExpr, suffix: String, arg: String): SMTExpr = sym match {
+ case BVSymbol(name, width) => BVRawExpr(s"(${id(name+suffix)} $arg)", width)
+ case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name+suffix)} $arg)", indexWidth, dataWidth)
+ case other => other
+ }
+}
+
+/** minimal set of pseudo SMT commands needed for our encoding */
+private sealed trait SMTCommand
+private case class Comment(msg: String) extends SMTCommand
+private case class DefineFunction(name: String, args: Seq[(String, String)], e: SMTExpr) extends SMTCommand
+private case class DeclareFunction(sym: SMTSymbol, tpes: Seq[String]) extends SMTCommand
+private case class DeclareUninterpretedSort(name: String) extends SMTCommand
diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
new file mode 100644
index 00000000..d8e203f8
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
@@ -0,0 +1,253 @@
+// See LICENSE for license details.
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import firrtl.{CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils, ir}
+import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation, ReferenceTarget, SingleTargetAnnotation}
+import firrtl.ir.EmptyStmt
+import firrtl.options.Dependency
+import firrtl.passes.PassException
+import firrtl.stage.Forms
+import firrtl.stage.TransformManager.TransformDependency
+
+import scala.collection.mutable
+
+case class GlobalClockAnnotation(target: ReferenceTarget) extends SingleTargetAnnotation[ReferenceTarget] {
+ override def duplicate(n: ReferenceTarget): Annotation = this.copy(n)
+}
+
+/** Converts every input clock into a clock enable input and adds a single global clock.
+ * - all registers and memory ports will be connected to the new global clock
+ * - all registers and memory ports will be guarded by the enable signal of their original clock
+ * - the clock enabled signal can be understood as a clock tick or posedge
+ * - this transform can be used in order to (formally) verify designs with multiple clocks or asynchronous resets
+ */
+class StutteringClockTransform extends Transform with DependencyAPIMigration {
+ override def prerequisites: Seq[TransformDependency] = Forms.LowForm
+ override def invalidates(a: Transform): Boolean = false
+
+ // this pass needs to run *before* converting to a transition system
+ override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq(Dependency(FirrtlToTransitionSystem))
+ // since this pass only runs on the main module, inlining needs to happen before
+ override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances])
+
+
+ override protected def execute(state: CircuitState): CircuitState = {
+ if(state.circuit.modules.size > 1) {
+ logger.warn("WARN: StutteringClockTransform currently only supports running on a single module.\n" +
+ s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want.")
+ }
+
+ // get main module
+ val main = state.circuit.modules.find(_.name == state.circuit.main).get match {
+ case m: ir.Module => m
+ case e: ir.ExtModule => unsupportedError(s"Cannot run on extmodule $e")
+ }
+ mainName = main.name
+
+ val namespace = Namespace(main)
+
+ // create a global clock
+ val globalClocks = state.annotations.collect { case GlobalClockAnnotation(c) => c }
+ assert(globalClocks.size < 2, "There can only be a single global clock: " + globalClocks.mkString(", "))
+ val (globalClock, portsWithGlobalClock) = globalClocks.headOption match {
+ case Some(clock) =>
+ assert(clock.module == main.name, "GlobalClock needs to be an input of the main module!")
+ assert(main.ports.exists(_.name == clock.ref), "GlobalClock needs to be an input port!")
+ assert(main.ports.find(_.name == clock.ref).get.direction == ir.Input, "GlobalClock needs to be an input port!")
+ (clock.ref, main.ports)
+ case None =>
+ val name = namespace.newName("global_clock")
+ (name, ir.Port(ir.NoInfo, name, ir.Input, ir.ClockType) +: main.ports)
+ }
+
+ // replace all other clocks with enable signals, unless they are the global clock
+ val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name)
+ val clockToEnable = clocks.map{c =>
+ c -> ir.Reference(namespace.newName(c + "_en"), Bool, firrtl.PortKind, firrtl.SourceFlow)
+ }.toMap
+ val portsWithEnableSignals = portsWithGlobalClock.map { p =>
+ if(clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) } else { p }
+ }
+ // replace async reset with synchronous reset (since everything will we synchronous with the global clock)
+ // unless it is a preset reset
+ val asyncResets = portsWithEnableSignals.filter(_.tpe == ir.AsyncResetType).map(_.name)
+ val isPresetReset = state.annotations.collect{ case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet
+ val resetsToChange = asyncResets.filterNot(isPresetReset).toSet
+ val portsWithSyncReset = portsWithEnableSignals.map { p =>
+ if(resetsToChange.contains(p.name)) { p.copy(tpe = Bool) } else { p }
+ }
+
+ // discover clock and reset connections
+ val scan = scanClocks(main, clockToEnable, resetsToChange)
+
+ // rename clocks to clock enable signals
+ val mRef = CircuitTarget(state.circuit.main).module(main.name)
+ val renameMap = RenameMap()
+ scan.clockToEnable.foreach { case (clk, en) =>
+ renameMap.record(mRef.ref(clk), mRef.ref(en.name))
+ }
+
+ // make changes
+ implicit val ctx: Context = new Context(globalClock, scan)
+ val newMain = main.copy(ports = portsWithSyncReset).mapStmt(onStatement)
+
+ val nonMainModules = state.circuit.modules.filterNot(_.name == state.circuit.main)
+ val newCircuit = state.circuit.copy(modules = nonMainModules :+ newMain)
+ state.copy(circuit = newCircuit, renames = Some(renameMap))
+ }
+
+ private def onStatement(s: ir.Statement)(implicit ctx: Context): ir.Statement = {
+ s.foreachExpr(checkExpr)
+ s match {
+ // memory field connects
+ case c @ ir.Connect(_, ir.SubField(ir.SubField(ir.Reference(mem, _, _, _), port, _, _), field, _, _), _)
+ if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) =>
+ // replace clock with the global clock
+ if(field == "clk") {
+ c.copy(expr = ctx.globalClock)
+ } else if(field == "en") {
+ val m = ctx.memInfo(mem)
+ val isWritePort = m.writers.contains(port)
+ assert(isWritePort || m.readers.contains(port))
+
+ // for write ports we guard the write enable with the clock enable signal, similar to registers
+ if(isWritePort) {
+ val clockEn = ctx.memPortToClockEnable(mem + "." + port)
+ val guardedEnable = and(clockEn, c.expr)
+ c.copy(expr = guardedEnable)
+ } else { c }
+ } else { c}
+ // register field connects
+ case c @ ir.Connect(_, r : ir.Reference, next) if ctx.registerToEnable.contains(r.name) =>
+ val clockEnable = ctx.registerToEnable(r.name)
+ val guardedNext = mux(clockEnable, next, r)
+ c.copy(expr = guardedNext)
+ // remove other clock wires and nodes
+ case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt
+ case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
+ case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
+ // change async reset to synchronous reset
+ case ir.Connect(info, loc: ir.Reference, expr: ir.Reference) if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) =>
+ ir.Connect(info, loc.copy(tpe=Bool), expr.copy(tpe=Bool))
+ case d @ ir.DefNode(_, name, value: ir.Reference) if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) =>
+ d.copy(value = value.copy(tpe=Bool))
+ case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe=Bool)
+ // change memory clock and synchronize reset
+ case ir.DefRegister(info, name, tpe, clock, reset, init) if ctx.registerToEnable.contains(name) =>
+ val clockEnable = ctx.registerToEnable(name)
+ val newReset = reset match {
+ case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe=Bool)
+ case other => other
+ }
+ val synchronizedReset = if(reset.tpe == ir.AsyncResetType) { newReset } else { and(newReset, clockEnable) }
+ ir.DefRegister(info, name, tpe, ctx.globalClock, synchronizedReset, init)
+ case other => other.mapStmt(onStatement)
+ }
+ }
+
+ private def scanClocks(m: ir.Module, initialClockToEnable: Map[String, ir.Reference], resetsToChange: Set[String]): ScanCtx = {
+ implicit val ctx: ScanCtx = new ScanCtx(initialClockToEnable, resetsToChange)
+ m.foreachStmt(scanClocksAndResets)
+ ctx
+ }
+
+ private def scanClocksAndResets(s: ir.Statement)(implicit ctx: ScanCtx): Unit = {
+ s.foreachExpr(checkExpr)
+ s match {
+ // track clock aliases
+ case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType =>
+ val locName = loc.serialize
+ ctx.clockToEnable.get(expr.serialize).foreach { clockEn =>
+ ctx.clockToEnable(locName) = clockEn
+ // keep track of memory clocks
+ if(loc.isInstanceOf[ir.SubField]) {
+ val parts = locName.split('.')
+ if(ctx.mems.contains(parts.head)) {
+ assert(parts.length == 3 && parts.last == "clk")
+ ctx.memPortToClockEnable.append(parts.dropRight(1).mkString(".") -> clockEn)
+ }
+ }
+ }
+ case ir.DefNode(_, name, value) if value.tpe == ir.ClockType =>
+ ctx.clockToEnable.get(value.serialize).foreach(c => ctx.clockToEnable(name) = c)
+ // track reset aliases
+ case ir.Connect(_, loc, expr) if expr.tpe == ir.AsyncResetType && ctx.resetsToChange(expr.serialize) =>
+ ctx.resetsToChange.add(loc.serialize)
+ case ir.DefNode(_, name, value) if value.tpe == ir.AsyncResetType && ctx.resetsToChange(value.serialize) =>
+ ctx.resetsToChange.add(name)
+ // modify clocked elements
+ case ir.DefRegister(_, name, _, clock, _, _) =>
+ ctx.clockToEnable.get(clock.serialize).foreach { clockEnable =>
+ ctx.registerToEnable.append(name -> clockEnable)
+ }
+ case m : ir.DefMemory =>
+ assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!")
+ assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!")
+ assert(m.writeLatency == 1, "Only write-latency 1 is supported!")
+ if(m.readers.nonEmpty && m.readLatency == 1) {
+ unsupportedError("Registers memory read ports are not properly implemented yet :(")
+ }
+ ctx.mems(m.name) = m
+ case other => other.foreachStmt(scanClocksAndResets)
+ }
+ }
+
+ // we rely on people not casting clocks or async resets
+ private def checkExpr(expr: ir.Expression): Unit = expr match {
+ case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.ClockType =>
+ unsupportedError(s"Clock casts are not supported: ${expr.serialize}")
+ case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.ClockType =>
+ unsupportedError(s"Clock casts are not supported: ${expr.serialize}")
+ case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType =>
+ unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}")
+ case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType =>
+ unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}")
+ case ir.DoPrim(PrimOps.AsAsyncReset, _, _, _) =>
+ unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}")
+ case ir.DoPrim(PrimOps.AsClock, _, _, _) =>
+ unsupportedError(s"Clock casts are not supported: ${expr.serialize}")
+ case other => other.foreachExpr(checkExpr)
+ }
+
+ private class ScanCtx(initialClockToEnable: Map[String, ir.Reference], initialResetsToChange: Set[String]) {
+ // keeps track of which clock signals will be replaced by which clock enable signal
+ val clockToEnable = mutable.HashMap[String, ir.Reference]() ++ initialClockToEnable
+ // kepp track of asynchronous resets that need to be changed to bool
+ val resetsToChange = mutable.HashSet[String]() ++ initialResetsToChange
+ // registers whose next function needs to be guarded with a clock enable
+ val registerToEnable = mutable.ArrayBuffer[(String, ir.Reference)]()
+ // memory enables which need to be guarded with clock enables
+ val memPortToClockEnable = mutable.ArrayBuffer[(String, ir.Reference)]()
+ // keep track of memory names
+ val mems = mutable.HashMap[String, ir.DefMemory]()
+ }
+
+ private class Context(globalClockName: String, scanResults: ScanCtx) {
+ val globalClock: ir.Reference = ir.Reference(globalClockName, ir.ClockType, firrtl.PortKind, firrtl.SourceFlow)
+ // keeps track of which clock signals will be replaced by which clock enable signal
+ val isRemovedClock: String => Boolean = scanResults.clockToEnable.contains
+ // registers whose next function needs to be guarded with a clock enable
+ val registerToEnable: Map[String, ir.Reference] = scanResults.registerToEnable.toMap
+ // memory enables which need to be guarded with clock enables
+ val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap
+ // keep track of memory names
+ val isMem: String => Boolean = scanResults.mems.contains
+ val memInfo: String => ir.DefMemory = scanResults.mems
+ val isResetToChange: String => Boolean = scanResults.resetsToChange.contains
+ }
+
+ private var mainName: String = "" // for debugging
+ private def unsupportedError(msg: String): Nothing =
+ throw new UnsupportedFeatureException(s"StutteringClockTransform: [$mainName] $msg")
+
+ private def mux(cond: ir.Expression, a: ir.Expression, b: ir.Expression): ir.Expression = {
+ ir.Mux(cond, a, b, Utils.mux_type_and_widths(a, b))
+ }
+ private def and(a: ir.Expression, b: ir.Expression): ir.Expression =
+ ir.DoPrim(PrimOps.And, List(a, b), List(), Bool)
+ private val Bool = ir.UIntType(ir.IntWidth(1))
+}
+
+private class UnsupportedFeatureException(s: String) extends PassException(s) \ No newline at end of file