summaryrefslogtreecommitdiff
path: root/src/jib/jib_smt.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/jib/jib_smt.ml')
-rw-r--r--src/jib/jib_smt.ml325
1 files changed, 207 insertions, 118 deletions
diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml
index 0d70695b..07cab66b 100644
--- a/src/jib/jib_smt.ml
+++ b/src/jib/jib_smt.ml
@@ -88,6 +88,7 @@ type ctx = {
arg_stack : (int * string) Stack.t;
ast : Type_check.tannot defs;
events : smt_exp Stack.t EventMap.t ref;
+ node : int;
pathcond : smt_exp Lazy.t;
use_string : bool ref;
use_real : bool ref
@@ -112,6 +113,7 @@ let initial_ctx () = {
arg_stack = Stack.create ();
ast = Defs [];
events = ref EventMap.empty;
+ node = -1;
pathcond = lazy (Bool_lit true);
use_string = ref false;
use_real = ref false;
@@ -210,7 +212,7 @@ let bvpint sz x =
let padding_size = sz - String.length bin in
if padding_size < 0 then
raise (Reporting.err_general Parse_ast.Unknown
- (Printf.sprintf "Count not create a %d-bit integer with value %s.\nTry increasing the maximum integer size"
+ (Printf.sprintf "Could not create a %d-bit integer with value %s.\nTry increasing the maximum integer size"
sz (Big_int.to_string x)));
let padding = String.make (sz - String.length bin) '0' in
Bin (padding ^ bin)
@@ -258,65 +260,71 @@ let zencode_ctor ctor_id unifiers =
Util.zencode_string (string_of_id ctor_id ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)
let rec smt_cval ctx cval =
- match cval with
- | V_lit (vl, ctyp) -> smt_value ctx vl ctyp
- | V_id (Name (id, _) as ssa_id, _) ->
- begin match Type_check.Env.lookup_id id ctx.tc_env with
- | Enum _ -> Enum (zencode_id id)
- | _ -> Var (zencode_name ssa_id)
- end
- | V_id (ssa_id, _) -> Var (zencode_name ssa_id)
- | V_call (Neq, [cval1; cval2]) ->
- Fn ("not", [Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2])])
- | V_call (Bvor, [cval1; cval2]) ->
- Fn ("bvor", [smt_cval ctx cval1; smt_cval ctx cval2])
- | V_call (Bit_to_bool, [cval]) ->
- Fn ("=", [smt_cval ctx cval; Bin "1"])
- | V_call (Bnot, [cval]) ->
- Fn ("not", [smt_cval ctx cval])
- | V_call (Band, cvals) ->
- smt_conj (List.map (smt_cval ctx) cvals)
- | V_call (Bor, cvals) ->
- smt_disj (List.map (smt_cval ctx) cvals)
- | V_ctor_kind (union, ctor_id, unifiers, _) ->
- Fn ("not", [Tester (zencode_ctor ctor_id unifiers, smt_cval ctx union)])
- | V_ctor_unwrap (ctor_id, union, unifiers, _) ->
- Fn ("un" ^ zencode_ctor ctor_id unifiers, [smt_cval ctx union])
- | V_field (union, field) ->
- begin match cval_ctyp union with
- | CT_struct (struct_id, _) ->
- Fn (zencode_upper_id struct_id ^ "_" ^ field, [smt_cval ctx union])
- | _ -> failwith "Field for non-struct type"
- end
- | V_struct (fields, ctyp) ->
- begin match ctyp with
- | CT_struct (struct_id, field_ctyps) ->
- let set_field (field, cval) =
- match Util.assoc_compare_opt Id.compare field field_ctyps with
- | None -> failwith "Field type not found"
- | Some ctyp when ctyp_equal (cval_ctyp cval) ctyp ->
- smt_cval ctx cval
- | _ -> failwith "Type mismatch when generating struct for SMT"
- in
- Fn (zencode_upper_id struct_id, List.map set_field fields)
- | _ -> failwith "Struct does not have struct type"
- end
- | V_tuple_member (frag, len, n) ->
- ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes);
- Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag])
- | V_ref (Name (id, _), _) ->
- let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in
- assert (CTMap.cardinal rmap = 1);
- begin match CTMap.min_binding_opt rmap with
- | Some (ctyp, regs) ->
- begin match Util.list_index (fun reg -> Id.compare reg id = 0) regs with
- | Some i ->
- bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i)
- | None -> assert false
+ match cval_ctyp cval with
+ | CT_constant n ->
+ bvint (required_width n) n
+ | _ ->
+ match cval with
+ | V_lit (vl, ctyp) -> smt_value ctx vl ctyp
+ | V_id (Name (id, _) as ssa_id, _) ->
+ begin match Type_check.Env.lookup_id id ctx.tc_env with
+ | Enum _ -> Enum (zencode_id id)
+ | _ -> Var (zencode_name ssa_id)
end
- | _ -> assert false
- end
- | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval)
+ | V_id (ssa_id, _) -> Var (zencode_name ssa_id)
+ | V_call (Neq, [cval1; cval2]) ->
+ Fn ("not", [Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2])])
+ | V_call (Bvor, [cval1; cval2]) ->
+ Fn ("bvor", [smt_cval ctx cval1; smt_cval ctx cval2])
+ | V_call (Bit_to_bool, [cval]) ->
+ Fn ("=", [smt_cval ctx cval; Bin "1"])
+ | V_call (Eq, [cval1; cval2]) ->
+ Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2])
+ | V_call (Bnot, [cval]) ->
+ Fn ("not", [smt_cval ctx cval])
+ | V_call (Band, cvals) ->
+ smt_conj (List.map (smt_cval ctx) cvals)
+ | V_call (Bor, cvals) ->
+ smt_disj (List.map (smt_cval ctx) cvals)
+ | V_ctor_kind (union, ctor_id, unifiers, _) ->
+ Fn ("not", [Tester (zencode_ctor ctor_id unifiers, smt_cval ctx union)])
+ | V_ctor_unwrap (ctor_id, union, unifiers, _) ->
+ Fn ("un" ^ zencode_ctor ctor_id unifiers, [smt_cval ctx union])
+ | V_field (union, field) ->
+ begin match cval_ctyp union with
+ | CT_struct (struct_id, _) ->
+ Fn (zencode_upper_id struct_id ^ "_" ^ field, [smt_cval ctx union])
+ | _ -> failwith "Field for non-struct type"
+ end
+ | V_struct (fields, ctyp) ->
+ begin match ctyp with
+ | CT_struct (struct_id, field_ctyps) ->
+ let set_field (field, cval) =
+ match Util.assoc_compare_opt Id.compare field field_ctyps with
+ | None -> failwith "Field type not found"
+ | Some ctyp when ctyp_equal (cval_ctyp cval) ctyp ->
+ smt_cval ctx cval
+ | _ -> failwith "Type mismatch when generating struct for SMT"
+ in
+ Fn (zencode_upper_id struct_id, List.map set_field fields)
+ | _ -> failwith "Struct does not have struct type"
+ end
+ | V_tuple_member (frag, len, n) ->
+ ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes);
+ Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag])
+ | V_ref (Name (id, _), _) ->
+ let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in
+ assert (CTMap.cardinal rmap = 1);
+ begin match CTMap.min_binding_opt rmap with
+ | Some (ctyp, regs) ->
+ begin match Util.list_index (fun reg -> Id.compare reg id = 0) regs with
+ | Some i ->
+ bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i)
+ | None -> assert false
+ end
+ | _ -> assert false
+ end
+ | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval)
let add_event ctx ev smt =
let stack = event_stack ctx ev in
@@ -327,7 +335,7 @@ let add_pathcond_event ctx ev =
let overflow_check ctx smt =
if not !opt_ignore_overflow then (
- Util.warn "Adding overflow check in generated SMT";
+ Reporting.warn "Overflow check in generated SMT for" ctx.pragma_l "";
add_event ctx Overflow smt
)
@@ -382,7 +390,7 @@ let builtin_gteq = builtin_int_comparison "bvsge" Big_int.greater_equal
(* ***** Arithmetic operations: lib/arith.sail ***** *)
-(** [force_size n m exp] takes a smt expression assumed to be a
+(** [force_size ctx n m exp] takes a smt expression assumed to be a
integer (signed bitvector) of length m and forces it to be length n
by either sign extending it or truncating it as required *)
let force_size ?checked:(checked=true) ctx n m smt =
@@ -402,6 +410,16 @@ let force_size ?checked:(checked=true) ctx n m smt =
if checked then overflow_check ctx check else ();
Extract (n - 1, 0, smt)
+(** [unsigned_size ctx n m exp] is much like force_size, but it
+ assumes that the bitvector is unsigned *)
+let unsigned_size ?checked:(checked=true) ctx n m smt =
+ if n = m then
+ smt
+ else if n > m then
+ Fn ("concat", [bvzero (n - m); smt])
+ else
+ failwith "bad arguments to unsigned_size"
+
let int_size ctx = function
| CT_constant n -> required_width n
| CT_fint sz -> sz
@@ -420,13 +438,6 @@ let builtin_arith fn big_int_fn padding ctx v1 v2 ret_ctyp =
| CT_constant c1, CT_constant c2, _ ->
bvint (int_size ctx ret_ctyp) (big_int_fn c1 c2)
- | ctyp, CT_constant c, _ ->
- let n = int_size ctx ctyp in
- force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c]))
- | CT_constant c, ctyp, _ ->
- let n = int_size ctx ctyp in
- force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2]))
-
| ctyp1, ctyp2, _ ->
let ret_sz = int_size ctx ret_ctyp in
let smt1 = smt_cval ctx v1 in
@@ -438,6 +449,12 @@ let builtin_add_int = builtin_arith "bvadd" Big_int.add (fun x -> x + 1)
let builtin_sub_int = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1)
let builtin_mult_int = builtin_arith "bvmul" Big_int.mul (fun x -> x * 2)
+let builtin_sub_nat ctx v1 v2 ret_ctyp =
+ let result = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1) ctx v1 v2 ret_ctyp in
+ Ite (Fn ("bvslt", [result; bvint (int_size ctx ret_ctyp) Big_int.zero]),
+ bvint (int_size ctx ret_ctyp) Big_int.zero,
+ result)
+
let builtin_negate_int ctx v ret_ctyp =
match cval_ctyp v, ret_ctyp with
| _, CT_constant c ->
@@ -493,6 +510,53 @@ let builtin_pow2 ctx v ret_ctyp =
| _ -> builtin_type_error ctx "pow2" [v] (Some ret_ctyp)
+let builtin_max_int ctx v1 v2 ret_ctyp =
+ match cval_ctyp v1, cval_ctyp v2 with
+ | CT_constant n, CT_constant m ->
+ bvint (int_size ctx ret_ctyp) (max n m)
+
+ | ctyp1, ctyp2 ->
+ let ret_sz = int_size ctx ret_ctyp in
+ let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in
+ let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in
+ Ite (Fn ("bvslt", [smt1; smt2]),
+ smt2,
+ smt1)
+
+let builtin_min_int ctx v1 v2 ret_ctyp =
+ match cval_ctyp v1, cval_ctyp v2 with
+ | CT_constant n, CT_constant m ->
+ bvint (int_size ctx ret_ctyp) (min n m)
+
+ | ctyp1, ctyp2 ->
+ let ret_sz = int_size ctx ret_ctyp in
+ let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in
+ let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in
+ Ite (Fn ("bvslt", [smt1; smt2]),
+ smt1,
+ smt2)
+
+let builtin_eq_bits ctx v1 v2 =
+ match cval_ctyp v1, cval_ctyp v2 with
+ | CT_fbits (n, _), CT_fbits (m, _) ->
+ let o = max n m in
+ let smt1 = unsigned_size ctx o n (smt_cval ctx v1) in
+ let smt2 = unsigned_size ctx o n (smt_cval ctx v2) in
+ Fn ("=", [smt1; smt2])
+
+ | CT_lbits _, CT_lbits _ ->
+ Fn ("=", [smt_cval ctx v1; smt_cval ctx v2])
+
+ | CT_lbits _, CT_fbits (n, _) ->
+ let smt2 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v2) in
+ Fn ("=", [Fn ("contents", [smt_cval ctx v1]); smt2])
+
+ | CT_fbits (n, _), CT_lbits _ ->
+ let smt1 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v1) in
+ Fn ("=", [smt1; Fn ("contents", [smt_cval ctx v2])])
+
+ | _ -> builtin_type_error ctx "eq_bits" [v1; v2] None
+
let builtin_zeros ctx v ret_ctyp =
match cval_ctyp v, ret_ctyp with
| _, CT_fbits (n, _) -> bvzero n
@@ -514,7 +578,7 @@ let builtin_ones ctx cval = function
Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; bvones (lbits_size ctx)])]);
| ret_ctyp -> builtin_type_error ctx "ones" [cval] (Some ret_ctyp)
-(* [bvzeint esz cval] (BitVector Zero Extend INTeger), takes a cval
+(* [bvzeint ctx esz cval] (BitVector Zero Extend INTeger), takes a cval
which must be an integer type (either CT_fint, or CT_lint), and
produces a bitvector which is either zero extended or truncated to
exactly esz bits. *)
@@ -773,11 +837,10 @@ let builtin_replicate_bits ctx v1 v2 ret_ctyp =
let smt = smt_cval ctx v1 in
Fn ("concat", List.init (Big_int.to_int c) (fun _ -> smt))
- (*| CT_fbits (n, _), ctyp2, CT_lbits _ ->
- let len = Fn ("bvmul", [bvint ctx.lbits_index (Big_int.of_int n);
- Extract (ctx.lbits_index - 1, 0, smt_cval ctx v2)])
- in
- assert false*)
+ | CT_fbits (n, _), _, CT_fbits (m, _) ->
+ let smt = smt_cval ctx v1 in
+ let c = m / n in
+ Fn ("concat", List.init c (fun _ -> smt))
| _ -> builtin_type_error ctx "replicate_bits" [v1; v2] (Some ret_ctyp)
@@ -791,7 +854,12 @@ let builtin_sail_truncate ctx v1 v2 ret_ctyp =
assert (Big_int.to_int c = m && m < lbits_size ctx);
Extract (Big_int.to_int c - 1, 0, Fn ("contents", [smt_cval ctx v1]))
- | _ -> builtin_type_error ctx "sail_truncate" [v2; v2] (Some ret_ctyp)
+ | CT_fbits (n, _), _, CT_lbits _ ->
+ let smt1 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v1) in
+ let smt2 = bvzeint ctx ctx.lbits_index v2 in
+ Fn ("Bits", [smt2; Fn ("bvand", [bvmask ctx smt2; smt1])])
+
+ | _ -> builtin_type_error ctx "sail_truncate" [v1; v2] (Some ret_ctyp)
let builtin_sail_truncateLSB ctx v1 v2 ret_ctyp =
match cval_ctyp v1, cval_ctyp v2, ret_ctyp with
@@ -927,7 +995,6 @@ let builtin_sqrt_real ctx root v =
let smt_builtin ctx name args ret_ctyp =
match name, args, ret_ctyp with
- | "eq_bits", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2])
| "eq_anything", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2])
(* lib/flow.sail *)
@@ -946,6 +1013,7 @@ let smt_builtin ctx name args ret_ctyp =
(* lib/arith.sail *)
| "add_int", [v1; v2], _ -> builtin_add_int ctx v1 v2 ret_ctyp
| "sub_int", [v1; v2], _ -> builtin_sub_int ctx v1 v2 ret_ctyp
+ | "sub_nat", [v1; v2], _ -> builtin_sub_nat ctx v1 v2 ret_ctyp
| "mult_int", [v1; v2], _ -> builtin_mult_int ctx v1 v2 ret_ctyp
| "neg_int", [v], _ -> builtin_negate_int ctx v ret_ctyp
| "shl_int", [v1; v2], _ -> builtin_shl_int ctx v1 v2 ret_ctyp
@@ -955,6 +1023,9 @@ let smt_builtin ctx name args ret_ctyp =
| "abs_int", [v], _ -> builtin_abs_int ctx v ret_ctyp
| "pow2", [v], _ -> builtin_pow2 ctx v ret_ctyp
+ | "max_int", [v1; v2], _ -> builtin_max_int ctx v1 v2 ret_ctyp
+ | "min_int", [v1; v2], _ -> builtin_min_int ctx v1 v2 ret_ctyp
+
(* All signed and unsigned bitvector comparisons *)
| "slt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvslt" ctx v1 v2 ret_ctyp
| "ult_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvult" ctx v1 v2 ret_ctyp
@@ -966,6 +1037,7 @@ let smt_builtin ctx name args ret_ctyp =
| "ugteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvuge" ctx v1 v2 ret_ctyp
(* lib/vector_dec.sail *)
+ | "eq_bits", [v1; v2], CT_bool -> builtin_eq_bits ctx v1 v2
| "zeros", [v], _ -> builtin_zeros ctx v ret_ctyp
| "sail_zeros", [v], _ -> builtin_zeros ctx v ret_ctyp
| "ones", [v], _ -> builtin_ones ctx v ret_ctyp
@@ -1022,7 +1094,8 @@ let writes = ref (-1)
let builtin_write_mem ctx wk addr_size addr data_size data =
incr writes;
let name = "W" ^ string_of_int !writes in
- [Write_mem (name, smt_cval ctx wk, smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data, smt_ctyp ctx (cval_ctyp data))],
+ [Write_mem (name, ctx.node, Lazy.force ctx.pathcond, smt_cval ctx wk,
+ smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data, smt_ctyp ctx (cval_ctyp data))],
Var (name ^ "_ret")
let ea_writes = ref (-1)
@@ -1030,7 +1103,8 @@ let ea_writes = ref (-1)
let builtin_write_mem_ea ctx wk addr_size addr data_size =
incr ea_writes;
let name = "A" ^ string_of_int !ea_writes in
- [Write_mem_ea (name, smt_cval ctx wk, smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data_size, smt_ctyp ctx (cval_ctyp data_size))],
+ [Write_mem_ea (name, ctx.node, Lazy.force ctx.pathcond, smt_cval ctx wk,
+ smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data_size, smt_ctyp ctx (cval_ctyp data_size))],
Enum "unit"
let reads = ref (-1)
@@ -1038,15 +1112,16 @@ let reads = ref (-1)
let builtin_read_mem ctx rk addr_size addr data_size ret_ctyp =
incr reads;
let name = "R" ^ string_of_int !reads in
- [Read_mem (name, smt_ctyp ctx ret_ctyp, smt_cval ctx rk, smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr))],
- Var (name ^ "_ret")
+ [Read_mem (name, ctx.node, Lazy.force ctx.pathcond, smt_ctyp ctx ret_ctyp, smt_cval ctx rk,
+ smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr))],
+ Read_res name
let excl_results = ref (-1)
let builtin_excl_res ctx =
incr excl_results;
let name = "E" ^ string_of_int !excl_results in
- [Excl_res name],
+ [Excl_res (name, ctx.node, Lazy.force ctx.pathcond)],
Var (name ^ "_ret")
let barriers = ref (-1)
@@ -1054,7 +1129,7 @@ let barriers = ref (-1)
let builtin_barrier ctx bk =
incr barriers;
let name = "B" ^ string_of_int !barriers in
- [Barrier (name, smt_cval ctx bk)],
+ [Barrier (name, ctx.node, Lazy.force ctx.pathcond, smt_cval ctx bk)],
Enum "unit"
let rec smt_conversion ctx from_ctyp to_ctyp x =
@@ -1064,6 +1139,8 @@ let rec smt_conversion ctx from_ctyp to_ctyp x =
bvint sz c
| CT_constant c, CT_lint ->
bvint ctx.lint_size c
+ | CT_fint sz, CT_lint ->
+ force_size ctx ctx.lint_size sz x
| _, _ -> failwith (Printf.sprintf "Cannot perform conversion from %s to %s" (string_of_ctyp from_ctyp) (string_of_ctyp to_ctyp))
let define_const ctx id ctyp exp = Define_const (zencode_name id, smt_ctyp ctx ctyp, exp)
@@ -1295,8 +1372,8 @@ let smt_ssanode ctx cfg preds =
| Phi (id, ctyp, ids) ->
let get_pi n =
match get_vertex cfg n with
- | Some ((ssanodes, _), _, _) ->
- List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes)
+ | Some ((ssa_elems, _), _, _) ->
+ List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems)
| None -> failwith "Predecessor node does not exist"
in
let pis = List.map get_pi (IntSet.elements preds) in
@@ -1360,8 +1437,8 @@ let rec get_pathcond n cfg ctx =
let open Jib_ssa in
let get_pi m =
match get_vertex cfg m with
- | Some ((ssanodes, _), _, _) ->
- V_call (Band, List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes))
+ | Some ((ssa_elems, _), _, _) ->
+ V_call (Band, List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems))
| None -> failwith "Node does not exist"
in
match get_vertex cfg n with
@@ -1533,7 +1610,7 @@ let smt_instr ctx =
end
else if not extern then
let smt_args = List.map (smt_cval ctx) args in
- [define_const ctx id ret_ctyp (Fn (zencode_id function_id, smt_args))]
+ [define_const ctx id ret_ctyp (Ctor (zencode_id function_id, smt_args))]
else
failwith ("Unrecognised function " ^ string_of_id function_id)
@@ -1569,7 +1646,7 @@ let smt_instr ctx =
| instr ->
failwith ("Cannot translate: " ^ Pretty_print_sail.to_string (pp_instr instr))
-let smt_cfnode all_cdefs ctx ssanodes =
+let smt_cfnode all_cdefs ctx ssa_elems =
let open Jib_ssa in
function
| CF_start inits ->
@@ -1619,8 +1696,8 @@ module Make_optimizer(S : Sequence) = struct
| Some n -> Hashtbl.replace uses var (n + 1)
| None -> Hashtbl.add uses var 1
end
- | Enum _ | Hex _ | Bin _ | Bool_lit _ | String_lit _ | Real_lit _ -> ()
- | Fn (f, exps) ->
+ | Enum _ | Read_res _ | Hex _ | Bin _ | Bool_lit _ | String_lit _ | Real_lit _ -> ()
+ | Fn (_, exps) | Ctor (_, exps) ->
List.iter uses_in_exp exps
| Ite (cond, t, e) ->
uses_in_exp cond; uses_in_exp t; uses_in_exp e
@@ -1644,19 +1721,20 @@ module Make_optimizer(S : Sequence) = struct
end
| (Declare_datatypes _ | Declare_tuple _) as def ->
Stack.push def stack'
- | Write_mem (_, wk, addr, _, data, _) as def ->
- uses_in_exp wk; uses_in_exp addr; uses_in_exp data;
+ | Write_mem (_, _, active, wk, addr, _, data, _) as def ->
+ uses_in_exp active; uses_in_exp wk; uses_in_exp addr; uses_in_exp data;
Stack.push def stack'
- | Write_mem_ea (_, wk, addr, _, data_size, _) as def ->
- uses_in_exp wk; uses_in_exp addr; uses_in_exp data_size;
+ | Write_mem_ea (_, _, active, wk, addr, _, data_size, _) as def ->
+ uses_in_exp active; uses_in_exp wk; uses_in_exp addr; uses_in_exp data_size;
Stack.push def stack'
- | Read_mem (_, _, rk, addr, _) as def ->
- uses_in_exp rk; uses_in_exp addr;
+ | Read_mem (_, _, active, _, rk, addr, _) as def ->
+ uses_in_exp active; uses_in_exp rk; uses_in_exp addr;
Stack.push def stack'
- | Barrier (_, bk) as def ->
- uses_in_exp bk;
+ | Barrier (_, _, active, bk) as def ->
+ uses_in_exp active; uses_in_exp bk;
Stack.push def stack'
- | Excl_res _ as def ->
+ | Excl_res (_, _, active) as def ->
+ uses_in_exp active;
Stack.push def stack'
| Assert exp as def ->
uses_in_exp exp;
@@ -1666,35 +1744,46 @@ module Make_optimizer(S : Sequence) = struct
Stack.fold remove_unused () stack;
let vars = Hashtbl.create (Stack.length stack') in
+ let kinds = Hashtbl.create (Stack.length stack') in
let seq = S.create () in
let constant_propagate = function
| Declare_const _ as def ->
S.add def seq
| Define_const (var, typ, exp) ->
- begin match Hashtbl.find_opt uses var, simp_smt_exp vars exp with
+ let exp = simp_smt_exp vars kinds exp in
+ begin match Hashtbl.find_opt uses var, simp_smt_exp vars kinds exp with
| _, (Bin _ | Bool_lit _) ->
Hashtbl.add vars var exp
| _, Var _ when !opt_propagate_vars ->
Hashtbl.add vars var exp
+ | _, (Ctor (str, _)) ->
+ Hashtbl.add kinds var str;
+ S.add (Define_const (var, typ, exp)) seq
| Some 1, _ ->
Hashtbl.add vars var exp
| Some _, exp ->
S.add (Define_const (var, typ, exp)) seq
| None, _ -> assert false
end
- | Write_mem (name, wk, addr, addr_ty, data, data_ty) ->
- S.add (Write_mem (name, simp_smt_exp vars wk, simp_smt_exp vars addr, addr_ty, simp_smt_exp vars data, data_ty)) seq
- | Write_mem_ea (name, wk, addr, addr_ty, data_size, data_size_ty) ->
- S.add (Write_mem_ea (name, simp_smt_exp vars wk, simp_smt_exp vars addr, addr_ty, simp_smt_exp vars data_size, data_size_ty)) seq
- | Read_mem (name, typ, rk, addr, addr_typ) ->
- S.add (Read_mem (name, typ, simp_smt_exp vars rk, simp_smt_exp vars addr, addr_typ)) seq
- | Barrier (name, bk) ->
- S.add (Barrier (name, simp_smt_exp vars bk)) seq
- | Excl_res name ->
- S.add (Excl_res name) seq
+ | Write_mem (name, node, active, wk, addr, addr_ty, data, data_ty) ->
+ S.add (Write_mem (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds wk,
+ simp_smt_exp vars kinds addr, addr_ty, simp_smt_exp vars kinds data, data_ty))
+ seq
+ | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) ->
+ S.add (Write_mem_ea (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds wk,
+ simp_smt_exp vars kinds addr, addr_ty, simp_smt_exp vars kinds data_size, data_size_ty))
+ seq
+ | Read_mem (name, node, active, typ, rk, addr, addr_typ) ->
+ S.add (Read_mem (name, node, simp_smt_exp vars kinds active, typ, simp_smt_exp vars kinds rk,
+ simp_smt_exp vars kinds addr, addr_typ))
+ seq
+ | Barrier (name, node, active, bk) ->
+ S.add (Barrier (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds bk)) seq
+ | Excl_res (name, node, active) ->
+ S.add (Excl_res (name, node, simp_smt_exp vars kinds active)) seq
| Assert exp ->
- S.add (Assert (simp_smt_exp vars exp)) seq
+ S.add (Assert (simp_smt_exp vars kinds exp)) seq
| (Declare_datatypes _ | Declare_tuple _) as def ->
S.add def seq
| Define_fun _ -> assert false
@@ -1815,18 +1904,18 @@ let smt_instr_list name ctx all_cdefs instrs =
List.iter (fun n ->
begin match get_vertex cfg n with
| None -> ()
- | Some ((ssanodes, cfnode), preds, succs) ->
+ | Some ((ssa_elems, cfnode), preds, succs) ->
let muxers =
- ssanodes |> List.map (smt_ssanode ctx cfg preds) |> List.concat
+ ssa_elems |> List.map (smt_ssanode ctx cfg preds) |> List.concat
in
- let ctx = { ctx with pathcond = lazy (get_pathcond n cfg ctx) } in
- let basic_block = smt_cfnode all_cdefs ctx ssanodes cfnode in
+ let ctx = { ctx with node = n; pathcond = lazy (get_pathcond n cfg ctx) } in
+ let basic_block = smt_cfnode all_cdefs ctx ssa_elems cfnode in
push_smt_defs stack muxers;
- push_smt_defs stack basic_block;
+ push_smt_defs stack basic_block
end
) visit_order;
- stack
+ stack, cfg
let smt_cdef props lets name_file ctx all_cdefs = function
| CDEF_spec (function_id, arg_ctyps, ret_ctyp) when Bindings.mem function_id props ->
@@ -1853,7 +1942,7 @@ let smt_cdef props lets name_file ctx all_cdefs = function
|> remove_pointless_goto
in
- let stack = smt_instr_list (string_of_id function_id) ctx all_cdefs instrs in
+ let stack, _ = smt_instr_list (string_of_id function_id) ctx all_cdefs instrs in
let query = smt_query ctx pragma.query in
push_smt_defs stack [Assert (Fn ("not", [query]))];