summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/jib/jib_optimize.ml84
-rw-r--r--src/jib/jib_optimize.mli4
-rw-r--r--src/jib/jib_smt.ml39
-rw-r--r--src/jib/jib_util.ml5
-rw-r--r--src/smtlib.ml14
5 files changed, 121 insertions, 25 deletions
diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml
index 93abf498..3fc42aa3 100644
--- a/src/jib/jib_optimize.ml
+++ b/src/jib/jib_optimize.ml
@@ -82,7 +82,7 @@ let optimize_unit instrs =
let flat_counter = ref 0
let flat_id orig_id =
- let id = mk_id (string_of_name orig_id ^ "_local#" ^ string_of_int !flat_counter) in
+ let id = mk_id (string_of_name ~zencode:false orig_id ^ "_l#" ^ string_of_int !flat_counter) in
incr flat_counter;
name id
@@ -170,14 +170,29 @@ let rec cval_subst id subst = function
| V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> field, cval_subst id subst cval) fields, ctyp)
| V_poly (cval, ctyp) -> V_poly (cval_subst id subst cval, ctyp)
+let rec cval_map_id f = function
+ | V_id (id, ctyp) -> V_id (f id, ctyp)
+ | V_ref (id, ctyp) -> V_ref (f id, ctyp)
+ | V_lit (vl, ctyp) -> V_lit (vl, ctyp)
+ | V_call (call, cvals) -> V_call (call, List.map (cval_map_id f) cvals)
+ | V_op (cval1, op, cval2) -> V_op (cval_map_id f cval1, op, cval_map_id f cval2)
+ | V_unary (op, cval) -> V_unary (op, cval_map_id f cval)
+ | V_field (cval, field) -> V_field (cval_map_id f cval, field)
+ | V_tuple_member (cval, len, n) -> V_tuple_member (cval_map_id f cval, len, n)
+ | V_ctor_kind (cval, ctor, unifiers, ctyp) -> V_ctor_kind (cval_map_id f cval, ctor, unifiers, ctyp)
+ | V_ctor_unwrap (ctor, cval, unifiers, ctyp) -> V_ctor_unwrap (ctor, cval_map_id f cval, unifiers, ctyp)
+ | V_hd cval -> V_hd (cval_map_id f cval)
+ | V_tl cval -> V_tl (cval_map_id f cval)
+ | V_struct (fields, ctyp) ->
+ V_struct (List.map (fun (field, cval) -> field, cval_map_id f cval) fields, ctyp)
+ | V_poly (cval, ctyp) -> V_poly (cval_map_id f cval, ctyp)
+
let rec instrs_subst id subst =
function
| (I_aux (I_decl (_, id'), _) :: _) as instrs when Name.compare id id' = 0 ->
- prerr_endline ("DECL: " ^ string_of_name id);
instrs
| I_aux (I_init (ctyp, id', cval), aux) :: rest when Name.compare id id' = 0 ->
- prerr_endline ("INIT: " ^ string_of_name id);
I_aux (I_init (ctyp, id', cval_subst id subst cval), aux) :: rest
| (I_aux (I_reset (_, id'), _) :: _) as instrs when Name.compare id id' = 0 ->
@@ -215,10 +230,6 @@ let rec instrs_subst id subst =
| [] -> []
-let instrs_subst' id subst =
- prerr_endline (string_of_name id ^ " => " ^ string_of_cval subst);
- instrs_subst id subst
-
let rec clexp_subst id subst = function
| CL_id (id', ctyp) when Name.compare id id' = 0 ->
if ctyp_equal ctyp (clexp_ctyp subst) then
@@ -239,6 +250,12 @@ let rec find_function fid = function
| [] -> None
+let ssa_name i = function
+ | Name (id, _) -> Name (id, i)
+ | Have_exception _ -> Have_exception i
+ | Current_exception _ -> Current_exception i
+ | Return _ -> Return i
+
let inline cdefs should_inline instrs =
let inlines = ref (-1) in
let label_count = ref (-1) in
@@ -266,6 +283,26 @@ let inline cdefs should_inline instrs =
| instr -> instr
in
+ let fix_substs =
+ let f = cval_map_id (ssa_name (-1)) in
+ function
+ | I_aux (I_init (ctyp, id, cval), aux) ->
+ I_aux (I_init (ctyp, id, f cval), aux)
+ | I_aux (I_jump (cval, label), aux) ->
+ I_aux (I_jump (f cval, label), aux)
+ | I_aux (I_funcall (clexp, extern, function_id, args), aux) ->
+ I_aux (I_funcall (clexp, extern, function_id, List.map f args), aux)
+ | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) ->
+ I_aux (I_if (f cval, then_instrs, else_instrs, ctyp), aux)
+ | I_aux (I_copy (clexp, cval), aux) ->
+ I_aux (I_copy (clexp, f cval), aux)
+ | I_aux (I_return cval, aux) ->
+ I_aux (I_return (f cval), aux)
+ | I_aux (I_throw cval, aux) ->
+ I_aux (I_throw (f cval), aux)
+ | instr -> instr
+ in
+
let rec inline_instr = function
| I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline function_id ->
begin match find_function function_id cdefs with
@@ -273,7 +310,14 @@ let inline cdefs should_inline instrs =
incr inlines;
incr label_count;
let inline_label = label "end_inline_" in
- let body = List.fold_right2 instrs_subst' (List.map name ids) args body in
+ (* For situations where we have e.g. x => x' and x' => y, we
+ use a dummy SSA number turning this into x => x'/-2 and
+ x' => y/-2, ensuring x's won't get turned into y's. This
+ is undone by fix_substs which removes the -2 SSA
+ numbers. *)
+ let args = List.map (cval_map_id (ssa_name (-2))) args in
+ let body = List.fold_right2 instrs_subst (List.map name ids) args body in
+ let body = List.map (map_instr fix_substs) body in
let body = List.map (map_instr fix_labels) body in
let body = List.map (map_instr (replace_end inline_label)) body in
let body = List.map (map_instr (replace_return clexp)) body in
@@ -302,6 +346,30 @@ let inline cdefs should_inline instrs =
let rec remove_pointless_goto = function
| I_aux (I_goto label, _) :: I_aux (I_label label', aux) :: instrs when label = label' ->
I_aux (I_label label', aux) :: remove_pointless_goto instrs
+ | I_aux (I_goto label, aux) :: I_aux (I_goto _, _) :: instrs ->
+ I_aux (I_goto label, aux) :: remove_pointless_goto instrs
| instr :: instrs ->
instr :: remove_pointless_goto instrs
| [] -> []
+
+module StringSet = Set.Make(String)
+
+let rec get_used_labels set = function
+ | I_aux (I_goto label, _) :: instrs -> get_used_labels (StringSet.add label set) instrs
+ | I_aux (I_jump (_, label), _) :: instrs -> get_used_labels (StringSet.add label set) instrs
+ | _ :: instrs -> get_used_labels set instrs
+ | [] -> set
+
+let remove_unused_labels instrs =
+ let used = get_used_labels StringSet.empty instrs in
+ let rec go acc = function
+ | I_aux (I_label label, _) :: instrs when not (StringSet.mem label used) -> go acc instrs
+ | instr :: instrs -> go (instr :: acc) instrs
+ | [] -> List.rev acc
+ in
+ go [] instrs
+
+let rec remove_clear = function
+ | I_aux (I_clear _, _) :: instrs -> remove_clear instrs
+ | instr :: instrs -> instr :: remove_clear instrs
+ | [] -> []
diff --git a/src/jib/jib_optimize.mli b/src/jib/jib_optimize.mli
index e9f5cfcf..d992793c 100644
--- a/src/jib/jib_optimize.mli
+++ b/src/jib/jib_optimize.mli
@@ -65,5 +65,9 @@ val unique_per_function_ids : cdef list -> cdef list
val inline : cdef list -> (Ast.id -> bool) -> instr list -> instr list
+val remove_clear : instr list -> instr list
+
(** Remove gotos immediately followed by the label it jumps to *)
val remove_pointless_goto : instr list -> instr list
+
+val remove_unused_labels : instr list -> instr list
diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml
index c094cdaa..dd2fb0cb 100644
--- a/src/jib/jib_smt.ml
+++ b/src/jib/jib_smt.ml
@@ -203,7 +203,7 @@ let rec smt_cval env cval =
in
Fn (zencode_upper_id struct_id, List.map set_field fields)
| _ -> failwith "Struct does not have struct type"
- end
+ end
| V_tuple_member (frag, len, n) ->
Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval env frag])
| cval -> failwith ("Unrecognised cval " ^ string_of_cval ~zencode:false cval)
@@ -379,14 +379,14 @@ let bvmask len =
let all_ones = bvones (lbits_size ()) in
let shift = Fn ("concat", [bvzero (lbits_size () - !lbits_index); len]) in
bvnot (bvshl all_ones shift)
-
+
let builtin_ones env cval = function
| CT_fbits (n, _) -> bvones n
| CT_lbits _ ->
let len = extract (!lbits_index - 1) 0 (smt_cval env cval) in
- Fn ("Bits", [len; Fn ("bvand", [bvmask len; bvones (lbits_size ())])]);
+ Fn ("Bits", [len; Fn ("bvand", [bvmask len; bvones (lbits_size ())])]);
| ret_ctyp -> builtin_type_error "ones" [cval] (Some ret_ctyp)
-
+
(* [bvzeint esz cval] (BitVector Zero Extend INTeger), takes a cval
which must be an integer type (either CT_fint, or CT_lint), and
produces a bitvector which is either zero extended or truncated to
@@ -434,7 +434,7 @@ let builtin_sign_extend env vbits vlen ret_ctyp =
let bv = smt_cval env vbits in
let top_bit_one = Fn ("=", [Extract (n - 1, n - 1, bv); Bin "1"]) in
Ite (top_bit_one, Fn ("concat", [bvones (m - n); bv]), Fn ("concat", [bvzero (m - n); bv]))
-
+
| _ -> failwith "Cannot compile zero_extend"
let builtin_shift shiftop env vbits vshift ret_ctyp =
@@ -646,7 +646,7 @@ let builtin_replicate_bits env v1 v2 ret_ctyp =
Extract (!lbits_index - 1, 0, smt_cval env v2)])
in
assert false*)
-
+
| _ -> builtin_type_error "replicate_bits" [v1; v2] (Some ret_ctyp)
let builtin_sail_truncate env v1 v2 ret_ctyp =
@@ -768,6 +768,13 @@ let builtin_set_slice_bits env v1 v2 v3 v4 v5 ret_ctyp =
| _ -> builtin_type_error "set_slice" [v1; v2; v3; v4; v5] (Some ret_ctyp)
+let builtin_compare_bits fn env v1 v2 ret_ctyp =
+ match cval_ctyp v1, cval_ctyp v2 with
+ | CT_fbits (n, _), CT_fbits (m, _) when n = m ->
+ Fn (fn, [smt_cval env v1; smt_cval env v2])
+
+ | _ -> builtin_type_error fn [v1; v2] (Some ret_ctyp)
+
let smt_builtin env name args ret_ctyp =
match name, args, ret_ctyp with
| "eq_bits", [v1; v2], _ -> Fn ("=", [smt_cval env v1; smt_cval env v2])
@@ -797,6 +804,16 @@ let smt_builtin env name args ret_ctyp =
| "shr_mach_int", [v1; v2], _ -> builtin_shr_int env v1 v2 ret_ctyp
| "abs_int", [v], _ -> builtin_abs_int env v ret_ctyp
+ (* All signed and unsigned bitvector comparisons *)
+ | "slt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvslt" env v1 v2 ret_ctyp
+ | "ult_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvult" env v1 v2 ret_ctyp
+ | "sgt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsgt" env v1 v2 ret_ctyp
+ | "ugt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvugt" env v1 v2 ret_ctyp
+ | "slteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsle" env v1 v2 ret_ctyp
+ | "ulteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvule" env v1 v2 ret_ctyp
+ | "sgteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsge" env v1 v2 ret_ctyp
+ | "ugteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvuge" env v1 v2 ret_ctyp
+
(* lib/vector_dec.sail *)
| "zeros", [v], _ -> builtin_zeros env v ret_ctyp
| "sail_zeros", [v], _ -> builtin_zeros env v ret_ctyp
@@ -826,7 +843,7 @@ let smt_builtin env name args ret_ctyp =
| "slice", [v1; v2; v3], ret_ctyp -> builtin_slice env v1 v2 v3 ret_ctyp
| "get_slice_int", [v1; v2; v3], ret_ctyp -> builtin_get_slice_int env v1 v2 v3 ret_ctyp
| "set_slice", [v1; v2; v3; v4; v5], ret_ctyp -> builtin_set_slice_bits env v1 v2 v3 v4 v5 ret_ctyp
-
+
| _ -> failwith ("Bad builtin " ^ name ^ " " ^ Util.string_of_list ", " string_of_ctyp (List.map cval_ctyp args) ^ " -> " ^ string_of_ctyp ret_ctyp)
let rec smt_conversion from_ctyp to_ctyp x =
@@ -1292,6 +1309,7 @@ let smt_cdef props name_file env all_cdefs = function
(* |> optimize_unit *)
|> inline all_cdefs (fun _ -> true)
|> flatten_instrs
+ |> remove_unused_labels
|> remove_pointless_goto
in
@@ -1320,8 +1338,8 @@ let smt_cdef props name_file env all_cdefs = function
) visit_order;
let out_chan = open_out (name_file (string_of_id function_id)) in
- output_string out_chan "(set-option :produce-models true)\n";
- (*output_string out_chan "(set-logic QF_AUFBVDT)\n";*)
+ (* output_string out_chan "(set-option :produce-models true)\n"; *)
+ output_string out_chan "(set-logic QF_AUFBVDT)\n";
(* let stack' = Stack.create () in
Stack.iter (fun def -> Stack.push def stack') stack;
@@ -1330,7 +1348,7 @@ let smt_cdef props name_file env all_cdefs = function
Queue.iter (fun def -> output_string out_chan (string_of_smt_def def); output_string out_chan "\n") queue;
output_string out_chan "(check-sat)\n";
- output_string out_chan "(get-model)\n"
+ (* output_string out_chan "(get-model)\n" *)
| _ -> failwith "Bad function body"
end
@@ -1355,6 +1373,7 @@ let generate_smt props name_file env ast =
let t = Profile.start () in
let cdefs, ctx = compile_ast { ctx with specialize_calls = true; ignore_64 = true; struct_value = true } ast in
Profile.finish "Compiling to Jib IR" t;
+ let cdefs = Jib_optimize.unique_per_function_ids cdefs in
smt_cdefs props name_file env cdefs cdefs
with
diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml
index d9edd9c4..df2ce369 100644
--- a/src/jib/jib_util.ml
+++ b/src/jib/jib_util.ml
@@ -266,7 +266,7 @@ let string_of_value = function
| VL_string str -> "\"" ^ str ^ "\""
let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) =
- let ssa_num n = if n < 0 then "" else ("/" ^ string_of_int n) in
+ let ssa_num n = if n = -1 then "" else ("/" ^ string_of_int n) in
function
| Name (id, n) ->
(if zencode then Util.zencode_string (string_of_id id) else string_of_id id) ^ ssa_num n
@@ -1047,5 +1047,6 @@ let rec instrs_rename from_id to_id =
| I_aux (I_block block, aux) :: instrs -> I_aux (I_block (irename block), aux) :: irename instrs
| I_aux (I_try_block block, aux) :: instrs -> I_aux (I_try_block (irename block), aux) :: irename instrs
| I_aux (I_throw cval, aux) :: instrs -> I_aux (I_throw (crename cval), aux) :: irename instrs
- | (I_aux ((I_comment _ | I_raw _ | I_end _ | I_label _ | I_goto _ | I_match_failure | I_undefined _), _) as instr) :: instrs -> instr :: irename instrs
+ | I_aux (I_end id, aux) :: instrs -> I_aux (I_end (rename id), aux) :: irename instrs
+ | (I_aux ((I_comment _ | I_raw _ | I_label _ | I_goto _ | I_match_failure | I_undefined _), _) as instr) :: instrs -> instr :: irename instrs
| [] -> []
diff --git a/src/smtlib.ml b/src/smtlib.ml
index abac05f2..7d6bb564 100644
--- a/src/smtlib.ml
+++ b/src/smtlib.ml
@@ -119,17 +119,21 @@ let bvones n =
else
Bin (String.concat "" (Util.list_init n (fun _ -> "1")))
-let rec simp_fn = function
+let simp_fn = function
| Fn ("not", [Fn ("not", [exp])]) -> exp
+ | Fn ("contents", [Fn ("Bits", [_; contents])]) -> contents
+ | Fn ("len", [Fn ("Bits", [len; _])]) -> len
| exp -> exp
-let rec simp_ite = function
+let simp_ite = function
| Ite (cond, Bool_lit true, Bool_lit false) -> cond
| Ite (_, Var v, Var v') when v = v' -> Var v
+ | Ite (Bool_lit true, then_exp, _) -> then_exp
+ | Ite (Bool_lit false, _, else_exp) -> else_exp
| exp -> exp
-
+
let rec simp_smt_exp vars = function
- | Var v ->
+ | Var v ->
begin match Hashtbl.find_opt vars v with
| Some exp -> simp_smt_exp vars exp
| None -> Var v
@@ -185,7 +189,7 @@ let rec pp_smt_exp =
parens (string (Printf.sprintf "(_ is %s)" kind) ^^ space ^^ pp_smt_exp exp)
| SignExtend (i, exp) ->
parens (string (Printf.sprintf "(_ sign_extend %d)" i) ^^ space ^^ pp_smt_exp exp)
-
+
let rec pp_smt_typ =
let open PPrint in
function