diff options
Diffstat (limited to 'src/jib/jib_smt.ml')
| -rw-r--r-- | src/jib/jib_smt.ml | 325 |
1 files changed, 207 insertions, 118 deletions
diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index 0d70695b..07cab66b 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -88,6 +88,7 @@ type ctx = { arg_stack : (int * string) Stack.t; ast : Type_check.tannot defs; events : smt_exp Stack.t EventMap.t ref; + node : int; pathcond : smt_exp Lazy.t; use_string : bool ref; use_real : bool ref @@ -112,6 +113,7 @@ let initial_ctx () = { arg_stack = Stack.create (); ast = Defs []; events = ref EventMap.empty; + node = -1; pathcond = lazy (Bool_lit true); use_string = ref false; use_real = ref false; @@ -210,7 +212,7 @@ let bvpint sz x = let padding_size = sz - String.length bin in if padding_size < 0 then raise (Reporting.err_general Parse_ast.Unknown - (Printf.sprintf "Count not create a %d-bit integer with value %s.\nTry increasing the maximum integer size" + (Printf.sprintf "Could not create a %d-bit integer with value %s.\nTry increasing the maximum integer size" sz (Big_int.to_string x))); let padding = String.make (sz - String.length bin) '0' in Bin (padding ^ bin) @@ -258,65 +260,71 @@ let zencode_ctor ctor_id unifiers = Util.zencode_string (string_of_id ctor_id ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) let rec smt_cval ctx cval = - match cval with - | V_lit (vl, ctyp) -> smt_value ctx vl ctyp - | V_id (Name (id, _) as ssa_id, _) -> - begin match Type_check.Env.lookup_id id ctx.tc_env with - | Enum _ -> Enum (zencode_id id) - | _ -> Var (zencode_name ssa_id) - end - | V_id (ssa_id, _) -> Var (zencode_name ssa_id) - | V_call (Neq, [cval1; cval2]) -> - Fn ("not", [Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2])]) - | V_call (Bvor, [cval1; cval2]) -> - Fn ("bvor", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_call (Bit_to_bool, [cval]) -> - Fn ("=", [smt_cval ctx cval; Bin "1"]) - | V_call (Bnot, [cval]) -> - Fn ("not", [smt_cval ctx cval]) - | V_call (Band, cvals) -> - smt_conj (List.map (smt_cval ctx) cvals) - | V_call (Bor, cvals) -> - smt_disj (List.map (smt_cval ctx) cvals) - | V_ctor_kind (union, ctor_id, unifiers, _) -> - Fn ("not", [Tester (zencode_ctor ctor_id unifiers, smt_cval ctx union)]) - | V_ctor_unwrap (ctor_id, union, unifiers, _) -> - Fn ("un" ^ zencode_ctor ctor_id unifiers, [smt_cval ctx union]) - | V_field (union, field) -> - begin match cval_ctyp union with - | CT_struct (struct_id, _) -> - Fn (zencode_upper_id struct_id ^ "_" ^ field, [smt_cval ctx union]) - | _ -> failwith "Field for non-struct type" - end - | V_struct (fields, ctyp) -> - begin match ctyp with - | CT_struct (struct_id, field_ctyps) -> - let set_field (field, cval) = - match Util.assoc_compare_opt Id.compare field field_ctyps with - | None -> failwith "Field type not found" - | Some ctyp when ctyp_equal (cval_ctyp cval) ctyp -> - smt_cval ctx cval - | _ -> failwith "Type mismatch when generating struct for SMT" - in - Fn (zencode_upper_id struct_id, List.map set_field fields) - | _ -> failwith "Struct does not have struct type" - end - | V_tuple_member (frag, len, n) -> - ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes); - Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag]) - | V_ref (Name (id, _), _) -> - let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in - assert (CTMap.cardinal rmap = 1); - begin match CTMap.min_binding_opt rmap with - | Some (ctyp, regs) -> - begin match Util.list_index (fun reg -> Id.compare reg id = 0) regs with - | Some i -> - bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i) - | None -> assert false + match cval_ctyp cval with + | CT_constant n -> + bvint (required_width n) n + | _ -> + match cval with + | V_lit (vl, ctyp) -> smt_value ctx vl ctyp + | V_id (Name (id, _) as ssa_id, _) -> + begin match Type_check.Env.lookup_id id ctx.tc_env with + | Enum _ -> Enum (zencode_id id) + | _ -> Var (zencode_name ssa_id) end - | _ -> assert false - end - | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval) + | V_id (ssa_id, _) -> Var (zencode_name ssa_id) + | V_call (Neq, [cval1; cval2]) -> + Fn ("not", [Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2])]) + | V_call (Bvor, [cval1; cval2]) -> + Fn ("bvor", [smt_cval ctx cval1; smt_cval ctx cval2]) + | V_call (Bit_to_bool, [cval]) -> + Fn ("=", [smt_cval ctx cval; Bin "1"]) + | V_call (Eq, [cval1; cval2]) -> + Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2]) + | V_call (Bnot, [cval]) -> + Fn ("not", [smt_cval ctx cval]) + | V_call (Band, cvals) -> + smt_conj (List.map (smt_cval ctx) cvals) + | V_call (Bor, cvals) -> + smt_disj (List.map (smt_cval ctx) cvals) + | V_ctor_kind (union, ctor_id, unifiers, _) -> + Fn ("not", [Tester (zencode_ctor ctor_id unifiers, smt_cval ctx union)]) + | V_ctor_unwrap (ctor_id, union, unifiers, _) -> + Fn ("un" ^ zencode_ctor ctor_id unifiers, [smt_cval ctx union]) + | V_field (union, field) -> + begin match cval_ctyp union with + | CT_struct (struct_id, _) -> + Fn (zencode_upper_id struct_id ^ "_" ^ field, [smt_cval ctx union]) + | _ -> failwith "Field for non-struct type" + end + | V_struct (fields, ctyp) -> + begin match ctyp with + | CT_struct (struct_id, field_ctyps) -> + let set_field (field, cval) = + match Util.assoc_compare_opt Id.compare field field_ctyps with + | None -> failwith "Field type not found" + | Some ctyp when ctyp_equal (cval_ctyp cval) ctyp -> + smt_cval ctx cval + | _ -> failwith "Type mismatch when generating struct for SMT" + in + Fn (zencode_upper_id struct_id, List.map set_field fields) + | _ -> failwith "Struct does not have struct type" + end + | V_tuple_member (frag, len, n) -> + ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes); + Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag]) + | V_ref (Name (id, _), _) -> + let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in + assert (CTMap.cardinal rmap = 1); + begin match CTMap.min_binding_opt rmap with + | Some (ctyp, regs) -> + begin match Util.list_index (fun reg -> Id.compare reg id = 0) regs with + | Some i -> + bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i) + | None -> assert false + end + | _ -> assert false + end + | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval) let add_event ctx ev smt = let stack = event_stack ctx ev in @@ -327,7 +335,7 @@ let add_pathcond_event ctx ev = let overflow_check ctx smt = if not !opt_ignore_overflow then ( - Util.warn "Adding overflow check in generated SMT"; + Reporting.warn "Overflow check in generated SMT for" ctx.pragma_l ""; add_event ctx Overflow smt ) @@ -382,7 +390,7 @@ let builtin_gteq = builtin_int_comparison "bvsge" Big_int.greater_equal (* ***** Arithmetic operations: lib/arith.sail ***** *) -(** [force_size n m exp] takes a smt expression assumed to be a +(** [force_size ctx n m exp] takes a smt expression assumed to be a integer (signed bitvector) of length m and forces it to be length n by either sign extending it or truncating it as required *) let force_size ?checked:(checked=true) ctx n m smt = @@ -402,6 +410,16 @@ let force_size ?checked:(checked=true) ctx n m smt = if checked then overflow_check ctx check else (); Extract (n - 1, 0, smt) +(** [unsigned_size ctx n m exp] is much like force_size, but it + assumes that the bitvector is unsigned *) +let unsigned_size ?checked:(checked=true) ctx n m smt = + if n = m then + smt + else if n > m then + Fn ("concat", [bvzero (n - m); smt]) + else + failwith "bad arguments to unsigned_size" + let int_size ctx = function | CT_constant n -> required_width n | CT_fint sz -> sz @@ -420,13 +438,6 @@ let builtin_arith fn big_int_fn padding ctx v1 v2 ret_ctyp = | CT_constant c1, CT_constant c2, _ -> bvint (int_size ctx ret_ctyp) (big_int_fn c1 c2) - | ctyp, CT_constant c, _ -> - let n = int_size ctx ctyp in - force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c])) - | CT_constant c, ctyp, _ -> - let n = int_size ctx ctyp in - force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2])) - | ctyp1, ctyp2, _ -> let ret_sz = int_size ctx ret_ctyp in let smt1 = smt_cval ctx v1 in @@ -438,6 +449,12 @@ let builtin_add_int = builtin_arith "bvadd" Big_int.add (fun x -> x + 1) let builtin_sub_int = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1) let builtin_mult_int = builtin_arith "bvmul" Big_int.mul (fun x -> x * 2) +let builtin_sub_nat ctx v1 v2 ret_ctyp = + let result = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1) ctx v1 v2 ret_ctyp in + Ite (Fn ("bvslt", [result; bvint (int_size ctx ret_ctyp) Big_int.zero]), + bvint (int_size ctx ret_ctyp) Big_int.zero, + result) + let builtin_negate_int ctx v ret_ctyp = match cval_ctyp v, ret_ctyp with | _, CT_constant c -> @@ -493,6 +510,53 @@ let builtin_pow2 ctx v ret_ctyp = | _ -> builtin_type_error ctx "pow2" [v] (Some ret_ctyp) +let builtin_max_int ctx v1 v2 ret_ctyp = + match cval_ctyp v1, cval_ctyp v2 with + | CT_constant n, CT_constant m -> + bvint (int_size ctx ret_ctyp) (max n m) + + | ctyp1, ctyp2 -> + let ret_sz = int_size ctx ret_ctyp in + let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in + let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in + Ite (Fn ("bvslt", [smt1; smt2]), + smt2, + smt1) + +let builtin_min_int ctx v1 v2 ret_ctyp = + match cval_ctyp v1, cval_ctyp v2 with + | CT_constant n, CT_constant m -> + bvint (int_size ctx ret_ctyp) (min n m) + + | ctyp1, ctyp2 -> + let ret_sz = int_size ctx ret_ctyp in + let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in + let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in + Ite (Fn ("bvslt", [smt1; smt2]), + smt1, + smt2) + +let builtin_eq_bits ctx v1 v2 = + match cval_ctyp v1, cval_ctyp v2 with + | CT_fbits (n, _), CT_fbits (m, _) -> + let o = max n m in + let smt1 = unsigned_size ctx o n (smt_cval ctx v1) in + let smt2 = unsigned_size ctx o n (smt_cval ctx v2) in + Fn ("=", [smt1; smt2]) + + | CT_lbits _, CT_lbits _ -> + Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) + + | CT_lbits _, CT_fbits (n, _) -> + let smt2 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v2) in + Fn ("=", [Fn ("contents", [smt_cval ctx v1]); smt2]) + + | CT_fbits (n, _), CT_lbits _ -> + let smt1 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v1) in + Fn ("=", [smt1; Fn ("contents", [smt_cval ctx v2])]) + + | _ -> builtin_type_error ctx "eq_bits" [v1; v2] None + let builtin_zeros ctx v ret_ctyp = match cval_ctyp v, ret_ctyp with | _, CT_fbits (n, _) -> bvzero n @@ -514,7 +578,7 @@ let builtin_ones ctx cval = function Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; bvones (lbits_size ctx)])]); | ret_ctyp -> builtin_type_error ctx "ones" [cval] (Some ret_ctyp) -(* [bvzeint esz cval] (BitVector Zero Extend INTeger), takes a cval +(* [bvzeint ctx esz cval] (BitVector Zero Extend INTeger), takes a cval which must be an integer type (either CT_fint, or CT_lint), and produces a bitvector which is either zero extended or truncated to exactly esz bits. *) @@ -773,11 +837,10 @@ let builtin_replicate_bits ctx v1 v2 ret_ctyp = let smt = smt_cval ctx v1 in Fn ("concat", List.init (Big_int.to_int c) (fun _ -> smt)) - (*| CT_fbits (n, _), ctyp2, CT_lbits _ -> - let len = Fn ("bvmul", [bvint ctx.lbits_index (Big_int.of_int n); - Extract (ctx.lbits_index - 1, 0, smt_cval ctx v2)]) - in - assert false*) + | CT_fbits (n, _), _, CT_fbits (m, _) -> + let smt = smt_cval ctx v1 in + let c = m / n in + Fn ("concat", List.init c (fun _ -> smt)) | _ -> builtin_type_error ctx "replicate_bits" [v1; v2] (Some ret_ctyp) @@ -791,7 +854,12 @@ let builtin_sail_truncate ctx v1 v2 ret_ctyp = assert (Big_int.to_int c = m && m < lbits_size ctx); Extract (Big_int.to_int c - 1, 0, Fn ("contents", [smt_cval ctx v1])) - | _ -> builtin_type_error ctx "sail_truncate" [v2; v2] (Some ret_ctyp) + | CT_fbits (n, _), _, CT_lbits _ -> + let smt1 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v1) in + let smt2 = bvzeint ctx ctx.lbits_index v2 in + Fn ("Bits", [smt2; Fn ("bvand", [bvmask ctx smt2; smt1])]) + + | _ -> builtin_type_error ctx "sail_truncate" [v1; v2] (Some ret_ctyp) let builtin_sail_truncateLSB ctx v1 v2 ret_ctyp = match cval_ctyp v1, cval_ctyp v2, ret_ctyp with @@ -927,7 +995,6 @@ let builtin_sqrt_real ctx root v = let smt_builtin ctx name args ret_ctyp = match name, args, ret_ctyp with - | "eq_bits", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) | "eq_anything", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) (* lib/flow.sail *) @@ -946,6 +1013,7 @@ let smt_builtin ctx name args ret_ctyp = (* lib/arith.sail *) | "add_int", [v1; v2], _ -> builtin_add_int ctx v1 v2 ret_ctyp | "sub_int", [v1; v2], _ -> builtin_sub_int ctx v1 v2 ret_ctyp + | "sub_nat", [v1; v2], _ -> builtin_sub_nat ctx v1 v2 ret_ctyp | "mult_int", [v1; v2], _ -> builtin_mult_int ctx v1 v2 ret_ctyp | "neg_int", [v], _ -> builtin_negate_int ctx v ret_ctyp | "shl_int", [v1; v2], _ -> builtin_shl_int ctx v1 v2 ret_ctyp @@ -955,6 +1023,9 @@ let smt_builtin ctx name args ret_ctyp = | "abs_int", [v], _ -> builtin_abs_int ctx v ret_ctyp | "pow2", [v], _ -> builtin_pow2 ctx v ret_ctyp + | "max_int", [v1; v2], _ -> builtin_max_int ctx v1 v2 ret_ctyp + | "min_int", [v1; v2], _ -> builtin_min_int ctx v1 v2 ret_ctyp + (* All signed and unsigned bitvector comparisons *) | "slt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvslt" ctx v1 v2 ret_ctyp | "ult_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvult" ctx v1 v2 ret_ctyp @@ -966,6 +1037,7 @@ let smt_builtin ctx name args ret_ctyp = | "ugteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvuge" ctx v1 v2 ret_ctyp (* lib/vector_dec.sail *) + | "eq_bits", [v1; v2], CT_bool -> builtin_eq_bits ctx v1 v2 | "zeros", [v], _ -> builtin_zeros ctx v ret_ctyp | "sail_zeros", [v], _ -> builtin_zeros ctx v ret_ctyp | "ones", [v], _ -> builtin_ones ctx v ret_ctyp @@ -1022,7 +1094,8 @@ let writes = ref (-1) let builtin_write_mem ctx wk addr_size addr data_size data = incr writes; let name = "W" ^ string_of_int !writes in - [Write_mem (name, smt_cval ctx wk, smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data, smt_ctyp ctx (cval_ctyp data))], + [Write_mem (name, ctx.node, Lazy.force ctx.pathcond, smt_cval ctx wk, + smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data, smt_ctyp ctx (cval_ctyp data))], Var (name ^ "_ret") let ea_writes = ref (-1) @@ -1030,7 +1103,8 @@ let ea_writes = ref (-1) let builtin_write_mem_ea ctx wk addr_size addr data_size = incr ea_writes; let name = "A" ^ string_of_int !ea_writes in - [Write_mem_ea (name, smt_cval ctx wk, smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data_size, smt_ctyp ctx (cval_ctyp data_size))], + [Write_mem_ea (name, ctx.node, Lazy.force ctx.pathcond, smt_cval ctx wk, + smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data_size, smt_ctyp ctx (cval_ctyp data_size))], Enum "unit" let reads = ref (-1) @@ -1038,15 +1112,16 @@ let reads = ref (-1) let builtin_read_mem ctx rk addr_size addr data_size ret_ctyp = incr reads; let name = "R" ^ string_of_int !reads in - [Read_mem (name, smt_ctyp ctx ret_ctyp, smt_cval ctx rk, smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr))], - Var (name ^ "_ret") + [Read_mem (name, ctx.node, Lazy.force ctx.pathcond, smt_ctyp ctx ret_ctyp, smt_cval ctx rk, + smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr))], + Read_res name let excl_results = ref (-1) let builtin_excl_res ctx = incr excl_results; let name = "E" ^ string_of_int !excl_results in - [Excl_res name], + [Excl_res (name, ctx.node, Lazy.force ctx.pathcond)], Var (name ^ "_ret") let barriers = ref (-1) @@ -1054,7 +1129,7 @@ let barriers = ref (-1) let builtin_barrier ctx bk = incr barriers; let name = "B" ^ string_of_int !barriers in - [Barrier (name, smt_cval ctx bk)], + [Barrier (name, ctx.node, Lazy.force ctx.pathcond, smt_cval ctx bk)], Enum "unit" let rec smt_conversion ctx from_ctyp to_ctyp x = @@ -1064,6 +1139,8 @@ let rec smt_conversion ctx from_ctyp to_ctyp x = bvint sz c | CT_constant c, CT_lint -> bvint ctx.lint_size c + | CT_fint sz, CT_lint -> + force_size ctx ctx.lint_size sz x | _, _ -> failwith (Printf.sprintf "Cannot perform conversion from %s to %s" (string_of_ctyp from_ctyp) (string_of_ctyp to_ctyp)) let define_const ctx id ctyp exp = Define_const (zencode_name id, smt_ctyp ctx ctyp, exp) @@ -1295,8 +1372,8 @@ let smt_ssanode ctx cfg preds = | Phi (id, ctyp, ids) -> let get_pi n = match get_vertex cfg n with - | Some ((ssanodes, _), _, _) -> - List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes) + | Some ((ssa_elems, _), _, _) -> + List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems) | None -> failwith "Predecessor node does not exist" in let pis = List.map get_pi (IntSet.elements preds) in @@ -1360,8 +1437,8 @@ let rec get_pathcond n cfg ctx = let open Jib_ssa in let get_pi m = match get_vertex cfg m with - | Some ((ssanodes, _), _, _) -> - V_call (Band, List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes)) + | Some ((ssa_elems, _), _, _) -> + V_call (Band, List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems)) | None -> failwith "Node does not exist" in match get_vertex cfg n with @@ -1533,7 +1610,7 @@ let smt_instr ctx = end else if not extern then let smt_args = List.map (smt_cval ctx) args in - [define_const ctx id ret_ctyp (Fn (zencode_id function_id, smt_args))] + [define_const ctx id ret_ctyp (Ctor (zencode_id function_id, smt_args))] else failwith ("Unrecognised function " ^ string_of_id function_id) @@ -1569,7 +1646,7 @@ let smt_instr ctx = | instr -> failwith ("Cannot translate: " ^ Pretty_print_sail.to_string (pp_instr instr)) -let smt_cfnode all_cdefs ctx ssanodes = +let smt_cfnode all_cdefs ctx ssa_elems = let open Jib_ssa in function | CF_start inits -> @@ -1619,8 +1696,8 @@ module Make_optimizer(S : Sequence) = struct | Some n -> Hashtbl.replace uses var (n + 1) | None -> Hashtbl.add uses var 1 end - | Enum _ | Hex _ | Bin _ | Bool_lit _ | String_lit _ | Real_lit _ -> () - | Fn (f, exps) -> + | Enum _ | Read_res _ | Hex _ | Bin _ | Bool_lit _ | String_lit _ | Real_lit _ -> () + | Fn (_, exps) | Ctor (_, exps) -> List.iter uses_in_exp exps | Ite (cond, t, e) -> uses_in_exp cond; uses_in_exp t; uses_in_exp e @@ -1644,19 +1721,20 @@ module Make_optimizer(S : Sequence) = struct end | (Declare_datatypes _ | Declare_tuple _) as def -> Stack.push def stack' - | Write_mem (_, wk, addr, _, data, _) as def -> - uses_in_exp wk; uses_in_exp addr; uses_in_exp data; + | Write_mem (_, _, active, wk, addr, _, data, _) as def -> + uses_in_exp active; uses_in_exp wk; uses_in_exp addr; uses_in_exp data; Stack.push def stack' - | Write_mem_ea (_, wk, addr, _, data_size, _) as def -> - uses_in_exp wk; uses_in_exp addr; uses_in_exp data_size; + | Write_mem_ea (_, _, active, wk, addr, _, data_size, _) as def -> + uses_in_exp active; uses_in_exp wk; uses_in_exp addr; uses_in_exp data_size; Stack.push def stack' - | Read_mem (_, _, rk, addr, _) as def -> - uses_in_exp rk; uses_in_exp addr; + | Read_mem (_, _, active, _, rk, addr, _) as def -> + uses_in_exp active; uses_in_exp rk; uses_in_exp addr; Stack.push def stack' - | Barrier (_, bk) as def -> - uses_in_exp bk; + | Barrier (_, _, active, bk) as def -> + uses_in_exp active; uses_in_exp bk; Stack.push def stack' - | Excl_res _ as def -> + | Excl_res (_, _, active) as def -> + uses_in_exp active; Stack.push def stack' | Assert exp as def -> uses_in_exp exp; @@ -1666,35 +1744,46 @@ module Make_optimizer(S : Sequence) = struct Stack.fold remove_unused () stack; let vars = Hashtbl.create (Stack.length stack') in + let kinds = Hashtbl.create (Stack.length stack') in let seq = S.create () in let constant_propagate = function | Declare_const _ as def -> S.add def seq | Define_const (var, typ, exp) -> - begin match Hashtbl.find_opt uses var, simp_smt_exp vars exp with + let exp = simp_smt_exp vars kinds exp in + begin match Hashtbl.find_opt uses var, simp_smt_exp vars kinds exp with | _, (Bin _ | Bool_lit _) -> Hashtbl.add vars var exp | _, Var _ when !opt_propagate_vars -> Hashtbl.add vars var exp + | _, (Ctor (str, _)) -> + Hashtbl.add kinds var str; + S.add (Define_const (var, typ, exp)) seq | Some 1, _ -> Hashtbl.add vars var exp | Some _, exp -> S.add (Define_const (var, typ, exp)) seq | None, _ -> assert false end - | Write_mem (name, wk, addr, addr_ty, data, data_ty) -> - S.add (Write_mem (name, simp_smt_exp vars wk, simp_smt_exp vars addr, addr_ty, simp_smt_exp vars data, data_ty)) seq - | Write_mem_ea (name, wk, addr, addr_ty, data_size, data_size_ty) -> - S.add (Write_mem_ea (name, simp_smt_exp vars wk, simp_smt_exp vars addr, addr_ty, simp_smt_exp vars data_size, data_size_ty)) seq - | Read_mem (name, typ, rk, addr, addr_typ) -> - S.add (Read_mem (name, typ, simp_smt_exp vars rk, simp_smt_exp vars addr, addr_typ)) seq - | Barrier (name, bk) -> - S.add (Barrier (name, simp_smt_exp vars bk)) seq - | Excl_res name -> - S.add (Excl_res name) seq + | Write_mem (name, node, active, wk, addr, addr_ty, data, data_ty) -> + S.add (Write_mem (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds wk, + simp_smt_exp vars kinds addr, addr_ty, simp_smt_exp vars kinds data, data_ty)) + seq + | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> + S.add (Write_mem_ea (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds wk, + simp_smt_exp vars kinds addr, addr_ty, simp_smt_exp vars kinds data_size, data_size_ty)) + seq + | Read_mem (name, node, active, typ, rk, addr, addr_typ) -> + S.add (Read_mem (name, node, simp_smt_exp vars kinds active, typ, simp_smt_exp vars kinds rk, + simp_smt_exp vars kinds addr, addr_typ)) + seq + | Barrier (name, node, active, bk) -> + S.add (Barrier (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds bk)) seq + | Excl_res (name, node, active) -> + S.add (Excl_res (name, node, simp_smt_exp vars kinds active)) seq | Assert exp -> - S.add (Assert (simp_smt_exp vars exp)) seq + S.add (Assert (simp_smt_exp vars kinds exp)) seq | (Declare_datatypes _ | Declare_tuple _) as def -> S.add def seq | Define_fun _ -> assert false @@ -1815,18 +1904,18 @@ let smt_instr_list name ctx all_cdefs instrs = List.iter (fun n -> begin match get_vertex cfg n with | None -> () - | Some ((ssanodes, cfnode), preds, succs) -> + | Some ((ssa_elems, cfnode), preds, succs) -> let muxers = - ssanodes |> List.map (smt_ssanode ctx cfg preds) |> List.concat + ssa_elems |> List.map (smt_ssanode ctx cfg preds) |> List.concat in - let ctx = { ctx with pathcond = lazy (get_pathcond n cfg ctx) } in - let basic_block = smt_cfnode all_cdefs ctx ssanodes cfnode in + let ctx = { ctx with node = n; pathcond = lazy (get_pathcond n cfg ctx) } in + let basic_block = smt_cfnode all_cdefs ctx ssa_elems cfnode in push_smt_defs stack muxers; - push_smt_defs stack basic_block; + push_smt_defs stack basic_block end ) visit_order; - stack + stack, cfg let smt_cdef props lets name_file ctx all_cdefs = function | CDEF_spec (function_id, arg_ctyps, ret_ctyp) when Bindings.mem function_id props -> @@ -1853,7 +1942,7 @@ let smt_cdef props lets name_file ctx all_cdefs = function |> remove_pointless_goto in - let stack = smt_instr_list (string_of_id function_id) ctx all_cdefs instrs in + let stack, _ = smt_instr_list (string_of_id function_id) ctx all_cdefs instrs in let query = smt_query ctx pragma.query in push_smt_defs stack [Assert (Fn ("not", [query]))]; |
