aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala28
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala171
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala11
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala5
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala6
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala3
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala86
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala49
8 files changed, 304 insertions, 55 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
index 4cd5c9f7..f96fd4e8 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
@@ -3,12 +3,28 @@
package firrtl.backends.experimental.smt
+import firrtl.backends.experimental.smt.Btor2Serializer.functionCallToArrayRead
+
import scala.collection.mutable
private object Btor2Serializer {
def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = {
new Btor2Serializer().run(sys, skipOutput)
}
+
+ private def functionCallToArrayRead(call: BVFunctionCall): BVExpr = {
+ if (call.args.isEmpty) {
+ BVSymbol(call.name, call.width)
+ } else {
+ val index = concat(call.args)
+ val a = ArraySymbol(call.name, indexWidth = index.width, dataWidth = call.width)
+ ArrayRead(a, index)
+ }
+ }
+ private def concat(e: Iterable[BVExpr]): BVExpr = {
+ require(e.nonEmpty)
+ e.reduce((a, b) => BVConcat(a, b))
+ }
}
private class Btor2Serializer private () {
@@ -65,6 +81,7 @@ private class Btor2Serializer private () {
case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b)
case BVOp(op, a, b) => binary(s(op), expr.width, a, b)
case BVConcat(a, b) => binary("concat", expr.width, a, b)
+ case call: BVFunctionCall => s(functionCallToArrayRead(call))
case ArrayRead(array, index) =>
line(s"read ${t(expr.width)} ${s(array)} ${s(index)}")
case BVIte(cond, tru, fals) =>
@@ -164,6 +181,17 @@ private class Btor2Serializer private () {
declare(ii.name, line(s"input ${t(ii.width)} ${ii.name}"))
}
+ // declare uninterpreted functions a constant arrays
+ sys.ufs.foreach { foo =>
+ val sym = if (foo.argWidths.isEmpty) { BVSymbol(foo.name, foo.width) }
+ else {
+ ArraySymbol(foo.name, foo.argWidths.sum, foo.width)
+ }
+ comment(foo.toString)
+ declare(sym.name, line(s"state ${t(sym)} ${sym.name}"))
+ line(s"next ${t(sym)} ${s(sym)} ${s(sym)}")
+ }
+
// define state init
sys.states.foreach { st =>
// calculate init expression before declaring the state
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index aed2011a..145b5b0f 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -40,10 +40,12 @@ private case class TransitionSystem(
assumes: Set[String],
asserts: Set[String],
fair: Set[String],
+ ufs: List[BVFunctionSymbol] = List(),
comments: Map[String, String] = Map(),
header: Array[String] = Array()) {
def serialize: String = {
(Iterator(name) ++
+ ufs.map(u => u.toString) ++
inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++
signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++
states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")).mkString("\n")
@@ -79,15 +81,25 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
.map(a => a.target.ref -> a.initValue)
.toMap
+ // module look up table
+ val modules = circuit.modules.map(m => m.name -> m).toMap
+
+ // collect uninterpreted module annotations
+ val uninterpreted = afterPreset.annotations.collect {
+ case a: UninterpretedModuleAnnotation =>
+ UninterpretedModuleAnnotation.checkModule(modules(a.target.module), a)
+ a.target.module -> a
+ }.toMap
+
// convert the main module
- val main = circuit.modules.find(_.name == circuit.main).get
+ val main = modules(circuit.main)
val sys = main match {
case x: ir.ExtModule =>
throw new ExtModuleException(
"External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog."
)
case m: ir.Module =>
- new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit)
+ new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit, uninterpreted = uninterpreted)
}
val sortedSys = TopologicalSort.run(sys)
@@ -122,12 +134,13 @@ private class MissingFeatureException(s: String)
private class ModuleToTransitionSystem extends LazyLogging {
def run(
- m: ir.Module,
- presetRegs: Set[String] = Set(),
- memInit: Map[String, MemoryInitValue] = Map()
+ m: ir.Module,
+ presetRegs: Set[String] = Set(),
+ memInit: Map[String, MemoryInitValue] = Map(),
+ uninterpreted: Map[String, UninterpretedModuleAnnotation] = Map()
): TransitionSystem = {
// first pass over the module to convert expressions; discover state and I/O
- val scan = new ModuleScanner(makeRandom)
+ val scan = new ModuleScanner(makeRandom, uninterpreted)
m.foreachPort(scan.onPort)
m.foreachStmt(scan.onStatement)
@@ -188,6 +201,10 @@ private class ModuleToTransitionSystem extends LazyLogging {
val header = serializeInfo(m.info).map(InfoPrefix + _).toArray
val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints
+
+ // collect unique functions
+ val ufs = scan.functionCalls.groupBy(_.name).map(_._2.head).toList
+
TransitionSystem(
m.name,
inputs.toArray,
@@ -197,6 +214,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
constraints,
bad,
fair,
+ ufs,
comments.toMap,
header
)
@@ -456,7 +474,10 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr, namespace: Nam
}
// performas a first pass over the module collecting all connections, wires, registers, input and outputs
-private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLogging {
+private class ModuleScanner(
+ makeRandom: (String, Int) => BVExpr,
+ uninterpreted: Map[String, UninterpretedModuleAnnotation])
+ extends LazyLogging {
import FirrtlExpressionSemantics.getWidth
private[firrtl] val inputs = mutable.ArrayBuffer[BVSymbol]()
@@ -473,10 +494,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
private[firrtl] val assumes = mutable.ArrayBuffer[String]()
// maps identifiers to their info
private[firrtl] val infos = mutable.ArrayBuffer[(String, ir.Info)]()
- // keeps track of unused memory (data) outputs so that we can see where they are first used
- private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]()
+ // Keeps track of (so far) unused memory (data) and uninterpreted module outputs.
+ // This is used in order to delay declaring them for as long as possible.
+ private val unusedOutputs = mutable.LinkedHashMap[String, BVExpr]()
// ensure unique names for assert/assume signals
private[firrtl] val namespace = Namespace()
+ // keep track of all uninterpreted functions called
+ private[firrtl] val functionCalls = mutable.ArrayBuffer[BVFunctionSymbol]()
private[firrtl] def onPort(p: ir.Port): Unit = {
if (isAsyncReset(p.tpe)) {
@@ -508,7 +532,7 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
case ir.DefNode(info, name, expr) =>
namespace.newName(name)
if (!isClock(expr.tpe)) {
- insertDummyAssignsForMemoryOutputs(expr)
+ insertDummyAssignsForUnusedOutputs(expr)
infos.append(name -> info)
val e = onExpression(expr, name)
nodes.append(name)
@@ -516,8 +540,8 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
}
case ir.DefRegister(info, name, tpe, _, reset, init) =>
namespace.newName(name)
- insertDummyAssignsForMemoryOutputs(reset)
- insertDummyAssignsForMemoryOutputs(init)
+ insertDummyAssignsForUnusedOutputs(reset)
+ insertDummyAssignsForUnusedOutputs(init)
infos.append(name -> info)
val width = getWidth(tpe)
val resetExpr = onExpression(reset, 1, name + "_reset")
@@ -529,13 +553,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
val outputs = getMemOutputs(m)
(getMemInputs(m) ++ outputs).foreach(memSignals.append(_))
val dataWidth = getWidth(m.dataType)
- outputs.foreach(name => unusedMemOutputs(name) = dataWidth)
+ outputs.foreach(name => unusedOutputs(name) = BVSymbol(name, dataWidth))
memories.append(m)
case ir.Connect(info, loc, expr) =>
if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
if (!isClock(loc.tpe)) { // we ignore clock connections
val name = loc.serialize
- insertDummyAssignsForMemoryOutputs(expr)
+ insertDummyAssignsForUnusedOutputs(expr)
infos.append(name -> info)
connects.append((name, onExpression(expr, getWidth(loc.tpe), name)))
}
@@ -544,40 +568,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
val name = loc.serialize
infos.append(name -> info)
connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe))))
- case ir.DefInstance(info, name, module, tpe) =>
- namespace.newName(name)
- if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}")
- // we treat all instances as blackboxes
- logger.warn(
- s"WARN: treating instance $name of $module as blackbox. " +
- "Please flatten your hierarchy if you want to include submodules in the formal model."
- )
- val ports = tpe.asInstanceOf[ir.BundleType].fields
- // skip async reset ports
- ports.filterNot(p => isAsyncReset(p.tpe)).foreach { p =>
- if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p")
- val isOutput = p.flip == ir.Default
- val pName = name + "." + p.name
- infos.append(pName -> info)
- // outputs of the submodule become inputs to our module
- if (isOutput) {
- if (isClock(p.tpe)) {
- clocks.add(pName)
- } else {
- inputs.append(BVSymbol(pName, getWidth(p.tpe)))
- }
- } else {
- if (!isClock(p.tpe)) { // we ignore clock outputs
- outputs.append(pName)
- }
- }
- }
+ case ir.DefInstance(info, name, module, tpe) => onInstance(info, name, module, tpe)
case s @ ir.Verification(op, info, _, pred, en, msg) =>
if (op == ir.Formal.Cover) {
logger.warn(s"WARN: Cover statement was ignored: ${s.serialize}")
} else {
- insertDummyAssignsForMemoryOutputs(pred)
- insertDummyAssignsForMemoryOutputs(en)
+ insertDummyAssignsForUnusedOutputs(pred)
+ insertDummyAssignsForUnusedOutputs(en)
val name = namespace.newName(msgToName(op.toString, msg.string))
val predicate = onExpression(pred, name + "_predicate")
val enabled = onExpression(en, name + "_enabled")
@@ -604,6 +601,70 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
case other => other.foreachStmt(onStatement)
}
+ private def onInstance(info: ir.Info, name: String, module: String, tpe: ir.Type): Unit = {
+ namespace.newName(name)
+ if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}")
+ if (uninterpreted.contains(module)) {
+ onUninterpretedInstance(info: ir.Info, name: String, module: String, tpe: ir.Type)
+ } else {
+ // We treat all instances that aren't annotated as uninterpreted as blackboxes
+ // this means that their outputs could be any value, no matter what their inputs are.
+ logger.warn(
+ s"WARN: treating instance $name of $module as blackbox. " +
+ "Please flatten your hierarchy if you want to include submodules in the formal model."
+ )
+ val ports = tpe.asInstanceOf[ir.BundleType].fields
+ // skip async reset ports
+ ports.filterNot(p => isAsyncReset(p.tpe)).foreach { p =>
+ if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p")
+ val isOutput = p.flip == ir.Default
+ val pName = name + "." + p.name
+ infos.append(pName -> info)
+ // outputs of the submodule become inputs to our module
+ if (isOutput) {
+ if (isClock(p.tpe)) {
+ clocks.add(pName)
+ } else {
+ inputs.append(BVSymbol(pName, getWidth(p.tpe)))
+ }
+ } else {
+ if (!isClock(p.tpe)) { // we ignore clock outputs
+ outputs.append(pName)
+ }
+ }
+ }
+ }
+ }
+
+ private def onUninterpretedInstance(info: ir.Info, instanceName: String, module: String, tpe: ir.Type): Unit = {
+ val anno = uninterpreted(module)
+
+ // sanity checks for ports were done already using the UninterpretedModule.checkModule function
+ val ports = tpe.asInstanceOf[ir.BundleType].fields
+
+ val outputs = ports.filter(_.flip == ir.Default).map(p => BVSymbol(p.name, getWidth(p.tpe)))
+ val inputs = ports.filterNot(_.flip == ir.Default).map(p => BVSymbol(p.name, getWidth(p.tpe)))
+
+ assert(anno.stateBits == 0, "TODO: implement support for uninterpreted stateful modules!")
+
+ // for state-less (i.e. combinatorial) circuits, the outputs only depend on the inputs
+ val args = inputs.map(i => BVSymbol(instanceName + "." + i.name, i.width)).toList
+ outputs.foreach { out =>
+ val functionName = anno.prefix + "." + out.name
+ val call = BVFunctionCall(functionName, args, out.width)
+ val wireName = instanceName + "." + out.name
+ // remember which functions were called
+ functionCalls.append(call.toSymbol)
+ // insert the output definition right before its first use in an attempt to get SSA
+ unusedOutputs(wireName) = call
+ // treat these outputs as wires
+ wires.append(wireName)
+ }
+
+ // we also treat the arguments as wires
+ wires ++= args.map(_.name)
+ }
+
private val readInputFields = List("en", "addr")
private val writeInputFields = List("en", "mask", "addr", "data")
private def getMemInputs(m: ir.DefMemory): Iterable[String] = {
@@ -617,27 +678,27 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
val p = m.name + "."
m.readers.map(r => p + r + ".data")
}
- // inserts a dummy assign right before a memory output is used for the first time
+ // inserts a dummy assign right before a memory/uninterpreted module output is used for the first time
// example:
// m.r.data <= m.r.data ; this is the dummy assign
// test <= m.r.data ; this is the first use of m.r.data
- private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if (unusedMemOutputs.nonEmpty) {
- implicit val uses = mutable.ArrayBuffer[String]()
- findUnusedMemoryOutputUse(next)
+ private def insertDummyAssignsForUnusedOutputs(next: ir.Expression): Unit = if (unusedOutputs.nonEmpty) {
+ val uses = mutable.ArrayBuffer[String]()
+ findUnusedOutputUse(next)(uses)
if (uses.nonEmpty) {
val useSet = uses.toSet
- unusedMemOutputs.foreach {
- case (name, width) =>
- if (useSet.contains(name)) connects.append(name -> BVSymbol(name, width))
+ unusedOutputs.foreach {
+ case (name, value) =>
+ if (useSet.contains(name)) connects.append(name -> value)
}
- useSet.foreach(name => unusedMemOutputs.remove(name))
+ useSet.foreach(name => unusedOutputs.remove(name))
}
}
- private def findUnusedMemoryOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match {
+ private def findUnusedOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match {
case s: ir.SubField =>
val name = s.serialize
- if (unusedMemOutputs.contains(name)) uses.append(name)
- case other => other.foreachExpr(findUnusedMemoryOutputUse)
+ if (unusedOutputs.contains(name)) uses.append(name)
+ case other => other.foreachExpr(findUnusedOutputUse)
}
private case class Context(baseName: String) extends TranslationContext {
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
index 63692006..0fc507e6 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
@@ -138,6 +138,17 @@ private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr
override def children: List[BVExpr] = List(cond, tru, fals)
}
+/** apply bv arguments to a function which returns a result of bit vector type */
+private case class BVFunctionCall(name: String, args: List[BVExpr], width: Int) extends BVExpr {
+ override def children = args
+ def toSymbol: BVFunctionSymbol = BVFunctionSymbol(name, args.map(_.width), width)
+ override def toString: String = args.mkString(name + "(", ", ", ")")
+}
+
+private case class BVFunctionSymbol(name: String, argWidths: List[Int], width: Int) {
+ override def toString: String = s"$name : " + (argWidths :+ width).map(w => s"bv<$w>").mkString(" -> ")
+}
+
private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int }
private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol {
assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
index 19f1de84..13ed8bdd 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
@@ -54,6 +54,11 @@ private object SMTExprVisitor {
case old @ BVIte(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
bv(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC))
+ // n-ary
+ case old @ BVFunctionCall(name, args, width) =>
+ val nArgs = args.map(a => map(a, bv, ar))
+ val noneNew = nArgs.zip(args).forall { case (n, o) => n.eq(o) }
+ bv(if (noneNew) old else BVFunctionCall(name, nArgs, width))
}
private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match {
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
index 7bc0a077..75bde09c 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
@@ -24,6 +24,11 @@ private object SMTLibSerializer {
case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth)
}
+ def declareFunction(foo: BVFunctionSymbol): SMTCommand = {
+ val args = foo.argWidths.map(serializeBitVectorType)
+ DeclareFunction(BVSymbol(foo.name, foo.width), args)
+ }
+
private def serialize(e: BVExpr): String = e match {
case BVLiteral(value, width) =>
val mask = (BigInt(1) << width) - 1
@@ -74,6 +79,7 @@ private object SMTLibSerializer {
case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})"
case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})"
case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
+ case BVFunctionCall(name, args, _) => args.map(serialize).mkString(s"($name ", " ", ")")
case BVRawExpr(serialized, _) => serialized
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
index f6d9a26f..d35fe139 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
@@ -20,6 +20,9 @@ private object SMTTransitionSystemEncoder {
// emit header as comments
cmds ++= sys.header.map(Comment)
+ // declare uninterpreted functions used in model
+ cmds ++= sys.ufs.map(SMTLibSerializer.declareFunction)
+
// declare state type
val stateType = id(name + "_s")
cmds += DeclareUninterpretedSort(stateType)
diff --git a/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala b/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala
new file mode 100644
index 00000000..c7442f69
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala
@@ -0,0 +1,86 @@
+// SPDX-License-Identifier: Apache-2.0
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import firrtl.annotations._
+import firrtl.ir
+import firrtl.passes.PassException
+
+/** ExtModules annotated as UninterpretedModule will be modelled as
+ * UninterpretedFunction (SMTLib) or constant arrays (btor2).
+ * This can be useful when trying to abstract over a function that the
+ * SMT solver or model checker is struggling with.
+ *
+ * E.g., one could declare an abstract 64bit multiplier like this:
+ * ```
+ * extmodule Mul64 :
+ * input a : UInt<64>
+ * input b : UInt<64>
+ * output r : UInt<64>
+ * ```
+ * Now instead of using Chisel to actually implement a multiplication circuit
+ * we can instantiate this Mul64 module twice: Once in our implementation
+ * and once for our correctness property that might specify how the
+ * multiply instruction is supposed to be executed on our CPU.
+ * Now instead of having to prove equivalence of multiplication circuits, the
+ * solver only has to make sure that the connections to the multiplier are correct,
+ * since if `a` and `b` are the same on both instances of `Mul64`, then the `r` output
+ * will also be the same. This is a much easier problem and will result in much faster
+ * solving due to manual abstraction.
+ *
+ * When [[stateBits]] is 0, we model the module as purely combinatorial circuit and
+ * thus expect there to be no clock wire going into the module.
+ * Every output is thus a function of all inputs of the module.
+ *
+ * When [[stateBits]] is an N greater than zero, we will model the module as having an abstract state of width N.
+ * Thus on every clock transition the abstract state is updated and all outputs will take the state
+ * as well as the current inputs as arguments.
+ * TODO: Support for stateful circuits is work in progress.
+ *
+ * All output functions well be prefixed with [[prefix]] and end in the name of the output pin.
+ * It is the users responsibility to ensure that all function names will be unique by choosing apropriate
+ * prefixes.
+ *
+ * The annotation is consumed by the [[FirrtlToTransitionSystem]] pass.
+ */
+case class UninterpretedModuleAnnotation(target: ModuleTarget, prefix: String, stateBits: Int = 0)
+ extends SingleTargetAnnotation[ModuleTarget] {
+ require(stateBits >= 0, "negative number of bits is forbidden")
+ if (stateBits > 0) throw new NotImplementedError("TODO: support for stateful circuits is not implemented yet!")
+ override def duplicate(n: ModuleTarget) = copy(n)
+}
+
+object UninterpretedModuleAnnotation {
+
+ /** checks to see whether the annotation module can actually be abstracted. Use *after* LowerTypes! */
+ def checkModule(m: ir.DefModule, anno: UninterpretedModuleAnnotation): Unit = m match {
+ case _: ir.Module =>
+ throw new UninterpretedModuleException(s"UninterpretedModuleAnnotation can only be used with extmodule! $anno")
+ case m: ir.ExtModule =>
+ val clockInputs = m.ports.collect { case p @ ir.Port(_, _, ir.Input, ir.ClockType) => p.name }
+ val clockOutput = m.ports.collect { case p @ ir.Port(_, _, ir.Output, ir.ClockType) => p.name }
+ val asyncResets = m.ports.collect { case p @ ir.Port(_, _, _, ir.AsyncResetType) => p.name }
+ if (clockOutput.nonEmpty) {
+ throw new UninterpretedModuleException(
+ s"We do not support clock outputs for uninterpreted modules! $clockOutput"
+ )
+ }
+ if (asyncResets.nonEmpty) {
+ throw new UninterpretedModuleException(
+ s"We do not support async reset I/O for uninterpreted modules! $asyncResets"
+ )
+ }
+ if (anno.stateBits == 0) {
+ if (clockInputs.nonEmpty) {
+ throw new UninterpretedModuleException(s"A combinatorial module may not have any clock inputs! $clockInputs")
+ }
+ } else {
+ if (clockInputs.size != 1) {
+ throw new UninterpretedModuleException(s"A stateful module must have exactly one clock input! $clockInputs")
+ }
+ }
+ }
+}
+
+private class UninterpretedModuleException(s: String) extends PassException(s)
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala
new file mode 100644
index 00000000..e4404d10
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala
@@ -0,0 +1,49 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl.backends.experimental.smt.end2end
+
+import firrtl.annotations.CircuitTarget
+import firrtl.backends.experimental.smt.UninterpretedModuleAnnotation
+
+class UninterpretedModulesSpec extends EndToEndSMTBaseSpec {
+
+ private def testCircuit(assumption: String = ""): String = {
+ s"""circuit UF00:
+ | module UF00:
+ | input clk: Clock
+ | input a: UInt<128>
+ | input b: UInt<128>
+ | input c: UInt<128>
+ |
+ | inst m0 of Magic
+ | m0.a <= a
+ | m0.b <= b
+ |
+ | inst m1 of Magic
+ | m1.a <= a
+ | m1.b <= c
+ |
+ | assert(clk, eq(m0.r, m1.r), UInt(1), "m0.r == m1.r")
+ | $assumption
+ | extmodule Magic:
+ | input a: UInt<128>
+ | input b: UInt<128>
+ | output r: UInt<128>
+ |""".stripMargin
+ }
+ private val magicAnno = UninterpretedModuleAnnotation(CircuitTarget("UF00").module("Magic"), "magic", 0)
+
+ "two instances of the same uninterpreted module" should "give the same result when given the same inputs" taggedAs (RequiresZ3) in {
+ val assumeTheSame = """assume(clk, eq(b,c), UInt(1), "b == c")"""
+ test(testCircuit(assumeTheSame), MCSuccess, 1, "inputs are the same ==> outputs are the same", Seq(magicAnno))
+ }
+ "two instances of the same uninterpreted module" should "not always give the same result when given potentially different inputs" taggedAs (RequiresZ3) in {
+ test(
+ testCircuit(),
+ MCFail(0),
+ 1,
+ "inputs are not necessarily the same ==> outputs can be different",
+ Seq(magicAnno)
+ )
+ }
+}