summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlasdair Armstrong2019-04-29 20:52:18 +0100
committerAlasdair Armstrong2019-04-29 20:56:22 +0100
commit44d533ad7e47daa636ee00a256b09fb3df9433ee (patch)
tree3895fb739b96db4657477d51ff13854cf8175c89 /src
parente1d0f39a9803051c646363c95c13c2d4ffb961c7 (diff)
SMT: Refactor overflow checks into generic event checking system
Have assert events for assertions and overflow events for potential integer overflow. Unclear how these should interact... The order in which such events are applied to the final assertion is potentially quite important. Overflow checks and assertions are now path sensitive, as they should be.
Diffstat (limited to 'src')
-rw-r--r--src/jib/jib_smt.ml161
-rw-r--r--src/smtlib.ml4
2 files changed, 119 insertions, 46 deletions
diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml
index 230e9e07..b62e761e 100644
--- a/src/jib/jib_smt.ml
+++ b/src/jib/jib_smt.ml
@@ -66,6 +66,27 @@ let opt_ignore_overflow = ref false
let opt_auto = ref false
+type event = Overflow | Assertion
+
+let event_name = function
+ | Overflow -> "overflow"
+ | Assertion -> "assert"
+
+module Event = struct
+ type t = event
+ let compare e1 e2 =
+ match e1, e2 with
+ | Overflow, Overflow -> 0
+ | Assertion, Assertion -> 0
+ | Overflow, _ -> 1
+ | _, Overflow -> -1
+end
+
+module EventMap = Map.Make(Event)
+
+(* Note that we have to use x : ty ref rather than mutable x : ty, to
+ make sure { ctx with x = ... } doesn't break the mutable state. *)
+
type ctx = {
(* Arbitrary-precision bitvectors are represented as a (BitVec lbits_index, BitVec (2 ^ lbits_index)) pair. *)
lbits_index : int;
@@ -76,11 +97,13 @@ type ctx = {
vector_index : int;
register_map : id list CTMap.t;
(* Keep track of all the tuple sizes we need to genenerate types for *)
- mutable tuple_sizes : IntSet.t;
+ tuple_sizes : IntSet.t ref;
tc_env : Type_check.Env.t;
pragma_l : Ast.l;
arg_stack : (int * string) Stack.t;
- ast : Type_check.tannot defs
+ ast : Type_check.tannot defs;
+ events : smt_def Stack.t EventMap.t ref;
+ pathcond : smt_exp;
}
(* These give the default bounds for various SMT types, stored in the
@@ -96,13 +119,34 @@ let initial_ctx () = {
lint_size = !opt_default_lint_size;
vector_index = !opt_default_vector_index;
register_map = CTMap.empty;
- tuple_sizes = IntSet.empty;
+ tuple_sizes = ref IntSet.empty;
tc_env = Type_check.initial_env;
pragma_l = Parse_ast.Unknown;
arg_stack = Stack.create ();
- ast = Defs []
+ ast = Defs [];
+ events = ref EventMap.empty;
+ pathcond = Bool_lit true;
}
+let event_stack ctx ev =
+ match EventMap.find_opt ev !(ctx.events) with
+ | Some stack -> stack
+ | None ->
+ let stack = Stack.create () in
+ ctx.events := EventMap.add ev stack !(ctx.events);
+ stack
+
+let event_check check (ev, checks) =
+ match ev with
+ | Overflow ->
+ if !opt_ignore_overflow then
+ Bool_lit false
+ else
+ Fn ("and", check :: checks)
+
+ | Assertion ->
+ Fn ("or", check :: List.map (fun c -> Fn ("not", [c])) checks)
+
let lbits_size ctx = Util.power 2 ctx.lbits_index
let vector_index = ref 5
@@ -138,7 +182,7 @@ let rec smt_ctyp ctx = function
| CT_variant (id, ctors) ->
mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) ctors)
| CT_tup ctyps ->
- ctx.tuple_sizes <- IntSet.add (List.length ctyps) ctx.tuple_sizes;
+ ctx.tuple_sizes := IntSet.add (List.length ctyps) !(ctx.tuple_sizes);
Tuple (List.map (smt_ctyp ctx) ctyps)
| CT_vector (_, ctyp) -> Array (Bitvec !vector_index, smt_ctyp ctx ctyp)
| CT_string -> Bitvec 64
@@ -210,6 +254,7 @@ let smt_value ctx vl ctyp =
| VL_bit Sail2_values.B0, CT_bit -> Bin "0"
| VL_bit Sail2_values.B1, CT_bit -> Bin "1"
| VL_unit, _ -> Var "unit"
+ | VL_string _, _ -> Var "unit" (* FIXME: String support *)
| vl, _ -> failwith ("Cannot translate literal to SMT " ^ string_of_value vl)
let zencode_ctor ctor_id unifiers =
@@ -260,7 +305,7 @@ let rec smt_cval ctx cval =
| _ -> failwith "Struct does not have struct type"
end
| V_tuple_member (frag, len, n) ->
- ctx.tuple_sizes <- IntSet.add len ctx.tuple_sizes;
+ ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes);
Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag])
| V_ref (Name (id, _), _) ->
let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in
@@ -276,12 +321,14 @@ let rec smt_cval ctx cval =
end
| cval -> failwith ("Unrecognised cval " ^ string_of_cval ~zencode:false cval)
-let overflow_checks = Stack.create ()
+let add_event ctx ev smt =
+ let stack = event_stack ctx ev in
+ Stack.push (Define_const (event_name ev ^ string_of_int (Stack.length stack), Bool, Fn ("=>", [ctx.pathcond; smt]))) stack
-let overflow_check smt =
+let overflow_check ctx smt =
if not !opt_ignore_overflow then (
Util.warn "Adding overflow check in generated SMT";
- Stack.push (Define_const ("overflow" ^ string_of_int (Stack.length overflow_checks), Bool, Fn ("not", [smt]))) overflow_checks
+ add_event ctx Overflow (Fn ("not", [smt]))
)
(**************************************************************************)
@@ -338,7 +385,7 @@ let builtin_gteq = builtin_int_comparison "bvsge" Big_int.greater_equal
(** [force_size n m exp] takes a smt expression assumed to be a
integer (signed bitvector) of length m and forces it to be length n
by either sign extending it or truncating it as required *)
-let force_size ?checked:(checked=true) n m smt =
+let force_size ?checked:(checked=true) ctx n m smt =
if n = m then
smt
else if n > m then
@@ -352,7 +399,7 @@ let force_size ?checked:(checked=true) n m smt =
(* Otherwise, all the top bits must be zero *)
Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvzero (m - n)])]))
in
- if checked then overflow_check check else ();
+ if checked then overflow_check ctx check else ();
Extract (n - 1, 0, smt)
let int_size ctx = function
@@ -375,17 +422,17 @@ let builtin_arith fn big_int_fn padding ctx v1 v2 ret_ctyp =
| ctyp, CT_constant c, _ ->
let n = int_size ctx ctyp in
- force_size (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c]))
+ force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c]))
| CT_constant c, ctyp, _ ->
let n = int_size ctx ctyp in
- force_size (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2]))
+ force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2]))
| ctyp1, ctyp2, _ ->
let ret_sz = int_size ctx ret_ctyp in
let smt1 = smt_cval ctx v1 in
let smt2 = smt_cval ctx v2 in
- force_size ret_sz (padding ret_sz) (Fn (fn, [force_size (padding ret_sz) (int_size ctx ctyp1) smt1;
- force_size (padding ret_sz) (int_size ctx ctyp2) smt2]))
+ force_size ctx ret_sz (padding ret_sz) (Fn (fn, [force_size ctx (padding ret_sz) (int_size ctx ctyp1) smt1;
+ force_size ctx (padding ret_sz) (int_size ctx ctyp2) smt2]))
let builtin_add_int = builtin_arith "bvadd" Big_int.add (fun x -> x + 1)
let builtin_sub_int = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1)
@@ -398,8 +445,8 @@ let builtin_negate_int ctx v ret_ctyp =
| CT_constant c, _ ->
bvint (int_size ctx ret_ctyp) (Big_int.negate c)
| ctyp, _ ->
- let smt = force_size (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in
- overflow_check (Fn ("=", [smt; Bin ("1" ^ String.make (int_size ctx ret_ctyp - 1) '0')]));
+ let smt = force_size ctx (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in
+ overflow_check ctx (Fn ("=", [smt; Bin ("1" ^ String.make (int_size ctx ret_ctyp - 1) '0')]));
Fn ("bvneg", [smt])
let builtin_shift_int fn big_int_fn ctx v1 v2 ret_ctyp =
@@ -411,17 +458,17 @@ let builtin_shift_int fn big_int_fn ctx v1 v2 ret_ctyp =
| ctyp, CT_constant c, _ ->
let n = int_size ctx ctyp in
- force_size (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c]))
+ force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c]))
| CT_constant c, ctyp, _ ->
let n = int_size ctx ctyp in
- force_size (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2]))
+ force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2]))
| ctyp1, ctyp2, _ ->
let ret_sz = int_size ctx ret_ctyp in
let smt1 = smt_cval ctx v1 in
let smt2 = smt_cval ctx v2 in
- (Fn (fn, [force_size ret_sz (int_size ctx ctyp1) smt1;
- force_size ret_sz (int_size ctx ctyp2) smt2]))
+ (Fn (fn, [force_size ctx ret_sz (int_size ctx ctyp1) smt1;
+ force_size ctx ret_sz (int_size ctx ctyp2) smt2]))
let builtin_shl_int = builtin_shift_int "bvshl" Big_int.shift_left
let builtin_shr_int = builtin_shift_int "bvashr" Big_int.shift_right
@@ -436,8 +483,8 @@ let builtin_abs_int ctx v ret_ctyp =
let sz = int_size ctx ctyp in
let smt = smt_cval ctx v in
Ite (Fn ("=", [Extract (sz - 1, sz -1, smt); Bin "1"]),
- force_size (int_size ctx ret_ctyp) sz (Fn ("bvneg", [smt])),
- force_size (int_size ctx ret_ctyp) sz smt)
+ force_size ctx (int_size ctx ret_ctyp) sz (Fn ("bvneg", [smt])),
+ force_size ctx (int_size ctx ret_ctyp) sz smt)
let builtin_pow2 ctx v ret_ctyp =
match cval_ctyp v, ret_ctyp with
@@ -628,7 +675,7 @@ let builtin_vector_access ctx vec i ret_ctyp =
| CT_vector _, CT_constant i, _ ->
Fn ("select", [smt_cval ctx vec; bvint !vector_index i])
| CT_vector _, index_ctyp, _ ->
- Fn ("select", [smt_cval ctx vec; force_size !vector_index (int_size ctx index_ctyp) (smt_cval ctx i)])
+ Fn ("select", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i)])
| _ -> builtin_type_error ctx "vector_access" [vec; i] (Some ret_ctyp)
@@ -651,7 +698,7 @@ let builtin_vector_update ctx vec i x ret_ctyp =
| CT_vector _, CT_constant i, ctyp, CT_vector _ ->
Fn ("store", [smt_cval ctx vec; bvint !vector_index i; smt_cval ctx x])
| CT_vector _, index_ctyp, _, CT_vector _ ->
- Fn ("store", [smt_cval ctx vec; force_size !vector_index (int_size ctx index_ctyp) (smt_cval ctx i); smt_cval ctx x])
+ Fn ("store", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i); smt_cval ctx x])
| _ -> builtin_type_error ctx "vector_update" [vec; i; x] (Some ret_ctyp)
@@ -699,10 +746,10 @@ let builtin_add_bits_int ctx v1 v2 ret_ctyp =
Fn ("bvadd", [smt_cval ctx v1; bvint o c])
| CT_fbits (n, _), CT_fint m, CT_fbits (o, _) when n = o ->
- Fn ("bvadd", [smt_cval ctx v1; force_size o m (smt_cval ctx v2)])
+ Fn ("bvadd", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)])
| CT_fbits (n, _), CT_lint, CT_fbits (o, _) when n = o ->
- Fn ("bvadd", [smt_cval ctx v1; force_size o ctx.lint_size (smt_cval ctx v2)])
+ Fn ("bvadd", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)])
| _ -> builtin_type_error ctx "add_bits_int" [v1; v2] (Some ret_ctyp)
@@ -712,10 +759,10 @@ let builtin_sub_bits_int ctx v1 v2 ret_ctyp =
Fn ("bvadd", [smt_cval ctx v1; bvint o (Big_int.negate c)])
| CT_fbits (n, _), CT_fint m, CT_fbits (o, _) when n = o ->
- Fn ("bvsub", [smt_cval ctx v1; force_size o m (smt_cval ctx v2)])
+ Fn ("bvsub", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)])
| CT_fbits (n, _), CT_lint, CT_fbits (o, _) when n = o ->
- Fn ("bvsub", [smt_cval ctx v1; force_size o ctx.lint_size (smt_cval ctx v2)])
+ Fn ("bvsub", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)])
| _ -> builtin_type_error ctx "sub_bits_int" [v1; v2] (Some ret_ctyp)
@@ -768,8 +815,8 @@ let builtin_slice ctx v1 v2 v3 ret_ctyp =
Extract(Big_int.to_int (Big_int.pred len), 0, builtin_shift "bvlshr" ctx v1 v2 (cval_ctyp v1))
| CT_fbits(n, ord), ctyp2, _, CT_lbits _ ->
- let smt1 = force_size (lbits_size ctx) n (smt_cval ctx v1) in
- let smt2 = force_size (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in
+ let smt1 = force_size ctx (lbits_size ctx) n (smt_cval ctx v1) in
+ let smt2 = force_size ctx (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in
let smt3 = bvzeint ctx ctx.lbits_index v3 in
Fn ("Bits", [smt3; Fn ("bvand", [Fn ("bvlshr", [smt1; smt2]); bvmask ctx smt3])])
@@ -783,7 +830,7 @@ let builtin_get_slice_int ctx v1 v2 v3 ret_ctyp =
let in_sz = int_size ctx ctyp in
let smt =
if in_sz < len + start then
- force_size (len + start) in_sz (smt_cval ctx v2)
+ force_size ctx (len + start) in_sz (smt_cval ctx v2)
else
smt_cval ctx v2
in
@@ -1196,6 +1243,13 @@ let smt_ssanode ctx cfg preds =
| Some mux ->
[Define_const (zencode_name id, smt_ctyp ctx ctyp, mux)]
+let rec get_pathcond ctx fns =
+ let open Jib_ssa in
+ match fns with
+ | [] | Pi [] :: _ -> Bool_lit true
+ | Pi cvals :: _ -> Fn ("and", List.map (smt_cval ctx) cvals)
+ | Phi _ :: fns -> get_pathcond ctx fns
+
(* For any complex l-expression we need to turn it into a
read-modify-write in the SMT solver. The SSA transform turns CL_id
nodes into CL_rmw (read, write, ctyp) nodes when CL_id is wrapped
@@ -1266,13 +1320,24 @@ let smt_instr ctx =
| [vec; i; x] ->
let sz = int_size ctx (cval_ctyp i) in
[define_const ctx id ret_ctyp
- (Fn ("store", [smt_cval ctx vec; force_size ~checked:false ctx.vector_index sz (smt_cval ctx i); smt_cval ctx x]))]
+ (Fn ("store", [smt_cval ctx vec; force_size ~checked:false ctx ctx.vector_index sz (smt_cval ctx i); smt_cval ctx x]))]
| _ ->
Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update"
end
- else
+ else if string_of_id function_id = "sail_assert" then
+ begin match args with
+ | [assertion; _] ->
+ let smt = smt_cval ctx assertion in
+ add_event ctx Assertion smt;
+ []
+ | _ ->
+ Reporting.unreachable l __POS__ "Bad arguments for assertion"
+ end
+ else if not extern then
let smt_args = List.map (smt_cval ctx) args in
[define_const ctx id ret_ctyp (Fn (zencode_id function_id, smt_args))]
+ else
+ failwith ("Unrecognised function " ^ string_of_id function_id)
| I_aux (I_copy (CL_addr (CL_id (_, _)), _), (_, l)) ->
Reporting.unreachable l __POS__ "Register reference write should be re-written by now"
@@ -1298,22 +1363,26 @@ let smt_instr ctx =
[declare_const ctx id ctyp]
| I_aux (I_end id, _) ->
- if !opt_ignore_overflow then
- [Assert (Fn ("not", [Var (zencode_name id)]))]
- else
- let checks =
- Stack.fold (fun checks -> function (Define_const (name, _, _) as def) -> (name, def) :: checks | _ -> assert false) [] overflow_checks
- in
- List.map snd checks @ [Assert (Fn ("and", Fn ("not", [Var (zencode_name id)]) :: List.map (fun check -> Var (fst check)) checks))]
+ let checks =
+ EventMap.bindings !(ctx.events)
+ |> List.map (fun (ev, stack) ->
+ (ev, Stack.fold (fun checks -> function (Define_const (name, _, _) as def) -> (Var name, def) :: checks | _ -> assert false) [] stack))
+ |> List.map (fun (ev, checks) ->
+ ((ev, List.map fst checks), List.map snd checks))
+ in
+ List.concat (List.map snd checks)
+ @ [Assert (List.fold_left event_check (Fn ("not", [Var (zencode_name id)])) (List.map fst checks))]
| I_aux (I_clear _, _) -> []
| I_aux (I_match_failure, _) -> []
+ | I_aux (I_undefined ctyp, _) -> []
+
| instr ->
failwith ("Cannot translate: " ^ Pretty_print_sail.to_string (pp_instr instr))
-let smt_cfnode all_cdefs ctx =
+let smt_cfnode all_cdefs ctx ssanodes =
let open Jib_ssa in
function
| CF_start inits ->
@@ -1325,6 +1394,7 @@ let smt_cfnode all_cdefs ctx =
in
smt_reg_decs @ List.map smt_start (NameMap.bindings inits)
| CF_block instrs ->
+ let ctx = { ctx with pathcond = get_pathcond ctx ssanodes } in
List.concat (List.map (smt_instr ctx) instrs)
(* We can ignore any non basic-block/start control-flow nodes *)
| _ -> []
@@ -1413,7 +1483,7 @@ let optimize_smt stack =
let smt_header ctx cdefs =
let smt_ctype_defs = List.concat (generate_ctype_defs ctx cdefs) in
[declare_datatypes (mk_enum "Unit" ["unit"])]
- @ (IntSet.elements ctx.tuple_sizes |> List.map (fun n -> Declare_tuple n))
+ @ (IntSet.elements !(ctx.tuple_sizes) |> List.map (fun n -> Declare_tuple n))
@ [declare_datatypes (mk_record "Bits" [("len", Bitvec ctx.lbits_index);
("contents", Bitvec (lbits_size ctx))])
@@ -1482,8 +1552,7 @@ let smt_cdef props lets name_file ctx all_cdefs = function
let stack = Stack.create () in
- Stack.clear overflow_checks;
- let ctx = { ctx with pragma_l = pragma_l; arg_stack = Stack.create () } in
+ let ctx = { ctx with events = ref EventMap.empty; pragma_l = pragma_l; arg_stack = Stack.create () } in
(* When we create each argument declaration, give it a unique
location from the $property pragma, so we can identify it later. *)
@@ -1524,7 +1593,7 @@ let smt_cdef props lets name_file ctx all_cdefs = function
let muxers =
ssanodes |> List.map (smt_ssanode ctx cfg preds) |> List.concat
in
- let basic_block = smt_cfnode all_cdefs ctx cfnode in
+ let basic_block = smt_cfnode all_cdefs ctx ssanodes cfnode in
push_smt_defs stack muxers;
push_smt_defs stack basic_block;
end
diff --git a/src/smtlib.ml b/src/smtlib.ml
index 3ba85306..fe7aee5e 100644
--- a/src/smtlib.ml
+++ b/src/smtlib.ml
@@ -124,12 +124,16 @@ let bvones n =
let simp_fn = function
| Fn ("not", [Fn ("not", [exp])]) -> exp
+ | Fn ("not", [Bool_lit b]) -> Bool_lit (not b)
| Fn ("contents", [Fn ("Bits", [_; contents])]) -> contents
| Fn ("len", [Fn ("Bits", [len; _])]) -> len
+ | Fn ("or", [x]) -> x
+ | Fn ("and", [x]) -> x
| exp -> exp
let simp_ite = function
| Ite (cond, Bool_lit true, Bool_lit false) -> cond
+ | Ite (cond, Bool_lit x, Bool_lit y) when x = y -> Bool_lit x
| Ite (_, Var v, Var v') when v = v' -> Var v
| Ite (Bool_lit true, then_exp, _) -> then_exp
| Ite (Bool_lit false, _, else_exp) -> else_exp