diff options
Diffstat (limited to 'src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala')
| -rw-r--r-- | src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala | 59 |
1 files changed, 36 insertions, 23 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala index 75bde09c..bb4e0348 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala @@ -19,14 +19,9 @@ private object SMTLibSerializer { case a: ArrayExpr => serialize(a) } - def serializeType(e: SMTExpr): String = e match { - case b: BVExpr => serializeBitVectorType(b.width) - case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth) - } - - def declareFunction(foo: BVFunctionSymbol): SMTCommand = { - val args = foo.argWidths.map(serializeBitVectorType) - DeclareFunction(BVSymbol(foo.name, foo.width), args) + def serialize(t: SMTType): String = t match { + case BVType(width) => serializeBitVectorType(width) + case ArrayType(indexWidth, dataWidth) => serializeArrayType(indexWidth, dataWidth) } private def serialize(e: BVExpr): String = e match { @@ -71,37 +66,57 @@ private object SMTLibSerializer { case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})" case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})" // boolean operations get a special treatment for 1-bit vectors aka bools - case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})" - case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})" + case b: BVAnd => serializeVariadic(if (b.width == 1) "and" else "bvand", b.terms) + case b: BVOr => serializeVariadic(if (b.width == 1) "or" else "bvor", b.terms) case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})" case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})") case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})" case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})" case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})" case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" - case BVFunctionCall(name, args, _) => args.map(serialize).mkString(s"($name ", " ", ")") - case BVRawExpr(serialized, _) => serialized + case BVFunctionCall(name, args, _) => args.map(serializeArg).mkString(s"($name ", " ", ")") + case BVForall(variable, e) => s"(forall ((${variable.name} ${serialize(variable.tpe)})) ${serialize(e)})" + } + + private def serializeVariadic(op: String, terms: List[BVExpr]): String = terms match { + case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op") + case Seq(a, b) => s"($op ${serialize(a)} ${serialize(b)})" + case head :: tail => s"($op ${serialize(head)} ${serializeVariadic(op, tail)})" } def serialize(e: ArrayExpr): String = e match { - case ArraySymbol(name, _, _) => escapeIdentifier(name) - case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})" - case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" - case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})" - case ArrayRawExpr(serialized, _, _) => serialized + case ArraySymbol(name, _, _) => escapeIdentifier(name) + case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})" + case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})" + case ArrayFunctionCall(name, args, _, _) => args.map(serializeArg).mkString(s"($name ", " ", ")") } def serialize(c: SMTCommand): String = c match { case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n") case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)" case DefineFunction(name, args, e) => - val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ") - s"(define-fun ${escapeIdentifier(name)} ($aa) ${serializeType(e)} ${serialize(e)})" + val aa = args.map(a => s"(${serializeArg(a)} ${serializeArgTpe(a)})").mkString(" ") + s"(define-fun ${escapeIdentifier(name)} ($aa) ${serialize(e.tpe)} ${serialize(e)})" case DeclareFunction(sym, tpes) => - val aa = tpes.mkString(" ") - s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serializeType(sym)})" + val aa = tpes.map(serializeArgTpe).mkString(" ") + s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serialize(sym.tpe)})" + case SetLogic(logic) => s"(set-logic $logic)" + case DeclareUninterpretedSymbol(name, tpe) => + s"(declare-fun ${escapeIdentifier(name)} () ${escapeIdentifier(tpe)})" } + private def serializeArgTpe(a: SMTFunctionArg): String = + a match { + case u: UTSymbol => escapeIdentifier(u.tpe) + case s: SMTExpr => serialize(s.tpe) + } + private def serializeArg(a: SMTFunctionArg): String = + a match { + case u: UTSymbol => escapeIdentifier(u.name) + case s: SMTExpr => serialize(s) + } + private def serializeArrayType(indexWidth: Int, dataWidth: Int): String = s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})" private def serializeBitVectorType(width: Int): String = @@ -109,8 +124,6 @@ private object SMTLibSerializer { else { assert(width > 1); s"(_ BitVec $width)" } private def serialize(op: Op.Value): String = op match { - case Op.And => "bvand" - case Op.Or => "bvor" case Op.Xor => "bvxor" case Op.ArithmeticShiftRight => "bvashr" case Op.ShiftRight => "bvlshr" |
