diff options
| author | Kevin Laeufer | 2020-12-02 11:21:31 -0800 |
|---|---|---|
| committer | GitHub | 2020-12-02 19:21:31 +0000 |
| commit | 228878ecb49f87497638b41086c7194cd59ea50b (patch) | |
| tree | bea730acadf590bb273541cd37b5cc55a5efb7c9 /src | |
| parent | 6c5ce834e26386100b196881f6e487aed26c9c0a (diff) | |
smt: add support for uninterpreted ext modules (#1994)
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src')
8 files changed, 304 insertions, 55 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala index 4cd5c9f7..f96fd4e8 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala @@ -3,12 +3,28 @@ package firrtl.backends.experimental.smt +import firrtl.backends.experimental.smt.Btor2Serializer.functionCallToArrayRead + import scala.collection.mutable private object Btor2Serializer { def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = { new Btor2Serializer().run(sys, skipOutput) } + + private def functionCallToArrayRead(call: BVFunctionCall): BVExpr = { + if (call.args.isEmpty) { + BVSymbol(call.name, call.width) + } else { + val index = concat(call.args) + val a = ArraySymbol(call.name, indexWidth = index.width, dataWidth = call.width) + ArrayRead(a, index) + } + } + private def concat(e: Iterable[BVExpr]): BVExpr = { + require(e.nonEmpty) + e.reduce((a, b) => BVConcat(a, b)) + } } private class Btor2Serializer private () { @@ -65,6 +81,7 @@ private class Btor2Serializer private () { case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b) case BVOp(op, a, b) => binary(s(op), expr.width, a, b) case BVConcat(a, b) => binary("concat", expr.width, a, b) + case call: BVFunctionCall => s(functionCallToArrayRead(call)) case ArrayRead(array, index) => line(s"read ${t(expr.width)} ${s(array)} ${s(index)}") case BVIte(cond, tru, fals) => @@ -164,6 +181,17 @@ private class Btor2Serializer private () { declare(ii.name, line(s"input ${t(ii.width)} ${ii.name}")) } + // declare uninterpreted functions a constant arrays + sys.ufs.foreach { foo => + val sym = if (foo.argWidths.isEmpty) { BVSymbol(foo.name, foo.width) } + else { + ArraySymbol(foo.name, foo.argWidths.sum, foo.width) + } + comment(foo.toString) + declare(sym.name, line(s"state ${t(sym)} ${sym.name}")) + line(s"next ${t(sym)} ${s(sym)} ${s(sym)}") + } + // define state init sys.states.foreach { st => // calculate init expression before declaring the state diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index aed2011a..145b5b0f 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -40,10 +40,12 @@ private case class TransitionSystem( assumes: Set[String], asserts: Set[String], fair: Set[String], + ufs: List[BVFunctionSymbol] = List(), comments: Map[String, String] = Map(), header: Array[String] = Array()) { def serialize: String = { (Iterator(name) ++ + ufs.map(u => u.toString) ++ inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++ signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++ states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")).mkString("\n") @@ -79,15 +81,25 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { .map(a => a.target.ref -> a.initValue) .toMap + // module look up table + val modules = circuit.modules.map(m => m.name -> m).toMap + + // collect uninterpreted module annotations + val uninterpreted = afterPreset.annotations.collect { + case a: UninterpretedModuleAnnotation => + UninterpretedModuleAnnotation.checkModule(modules(a.target.module), a) + a.target.module -> a + }.toMap + // convert the main module - val main = circuit.modules.find(_.name == circuit.main).get + val main = modules(circuit.main) val sys = main match { case x: ir.ExtModule => throw new ExtModuleException( "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog." ) case m: ir.Module => - new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit) + new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit, uninterpreted = uninterpreted) } val sortedSys = TopologicalSort.run(sys) @@ -122,12 +134,13 @@ private class MissingFeatureException(s: String) private class ModuleToTransitionSystem extends LazyLogging { def run( - m: ir.Module, - presetRegs: Set[String] = Set(), - memInit: Map[String, MemoryInitValue] = Map() + m: ir.Module, + presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map(), + uninterpreted: Map[String, UninterpretedModuleAnnotation] = Map() ): TransitionSystem = { // first pass over the module to convert expressions; discover state and I/O - val scan = new ModuleScanner(makeRandom) + val scan = new ModuleScanner(makeRandom, uninterpreted) m.foreachPort(scan.onPort) m.foreachStmt(scan.onStatement) @@ -188,6 +201,10 @@ private class ModuleToTransitionSystem extends LazyLogging { val header = serializeInfo(m.info).map(InfoPrefix + _).toArray val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints + + // collect unique functions + val ufs = scan.functionCalls.groupBy(_.name).map(_._2.head).toList + TransitionSystem( m.name, inputs.toArray, @@ -197,6 +214,7 @@ private class ModuleToTransitionSystem extends LazyLogging { constraints, bad, fair, + ufs, comments.toMap, header ) @@ -456,7 +474,10 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr, namespace: Nam } // performas a first pass over the module collecting all connections, wires, registers, input and outputs -private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLogging { +private class ModuleScanner( + makeRandom: (String, Int) => BVExpr, + uninterpreted: Map[String, UninterpretedModuleAnnotation]) + extends LazyLogging { import FirrtlExpressionSemantics.getWidth private[firrtl] val inputs = mutable.ArrayBuffer[BVSymbol]() @@ -473,10 +494,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog private[firrtl] val assumes = mutable.ArrayBuffer[String]() // maps identifiers to their info private[firrtl] val infos = mutable.ArrayBuffer[(String, ir.Info)]() - // keeps track of unused memory (data) outputs so that we can see where they are first used - private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]() + // Keeps track of (so far) unused memory (data) and uninterpreted module outputs. + // This is used in order to delay declaring them for as long as possible. + private val unusedOutputs = mutable.LinkedHashMap[String, BVExpr]() // ensure unique names for assert/assume signals private[firrtl] val namespace = Namespace() + // keep track of all uninterpreted functions called + private[firrtl] val functionCalls = mutable.ArrayBuffer[BVFunctionSymbol]() private[firrtl] def onPort(p: ir.Port): Unit = { if (isAsyncReset(p.tpe)) { @@ -508,7 +532,7 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog case ir.DefNode(info, name, expr) => namespace.newName(name) if (!isClock(expr.tpe)) { - insertDummyAssignsForMemoryOutputs(expr) + insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) val e = onExpression(expr, name) nodes.append(name) @@ -516,8 +540,8 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog } case ir.DefRegister(info, name, tpe, _, reset, init) => namespace.newName(name) - insertDummyAssignsForMemoryOutputs(reset) - insertDummyAssignsForMemoryOutputs(init) + insertDummyAssignsForUnusedOutputs(reset) + insertDummyAssignsForUnusedOutputs(init) infos.append(name -> info) val width = getWidth(tpe) val resetExpr = onExpression(reset, 1, name + "_reset") @@ -529,13 +553,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog val outputs = getMemOutputs(m) (getMemInputs(m) ++ outputs).foreach(memSignals.append(_)) val dataWidth = getWidth(m.dataType) - outputs.foreach(name => unusedMemOutputs(name) = dataWidth) + outputs.foreach(name => unusedOutputs(name) = BVSymbol(name, dataWidth)) memories.append(m) case ir.Connect(info, loc, expr) => if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") if (!isClock(loc.tpe)) { // we ignore clock connections val name = loc.serialize - insertDummyAssignsForMemoryOutputs(expr) + insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) connects.append((name, onExpression(expr, getWidth(loc.tpe), name))) } @@ -544,40 +568,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog val name = loc.serialize infos.append(name -> info) connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe)))) - case ir.DefInstance(info, name, module, tpe) => - namespace.newName(name) - if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") - // we treat all instances as blackboxes - logger.warn( - s"WARN: treating instance $name of $module as blackbox. " + - "Please flatten your hierarchy if you want to include submodules in the formal model." - ) - val ports = tpe.asInstanceOf[ir.BundleType].fields - // skip async reset ports - ports.filterNot(p => isAsyncReset(p.tpe)).foreach { p => - if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p") - val isOutput = p.flip == ir.Default - val pName = name + "." + p.name - infos.append(pName -> info) - // outputs of the submodule become inputs to our module - if (isOutput) { - if (isClock(p.tpe)) { - clocks.add(pName) - } else { - inputs.append(BVSymbol(pName, getWidth(p.tpe))) - } - } else { - if (!isClock(p.tpe)) { // we ignore clock outputs - outputs.append(pName) - } - } - } + case ir.DefInstance(info, name, module, tpe) => onInstance(info, name, module, tpe) case s @ ir.Verification(op, info, _, pred, en, msg) => if (op == ir.Formal.Cover) { logger.warn(s"WARN: Cover statement was ignored: ${s.serialize}") } else { - insertDummyAssignsForMemoryOutputs(pred) - insertDummyAssignsForMemoryOutputs(en) + insertDummyAssignsForUnusedOutputs(pred) + insertDummyAssignsForUnusedOutputs(en) val name = namespace.newName(msgToName(op.toString, msg.string)) val predicate = onExpression(pred, name + "_predicate") val enabled = onExpression(en, name + "_enabled") @@ -604,6 +601,70 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog case other => other.foreachStmt(onStatement) } + private def onInstance(info: ir.Info, name: String, module: String, tpe: ir.Type): Unit = { + namespace.newName(name) + if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") + if (uninterpreted.contains(module)) { + onUninterpretedInstance(info: ir.Info, name: String, module: String, tpe: ir.Type) + } else { + // We treat all instances that aren't annotated as uninterpreted as blackboxes + // this means that their outputs could be any value, no matter what their inputs are. + logger.warn( + s"WARN: treating instance $name of $module as blackbox. " + + "Please flatten your hierarchy if you want to include submodules in the formal model." + ) + val ports = tpe.asInstanceOf[ir.BundleType].fields + // skip async reset ports + ports.filterNot(p => isAsyncReset(p.tpe)).foreach { p => + if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p") + val isOutput = p.flip == ir.Default + val pName = name + "." + p.name + infos.append(pName -> info) + // outputs of the submodule become inputs to our module + if (isOutput) { + if (isClock(p.tpe)) { + clocks.add(pName) + } else { + inputs.append(BVSymbol(pName, getWidth(p.tpe))) + } + } else { + if (!isClock(p.tpe)) { // we ignore clock outputs + outputs.append(pName) + } + } + } + } + } + + private def onUninterpretedInstance(info: ir.Info, instanceName: String, module: String, tpe: ir.Type): Unit = { + val anno = uninterpreted(module) + + // sanity checks for ports were done already using the UninterpretedModule.checkModule function + val ports = tpe.asInstanceOf[ir.BundleType].fields + + val outputs = ports.filter(_.flip == ir.Default).map(p => BVSymbol(p.name, getWidth(p.tpe))) + val inputs = ports.filterNot(_.flip == ir.Default).map(p => BVSymbol(p.name, getWidth(p.tpe))) + + assert(anno.stateBits == 0, "TODO: implement support for uninterpreted stateful modules!") + + // for state-less (i.e. combinatorial) circuits, the outputs only depend on the inputs + val args = inputs.map(i => BVSymbol(instanceName + "." + i.name, i.width)).toList + outputs.foreach { out => + val functionName = anno.prefix + "." + out.name + val call = BVFunctionCall(functionName, args, out.width) + val wireName = instanceName + "." + out.name + // remember which functions were called + functionCalls.append(call.toSymbol) + // insert the output definition right before its first use in an attempt to get SSA + unusedOutputs(wireName) = call + // treat these outputs as wires + wires.append(wireName) + } + + // we also treat the arguments as wires + wires ++= args.map(_.name) + } + private val readInputFields = List("en", "addr") private val writeInputFields = List("en", "mask", "addr", "data") private def getMemInputs(m: ir.DefMemory): Iterable[String] = { @@ -617,27 +678,27 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog val p = m.name + "." m.readers.map(r => p + r + ".data") } - // inserts a dummy assign right before a memory output is used for the first time + // inserts a dummy assign right before a memory/uninterpreted module output is used for the first time // example: // m.r.data <= m.r.data ; this is the dummy assign // test <= m.r.data ; this is the first use of m.r.data - private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if (unusedMemOutputs.nonEmpty) { - implicit val uses = mutable.ArrayBuffer[String]() - findUnusedMemoryOutputUse(next) + private def insertDummyAssignsForUnusedOutputs(next: ir.Expression): Unit = if (unusedOutputs.nonEmpty) { + val uses = mutable.ArrayBuffer[String]() + findUnusedOutputUse(next)(uses) if (uses.nonEmpty) { val useSet = uses.toSet - unusedMemOutputs.foreach { - case (name, width) => - if (useSet.contains(name)) connects.append(name -> BVSymbol(name, width)) + unusedOutputs.foreach { + case (name, value) => + if (useSet.contains(name)) connects.append(name -> value) } - useSet.foreach(name => unusedMemOutputs.remove(name)) + useSet.foreach(name => unusedOutputs.remove(name)) } } - private def findUnusedMemoryOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match { + private def findUnusedOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match { case s: ir.SubField => val name = s.serialize - if (unusedMemOutputs.contains(name)) uses.append(name) - case other => other.foreachExpr(findUnusedMemoryOutputUse) + if (unusedOutputs.contains(name)) uses.append(name) + case other => other.foreachExpr(findUnusedOutputUse) } private case class Context(baseName: String) extends TranslationContext { diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala index 63692006..0fc507e6 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala @@ -138,6 +138,17 @@ private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr override def children: List[BVExpr] = List(cond, tru, fals) } +/** apply bv arguments to a function which returns a result of bit vector type */ +private case class BVFunctionCall(name: String, args: List[BVExpr], width: Int) extends BVExpr { + override def children = args + def toSymbol: BVFunctionSymbol = BVFunctionSymbol(name, args.map(_.width), width) + override def toString: String = args.mkString(name + "(", ", ", ")") +} + +private case class BVFunctionSymbol(name: String, argWidths: List[Int], width: Int) { + override def toString: String = s"$name : " + (argWidths :+ width).map(w => s"bv<$w>").mkString(" -> ") +} + private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int } private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol { assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala index 19f1de84..13ed8bdd 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala @@ -54,6 +54,11 @@ private object SMTExprVisitor { case old @ BVIte(a, b, c) => val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) bv(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)) + // n-ary + case old @ BVFunctionCall(name, args, width) => + val nArgs = args.map(a => map(a, bv, ar)) + val noneNew = nArgs.zip(args).forall { case (n, o) => n.eq(o) } + bv(if (noneNew) old else BVFunctionCall(name, nArgs, width)) } private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match { diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala index 7bc0a077..75bde09c 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala @@ -24,6 +24,11 @@ private object SMTLibSerializer { case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth) } + def declareFunction(foo: BVFunctionSymbol): SMTCommand = { + val args = foo.argWidths.map(serializeBitVectorType) + DeclareFunction(BVSymbol(foo.name, foo.width), args) + } + private def serialize(e: BVExpr): String = e match { case BVLiteral(value, width) => val mask = (BigInt(1) << width) - 1 @@ -74,6 +79,7 @@ private object SMTLibSerializer { case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})" case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})" case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case BVFunctionCall(name, args, _) => args.map(serialize).mkString(s"($name ", " ", ")") case BVRawExpr(serialized, _) => serialized } diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala index f6d9a26f..d35fe139 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala @@ -20,6 +20,9 @@ private object SMTTransitionSystemEncoder { // emit header as comments cmds ++= sys.header.map(Comment) + // declare uninterpreted functions used in model + cmds ++= sys.ufs.map(SMTLibSerializer.declareFunction) + // declare state type val stateType = id(name + "_s") cmds += DeclareUninterpretedSort(stateType) diff --git a/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala b/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala new file mode 100644 index 00000000..c7442f69 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import firrtl.annotations._ +import firrtl.ir +import firrtl.passes.PassException + +/** ExtModules annotated as UninterpretedModule will be modelled as + * UninterpretedFunction (SMTLib) or constant arrays (btor2). + * This can be useful when trying to abstract over a function that the + * SMT solver or model checker is struggling with. + * + * E.g., one could declare an abstract 64bit multiplier like this: + * ``` + * extmodule Mul64 : + * input a : UInt<64> + * input b : UInt<64> + * output r : UInt<64> + * ``` + * Now instead of using Chisel to actually implement a multiplication circuit + * we can instantiate this Mul64 module twice: Once in our implementation + * and once for our correctness property that might specify how the + * multiply instruction is supposed to be executed on our CPU. + * Now instead of having to prove equivalence of multiplication circuits, the + * solver only has to make sure that the connections to the multiplier are correct, + * since if `a` and `b` are the same on both instances of `Mul64`, then the `r` output + * will also be the same. This is a much easier problem and will result in much faster + * solving due to manual abstraction. + * + * When [[stateBits]] is 0, we model the module as purely combinatorial circuit and + * thus expect there to be no clock wire going into the module. + * Every output is thus a function of all inputs of the module. + * + * When [[stateBits]] is an N greater than zero, we will model the module as having an abstract state of width N. + * Thus on every clock transition the abstract state is updated and all outputs will take the state + * as well as the current inputs as arguments. + * TODO: Support for stateful circuits is work in progress. + * + * All output functions well be prefixed with [[prefix]] and end in the name of the output pin. + * It is the users responsibility to ensure that all function names will be unique by choosing apropriate + * prefixes. + * + * The annotation is consumed by the [[FirrtlToTransitionSystem]] pass. + */ +case class UninterpretedModuleAnnotation(target: ModuleTarget, prefix: String, stateBits: Int = 0) + extends SingleTargetAnnotation[ModuleTarget] { + require(stateBits >= 0, "negative number of bits is forbidden") + if (stateBits > 0) throw new NotImplementedError("TODO: support for stateful circuits is not implemented yet!") + override def duplicate(n: ModuleTarget) = copy(n) +} + +object UninterpretedModuleAnnotation { + + /** checks to see whether the annotation module can actually be abstracted. Use *after* LowerTypes! */ + def checkModule(m: ir.DefModule, anno: UninterpretedModuleAnnotation): Unit = m match { + case _: ir.Module => + throw new UninterpretedModuleException(s"UninterpretedModuleAnnotation can only be used with extmodule! $anno") + case m: ir.ExtModule => + val clockInputs = m.ports.collect { case p @ ir.Port(_, _, ir.Input, ir.ClockType) => p.name } + val clockOutput = m.ports.collect { case p @ ir.Port(_, _, ir.Output, ir.ClockType) => p.name } + val asyncResets = m.ports.collect { case p @ ir.Port(_, _, _, ir.AsyncResetType) => p.name } + if (clockOutput.nonEmpty) { + throw new UninterpretedModuleException( + s"We do not support clock outputs for uninterpreted modules! $clockOutput" + ) + } + if (asyncResets.nonEmpty) { + throw new UninterpretedModuleException( + s"We do not support async reset I/O for uninterpreted modules! $asyncResets" + ) + } + if (anno.stateBits == 0) { + if (clockInputs.nonEmpty) { + throw new UninterpretedModuleException(s"A combinatorial module may not have any clock inputs! $clockInputs") + } + } else { + if (clockInputs.size != 1) { + throw new UninterpretedModuleException(s"A stateful module must have exactly one clock input! $clockInputs") + } + } + } +} + +private class UninterpretedModuleException(s: String) extends PassException(s) diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala new file mode 100644 index 00000000..e4404d10 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.CircuitTarget +import firrtl.backends.experimental.smt.UninterpretedModuleAnnotation + +class UninterpretedModulesSpec extends EndToEndSMTBaseSpec { + + private def testCircuit(assumption: String = ""): String = { + s"""circuit UF00: + | module UF00: + | input clk: Clock + | input a: UInt<128> + | input b: UInt<128> + | input c: UInt<128> + | + | inst m0 of Magic + | m0.a <= a + | m0.b <= b + | + | inst m1 of Magic + | m1.a <= a + | m1.b <= c + | + | assert(clk, eq(m0.r, m1.r), UInt(1), "m0.r == m1.r") + | $assumption + | extmodule Magic: + | input a: UInt<128> + | input b: UInt<128> + | output r: UInt<128> + |""".stripMargin + } + private val magicAnno = UninterpretedModuleAnnotation(CircuitTarget("UF00").module("Magic"), "magic", 0) + + "two instances of the same uninterpreted module" should "give the same result when given the same inputs" taggedAs (RequiresZ3) in { + val assumeTheSame = """assume(clk, eq(b,c), UInt(1), "b == c")""" + test(testCircuit(assumeTheSame), MCSuccess, 1, "inputs are the same ==> outputs are the same", Seq(magicAnno)) + } + "two instances of the same uninterpreted module" should "not always give the same result when given potentially different inputs" taggedAs (RequiresZ3) in { + test( + testCircuit(), + MCFail(0), + 1, + "inputs are not necessarily the same ==> outputs can be different", + Seq(magicAnno) + ) + } +} |
