diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/jib/jib_optimize.ml | 84 | ||||
| -rw-r--r-- | src/jib/jib_optimize.mli | 4 | ||||
| -rw-r--r-- | src/jib/jib_smt.ml | 39 | ||||
| -rw-r--r-- | src/jib/jib_util.ml | 5 | ||||
| -rw-r--r-- | src/smtlib.ml | 14 |
5 files changed, 121 insertions, 25 deletions
diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml index 93abf498..3fc42aa3 100644 --- a/src/jib/jib_optimize.ml +++ b/src/jib/jib_optimize.ml @@ -82,7 +82,7 @@ let optimize_unit instrs = let flat_counter = ref 0 let flat_id orig_id = - let id = mk_id (string_of_name orig_id ^ "_local#" ^ string_of_int !flat_counter) in + let id = mk_id (string_of_name ~zencode:false orig_id ^ "_l#" ^ string_of_int !flat_counter) in incr flat_counter; name id @@ -170,14 +170,29 @@ let rec cval_subst id subst = function | V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> field, cval_subst id subst cval) fields, ctyp) | V_poly (cval, ctyp) -> V_poly (cval_subst id subst cval, ctyp) +let rec cval_map_id f = function + | V_id (id, ctyp) -> V_id (f id, ctyp) + | V_ref (id, ctyp) -> V_ref (f id, ctyp) + | V_lit (vl, ctyp) -> V_lit (vl, ctyp) + | V_call (call, cvals) -> V_call (call, List.map (cval_map_id f) cvals) + | V_op (cval1, op, cval2) -> V_op (cval_map_id f cval1, op, cval_map_id f cval2) + | V_unary (op, cval) -> V_unary (op, cval_map_id f cval) + | V_field (cval, field) -> V_field (cval_map_id f cval, field) + | V_tuple_member (cval, len, n) -> V_tuple_member (cval_map_id f cval, len, n) + | V_ctor_kind (cval, ctor, unifiers, ctyp) -> V_ctor_kind (cval_map_id f cval, ctor, unifiers, ctyp) + | V_ctor_unwrap (ctor, cval, unifiers, ctyp) -> V_ctor_unwrap (ctor, cval_map_id f cval, unifiers, ctyp) + | V_hd cval -> V_hd (cval_map_id f cval) + | V_tl cval -> V_tl (cval_map_id f cval) + | V_struct (fields, ctyp) -> + V_struct (List.map (fun (field, cval) -> field, cval_map_id f cval) fields, ctyp) + | V_poly (cval, ctyp) -> V_poly (cval_map_id f cval, ctyp) + let rec instrs_subst id subst = function | (I_aux (I_decl (_, id'), _) :: _) as instrs when Name.compare id id' = 0 -> - prerr_endline ("DECL: " ^ string_of_name id); instrs | I_aux (I_init (ctyp, id', cval), aux) :: rest when Name.compare id id' = 0 -> - prerr_endline ("INIT: " ^ string_of_name id); I_aux (I_init (ctyp, id', cval_subst id subst cval), aux) :: rest | (I_aux (I_reset (_, id'), _) :: _) as instrs when Name.compare id id' = 0 -> @@ -215,10 +230,6 @@ let rec instrs_subst id subst = | [] -> [] -let instrs_subst' id subst = - prerr_endline (string_of_name id ^ " => " ^ string_of_cval subst); - instrs_subst id subst - let rec clexp_subst id subst = function | CL_id (id', ctyp) when Name.compare id id' = 0 -> if ctyp_equal ctyp (clexp_ctyp subst) then @@ -239,6 +250,12 @@ let rec find_function fid = function | [] -> None +let ssa_name i = function + | Name (id, _) -> Name (id, i) + | Have_exception _ -> Have_exception i + | Current_exception _ -> Current_exception i + | Return _ -> Return i + let inline cdefs should_inline instrs = let inlines = ref (-1) in let label_count = ref (-1) in @@ -266,6 +283,26 @@ let inline cdefs should_inline instrs = | instr -> instr in + let fix_substs = + let f = cval_map_id (ssa_name (-1)) in + function + | I_aux (I_init (ctyp, id, cval), aux) -> + I_aux (I_init (ctyp, id, f cval), aux) + | I_aux (I_jump (cval, label), aux) -> + I_aux (I_jump (f cval, label), aux) + | I_aux (I_funcall (clexp, extern, function_id, args), aux) -> + I_aux (I_funcall (clexp, extern, function_id, List.map f args), aux) + | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) -> + I_aux (I_if (f cval, then_instrs, else_instrs, ctyp), aux) + | I_aux (I_copy (clexp, cval), aux) -> + I_aux (I_copy (clexp, f cval), aux) + | I_aux (I_return cval, aux) -> + I_aux (I_return (f cval), aux) + | I_aux (I_throw cval, aux) -> + I_aux (I_throw (f cval), aux) + | instr -> instr + in + let rec inline_instr = function | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline function_id -> begin match find_function function_id cdefs with @@ -273,7 +310,14 @@ let inline cdefs should_inline instrs = incr inlines; incr label_count; let inline_label = label "end_inline_" in - let body = List.fold_right2 instrs_subst' (List.map name ids) args body in + (* For situations where we have e.g. x => x' and x' => y, we + use a dummy SSA number turning this into x => x'/-2 and + x' => y/-2, ensuring x's won't get turned into y's. This + is undone by fix_substs which removes the -2 SSA + numbers. *) + let args = List.map (cval_map_id (ssa_name (-2))) args in + let body = List.fold_right2 instrs_subst (List.map name ids) args body in + let body = List.map (map_instr fix_substs) body in let body = List.map (map_instr fix_labels) body in let body = List.map (map_instr (replace_end inline_label)) body in let body = List.map (map_instr (replace_return clexp)) body in @@ -302,6 +346,30 @@ let inline cdefs should_inline instrs = let rec remove_pointless_goto = function | I_aux (I_goto label, _) :: I_aux (I_label label', aux) :: instrs when label = label' -> I_aux (I_label label', aux) :: remove_pointless_goto instrs + | I_aux (I_goto label, aux) :: I_aux (I_goto _, _) :: instrs -> + I_aux (I_goto label, aux) :: remove_pointless_goto instrs | instr :: instrs -> instr :: remove_pointless_goto instrs | [] -> [] + +module StringSet = Set.Make(String) + +let rec get_used_labels set = function + | I_aux (I_goto label, _) :: instrs -> get_used_labels (StringSet.add label set) instrs + | I_aux (I_jump (_, label), _) :: instrs -> get_used_labels (StringSet.add label set) instrs + | _ :: instrs -> get_used_labels set instrs + | [] -> set + +let remove_unused_labels instrs = + let used = get_used_labels StringSet.empty instrs in + let rec go acc = function + | I_aux (I_label label, _) :: instrs when not (StringSet.mem label used) -> go acc instrs + | instr :: instrs -> go (instr :: acc) instrs + | [] -> List.rev acc + in + go [] instrs + +let rec remove_clear = function + | I_aux (I_clear _, _) :: instrs -> remove_clear instrs + | instr :: instrs -> instr :: remove_clear instrs + | [] -> [] diff --git a/src/jib/jib_optimize.mli b/src/jib/jib_optimize.mli index e9f5cfcf..d992793c 100644 --- a/src/jib/jib_optimize.mli +++ b/src/jib/jib_optimize.mli @@ -65,5 +65,9 @@ val unique_per_function_ids : cdef list -> cdef list val inline : cdef list -> (Ast.id -> bool) -> instr list -> instr list +val remove_clear : instr list -> instr list + (** Remove gotos immediately followed by the label it jumps to *) val remove_pointless_goto : instr list -> instr list + +val remove_unused_labels : instr list -> instr list diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index c094cdaa..dd2fb0cb 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -203,7 +203,7 @@ let rec smt_cval env cval = in Fn (zencode_upper_id struct_id, List.map set_field fields) | _ -> failwith "Struct does not have struct type" - end + end | V_tuple_member (frag, len, n) -> Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval env frag]) | cval -> failwith ("Unrecognised cval " ^ string_of_cval ~zencode:false cval) @@ -379,14 +379,14 @@ let bvmask len = let all_ones = bvones (lbits_size ()) in let shift = Fn ("concat", [bvzero (lbits_size () - !lbits_index); len]) in bvnot (bvshl all_ones shift) - + let builtin_ones env cval = function | CT_fbits (n, _) -> bvones n | CT_lbits _ -> let len = extract (!lbits_index - 1) 0 (smt_cval env cval) in - Fn ("Bits", [len; Fn ("bvand", [bvmask len; bvones (lbits_size ())])]); + Fn ("Bits", [len; Fn ("bvand", [bvmask len; bvones (lbits_size ())])]); | ret_ctyp -> builtin_type_error "ones" [cval] (Some ret_ctyp) - + (* [bvzeint esz cval] (BitVector Zero Extend INTeger), takes a cval which must be an integer type (either CT_fint, or CT_lint), and produces a bitvector which is either zero extended or truncated to @@ -434,7 +434,7 @@ let builtin_sign_extend env vbits vlen ret_ctyp = let bv = smt_cval env vbits in let top_bit_one = Fn ("=", [Extract (n - 1, n - 1, bv); Bin "1"]) in Ite (top_bit_one, Fn ("concat", [bvones (m - n); bv]), Fn ("concat", [bvzero (m - n); bv])) - + | _ -> failwith "Cannot compile zero_extend" let builtin_shift shiftop env vbits vshift ret_ctyp = @@ -646,7 +646,7 @@ let builtin_replicate_bits env v1 v2 ret_ctyp = Extract (!lbits_index - 1, 0, smt_cval env v2)]) in assert false*) - + | _ -> builtin_type_error "replicate_bits" [v1; v2] (Some ret_ctyp) let builtin_sail_truncate env v1 v2 ret_ctyp = @@ -768,6 +768,13 @@ let builtin_set_slice_bits env v1 v2 v3 v4 v5 ret_ctyp = | _ -> builtin_type_error "set_slice" [v1; v2; v3; v4; v5] (Some ret_ctyp) +let builtin_compare_bits fn env v1 v2 ret_ctyp = + match cval_ctyp v1, cval_ctyp v2 with + | CT_fbits (n, _), CT_fbits (m, _) when n = m -> + Fn (fn, [smt_cval env v1; smt_cval env v2]) + + | _ -> builtin_type_error fn [v1; v2] (Some ret_ctyp) + let smt_builtin env name args ret_ctyp = match name, args, ret_ctyp with | "eq_bits", [v1; v2], _ -> Fn ("=", [smt_cval env v1; smt_cval env v2]) @@ -797,6 +804,16 @@ let smt_builtin env name args ret_ctyp = | "shr_mach_int", [v1; v2], _ -> builtin_shr_int env v1 v2 ret_ctyp | "abs_int", [v], _ -> builtin_abs_int env v ret_ctyp + (* All signed and unsigned bitvector comparisons *) + | "slt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvslt" env v1 v2 ret_ctyp + | "ult_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvult" env v1 v2 ret_ctyp + | "sgt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsgt" env v1 v2 ret_ctyp + | "ugt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvugt" env v1 v2 ret_ctyp + | "slteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsle" env v1 v2 ret_ctyp + | "ulteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvule" env v1 v2 ret_ctyp + | "sgteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsge" env v1 v2 ret_ctyp + | "ugteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvuge" env v1 v2 ret_ctyp + (* lib/vector_dec.sail *) | "zeros", [v], _ -> builtin_zeros env v ret_ctyp | "sail_zeros", [v], _ -> builtin_zeros env v ret_ctyp @@ -826,7 +843,7 @@ let smt_builtin env name args ret_ctyp = | "slice", [v1; v2; v3], ret_ctyp -> builtin_slice env v1 v2 v3 ret_ctyp | "get_slice_int", [v1; v2; v3], ret_ctyp -> builtin_get_slice_int env v1 v2 v3 ret_ctyp | "set_slice", [v1; v2; v3; v4; v5], ret_ctyp -> builtin_set_slice_bits env v1 v2 v3 v4 v5 ret_ctyp - + | _ -> failwith ("Bad builtin " ^ name ^ " " ^ Util.string_of_list ", " string_of_ctyp (List.map cval_ctyp args) ^ " -> " ^ string_of_ctyp ret_ctyp) let rec smt_conversion from_ctyp to_ctyp x = @@ -1292,6 +1309,7 @@ let smt_cdef props name_file env all_cdefs = function (* |> optimize_unit *) |> inline all_cdefs (fun _ -> true) |> flatten_instrs + |> remove_unused_labels |> remove_pointless_goto in @@ -1320,8 +1338,8 @@ let smt_cdef props name_file env all_cdefs = function ) visit_order; let out_chan = open_out (name_file (string_of_id function_id)) in - output_string out_chan "(set-option :produce-models true)\n"; - (*output_string out_chan "(set-logic QF_AUFBVDT)\n";*) + (* output_string out_chan "(set-option :produce-models true)\n"; *) + output_string out_chan "(set-logic QF_AUFBVDT)\n"; (* let stack' = Stack.create () in Stack.iter (fun def -> Stack.push def stack') stack; @@ -1330,7 +1348,7 @@ let smt_cdef props name_file env all_cdefs = function Queue.iter (fun def -> output_string out_chan (string_of_smt_def def); output_string out_chan "\n") queue; output_string out_chan "(check-sat)\n"; - output_string out_chan "(get-model)\n" + (* output_string out_chan "(get-model)\n" *) | _ -> failwith "Bad function body" end @@ -1355,6 +1373,7 @@ let generate_smt props name_file env ast = let t = Profile.start () in let cdefs, ctx = compile_ast { ctx with specialize_calls = true; ignore_64 = true; struct_value = true } ast in Profile.finish "Compiling to Jib IR" t; + let cdefs = Jib_optimize.unique_per_function_ids cdefs in smt_cdefs props name_file env cdefs cdefs with diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml index d9edd9c4..df2ce369 100644 --- a/src/jib/jib_util.ml +++ b/src/jib/jib_util.ml @@ -266,7 +266,7 @@ let string_of_value = function | VL_string str -> "\"" ^ str ^ "\"" let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) = - let ssa_num n = if n < 0 then "" else ("/" ^ string_of_int n) in + let ssa_num n = if n = -1 then "" else ("/" ^ string_of_int n) in function | Name (id, n) -> (if zencode then Util.zencode_string (string_of_id id) else string_of_id id) ^ ssa_num n @@ -1047,5 +1047,6 @@ let rec instrs_rename from_id to_id = | I_aux (I_block block, aux) :: instrs -> I_aux (I_block (irename block), aux) :: irename instrs | I_aux (I_try_block block, aux) :: instrs -> I_aux (I_try_block (irename block), aux) :: irename instrs | I_aux (I_throw cval, aux) :: instrs -> I_aux (I_throw (crename cval), aux) :: irename instrs - | (I_aux ((I_comment _ | I_raw _ | I_end _ | I_label _ | I_goto _ | I_match_failure | I_undefined _), _) as instr) :: instrs -> instr :: irename instrs + | I_aux (I_end id, aux) :: instrs -> I_aux (I_end (rename id), aux) :: irename instrs + | (I_aux ((I_comment _ | I_raw _ | I_label _ | I_goto _ | I_match_failure | I_undefined _), _) as instr) :: instrs -> instr :: irename instrs | [] -> [] diff --git a/src/smtlib.ml b/src/smtlib.ml index abac05f2..7d6bb564 100644 --- a/src/smtlib.ml +++ b/src/smtlib.ml @@ -119,17 +119,21 @@ let bvones n = else Bin (String.concat "" (Util.list_init n (fun _ -> "1"))) -let rec simp_fn = function +let simp_fn = function | Fn ("not", [Fn ("not", [exp])]) -> exp + | Fn ("contents", [Fn ("Bits", [_; contents])]) -> contents + | Fn ("len", [Fn ("Bits", [len; _])]) -> len | exp -> exp -let rec simp_ite = function +let simp_ite = function | Ite (cond, Bool_lit true, Bool_lit false) -> cond | Ite (_, Var v, Var v') when v = v' -> Var v + | Ite (Bool_lit true, then_exp, _) -> then_exp + | Ite (Bool_lit false, _, else_exp) -> else_exp | exp -> exp - + let rec simp_smt_exp vars = function - | Var v -> + | Var v -> begin match Hashtbl.find_opt vars v with | Some exp -> simp_smt_exp vars exp | None -> Var v @@ -185,7 +189,7 @@ let rec pp_smt_exp = parens (string (Printf.sprintf "(_ is %s)" kind) ^^ space ^^ pp_smt_exp exp) | SignExtend (i, exp) -> parens (string (Printf.sprintf "(_ sign_extend %d)" i) ^^ space ^^ pp_smt_exp exp) - + let rec pp_smt_typ = let open PPrint in function |
