diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/jib/jib_smt.ml | 161 | ||||
| -rw-r--r-- | src/smtlib.ml | 4 |
2 files changed, 119 insertions, 46 deletions
diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index 230e9e07..b62e761e 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -66,6 +66,27 @@ let opt_ignore_overflow = ref false let opt_auto = ref false +type event = Overflow | Assertion + +let event_name = function + | Overflow -> "overflow" + | Assertion -> "assert" + +module Event = struct + type t = event + let compare e1 e2 = + match e1, e2 with + | Overflow, Overflow -> 0 + | Assertion, Assertion -> 0 + | Overflow, _ -> 1 + | _, Overflow -> -1 +end + +module EventMap = Map.Make(Event) + +(* Note that we have to use x : ty ref rather than mutable x : ty, to + make sure { ctx with x = ... } doesn't break the mutable state. *) + type ctx = { (* Arbitrary-precision bitvectors are represented as a (BitVec lbits_index, BitVec (2 ^ lbits_index)) pair. *) lbits_index : int; @@ -76,11 +97,13 @@ type ctx = { vector_index : int; register_map : id list CTMap.t; (* Keep track of all the tuple sizes we need to genenerate types for *) - mutable tuple_sizes : IntSet.t; + tuple_sizes : IntSet.t ref; tc_env : Type_check.Env.t; pragma_l : Ast.l; arg_stack : (int * string) Stack.t; - ast : Type_check.tannot defs + ast : Type_check.tannot defs; + events : smt_def Stack.t EventMap.t ref; + pathcond : smt_exp; } (* These give the default bounds for various SMT types, stored in the @@ -96,13 +119,34 @@ let initial_ctx () = { lint_size = !opt_default_lint_size; vector_index = !opt_default_vector_index; register_map = CTMap.empty; - tuple_sizes = IntSet.empty; + tuple_sizes = ref IntSet.empty; tc_env = Type_check.initial_env; pragma_l = Parse_ast.Unknown; arg_stack = Stack.create (); - ast = Defs [] + ast = Defs []; + events = ref EventMap.empty; + pathcond = Bool_lit true; } +let event_stack ctx ev = + match EventMap.find_opt ev !(ctx.events) with + | Some stack -> stack + | None -> + let stack = Stack.create () in + ctx.events := EventMap.add ev stack !(ctx.events); + stack + +let event_check check (ev, checks) = + match ev with + | Overflow -> + if !opt_ignore_overflow then + Bool_lit false + else + Fn ("and", check :: checks) + + | Assertion -> + Fn ("or", check :: List.map (fun c -> Fn ("not", [c])) checks) + let lbits_size ctx = Util.power 2 ctx.lbits_index let vector_index = ref 5 @@ -138,7 +182,7 @@ let rec smt_ctyp ctx = function | CT_variant (id, ctors) -> mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) ctors) | CT_tup ctyps -> - ctx.tuple_sizes <- IntSet.add (List.length ctyps) ctx.tuple_sizes; + ctx.tuple_sizes := IntSet.add (List.length ctyps) !(ctx.tuple_sizes); Tuple (List.map (smt_ctyp ctx) ctyps) | CT_vector (_, ctyp) -> Array (Bitvec !vector_index, smt_ctyp ctx ctyp) | CT_string -> Bitvec 64 @@ -210,6 +254,7 @@ let smt_value ctx vl ctyp = | VL_bit Sail2_values.B0, CT_bit -> Bin "0" | VL_bit Sail2_values.B1, CT_bit -> Bin "1" | VL_unit, _ -> Var "unit" + | VL_string _, _ -> Var "unit" (* FIXME: String support *) | vl, _ -> failwith ("Cannot translate literal to SMT " ^ string_of_value vl) let zencode_ctor ctor_id unifiers = @@ -260,7 +305,7 @@ let rec smt_cval ctx cval = | _ -> failwith "Struct does not have struct type" end | V_tuple_member (frag, len, n) -> - ctx.tuple_sizes <- IntSet.add len ctx.tuple_sizes; + ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes); Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag]) | V_ref (Name (id, _), _) -> let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in @@ -276,12 +321,14 @@ let rec smt_cval ctx cval = end | cval -> failwith ("Unrecognised cval " ^ string_of_cval ~zencode:false cval) -let overflow_checks = Stack.create () +let add_event ctx ev smt = + let stack = event_stack ctx ev in + Stack.push (Define_const (event_name ev ^ string_of_int (Stack.length stack), Bool, Fn ("=>", [ctx.pathcond; smt]))) stack -let overflow_check smt = +let overflow_check ctx smt = if not !opt_ignore_overflow then ( Util.warn "Adding overflow check in generated SMT"; - Stack.push (Define_const ("overflow" ^ string_of_int (Stack.length overflow_checks), Bool, Fn ("not", [smt]))) overflow_checks + add_event ctx Overflow (Fn ("not", [smt])) ) (**************************************************************************) @@ -338,7 +385,7 @@ let builtin_gteq = builtin_int_comparison "bvsge" Big_int.greater_equal (** [force_size n m exp] takes a smt expression assumed to be a integer (signed bitvector) of length m and forces it to be length n by either sign extending it or truncating it as required *) -let force_size ?checked:(checked=true) n m smt = +let force_size ?checked:(checked=true) ctx n m smt = if n = m then smt else if n > m then @@ -352,7 +399,7 @@ let force_size ?checked:(checked=true) n m smt = (* Otherwise, all the top bits must be zero *) Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvzero (m - n)])])) in - if checked then overflow_check check else (); + if checked then overflow_check ctx check else (); Extract (n - 1, 0, smt) let int_size ctx = function @@ -375,17 +422,17 @@ let builtin_arith fn big_int_fn padding ctx v1 v2 ret_ctyp = | ctyp, CT_constant c, _ -> let n = int_size ctx ctyp in - force_size (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c])) + force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c])) | CT_constant c, ctyp, _ -> let n = int_size ctx ctyp in - force_size (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2])) + force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2])) | ctyp1, ctyp2, _ -> let ret_sz = int_size ctx ret_ctyp in let smt1 = smt_cval ctx v1 in let smt2 = smt_cval ctx v2 in - force_size ret_sz (padding ret_sz) (Fn (fn, [force_size (padding ret_sz) (int_size ctx ctyp1) smt1; - force_size (padding ret_sz) (int_size ctx ctyp2) smt2])) + force_size ctx ret_sz (padding ret_sz) (Fn (fn, [force_size ctx (padding ret_sz) (int_size ctx ctyp1) smt1; + force_size ctx (padding ret_sz) (int_size ctx ctyp2) smt2])) let builtin_add_int = builtin_arith "bvadd" Big_int.add (fun x -> x + 1) let builtin_sub_int = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1) @@ -398,8 +445,8 @@ let builtin_negate_int ctx v ret_ctyp = | CT_constant c, _ -> bvint (int_size ctx ret_ctyp) (Big_int.negate c) | ctyp, _ -> - let smt = force_size (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in - overflow_check (Fn ("=", [smt; Bin ("1" ^ String.make (int_size ctx ret_ctyp - 1) '0')])); + let smt = force_size ctx (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in + overflow_check ctx (Fn ("=", [smt; Bin ("1" ^ String.make (int_size ctx ret_ctyp - 1) '0')])); Fn ("bvneg", [smt]) let builtin_shift_int fn big_int_fn ctx v1 v2 ret_ctyp = @@ -411,17 +458,17 @@ let builtin_shift_int fn big_int_fn ctx v1 v2 ret_ctyp = | ctyp, CT_constant c, _ -> let n = int_size ctx ctyp in - force_size (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c])) + force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c])) | CT_constant c, ctyp, _ -> let n = int_size ctx ctyp in - force_size (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2])) + force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2])) | ctyp1, ctyp2, _ -> let ret_sz = int_size ctx ret_ctyp in let smt1 = smt_cval ctx v1 in let smt2 = smt_cval ctx v2 in - (Fn (fn, [force_size ret_sz (int_size ctx ctyp1) smt1; - force_size ret_sz (int_size ctx ctyp2) smt2])) + (Fn (fn, [force_size ctx ret_sz (int_size ctx ctyp1) smt1; + force_size ctx ret_sz (int_size ctx ctyp2) smt2])) let builtin_shl_int = builtin_shift_int "bvshl" Big_int.shift_left let builtin_shr_int = builtin_shift_int "bvashr" Big_int.shift_right @@ -436,8 +483,8 @@ let builtin_abs_int ctx v ret_ctyp = let sz = int_size ctx ctyp in let smt = smt_cval ctx v in Ite (Fn ("=", [Extract (sz - 1, sz -1, smt); Bin "1"]), - force_size (int_size ctx ret_ctyp) sz (Fn ("bvneg", [smt])), - force_size (int_size ctx ret_ctyp) sz smt) + force_size ctx (int_size ctx ret_ctyp) sz (Fn ("bvneg", [smt])), + force_size ctx (int_size ctx ret_ctyp) sz smt) let builtin_pow2 ctx v ret_ctyp = match cval_ctyp v, ret_ctyp with @@ -628,7 +675,7 @@ let builtin_vector_access ctx vec i ret_ctyp = | CT_vector _, CT_constant i, _ -> Fn ("select", [smt_cval ctx vec; bvint !vector_index i]) | CT_vector _, index_ctyp, _ -> - Fn ("select", [smt_cval ctx vec; force_size !vector_index (int_size ctx index_ctyp) (smt_cval ctx i)]) + Fn ("select", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i)]) | _ -> builtin_type_error ctx "vector_access" [vec; i] (Some ret_ctyp) @@ -651,7 +698,7 @@ let builtin_vector_update ctx vec i x ret_ctyp = | CT_vector _, CT_constant i, ctyp, CT_vector _ -> Fn ("store", [smt_cval ctx vec; bvint !vector_index i; smt_cval ctx x]) | CT_vector _, index_ctyp, _, CT_vector _ -> - Fn ("store", [smt_cval ctx vec; force_size !vector_index (int_size ctx index_ctyp) (smt_cval ctx i); smt_cval ctx x]) + Fn ("store", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i); smt_cval ctx x]) | _ -> builtin_type_error ctx "vector_update" [vec; i; x] (Some ret_ctyp) @@ -699,10 +746,10 @@ let builtin_add_bits_int ctx v1 v2 ret_ctyp = Fn ("bvadd", [smt_cval ctx v1; bvint o c]) | CT_fbits (n, _), CT_fint m, CT_fbits (o, _) when n = o -> - Fn ("bvadd", [smt_cval ctx v1; force_size o m (smt_cval ctx v2)]) + Fn ("bvadd", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)]) | CT_fbits (n, _), CT_lint, CT_fbits (o, _) when n = o -> - Fn ("bvadd", [smt_cval ctx v1; force_size o ctx.lint_size (smt_cval ctx v2)]) + Fn ("bvadd", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)]) | _ -> builtin_type_error ctx "add_bits_int" [v1; v2] (Some ret_ctyp) @@ -712,10 +759,10 @@ let builtin_sub_bits_int ctx v1 v2 ret_ctyp = Fn ("bvadd", [smt_cval ctx v1; bvint o (Big_int.negate c)]) | CT_fbits (n, _), CT_fint m, CT_fbits (o, _) when n = o -> - Fn ("bvsub", [smt_cval ctx v1; force_size o m (smt_cval ctx v2)]) + Fn ("bvsub", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)]) | CT_fbits (n, _), CT_lint, CT_fbits (o, _) when n = o -> - Fn ("bvsub", [smt_cval ctx v1; force_size o ctx.lint_size (smt_cval ctx v2)]) + Fn ("bvsub", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)]) | _ -> builtin_type_error ctx "sub_bits_int" [v1; v2] (Some ret_ctyp) @@ -768,8 +815,8 @@ let builtin_slice ctx v1 v2 v3 ret_ctyp = Extract(Big_int.to_int (Big_int.pred len), 0, builtin_shift "bvlshr" ctx v1 v2 (cval_ctyp v1)) | CT_fbits(n, ord), ctyp2, _, CT_lbits _ -> - let smt1 = force_size (lbits_size ctx) n (smt_cval ctx v1) in - let smt2 = force_size (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in + let smt1 = force_size ctx (lbits_size ctx) n (smt_cval ctx v1) in + let smt2 = force_size ctx (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in let smt3 = bvzeint ctx ctx.lbits_index v3 in Fn ("Bits", [smt3; Fn ("bvand", [Fn ("bvlshr", [smt1; smt2]); bvmask ctx smt3])]) @@ -783,7 +830,7 @@ let builtin_get_slice_int ctx v1 v2 v3 ret_ctyp = let in_sz = int_size ctx ctyp in let smt = if in_sz < len + start then - force_size (len + start) in_sz (smt_cval ctx v2) + force_size ctx (len + start) in_sz (smt_cval ctx v2) else smt_cval ctx v2 in @@ -1196,6 +1243,13 @@ let smt_ssanode ctx cfg preds = | Some mux -> [Define_const (zencode_name id, smt_ctyp ctx ctyp, mux)] +let rec get_pathcond ctx fns = + let open Jib_ssa in + match fns with + | [] | Pi [] :: _ -> Bool_lit true + | Pi cvals :: _ -> Fn ("and", List.map (smt_cval ctx) cvals) + | Phi _ :: fns -> get_pathcond ctx fns + (* For any complex l-expression we need to turn it into a read-modify-write in the SMT solver. The SSA transform turns CL_id nodes into CL_rmw (read, write, ctyp) nodes when CL_id is wrapped @@ -1266,13 +1320,24 @@ let smt_instr ctx = | [vec; i; x] -> let sz = int_size ctx (cval_ctyp i) in [define_const ctx id ret_ctyp - (Fn ("store", [smt_cval ctx vec; force_size ~checked:false ctx.vector_index sz (smt_cval ctx i); smt_cval ctx x]))] + (Fn ("store", [smt_cval ctx vec; force_size ~checked:false ctx ctx.vector_index sz (smt_cval ctx i); smt_cval ctx x]))] | _ -> Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update" end - else + else if string_of_id function_id = "sail_assert" then + begin match args with + | [assertion; _] -> + let smt = smt_cval ctx assertion in + add_event ctx Assertion smt; + [] + | _ -> + Reporting.unreachable l __POS__ "Bad arguments for assertion" + end + else if not extern then let smt_args = List.map (smt_cval ctx) args in [define_const ctx id ret_ctyp (Fn (zencode_id function_id, smt_args))] + else + failwith ("Unrecognised function " ^ string_of_id function_id) | I_aux (I_copy (CL_addr (CL_id (_, _)), _), (_, l)) -> Reporting.unreachable l __POS__ "Register reference write should be re-written by now" @@ -1298,22 +1363,26 @@ let smt_instr ctx = [declare_const ctx id ctyp] | I_aux (I_end id, _) -> - if !opt_ignore_overflow then - [Assert (Fn ("not", [Var (zencode_name id)]))] - else - let checks = - Stack.fold (fun checks -> function (Define_const (name, _, _) as def) -> (name, def) :: checks | _ -> assert false) [] overflow_checks - in - List.map snd checks @ [Assert (Fn ("and", Fn ("not", [Var (zencode_name id)]) :: List.map (fun check -> Var (fst check)) checks))] + let checks = + EventMap.bindings !(ctx.events) + |> List.map (fun (ev, stack) -> + (ev, Stack.fold (fun checks -> function (Define_const (name, _, _) as def) -> (Var name, def) :: checks | _ -> assert false) [] stack)) + |> List.map (fun (ev, checks) -> + ((ev, List.map fst checks), List.map snd checks)) + in + List.concat (List.map snd checks) + @ [Assert (List.fold_left event_check (Fn ("not", [Var (zencode_name id)])) (List.map fst checks))] | I_aux (I_clear _, _) -> [] | I_aux (I_match_failure, _) -> [] + | I_aux (I_undefined ctyp, _) -> [] + | instr -> failwith ("Cannot translate: " ^ Pretty_print_sail.to_string (pp_instr instr)) -let smt_cfnode all_cdefs ctx = +let smt_cfnode all_cdefs ctx ssanodes = let open Jib_ssa in function | CF_start inits -> @@ -1325,6 +1394,7 @@ let smt_cfnode all_cdefs ctx = in smt_reg_decs @ List.map smt_start (NameMap.bindings inits) | CF_block instrs -> + let ctx = { ctx with pathcond = get_pathcond ctx ssanodes } in List.concat (List.map (smt_instr ctx) instrs) (* We can ignore any non basic-block/start control-flow nodes *) | _ -> [] @@ -1413,7 +1483,7 @@ let optimize_smt stack = let smt_header ctx cdefs = let smt_ctype_defs = List.concat (generate_ctype_defs ctx cdefs) in [declare_datatypes (mk_enum "Unit" ["unit"])] - @ (IntSet.elements ctx.tuple_sizes |> List.map (fun n -> Declare_tuple n)) + @ (IntSet.elements !(ctx.tuple_sizes) |> List.map (fun n -> Declare_tuple n)) @ [declare_datatypes (mk_record "Bits" [("len", Bitvec ctx.lbits_index); ("contents", Bitvec (lbits_size ctx))]) @@ -1482,8 +1552,7 @@ let smt_cdef props lets name_file ctx all_cdefs = function let stack = Stack.create () in - Stack.clear overflow_checks; - let ctx = { ctx with pragma_l = pragma_l; arg_stack = Stack.create () } in + let ctx = { ctx with events = ref EventMap.empty; pragma_l = pragma_l; arg_stack = Stack.create () } in (* When we create each argument declaration, give it a unique location from the $property pragma, so we can identify it later. *) @@ -1524,7 +1593,7 @@ let smt_cdef props lets name_file ctx all_cdefs = function let muxers = ssanodes |> List.map (smt_ssanode ctx cfg preds) |> List.concat in - let basic_block = smt_cfnode all_cdefs ctx cfnode in + let basic_block = smt_cfnode all_cdefs ctx ssanodes cfnode in push_smt_defs stack muxers; push_smt_defs stack basic_block; end diff --git a/src/smtlib.ml b/src/smtlib.ml index 3ba85306..fe7aee5e 100644 --- a/src/smtlib.ml +++ b/src/smtlib.ml @@ -124,12 +124,16 @@ let bvones n = let simp_fn = function | Fn ("not", [Fn ("not", [exp])]) -> exp + | Fn ("not", [Bool_lit b]) -> Bool_lit (not b) | Fn ("contents", [Fn ("Bits", [_; contents])]) -> contents | Fn ("len", [Fn ("Bits", [len; _])]) -> len + | Fn ("or", [x]) -> x + | Fn ("and", [x]) -> x | exp -> exp let simp_ite = function | Ite (cond, Bool_lit true, Bool_lit false) -> cond + | Ite (cond, Bool_lit x, Bool_lit y) when x = y -> Bool_lit x | Ite (_, Var v, Var v') when v = v' -> Var v | Ite (Bool_lit true, then_exp, _) -> then_exp | Ite (Bool_lit false, _, else_exp) -> else_exp |
