aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala114
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala20
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala577
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala12
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala266
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala86
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala60
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala77
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala59
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala100
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala69
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala120
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala16
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala6
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala2
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala8
16 files changed, 848 insertions, 744 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
index f96fd4e8..a6eaa51b 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
@@ -3,28 +3,12 @@
package firrtl.backends.experimental.smt
-import firrtl.backends.experimental.smt.Btor2Serializer.functionCallToArrayRead
-
import scala.collection.mutable
private object Btor2Serializer {
def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = {
new Btor2Serializer().run(sys, skipOutput)
}
-
- private def functionCallToArrayRead(call: BVFunctionCall): BVExpr = {
- if (call.args.isEmpty) {
- BVSymbol(call.name, call.width)
- } else {
- val index = concat(call.args)
- val a = ArraySymbol(call.name, indexWidth = index.width, dataWidth = call.width)
- ArrayRead(a, index)
- }
- }
- private def concat(e: Iterable[BVExpr]): BVExpr = {
- require(e.nonEmpty)
- e.reduce((a, b) => BVConcat(a, b))
- }
}
private class Btor2Serializer private () {
@@ -55,7 +39,7 @@ private class Btor2Serializer private () {
// bit vector expression serialization
private def s(expr: BVExpr): Int = expr match {
case BVLiteral(value, width) => lit(value, width)
- case BVSymbol(name, _) => symbols(name)
+ case BVSymbol(name, _) => symbols.getOrElse(name, throw new RuntimeException(s"Unknown symbol: $name"))
case BVExtend(e, 0, _) => s(e)
case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by")
case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by")
@@ -86,13 +70,13 @@ private class Btor2Serializer private () {
line(s"read ${t(expr.width)} ${s(array)} ${s(index)}")
case BVIte(cond, tru, fals) =>
line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}")
- case r: BVRawExpr =>
- throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}")
+ case b @ BVAnd(terms) => variadic("and", b.width, terms)
+ case b @ BVOr(terms) => variadic("or", b.width, terms)
+ case forall: BVForall =>
+ throw new RuntimeException(s"Quantifiers are not supported by the btor2 format: ${forall}")
}
private def s(op: Op.Value): String = op match {
- case Op.And => "and"
- case Op.Or => "or"
case Op.Xor => "xor"
case Op.ArithmeticShiftRight => "sra"
case Op.ShiftRight => "srl"
@@ -112,6 +96,14 @@ private class Btor2Serializer private () {
private def binary(op: String, width: Int, a: BVExpr, b: BVExpr): Int =
line(s"$op ${t(width)} ${s(a)} ${s(b)}")
+ private def variadic(op: String, width: Int, terms: List[BVExpr]): Int = terms match {
+ case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op")
+ case Seq(a, b) => binary(op, width, a, b)
+ case head :: tail =>
+ val tailId = variadic(op, width, tail)
+ line(s"$op ${t(width)} ${s(head)} ${tailId}")
+ }
+
private def lit(value: BigInt, w: Int): Int = {
val typ = t(w)
lazy val mask = (BigInt(1) << w) - 1
@@ -141,10 +133,24 @@ private class Btor2Serializer private () {
// While the spec does not seem to allow array ite, it seems to be supported in practice.
// It is essential to model memories, so any support in the wild should be fairly well tested.
line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}")
- case ArrayConstant(e, _) => s(e)
- case r: ArrayRawExpr =>
- throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}")
+ case ArrayConstant(e, indexWidth) =>
+ // The problem we are facing here is that the only way to create a constant array from a bv expression
+ // seems to be to use the bv expression as the init value of a state variable.
+ // Thus we need to create a fake state for every array init expression.
+ arrayConstants.getOrElseUpdate(
+ e.toString, {
+ comment(s"$expr")
+ val eId = s(e)
+ val tpeId = t(indexWidth, e.width)
+ val state = line(s"state $tpeId")
+ line(s"init $tpeId $state $eId")
+ state
+ }
+ )
+ case f: ArrayFunctionCall =>
+ throw new RuntimeException(s"The btor2 format does not support uninterpreted functions that return arrays!: $f")
}
+ private val arrayConstants = mutable.HashMap[String, Int]()
private def s(expr: SMTExpr): Int = expr match {
case b: BVExpr => s(b)
@@ -157,38 +163,62 @@ private class Btor2Serializer private () {
case a: ArrayExpr => t(a.indexWidth, a.dataWidth)
}
+ private def functionCallToArrayRead(call: BVFunctionCall): BVExpr = {
+ if (call.args.isEmpty) {
+ BVSymbol(call.name, call.width)
+ } else {
+ val args: List[BVExpr] = call.args.map {
+ case b: BVExpr => b
+ case other => throw new RuntimeException(s"Unsupported call argument: $other in $call")
+ }
+ val index = concat(args)
+ val a = ArraySymbol(call.name, indexWidth = index.width, dataWidth = call.width)
+ ArrayRead(a, index)
+ }
+ }
+ private def concat(e: Iterable[BVExpr]): BVExpr = {
+ require(e.nonEmpty)
+ e.reduce((a, b) => BVConcat(a, b))
+ }
+
def run(sys: TransitionSystem, skipOutput: Boolean): Iterable[String] = {
- def declare(name: String, expr: => Int): Unit = {
+ def declare(name: String, lbl: Option[SignalLabel], expr: => Int): Unit = {
assert(!symbols.contains(name), s"Trying to redeclare `$name`")
val id = expr
symbols(name) = id
- if (!skipOutput && sys.outputs.contains(name)) line(s"output $id ; $name")
- if (sys.assumes.contains(name)) line(s"constraint $id ; $name")
- if (sys.asserts.contains(name)) {
- val invertedId = line(s"not ${t(1)} $id")
- line(s"bad $invertedId ; $name")
+ // add label
+ lbl match {
+ case Some(IsOutput) => if (!skipOutput) line(s"output $id ; $name")
+ case Some(IsConstraint) => line(s"constraint $id ; $name")
+ case Some(IsBad) => line(s"bad $id ; $name")
+ case Some(IsFair) => line(s"fair $id ; $name")
+ case _ =>
}
- if (sys.fair.contains(name)) line(s"fair $id ; $name")
// add trailing comment
sys.comments.get(name).foreach(trailingComment)
}
// header
- sys.header.foreach(comment)
+ if (sys.header.nonEmpty) {
+ sys.header.split('\n').foreach(comment)
+ }
// declare inputs
sys.inputs.foreach { ii =>
- declare(ii.name, line(s"input ${t(ii.width)} ${ii.name}"))
+ declare(ii.name, None, line(s"input ${t(ii.width)} ${ii.name}"))
}
// declare uninterpreted functions a constant arrays
- sys.ufs.foreach { foo =>
- val sym = if (foo.argWidths.isEmpty) { BVSymbol(foo.name, foo.width) }
+ val ufs = TransitionSystem.findUninterpretedFunctions(sys)
+ ufs.foreach { foo =>
+ // only functions returning bit-vectors are supported!
+ val bvSym = foo.sym.asInstanceOf[BVSymbol]
+ val sym = if (foo.args.isEmpty) { bvSym }
else {
- ArraySymbol(foo.name, foo.argWidths.sum, foo.width)
+ ArraySymbol(bvSym.name, foo.args.map(_.asInstanceOf[BVExpr].width).sum, bvSym.width)
}
comment(foo.toString)
- declare(sym.name, line(s"state ${t(sym)} ${sym.name}"))
+ declare(sym.name, None, line(s"state ${t(sym)} ${sym.name}"))
line(s"next ${t(sym)} ${s(sym)} ${s(sym)}")
}
@@ -196,14 +226,18 @@ private class Btor2Serializer private () {
sys.states.foreach { st =>
// calculate init expression before declaring the state
// this is required by btormc (presumably to avoid cycles in the init expression)
- val initId = st.init.map { init => comment(s"${st.sym}.init"); s(init) }
- declare(st.sym.name, line(s"state ${t(st.sym)} ${st.sym.name}"))
+ val initId = st.init.map {
+ // only in the context of initializing a state can we use a bv expression to model an array
+ case ArrayConstant(e, _) => comment(s"${st.sym}.init"); s(e)
+ case init => comment(s"${st.sym}.init"); s(init)
+ }
+ declare(st.sym.name, None, line(s"state ${t(st.sym)} ${st.sym.name}"))
st.init.foreach { init => line(s"init ${t(init)} ${s(st.sym)} ${initId.get}") }
}
// define all other signals
sys.signals.foreach { signal =>
- declare(signal.name, s(signal.e))
+ declare(signal.name, Some(signal.lbl), s(signal.e))
}
// define state next
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
index 2c08ff6a..c7524e21 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
@@ -7,17 +7,11 @@ import firrtl.ir
import firrtl.PrimOps
import firrtl.passes.CheckWidths.WidthTooBig
-private trait TranslationContext {
- def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, firrtl.bitWidth(tpe).toInt)
-}
-
private object FirrtlExpressionSemantics {
- def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = {
+ def toSMT(e: ir.Expression): BVExpr = {
val eSMT = e match {
case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts)
- case r: ir.Reference => ctx.getReference(r.serialize, r.tpe)
- case r: ir.SubField => ctx.getReference(r.serialize, r.tpe)
- case r: ir.SubIndex => ctx.getReference(r.serialize, r.tpe)
+ case r: ir.RefLikeExpression => BVSymbol(r.serialize, getWidth(r))
case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt)
case ir.SIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt)
case ir.Mux(cond, tval, fval, _) =>
@@ -34,7 +28,7 @@ private object FirrtlExpressionSemantics {
}
/** Ensures that the result has the desired width by appropriately extending it. */
- def toSMT(e: ir.Expression, width: Int, allowNarrow: Boolean = false)(implicit ctx: TranslationContext): BVExpr =
+ def toSMT(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr =
forceWidth(toSMT(e), isSigned(e), width, allowNarrow)
private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = {
@@ -52,8 +46,6 @@ private object FirrtlExpressionSemantics {
op: ir.PrimOp,
args: Seq[ir.Expression],
consts: Seq[BigInt]
- )(
- implicit ctx: TranslationContext
): BVExpr = {
(op, args, consts) match {
case (PrimOps.Add, Seq(e1, e2), _) =>
@@ -137,10 +129,10 @@ private object FirrtlExpressionSemantics {
case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e))
case (PrimOps.And, Seq(e1, e2), _) =>
val width = args.map(getWidth).max
- BVOp(Op.And, toSMT(e1, width), toSMT(e2, width))
+ BVAnd(toSMT(e1, width), toSMT(e2, width))
case (PrimOps.Or, Seq(e1, e2), _) =>
val width = args.map(getWidth).max
- BVOp(Op.Or, toSMT(e1, width), toSMT(e2, width))
+ BVOr(toSMT(e1, width), toSMT(e2, width))
case (PrimOps.Xor, Seq(e1, e2), _) =>
val width = args.map(getWidth).max
BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width))
@@ -170,7 +162,7 @@ private object FirrtlExpressionSemantics {
private val BV1BitZero = BVLiteral(0, 1)
- def isSigned(e: ir.Expression): Boolean = e.tpe match {
+ private def isSigned(e: ir.Expression): Boolean = e.tpe match {
case _: ir.SIntType => true
case _ => false
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index 726a8854..c5fff849 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -4,60 +4,24 @@
package firrtl.backends.experimental.smt
import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation}
-import firrtl.bitWidth
-import FirrtlExpressionSemantics.getWidth
+import firrtl._
import firrtl.backends.experimental.smt.random._
-import firrtl.graph.MutableDiGraph
import firrtl.options.Dependency
import firrtl.passes.MemPortUtils.memPortField
import firrtl.passes.PassException
import firrtl.passes.memlib.VerilogMemDelays
import firrtl.stage.Forms
import firrtl.stage.TransformManager.TransformDependency
-import firrtl.transforms.{DeadCodeElimination, EnsureNamedStatements, PropagatePresetAnnotations}
-import firrtl.{
- ir,
- CircuitState,
- DependencyAPIMigration,
- MemoryArrayInit,
- MemoryInitValue,
- MemoryScalarInit,
- Namespace,
- Transform,
- Utils
-}
+import firrtl.transforms.{EnsureNamedStatements, PropagatePresetAnnotations}
import logger.LazyLogging
import scala.collection.mutable
-// Contains code to convert a flat firrtl module into a functional transition system which
-// can then be exported as SMTLib or Btor2 file.
-
-private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr])
-private case class Signal(name: String, e: BVExpr) { def toSymbol: BVSymbol = BVSymbol(name, e.width) }
-private case class TransitionSystem(
- name: String,
- inputs: Array[BVSymbol],
- states: Array[State],
- signals: Array[Signal],
- outputs: Set[String],
- assumes: Set[String],
- asserts: Set[String],
- fair: Set[String],
- ufs: List[BVFunctionSymbol] = List(),
- comments: Map[String, String] = Map(),
- header: Array[String] = Array()) {
- def serialize: String = {
- (Iterator(name) ++
- ufs.map(u => u.toString) ++
- inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++
- signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++
- states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")).mkString("\n")
- }
-}
-
private case class TransitionSystemAnnotation(sys: TransitionSystem) extends NoTargetAnnotation
+/** Contains code to convert a flat firrtl module into a functional transition system which
+ * can then be exported as SMTLib or Btor2 file.
+ */
object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++
Seq(
@@ -94,12 +58,12 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
// convert the main module
val main = modules(circuit.main)
val sys = main match {
- case x: ir.ExtModule =>
+ case _: ir.ExtModule =>
throw new ExtModuleException(
"External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog."
)
case m: ir.Module =>
- new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit, uninterpreted = uninterpreted)
+ new ModuleToTransitionSystem(presetRegs = presetRegs, memInit = memInit, uninterpreted = uninterpreted).run(m)
}
val sortedSys = TopologicalSort.run(sys)
@@ -131,336 +95,118 @@ private class MultiClockException(s: String) extends PassException(s + Unsupport
private class MissingFeatureException(s: String)
extends PassException("Unfortunately the SMT backend does not yet support: " + s)
-private class ModuleToTransitionSystem extends LazyLogging {
+private class ModuleToTransitionSystem(
+ presetRegs: Set[String],
+ memInit: Map[String, MemoryInitValue],
+ uninterpreted: Map[String, UninterpretedModuleAnnotation])
+ extends LazyLogging {
- def run(
- m: ir.Module,
- presetRegs: Set[String] = Set(),
- memInit: Map[String, MemoryInitValue] = Map(),
- uninterpreted: Map[String, UninterpretedModuleAnnotation] = Map()
- ): TransitionSystem = {
+ def run(m: ir.Module): TransitionSystem = {
// first pass over the module to convert expressions; discover state and I/O
- val scan = new ModuleScanner(uninterpreted)
- m.foreachPort(scan.onPort)
- m.foreachStmt(scan.onStatement)
+ m.foreachPort(onPort)
+ m.foreachStmt(onStatement)
// multi-clock support requires the StutteringClock transform to be run
- if (scan.clocks.size > 1) {
- throw new MultiClockException(s"The module ${m.name} has more than one clock: ${scan.clocks.mkString(", ")}")
+ if (clocks.size > 1) {
+ throw new MultiClockException(s"The module ${m.name} has more than one clock: ${clocks.mkString(", ")}")
}
- // turn wires and nodes into signals
- val outputs = scan.outputs.toSet
- val constraints = scan.assumes.toSet
- val bad = scan.asserts.toSet
- val isSignal = (scan.wires ++ scan.nodes ++ scan.memSignals).toSet ++ outputs ++ constraints ++ bad
- val signals = scan.connects.filter { case (name, _) => isSignal.contains(name) }.map {
- case (name, expr) => Signal(name, expr)
- }
-
- // turn registers and memories into states
- val registers = scan.registers.map(r => r._1 -> r).toMap
- val regStates = scan.connects.filter(s => registers.contains(s._1)).map {
- case (name, nextExpr) =>
- val (_, width, resetExpr, initExpr) = registers(name)
- onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs)
- }
- // turn memories into state
- val memoryStatesAndOutputs = scan.memories.map(m => onMemory(m, scan.connects, memInit.get(m.name)))
- // replace pseudo assigns for memory outputs
- val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap
- val signalsWithMem = signals.map { s =>
- if (memOutputs.contains(s.name)) {
- s.copy(e = memOutputs(s.name))
- } else { s }
- }
- // filter out any left-over self assignments (this happens when we have a registered read port)
- .filter(s =>
- s match {
- case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false
- case _ => true
- }
- )
- val states = regStates.toArray ++ memoryStatesAndOutputs.map(_._1)
-
// generate comments from infos
val comments = mutable.HashMap[String, String]()
- scan.infos.foreach {
+ infos.foreach {
case (name, info) =>
- serializeInfo(info).foreach { infoString =>
- if (comments.contains(name)) { comments(name) += InfoSeparator + infoString }
- else { comments(name) = InfoPrefix + infoString }
+ val infoStr = info.serialize.trim
+ if (infoStr.nonEmpty) {
+ val prefix = comments.get(name).map(_ + ", ").getOrElse("")
+ comments(name) = prefix + infoStr
}
}
- // inputs are original module inputs and any DefRandom signal
- val inputs = scan.inputs
-
// module info to the comment header
- val header = serializeInfo(m.info).map(InfoPrefix + _).toArray
-
- val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints
-
- // collect unique functions
- val ufs = scan.functionCalls.groupBy(_.name).map(_._2.head).toList
-
- TransitionSystem(
- m.name,
- inputs.toArray,
- states,
- signalsWithMem.toArray,
- outputs,
- constraints,
- bad,
- fair,
- ufs,
- comments.toMap,
- header
- )
- }
-
- private def onRegister(
- name: String,
- width: Int,
- resetExpr: BVExpr,
- initExpr: BVExpr,
- nextExpr: BVExpr,
- presetRegs: Set[String]
- ): State = {
- assert(initExpr.width == width)
- assert(nextExpr.width == width)
- assert(resetExpr.width == 1)
- val sym = BVSymbol(name, width)
- val hasReset = initExpr != sym
- val isPreset = presetRegs.contains(name)
- assert(!isPreset || hasReset, s"Expected preset register $name to have a reset value, not just $initExpr!")
- if (hasReset) {
- val init = if (isPreset) Some(initExpr) else None
- val next = if (isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr)
- State(sym, next = Some(next), init = init)
- } else {
- State(sym, next = Some(nextExpr), init = None)
- }
- }
-
- type Connects = Iterable[(String, BVExpr)]
- private def onMemory(m: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (State, Connects) = {
- checkMem(m)
-
- // map of inputs to the memory
- val inputs = connects.filter(_._1.startsWith(m.name)).toMap
-
- // derive the type of the memory from the dataType and depth
- val dataWidth = bitWidth(m.dataType).toInt
- val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
- val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth)
-
- // there could be a constant init
- val init = initValue.map(getInit(m, indexWidth, dataWidth, _))
- init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth))
-
- // derive next state expression
- val next = if (m.writers.isEmpty) {
- memSymbol
- } else {
- m.writers.foldLeft[ArrayExpr](memSymbol) {
- case (prev, write) =>
- // update
- val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth)
- val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth)
- val update = ArrayStore(prev, index = addr, data = data)
-
- // update guard
- val en = BVSymbol(memPortField(m, write, "en").serialize, 1)
- val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1)
- val alwaysEnabled = Seq(en, mask).forall(s => inputs(s.name) == True)
- if (alwaysEnabled) { update }
- else {
- ArrayIte(and(en, mask), update, prev)
- }
- }
- }
-
- val state = State(memSymbol, init, Some(next))
-
- // derive read expressions
- val readSignals = m.readers.map { read =>
- val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth)
- memPortField(m, read, "data").serialize -> ArrayRead(memSymbol, addr)
- }
+ val header = m.info.serialize.trim
- (state, readSignals)
+ TransitionSystem(m.name, inputs.toList, states.values.toList, signals.toList, comments.toMap, header)
}
- private def getInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr =
- initValue match {
- case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth)
- case MemoryArrayInit(values) =>
- assert(
- values.length == m.depth,
- s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!"
- )
- // in order to get a more compact encoding try to find the most common values
- val histogram = mutable.LinkedHashMap[BigInt, Int]()
- values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
- val baseValue = histogram.maxBy(_._2)._1
- val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth)
- values.zipWithIndex
- .filterNot(_._1 == baseValue)
- .foldLeft[ArrayExpr](base) {
- case (array, (value, index)) =>
- ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth))
- }
- case other => throw new RuntimeException(s"Unsupported memory init option: $other")
- }
+ private val inputs = mutable.ArrayBuffer[BVSymbol]()
+ private val clocks = mutable.ArrayBuffer[String]()
+ private val signals = mutable.ArrayBuffer[Signal]()
+ private val states = mutable.LinkedHashMap[String, State]()
+ private val infos = mutable.ArrayBuffer[(String, ir.Info)]()
- // TODO: add to BV expression library
- private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match {
- case (True, True) => True
- case (True, x) => x
- case (x, True) => x
- case _ => BVOp(Op.And, a, b)
- }
-
- private val True = BVLiteral(1, 1)
- private def checkMem(m: ir.DefMemory): Unit = {
- assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?")
- assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?")
- assert(
- m.dataType.isInstanceOf[ir.GroundType],
- s"Memory $m is of type ${m.dataType} which is not a ground type!"
- )
- assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
- }
-
- private val InfoSeparator = ", "
- private val InfoPrefix = "@ "
- private def serializeInfo(info: ir.Info): Option[String] = info match {
- case ir.NoInfo => None
- case f: ir.FileInfo => Some(f.escaped)
- case m: ir.MultiInfo =>
- val infos = m.flatten
- if (infos.isEmpty) { None }
- else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
- }
-}
-
-// performas a first pass over the module collecting all connections, wires, registers, input and outputs
-private class ModuleScanner(
- uninterpreted: Map[String, UninterpretedModuleAnnotation])
- extends LazyLogging {
- import FirrtlExpressionSemantics.getWidth
-
- private[firrtl] val inputs = mutable.ArrayBuffer[BVSymbol]()
- private[firrtl] val outputs = mutable.ArrayBuffer[String]()
- private[firrtl] val clocks = mutable.LinkedHashSet[String]()
- private[firrtl] val wires = mutable.ArrayBuffer[String]()
- private[firrtl] val nodes = mutable.ArrayBuffer[String]()
- private[firrtl] val memSignals = mutable.ArrayBuffer[String]()
- private[firrtl] val registers = mutable.ArrayBuffer[(String, Int, BVExpr, BVExpr)]()
- private[firrtl] val memories = mutable.ArrayBuffer[ir.DefMemory]()
- // DefNode, Connect, IsInvalid and VerificationStatement connections
- private[firrtl] val connects = mutable.ArrayBuffer[(String, BVExpr)]()
- private[firrtl] val asserts = mutable.ArrayBuffer[String]()
- private[firrtl] val assumes = mutable.ArrayBuffer[String]()
- // maps identifiers to their info
- private[firrtl] val infos = mutable.ArrayBuffer[(String, ir.Info)]()
- // Keeps track of (so far) unused memory (data) and uninterpreted module outputs.
- // This is used in order to delay declaring them for as long as possible.
- private val unusedOutputs = mutable.LinkedHashMap[String, BVExpr]()
- // ensure unique names for assert/assume signals
- private[firrtl] val namespace = Namespace()
- // keep track of all uninterpreted functions called
- private[firrtl] val functionCalls = mutable.ArrayBuffer[BVFunctionSymbol]()
-
- private[firrtl] def onPort(p: ir.Port): Unit = {
+ private def onPort(p: ir.Port): Unit = {
if (isAsyncReset(p.tpe)) {
throw new AsyncResetException(s"Found AsyncReset ${p.name}.")
}
- namespace.newName(p.name)
infos.append(p.name -> p.info)
p.direction match {
case ir.Input =>
if (isClock(p.tpe)) {
- clocks.add(p.name)
+ clocks.append(p.name)
} else {
inputs.append(BVSymbol(p.name, bitWidth(p.tpe).toInt))
}
case ir.Output =>
- if (!isClock(p.tpe)) { // we ignore clock outputs
- outputs.append(p.name)
- }
}
}
- private[firrtl] def onStatement(s: ir.Statement): Unit = s match {
- case DefRandom(info, name, tpe, _, _) =>
- namespace.newName(name)
+ private def onStatement(s: ir.Statement): Unit = s match {
+ case DefRandom(info, name, tpe, _, en) =>
assert(!isClock(tpe), "rand should never be a clock!")
- // we model random sources as inputs and ignore the enable signal
+ // we model random sources as inputs and the enable signal as output
infos.append(name -> info)
inputs.append(BVSymbol(name, bitWidth(tpe).toInt))
- case ir.DefWire(info, name, tpe) =>
- namespace.newName(name)
- if (!isClock(tpe) && !isAsyncReset(tpe)) {
- infos.append(name -> info)
- wires.append(name)
+ signals.append(Signal(name + ".en", onExpression(en, 1), IsOutput))
+ case w: ir.DefWire =>
+ if (!isClock(w.tpe)) {
+ // InlineInstances can insert wires without re-running RemoveWires for now we just deal with it when
+ // the Wires is connected to (ir.Connect).
}
case ir.DefNode(info, name, expr) =>
- namespace.newName(name)
if (!isClock(expr.tpe) && !isAsyncReset(expr.tpe)) {
- insertDummyAssignsForUnusedOutputs(expr)
infos.append(name -> info)
- val e = onExpression(expr)
- nodes.append(name)
- connects.append((name, e))
+ signals.append(Signal(name, onExpression(expr), IsNode))
}
- case ir.DefRegister(info, name, tpe, _, reset, init) =>
- namespace.newName(name)
- insertDummyAssignsForUnusedOutputs(reset)
- insertDummyAssignsForUnusedOutputs(init)
- infos.append(name -> info)
- val width = bitWidth(tpe).toInt
- val resetExpr = onExpression(reset, 1)
- val initExpr = onExpression(init, width)
- registers.append((name, width, resetExpr, initExpr))
+ case r: ir.DefRegister =>
+ infos.append(r.name -> r.info)
+ states(r.name) = onRegister(r)
case m: ir.DefMemory =>
- namespace.newName(m.name)
infos.append(m.name -> m.info)
- val outputs = getMemOutputs(m)
- (getMemInputs(m) ++ outputs).foreach(memSignals.append(_))
- val dataWidth = bitWidth(m.dataType).toInt
- outputs.foreach(name => unusedOutputs(name) = BVSymbol(name, dataWidth))
- memories.append(m)
+ states(m.name) = onMemory(m)
case ir.Connect(info, loc, expr) =>
if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
- if (!isClock(loc.tpe)) { // we ignore clock connections
+ if (!isClock(loc.tpe) && !isAsyncReset(expr.tpe)) { // we ignore clock connections
val name = loc.serialize
- insertDummyAssignsForUnusedOutputs(expr)
- infos.append(name -> info)
- connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = true)))
+ val e = onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = false)
+ Utils.kind(loc) match {
+ case RegKind => states(name) = states(name).copy(next = Some(e))
+ case PortKind | InstanceKind => // module output or submodule input
+ infos.append(name -> info)
+ signals.append(Signal(name, e, IsOutput))
+ case MemKind | WireKind =>
+ // InlineInstances can insert wires without re-running RemoveWires for now we just deal with it.
+ infos.append(name -> info)
+ signals.append(Signal(name, e, IsNode))
+ }
}
- case i @ ir.IsInvalid(info, loc) =>
- if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
+ case i: ir.IsInvalid =>
throw new UnsupportedFeatureException(s"IsInvalid statements are not supported: ${i.serialize}")
case ir.DefInstance(info, name, module, tpe) => onInstance(info, name, module, tpe)
- case s @ ir.Verification(op, info, _, pred, en, msg) =>
- if (op == ir.Formal.Cover) {
+ case s: ir.Verification =>
+ if (s.op == ir.Formal.Cover) {
logger.info(s"[info] Cover statement was ignored: ${s.serialize}")
} else {
- insertDummyAssignsForUnusedOutputs(pred)
- insertDummyAssignsForUnusedOutputs(en)
val name = s.name
- val predicate = onExpression(pred)
- val enabled = onExpression(en)
+ val predicate = onExpression(s.pred)
+ val enabled = onExpression(s.en)
val e = BVImplies(enabled, predicate)
- infos.append(name -> info)
- connects.append(name -> e)
- if (op == ir.Formal.Assert) {
- asserts.append(name)
+ infos.append(name -> s.info)
+ val signal = if (s.op == ir.Formal.Assert) {
+ Signal(name, BVNot(e), IsBad)
} else {
- assumes.append(name)
+ Signal(name, e, IsConstraint)
}
+ signals.append(signal)
}
case s: ir.Conditionally =>
error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}")
@@ -475,20 +221,29 @@ private class ModuleScanner(
)
} else {
// we treat Stop statements with a non-zero exit value as assertions that en will always be false!
- insertDummyAssignsForUnusedOutputs(s.en)
val name = s.name
infos.append(name -> s.info)
- val enabled = onExpression(s.en)
- connects.append(name -> BVNot(enabled))
- asserts.append(name)
+ signals.append(Signal(name, onExpression(s.en), IsBad))
}
case s: ir.Print =>
logger.info(s"Info: ignoring: ${s.serialize}")
case other => other.foreachStmt(onStatement)
}
+ private def onRegister(r: ir.DefRegister): State = {
+ val width = bitWidth(r.tpe).toInt
+ val resetExpr = onExpression(r.reset, 1)
+ assert(resetExpr == False(), s"Expected reset expression of ${r.name} to be 0, not $resetExpr")
+ val initExpr = onExpression(r.init, width)
+ val sym = BVSymbol(r.name, width)
+ val hasReset = initExpr != sym
+ val isPreset = presetRegs.contains(r.name)
+ assert(!isPreset || hasReset, s"Expected preset register ${r.name} to have a reset value, not just $initExpr!")
+ val state = State(sym, if (isPreset) Some(initExpr) else None, None)
+ state
+ }
+
private def onInstance(info: ir.Info, name: String, module: String, tpe: ir.Type): Unit = {
- namespace.newName(name)
if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}")
if (uninterpreted.contains(module)) {
onUninterpretedInstance(info: ir.Info, name: String, module: String, tpe: ir.Type)
@@ -509,14 +264,10 @@ private class ModuleScanner(
// outputs of the submodule become inputs to our module
if (isOutput) {
if (isClock(p.tpe)) {
- clocks.add(pName)
+ clocks.append(pName)
} else {
inputs.append(BVSymbol(pName, bitWidth(p.tpe).toInt))
}
- } else {
- if (!isClock(p.tpe)) { // we ignore clock outputs
- outputs.append(pName)
- }
}
}
}
@@ -539,112 +290,90 @@ private class ModuleScanner(
val functionName = anno.prefix + "." + out.name
val call = BVFunctionCall(functionName, args, out.width)
val wireName = instanceName + "." + out.name
- // remember which functions were called
- functionCalls.append(call.toSymbol)
- // insert the output definition right before its first use in an attempt to get SSA
- unusedOutputs(wireName) = call
- // treat these outputs as wires
- wires.append(wireName)
+ signals.append(Signal(wireName, call))
}
-
- // we also treat the arguments as wires
- wires ++= args.map(_.name)
}
- private val readInputFields = List("en", "addr")
- private val writeInputFields = List("en", "mask", "addr", "data")
- private def getMemInputs(m: ir.DefMemory): Iterable[String] = {
- assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!")
- val p = m.name + "."
- m.writers.flatMap(w => writeInputFields.map(p + w + "." + _)) ++
- m.readers.flatMap(r => readInputFields.map(p + r + "." + _))
- }
- private def getMemOutputs(m: ir.DefMemory): Iterable[String] = {
- assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!")
- val p = m.name + "."
- m.readers.map(r => p + r + ".data")
- }
- // inserts a dummy assign right before a memory/uninterpreted module output is used for the first time
- // example:
- // m.r.data <= m.r.data ; this is the dummy assign
- // test <= m.r.data ; this is the first use of m.r.data
- private def insertDummyAssignsForUnusedOutputs(next: ir.Expression): Unit = if (unusedOutputs.nonEmpty) {
- val uses = mutable.ArrayBuffer[String]()
- findUnusedOutputUse(next)(uses)
- if (uses.nonEmpty) {
- val useSet = uses.toSet
- unusedOutputs.foreach {
- case (name, value) =>
- if (useSet.contains(name)) connects.append(name -> value)
+ private def onMemory(m: ir.DefMemory): State = {
+ checkMem(m)
+
+ // derive the type of the memory from the dataType and depth
+ val dataWidth = bitWidth(m.dataType).toInt
+ val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
+ val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth)
+
+ // there could be a constant init
+ val init = memInit.get(m.name).map(getMemInit(m, indexWidth, dataWidth, _))
+ init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth))
+
+ // derive next state expression
+ val next = if (m.writers.isEmpty) {
+ memSymbol
+ } else {
+ m.writers.foldLeft[ArrayExpr](memSymbol) {
+ case (prev, write) =>
+ // update
+ val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth)
+ val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth)
+ val update = ArrayStore(prev, index = addr, data = data)
+
+ // update guard
+ val en = BVSymbol(memPortField(m, write, "en").serialize, 1)
+ val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1)
+ ArrayIte(BVAnd(en, mask), update, prev)
}
- useSet.foreach(name => unusedOutputs.remove(name))
}
- }
- private def findUnusedOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match {
- case s: ir.SubField =>
- val name = s.serialize
- if (unusedOutputs.contains(name)) uses.append(name)
- case other => other.foreachExpr(findUnusedOutputUse)
- }
- private case class Context() extends TranslationContext {}
+ val state = State(memSymbol, init, Some(next))
- private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = {
- implicit val ctx: TranslationContext = Context()
- FirrtlExpressionSemantics.toSMT(e, width, allowNarrow)
+ // derive read expressions
+ val readSignals = m.readers.map { read =>
+ val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth)
+ Signal(memPortField(m, read, "data").serialize, ArrayRead(memSymbol, addr), IsNode)
+ }
+ signals ++= readSignals
+
+ state
}
- private def onExpression(e: ir.Expression): BVExpr = {
- implicit val ctx: TranslationContext = Context()
- FirrtlExpressionSemantics.toSMT(e)
+
+ private def getMemInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr =
+ initValue match {
+ case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth)
+ case MemoryArrayInit(values) =>
+ assert(
+ values.length == m.depth,
+ s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!"
+ )
+ // in order to get a more compact encoding try to find the most common values
+ val histogram = mutable.LinkedHashMap[BigInt, Int]()
+ values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
+ val baseValue = histogram.maxBy(_._2)._1
+ val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth)
+ values.zipWithIndex
+ .filterNot(_._1 == baseValue)
+ .foldLeft[ArrayExpr](base) {
+ case (array, (value, index)) =>
+ ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth))
+ }
+ case other => throw new RuntimeException(s"Unsupported memory init option: $other")
+ }
+
+ private def checkMem(m: ir.DefMemory): Unit = {
+ assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?")
+ assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?")
+ assert(
+ m.dataType.isInstanceOf[ir.GroundType],
+ s"Memory $m is of type ${m.dataType} which is not a ground type!"
+ )
+ assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
}
+ private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr =
+ FirrtlExpressionSemantics.toSMT(e, width, allowNarrow)
+ private def onExpression(e: ir.Expression): BVExpr = FirrtlExpressionSemantics.toSMT(e)
+
private def error(msg: String): Unit = throw new RuntimeException(msg)
private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType]
private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType
private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType
}
-
-private object TopologicalSort {
-
- /** Ensures that all signals in the resulting system are topologically sorted.
- * This is necessary because [[firrtl.transforms.RemoveWires]] does
- * not sort assignments to outputs, submodule inputs nor memory ports.
- */
- def run(sys: TransitionSystem): TransitionSystem = {
- val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name)
- val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates)
- // TODO: maybe sort init expressions of states (this should not be needed most of the time)
- signalOrder match {
- case None => sys
- case Some(order) =>
- val signalMap = sys.signals.map(s => s.name -> s).toMap
- // we flatMap over `get` in order to ignore inputs/states in the order
- sys.copy(signals = order.flatMap(signalMap.get).toArray)
- }
- }
-
- private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = {
- val known = new mutable.HashSet[String]() ++ globalSignals
- var needsReordering = false
- val digraph = new MutableDiGraph[String]
- signals.foreach {
- case (name, expr) =>
- digraph.addVertex(name)
- val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr)
- uniqueDependencies.foreach { d =>
- if (!known.contains(d)) { needsReordering = true }
- digraph.addPairWithEdge(name, d)
- }
- known.add(name)
- }
- if (needsReordering) {
- Some(digraph.linearize.reverse)
- } else { None }
- }
-
- private def findDependencies(expr: SMTExpr): List[String] = expr match {
- case BVSymbol(name, _) => List(name)
- case ArraySymbol(name, _, _) => List(name)
- case other => other.children.flatMap(findDependencies)
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala
new file mode 100644
index 00000000..21a64f98
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala
@@ -0,0 +1,12 @@
+// SPDX-License-Identifier: Apache-2.0
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+private sealed trait SMTCommand
+private case class Comment(msg: String) extends SMTCommand
+private case class SetLogic(logic: String) extends SMTCommand
+private case class DefineFunction(name: String, args: Seq[SMTFunctionArg], e: SMTExpr) extends SMTCommand
+private case class DeclareFunction(sym: SMTSymbol, args: Seq[SMTFunctionArg]) extends SMTCommand
+private case class DeclareUninterpretedSort(name: String) extends SMTCommand
+private case class DeclareUninterpretedSymbol(name: String, tpe: String) extends SMTCommand
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
index 0fc507e6..a40717f9 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
@@ -5,9 +5,20 @@
package firrtl.backends.experimental.smt
-private sealed trait SMTExpr { def children: List[SMTExpr] }
-private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { val name: String }
+/** base trait for all SMT expressions */
+private sealed trait SMTExpr extends SMTFunctionArg {
+ def tpe: SMTType
+ def children: List[SMTExpr]
+}
+private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr {
+ def name: String
+
+ /** keeps the type of the symbol while changing the name */
+ def rename(newName: String): SMTSymbol
+}
private object SMTSymbol {
+
+ /** makes a SMTSymbol of the same type as the expression */
def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match {
case b: BVExpr => BVSymbol(name, b.width)
case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth)
@@ -17,91 +28,115 @@ private sealed trait SMTNullaryExpr extends SMTExpr {
override def children: List[SMTExpr] = List()
}
-private sealed trait BVExpr extends SMTExpr { def width: Int }
+/** a SMT bit vector expression: https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml */
+private sealed trait BVExpr extends SMTExpr {
+ def width: Int
+ def tpe: BVType = BVType(width)
+ override def toString: String = SMTExprSerializer.serialize(this)
+}
private case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr {
private def minWidth = value.bitLength + (if (value <= 0) 1 else 0)
+ assert(value >= 0, "Negative values are not supported! Please normalize by calculating 2s complement.")
assert(width > 0, "Zero or negative width literals are not allowed!")
assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.")
- override def toString: String = if (width <= 8) {
- width.toString + "'b" + value.toString(2)
- } else { width.toString + "'x" + value.toString(16) }
+}
+private object BVLiteral {
+ def apply(nums: String): BVLiteral = nums.head match {
+ case 'b' => BVLiteral(BigInt(nums.drop(1), 2), nums.length - 1)
+ }
}
private case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol {
assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
- assert(!name.contains("\\"), s"Invalid id $name contains `\\`")
assert(width > 0, "Zero width bit vectors are not supported!")
- override def toString: String = name
- def toStringWithType: String = name + " : " + SMTExpr.serializeType(this)
+ override def rename(newName: String) = BVSymbol(newName, width)
}
private sealed trait BVUnaryExpr extends BVExpr {
def e: BVExpr
+
+ /** same function, different child, e.g.: not(x) -- reapply(Y) --> not(Y) */
+ def reapply(expr: BVExpr): BVUnaryExpr
override def children: List[BVExpr] = List(e)
}
private case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr {
assert(by >= 0, "Extension must be non-negative!")
override val width: Int = e.width + by
- override def toString: String = if (signed) { s"sext($e, $by)" }
- else { s"zext($e, $by)" }
+ override def reapply(expr: BVExpr) = BVExtend(expr, by, signed)
}
// also known as bit extract operation
private case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr {
assert(lo >= 0, s"lo (lsb) must be non-negative!")
assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo")
assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi")
- override def width: Int = hi - lo + 1
- override def toString: String = if (hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]"
+ override def width: Int = hi - lo + 1
+ override def reapply(expr: BVExpr) = BVSlice(expr, hi, lo)
}
private case class BVNot(e: BVExpr) extends BVUnaryExpr {
- override val width: Int = e.width
- override def toString: String = s"not($e)"
+ override val width: Int = e.width
+ override def reapply(expr: BVExpr) = new BVNot(expr)
}
private case class BVNegate(e: BVExpr) extends BVUnaryExpr {
- override val width: Int = e.width
- override def toString: String = s"neg($e)"
+ override val width: Int = e.width
+ override def reapply(expr: BVExpr) = BVNegate(expr)
}
+
private case class BVReduceOr(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
- override def toString: String = s"redor($e)"
+ override def width: Int = 1
+ override def reapply(expr: BVExpr) = BVReduceOr(expr)
}
private case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
- override def toString: String = s"redand($e)"
+ override def width: Int = 1
+ override def reapply(expr: BVExpr) = BVReduceAnd(expr)
}
private case class BVReduceXor(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
- override def toString: String = s"redxor($e)"
+ override def width: Int = 1
+ override def reapply(expr: BVExpr) = BVReduceXor(expr)
}
private sealed trait BVBinaryExpr extends BVExpr {
def a: BVExpr
def b: BVExpr
override def children: List[BVExpr] = List(a, b)
-}
-private case class BVImplies(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
- assert(a.width == 1 && b.width == 1, s"Both arguments need to be 1-bit!")
- override def width: Int = 1
- override def toString: String = s"impl($a, $b)"
+
+ /** same function, different child, e.g.: add(a,b) -- reapply(a,c) --> add(a,c) */
+ def reapply(nA: BVExpr, nB: BVExpr): BVBinaryExpr
}
private case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
assert(a.width == b.width, s"Both argument need to be the same width!")
- override def width: Int = 1
- override def toString: String = s"eq($a, $b)"
+ override def width: Int = 1
+ override def reapply(nA: BVExpr, nB: BVExpr) = BVEqual(nA, nB)
}
+// added as a separate node because it is used a lot in model checking and benefits from pretty printing
+private class BVImplies(val a: BVExpr, val b: BVExpr) extends BVBinaryExpr {
+ assert(a.width == 1, s"The antecedent needs to be a boolean expression!")
+ assert(b.width == 1, s"The consequent needs to be a boolean expression!")
+ override def width: Int = 1
+ override def reapply(nA: BVExpr, nB: BVExpr) = new BVImplies(nA, nB)
+}
+private object BVImplies {
+ def apply(a: BVExpr, b: BVExpr): BVExpr = {
+ assert(a.width == b.width, s"Both argument need to be the same width!")
+ (a, b) match {
+ case (True(), b) => b // (!1 || b) = b
+ case (False(), _) => True() // (!0 || _) = (1 || _) = 1
+ case (_, True()) => True() // (!a || 1) = 1
+ case (a, False()) => BVNot(a) // (!a || 0) = !a
+ case (a, b) => new BVImplies(a, b)
+ }
+ }
+ def unapply(i: BVImplies): Some[(BVExpr, BVExpr)] = Some((i.a, i.b))
+}
+
private object Compare extends Enumeration {
val Greater, GreaterEqual = Value
}
private case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed: Boolean) extends BVBinaryExpr {
assert(a.width == b.width, s"Both argument need to be the same width!")
override def width: Int = 1
- override def toString: String = op match {
- case Compare.Greater => (if (signed) "sgt" else "ugt") + s"($a, $b)"
- case Compare.GreaterEqual => (if (signed) "sgeq" else "ugeq") + s"($a, $b)"
- }
+ override def reapply(nA: BVExpr, nB: BVExpr) = BVComparison(op, nA, nB, signed)
}
+
private object Op extends Enumeration {
- val And = Value("and")
- val Or = Value("or")
val Xor = Value("xor")
val ShiftLeft = Value("logical_shift_left")
val ArithmeticShiftRight = Value("arithmetic_shift_right")
@@ -117,51 +152,65 @@ private object Op extends Enumeration {
}
private case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr {
assert(a.width == b.width, s"Both argument need to be the same width!")
- override val width: Int = a.width
- override def toString: String = s"$op($a, $b)"
+ override val width: Int = a.width
+ override def reapply(nA: BVExpr, nB: BVExpr) = BVOp(op, nA, nB)
}
private case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
- override val width: Int = a.width + b.width
- override def toString: String = s"concat($a, $b)"
+ override val width: Int = a.width + b.width
+ override def reapply(nA: BVExpr, nB: BVExpr) = BVConcat(nA, nB)
}
private case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr {
assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
override val width: Int = array.dataWidth
- override def toString: String = s"$array[$index]"
override def children: List[SMTExpr] = List(array, index)
}
private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr {
assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!")
assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}")
override val width: Int = tru.width
- override def toString: String = s"ite($cond, $tru, $fals)"
override def children: List[BVExpr] = List(cond, tru, fals)
}
-/** apply bv arguments to a function which returns a result of bit vector type */
-private case class BVFunctionCall(name: String, args: List[BVExpr], width: Int) extends BVExpr {
- override def children = args
- def toSymbol: BVFunctionSymbol = BVFunctionSymbol(name, args.map(_.width), width)
- override def toString: String = args.mkString(name + "(", ", ", ")")
+private case class BVAnd(terms: List[BVExpr]) extends BVExpr {
+ require(terms.size > 1)
+ override val width: Int = terms.head.width
+ require(terms.forall(_.width == width))
+ override def children: List[BVExpr] = terms
}
-private case class BVFunctionSymbol(name: String, argWidths: List[Int], width: Int) {
- override def toString: String = s"$name : " + (argWidths :+ width).map(w => s"bv<$w>").mkString(" -> ")
+private case class BVOr(terms: List[BVExpr]) extends BVExpr {
+ require(terms.size > 1)
+ override val width: Int = terms.head.width
+ require(terms.forall(_.width == width))
+ override def children: List[BVExpr] = terms
}
-private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int }
+private sealed trait ArrayExpr extends SMTExpr {
+ val indexWidth: Int
+ val dataWidth: Int
+ def tpe: ArrayType = ArrayType(indexWidth = indexWidth, dataWidth = dataWidth)
+ override def toString: String = SMTExprSerializer.serialize(this)
+}
private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol {
assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
assert(!name.contains("\\"), s"Invalid id $name contains `\\`")
- override def toString: String = name
- def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>"
+ override def rename(newName: String) = ArraySymbol(newName, indexWidth, dataWidth)
+}
+private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr {
+ override val dataWidth: Int = e.width
+ override def children: List[SMTExpr] = List(e)
+}
+private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr {
+ assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!")
+ assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!")
+ override def width: Int = 1
+ override def children: List[SMTExpr] = List(a, b)
}
private case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr {
assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
assert(array.dataWidth == data.width, "Data with does not match expected array data width!")
override val dataWidth: Int = array.dataWidth
override val indexWidth: Int = array.indexWidth
- override def toString: String = s"$array[$index := $data]"
override def children: List[SMTExpr] = List(array, index, data)
}
private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr {
@@ -176,20 +225,79 @@ private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) exten
)
override val dataWidth: Int = tru.dataWidth
override val indexWidth: Int = tru.indexWidth
- override def toString: String = s"ite($cond, $tru, $fals)"
override def children: List[SMTExpr] = List(cond, tru, fals)
}
-private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr {
- assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!")
- assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!")
- override def width: Int = 1
- override def toString: String = s"eq($a, $b)"
- override def children: List[SMTExpr] = List(a, b)
+
+private case class BVForall(variable: BVSymbol, e: BVExpr) extends BVUnaryExpr {
+ assert(e.width == 1, "Can only quantify over boolean expressions!")
+ override def width = 1
+ override def reapply(expr: BVExpr) = BVForall(variable, expr)
}
-private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr {
- override val dataWidth: Int = e.width
- override def toString: String = s"([$e] x ${(BigInt(1) << indexWidth)})"
- override def children: List[SMTExpr] = List(e)
+
+/** apply arguments to a function which returns a result of bit vector type */
+private case class BVFunctionCall(name: String, args: List[SMTFunctionArg], width: Int) extends BVExpr {
+ override def children = args.map(_.asInstanceOf[SMTExpr])
+}
+
+/** apply arguments to a function which returns a result of array type */
+private case class ArrayFunctionCall(name: String, args: List[SMTFunctionArg], indexWidth: Int, dataWidth: Int)
+ extends ArrayExpr {
+ override def children = args.map(_.asInstanceOf[SMTExpr])
+}
+private sealed trait SMTFunctionArg
+// we allow symbols with uninterpreted type to be function arguments
+private case class UTSymbol(name: String, tpe: String) extends SMTFunctionArg
+
+private object BVAnd {
+ def apply(a: BVExpr, b: BVExpr): BVExpr = {
+ assert(a.width == b.width, s"Both argument need to be the same width!")
+ (a, b) match {
+ case (True(), b) => b
+ case (a, True()) => a
+ case (False(), _) => False()
+ case (_, False()) => False()
+ case (a, b) => new BVAnd(List(a, b))
+ }
+ }
+ def apply(exprs: List[BVExpr]): BVExpr = {
+ assert(exprs.nonEmpty, "Don't know what to do with an empty list!")
+ val nonTriviallyTrue = exprs.filterNot(_ == True())
+ nonTriviallyTrue.distinct match {
+ case Seq() => True()
+ case Seq(one) => one
+ case terms => new BVAnd(terms)
+ }
+ }
+}
+private object BVOr {
+ def apply(a: BVExpr, b: BVExpr): BVExpr = {
+ assert(a.width == b.width, s"Both argument need to be the same width!")
+ (a, b) match {
+ case (True(), _) => True()
+ case (_, True()) => True()
+ case (False(), b) => b
+ case (a, False()) => a
+ case (a, b) => new BVOr(List(a, b))
+ }
+ }
+ def apply(exprs: List[BVExpr]): BVExpr = {
+ assert(exprs.nonEmpty, "Don't know what to do with an empty list!")
+ val nonTriviallyFalse = exprs.filterNot(_ == False())
+ nonTriviallyFalse.distinct match {
+ case Seq() => False()
+ case Seq(one) => one
+ case terms => new BVOr(terms)
+ }
+ }
+}
+
+private object BVNot {
+ def apply(e: BVExpr): BVExpr = e match {
+ case True() => False()
+ case False() => True()
+ case BVNot(inner) => inner
+ case other => new BVNot(other)
+ }
}
private object SMTEqual {
@@ -200,6 +308,14 @@ private object SMTEqual {
}
}
+private object SMTIte {
+ def apply(cond: BVExpr, tru: SMTExpr, fals: SMTExpr): SMTExpr = (tru, fals) match {
+ case (ab: BVExpr, bb: BVExpr) => BVIte(cond, ab, bb)
+ case (aa: ArrayExpr, ba: ArrayExpr) => ArrayIte(cond, aa, ba)
+ case _ => throw new RuntimeException(s"Cannot mux $tru and $fals")
+ }
+}
+
private object SMTExpr {
def serializeType(e: SMTExpr): String = e match {
case b: BVExpr => s"bv<${b.width}>"
@@ -207,8 +323,20 @@ private object SMTExpr {
}
}
-// Raw SMTLib encoded expressions as an escape hatch used in the [[SMTTransitionSystemEncoder]]
-private case class BVRawExpr(serialized: String, width: Int) extends BVExpr with SMTNullaryExpr
-private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int)
- extends ArrayExpr
- with SMTNullaryExpr
+// unapply for matching BVLiteral(1, 1)
+private object True {
+ private val _True = BVLiteral(1, 1)
+ def apply(): BVLiteral = _True
+ def unapply(l: BVLiteral): Boolean = l.value == 1 && l.width == 1
+}
+
+// unapply for matching BVLiteral(0, 1)
+private object False {
+ private val _False = BVLiteral(0, 1)
+ def apply(): BVLiteral = _False
+ def unapply(l: BVLiteral): Boolean = l.value == 0 && l.width == 1
+}
+
+private sealed trait SMTType
+private case class BVType(width: Int) extends SMTType
+private case class ArrayType(indexWidth: Int, dataWidth: Int) extends SMTType
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala
new file mode 100644
index 00000000..c991941f
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala
@@ -0,0 +1,86 @@
+// SPDX-License-Identifier: Apache-2.0
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+package firrtl.backends.experimental.smt
+
+private object SMTExprMap {
+ def mapExpr(expr: SMTExpr, f: SMTExpr => SMTExpr): SMTExpr = {
+ val bv = (b: BVExpr) => f(b).asInstanceOf[BVExpr]
+ val ar = (a: ArrayExpr) => f(a).asInstanceOf[ArrayExpr]
+ expr match {
+ case b: BVExpr => mapExpr(b, bv, ar)
+ case a: ArrayExpr => mapExpr(a, bv, ar)
+ }
+ }
+
+ /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */
+ def mapExpr(expr: BVExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): BVExpr = expr match {
+ // nullary
+ case old: BVLiteral => old
+ case old: BVSymbol => old
+ // unary
+ case old @ BVExtend(e, by, signed) => val n = bv(e); if (n.eq(e)) old else BVExtend(n, by, signed)
+ case old @ BVSlice(e, hi, lo) => val n = bv(e); if (n.eq(e)) old else BVSlice(n, hi, lo)
+ case old @ BVNot(e) => val n = bv(e); if (n.eq(e)) old else BVNot(n)
+ case old @ BVNegate(e) => val n = bv(e); if (n.eq(e)) old else BVNegate(n)
+ case old @ BVForall(variables, e) => val n = bv(e); if (n.eq(e)) old else BVForall(variables, n)
+ case old @ BVReduceAnd(e) => val n = bv(e); if (n.eq(e)) old else BVReduceAnd(n)
+ case old @ BVReduceOr(e) => val n = bv(e); if (n.eq(e)) old else BVReduceOr(n)
+ case old @ BVReduceXor(e) => val n = bv(e); if (n.eq(e)) old else BVReduceXor(n)
+ // binary
+ case old @ BVEqual(a, b) =>
+ val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB)
+ case old @ ArrayEqual(a, b) =>
+ val (nA, nB) = (ar(a), ar(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB)
+ case old @ BVComparison(op, a, b, signed) =>
+ val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed)
+ case old @ BVOp(op, a, b) =>
+ val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB)
+ case old @ BVConcat(a, b) =>
+ val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB)
+ case old @ ArrayRead(a, b) =>
+ val (nA, nB) = (ar(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB)
+ case old @ BVImplies(a, b) =>
+ val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB)
+ // ternary
+ case old @ BVIte(a, b, c) =>
+ val (nA, nB, nC) = (bv(a), bv(b), bv(c))
+ if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)
+ // n-ary
+ case old @ BVFunctionCall(name, args, width) =>
+ val nArgs = args.map {
+ case b: BVExpr => bv(b)
+ case a: ArrayExpr => ar(a)
+ case u: UTSymbol => u
+ }
+ val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) }
+ if (anyNew) BVFunctionCall(name, nArgs, width) else old
+ case old @ BVAnd(terms) =>
+ val nTerms = terms.map(bv)
+ val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) }
+ if (anyNew) BVAnd(nTerms) else old
+ case old @ BVOr(terms) =>
+ val nTerms = terms.map(bv)
+ val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) }
+ if (anyNew) BVOr(nTerms) else old
+ }
+
+ /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */
+ def mapExpr(expr: ArrayExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): ArrayExpr = expr match {
+ case old: ArraySymbol => old
+ case old @ ArrayConstant(e, indexWidth) => val n = bv(e); if (n.eq(e)) old else ArrayConstant(n, indexWidth)
+ case old @ ArrayStore(a, b, c) =>
+ val (nA, nB, nC) = (ar(a), bv(b), bv(c))
+ if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC)
+ case old @ ArrayIte(a, b, c) =>
+ val (nA, nB, nC) = (bv(a), ar(b), ar(c))
+ if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC)
+ case old @ ArrayFunctionCall(name, args, indexWidth, dataWidth) =>
+ val nArgs = args.map {
+ case b: BVExpr => bv(b)
+ case a: ArrayExpr => ar(a)
+ case u: UTSymbol => u
+ }
+ val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) }
+ if (anyNew) ArrayFunctionCall(name, nArgs, indexWidth, dataWidth) else old
+ }
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala
new file mode 100644
index 00000000..4aaf78a2
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala
@@ -0,0 +1,60 @@
+// SPDX-License-Identifier: Apache-2.0
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+private object SMTExprSerializer {
+ def serialize(expr: BVExpr): String = expr match {
+ // nullary
+ case lit: BVLiteral =>
+ if (lit.width <= 8) {
+ lit.width.toString + "'b" + lit.value.toString(2)
+ } else {
+ lit.width.toString + "'x" + lit.value.toString(16)
+ }
+ case BVSymbol(name, _) => name
+ // unary
+ case BVExtend(e, by, false) => s"zext(${serialize(e)}, $by)"
+ case BVExtend(e, by, true) => s"sext(${serialize(e)}, $by)"
+ case BVSlice(e, hi, lo) if hi == lo => s"${serialize(e)}[$hi]"
+ case BVSlice(e, hi, lo) => s"${serialize(e)}[$hi:$lo]"
+ case BVNot(e) => s"not(${serialize(e)})"
+ case BVNegate(e) => s"neg(${serialize(e)})"
+ case BVForall(variable, e) => s"forall(${variable.name} : bv<${variable.width}, ${serialize(e)})"
+ case BVReduceAnd(e) => s"redand(${serialize(e)})"
+ case BVReduceOr(e) => s"redor(${serialize(e)})"
+ case BVReduceXor(e) => s"redxor(${serialize(e)})"
+ // binary
+ case BVEqual(a, b) => s"eq(${serialize(a)}, ${serialize(b)})"
+ case BVComparison(Compare.Greater, a, b, false) => s"ugt(${serialize(a)}, ${serialize(b)})"
+ case BVComparison(Compare.Greater, a, b, true) => s"sgt(${serialize(a)}, ${serialize(b)})"
+ case BVComparison(Compare.GreaterEqual, a, b, false) => s"ugeq(${serialize(a)}, ${serialize(b)})"
+ case BVComparison(Compare.GreaterEqual, a, b, true) => s"sgeq(${serialize(a)}, ${serialize(b)})"
+ case BVOp(op, a, b) => s"$op(${serialize(a)}, ${serialize(b)})"
+ case BVConcat(a, b) => s"concat(${serialize(a)}, ${serialize(b)})"
+ case ArrayRead(array, index) => s"${serialize(array)}[${serialize(index)}]"
+ case ArrayEqual(a, b) => s"eq(${serialize(a)}, ${serialize(b)})"
+ case BVImplies(a, b) => s"implies(${serialize(a)}, ${serialize(b)})"
+ // ternary
+ case BVIte(cond, tru, fals) => s"ite(${serialize(cond)}, ${serialize(tru)}, ${serialize(fals)})"
+ // n-ary
+ case BVFunctionCall(name, args, _) => name + serialize(args).mkString("(", ",", ")")
+ case BVAnd(terms) => terms.map(serialize).mkString("and(", ", ", ")")
+ case BVOr(terms) => terms.map(serialize).mkString("or(", ", ", ")")
+ }
+
+ def serialize(expr: ArrayExpr): String = expr match {
+ case ArraySymbol(name, _, _) => name
+ case ArrayConstant(e, indexWidth) => s"([${serialize(e)}] x ${(BigInt(1) << indexWidth)})"
+ case ArrayStore(array, index, data) => s"${serialize(array)}[${serialize(index)} := ${serialize(data)}]"
+ case ArrayIte(cond, tru, fals) => s"ite(${serialize(cond)}, ${serialize(tru)}, ${serialize(fals)})"
+ case ArrayFunctionCall(name, args, _, _) => name + serialize(args).mkString("(", ",", ")")
+ }
+
+ private def serialize(args: Iterable[SMTFunctionArg]): Iterable[String] =
+ args.map {
+ case b: BVExpr => serialize(b)
+ case a: ArrayExpr => serialize(a)
+ case u: UTSymbol => u.name
+ }
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
deleted file mode 100644
index 13ed8bdd..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
+++ /dev/null
@@ -1,77 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-/** Similar to the mapExpr and foreachExpr methods of the firrtl ir nodes, but external to the case classes */
-private object SMTExprVisitor {
- type ArrayFun = ArrayExpr => ArrayExpr
- type BVFun = BVExpr => BVExpr
-
- def map[T <: SMTExpr](bv: BVFun, ar: ArrayFun)(e: T): T = e match {
- case b: BVExpr => map(b, bv, ar).asInstanceOf[T]
- case a: ArrayExpr => map(a, bv, ar).asInstanceOf[T]
- }
- def map[T <: SMTExpr](f: SMTExpr => SMTExpr)(e: T): T =
- map(b => f(b).asInstanceOf[BVExpr], a => f(a).asInstanceOf[ArrayExpr])(e)
-
- private def map(e: BVExpr, bv: BVFun, ar: ArrayFun): BVExpr = e match {
- // nullary
- case old: BVLiteral => bv(old)
- case old: BVSymbol => bv(old)
- case old: BVRawExpr => bv(old)
- // unary
- case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVExtend(n, by, signed))
- case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVSlice(n, hi, lo))
- case old @ BVNot(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNot(n))
- case old @ BVNegate(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNegate(n))
- case old @ BVReduceAnd(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceAnd(n))
- case old @ BVReduceOr(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceOr(n))
- case old @ BVReduceXor(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceXor(n))
- // binary
- case old @ BVImplies(a, b) =>
- val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB))
- case old @ BVEqual(a, b) =>
- val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB))
- case old @ ArrayEqual(a, b) =>
- val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB))
- case old @ BVComparison(op, a, b, signed) =>
- val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed))
- case old @ BVOp(op, a, b) =>
- val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB))
- case old @ BVConcat(a, b) =>
- val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB))
- case old @ ArrayRead(a, b) =>
- val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB))
- // ternary
- case old @ BVIte(a, b, c) =>
- val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
- bv(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC))
- // n-ary
- case old @ BVFunctionCall(name, args, width) =>
- val nArgs = args.map(a => map(a, bv, ar))
- val noneNew = nArgs.zip(args).forall { case (n, o) => n.eq(o) }
- bv(if (noneNew) old else BVFunctionCall(name, nArgs, width))
- }
-
- private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match {
- case old: ArrayRawExpr => ar(old)
- case old: ArraySymbol => ar(old)
- case old @ ArrayConstant(e, indexWidth) =>
- val n = map(e, bv, ar); ar(if (n.eq(e)) old else ArrayConstant(n, indexWidth))
- case old @ ArrayStore(a, b, c) =>
- val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
- ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC))
- case old @ ArrayIte(a, b, c) =>
- val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
- ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC))
- }
-
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
index 75bde09c..bb4e0348 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
@@ -19,14 +19,9 @@ private object SMTLibSerializer {
case a: ArrayExpr => serialize(a)
}
- def serializeType(e: SMTExpr): String = e match {
- case b: BVExpr => serializeBitVectorType(b.width)
- case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth)
- }
-
- def declareFunction(foo: BVFunctionSymbol): SMTCommand = {
- val args = foo.argWidths.map(serializeBitVectorType)
- DeclareFunction(BVSymbol(foo.name, foo.width), args)
+ def serialize(t: SMTType): String = t match {
+ case BVType(width) => serializeBitVectorType(width)
+ case ArrayType(indexWidth, dataWidth) => serializeArrayType(indexWidth, dataWidth)
}
private def serialize(e: BVExpr): String = e match {
@@ -71,37 +66,57 @@ private object SMTLibSerializer {
case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})"
case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})"
// boolean operations get a special treatment for 1-bit vectors aka bools
- case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})"
- case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})"
+ case b: BVAnd => serializeVariadic(if (b.width == 1) "and" else "bvand", b.terms)
+ case b: BVOr => serializeVariadic(if (b.width == 1) "or" else "bvor", b.terms)
case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})"
case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})")
case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})"
case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})"
case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})"
case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case BVFunctionCall(name, args, _) => args.map(serialize).mkString(s"($name ", " ", ")")
- case BVRawExpr(serialized, _) => serialized
+ case BVFunctionCall(name, args, _) => args.map(serializeArg).mkString(s"($name ", " ", ")")
+ case BVForall(variable, e) => s"(forall ((${variable.name} ${serialize(variable.tpe)})) ${serialize(e)})"
+ }
+
+ private def serializeVariadic(op: String, terms: List[BVExpr]): String = terms match {
+ case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op")
+ case Seq(a, b) => s"($op ${serialize(a)} ${serialize(b)})"
+ case head :: tail => s"($op ${serialize(head)} ${serializeVariadic(op, tail)})"
}
def serialize(e: ArrayExpr): String = e match {
- case ArraySymbol(name, _, _) => escapeIdentifier(name)
- case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})"
- case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
- case ArrayRawExpr(serialized, _, _) => serialized
+ case ArraySymbol(name, _, _) => escapeIdentifier(name)
+ case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})"
+ case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
+ case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
+ case ArrayFunctionCall(name, args, _, _) => args.map(serializeArg).mkString(s"($name ", " ", ")")
}
def serialize(c: SMTCommand): String = c match {
case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n")
case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)"
case DefineFunction(name, args, e) =>
- val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ")
- s"(define-fun ${escapeIdentifier(name)} ($aa) ${serializeType(e)} ${serialize(e)})"
+ val aa = args.map(a => s"(${serializeArg(a)} ${serializeArgTpe(a)})").mkString(" ")
+ s"(define-fun ${escapeIdentifier(name)} ($aa) ${serialize(e.tpe)} ${serialize(e)})"
case DeclareFunction(sym, tpes) =>
- val aa = tpes.mkString(" ")
- s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serializeType(sym)})"
+ val aa = tpes.map(serializeArgTpe).mkString(" ")
+ s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serialize(sym.tpe)})"
+ case SetLogic(logic) => s"(set-logic $logic)"
+ case DeclareUninterpretedSymbol(name, tpe) =>
+ s"(declare-fun ${escapeIdentifier(name)} () ${escapeIdentifier(tpe)})"
}
+ private def serializeArgTpe(a: SMTFunctionArg): String =
+ a match {
+ case u: UTSymbol => escapeIdentifier(u.tpe)
+ case s: SMTExpr => serialize(s.tpe)
+ }
+ private def serializeArg(a: SMTFunctionArg): String =
+ a match {
+ case u: UTSymbol => escapeIdentifier(u.name)
+ case s: SMTExpr => serialize(s)
+ }
+
private def serializeArrayType(indexWidth: Int, dataWidth: Int): String =
s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})"
private def serializeBitVectorType(width: Int): String =
@@ -109,8 +124,6 @@ private object SMTLibSerializer {
else { assert(width > 1); s"(_ BitVec $width)" }
private def serialize(op: Op.Value): String = op match {
- case Op.And => "bvand"
- case Op.Or => "bvor"
case Op.Xor => "bvxor"
case Op.ArithmeticShiftRight => "bvashr"
case Op.ShiftRight => "bvlshr"
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
index d35fe139..472363cc 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
@@ -17,37 +17,41 @@ private object SMTTransitionSystemEncoder {
val cmds = mutable.ArrayBuffer[SMTCommand]()
val name = sys.name
- // emit header as comments
- cmds ++= sys.header.map(Comment)
+ // declare UFs if necessary
+ cmds ++= TransitionSystem.findUninterpretedFunctions(sys)
- // declare uninterpreted functions used in model
- cmds ++= sys.ufs.map(SMTLibSerializer.declareFunction)
+ // emit header as comments
+ if (sys.header.nonEmpty) {
+ cmds ++= sys.header.split('\n').map(Comment)
+ }
// declare state type
val stateType = id(name + "_s")
cmds += DeclareUninterpretedSort(stateType)
+ // state symbol
+ val State = UTSymbol("state", stateType)
+ val StateNext = UTSymbol("state_n", stateType)
+
// inputs and states are modelled as constants
def declare(sym: SMTSymbol, kind: String): Unit = {
cmds ++= toDescription(sym, kind, sys.comments.get)
val s = SMTSymbol.fromExpr(sym.name + SignalSuffix, sym)
- cmds += DeclareFunction(s, List(stateType))
+ cmds += DeclareFunction(s, List(State))
}
sys.inputs.foreach(i => declare(i, "input"))
sys.states.foreach(s => declare(s.sym, "register"))
// signals are just functions of other signals, inputs and state
def define(sym: SMTSymbol, e: SMTExpr, suffix: String = SignalSuffix): Unit = {
- cmds += DefineFunction(sym.name + suffix, List((State, stateType)), replaceSymbols(e))
+ val withReplacedSymbols = replaceSymbols(SignalSuffix, State)(e)
+ cmds += DefineFunction(sym.name + suffix, List(State), withReplacedSymbols)
}
sys.signals.foreach { signal =>
- val kind = if (sys.outputs.contains(signal.name)) { "output" }
- else if (sys.assumes.contains(signal.name)) { "assume" }
- else if (sys.asserts.contains(signal.name)) { "assert" }
- else { "wire" }
- val sym = SMTSymbol.fromExpr(signal.name, signal.e)
- cmds ++= toDescription(sym, kind, sys.comments.get)
- define(sym, signal.e)
+ val sym = signal.sym
+ cmds ++= toDescription(sym, lblToKind(signal.lbl), sys.comments.get)
+ val e = if (signal.lbl == IsBad) BVNot(signal.e.asInstanceOf[BVExpr]) else signal.e
+ define(sym, e)
}
// define the next and init functions for all states
@@ -60,72 +64,70 @@ private object SMTTransitionSystemEncoder {
}
}
- def defineConjunction(e: Iterable[BVExpr], suffix: String): Unit = {
- define(BVSymbol(name, 1), andReduce(e), suffix)
+ def defineConjunction(e: List[BVExpr], suffix: String): Unit = {
+ define(BVSymbol(name, 1), if (e.isEmpty) True() else BVAnd(e), suffix)
}
// the transition relation asserts that the value of the next state is the next value from the previous state
// e.g., (reg state_n) == (reg_next state)
val transitionRelations = sys.states.map { state =>
- val newState = symbolToFunApp(state.sym, SignalSuffix, StateNext)
- val nextOldState = symbolToFunApp(state.sym, NextSuffix, State)
+ val newState = replaceSymbols(SignalSuffix, StateNext)(state.sym)
+ val nextOldState = replaceSymbols(NextSuffix, State)(state.sym)
SMTEqual(newState, nextOldState)
}
// the transition relation is over two states
- val transitionExpr = replaceSymbols(andReduce(transitionRelations))
- cmds += DefineFunction(name + "_t", List((State, stateType), (StateNext, stateType)), transitionExpr)
+ val transitionExpr = if (transitionRelations.isEmpty) { True() }
+ else {
+ replaceSymbols(SignalSuffix, State)(BVAnd(transitionRelations))
+ }
+ cmds += DefineFunction(name + "_t", List(State, StateNext), transitionExpr)
// The init relation just asserts that all init function hold
val initRelations = sys.states.filter(_.init.isDefined).map { state =>
- val stateSignal = symbolToFunApp(state.sym, SignalSuffix, State)
- val initSignal = symbolToFunApp(state.sym, InitSuffix, State)
+ val stateSignal = replaceSymbols(SignalSuffix, State)(state.sym)
+ val initSignal = replaceSymbols(InitSuffix, State)(state.sym)
SMTEqual(stateSignal, initSignal)
}
defineConjunction(initRelations, "_i")
// assertions and assumptions
- val assertions = sys.signals.filter(a => sys.asserts.contains(a.name)).map(a => replaceSymbols(a.toSymbol))
- defineConjunction(assertions, "_a")
- val assumptions = sys.signals.filter(a => sys.assumes.contains(a.name)).map(a => replaceSymbols(a.toSymbol))
- defineConjunction(assumptions, "_u")
+ val assertions = sys.signals.filter(_.lbl == IsBad).map(a => replaceSymbols(SignalSuffix, State)(a.sym))
+ defineConjunction(assertions.map(_.asInstanceOf[BVExpr]), AssertionSuffix)
+ val assumptions = sys.signals.filter(_.lbl == IsConstraint).map(a => replaceSymbols(SignalSuffix, State)(a.sym))
+ defineConjunction(assumptions.map(_.asInstanceOf[BVExpr]), AssumptionSuffix)
cmds
}
private def id(s: String): String = SMTLibSerializer.escapeIdentifier(s)
- private val State = "state"
- private val StateNext = "state_n"
private val SignalSuffix = "_f"
private val NextSuffix = "_next"
private val InitSuffix = "_init"
+ val AssertionSuffix = "_a"
+ val AssumptionSuffix = "_u"
+ private def lblToKind(lbl: SignalLabel): String = lbl match {
+ case IsNode | IsInit | IsNext => "wire"
+ case IsOutput => "output"
+ // for the SMT encoding we turn bad state signals back into assertions
+ case IsBad => "assert"
+ case IsConstraint => "assume"
+ case IsFair => "fair"
+ }
private def toDescription(sym: SMTSymbol, kind: String, comments: String => Option[String]): List[Comment] = {
List(sym match {
- case BVSymbol(name, width) =>
- Comment(s"firrtl-smt2-$kind $name $width")
+ case BVSymbol(name, width) => Comment(s"firrtl-smt2-$kind $name $width")
case ArraySymbol(name, indexWidth, dataWidth) =>
Comment(s"firrtl-smt2-$kind $name $indexWidth $dataWidth")
}) ++ comments(sym.name).map(Comment)
}
-
- private def andReduce(e: Iterable[BVExpr]): BVExpr =
- if (e.isEmpty) BVLiteral(1, 1) else e.reduce((a, b) => BVOp(Op.And, a, b))
-
// All signals are modelled with functions that need to be called with the state as argument,
// this replaces all Symbols with function applications to the state.
- private def replaceSymbols(e: SMTExpr): SMTExpr = {
- SMTExprVisitor.map(symbolToFunApp(_, SignalSuffix, State))(e)
- }
- private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr]
- private def symbolToFunApp(sym: SMTExpr, suffix: String, arg: String): SMTExpr = sym match {
- case BVSymbol(name, width) => BVRawExpr(s"(${id(name + suffix)} $arg)", width)
- case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name + suffix)} $arg)", indexWidth, dataWidth)
- case other => other
- }
+ private def replaceSymbols(suffix: String, arg: SMTFunctionArg, vars: Set[String] = Set())(e: SMTExpr): SMTExpr =
+ e match {
+ case BVSymbol(name, width) if !vars(name) => BVFunctionCall(id(name + suffix), List(arg), width)
+ case ArraySymbol(name, indexWidth, dataWidth) if !vars(name) =>
+ ArrayFunctionCall(id(name + suffix), List(arg), indexWidth, dataWidth)
+ case fa @ BVForall(variable, _) => SMTExprMap.mapExpr(fa, replaceSymbols(suffix, arg, vars + variable.name))
+ case other => SMTExprMap.mapExpr(other, replaceSymbols(suffix, arg, vars))
+ }
}
-
-/** minimal set of pseudo SMT commands needed for our encoding */
-private sealed trait SMTCommand
-private case class Comment(msg: String) extends SMTCommand
-private case class DefineFunction(name: String, args: Seq[(String, String)], e: SMTExpr) extends SMTCommand
-private case class DeclareFunction(sym: SMTSymbol, tpes: Seq[String]) extends SMTCommand
-private case class DeclareUninterpretedSort(name: String) extends SMTCommand
diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
index eac9f00a..5db39ac9 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
@@ -3,13 +3,14 @@
package firrtl.backends.experimental.smt
-import firrtl.{ir, CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils}
-import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation, ReferenceTarget, SingleTargetAnnotation}
+import firrtl._
+import firrtl.annotations._
import firrtl.ir.EmptyStmt
import firrtl.options.Dependency
import firrtl.passes.PassException
import firrtl.stage.Forms
import firrtl.stage.TransformManager.TransformDependency
+import firrtl.transforms.PropagatePresetAnnotations
import scala.collection.mutable
@@ -30,7 +31,10 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// this pass needs to run *before* converting to a transition system
override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq(Dependency(FirrtlToTransitionSystem))
// since this pass only runs on the main module, inlining needs to happen before
- override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances])
+ override def optionalPrerequisites: Seq[TransformDependency] = Seq(
+ Dependency[firrtl.passes.InlineInstances],
+ Dependency[PropagatePresetAnnotations]
+ )
override protected def execute(state: CircuitState): CircuitState = {
if (state.circuit.modules.size > 1) {
@@ -66,10 +70,10 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// replace all other clocks with enable signals, unless they are the global clock
val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name)
val clockToEnable = clocks.map { c =>
- c -> ir.Reference(namespace.newName(c + "_en"), Bool, firrtl.PortKind, firrtl.SourceFlow)
+ c -> ir.Reference(namespace.newName(c + "_en"), Utils.BoolType, firrtl.PortKind, firrtl.SourceFlow)
}.toMap
val portsWithEnableSignals = portsWithGlobalClock.map { p =>
- if (clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) }
+ if (clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Utils.BoolType) }
else { p }
}
// replace async reset with synchronous reset (since everything will we synchronous with the global clock)
@@ -78,9 +82,12 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
val isPresetReset = state.annotations.collect { case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet
val resetsToChange = asyncResets.filterNot(isPresetReset).toSet
val portsWithSyncReset = portsWithEnableSignals.map { p =>
- if (resetsToChange.contains(p.name)) { p.copy(tpe = Bool) }
+ if (resetsToChange.contains(p.name)) { p.copy(tpe = Utils.BoolType) }
else { p }
}
+ val presetRegs = state.annotations.collect {
+ case PresetRegAnnotation(target) if target.module == mainName => target.ref
+ }.toSet
// discover clock and reset connections
val scan = scanClocks(main, clockToEnable, resetsToChange)
@@ -94,7 +101,7 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
}
// make changes
- implicit val ctx: Context = new Context(globalClock, scan)
+ implicit val ctx: Context = new Context(globalClock, scan, presetRegs)
val newMain = main.copy(ports = portsWithSyncReset).mapStmt(onStatement)
val nonMainModules = state.circuit.modules.filterNot(_.name == state.circuit.main)
@@ -119,15 +126,19 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// for write ports we guard the write enable with the clock enable signal, similar to registers
if (isWritePort) {
val clockEn = ctx.memPortToClockEnable(mem + "." + port)
- val guardedEnable = and(clockEn, c.expr)
+ val guardedEnable = Utils.and(clockEn, c.expr)
c.copy(expr = guardedEnable)
} else { c }
} else { c }
// register field connects
case c @ ir.Connect(_, r: ir.Reference, next) if ctx.registerToEnable.contains(r.name) =>
val clockEnable = ctx.registerToEnable(r.name)
- val guardedNext = mux(clockEnable, next, r)
- c.copy(expr = guardedNext)
+ val guardedNext = Utils.mux(clockEnable, next, r)
+ val withReset = ctx.registerToAsyncReset.get(r.name) match {
+ case None => guardedNext
+ case Some((asyncReset, init)) => Utils.mux(asyncReset, init, guardedNext)
+ }
+ c.copy(expr = withReset)
// remove other clock wires and nodes
case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt
case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
@@ -135,21 +146,16 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// change async reset to synchronous reset
case ir.Connect(info, loc: ir.Reference, expr: ir.Reference)
if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) =>
- ir.Connect(info, loc.copy(tpe = Bool), expr.copy(tpe = Bool))
+ ir.Connect(info, loc.copy(tpe = Utils.BoolType), expr.copy(tpe = Utils.BoolType))
case d @ ir.DefNode(_, name, value: ir.Reference)
if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) =>
- d.copy(value = value.copy(tpe = Bool))
- case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe = Bool)
+ d.copy(value = value.copy(tpe = Utils.BoolType))
+ case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) =>
+ d.copy(tpe = Utils.BoolType)
// change memory clock and synchronize reset
- case ir.DefRegister(info, name, tpe, clock, reset, init) if ctx.registerToEnable.contains(name) =>
- val clockEnable = ctx.registerToEnable(name)
- val newReset = reset match {
- case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe = Bool)
- case other => other
- }
- val synchronizedReset = if (reset.tpe == ir.AsyncResetType) { newReset }
- else { and(newReset, clockEnable) }
- ir.DefRegister(info, name, tpe, ctx.globalClock, synchronizedReset, init)
+ case ir.DefRegister(info, name, tpe, _, _, init) if ctx.registerToEnable.contains(name) =>
+ val newInit = if (ctx.isPresetReg(name)) init else ir.Reference(name, tpe, RegKind, SourceFlow)
+ ir.DefRegister(info, name, tpe, ctx.globalClock, Utils.False(), newInit)
case other => other.mapStmt(onStatement)
}
}
@@ -189,10 +195,14 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
case ir.DefNode(_, name, value) if value.tpe == ir.AsyncResetType && ctx.resetsToChange(value.serialize) =>
ctx.resetsToChange.add(name)
// modify clocked elements
- case ir.DefRegister(_, name, _, clock, _, _) =>
+ case ir.DefRegister(_, name, _, clock, reset, init) =>
ctx.clockToEnable.get(clock.serialize).foreach { clockEnable =>
ctx.registerToEnable.append(name -> clockEnable)
}
+ reset match {
+ case Utils.False() =>
+ case other => ctx.registerToAsyncReset.append(name -> (other, init))
+ }
case m: ir.DefMemory =>
assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!")
assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!")
@@ -229,18 +239,22 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
val resetsToChange = mutable.HashSet[String]() ++ initialResetsToChange
// registers whose next function needs to be guarded with a clock enable
val registerToEnable = mutable.ArrayBuffer[(String, ir.Reference)]()
+ // registers with asynchronous reset
+ val registerToAsyncReset = mutable.ArrayBuffer[(String, (ir.Expression, ir.Expression))]()
// memory enables which need to be guarded with clock enables
val memPortToClockEnable = mutable.ArrayBuffer[(String, ir.Reference)]()
// keep track of memory names
val mems = mutable.HashMap[String, ir.DefMemory]()
}
- private class Context(globalClockName: String, scanResults: ScanCtx) {
+ private class Context(globalClockName: String, scanResults: ScanCtx, val isPresetReg: String => Boolean) {
val globalClock: ir.Reference = ir.Reference(globalClockName, ir.ClockType, firrtl.PortKind, firrtl.SourceFlow)
// keeps track of which clock signals will be replaced by which clock enable signal
val isRemovedClock: String => Boolean = scanResults.clockToEnable.contains
// registers whose next function needs to be guarded with a clock enable
val registerToEnable: Map[String, ir.Reference] = scanResults.registerToEnable.toMap
+ // registers with asynchronous reset
+ val registerToAsyncReset: Map[String, (ir.Expression, ir.Expression)] = scanResults.registerToAsyncReset.toMap
// memory enables which need to be guarded with clock enables
val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap
// keep track of memory names
@@ -252,13 +266,6 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
private var mainName: String = "" // for debugging
private def unsupportedError(msg: String): Nothing =
throw new UnsupportedFeatureException(s"StutteringClockTransform: [$mainName] $msg")
-
- private def mux(cond: ir.Expression, a: ir.Expression, b: ir.Expression): ir.Expression = {
- ir.Mux(cond, a, b, Utils.mux_type_and_widths(a, b))
- }
- private def and(a: ir.Expression, b: ir.Expression): ir.Expression =
- ir.DoPrim(PrimOps.And, List(a, b), List(), Bool)
- private val Bool = ir.UIntType(ir.IntWidth(1))
}
private class UnsupportedFeatureException(s: String) extends PassException(s)
diff --git a/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala
new file mode 100644
index 00000000..66a1b385
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala
@@ -0,0 +1,120 @@
+// SPDX-License-Identifier: Apache-2.0
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import firrtl.graph.MutableDiGraph
+import scala.collection.mutable
+
+private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr]) {
+ def name: String = sym.name
+}
+private case class Signal(name: String, e: SMTExpr, lbl: SignalLabel = IsNode) {
+ def toSymbol: SMTSymbol = SMTSymbol.fromExpr(name, e)
+ def sym: SMTSymbol = toSymbol
+}
+private case class TransitionSystem(
+ name: String,
+ inputs: List[BVSymbol],
+ states: List[State],
+ signals: List[Signal],
+ comments: Map[String, String] = Map(),
+ header: String = "") {
+ def serialize: String = TransitionSystem.serialize(this)
+}
+
+private sealed trait SignalLabel
+private case object IsNode extends SignalLabel
+private case object IsOutput extends SignalLabel
+private case object IsConstraint extends SignalLabel
+private case object IsBad extends SignalLabel
+private case object IsFair extends SignalLabel
+private case object IsNext extends SignalLabel
+private case object IsInit extends SignalLabel
+
+private object SignalLabel {
+ private val labels = Seq(IsNode, IsOutput, IsConstraint, IsBad, IsFair, IsNext, IsInit)
+ val labelStrings = Seq("node", "output", "constraint", "bad", "fair", "next", "init")
+ val labelToString: SignalLabel => String = labels.zip(labelStrings).toMap
+ val stringToLabel: String => SignalLabel = labelStrings.zip(labels).toMap
+}
+
+private object TransitionSystem {
+ def serialize(sys: TransitionSystem): String = {
+ (Iterator(sys.name) ++
+ sys.inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++
+ sys.signals.map(s => s"${SignalLabel.labelToString(s.lbl)} ${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++
+ sys.states.map(serialize)).mkString("\n")
+ }
+
+ def serialize(s: State): String = {
+ s"state ${s.sym.name} : ${SMTExpr.serializeType(s.sym)}" +
+ s.init.map("\n [init] " + _).getOrElse("") +
+ s.next.map("\n [next] " + _).getOrElse("")
+ }
+
+ def systemExpressions(sys: TransitionSystem): List[SMTExpr] =
+ sys.signals.map(_.e) ++ sys.states.flatMap(s => s.init ++ s.next)
+
+ def findUninterpretedFunctions(sys: TransitionSystem): List[DeclareFunction] = {
+ val calls = systemExpressions(sys).flatMap(findUFCalls)
+ // find unique functions
+ calls.groupBy(_.sym.name).map(_._2.head).toList
+ }
+
+ private def findUFCalls(e: SMTExpr): List[DeclareFunction] = {
+ val f = e match {
+ case BVFunctionCall(name, args, width) =>
+ Some(DeclareFunction(BVSymbol(name, width), args))
+ case ArrayFunctionCall(name, args, indexWidth, dataWidth) =>
+ Some(DeclareFunction(ArraySymbol(name, indexWidth, dataWidth), args))
+ case _ => None
+ }
+ f.toList ++ e.children.flatMap(findUFCalls)
+ }
+}
+
+private object TopologicalSort {
+
+ /** Ensures that all signals in the resulting system are topologically sorted.
+ * This is necessary because [[firrtl.transforms.RemoveWires]] does
+ * not sort assignments to outputs, submodule inputs nor memory ports.
+ */
+ def run(sys: TransitionSystem): TransitionSystem = {
+ val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name)
+ val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates)
+ // TODO: maybe sort init expressions of states (this should not be needed most of the time)
+ signalOrder match {
+ case None => sys
+ case Some(order) =>
+ val signalMap = sys.signals.map(s => s.name -> s).toMap
+ // we flatMap over `get` in order to ignore inputs/states in the order
+ sys.copy(signals = order.flatMap(signalMap.get).toList)
+ }
+ }
+
+ private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = {
+ val known = new mutable.HashSet[String]() ++ globalSignals
+ var needsReordering = false
+ val digraph = new MutableDiGraph[String]
+ signals.foreach {
+ case (name, expr) =>
+ digraph.addVertex(name)
+ val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr)
+ uniqueDependencies.foreach { d =>
+ if (!known.contains(d)) { needsReordering = true }
+ digraph.addPairWithEdge(name, d)
+ }
+ known.add(name)
+ }
+ if (needsReordering) {
+ Some(digraph.linearize.reverse)
+ } else { None }
+ }
+
+ private def findDependencies(expr: SMTExpr): List[String] = expr match {
+ case BVSymbol(name, _) => List(name)
+ case ArraySymbol(name, _, _) => List(name)
+ case other => other.children.flatMap(findDependencies)
+ }
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala
index fdd51a37..21e8289e 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala
@@ -26,9 +26,8 @@ class Btor2Spec extends AnyFlatSpec {
|5 output 4 ; b
|6 sort bitvec 1
|7 uext 3 2 8
- |8 eq 6 7 4
- |9 not 6 8
- |10 bad 9 ; a_eq_b
+ |8 neq 6 7 4
+ |9 bad 8 ; a_eq_b
|""".stripMargin
assert(SMTBackendHelpers.toBotr2Str(src) == expected)
@@ -46,17 +45,16 @@ class Btor2Spec extends AnyFlatSpec {
|""".stripMargin
val expected =
- """; @ module 0:0
+ """; @[module 0:0]
|1 sort bitvec 8
- |2 input 1 a ; @ a 0:0
+ |2 input 1 a ; @[a 0:0]
|3 sort bitvec 16
|4 uext 3 2 8
- |5 output 4 ; b @ b 0:0, b_a 0:0
+ |5 output 4 ; b @[b 0:0], @[b_a 0:0]
|6 sort bitvec 1
|7 uext 3 2 8
- |8 eq 6 7 4
- |9 not 6 8
- |10 bad 9 ; assert_0 @ assert 0:0
+ |8 neq 6 7 4
+ |9 bad 8 ; assert_0 @[assert 0:0]
|""".stripMargin
assert(SMTBackendHelpers.toBotr2Str(src) == expected)
diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala
index c100da56..1fd0e99b 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala
@@ -94,8 +94,7 @@ class FirrtlModuleToTransitionSystemSpec extends AnyFlatSpec {
assert(sym.indexWidth == 5)
assert(sym.dataWidth == 8)
assert(m.init.isEmpty)
- //assert(m.next.get.toString.contains("m[m.w.addr := m.w.data]"))
- assert(m.next.get.toString == "m[m.w.addr := m.w.data]")
+ assert(m.next.get.toString.contains("m[m.w.addr := m.w.data]"))
}
it should "support scalar initialization of a memory to 0" in {
@@ -170,7 +169,8 @@ class FirrtlModuleToTransitionSystemSpec extends AnyFlatSpec {
|""".stripMargin
val sys = SMTBackendHelpers.toSys(src)
assert(sys.inputs.isEmpty, "Clock inputs should be ignored.")
- assert(sys.outputs.isEmpty, "Clock outputs should be ignored.")
+ val outputs = sys.signals.filter(_.lbl == IsOutput)
+ assert(outputs.isEmpty, "Clock outputs should be ignored.")
assert(sys.signals.isEmpty, "Connects of clock type should be ignored.")
}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala
index 71d1d38c..8f4486ab 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala
@@ -38,7 +38,7 @@ private object SMTBackendHelpers {
val circuit = if (modelUndef) compileUndef(src) else compile(src)
val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module]
// println(module.serialize)
- new ModuleToTransitionSystem().run(module, presetRegs = presetRegs, memInit = memInit)
+ new ModuleToTransitionSystem(presetRegs = presetRegs, memInit = memInit, uninterpreted = Map()).run(module)
}
def toBotr2(src: String, mod: String = "m"): Iterable[String] =
diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala
index 338d760c..4d96631e 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala
@@ -47,16 +47,16 @@ class SMTLibSpec extends AnyFlatSpec {
|""".stripMargin
val expected =
- """; @ module 0:0
+ """; @[module 0:0]
|(declare-sort m_s 0)
|; firrtl-smt2-input a 8
- |; @ a 0:0
+ |; @[a 0:0]
|(declare-fun a_f (m_s) (_ BitVec 8))
|; firrtl-smt2-output b 16
- |; @ b 0:0, b_a 0:0
+ |; @[b 0:0], @[b_a 0:0]
|(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state)))
|; firrtl-smt2-assert assert_0 1
- |; @ assert 0:0
+ |; @[assert 0:0]
|(define-fun assert_0_f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state)))
|(define-fun m_t ((state m_s) (state_n m_s)) Bool true)
|(define-fun m_i ((state m_s)) Bool true)