summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/jib/jib_smt.ml161
-rw-r--r--src/smtlib.ml4
2 files changed, 119 insertions, 46 deletions
diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml
index 230e9e07..b62e761e 100644
--- a/src/jib/jib_smt.ml
+++ b/src/jib/jib_smt.ml
@@ -66,6 +66,27 @@ let opt_ignore_overflow = ref false
let opt_auto = ref false
+type event = Overflow | Assertion
+
+let event_name = function
+ | Overflow -> "overflow"
+ | Assertion -> "assert"
+
+module Event = struct
+ type t = event
+ let compare e1 e2 =
+ match e1, e2 with
+ | Overflow, Overflow -> 0
+ | Assertion, Assertion -> 0
+ | Overflow, _ -> 1
+ | _, Overflow -> -1
+end
+
+module EventMap = Map.Make(Event)
+
+(* Note that we have to use x : ty ref rather than mutable x : ty, to
+ make sure { ctx with x = ... } doesn't break the mutable state. *)
+
type ctx = {
(* Arbitrary-precision bitvectors are represented as a (BitVec lbits_index, BitVec (2 ^ lbits_index)) pair. *)
lbits_index : int;
@@ -76,11 +97,13 @@ type ctx = {
vector_index : int;
register_map : id list CTMap.t;
(* Keep track of all the tuple sizes we need to genenerate types for *)
- mutable tuple_sizes : IntSet.t;
+ tuple_sizes : IntSet.t ref;
tc_env : Type_check.Env.t;
pragma_l : Ast.l;
arg_stack : (int * string) Stack.t;
- ast : Type_check.tannot defs
+ ast : Type_check.tannot defs;
+ events : smt_def Stack.t EventMap.t ref;
+ pathcond : smt_exp;
}
(* These give the default bounds for various SMT types, stored in the
@@ -96,13 +119,34 @@ let initial_ctx () = {
lint_size = !opt_default_lint_size;
vector_index = !opt_default_vector_index;
register_map = CTMap.empty;
- tuple_sizes = IntSet.empty;
+ tuple_sizes = ref IntSet.empty;
tc_env = Type_check.initial_env;
pragma_l = Parse_ast.Unknown;
arg_stack = Stack.create ();
- ast = Defs []
+ ast = Defs [];
+ events = ref EventMap.empty;
+ pathcond = Bool_lit true;
}
+let event_stack ctx ev =
+ match EventMap.find_opt ev !(ctx.events) with
+ | Some stack -> stack
+ | None ->
+ let stack = Stack.create () in
+ ctx.events := EventMap.add ev stack !(ctx.events);
+ stack
+
+let event_check check (ev, checks) =
+ match ev with
+ | Overflow ->
+ if !opt_ignore_overflow then
+ Bool_lit false
+ else
+ Fn ("and", check :: checks)
+
+ | Assertion ->
+ Fn ("or", check :: List.map (fun c -> Fn ("not", [c])) checks)
+
let lbits_size ctx = Util.power 2 ctx.lbits_index
let vector_index = ref 5
@@ -138,7 +182,7 @@ let rec smt_ctyp ctx = function
| CT_variant (id, ctors) ->
mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) ctors)
| CT_tup ctyps ->
- ctx.tuple_sizes <- IntSet.add (List.length ctyps) ctx.tuple_sizes;
+ ctx.tuple_sizes := IntSet.add (List.length ctyps) !(ctx.tuple_sizes);
Tuple (List.map (smt_ctyp ctx) ctyps)
| CT_vector (_, ctyp) -> Array (Bitvec !vector_index, smt_ctyp ctx ctyp)
| CT_string -> Bitvec 64
@@ -210,6 +254,7 @@ let smt_value ctx vl ctyp =
| VL_bit Sail2_values.B0, CT_bit -> Bin "0"
| VL_bit Sail2_values.B1, CT_bit -> Bin "1"
| VL_unit, _ -> Var "unit"
+ | VL_string _, _ -> Var "unit" (* FIXME: String support *)
| vl, _ -> failwith ("Cannot translate literal to SMT " ^ string_of_value vl)
let zencode_ctor ctor_id unifiers =
@@ -260,7 +305,7 @@ let rec smt_cval ctx cval =
| _ -> failwith "Struct does not have struct type"
end
| V_tuple_member (frag, len, n) ->
- ctx.tuple_sizes <- IntSet.add len ctx.tuple_sizes;
+ ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes);
Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag])
| V_ref (Name (id, _), _) ->
let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in
@@ -276,12 +321,14 @@ let rec smt_cval ctx cval =
end
| cval -> failwith ("Unrecognised cval " ^ string_of_cval ~zencode:false cval)
-let overflow_checks = Stack.create ()
+let add_event ctx ev smt =
+ let stack = event_stack ctx ev in
+ Stack.push (Define_const (event_name ev ^ string_of_int (Stack.length stack), Bool, Fn ("=>", [ctx.pathcond; smt]))) stack
-let overflow_check smt =
+let overflow_check ctx smt =
if not !opt_ignore_overflow then (
Util.warn "Adding overflow check in generated SMT";
- Stack.push (Define_const ("overflow" ^ string_of_int (Stack.length overflow_checks), Bool, Fn ("not", [smt]))) overflow_checks
+ add_event ctx Overflow (Fn ("not", [smt]))
)
(**************************************************************************)
@@ -338,7 +385,7 @@ let builtin_gteq = builtin_int_comparison "bvsge" Big_int.greater_equal
(** [force_size n m exp] takes a smt expression assumed to be a
integer (signed bitvector) of length m and forces it to be length n
by either sign extending it or truncating it as required *)
-let force_size ?checked:(checked=true) n m smt =
+let force_size ?checked:(checked=true) ctx n m smt =
if n = m then
smt
else if n > m then
@@ -352,7 +399,7 @@ let force_size ?checked:(checked=true) n m smt =
(* Otherwise, all the top bits must be zero *)
Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvzero (m - n)])]))
in
- if checked then overflow_check check else ();
+ if checked then overflow_check ctx check else ();
Extract (n - 1, 0, smt)
let int_size ctx = function
@@ -375,17 +422,17 @@ let builtin_arith fn big_int_fn padding ctx v1 v2 ret_ctyp =
| ctyp, CT_constant c, _ ->
let n = int_size ctx ctyp in
- force_size (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c]))
+ force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c]))
| CT_constant c, ctyp, _ ->
let n = int_size ctx ctyp in
- force_size (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2]))
+ force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2]))
| ctyp1, ctyp2, _ ->
let ret_sz = int_size ctx ret_ctyp in
let smt1 = smt_cval ctx v1 in
let smt2 = smt_cval ctx v2 in
- force_size ret_sz (padding ret_sz) (Fn (fn, [force_size (padding ret_sz) (int_size ctx ctyp1) smt1;
- force_size (padding ret_sz) (int_size ctx ctyp2) smt2]))
+ force_size ctx ret_sz (padding ret_sz) (Fn (fn, [force_size ctx (padding ret_sz) (int_size ctx ctyp1) smt1;
+ force_size ctx (padding ret_sz) (int_size ctx ctyp2) smt2]))
let builtin_add_int = builtin_arith "bvadd" Big_int.add (fun x -> x + 1)
let builtin_sub_int = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1)
@@ -398,8 +445,8 @@ let builtin_negate_int ctx v ret_ctyp =
| CT_constant c, _ ->
bvint (int_size ctx ret_ctyp) (Big_int.negate c)
| ctyp, _ ->
- let smt = force_size (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in
- overflow_check (Fn ("=", [smt; Bin ("1" ^ String.make (int_size ctx ret_ctyp - 1) '0')]));
+ let smt = force_size ctx (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in
+ overflow_check ctx (Fn ("=", [smt; Bin ("1" ^ String.make (int_size ctx ret_ctyp - 1) '0')]));
Fn ("bvneg", [smt])
let builtin_shift_int fn big_int_fn ctx v1 v2 ret_ctyp =
@@ -411,17 +458,17 @@ let builtin_shift_int fn big_int_fn ctx v1 v2 ret_ctyp =
| ctyp, CT_constant c, _ ->
let n = int_size ctx ctyp in
- force_size (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c]))
+ force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c]))
| CT_constant c, ctyp, _ ->
let n = int_size ctx ctyp in
- force_size (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2]))
+ force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2]))
| ctyp1, ctyp2, _ ->
let ret_sz = int_size ctx ret_ctyp in
let smt1 = smt_cval ctx v1 in
let smt2 = smt_cval ctx v2 in
- (Fn (fn, [force_size ret_sz (int_size ctx ctyp1) smt1;
- force_size ret_sz (int_size ctx ctyp2) smt2]))
+ (Fn (fn, [force_size ctx ret_sz (int_size ctx ctyp1) smt1;
+ force_size ctx ret_sz (int_size ctx ctyp2) smt2]))
let builtin_shl_int = builtin_shift_int "bvshl" Big_int.shift_left
let builtin_shr_int = builtin_shift_int "bvashr" Big_int.shift_right
@@ -436,8 +483,8 @@ let builtin_abs_int ctx v ret_ctyp =
let sz = int_size ctx ctyp in
let smt = smt_cval ctx v in
Ite (Fn ("=", [Extract (sz - 1, sz -1, smt); Bin "1"]),
- force_size (int_size ctx ret_ctyp) sz (Fn ("bvneg", [smt])),
- force_size (int_size ctx ret_ctyp) sz smt)
+ force_size ctx (int_size ctx ret_ctyp) sz (Fn ("bvneg", [smt])),
+ force_size ctx (int_size ctx ret_ctyp) sz smt)
let builtin_pow2 ctx v ret_ctyp =
match cval_ctyp v, ret_ctyp with
@@ -628,7 +675,7 @@ let builtin_vector_access ctx vec i ret_ctyp =
| CT_vector _, CT_constant i, _ ->
Fn ("select", [smt_cval ctx vec; bvint !vector_index i])
| CT_vector _, index_ctyp, _ ->
- Fn ("select", [smt_cval ctx vec; force_size !vector_index (int_size ctx index_ctyp) (smt_cval ctx i)])
+ Fn ("select", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i)])
| _ -> builtin_type_error ctx "vector_access" [vec; i] (Some ret_ctyp)
@@ -651,7 +698,7 @@ let builtin_vector_update ctx vec i x ret_ctyp =
| CT_vector _, CT_constant i, ctyp, CT_vector _ ->
Fn ("store", [smt_cval ctx vec; bvint !vector_index i; smt_cval ctx x])
| CT_vector _, index_ctyp, _, CT_vector _ ->
- Fn ("store", [smt_cval ctx vec; force_size !vector_index (int_size ctx index_ctyp) (smt_cval ctx i); smt_cval ctx x])
+ Fn ("store", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i); smt_cval ctx x])
| _ -> builtin_type_error ctx "vector_update" [vec; i; x] (Some ret_ctyp)
@@ -699,10 +746,10 @@ let builtin_add_bits_int ctx v1 v2 ret_ctyp =
Fn ("bvadd", [smt_cval ctx v1; bvint o c])
| CT_fbits (n, _), CT_fint m, CT_fbits (o, _) when n = o ->
- Fn ("bvadd", [smt_cval ctx v1; force_size o m (smt_cval ctx v2)])
+ Fn ("bvadd", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)])
| CT_fbits (n, _), CT_lint, CT_fbits (o, _) when n = o ->
- Fn ("bvadd", [smt_cval ctx v1; force_size o ctx.lint_size (smt_cval ctx v2)])
+ Fn ("bvadd", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)])
| _ -> builtin_type_error ctx "add_bits_int" [v1; v2] (Some ret_ctyp)
@@ -712,10 +759,10 @@ let builtin_sub_bits_int ctx v1 v2 ret_ctyp =
Fn ("bvadd", [smt_cval ctx v1; bvint o (Big_int.negate c)])
| CT_fbits (n, _), CT_fint m, CT_fbits (o, _) when n = o ->
- Fn ("bvsub", [smt_cval ctx v1; force_size o m (smt_cval ctx v2)])
+ Fn ("bvsub", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)])
| CT_fbits (n, _), CT_lint, CT_fbits (o, _) when n = o ->
- Fn ("bvsub", [smt_cval ctx v1; force_size o ctx.lint_size (smt_cval ctx v2)])
+ Fn ("bvsub", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)])
| _ -> builtin_type_error ctx "sub_bits_int" [v1; v2] (Some ret_ctyp)
@@ -768,8 +815,8 @@ let builtin_slice ctx v1 v2 v3 ret_ctyp =
Extract(Big_int.to_int (Big_int.pred len), 0, builtin_shift "bvlshr" ctx v1 v2 (cval_ctyp v1))
| CT_fbits(n, ord), ctyp2, _, CT_lbits _ ->
- let smt1 = force_size (lbits_size ctx) n (smt_cval ctx v1) in
- let smt2 = force_size (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in
+ let smt1 = force_size ctx (lbits_size ctx) n (smt_cval ctx v1) in
+ let smt2 = force_size ctx (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in
let smt3 = bvzeint ctx ctx.lbits_index v3 in
Fn ("Bits", [smt3; Fn ("bvand", [Fn ("bvlshr", [smt1; smt2]); bvmask ctx smt3])])
@@ -783,7 +830,7 @@ let builtin_get_slice_int ctx v1 v2 v3 ret_ctyp =
let in_sz = int_size ctx ctyp in
let smt =
if in_sz < len + start then
- force_size (len + start) in_sz (smt_cval ctx v2)
+ force_size ctx (len + start) in_sz (smt_cval ctx v2)
else
smt_cval ctx v2
in
@@ -1196,6 +1243,13 @@ let smt_ssanode ctx cfg preds =
| Some mux ->
[Define_const (zencode_name id, smt_ctyp ctx ctyp, mux)]
+let rec get_pathcond ctx fns =
+ let open Jib_ssa in
+ match fns with
+ | [] | Pi [] :: _ -> Bool_lit true
+ | Pi cvals :: _ -> Fn ("and", List.map (smt_cval ctx) cvals)
+ | Phi _ :: fns -> get_pathcond ctx fns
+
(* For any complex l-expression we need to turn it into a
read-modify-write in the SMT solver. The SSA transform turns CL_id
nodes into CL_rmw (read, write, ctyp) nodes when CL_id is wrapped
@@ -1266,13 +1320,24 @@ let smt_instr ctx =
| [vec; i; x] ->
let sz = int_size ctx (cval_ctyp i) in
[define_const ctx id ret_ctyp
- (Fn ("store", [smt_cval ctx vec; force_size ~checked:false ctx.vector_index sz (smt_cval ctx i); smt_cval ctx x]))]
+ (Fn ("store", [smt_cval ctx vec; force_size ~checked:false ctx ctx.vector_index sz (smt_cval ctx i); smt_cval ctx x]))]
| _ ->
Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update"
end
- else
+ else if string_of_id function_id = "sail_assert" then
+ begin match args with
+ | [assertion; _] ->
+ let smt = smt_cval ctx assertion in
+ add_event ctx Assertion smt;
+ []
+ | _ ->
+ Reporting.unreachable l __POS__ "Bad arguments for assertion"
+ end
+ else if not extern then
let smt_args = List.map (smt_cval ctx) args in
[define_const ctx id ret_ctyp (Fn (zencode_id function_id, smt_args))]
+ else
+ failwith ("Unrecognised function " ^ string_of_id function_id)
| I_aux (I_copy (CL_addr (CL_id (_, _)), _), (_, l)) ->
Reporting.unreachable l __POS__ "Register reference write should be re-written by now"
@@ -1298,22 +1363,26 @@ let smt_instr ctx =
[declare_const ctx id ctyp]
| I_aux (I_end id, _) ->
- if !opt_ignore_overflow then
- [Assert (Fn ("not", [Var (zencode_name id)]))]
- else
- let checks =
- Stack.fold (fun checks -> function (Define_const (name, _, _) as def) -> (name, def) :: checks | _ -> assert false) [] overflow_checks
- in
- List.map snd checks @ [Assert (Fn ("and", Fn ("not", [Var (zencode_name id)]) :: List.map (fun check -> Var (fst check)) checks))]
+ let checks =
+ EventMap.bindings !(ctx.events)
+ |> List.map (fun (ev, stack) ->
+ (ev, Stack.fold (fun checks -> function (Define_const (name, _, _) as def) -> (Var name, def) :: checks | _ -> assert false) [] stack))
+ |> List.map (fun (ev, checks) ->
+ ((ev, List.map fst checks), List.map snd checks))
+ in
+ List.concat (List.map snd checks)
+ @ [Assert (List.fold_left event_check (Fn ("not", [Var (zencode_name id)])) (List.map fst checks))]
| I_aux (I_clear _, _) -> []
| I_aux (I_match_failure, _) -> []
+ | I_aux (I_undefined ctyp, _) -> []
+
| instr ->
failwith ("Cannot translate: " ^ Pretty_print_sail.to_string (pp_instr instr))
-let smt_cfnode all_cdefs ctx =
+let smt_cfnode all_cdefs ctx ssanodes =
let open Jib_ssa in
function
| CF_start inits ->
@@ -1325,6 +1394,7 @@ let smt_cfnode all_cdefs ctx =
in
smt_reg_decs @ List.map smt_start (NameMap.bindings inits)
| CF_block instrs ->
+ let ctx = { ctx with pathcond = get_pathcond ctx ssanodes } in
List.concat (List.map (smt_instr ctx) instrs)
(* We can ignore any non basic-block/start control-flow nodes *)
| _ -> []
@@ -1413,7 +1483,7 @@ let optimize_smt stack =
let smt_header ctx cdefs =
let smt_ctype_defs = List.concat (generate_ctype_defs ctx cdefs) in
[declare_datatypes (mk_enum "Unit" ["unit"])]
- @ (IntSet.elements ctx.tuple_sizes |> List.map (fun n -> Declare_tuple n))
+ @ (IntSet.elements !(ctx.tuple_sizes) |> List.map (fun n -> Declare_tuple n))
@ [declare_datatypes (mk_record "Bits" [("len", Bitvec ctx.lbits_index);
("contents", Bitvec (lbits_size ctx))])
@@ -1482,8 +1552,7 @@ let smt_cdef props lets name_file ctx all_cdefs = function
let stack = Stack.create () in
- Stack.clear overflow_checks;
- let ctx = { ctx with pragma_l = pragma_l; arg_stack = Stack.create () } in
+ let ctx = { ctx with events = ref EventMap.empty; pragma_l = pragma_l; arg_stack = Stack.create () } in
(* When we create each argument declaration, give it a unique
location from the $property pragma, so we can identify it later. *)
@@ -1524,7 +1593,7 @@ let smt_cdef props lets name_file ctx all_cdefs = function
let muxers =
ssanodes |> List.map (smt_ssanode ctx cfg preds) |> List.concat
in
- let basic_block = smt_cfnode all_cdefs ctx cfnode in
+ let basic_block = smt_cfnode all_cdefs ctx ssanodes cfnode in
push_smt_defs stack muxers;
push_smt_defs stack basic_block;
end
diff --git a/src/smtlib.ml b/src/smtlib.ml
index 3ba85306..fe7aee5e 100644
--- a/src/smtlib.ml
+++ b/src/smtlib.ml
@@ -124,12 +124,16 @@ let bvones n =
let simp_fn = function
| Fn ("not", [Fn ("not", [exp])]) -> exp
+ | Fn ("not", [Bool_lit b]) -> Bool_lit (not b)
| Fn ("contents", [Fn ("Bits", [_; contents])]) -> contents
| Fn ("len", [Fn ("Bits", [len; _])]) -> len
+ | Fn ("or", [x]) -> x
+ | Fn ("and", [x]) -> x
| exp -> exp
let simp_ite = function
| Ite (cond, Bool_lit true, Bool_lit false) -> cond
+ | Ite (cond, Bool_lit x, Bool_lit y) when x = y -> Bool_lit x
| Ite (_, Var v, Var v') when v = v' -> Var v
| Ite (Bool_lit true, then_exp, _) -> then_exp
| Ite (Bool_lit false, _, else_exp) -> else_exp