aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala')
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala59
1 files changed, 36 insertions, 23 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
index 75bde09c..bb4e0348 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
@@ -19,14 +19,9 @@ private object SMTLibSerializer {
case a: ArrayExpr => serialize(a)
}
- def serializeType(e: SMTExpr): String = e match {
- case b: BVExpr => serializeBitVectorType(b.width)
- case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth)
- }
-
- def declareFunction(foo: BVFunctionSymbol): SMTCommand = {
- val args = foo.argWidths.map(serializeBitVectorType)
- DeclareFunction(BVSymbol(foo.name, foo.width), args)
+ def serialize(t: SMTType): String = t match {
+ case BVType(width) => serializeBitVectorType(width)
+ case ArrayType(indexWidth, dataWidth) => serializeArrayType(indexWidth, dataWidth)
}
private def serialize(e: BVExpr): String = e match {
@@ -71,37 +66,57 @@ private object SMTLibSerializer {
case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})"
case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})"
// boolean operations get a special treatment for 1-bit vectors aka bools
- case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})"
- case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})"
+ case b: BVAnd => serializeVariadic(if (b.width == 1) "and" else "bvand", b.terms)
+ case b: BVOr => serializeVariadic(if (b.width == 1) "or" else "bvor", b.terms)
case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})"
case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})")
case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})"
case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})"
case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})"
case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case BVFunctionCall(name, args, _) => args.map(serialize).mkString(s"($name ", " ", ")")
- case BVRawExpr(serialized, _) => serialized
+ case BVFunctionCall(name, args, _) => args.map(serializeArg).mkString(s"($name ", " ", ")")
+ case BVForall(variable, e) => s"(forall ((${variable.name} ${serialize(variable.tpe)})) ${serialize(e)})"
+ }
+
+ private def serializeVariadic(op: String, terms: List[BVExpr]): String = terms match {
+ case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op")
+ case Seq(a, b) => s"($op ${serialize(a)} ${serialize(b)})"
+ case head :: tail => s"($op ${serialize(head)} ${serializeVariadic(op, tail)})"
}
def serialize(e: ArrayExpr): String = e match {
- case ArraySymbol(name, _, _) => escapeIdentifier(name)
- case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})"
- case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
- case ArrayRawExpr(serialized, _, _) => serialized
+ case ArraySymbol(name, _, _) => escapeIdentifier(name)
+ case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})"
+ case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
+ case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
+ case ArrayFunctionCall(name, args, _, _) => args.map(serializeArg).mkString(s"($name ", " ", ")")
}
def serialize(c: SMTCommand): String = c match {
case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n")
case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)"
case DefineFunction(name, args, e) =>
- val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ")
- s"(define-fun ${escapeIdentifier(name)} ($aa) ${serializeType(e)} ${serialize(e)})"
+ val aa = args.map(a => s"(${serializeArg(a)} ${serializeArgTpe(a)})").mkString(" ")
+ s"(define-fun ${escapeIdentifier(name)} ($aa) ${serialize(e.tpe)} ${serialize(e)})"
case DeclareFunction(sym, tpes) =>
- val aa = tpes.mkString(" ")
- s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serializeType(sym)})"
+ val aa = tpes.map(serializeArgTpe).mkString(" ")
+ s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serialize(sym.tpe)})"
+ case SetLogic(logic) => s"(set-logic $logic)"
+ case DeclareUninterpretedSymbol(name, tpe) =>
+ s"(declare-fun ${escapeIdentifier(name)} () ${escapeIdentifier(tpe)})"
}
+ private def serializeArgTpe(a: SMTFunctionArg): String =
+ a match {
+ case u: UTSymbol => escapeIdentifier(u.tpe)
+ case s: SMTExpr => serialize(s.tpe)
+ }
+ private def serializeArg(a: SMTFunctionArg): String =
+ a match {
+ case u: UTSymbol => escapeIdentifier(u.name)
+ case s: SMTExpr => serialize(s)
+ }
+
private def serializeArrayType(indexWidth: Int, dataWidth: Int): String =
s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})"
private def serializeBitVectorType(width: Int): String =
@@ -109,8 +124,6 @@ private object SMTLibSerializer {
else { assert(width > 1); s"(_ BitVec $width)" }
private def serialize(op: Op.Value): String = op match {
- case Op.And => "bvand"
- case Op.Or => "bvor"
case Op.Xor => "bvxor"
case Op.ArithmeticShiftRight => "bvashr"
case Op.ShiftRight => "bvlshr"