diff options
| author | jp | 2020-02-12 17:46:48 +0000 |
|---|---|---|
| committer | jp | 2020-02-12 17:46:48 +0000 |
| commit | ed8bccd927306551f93d5aab8d0e2a92b9e5d227 (patch) | |
| tree | 55bf788c8155f0c7d024f2147f5eb3873729b02a /src/jib | |
| parent | 31a65c9b7383d2a87da0fbcf5c265d533146ac23 (diff) | |
| parent | 4a72cb8084237161d0bccc66f27d5fb6d24315e0 (diff) | |
Merge branch 'sail2' of https://github.com/rems-project/sail into sail2
Diffstat (limited to 'src/jib')
| -rw-r--r-- | src/jib/c_backend.ml | 863 | ||||
| -rw-r--r-- | src/jib/c_backend.mli | 3 | ||||
| -rw-r--r-- | src/jib/jib_compile.ml | 169 | ||||
| -rw-r--r-- | src/jib/jib_compile.mli | 78 | ||||
| -rw-r--r-- | src/jib/jib_ir.ml | 22 | ||||
| -rw-r--r-- | src/jib/jib_optimize.ml | 27 | ||||
| -rw-r--r-- | src/jib/jib_smt.ml | 847 | ||||
| -rw-r--r-- | src/jib/jib_smt.mli | 45 | ||||
| -rw-r--r-- | src/jib/jib_smt_fuzz.ml | 8 | ||||
| -rw-r--r-- | src/jib/jib_ssa.ml | 4 | ||||
| -rw-r--r-- | src/jib/jib_util.ml | 74 |
11 files changed, 1234 insertions, 906 deletions
diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index 98ee5bc1..2b144d35 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -100,7 +100,9 @@ let zencode_uid (id, ctyps) = match ctyps with | [] -> Util.zencode_string (string_of_id id) | _ -> Util.zencode_string (string_of_id id ^ "#" ^ Util.string_of_list "_" string_of_ctyp ctyps) - + +let ctor_bindings = List.fold_left (fun map (id, ctyp) -> UBindings.add id ctyp map) UBindings.empty + (**************************************************************************) (* 2. Converting sail types to C types *) (**************************************************************************) @@ -108,90 +110,9 @@ let zencode_uid (id, ctyps) = let max_int n = Big_int.pred (Big_int.pow_int_positive 2 (n - 1)) let min_int n = Big_int.negate (Big_int.pow_int_positive 2 (n - 1)) -(** Convert a sail type into a C-type. This function can be quite - slow, because it uses ctx.local_env and SMT to analyse the Sail - types and attempts to fit them into the smallest possible C - types, provided ctx.optimize_smt is true (default) **) -let rec ctyp_of_typ ctx typ = - let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.tc_env typ in - match typ_aux with - | Typ_id id when string_of_id id = "bit" -> CT_bit - | Typ_id id when string_of_id id = "bool" -> CT_bool - | Typ_id id when string_of_id id = "int" -> CT_lint - | Typ_id id when string_of_id id = "nat" -> CT_lint - | Typ_id id when string_of_id id = "unit" -> CT_unit - | Typ_id id when string_of_id id = "string" -> CT_string - | Typ_id id when string_of_id id = "real" -> CT_real - - | Typ_app (id, _) when string_of_id id = "atom_bool" -> CT_bool - - | Typ_app (id, args) when string_of_id id = "itself" -> - ctyp_of_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) - | Typ_app (id, _) when string_of_id id = "range" || string_of_id id = "atom" || string_of_id id = "implicit" -> - begin match destruct_range Env.empty typ with - | None -> assert false (* Checked if range type in guard *) - | Some (kids, constr, n, m) -> - let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env }in - match nexp_simp n, nexp_simp m with - | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) - when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> - CT_fint 64 - | n, m -> - if prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) then - CT_fint 64 - else - CT_lint - end - - | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> - CT_list (ctyp_of_typ ctx typ) - - (* When converting a sail bitvector type into C, we have three options in order of efficiency: - - If the length is obviously static and smaller than 64, use the fixed bits type (aka uint64_t), fbits. - - If the length is less than 64, then use a small bits type, sbits. - - If the length may be larger than 64, use a large bits type lbits. *) - | Typ_app (id, [A_aux (A_nexp n, _); - A_aux (A_order ord, _)]) - when string_of_id id = "bitvector" -> - let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in - begin match nexp_simp n with - | Nexp_aux (Nexp_constant n, _) when Big_int.less_equal n (Big_int.of_int 64) -> CT_fbits (Big_int.to_int n, direction) - | n when prove __POS__ ctx.local_env (nc_lteq n (nint 64)) -> CT_sbits (64, direction) - | _ -> CT_lbits direction - end - - | Typ_app (id, [A_aux (A_nexp n, _); - A_aux (A_order ord, _); - A_aux (A_typ typ, _)]) - when string_of_id id = "vector" -> - let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in - CT_vector (direction, ctyp_of_typ ctx typ) - - | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> - CT_ref (ctyp_of_typ ctx typ) - - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> UBindings.bindings) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> UBindings.bindings) - | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) - - | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) - - | Typ_exist _ -> - (* Use Type_check.destruct_exist when optimising with SMT, to - ensure that we don't cause any type variable clashes in - local_env, and that we can optimize the existential based upon - it's constraints. *) - begin match destruct_exist (Env.expand_synonyms ctx.local_env typ) with - | Some (kids, nc, typ) -> - let env = add_existential l kids nc ctx.local_env in - ctyp_of_typ { ctx with local_env = env } typ - | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!") - end - - | Typ_var kid -> CT_poly - - | _ -> c_error ~loc:l ("No C type for type " ^ string_of_typ typ) - +(** This function is used to split types into those we allocate on the + stack, versus those which need to live on the heap, or otherwise + require some additional memory management *) let rec is_stack_ctyp ctyp = match ctyp with | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_enum _ -> true | CT_fint n -> n <= 64 @@ -199,7 +120,7 @@ let rec is_stack_ctyp ctyp = match ctyp with | CT_lint -> false | CT_lbits _ when !optimize_fixed_bits -> true | CT_lbits _ -> false - | CT_real | CT_string | CT_list _ | CT_vector _ -> false + | CT_real | CT_string | CT_list _ | CT_vector _ | CT_fvector _ -> false | CT_struct (_, fields) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) fields | CT_variant (_, ctors) -> false (* List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors *) (* FIXME *) | CT_tup ctyps -> List.for_all is_stack_ctyp ctyps @@ -207,346 +128,442 @@ let rec is_stack_ctyp ctyp = match ctyp with | CT_poly -> true | CT_constant n -> Big_int.less_equal (min_int 64) n && Big_int.greater_equal n (max_int 64) -let is_stack_typ ctx typ = is_stack_ctyp (ctyp_of_typ ctx typ) - -let is_fbits_typ ctx typ = - match ctyp_of_typ ctx typ with - | CT_fbits _ -> true - | _ -> false - -let is_sbits_typ ctx typ = - match ctyp_of_typ ctx typ with - | CT_sbits _ -> true - | _ -> false - -let ctor_bindings = List.fold_left (fun map (id, ctyp) -> UBindings.add id ctyp map) UBindings.empty - -(**************************************************************************) -(* 3. Optimization of primitives and literals *) -(**************************************************************************) - -let hex_char = - let open Sail2_values in - function - | '0' -> [B0; B0; B0; B0] - | '1' -> [B0; B0; B0; B1] - | '2' -> [B0; B0; B1; B0] - | '3' -> [B0; B0; B1; B1] - | '4' -> [B0; B1; B0; B0] - | '5' -> [B0; B1; B0; B1] - | '6' -> [B0; B1; B1; B0] - | '7' -> [B0; B1; B1; B1] - | '8' -> [B1; B0; B0; B0] - | '9' -> [B1; B0; B0; B1] - | 'A' | 'a' -> [B1; B0; B1; B0] - | 'B' | 'b' -> [B1; B0; B1; B1] - | 'C' | 'c' -> [B1; B1; B0; B0] - | 'D' | 'd' -> [B1; B1; B0; B1] - | 'E' | 'e' -> [B1; B1; B1; B0] - | 'F' | 'f' -> [B1; B1; B1; B1] - | _ -> failwith "Invalid hex character" - -let literal_to_fragment (L_aux (l_aux, _) as lit) = - match l_aux with - | L_num n when Big_int.less_equal (min_int 64) n && Big_int.less_equal n (max_int 64) -> - Some (V_lit (VL_int n, CT_fint 64)) - | L_hex str when String.length str <= 16 -> - let padding = 16 - String.length str in - let padding = Util.list_init padding (fun _ -> Sail2_values.B0) in - let content = Util.string_to_list str |> List.map hex_char |> List.concat in - Some (V_lit (VL_bits (padding @ content, true), CT_fbits (String.length str * 4, true))) - | L_unit -> Some (V_lit (VL_unit, CT_unit)) - | L_true -> Some (V_lit (VL_bool true, CT_bool)) - | L_false -> Some (V_lit (VL_bool false, CT_bool)) - | _ -> None - -let c_literals ctx = - let rec c_literal env l = function - | AV_lit (lit, typ) as v when is_stack_ctyp (ctyp_of_typ { ctx with local_env = env } typ) -> - begin - match literal_to_fragment lit with - | Some cval -> AV_cval (cval, typ) - | None -> v - end - | AV_tuple avals -> AV_tuple (List.map (c_literal env l) avals) - | v -> v - in - map_aval c_literal - -let rec is_bitvector = function - | [] -> true - | AV_lit (L_aux (L_zero, _), _) :: avals -> is_bitvector avals - | AV_lit (L_aux (L_one, _), _) :: avals -> is_bitvector avals - | _ :: _ -> false - -let rec value_of_aval_bit = function - | AV_lit (L_aux (L_zero, _), _) -> Sail2_values.B0 - | AV_lit (L_aux (L_one, _), _) -> Sail2_values.B1 - | _ -> assert false - -(** Used to make sure the -Ofixed_int and -Ofixed_bits don't interfere - with assumptions made about optimizations in the common case. *) -let rec never_optimize = function - | CT_lbits _ | CT_lint -> true - | _ -> false - -let rec c_aval ctx = function - | AV_lit (lit, typ) as v -> - begin - match literal_to_fragment lit with - | Some cval -> AV_cval (cval, typ) - | None -> v - end - | AV_cval (cval, typ) -> AV_cval (cval, typ) - (* An id can be converted to a C fragment if it's type can be - stack-allocated. *) - | AV_id (id, lvar) as v -> - begin - match lvar with - | Local (_, typ) -> - let ctyp = ctyp_of_typ ctx typ in - if is_stack_ctyp ctyp && not (never_optimize ctyp) then - begin - try - (* We need to check that id's type hasn't changed due to flow typing *) - let _, ctyp' = Bindings.find id ctx.locals in - if ctyp_equal ctyp ctyp' then - AV_cval (V_id (name id, ctyp), typ) - else - (* id's type changed due to flow - typing, so it's really still heap allocated! *) - v - with - (* Hack: Assuming global letbindings don't change from flow typing... *) - Not_found -> AV_cval (V_id (name id, ctyp), typ) - end - else - v - | Register (_, _, typ) -> - let ctyp = ctyp_of_typ ctx typ in - if is_stack_ctyp ctyp && not (never_optimize ctyp) then - AV_cval (V_id (name id, ctyp), typ) - else - v - | _ -> v - end - | AV_vector (v, typ) when is_bitvector v && List.length v <= 64 -> - let bitstring = VL_bits (List.map value_of_aval_bit v, true) in - AV_cval (V_lit (bitstring, CT_fbits (List.length v, true)), typ) - | AV_tuple avals -> AV_tuple (List.map (c_aval ctx) avals) - | aval -> aval - -let c_fragment = function - | AV_cval (cval, _) -> cval - | _ -> assert false - let v_mask_lower i = V_lit (VL_bits (Util.list_init i (fun _ -> Sail2_values.B1), true), CT_fbits (i, true)) -(* Map over all the functions in an aexp. *) -let rec analyze_functions ctx f (AE_aux (aexp, env, l)) = - let ctx = { ctx with local_env = env } in - let aexp = match aexp with - | AE_app (id, vs, typ) -> f ctx id vs typ +module C_config : Config = struct - | AE_cast (aexp, typ) -> AE_cast (analyze_functions ctx f aexp, typ) +(** Convert a sail type into a C-type. This function can be quite + slow, because it uses ctx.local_env and SMT to analyse the Sail + types and attempts to fit them into the smallest possible C + types, provided ctx.optimize_smt is true (default) **) + let rec convert_typ ctx typ = + let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.tc_env typ in + match typ_aux with + | Typ_id id when string_of_id id = "bit" -> CT_bit + | Typ_id id when string_of_id id = "bool" -> CT_bool + | Typ_id id when string_of_id id = "int" -> CT_lint + | Typ_id id when string_of_id id = "nat" -> CT_lint + | Typ_id id when string_of_id id = "unit" -> CT_unit + | Typ_id id when string_of_id id = "string" -> CT_string + | Typ_id id when string_of_id id = "real" -> CT_real + + | Typ_app (id, _) when string_of_id id = "atom_bool" -> CT_bool + + | Typ_app (id, args) when string_of_id id = "itself" -> + convert_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) + | Typ_app (id, _) when string_of_id id = "range" || string_of_id id = "atom" || string_of_id id = "implicit" -> + begin match destruct_range Env.empty typ with + | None -> assert false (* Checked if range type in guard *) + | Some (kids, constr, n, m) -> + let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env }in + match nexp_simp n, nexp_simp m with + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> + CT_fint 64 + | n, m -> + if prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) then + CT_fint 64 + else + CT_lint + end - | AE_assign (id, typ, aexp) -> AE_assign (id, typ, analyze_functions ctx f aexp) + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> + CT_list (convert_typ ctx typ) + + (* When converting a sail bitvector type into C, we have three options in order of efficiency: + - If the length is obviously static and smaller than 64, use the fixed bits type (aka uint64_t), fbits. + - If the length is less than 64, then use a small bits type, sbits. + - If the length may be larger than 64, use a large bits type lbits. *) + | Typ_app (id, [A_aux (A_nexp n, _); + A_aux (A_order ord, _)]) + when string_of_id id = "bitvector" -> + let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in + begin match nexp_simp n with + | Nexp_aux (Nexp_constant n, _) when Big_int.less_equal n (Big_int.of_int 64) -> CT_fbits (Big_int.to_int n, direction) + | n when prove __POS__ ctx.local_env (nc_lteq n (nint 64)) -> CT_sbits (64, direction) + | _ -> CT_lbits direction + end - | AE_write_ref (id, typ, aexp) -> AE_write_ref (id, typ, analyze_functions ctx f aexp) + | Typ_app (id, [A_aux (A_nexp n, _); + A_aux (A_order ord, _); + A_aux (A_typ typ, _)]) + when string_of_id id = "vector" -> + let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in + CT_vector (direction, convert_typ ctx typ) + + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> + CT_ref (convert_typ ctx typ) + + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> UBindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> UBindings.bindings) + | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) + + | Typ_tup typs -> CT_tup (List.map (convert_typ ctx) typs) + + | Typ_exist _ -> + (* Use Type_check.destruct_exist when optimising with SMT, to + ensure that we don't cause any type variable clashes in + local_env, and that we can optimize the existential based + upon it's constraints. *) + begin match destruct_exist (Env.expand_synonyms ctx.local_env typ) with + | Some (kids, nc, typ) -> + let env = add_existential l kids nc ctx.local_env in + convert_typ { ctx with local_env = env } typ + | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!") + end - | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, analyze_functions ctx f aexp) + | Typ_var kid -> CT_poly - | AE_let (mut, id, typ1, aexp1, (AE_aux (_, env2, _) as aexp2), typ2) -> - let aexp1 = analyze_functions ctx f aexp1 in - (* Use aexp2's environment because it will contain constraints for id *) - let ctyp1 = ctyp_of_typ { ctx with local_env = env2 } typ1 in - let ctx = { ctx with locals = Bindings.add id (mut, ctyp1) ctx.locals } in - AE_let (mut, id, typ1, aexp1, analyze_functions ctx f aexp2, typ2) + | _ -> c_error ~loc:l ("No C type for type " ^ string_of_typ typ) - | AE_block (aexps, aexp, typ) -> AE_block (List.map (analyze_functions ctx f) aexps, analyze_functions ctx f aexp, typ) + let is_stack_typ ctx typ = is_stack_ctyp (convert_typ ctx typ) - | AE_if (aval, aexp1, aexp2, typ) -> - AE_if (aval, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2, typ) + let is_fbits_typ ctx typ = + match convert_typ ctx typ with + | CT_fbits _ -> true + | _ -> false - | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2) + let is_sbits_typ ctx typ = + match convert_typ ctx typ with + | CT_sbits _ -> true + | _ -> false - | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> - let aexp1 = analyze_functions ctx f aexp1 in - let aexp2 = analyze_functions ctx f aexp2 in - let aexp3 = analyze_functions ctx f aexp3 in - let aexp4 = analyze_functions ctx f aexp4 in - (* Currently we assume that loop indexes are always safe to put into an int64 *) - let ctx = { ctx with locals = Bindings.add id (Immutable, CT_fint 64) ctx.locals } in - AE_for (id, aexp1, aexp2, aexp3, order, aexp4) + (**************************************************************************) + (* 3. Optimization of primitives and literals *) + (**************************************************************************) + + let hex_char = + let open Sail2_values in + function + | '0' -> [B0; B0; B0; B0] + | '1' -> [B0; B0; B0; B1] + | '2' -> [B0; B0; B1; B0] + | '3' -> [B0; B0; B1; B1] + | '4' -> [B0; B1; B0; B0] + | '5' -> [B0; B1; B0; B1] + | '6' -> [B0; B1; B1; B0] + | '7' -> [B0; B1; B1; B1] + | '8' -> [B1; B0; B0; B0] + | '9' -> [B1; B0; B0; B1] + | 'A' | 'a' -> [B1; B0; B1; B0] + | 'B' | 'b' -> [B1; B0; B1; B1] + | 'C' | 'c' -> [B1; B1; B0; B0] + | 'D' | 'd' -> [B1; B1; B0; B1] + | 'E' | 'e' -> [B1; B1; B1; B0] + | 'F' | 'f' -> [B1; B1; B1; B1] + | _ -> failwith "Invalid hex character" + + let literal_to_fragment (L_aux (l_aux, _) as lit) = + match l_aux with + | L_num n when Big_int.less_equal (min_int 64) n && Big_int.less_equal n (max_int 64) -> + Some (V_lit (VL_int n, CT_fint 64)) + | L_hex str when String.length str <= 16 -> + let padding = 16 - String.length str in + let padding = Util.list_init padding (fun _ -> Sail2_values.B0) in + let content = Util.string_to_list str |> List.map hex_char |> List.concat in + Some (V_lit (VL_bits (padding @ content, true), CT_fbits (String.length str * 4, true))) + | L_unit -> Some (V_lit (VL_unit, CT_unit)) + | L_true -> Some (V_lit (VL_bool true, CT_bool)) + | L_false -> Some (V_lit (VL_bool false, CT_bool)) + | _ -> None + + let c_literals ctx = + let rec c_literal env l = function + | AV_lit (lit, typ) as v when is_stack_ctyp (convert_typ { ctx with local_env = env } typ) -> + begin + match literal_to_fragment lit with + | Some cval -> AV_cval (cval, typ) + | None -> v + end + | AV_tuple avals -> AV_tuple (List.map (c_literal env l) avals) + | v -> v + in + map_aval c_literal + + let rec is_bitvector = function + | [] -> true + | AV_lit (L_aux (L_zero, _), _) :: avals -> is_bitvector avals + | AV_lit (L_aux (L_one, _), _) :: avals -> is_bitvector avals + | _ :: _ -> false + + let rec value_of_aval_bit = function + | AV_lit (L_aux (L_zero, _), _) -> Sail2_values.B0 + | AV_lit (L_aux (L_one, _), _) -> Sail2_values.B1 + | _ -> assert false + + (** Used to make sure the -Ofixed_int and -Ofixed_bits don't + interfere with assumptions made about optimizations in the common + case. *) + let rec never_optimize = function + | CT_lbits _ | CT_lint -> true + | _ -> false - | AE_case (aval, cases, typ) -> - let analyze_case (AP_aux (_, env, _) as pat, aexp1, aexp2) = - let pat_bindings = Bindings.bindings (apat_types pat) in - let ctx = { ctx with local_env = env } in - let ctx = - List.fold_left (fun ctx (id, typ) -> { ctx with locals = Bindings.add id (Immutable, ctyp_of_typ ctx typ) ctx.locals }) ctx pat_bindings + let rec c_aval ctx = function + | AV_lit (lit, typ) as v -> + begin + match literal_to_fragment lit with + | Some cval -> AV_cval (cval, typ) + | None -> v + end + | AV_cval (cval, typ) -> AV_cval (cval, typ) + (* An id can be converted to a C fragment if it's type can be + stack-allocated. *) + | AV_id (id, lvar) as v -> + begin + match lvar with + | Local (_, typ) -> + let ctyp = convert_typ ctx typ in + if is_stack_ctyp ctyp && not (never_optimize ctyp) then + begin + try + (* We need to check that id's type hasn't changed due to flow typing *) + let _, ctyp' = Bindings.find id ctx.locals in + if ctyp_equal ctyp ctyp' then + AV_cval (V_id (name id, ctyp), typ) + else + (* id's type changed due to flow typing, so it's + really still heap allocated! *) + v + with + (* Hack: Assuming global letbindings don't change from flow typing... *) + Not_found -> AV_cval (V_id (name id, ctyp), typ) + end + else + v + | Register (_, _, typ) -> + let ctyp = convert_typ ctx typ in + if is_stack_ctyp ctyp && not (never_optimize ctyp) then + AV_cval (V_id (name id, ctyp), typ) + else + v + | _ -> v + end + | AV_vector (v, typ) when is_bitvector v && List.length v <= 64 -> + let bitstring = VL_bits (List.map value_of_aval_bit v, true) in + AV_cval (V_lit (bitstring, CT_fbits (List.length v, true)), typ) + | AV_tuple avals -> AV_tuple (List.map (c_aval ctx) avals) + | aval -> aval + + let c_fragment = function + | AV_cval (cval, _) -> cval + | _ -> assert false + + (* Map over all the functions in an aexp. *) + let rec analyze_functions ctx f (AE_aux (aexp, env, l)) = + let ctx = { ctx with local_env = env } in + let aexp = match aexp with + | AE_app (id, vs, typ) -> f ctx id vs typ + + | AE_cast (aexp, typ) -> AE_cast (analyze_functions ctx f aexp, typ) + + | AE_assign (id, typ, aexp) -> AE_assign (id, typ, analyze_functions ctx f aexp) + + | AE_write_ref (id, typ, aexp) -> AE_write_ref (id, typ, analyze_functions ctx f aexp) + + | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, analyze_functions ctx f aexp) + + | AE_let (mut, id, typ1, aexp1, (AE_aux (_, env2, _) as aexp2), typ2) -> + let aexp1 = analyze_functions ctx f aexp1 in + (* Use aexp2's environment because it will contain constraints for id *) + let ctyp1 = convert_typ { ctx with local_env = env2 } typ1 in + let ctx = { ctx with locals = Bindings.add id (mut, ctyp1) ctx.locals } in + AE_let (mut, id, typ1, aexp1, analyze_functions ctx f aexp2, typ2) + + | AE_block (aexps, aexp, typ) -> AE_block (List.map (analyze_functions ctx f) aexps, analyze_functions ctx f aexp, typ) + + | AE_if (aval, aexp1, aexp2, typ) -> + AE_if (aval, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2, typ) + + | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2) + + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> + let aexp1 = analyze_functions ctx f aexp1 in + let aexp2 = analyze_functions ctx f aexp2 in + let aexp3 = analyze_functions ctx f aexp3 in + let aexp4 = analyze_functions ctx f aexp4 in + (* Currently we assume that loop indexes are always safe to put into an int64 *) + let ctx = { ctx with locals = Bindings.add id (Immutable, CT_fint 64) ctx.locals } in + AE_for (id, aexp1, aexp2, aexp3, order, aexp4) + + | AE_case (aval, cases, typ) -> + let analyze_case (AP_aux (_, env, _) as pat, aexp1, aexp2) = + let pat_bindings = Bindings.bindings (apat_types pat) in + let ctx = { ctx with local_env = env } in + let ctx = + List.fold_left (fun ctx (id, typ) -> { ctx with locals = Bindings.add id (Immutable, convert_typ ctx typ) ctx.locals }) ctx pat_bindings + in + pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2 in - pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2 - in - AE_case (aval, List.map analyze_case cases, typ) + AE_case (aval, List.map analyze_case cases, typ) - | AE_try (aexp, cases, typ) -> - AE_try (analyze_functions ctx f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2) cases, typ) + | AE_try (aexp, cases, typ) -> + AE_try (analyze_functions ctx f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2) cases, typ) - | AE_field _ | AE_record_update _ | AE_val _ | AE_return _ | AE_throw _ as v -> v - in - AE_aux (aexp, env, l) + | AE_field _ | AE_record_update _ | AE_val _ | AE_return _ | AE_throw _ as v -> v + in + AE_aux (aexp, env, l) -let analyze_primop' ctx id args typ = - let no_change = AE_app (id, args, typ) in - let args = List.map (c_aval ctx) args in - let extern = if Env.is_extern id ctx.tc_env "c" then Env.get_extern id ctx.tc_env "c" else failwith "Not extern" in + let analyze_primop' ctx id args typ = + let no_change = AE_app (id, args, typ) in + let args = List.map (c_aval ctx) args in + let extern = if Env.is_extern id ctx.tc_env "c" then Env.get_extern id ctx.tc_env "c" else failwith "Not extern" in - let v_one = V_lit (VL_int (Big_int.of_int 1), CT_fint 64) in - let v_int n = V_lit (VL_int (Big_int.of_int n), CT_fint 64) in + let v_one = V_lit (VL_int (Big_int.of_int 1), CT_fint 64) in + let v_int n = V_lit (VL_int (Big_int.of_int n), CT_fint 64) in - match extern, args with - | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - begin match cval_ctyp v1 with - | CT_fbits _ | CT_sbits _ -> + match extern, args with + | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> + begin match cval_ctyp v1 with + | CT_fbits _ | CT_sbits _ -> AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) - | _ -> no_change - end + | _ -> no_change + end - | "neq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - begin match cval_ctyp v1 with - | CT_fbits _ | CT_sbits _ -> - AE_val (AV_cval (V_call (Neq, [v1; v2]), typ)) - | _ -> no_change - end + | "neq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> + begin match cval_ctyp v1 with + | CT_fbits _ | CT_sbits _ -> + AE_val (AV_cval (V_call (Neq, [v1; v2]), typ)) + | _ -> no_change + end - | "eq_int", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) + | "eq_int", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) - | "eq_bit", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) + | "eq_bit", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) - | "zeros", [_] -> - begin match destruct_vector ctx.tc_env typ with - | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) - when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - let n = Big_int.to_int n in - AE_val (AV_cval (V_lit (VL_bits (Util.list_init n (fun _ -> Sail2_values.B0), true), CT_fbits (n, true)), typ)) - | _ -> no_change - end + | "zeros", [_] -> + begin match destruct_vector ctx.tc_env typ with + | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) + when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> + let n = Big_int.to_int n in + AE_val (AV_cval (V_lit (VL_bits (Util.list_init n (fun _ -> Sail2_values.B0), true), CT_fbits (n, true)), typ)) + | _ -> no_change + end - | "zero_extend", [AV_cval (v, _); _] -> - begin match destruct_vector ctx.tc_env typ with - | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) - when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - AE_val (AV_cval (V_call (Zero_extend (Big_int.to_int n), [v]), typ)) - | _ -> no_change - end + | "zero_extend", [AV_cval (v, _); _] -> + begin match destruct_vector ctx.tc_env typ with + | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) + when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> + AE_val (AV_cval (V_call (Zero_extend (Big_int.to_int n), [v]), typ)) + | _ -> no_change + end - | "sign_extend", [AV_cval (v, _); _] -> - begin match destruct_vector ctx.tc_env typ with - | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) - when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - AE_val (AV_cval (V_call (Sign_extend (Big_int.to_int n), [v]), typ)) - | _ -> no_change - end + | "sign_extend", [AV_cval (v, _); _] -> + begin match destruct_vector ctx.tc_env typ with + | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) + when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> + AE_val (AV_cval (V_call (Sign_extend (Big_int.to_int n), [v]), typ)) + | _ -> no_change + end - | "lteq", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Ilteq, [v1; v2]), typ)) - | "gteq", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Igteq, [v1; v2]), typ)) - | "lt", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Ilt, [v1; v2]), typ)) - | "gt", [AV_cval (v1, _); AV_cval (v2, _)] -> - AE_val (AV_cval (V_call (Igt, [v1; v2]), typ)) - - | "append", [AV_cval (v1, _); AV_cval (v2, _)] -> - begin match ctyp_of_typ ctx typ with - | CT_fbits _ | CT_sbits _ -> - AE_val (AV_cval (V_call (Concat, [v1; v2]), typ)) - | _ -> no_change - end + | "lteq", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Ilteq, [v1; v2]), typ)) + | "gteq", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Igteq, [v1; v2]), typ)) + | "lt", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Ilt, [v1; v2]), typ)) + | "gt", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Igt, [v1; v2]), typ)) + + | "append", [AV_cval (v1, _); AV_cval (v2, _)] -> + begin match convert_typ ctx typ with + | CT_fbits _ | CT_sbits _ -> + AE_val (AV_cval (V_call (Concat, [v1; v2]), typ)) + | _ -> no_change + end - | "not_bits", [AV_cval (v, _)] -> - AE_val (AV_cval (V_call (Bvnot, [v]), typ)) + | "not_bits", [AV_cval (v, _)] -> + AE_val (AV_cval (V_call (Bvnot, [v]), typ)) - | "add_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvadd, [v1; v2]), typ)) + | "add_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> + AE_val (AV_cval (V_call (Bvadd, [v1; v2]), typ)) - | "sub_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvsub, [v1; v2]), typ)) + | "sub_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> + AE_val (AV_cval (V_call (Bvsub, [v1; v2]), typ)) - | "and_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvand, [v1; v2]), typ)) + | "and_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> + AE_val (AV_cval (V_call (Bvand, [v1; v2]), typ)) - | "or_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvor, [v1; v2]), typ)) + | "or_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> + AE_val (AV_cval (V_call (Bvor, [v1; v2]), typ)) - | "xor_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> - AE_val (AV_cval (V_call (Bvxor, [v1; v2]), typ)) + | "xor_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> + AE_val (AV_cval (V_call (Bvxor, [v1; v2]), typ)) - | "vector_subrange", [AV_cval (vec, _); AV_cval (f, _); AV_cval (t, _)] -> - begin match ctyp_of_typ ctx typ with - | CT_fbits (n, true) -> - AE_val (AV_cval (V_call (Slice n, [vec; t]), typ)) - | _ -> no_change - end + | "vector_subrange", [AV_cval (vec, _); AV_cval (f, _); AV_cval (t, _)] -> + begin match convert_typ ctx typ with + | CT_fbits (n, true) -> + AE_val (AV_cval (V_call (Slice n, [vec; t]), typ)) + | _ -> no_change + end - | "slice", [AV_cval (vec, _); AV_cval (start, _); AV_cval (len, _)] -> - begin match ctyp_of_typ ctx typ with - | CT_fbits (n, _) -> - AE_val (AV_cval (V_call (Slice n, [vec; start]), typ)) - | CT_sbits (64, _) -> - AE_val (AV_cval (V_call (Sslice 64, [vec; start; len]), typ)) - | _ -> no_change - end + | "slice", [AV_cval (vec, _); AV_cval (start, _); AV_cval (len, _)] -> + begin match convert_typ ctx typ with + | CT_fbits (n, _) -> + AE_val (AV_cval (V_call (Slice n, [vec; start]), typ)) + | CT_sbits (64, _) -> + AE_val (AV_cval (V_call (Sslice 64, [vec; start; len]), typ)) + | _ -> no_change + end - | "vector_access", [AV_cval (vec, _); AV_cval (n, _)] -> - AE_val (AV_cval (V_call (Bvaccess, [vec; n]), typ)) + | "vector_access", [AV_cval (vec, _); AV_cval (n, _)] -> + AE_val (AV_cval (V_call (Bvaccess, [vec; n]), typ)) - | "add_int", [AV_cval (op1, _); AV_cval (op2, _)] -> - begin match destruct_range Env.empty typ with - | None -> no_change - | Some (kids, constr, n, m) -> - match nexp_simp n, nexp_simp m with - | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + | "add_int", [AV_cval (op1, _); AV_cval (op2, _)] -> + begin match destruct_range Env.empty typ with + | None -> no_change + | Some (kids, constr, n, m) -> + match nexp_simp n, nexp_simp m with + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> - AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) - | n, m when prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) -> - AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) - | _ -> no_change - end + AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) + | n, m when prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) -> + AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) + | _ -> no_change + end - | "replicate_bits", [AV_cval (vec, vtyp); _] -> - begin match destruct_vector ctx.tc_env typ, destruct_vector ctx.tc_env vtyp with - | Some (Nexp_aux (Nexp_constant n, _), _, _), Some (Nexp_aux (Nexp_constant m, _), _, _) - when Big_int.less_equal n (Big_int.of_int 64) -> - let times = Big_int.div n m in - if Big_int.equal (Big_int.mul m times) n then - AE_val (AV_cval (V_call (Replicate (Big_int.to_int times), [vec]), typ)) - else + | "replicate_bits", [AV_cval (vec, vtyp); _] -> + begin match destruct_vector ctx.tc_env typ, destruct_vector ctx.tc_env vtyp with + | Some (Nexp_aux (Nexp_constant n, _), _, _), Some (Nexp_aux (Nexp_constant m, _), _, _) + when Big_int.less_equal n (Big_int.of_int 64) -> + let times = Big_int.div n m in + if Big_int.equal (Big_int.mul m times) n then + AE_val (AV_cval (V_call (Replicate (Big_int.to_int times), [vec]), typ)) + else + no_change + | _, _ -> no_change - | _, _ -> - no_change - end - - | "undefined_bit", _ -> - AE_val (AV_cval (V_lit (VL_bit Sail2_values.B0, CT_bit), typ)) + end - | "undefined_bool", _ -> - AE_val (AV_cval (V_lit (VL_bool false, CT_bool), typ)) + | "undefined_bit", _ -> + AE_val (AV_cval (V_lit (VL_bit Sail2_values.B0, CT_bit), typ)) - | _, _ -> - no_change + | "undefined_bool", _ -> + AE_val (AV_cval (V_lit (VL_bool false, CT_bool), typ)) -let analyze_primop ctx id args typ = - let no_change = AE_app (id, args, typ) in - if !optimize_primops then - try analyze_primop' ctx id args typ with - | Failure str -> + | _, _ -> no_change - else - no_change + + let analyze_primop ctx id args typ = + let no_change = AE_app (id, args, typ) in + if !optimize_primops then + try analyze_primop' ctx id args typ with + | Failure str -> + no_change + else + no_change + + let optimize_anf ctx aexp = + analyze_functions ctx analyze_primop (c_literals ctx aexp) + + + let unroll_loops () = None + let specialize_calls = false + let ignore_64 = false + let struct_value = false + let use_real = false +end (** Functions that have heap-allocated return types are implemented by passing a pointer a location where the return value should be @@ -571,9 +588,9 @@ let fix_early_heap_return ret ret_ctyp instrs = before @ [iblock (rewrite_return instrs)] @ rewrite_return after - | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), _) :: after -> + | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> before - @ [iif cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] + @ [iif l cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] @ rewrite_return after | before, I_aux (I_funcall (CL_id (Return _, ctyp), extern, fid, args), aux) :: after -> before @@ -608,9 +625,9 @@ let fix_early_stack_return ret ret_ctyp instrs = before @ [iblock (rewrite_return instrs)] @ rewrite_return after - | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), _) :: after -> + | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> before - @ [iif cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] + @ [iif l cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] @ rewrite_return after | before, I_aux (I_funcall (CL_id (Return _, ctyp), extern, fid, args), aux) :: after -> before @@ -631,7 +648,7 @@ let fix_early_stack_return ret ret_ctyp instrs = rewrite_return instrs let rec insert_heap_returns ret_ctyps = function - | (CDEF_spec (id, _, ret_ctyp) as cdef) :: cdefs -> + | (CDEF_spec (id, _, _, ret_ctyp) as cdef) :: cdefs -> cdef :: insert_heap_returns (Bindings.add id ret_ctyp ret_ctyps) cdefs | CDEF_fundef (id, None, args, body) :: cdefs -> @@ -1022,7 +1039,7 @@ let optimize recursive_functions cdefs = let sgen_id id = Util.zencode_string (string_of_id id) let sgen_uid uid = zencode_uid uid -let sgen_name id = string_of_name id +let sgen_name id = string_of_name ~deref_current_exception:true ~zencode:true id let codegen_id id = string (sgen_id id) let codegen_uid id = string (sgen_uid id) @@ -1033,7 +1050,7 @@ let sgen_function_id id = let sgen_function_uid uid = let str = zencode_uid uid in !opt_prefix ^ String.sub str 1 (String.length str - 1) - + let codegen_function_id id = string (sgen_function_id id) let rec sgen_ctyp = function @@ -1052,6 +1069,7 @@ let rec sgen_ctyp = function | CT_variant (id, _) -> "struct " ^ sgen_id id | CT_list _ as l -> Util.zencode_string (string_of_ctyp l) | CT_vector _ as v -> Util.zencode_string (string_of_ctyp v) + | CT_fvector (_, ord, typ) -> sgen_ctyp (CT_vector (ord, typ)) | CT_string -> "sail_string" | CT_real -> "real" | CT_ref ctyp -> sgen_ctyp ctyp ^ "*" @@ -1073,6 +1091,7 @@ let rec sgen_ctyp_name = function | CT_variant (id, _) -> sgen_id id | CT_list _ as l -> Util.zencode_string (string_of_ctyp l) | CT_vector _ as v -> Util.zencode_string (string_of_ctyp v) + | CT_fvector (_, ord, typ) -> sgen_ctyp_name (CT_vector (ord, typ)) | CT_string -> "sail_string" | CT_real -> "real" | CT_ref ctyp -> "ref_" ^ sgen_ctyp_name ctyp @@ -1094,24 +1113,27 @@ let sgen_mask n = else failwith "Tried to create a mask literal for a vector greater than 64 bits." -let sgen_value = function +let rec sgen_value = function | VL_bits ([], _) -> "UINT64_C(0)" | VL_bits (bs, true) -> "UINT64_C(" ^ Sail2_values.show_bitlist bs ^ ")" | VL_bits (bs, false) -> "UINT64_C(" ^ Sail2_values.show_bitlist (List.rev bs) ^ ")" | VL_int i -> Big_int.to_string i ^ "l" | VL_bool true -> "true" | VL_bool false -> "false" - | VL_null -> "NULL" | VL_unit -> "UNIT" | VL_bit Sail2_values.B0 -> "UINT64_C(0)" | VL_bit Sail2_values.B1 -> "UINT64_C(1)" | VL_bit Sail2_values.BU -> failwith "Undefined bit found in value" | VL_real str -> str | VL_string str -> "\"" ^ str ^ "\"" - + | VL_empty_list -> "NULL" + | VL_enum element -> Util.zencode_string element + | VL_ref r -> "&" ^ Util.zencode_string r + | VL_undefined -> + Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot generate C value for an undefined literal" + let rec sgen_cval = function - | V_id (id, ctyp) -> string_of_name id - | V_ref (id, _) -> "&" ^ string_of_name id + | V_id (id, ctyp) -> sgen_name id | V_lit (vl, ctyp) -> sgen_value vl | V_call (op, cvals) -> sgen_call op cvals | V_field (f, field) -> @@ -1133,8 +1155,6 @@ let rec sgen_cval = function and sgen_call op cvals = let open Printf in match op, cvals with - | Bit_to_bool, [v] -> - sprintf "((bool) %s)" (sgen_cval v) | Bnot, [v] -> "!(" ^ sgen_cval v ^ ")" | List_hd, [v] -> sprintf "(%s).hd" ("*" ^ sgen_cval v) @@ -1306,6 +1326,7 @@ let sgen_cval_param cval = let rec sgen_clexp = function | CL_id (Have_exception _, _) -> "have_exception" | CL_id (Current_exception _, _) -> "current_exception" + | CL_id (Throw_location _, _) -> "throw_location" | CL_id (Return _, _) -> assert false | CL_id (Name (id, _), _) -> "&" ^ sgen_id id | CL_field (clexp, field) -> "&((" ^ sgen_clexp clexp ^ ")->" ^ zencode_uid field ^ ")" @@ -1317,6 +1338,7 @@ let rec sgen_clexp = function let rec sgen_clexp_pure = function | CL_id (Have_exception _, _) -> "have_exception" | CL_id (Current_exception _, _) -> "current_exception" + | CL_id (Throw_location _, _) -> "throw_location" | CL_id (Return _, _) -> assert false | CL_id (Name (id, _), _) -> sgen_id id | CL_field (clexp, field) -> sgen_clexp_pure clexp ^ "." ^ zencode_uid field @@ -1400,21 +1422,26 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = ^^ jump 2 2 (separate_map hardline (codegen_instr fid ctx) instrs) ^^ hardline ^^ string " }" - | I_funcall (x, extern, f, args) -> + | I_funcall (x, special_extern, f, args) -> let c_args = Util.string_of_list ", " sgen_cval args in let ctyp = clexp_ctyp x in - let is_extern = Env.is_extern (fst f) ctx.tc_env "c" || extern in + let is_extern = Env.is_extern (fst f) ctx.tc_env "c" || special_extern in let fname = - if Env.is_extern (fst f) ctx.tc_env "c" then - Env.get_extern (fst f) ctx.tc_env "c" - else if extern then + if special_extern then string_of_id (fst f) + else if Env.is_extern (fst f) ctx.tc_env "c" then + Env.get_extern (fst f) ctx.tc_env "c" else sgen_function_uid f in let fname = match fname, ctyp with | "internal_pick", _ -> Printf.sprintf "pick_%s" (sgen_ctyp_name ctyp) + | "cons", _ -> + begin match snd f with + | [ctyp] -> Util.zencode_string ("cons#" ^ string_of_ctyp ctyp) + | _ -> c_error "cons without specified type" + end | "eq_anything", _ -> begin match args with | cval :: _ -> Printf.sprintf "eq_%s" (sgen_ctyp_name (cval_ctyp cval)) @@ -1765,6 +1792,8 @@ let codegen_type_def ctx = function ^^ string "struct zexception *current_exception = NULL;" ^^ hardline ^^ string "bool have_exception = false;" + ^^ hardline + ^^ string "sail_string *throw_location = NULL;" else empty @@ -1994,7 +2023,7 @@ let codegen_def' ctx = function string (Printf.sprintf "// register %s" (string_of_id id)) ^^ hardline ^^ string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id)) - | CDEF_spec (id, arg_ctyps, ret_ctyp) -> + | CDEF_spec (id, _, arg_ctyps, ret_ctyp) -> let static = if !opt_static then "static " else "" in if Env.is_extern id ctx.tc_env "c" then empty @@ -2009,7 +2038,7 @@ let codegen_def' ctx = function | None -> c_error ~loc:(id_loc id) ("No valspec found for " ^ string_of_id id) in - + (* Check that the function has the correct arity at this point. *) if List.length arg_ctyps <> List.length args then c_error ~loc:(id_loc id) ("function arguments " @@ -2095,7 +2124,7 @@ type c_gen_typ = let rec ctyp_dependencies = function | CT_tup ctyps -> List.concat (List.map ctyp_dependencies ctyps) @ [CTG_tup ctyps] | CT_list ctyp -> ctyp_dependencies ctyp @ [CTG_list ctyp] - | CT_vector (direction, ctyp) -> ctyp_dependencies ctyp @ [CTG_vector (direction, ctyp)] + | CT_vector (direction, ctyp) | CT_fvector (_, direction, ctyp) -> ctyp_dependencies ctyp @ [CTG_vector (direction, ctyp)] | CT_ref ctyp -> ctyp_dependencies ctyp | CT_struct (_, ctors) -> List.concat (List.map (fun (_, ctyp) -> ctyp_dependencies ctyp) ctors) | CT_variant (_, ctors) -> List.concat (List.map (fun (_, ctyp) -> ctyp_dependencies ctyp) ctors) @@ -2169,20 +2198,17 @@ let rec get_recursive_functions (Defs defs) = | [] -> IdSet.empty let jib_of_ast env ast = - let ctx = - initial_ctx - ~convert_typ:ctyp_of_typ - ~optimize_anf:(fun ctx aexp -> analyze_functions ctx analyze_primop (c_literals ctx aexp)) - env - in - Jib_compile.compile_ast ctx ast + let module Jibc = Make(C_config) in + let ctx = initial_ctx (add_special_functions env) in + Jibc.compile_ast ctx ast let compile_ast env output_chan c_includes ast = try let recursive_functions = Spec_analysis.top_sort_defs ast |> get_recursive_functions in let cdefs, ctx = jib_of_ast env ast in - Jib_interactive.ir := cdefs; + let cdefs', _ = Jib_optimize.remove_tuples cdefs ctx in + Jib_interactive.ir := cdefs'; let cdefs = insert_heap_returns Bindings.empty cdefs in let cdefs = optimize recursive_functions cdefs in @@ -2199,10 +2225,15 @@ let compile_ast env output_chan c_includes ast = let exn_boilerplate = if not (Bindings.mem (mk_id "exception") ctx.variants) then ([], []) else ([ " current_exception = sail_malloc(sizeof(struct zexception));"; - " CREATE(zexception)(current_exception);" ], - [ " KILL(zexception)(current_exception);"; + " CREATE(zexception)(current_exception);"; + " throw_location = sail_malloc(sizeof(sail_string));"; + " CREATE(sail_string)(throw_location);" ], + [ " if (have_exception) {fprintf(stderr, \"Exiting due to uncaught exception: %s\\n\", *throw_location);}"; + " KILL(zexception)(current_exception);"; " sail_free(current_exception);"; - " if (have_exception) {fprintf(stderr, \"Exiting due to uncaught exception\\n\"); exit(EXIT_FAILURE);}" ]) + " KILL(sail_string)(throw_location);"; + " sail_free(throw_location);"; + " if (have_exception) {exit(EXIT_FAILURE);}" ]) in let letbind_initializers = diff --git a/src/jib/c_backend.mli b/src/jib/c_backend.mli index 2f748fd7..e627ebd8 100644 --- a/src/jib/c_backend.mli +++ b/src/jib/c_backend.mli @@ -106,8 +106,5 @@ val optimize_alias : bool ref val optimize_fixed_int : bool ref val optimize_fixed_bits : bool ref -(** Convert a typ to a IR ctyp *) -val ctyp_of_typ : Jib_compile.ctx -> Ast.typ -> ctyp - val jib_of_ast : Env.t -> tannot Ast.defs -> cdef list * Jib_compile.ctx val compile_ast : Env.t -> out_channel -> string list -> tannot Ast.defs -> unit diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index 0efac940..4282ae30 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -58,6 +58,7 @@ open Value2 open Anf let opt_memo_cache = ref false +let opt_track_throw = ref true let optimize_aarch64_fast_struct = ref false @@ -151,38 +152,38 @@ type ctx = enums : IdSet.t Bindings.t; variants : (ctyp UBindings.t) Bindings.t; valspecs : (ctyp list * ctyp) Bindings.t; - tc_env : Env.t; local_env : Env.t; + tc_env : Env.t; locals : (mut * ctyp) Bindings.t; letbinds : int list; no_raw : bool; - convert_typ : ctx -> typ -> ctyp; - optimize_anf : ctx -> typ aexp -> typ aexp; - specialize_calls : bool; - ignore_64 : bool; - struct_value : bool; - use_real : bool; } -let initial_ctx ~convert_typ:convert_typ ~optimize_anf:optimize_anf env = +let initial_ctx env = { records = Bindings.empty; enums = Bindings.empty; variants = Bindings.empty; valspecs = Bindings.empty; - tc_env = env; local_env = env; + tc_env = env; locals = Bindings.empty; letbinds = []; no_raw = false; - convert_typ = convert_typ; - optimize_anf = optimize_anf; - specialize_calls = false; - ignore_64 = false; - struct_value = false; - use_real = false; } -let ctyp_of_typ ctx typ = ctx.convert_typ ctx typ +module type Config = sig + val convert_typ : ctx -> typ -> ctyp + val optimize_anf : ctx -> typ aexp -> typ aexp + val unroll_loops : unit -> int option + val specialize_calls : bool + val ignore_64 : bool + val struct_value : bool + val use_real : bool +end + +module Make(C: Config) = struct + +let ctyp_of_typ ctx typ = C.convert_typ ctx typ let rec chunkify n xs = match Util.take n xs, Util.drop n xs with @@ -210,12 +211,12 @@ let rec compile_aval l ctx = function end | AV_ref (id, typ) -> - [], V_ref (name id, ctyp_of_typ ctx (lvar_typ typ)), [] + [], V_lit (VL_ref (string_of_id id), CT_ref (ctyp_of_typ ctx (lvar_typ typ))), [] | AV_lit (L_aux (L_string str, _), typ) -> [], V_lit ((VL_string (String.escaped str)), ctyp_of_typ ctx typ), [] - | AV_lit (L_aux (L_num n, _), typ) when ctx.ignore_64 -> + | AV_lit (L_aux (L_num n, _), typ) when C.ignore_64 -> [], V_lit ((VL_int n), ctyp_of_typ ctx typ), [] | AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal (min_int 64) n && Big_int.less_equal n (max_int 64) -> @@ -237,7 +238,7 @@ let rec compile_aval l ctx = function | AV_lit (L_aux (L_false, _), _) -> [], V_lit (VL_bool false, CT_bool), [] | AV_lit (L_aux (L_real str, _), _) -> - if ctx.use_real then + if C.use_real then [], V_lit (VL_real str, CT_real), [] else let gs = ngensym () in @@ -247,6 +248,10 @@ let rec compile_aval l ctx = function | AV_lit (L_aux (L_unit, _), _) -> [], V_lit (VL_unit, CT_unit), [] + | AV_lit (L_aux (L_undef, _), typ) -> + let ctyp = ctyp_of_typ ctx typ in + [], V_lit (VL_undefined, ctyp), [] + | AV_lit (L_aux (_, l) as lit, _) -> raise (Reporting.err_general l ("Encountered unexpected literal " ^ string_of_lit lit ^ " when converting ANF represention into IR")) @@ -264,7 +269,7 @@ let rec compile_aval l ctx = function [iclear tup_ctyp gs] @ cleanup - | AV_record (fields, typ) when ctx.struct_value -> + | AV_record (fields, typ) when C.struct_value -> let ctyp = ctyp_of_typ ctx typ in let gs = ngensym () in let compile_fields (id, aval) = @@ -309,7 +314,7 @@ let rec compile_aval l ctx = function end (* Convert a small bitvector to a uint64_t literal. *) - | AV_vector (avals, typ) when is_bitvector avals && (List.length avals <= 64 || ctx.ignore_64) -> + | AV_vector (avals, typ) when is_bitvector avals && (List.length avals <= 64 || C.ignore_64) -> begin let bitstring = List.map value_of_aval_bit avals in let len = List.length avals in @@ -358,11 +363,14 @@ let rec compile_aval l ctx = function | V_lit (VL_bit Sail2_values.B1, _) -> [icopy l (CL_id (gs, ctyp)) (V_call (Bvor, [V_id (gs, ctyp); V_lit (mask i, ctyp)]))] | _ -> - (* FIXME: Make this work in C *) - setup @ [iif (V_call (Bit_to_bool, [cval])) [icopy l (CL_id (gs, ctyp)) (V_call (Bvor, [V_id (gs, ctyp); V_lit (mask i, ctyp)]))] [] CT_unit] @ cleanup + setup + @ [iextern (CL_id (gs, ctyp)) + (mk_id "update_fbits", []) + [V_id (gs, ctyp); V_lit (VL_int (Big_int.of_int i), CT_fint 64); cval]] + @ cleanup in [idecl ctyp gs; - icopy l (CL_id (gs, ctyp)) (V_lit (VL_bits (Util.list_init 64 (fun _ -> Sail2_values.B0), direction), ctyp))] + icopy l (CL_id (gs, ctyp)) (V_lit (VL_bits (Util.list_init len (fun _ -> Sail2_values.B0), direction), ctyp))] @ List.concat (List.mapi aval_mask (List.rev avals)), V_id (gs, ctyp), [] @@ -403,7 +411,7 @@ let rec compile_aval l ctx = function let gs = ngensym () in let mk_cons aval = let setup, cval, cleanup = compile_aval l ctx aval in - setup @ [ifuncall (CL_id (gs, CT_list ctyp)) (mk_id ("cons#" ^ string_of_ctyp ctyp), []) [cval; V_id (gs, CT_list ctyp)]] @ cleanup + setup @ [iextern (CL_id (gs, CT_list ctyp)) (mk_id "cons", [ctyp]) [cval; V_id (gs, CT_list ctyp)]] @ cleanup in [idecl (CT_list ctyp) gs] @ List.concat (List.map mk_cons (List.rev avals)), @@ -420,7 +428,7 @@ let optimize_call l ctx clexp id args arg_ctyps ret_ctyp = let have_ctyp = cval_ctyp cval in if is_polymorphic ctyp then V_poly (cval, have_ctyp) - else if ctx.specialize_calls || ctyp_equal ctyp have_ctyp then + else if C.specialize_calls || ctyp_equal ctyp have_ctyp then cval else let gs = ngensym () in @@ -429,7 +437,7 @@ let optimize_call l ctx clexp id args arg_ctyps ret_ctyp = V_id (gs, ctyp)) arg_ctyps args in - if ctx.specialize_calls || ctyp_equal (clexp_ctyp clexp) ret_ctyp then + if C.specialize_calls || ctyp_equal (clexp_ctyp clexp) ret_ctyp then !setup @ [ifuncall clexp id cast_args] @ !cleanup else let gs = ngensym () in @@ -440,7 +448,7 @@ let optimize_call l ctx clexp id args arg_ctyps ret_ctyp = iclear ret_ctyp gs] @ !cleanup in - if not ctx.specialize_calls && Env.is_extern (fst id) ctx.tc_env "c" then + if not C.specialize_calls && Env.is_extern (fst id) ctx.tc_env "c" then let extern = Env.get_extern (fst id) ctx.tc_env "c" in begin match extern, List.map cval_ctyp args, clexp_ctyp clexp with | "slice", [CT_fbits _; CT_lint; _], CT_fbits (n, _) -> @@ -506,7 +514,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = let ctyp = cval_ctyp cval in match apat_aux with | AP_id (pid, _) when Env.is_union_constructor pid ctx.tc_env -> - [ijump (V_ctor_kind (cval, pid, [], cval_ctyp cval)) case_label], + [ijump l (V_ctor_kind (cval, pid, [], cval_ctyp cval)) case_label], [], ctx @@ -517,7 +525,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = | AP_id (pid, _) when is_ct_enum ctyp -> begin match Env.lookup_id pid ctx.tc_env with | Unbound -> [idecl ctyp (name pid); icopy l (CL_id (name pid, ctyp)) cval], [], ctx - | _ -> [ijump (V_call (Neq, [V_id (name pid, ctyp); cval])) case_label], [], ctx + | _ -> [ijump l (V_call (Neq, [V_id (name pid, ctyp); cval])) case_label], [], ctx end | AP_id (pid, typ) -> @@ -562,7 +570,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = [], ctor_ctyp in let instrs, cleanup, ctx = compile_match ctx apat (V_ctor_unwrap (ctor, cval, unifiers, ctor_ctyp)) case_label in - [ijump (V_ctor_kind (cval, ctor, unifiers, pat_ctyp)) case_label] + [ijump l (V_ctor_kind (cval, ctor, unifiers, pat_ctyp)) case_label] @ instrs, cleanup, ctx @@ -581,12 +589,12 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = | CT_list ctyp -> let hd_setup, hd_cleanup, ctx = compile_match ctx hd_apat (V_call (List_hd, [cval])) case_label in let tl_setup, tl_cleanup, ctx = compile_match ctx tl_apat (V_call (List_tl, [cval])) case_label in - [ijump (V_call (Eq, [cval; V_lit (VL_null, CT_list ctyp)])) case_label] @ hd_setup @ tl_setup, tl_cleanup @ hd_cleanup, ctx + [ijump l (V_call (Eq, [cval; V_lit (VL_empty_list, CT_list ctyp)])) case_label] @ hd_setup @ tl_setup, tl_cleanup @ hd_cleanup, ctx | _ -> raise (Reporting.err_general l "Tried to pattern match cons on non list type") end - | AP_nil _ -> [ijump (V_call (Neq, [cval; V_lit (VL_null, ctyp)])) case_label], [], ctx + | AP_nil _ -> [ijump l (V_call (Neq, [cval; V_lit (VL_empty_list, ctyp)])) case_label], [], ctx let unit_cval = V_lit (VL_unit, CT_unit) @@ -633,7 +641,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = destructure @ (if not trivial_guard then guard_setup @ [idecl CT_bool gs; guard_call (CL_id (gs, CT_bool))] @ guard_cleanup - @ [iif (V_call (Bnot, [V_id (gs, CT_bool)])) (destructure_cleanup @ [igoto case_label]) [] CT_unit] + @ [iif l (V_call (Bnot, [V_id (gs, CT_bool)])) (destructure_cleanup @ [igoto case_label]) [] CT_unit] else []) @ body_setup @ [body_call (CL_id (case_return_id, ctyp))] @ body_cleanup @ destructure_cleanup @ [igoto finish_match_label] @@ -674,7 +682,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = destructure @ [icomment "end destructuring"] @ (if not trivial_guard then guard_setup @ [idecl CT_bool gs; guard_call (CL_id (gs, CT_bool))] @ guard_cleanup - @ [ijump (V_call (Bnot, [V_id (gs, CT_bool)])) try_label] + @ [ijump l (V_call (Bnot, [V_id (gs, CT_bool)])) try_label] @ [icomment "end guard"] else []) @ body_setup @ [body_call (CL_id (try_return_id, ctyp))] @ body_cleanup @ destructure_cleanup @@ -685,7 +693,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = assert (ctyp_equal ctyp (ctyp_of_typ ctx typ)); [idecl ctyp try_return_id; itry_block (aexp_setup @ [aexp_call (CL_id (try_return_id, ctyp))] @ aexp_cleanup); - ijump (V_call (Bnot, [V_id (have_exception, CT_bool)])) handled_exception_label] + ijump l (V_call (Bnot, [V_id (have_exception, CT_bool)])) handled_exception_label] @ List.concat (List.map compile_case cases) @ [igoto fallthrough_label; ilabel handled_exception_label; @@ -707,7 +715,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = in let setup, cval, cleanup = compile_aval l ctx aval in setup, - (fun clexp -> iif cval + (fun clexp -> iif l cval (compile_branch then_aexp clexp) (compile_branch else_aexp clexp) if_ctyp), @@ -742,7 +750,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let gs = ngensym () in left_setup @ [ idecl CT_bool gs; - iif cval + iif l cval (right_setup @ [call (CL_id (gs, CT_bool))] @ right_cleanup) [icopy l (CL_id (gs, CT_bool)) (V_lit (VL_bool false, CT_bool))] CT_bool ] @@ -755,7 +763,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let gs = ngensym () in left_setup @ [ idecl CT_bool gs; - iif cval + iif l cval [icopy l (CL_id (gs, CT_bool)) (V_lit (VL_bool true, CT_bool))] (right_setup @ [call (CL_id (gs, CT_bool))] @ right_cleanup) CT_bool ] @@ -813,7 +821,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = @ [iblock (cond_setup @ [cond_call (CL_id (gs, CT_bool))] @ cond_cleanup - @ [ijump loop_test loop_end_label] + @ [ijump l loop_test loop_end_label] @ body_setup @ [body_call (CL_id (unit_gs, CT_unit))] @ body_cleanup @@ -838,7 +846,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = @ cond_setup @ [cond_call (CL_id (gs, CT_bool))] @ cond_cleanup - @ [ijump loop_test loop_end_label] + @ [ijump l loop_test loop_end_label] @ [igoto loop_start_label])] @ [ilabel loop_end_label], (fun clexp -> icopy l clexp unit_cval), @@ -869,7 +877,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = | AE_throw (aval, typ) -> (* Cleanup info will be handled by fix_exceptions *) let throw_setup, cval, _ = compile_aval l ctx aval in - throw_setup @ [ithrow cval], + throw_setup @ [ithrow l cval], (fun clexp -> icomment "unreachable after throw"), [] @@ -927,21 +935,29 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let loop_var = name loop_var in + let loop_body prefix continue = + prefix + @ [iblock ([ijump l (V_call ((if is_inc then Igt else Ilt), [V_id (loop_var, CT_fint 64); V_id (to_gs, CT_fint 64)])) loop_end_label] + @ body_setup + @ [body_call (CL_id (body_gs, CT_unit))] + @ body_cleanup + @ [icopy l (CL_id (loop_var, (CT_fint 64))) + (V_call ((if is_inc then Iadd else Isub), [V_id (loop_var, CT_fint 64); V_id (step_gs, CT_fint 64)]))] + @ continue ())] + in + (* We can either generate an actual loop body for C, or unroll the body for SMT *) + let actual = loop_body [ilabel loop_start_label] (fun () -> [igoto loop_start_label]) in + let rec unroll max n = loop_body [] (fun () -> if n < max then unroll max (n + 1) else [imatch_failure ()]) in + let body = match C.unroll_loops () with Some times -> unroll times 0 | None -> actual in + variable_init from_gs from_setup from_call from_cleanup @ variable_init to_gs to_setup to_call to_cleanup @ variable_init step_gs step_setup step_call step_cleanup @ [iblock ([idecl (CT_fint 64) loop_var; icopy l (CL_id (loop_var, (CT_fint 64))) (V_id (from_gs, CT_fint 64)); - idecl CT_unit body_gs; - iblock ([ilabel loop_start_label] - @ [ijump (V_call ((if is_inc then Igt else Ilt), [V_id (loop_var, CT_fint 64); V_id (to_gs, CT_fint 64)])) loop_end_label] - @ body_setup - @ [body_call (CL_id (body_gs, CT_unit))] - @ body_cleanup - @ [icopy l (CL_id (loop_var, (CT_fint 64))) - (V_call ((if is_inc then Iadd else Isub), [V_id (loop_var, CT_fint 64); V_id (step_gs, CT_fint 64)]))] - @ [igoto loop_start_label]); - ilabel loop_end_label])], + idecl CT_unit body_gs] + @ body + @ [ilabel loop_end_label])], (fun clexp -> icopy l clexp unit_cval), [] @@ -1032,19 +1048,23 @@ let fix_exception_block ?return:(return=None) ctx instrs = before @ [iblock (rewrite_exception (historic @ before) instrs)] @ rewrite_exception (historic @ before) after - | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), _) :: after -> + | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> let historic = historic @ before in before - @ [iif cval (rewrite_exception historic then_instrs) (rewrite_exception historic else_instrs) ctyp] + @ [iif l cval (rewrite_exception historic then_instrs) (rewrite_exception historic else_instrs) ctyp] @ rewrite_exception historic after | before, I_aux (I_throw cval, (_, l)) :: after -> before @ [icopy l (CL_id (current_exception, cval_ctyp cval)) cval; icopy l (CL_id (have_exception, CT_bool)) (V_lit (VL_bool true, CT_bool))] + @ (if !opt_track_throw then + let loc_string = Reporting.short_loc_to_string l in + [icopy l (CL_id (throw_location, CT_string)) (V_lit (VL_string loc_string, CT_string))] + else []) @ generate_cleanup (historic @ before) @ [igoto end_block_label] @ rewrite_exception (historic @ before) after - | before, (I_aux (I_funcall (x, _, f, args), _) as funcall) :: after -> + | before, (I_aux (I_funcall (x, _, f, args), (_, l)) as funcall) :: after -> let effects = match Env.get_val_spec (fst f) ctx.tc_env with | _, Typ_aux (Typ_fn (_, _, effects), _) -> effects | exception (Type_error _) -> no_effect (* nullary union constructor, so no val spec *) @@ -1053,7 +1073,7 @@ let fix_exception_block ?return:(return=None) ctx instrs = if has_effect effects BE_escape then before @ [funcall; - iif (V_id (have_exception, CT_bool)) (generate_cleanup (historic @ before) @ [igoto end_block_label]) [] CT_unit] + iif l (V_id (have_exception, CT_bool)) (generate_cleanup (historic @ before) @ [igoto end_block_label]) [] CT_unit] @ rewrite_exception (historic @ before) after else before @ funcall :: rewrite_exception (historic @ before) after @@ -1147,10 +1167,10 @@ let fix_early_return ret instrs = before @ [iblock (rewrite_return (historic @ before) instrs)] @ rewrite_return (historic @ before) after - | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), _) :: after -> + | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> let historic = historic @ before in before - @ [iif cval (rewrite_return historic then_instrs) (rewrite_return historic else_instrs) ctyp] + @ [iif l cval (rewrite_return historic then_instrs) (rewrite_return historic else_instrs) ctyp] @ rewrite_return historic after | before, I_aux (I_return cval, (_, l)) :: after -> let cleanup_label = label "cleanup_" in @@ -1211,7 +1231,7 @@ let compile_funcl ctx id pat guard exp = let guard_instrs = match guard with | Some guard -> - let guard_aexp = ctx.optimize_anf ctx (no_shadow (pat_ids pat) (anf guard)) in + let guard_aexp = C.optimize_anf ctx (no_shadow (pat_ids pat) (anf guard)) in let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard_aexp in let guard_label = label "guard_" in let gs = ngensym () in @@ -1220,7 +1240,7 @@ let compile_funcl ctx id pat guard exp = @ guard_setup @ [guard_call (CL_id (gs, CT_bool))] @ guard_cleanup - @ [ijump (V_id (gs, CT_bool)) guard_label; + @ [ijump (id_loc id) (V_id (gs, CT_bool)) guard_label; imatch_failure (); ilabel guard_label] )] @@ -1228,7 +1248,7 @@ let compile_funcl ctx id pat guard exp = in (* Optimize and compile the expression to ANF. *) - let aexp = ctx.optimize_anf ctx (no_shadow (pat_ids pat) (anf exp)) in + let aexp = C.optimize_anf ctx (no_shadow (pat_ids pat) (anf exp)) in let setup, call, cleanup = compile_aexp ctx aexp in let destructure, destructure_cleanup = @@ -1280,7 +1300,7 @@ and compile_def' n total ctx = function | DEF_reg_dec (DEC_aux (DEC_reg (_, _, typ, id), _)) -> [CDEF_reg_dec (id, ctyp_of_typ ctx typ, [])], ctx | DEF_reg_dec (DEC_aux (DEC_config (id, typ, exp), _)) -> - let aexp = ctx.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in + let aexp = C.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in let setup, call, cleanup = compile_aexp ctx aexp in let instrs = setup @ [call (CL_id (name id, ctyp_of_typ ctx typ))] @ cleanup in [CDEF_reg_dec (id, ctyp_of_typ ctx typ, instrs)], ctx @@ -1290,13 +1310,19 @@ and compile_def' n total ctx = function | DEF_spec (VS_aux (VS_val_spec (_, id, _, _), _)) -> let quant, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in + let extern = + if Env.is_extern id ctx.tc_env "c" then + Some (Env.get_extern id ctx.tc_env "c") + else + None + in let arg_typs, ret_typ = match fn_typ with | Typ_fn (arg_typs, ret_typ, _) -> arg_typs, ret_typ | _ -> assert false in let ctx' = { ctx with local_env = add_typquant (id_loc id) quant ctx.local_env } in let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ in - [CDEF_spec (id, arg_ctyps, ret_ctyp)], + [CDEF_spec (id, extern, arg_ctyps, ret_ctyp)], { ctx with valspecs = Bindings.add id (arg_ctyps, ret_ctyp) ctx.valspecs } | DEF_fundef (FD_aux (FD_function (_, _, _, [FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (pat, exp), _)), _)]), _)) -> @@ -1323,7 +1349,7 @@ and compile_def' n total ctx = function | DEF_val (LB_aux (LB_val (pat, exp), _)) -> let ctyp = ctyp_of_typ ctx (typ_of_pat pat) in - let aexp = ctx.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in + let aexp = C.optimize_anf ctx (no_shadow IdSet.empty (anf exp)) in let setup, call, cleanup = compile_aexp ctx aexp in let apat = anf_pat ~global:true pat in let gs = ngensym () in @@ -1544,12 +1570,6 @@ let sort_ctype_defs cdefs = ctype_defs @ cdefs let compile_ast ctx (Defs defs) = - let assert_vs = Initial_check.extern_of_string (mk_id "sail_assert") "(bool, string) -> unit" in - let exit_vs = Initial_check.extern_of_string (mk_id "sail_exit") "unit -> unit" in - let cons_vs = Initial_check.extern_of_string (mk_id "sail_cons") "forall ('a : Type). ('a, list('a)) -> list('a)" in - - let ctx = { ctx with tc_env = snd (Type_error.check ctx.tc_env (Defs [assert_vs; exit_vs; cons_vs])) } in - if !opt_memo_cache then (try if Sys.is_directory "_sbuild" then @@ -1568,3 +1588,12 @@ let compile_ast ctx (Defs defs) = let cdefs, ctx = specialize_variants ctx [] cdefs in let cdefs = sort_ctype_defs cdefs in cdefs, ctx + +end + +let add_special_functions env = + let assert_vs = Initial_check.extern_of_string (mk_id "sail_assert") "(bool, string) -> unit" in + let exit_vs = Initial_check.extern_of_string (mk_id "sail_exit") "unit -> unit" in + let cons_vs = Initial_check.extern_of_string (mk_id "sail_cons") "forall ('a : Type). ('a, list('a)) -> list('a)" in + + snd (Type_error.check env (Defs [assert_vs; exit_vs; cons_vs])) diff --git a/src/jib/jib_compile.mli b/src/jib/jib_compile.mli index 273e9e03..9014d8f7 100644 --- a/src/jib/jib_compile.mli +++ b/src/jib/jib_compile.mli @@ -58,53 +58,69 @@ open Type_check (** This forces all integer struct fields to be represented as int64_t. Specifically intended for the various TLB structs in the - ARM v8.5 spec. *) + ARM v8.5 spec. It is unsound in general. *) val optimize_aarch64_fast_struct : bool ref +(** If true (default) track the location of the last exception thrown, + useful for debugging C but we want to turn it off for SMT generation + where we can't use strings *) +val opt_track_throw : bool ref + (** {2 Jib context} *) -(** Context for compiling Sail to Jib. We need to pass a (global) - typechecking environment given by checking the full AST. We have to - provide a conversion function from Sail types into Jib types, as - well as a function that optimizes ANF expressions (which can just - be the identity function) *) +(** Dynamic context for compiling Sail to Jib. We need to pass a + (global) typechecking environment given by checking the full + AST. *) type ctx = { records : (ctyp Jib_util.UBindings.t) Bindings.t; enums : IdSet.t Bindings.t; variants : (ctyp Jib_util.UBindings.t) Bindings.t; valspecs : (ctyp list * ctyp) Bindings.t; - tc_env : Env.t; local_env : Env.t; + tc_env : Env.t; locals : (mut * ctyp) Bindings.t; letbinds : int list; no_raw : bool; - convert_typ : ctx -> typ -> ctyp; - optimize_anf : ctx -> typ aexp -> typ aexp; - (** If false (default), function arguments must match the function - type exactly. If true, they can be more specific. *) - specialize_calls : bool; - (** If false (default), will ensure that fixed size bitvectors are - specifically less that 64-bits. If true this restriction will - be ignored. *) - ignore_64 : bool; - (** If false (default) we won't generate any V_struct values *) - struct_value : bool; - (** Allow real literals *) - use_real : bool; } -val initial_ctx : - convert_typ:(ctx -> typ -> ctyp) -> - optimize_anf:(ctx -> typ aexp -> typ aexp) -> - Env.t -> - ctx +val initial_ctx : Env.t -> ctx (** {2 Compilation functions} *) -(** Compile a Sail definition into a Jib definition. The first two - arguments are is the current definition number and the total number - of definitions, and can be used to drive a progress bar (see - Util.progress). *) -val compile_def : int -> int -> ctx -> tannot def -> cdef list * ctx +(** The Config module specifies static configuration for compiling + Sail into Jib. We have to provide a conversion function from Sail + types into Jib types, as well as a function that optimizes ANF + expressions (which can just be the identity function) *) +module type Config = sig + val convert_typ : ctx -> typ -> ctyp + val optimize_anf : ctx -> typ aexp -> typ aexp + (** Unroll all for loops a bounded number of times. Used for SMT + generation. *) + val unroll_loops : unit -> int option + (** If false, function arguments must match the function + type exactly. If true, they can be more specific. *) + val specialize_calls : bool + (** If false, will ensure that fixed size bitvectors are + specifically less that 64-bits. If true this restriction will + be ignored. *) + val ignore_64 : bool + (** If false we won't generate any V_struct values *) + val struct_value : bool + (** Allow real literals *) + val use_real : bool +end + +module Make(C: Config) : sig + (** Compile a Sail definition into a Jib definition. The first two + arguments are is the current definition number and the total + number of definitions, and can be used to drive a progress bar + (see Util.progress). *) + val compile_def : int -> int -> ctx -> tannot def -> cdef list * ctx + + val compile_ast : ctx -> tannot defs -> cdef list * ctx +end -val compile_ast : ctx -> tannot defs -> cdef list * ctx +(** Adds some special functions to the environment that are used to + convert several Sail language features, these are sail_assert, + sail_exit, and sail_cons. *) +val add_special_functions : Env.t -> Env.t diff --git a/src/jib/jib_ir.ml b/src/jib/jib_ir.ml index c5f2b20a..4bf726aa 100644 --- a/src/jib/jib_ir.ml +++ b/src/jib/jib_ir.ml @@ -69,7 +69,9 @@ let string_of_name = "return" ^ ssa_num n | Current_exception n -> "current_exception" ^ ssa_num n - + | Throw_location n -> + "throw_location" ^ ssa_num n + let rec string_of_clexp = function | CL_id (id, ctyp) -> string_of_name id | CL_field (clexp, field) -> string_of_clexp clexp ^ "." ^ string_of_uid field @@ -107,7 +109,9 @@ module Ir_formatter = struct | I_label label -> C.output_label_instr buf label_map label | I_jump (cval, label) -> - add_instr n buf indent (C.keyword "jump" ^ " " ^ C.value cval ^ " " ^ C.string_of_label (StringMap.find label label_map)) + add_instr n buf indent (C.keyword "jump" ^ " " ^ C.value cval ^ " " + ^ C.keyword "goto" ^ " " ^ C.string_of_label (StringMap.find label label_map) + ^ " ` \"" ^ Reporting.short_loc_to_string l ^ "\"") | I_goto label -> add_instr n buf indent (C.keyword "goto" ^ " " ^ C.string_of_label (StringMap.find label label_map)) | I_match_failure -> @@ -151,8 +155,10 @@ module Ir_formatter = struct let output_def buf = function | CDEF_reg_dec (id, ctyp, _) -> Buffer.add_string buf (sprintf "%s %s : %s" (C.keyword "register") (zencode_id id) (C.typ ctyp)) - | CDEF_spec (id, ctyps, ctyp) -> + | CDEF_spec (id, None, ctyps, ctyp) -> Buffer.add_string buf (sprintf "%s %s : (%s) -> %s" (C.keyword "val") (zencode_id id) (Util.string_of_list ", " C.typ ctyps) (C.typ ctyp)); + | CDEF_spec (id, Some extern, ctyps, ctyp) -> + Buffer.add_string buf (sprintf "%s %s = \"%s\" : (%s) -> %s" (C.keyword "val") (zencode_id id) extern (Util.string_of_list ", " C.typ ctyps) (C.typ ctyp)); | CDEF_fundef (id, ret, args, instrs) -> let instrs = C.modify_instrs instrs in let label_map = C.make_label_map instrs in @@ -244,10 +250,10 @@ let () = let open Interactive in let open Jib_interactive in - (fun arg -> + ArgString ("(val|register)? identifier", fun arg -> Action (fun () -> let is_def id = function | CDEF_fundef (id', _, _, _) -> Id.compare id id' = 0 - | CDEF_spec (id', _, _) -> Id.compare id (prepend_id "val " id') = 0 + | CDEF_spec (id', _, _, _) -> Id.compare id (prepend_id "val " id') = 0 | CDEF_reg_dec (id', _, _) -> Id.compare id (prepend_id "register " id') = 0 | _ -> false in @@ -258,12 +264,12 @@ let () = let buf = Buffer.create 256 in with_colors (fun () -> Flat_ir_formatter.output_def buf cdef); print_endline (Buffer.contents buf) - ) |> Interactive.(register_command ~name:"ir" ~help:(sprintf ":ir %s - Print the ir representation of a toplevel definition" (arg "(val|register)? identifier"))); + )) |> Interactive.register_command ~name:"ir" ~help:"Print the ir representation of a toplevel definition"; - (fun file -> + ArgString ("file", fun file -> Action (fun () -> let buf = Buffer.create 256 in let out_chan = open_out file in Flat_ir_formatter.output_defs buf !ir; output_string out_chan (Buffer.contents buf); close_out out_chan - ) |> Interactive.(register_command ~name:"dump_ir" ~help:(sprintf ":dump_ir %s - Dump the ir to a file" (arg "file"))) + )) |> Interactive.register_command ~name:"dump_ir" ~help:"Dump the ir to a file" diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml index 323f3cd0..e0f3bf0d 100644 --- a/src/jib/jib_optimize.ml +++ b/src/jib/jib_optimize.ml @@ -102,10 +102,10 @@ let rec flatten_instrs = function | I_aux ((I_block block | I_try_block block), _) :: instrs -> flatten_instrs block @ flatten_instrs instrs - | I_aux (I_if (cval, then_instrs, else_instrs, _), _) :: instrs -> + | I_aux (I_if (cval, then_instrs, else_instrs, _), (_, l)) :: instrs -> let then_label = label "then_" in let endif_label = label "endif_" in - [ijump cval then_label] + [ijump l cval then_label] @ flatten_instrs else_instrs @ [igoto endif_label] @ [ilabel then_label] @@ -153,7 +153,7 @@ let unique_per_function_ids cdefs = | CDEF_reg_dec (id, ctyp, instrs) -> CDEF_reg_dec (id, ctyp, unique_instrs i instrs) | CDEF_type ctd -> CDEF_type ctd | CDEF_let (n, bindings, instrs) -> CDEF_let (n, bindings, unique_instrs i instrs) - | CDEF_spec (id, ctyps, ctyp) -> CDEF_spec (id, ctyps, ctyp) + | CDEF_spec (id, extern, ctyps, ctyp) -> CDEF_spec (id, extern, ctyps, ctyp) | CDEF_fundef (id, heap_return, args, instrs) -> CDEF_fundef (id, heap_return, args, unique_instrs i instrs) | CDEF_startup (id, instrs) -> CDEF_startup (id, unique_instrs i instrs) | CDEF_finish (id, instrs) -> CDEF_finish (id, unique_instrs i instrs) @@ -162,7 +162,6 @@ let unique_per_function_ids cdefs = let rec cval_subst id subst = function | V_id (id', ctyp) -> if Name.compare id id' = 0 then subst else V_id (id', ctyp) - | V_ref (reg_id, ctyp) -> V_ref (reg_id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_call (op, cvals) -> V_call (op, List.map (cval_subst id subst) cvals) | V_field (cval, field) -> V_field (cval_subst id subst cval, field) @@ -174,7 +173,6 @@ let rec cval_subst id subst = function let rec cval_map_id f = function | V_id (id, ctyp) -> V_id (f id, ctyp) - | V_ref (id, ctyp) -> V_ref (f id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_call (call, cvals) -> V_call (call, List.map (cval_map_id f) cvals) | V_field (cval, field) -> V_field (cval_map_id f cval, field) @@ -249,6 +247,7 @@ let ssa_name i = function | Name (id, _) -> Name (id, i) | Have_exception _ -> Have_exception i | Current_exception _ -> Current_exception i + | Throw_location _ -> Throw_location i | Return _ -> Return i let inline cdefs should_inline instrs = @@ -347,6 +346,15 @@ let rec remove_pointless_goto = function instr :: remove_pointless_goto instrs | [] -> [] +let rec remove_pointless_exit = function + | I_aux (I_end id, aux) :: I_aux (I_end _, _) :: instrs -> + I_aux (I_end id, aux) :: remove_pointless_exit instrs + | I_aux (I_end id, aux) :: I_aux (I_undefined _, _) :: instrs -> + I_aux (I_end id, aux) :: remove_pointless_exit instrs + | instr :: instrs -> + instr :: remove_pointless_exit instrs + | [] -> [] + module StringSet = Set.Make(String) let rec get_used_labels set = function @@ -364,7 +372,6 @@ let remove_unused_labels instrs = in go [] instrs - let remove_dead_after_goto instrs = let rec go acc = function | (I_aux (I_goto _, _) as instr) :: instrs -> go_dead (instr :: acc) instrs @@ -379,7 +386,7 @@ let remove_dead_after_goto instrs = let rec remove_dead_code instrs = let instrs' = - instrs |> remove_unused_labels |> remove_pointless_goto |> remove_dead_after_goto + instrs |> remove_unused_labels |> remove_pointless_goto |> remove_dead_after_goto |> remove_pointless_exit in if List.length instrs' < List.length instrs then remove_dead_code instrs' @@ -398,7 +405,7 @@ let remove_tuples cdefs ctx = CTSet.add ctyp (List.fold_left CTSet.union CTSet.empty (List.map all_tuples ctyps)) | CT_struct (_, id_ctyps) | CT_variant (_, id_ctyps) -> List.fold_left (fun cts (_, ctyp) -> CTSet.union (all_tuples ctyp) cts) CTSet.empty id_ctyps - | CT_list ctyp | CT_vector (_, ctyp) | CT_ref ctyp -> + | CT_list ctyp | CT_vector (_, ctyp) | CT_fvector (_, _, ctyp) | CT_ref ctyp -> all_tuples ctyp | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_unit | CT_bool | CT_real | CT_bit | CT_poly | CT_string | CT_enum _ -> @@ -409,7 +416,7 @@ let remove_tuples cdefs ctx = 1 + List.fold_left (fun d ctyp -> max d (tuple_depth ctyp)) 0 ctyps | CT_struct (_, id_ctyps) | CT_variant (_, id_ctyps) -> List.fold_left (fun d (_, ctyp) -> max (tuple_depth ctyp) d) 0 id_ctyps - | CT_list ctyp | CT_vector (_, ctyp) | CT_ref ctyp -> + | CT_list ctyp | CT_vector (_, ctyp) | CT_fvector (_, _, ctyp) | CT_ref ctyp -> tuple_depth ctyp | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_unit | CT_bool | CT_real | CT_bit | CT_poly | CT_string | CT_enum _ -> @@ -426,6 +433,7 @@ let remove_tuples cdefs ctx = CT_variant (id, List.map (fun (id, ctyp) -> id, fix_tuples ctyp) id_ctyps) | CT_list ctyp -> CT_list (fix_tuples ctyp) | CT_vector (d, ctyp) -> CT_vector (d, fix_tuples ctyp) + | CT_fvector (n, d, ctyp) -> CT_fvector (n, d, fix_tuples ctyp) | CT_ref ctyp -> CT_ref (fix_tuples ctyp) | (CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ | CT_unit | CT_bool | CT_real | CT_bit | CT_poly | CT_string | CT_enum _) as ctyp -> @@ -433,7 +441,6 @@ let remove_tuples cdefs ctx = in let rec fix_cval = function | V_id (id, ctyp) -> V_id (id, ctyp) - | V_ref (id, ctyp) -> V_ref (id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_ctor_kind (cval, id, unifiers, ctyp) -> V_ctor_kind (fix_cval cval, id, unifiers, ctyp) diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index fbaf8d3f..81b876a4 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -73,6 +73,8 @@ let opt_debug_graphs = ref false let opt_propagate_vars = ref false +let opt_unroll_limit = ref 10 + module EventMap = Map.Make(Event) (* Note that we have to use x : ty ref rather than mutable x : ty, to @@ -89,6 +91,8 @@ type ctx = { pragma_l : Ast.l; arg_stack : (int * string) Stack.t; ast : Type_check.tannot defs; + shared : ctyp Bindings.t; + preserved : IdSet.t; events : smt_exp Stack.t EventMap.t ref; node : int; pathcond : smt_exp Lazy.t; @@ -114,6 +118,8 @@ let initial_ctx () = { pragma_l = Parse_ast.Unknown; arg_stack = Stack.create (); ast = Defs []; + shared = Bindings.empty; + preserved = IdSet.empty; events = ref EventMap.empty; node = -1; pathcond = lazy (Bool_lit true); @@ -129,6 +135,19 @@ let event_stack ctx ev = ctx.events := EventMap.add ev stack !(ctx.events); stack +let add_event ctx ev smt = + let stack = event_stack ctx ev in + Stack.push (Fn ("and", [Lazy.force ctx.pathcond; smt])) stack + +let add_pathcond_event ctx ev = + Stack.push (Lazy.force ctx.pathcond) (event_stack ctx ev) + +let overflow_check ctx smt = + if not !opt_ignore_overflow then ( + Reporting.warn "Overflow check in generated SMT for" ctx.pragma_l ""; + add_event ctx Overflow smt + ) + let lbits_size ctx = Util.power 2 ctx.lbits_index let vector_index = ref 5 @@ -179,6 +198,8 @@ let rec smt_ctyp ctx = function | _ -> failwith ("No registers with ctyp: " ^ string_of_ctyp ctyp) end | CT_list _ -> raise (Reporting.err_todo ctx.pragma_l "Lists not yet supported in SMT generation") + | CT_fvector _ -> + Reporting.unreachable ctx.pragma_l __POS__ "Found CT_fvector in SMT property" | CT_poly -> Reporting.unreachable ctx.pragma_l __POS__ "Found polymorphic type in SMT property" @@ -188,21 +209,17 @@ let rec smt_ctyp ctx = function don't have a very good way to get the binary representation of either an ocaml integer or a big integer. *) let bvpint sz x = + let open Sail2_values in if Big_int.less_equal Big_int.zero x && Big_int.less_equal x (Big_int.of_int max_int) then ( - let open Sail_lib in let x = Big_int.to_int x in - if sz mod 4 = 0 then - let hex = Printf.sprintf "%X" x in - let padding = String.make (sz / 4 - String.length hex) '0' in - Hex (padding ^ hex) - else - let bin = Printf.sprintf "%X" x |> list_of_string |> List.map hex_char |> List.concat in - let _, bin = Util.take_drop (function B0 -> true | B1 -> false) bin in - let bin = String.concat "" (List.map string_of_bit bin) in - let padding = String.make (sz - String.length bin) '0' in - Bin (padding ^ bin) + match Printf.sprintf "%X" x |> Util.string_to_list |> List.map nibble_of_char |> Util.option_all with + | Some nibbles -> + let bin = List.map (fun (a, b, c, d) -> [a; b; c; d]) nibbles |> List.concat in + let _, bin = Util.take_drop (function B0 -> true | _ -> false) bin in + let padding = List.init (sz - List.length bin) (fun _ -> B0) in + Bitvec_lit (padding @ bin) + | None -> assert false ) else if Big_int.greater x (Big_int.of_int max_int) then ( - let open Sail_lib in let y = ref x in let bin = ref [] in while (not (Big_int.equal !y Big_int.zero)) do @@ -210,14 +227,13 @@ let bvpint sz x = bin := (if Big_int.equal m Big_int.zero then B0 else B1) :: !bin; y := q done; - let bin = String.concat "" (List.map string_of_bit !bin) in - let padding_size = sz - String.length bin in + let padding_size = sz - List.length !bin in if padding_size < 0 then raise (Reporting.err_general Parse_ast.Unknown (Printf.sprintf "Could not create a %d-bit integer with value %s.\nTry increasing the maximum integer size" sz (Big_int.to_string x))); - let padding = String.make (sz - String.length bin) '0' in - Bin (padding ^ bin) + let padding = List.init padding_size (fun _ -> B0) in + Bitvec_lit (padding @ !bin) ) else failwith "Invalid bvpint" let bvint sz x = @@ -226,22 +242,68 @@ let bvint sz x = else bvpint sz x +(** [force_size ctx n m exp] takes a smt expression assumed to be a + integer (signed bitvector) of length m and forces it to be length n + by either sign extending it or truncating it as required *) +let force_size ?checked:(checked=true) ctx n m smt = + if n = m then + smt + else if n > m then + SignExtend (n - m, smt) + else + let check = + (* If the top bit of the truncated number is one *) + Ite (Fn ("=", [Extract (n - 1, n - 1, smt); Bitvec_lit [Sail2_values.B1]]), + (* Then we have an overflow, unless all bits we truncated were also one *) + Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvones (m - n)])]), + (* Otherwise, all the top bits must be zero *) + Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvzero (m - n)])])) + in + if checked then overflow_check ctx check else (); + Extract (n - 1, 0, smt) + +(** [unsigned_size ctx n m exp] is much like force_size, but it + assumes that the bitvector is unsigned *) +let unsigned_size ?checked:(checked=true) ctx n m smt = + if n = m then + smt + else if n > m then + Fn ("concat", [bvzero (n - m); smt]) + else + Extract (n - 1, 0, smt) + +let smt_conversion ctx from_ctyp to_ctyp x = + match from_ctyp, to_ctyp with + | _, _ when ctyp_equal from_ctyp to_ctyp -> x + | CT_constant c, CT_fint sz -> + bvint sz c + | CT_constant c, CT_lint -> + bvint ctx.lint_size c + | CT_fint sz, CT_lint -> + force_size ctx ctx.lint_size sz x + | CT_lint, CT_fint sz -> + force_size ctx sz ctx.lint_size x + | CT_lbits _, CT_fbits (n, _) -> + unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [x])) + | CT_fbits (n, _), CT_lbits _ -> + Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int n); unsigned_size ctx (lbits_size ctx) n x]) + + | _, _ -> failwith (Printf.sprintf "Cannot perform conversion from %s to %s" (string_of_ctyp from_ctyp) (string_of_ctyp to_ctyp)) + (* Translate Jib literals into SMT *) -let smt_value ctx vl ctyp = +let rec smt_value ctx vl ctyp = let open Value2 in match vl, ctyp with - | VL_bits (bs, true), CT_fbits (n, _) -> - (* FIXME: Output the correct number of bits in Jib_compile *) - begin match Sail2_values.hexstring_of_bits (List.rev (Util.take n (List.rev bs))) with - | Some s -> Hex (Xstring.implode s) - | None -> Bin (Xstring.implode (List.map Sail2_values.bitU_char (List.rev (Util.take n (List.rev bs))))) - end + | VL_bits (bv, true), CT_fbits (n, _) -> + unsigned_size ctx n (List.length bv) (Bitvec_lit bv) + | VL_bits (bv, true), CT_lbits _ -> + let sz = List.length bv in + Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int sz); unsigned_size ctx (lbits_size ctx) sz (Bitvec_lit bv)]) | VL_bool b, _ -> Bool_lit b | VL_int n, CT_constant m -> bvint (required_width n) n | VL_int n, CT_fint sz -> bvint sz n | VL_int n, CT_lint -> bvint ctx.lint_size n - | VL_bit Sail2_values.B0, CT_bit -> Bin "0" - | VL_bit Sail2_values.B1, CT_bit -> Bin "1" + | VL_bit b, CT_bit -> Bitvec_lit [b] | VL_unit, _ -> Enum "unit" | VL_string str, _ -> ctx.use_string := true; @@ -252,7 +314,21 @@ let smt_value ctx vl ctyp = Fn ("-", [Real_lit (String.sub str 1 (String.length str - 1))]) else Real_lit str - | vl, _ -> failwith ("Cannot translate literal to SMT: " ^ string_of_value vl) + | VL_enum str, _ -> Enum (Util.zencode_string str) + | VL_ref reg_name, _ -> + let id = mk_id reg_name in + let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in + assert (CTMap.cardinal rmap = 1); + begin match CTMap.min_binding_opt rmap with + | Some (ctyp, regs) -> + begin match Util.list_index (fun reg -> Id.compare reg id = 0) regs with + | Some i -> + bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i) + | None -> assert false + end + | _ -> assert false + end + | _ -> failwith ("Cannot translate literal to SMT: " ^ string_of_value vl ^ " : " ^ string_of_ctyp ctyp) let rec smt_cval ctx cval = match cval_ctyp cval with @@ -264,6 +340,7 @@ let rec smt_cval ctx cval = | V_id (Name (id, _) as ssa_id, _) -> begin match Type_check.Env.lookup_id id ctx.tc_env with | Enum _ -> Enum (zencode_id id) + | _ when Bindings.mem id ctx.shared -> Shared (zencode_id id) | _ -> Var (zencode_name ssa_id) end | V_id (ssa_id, _) -> Var (zencode_name ssa_id) @@ -271,8 +348,6 @@ let rec smt_cval ctx cval = Fn ("not", [Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2])]) | V_call (Bvor, [cval1; cval2]) -> Fn ("bvor", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_call (Bit_to_bool, [cval]) -> - Fn ("=", [smt_cval ctx cval; Bin "1"]) | V_call (Eq, [cval1; cval2]) -> Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2]) | V_call (Bnot, [cval]) -> @@ -281,14 +356,18 @@ let rec smt_cval ctx cval = smt_conj (List.map (smt_cval ctx) cvals) | V_call (Bor, cvals) -> smt_disj (List.map (smt_cval ctx) cvals) + | V_call (Igt, [cval1; cval2]) -> + Fn ("bvsgt", [smt_cval ctx cval1; smt_cval ctx cval2]) + | V_call (Iadd, [cval1; cval2]) -> + Fn ("bvadd", [smt_cval ctx cval1; smt_cval ctx cval2]) | V_ctor_kind (union, ctor_id, unifiers, _) -> Fn ("not", [Tester (zencode_uid (ctor_id, unifiers), smt_cval ctx union)]) | V_ctor_unwrap (ctor_id, union, unifiers, _) -> Fn ("un" ^ zencode_uid (ctor_id, unifiers), [smt_cval ctx union]) - | V_field (union, field) -> - begin match cval_ctyp union with + | V_field (record, field) -> + begin match cval_ctyp record with | CT_struct (struct_id, _) -> - Fn (zencode_upper_id struct_id ^ "_" ^ zencode_uid field, [smt_cval ctx union]) + Field (zencode_upper_id struct_id ^ "_" ^ zencode_uid field, smt_cval ctx record) | _ -> failwith "Field for non-struct type" end | V_struct (fields, ctyp) -> @@ -297,43 +376,18 @@ let rec smt_cval ctx cval = let set_field (field, cval) = match Util.assoc_compare_opt UId.compare field field_ctyps with | None -> failwith "Field type not found" - | Some ctyp when ctyp_equal (cval_ctyp cval) ctyp -> - smt_cval ctx cval - | _ -> failwith "Type mismatch when generating struct for SMT" + | Some ctyp -> + zencode_upper_id struct_id ^ "_" ^ zencode_uid field, + smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval) in - Fn (zencode_upper_id struct_id, List.map set_field fields) + Struct (zencode_upper_id struct_id, List.map set_field fields) | _ -> failwith "Struct does not have struct type" end | V_tuple_member (frag, len, n) -> ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes); Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag]) - | V_ref (Name (id, _), _) -> - let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in - assert (CTMap.cardinal rmap = 1); - begin match CTMap.min_binding_opt rmap with - | Some (ctyp, regs) -> - begin match Util.list_index (fun reg -> Id.compare reg id = 0) regs with - | Some i -> - bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i) - | None -> assert false - end - | _ -> assert false - end | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval) -let add_event ctx ev smt = - let stack = event_stack ctx ev in - Stack.push (Fn ("=>", [Lazy.force ctx.pathcond; smt])) stack - -let add_pathcond_event ctx ev = - Stack.push (Lazy.force ctx.pathcond) (event_stack ctx ev) - -let overflow_check ctx smt = - if not !opt_ignore_overflow then ( - Reporting.warn "Overflow check in generated SMT for" ctx.pragma_l ""; - add_event ctx Overflow smt - ) - (**************************************************************************) (* 1. Generating SMT for Sail builtins *) (**************************************************************************) @@ -342,8 +396,8 @@ let builtin_type_error ctx fn cvals = let args = Util.string_of_list ", " (fun cval -> string_of_ctyp (cval_ctyp cval)) cvals in function | Some ret_ctyp -> - raise (Reporting.err_todo ctx.pragma_l - (Printf.sprintf "%s : (%s) -> %s" fn args (string_of_ctyp ret_ctyp))) + let message = Printf.sprintf "%s : (%s) -> %s" fn args (string_of_ctyp ret_ctyp) in + raise (Reporting.err_todo ctx.pragma_l message) | None -> raise (Reporting.err_todo ctx.pragma_l (Printf.sprintf "%s : (%s)" fn args)) @@ -385,36 +439,6 @@ let builtin_gteq = builtin_int_comparison "bvsge" Big_int.greater_equal (* ***** Arithmetic operations: lib/arith.sail ***** *) -(** [force_size ctx n m exp] takes a smt expression assumed to be a - integer (signed bitvector) of length m and forces it to be length n - by either sign extending it or truncating it as required *) -let force_size ?checked:(checked=true) ctx n m smt = - if n = m then - smt - else if n > m then - SignExtend (n - m, smt) - else - let check = - (* If the top bit of the truncated number is one *) - Ite (Fn ("=", [Extract (n - 1, n - 1, smt); Bin "1"]), - (* Then we have an overflow, unless all bits we truncated were also one *) - Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvones (m - n)])]), - (* Otherwise, all the top bits must be zero *) - Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvzero (m - n)])])) - in - if checked then overflow_check ctx check else (); - Extract (n - 1, 0, smt) - -(** [unsigned_size ctx n m exp] is much like force_size, but it - assumes that the bitvector is unsigned *) -let unsigned_size ?checked:(checked=true) ctx n m smt = - if n = m then - smt - else if n > m then - Fn ("concat", [bvzero (n - m); smt]) - else - Extract (n - 1, 0, smt) - let int_size ctx = function | CT_constant n -> required_width n | CT_fint sz -> sz @@ -457,8 +481,9 @@ let builtin_negate_int ctx v ret_ctyp = | CT_constant c, _ -> bvint (int_size ctx ret_ctyp) (Big_int.negate c) | ctyp, _ -> + let open Sail2_values in let smt = force_size ctx (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in - overflow_check ctx (Fn ("=", [smt; Bin ("1" ^ String.make (int_size ctx ret_ctyp - 1) '0')])); + overflow_check ctx (Fn ("=", [smt; Bitvec_lit (B1 :: List.init (int_size ctx ret_ctyp - 1) (fun _ -> B0))])); Fn ("bvneg", [smt]) let builtin_shift_int fn big_int_fn ctx v1 v2 ret_ctyp = @@ -494,7 +519,7 @@ let builtin_abs_int ctx v ret_ctyp = | ctyp, _ -> let sz = int_size ctx ctyp in let smt = smt_cval ctx v in - Ite (Fn ("=", [Extract (sz - 1, sz -1, smt); Bin "1"]), + Ite (Fn ("=", [Extract (sz - 1, sz -1, smt); Bitvec_lit [Sail2_values.B1]]), force_size ctx (int_size ctx ret_ctyp) sz (Fn ("bvneg", [smt])), force_size ctx (int_size ctx ret_ctyp) sz smt) @@ -531,6 +556,25 @@ let builtin_min_int ctx v1 v2 ret_ctyp = smt1, smt2) +let builtin_min_int ctx v1 v2 ret_ctyp = + match cval_ctyp v1, cval_ctyp v2 with + | CT_constant n, CT_constant m -> + bvint (int_size ctx ret_ctyp) (min n m) + + | ctyp1, ctyp2 -> + let ret_sz = int_size ctx ret_ctyp in + let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in + let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in + Ite (Fn ("bvslt", [smt1; smt2]), + smt1, + smt2) + +let builtin_tdiv_int = + builtin_arith "bvudiv" (Sail2_values.tdiv_int) (fun x -> x) + +let builtin_tmod_int = + builtin_arith "bvurem" (Sail2_values.tmod_int) (fun x -> x) + let bvmask ctx len = let all_ones = bvones (lbits_size ctx) in let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); len]) in @@ -623,7 +667,7 @@ let builtin_sign_extend ctx vbits vlen ret_ctyp = smt_cval ctx vbits | CT_fbits (n, _), CT_fbits (m, _) -> let bv = smt_cval ctx vbits in - let top_bit_one = Fn ("=", [Extract (n - 1, n - 1, bv); Bin "1"]) in + let top_bit_one = Fn ("=", [Extract (n - 1, n - 1, bv); Bitvec_lit [Sail2_values.B1]]) in Ite (top_bit_one, Fn ("concat", [bvones (m - n); bv]), Fn ("concat", [bvzero (m - n); bv])) | _ -> builtin_type_error ctx "sign_extend" [vbits; vlen] (Some ret_ctyp) @@ -658,14 +702,14 @@ let builtin_not_bits ctx v ret_ctyp = | _, _ -> builtin_type_error ctx "not_bits" [v] (Some ret_ctyp) let builtin_bitwise fn ctx v1 v2 ret_ctyp = - match cval_ctyp v1, cval_ctyp v2 with - | CT_fbits (n, _), CT_fbits (m, _) -> - assert (n = m); + match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + | CT_fbits (n, _), CT_fbits (m, _), CT_fbits (o, _) -> + assert (n = m && m = o); let smt1 = smt_cval ctx v1 in let smt2 = smt_cval ctx v2 in Fn (fn, [smt1; smt2]) - | CT_lbits _, CT_lbits _ -> + | CT_lbits _, CT_lbits _, CT_lbits _ -> let smt1 = smt_cval ctx v1 in let smt2 = smt_cval ctx v2 in Fn ("Bits", [Fn ("len", [smt1]); Fn (fn, [Fn ("contents", [smt1]); Fn ("contents", [smt2])])]) @@ -674,6 +718,7 @@ let builtin_bitwise fn ctx v1 v2 ret_ctyp = let builtin_and_bits = builtin_bitwise "bvand" let builtin_or_bits = builtin_bitwise "bvor" +let builtin_xor_bits = builtin_bitwise "bvxor" let builtin_append ctx v1 v2 ret_ctyp = match cval_ctyp v1, cval_ctyp v2, ret_ctyp with @@ -743,19 +788,29 @@ let builtin_length ctx v ret_ctyp = | _, _ -> builtin_type_error ctx "length" [v] (Some ret_ctyp) let builtin_vector_subrange ctx vec i j ret_ctyp = - match cval_ctyp vec, cval_ctyp i, cval_ctyp j with - | CT_fbits (n, _), CT_constant i, CT_constant j -> + match cval_ctyp vec, cval_ctyp i, cval_ctyp j, ret_ctyp with + | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits _ -> Extract (Big_int.to_int i, Big_int.to_int j, smt_cval ctx vec) - | CT_lbits _, CT_constant i, CT_constant j -> + | CT_lbits _, CT_constant i, CT_constant j, CT_fbits _ -> Extract (Big_int.to_int i, Big_int.to_int j, Fn ("contents", [smt_cval ctx vec])) + | CT_fbits (n, _), i_ctyp, CT_constant j, CT_lbits _ when Big_int.equal j Big_int.zero -> + let len = force_size ~checked:false ctx ctx.lbits_index (int_size ctx i_ctyp) (smt_cval ctx i) in + Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; unsigned_size ctx (lbits_size ctx) n (smt_cval ctx vec)])]) + | _ -> builtin_type_error ctx "vector_subrange" [vec; i; j] (Some ret_ctyp) let builtin_vector_access ctx vec i ret_ctyp = match cval_ctyp vec, cval_ctyp i, ret_ctyp with | CT_fbits (n, _), CT_constant i, CT_bit -> Extract (Big_int.to_int i, Big_int.to_int i, smt_cval ctx vec) + | CT_lbits _, CT_constant i, CT_bit -> + Extract (Big_int.to_int i, Big_int.to_int i, Fn ("contents", [smt_cval ctx vec])) + + | CT_lbits _, i_ctyp, CT_bit -> + let shift = force_size ~checked:false ctx (lbits_size ctx) (int_size ctx i_ctyp) (smt_cval ctx i) in + Extract (0, 0, Fn ("bvlshr", [Fn ("contents", [smt_cval ctx vec]); shift])) | CT_vector _, CT_constant i, _ -> Fn ("select", [smt_cval ctx vec; bvint !vector_index i]) @@ -787,6 +842,21 @@ let builtin_vector_update ctx vec i x ret_ctyp = | _ -> builtin_type_error ctx "vector_update" [vec; i; x] (Some ret_ctyp) +let builtin_vector_update_subrange ctx vec i j x ret_ctyp = + match cval_ctyp vec, cval_ctyp i, cval_ctyp j, cval_ctyp x, ret_ctyp with + | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) when n - 1 > Big_int.to_int i && Big_int.to_int j >= 0 -> + assert (n = m); + let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in + let bot = Extract (Big_int.to_int j - 1, 0, smt_cval ctx vec) in + Fn ("concat", [top; Fn ("concat", [smt_cval ctx x; bot])]) + + | CT_fbits (n, _), CT_constant i, CT_constant j, CT_fbits (sz, _), CT_fbits (m, _) when n - 1 = Big_int.to_int i && Big_int.to_int j >= 0 -> + assert (n = m); + let bot = Extract (Big_int.to_int j - 1, 0, smt_cval ctx vec) in + Fn ("concat", [smt_cval ctx x; bot]) + + | _ -> builtin_type_error ctx "vector_update_subrange" [vec; i; j; x] (Some ret_ctyp) + let builtin_unsigned ctx v ret_ctyp = match cval_ctyp v, ret_ctyp with | CT_fbits (n, _), CT_fint m when m > n -> @@ -800,6 +870,9 @@ let builtin_unsigned ctx v ret_ctyp = let smt = smt_cval ctx v in Fn ("concat", [bvzero (ctx.lint_size - n); smt]) + | CT_lbits _, CT_lint -> + Extract (ctx.lint_size - 1, 0, Fn ("contents", [smt_cval ctx v])) + | ctyp, _ -> builtin_type_error ctx "unsigned" [v] (Some ret_ctyp) let builtin_signed ctx v ret_ctyp = @@ -810,6 +883,9 @@ let builtin_signed ctx v ret_ctyp = | CT_fbits (n, _), CT_lint -> SignExtend(ctx.lint_size - n, smt_cval ctx v) + | CT_lbits _, CT_lint -> + Extract (ctx.lint_size - 1, 0, Fn ("contents", [smt_cval ctx v])) + | ctyp, _ -> builtin_type_error ctx "signed" [v] (Some ret_ctyp) let builtin_add_bits ctx v1 v2 ret_ctyp = @@ -818,6 +894,11 @@ let builtin_add_bits ctx v1 v2 ret_ctyp = assert (n = m && m = o); Fn ("bvadd", [smt_cval ctx v1; smt_cval ctx v2]) + | CT_lbits _, CT_lbits _, CT_lbits _ -> + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Fn ("Bits", [Fn ("len", [smt1]); Fn ("bvadd", [Fn ("contents", [smt1]); Fn ("contents", [smt2])])]) + | _ -> builtin_type_error ctx "add_bits" [v1; v2] (Some ret_ctyp) let builtin_sub_bits ctx v1 v2 ret_ctyp = @@ -866,6 +947,13 @@ let builtin_replicate_bits ctx v1 v2 ret_ctyp = let c = m / n in Fn ("concat", List.init c (fun _ -> smt)) + | CT_fbits (n, _), v2_ctyp, CT_lbits _ -> + let times = (lbits_size ctx / n) + 1 in + let len = force_size ~checked:false ctx ctx.lbits_index (int_size ctx v2_ctyp) (smt_cval ctx v2) in + let smt1 = smt_cval ctx v1 in + let contents = Extract (lbits_size ctx - 1, 0, Fn ("concat", List.init times (fun _ -> smt1))) in + Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; contents])]) + | _ -> builtin_type_error ctx "replicate_bits" [v1; v2] (Some ret_ctyp) let builtin_sail_truncate ctx v1 v2 ret_ctyp = @@ -928,13 +1016,18 @@ let builtin_get_slice_int ctx v1 v2 v3 ret_ctyp = in Extract ((start + len) - 1, start, smt) + | CT_lint, CT_lint, CT_constant start, CT_lbits _ when Big_int.equal start Big_int.zero -> + let len = Extract (ctx.lbits_index - 1, 0, smt_cval ctx v1) in + let contents = unsigned_size ~checked:false ctx (lbits_size ctx) ctx.lint_size (smt_cval ctx v2) in + Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; contents])]) + | _ -> builtin_type_error ctx "get_slice_int" [v1; v2; v3] (Some ret_ctyp) let builtin_count_leading_zeros ctx v ret_ctyp = let ret_sz = int_size ctx ret_ctyp in let rec lzcnt sz smt = if sz == 1 then - Ite (Fn ("=", [Extract (0, 0, smt); Bin "0"]), + Ite (Fn ("=", [Extract (0, 0, smt); Bitvec_lit [Sail2_values.B0]]), bvint ret_sz (Big_int.of_int 1), bvint ret_sz (Big_int.zero)) else ( @@ -1050,6 +1143,8 @@ let smt_builtin ctx name args ret_ctyp = | "max_int", [v1; v2], _ -> builtin_max_int ctx v1 v2 ret_ctyp | "min_int", [v1; v2], _ -> builtin_min_int ctx v1 v2 ret_ctyp + | "ediv_int", [v1; v2], _ -> builtin_tdiv_int ctx v1 v2 ret_ctyp + (* All signed and unsigned bitvector comparisons *) | "slt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvslt" ctx v1 v2 ret_ctyp | "ult_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvult" ctx v1 v2 ret_ctyp @@ -1072,8 +1167,9 @@ let smt_builtin ctx name args ret_ctyp = | "sail_truncateLSB", [v1; v2], _ -> builtin_sail_truncateLSB ctx v1 v2 ret_ctyp | "shiftl", [v1; v2], _ -> builtin_shift "bvshl" ctx v1 v2 ret_ctyp | "shiftr", [v1; v2], _ -> builtin_shift "bvlshr" ctx v1 v2 ret_ctyp - | "or_bits", [v1; v2], _ -> builtin_or_bits ctx v1 v2 ret_ctyp | "and_bits", [v1; v2], _ -> builtin_and_bits ctx v1 v2 ret_ctyp + | "or_bits", [v1; v2], _ -> builtin_or_bits ctx v1 v2 ret_ctyp + | "xor_bits", [v1; v2], _ -> builtin_xor_bits ctx v1 v2 ret_ctyp | "not_bits", [v], _ -> builtin_not_bits ctx v ret_ctyp | "add_bits", [v1; v2], _ -> builtin_add_bits ctx v1 v2 ret_ctyp | "add_bits_int", [v1; v2], _ -> builtin_add_bits_int ctx v1 v2 ret_ctyp @@ -1084,6 +1180,7 @@ let smt_builtin ctx name args ret_ctyp = | "vector_access", [v1; v2], ret_ctyp -> builtin_vector_access ctx v1 v2 ret_ctyp | "vector_subrange", [v1; v2; v3], ret_ctyp -> builtin_vector_subrange ctx v1 v2 v3 ret_ctyp | "vector_update", [v1; v2; v3], ret_ctyp -> builtin_vector_update ctx v1 v2 v3 ret_ctyp + | "vector_update_subrange", [v1; v2; v3; v4], ret_ctyp -> builtin_vector_update_subrange ctx v1 v2 v3 v4 ret_ctyp | "sail_unsigned", [v], ret_ctyp -> builtin_unsigned ctx v ret_ctyp | "sail_signed", [v], ret_ctyp -> builtin_signed ctx v ret_ctyp | "replicate_bits", [v1; v2], ret_ctyp -> builtin_replicate_bits ctx v1 v2 ret_ctyp @@ -1110,16 +1207,30 @@ let smt_builtin ctx name args ret_ctyp = | "lteq_real", [v1; v2], CT_bool -> ctx.use_real := true; Fn ("<=", [smt_cval ctx v1; smt_cval ctx v2]) | "gteq_real", [v1; v2], CT_bool -> ctx.use_real := true; Fn (">=", [smt_cval ctx v1; smt_cval ctx v2]) - | _ -> failwith ("Unknown builtin " ^ name ^ " " ^ Util.string_of_list ", " string_of_ctyp (List.map cval_ctyp args) ^ " -> " ^ string_of_ctyp ret_ctyp) + | _ -> + Reporting.unreachable ctx.pragma_l __POS__ ("Unknown builtin " ^ name ^ " " ^ Util.string_of_list ", " string_of_ctyp (List.map cval_ctyp args) ^ " -> " ^ string_of_ctyp ret_ctyp) + +let loc_doc = function + | Parse_ast.Documented (str, l) -> str + | _ -> "UNKNOWN" (* Memory reads and writes as defined in lib/regfp.sail *) let writes = ref (-1) -let builtin_write_mem ctx wk addr_size addr data_size data = +let builtin_write_mem l ctx wk addr_size addr data_size data = incr writes; let name = "W" ^ string_of_int !writes in - [Write_mem (name, ctx.node, Lazy.force ctx.pathcond, smt_cval ctx wk, - smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr), smt_cval ctx data, smt_ctyp ctx (cval_ctyp data))], + [Write_mem { + name = name; + node = ctx.node; + active = Lazy.force ctx.pathcond; + kind = smt_cval ctx wk; + addr = smt_cval ctx addr; + addr_type = smt_ctyp ctx (cval_ctyp addr); + data = smt_cval ctx data; + data_type = smt_ctyp ctx (cval_ctyp data); + doc = loc_doc l + }], Var (name ^ "_ret") let ea_writes = ref (-1) @@ -1133,11 +1244,19 @@ let builtin_write_mem_ea ctx wk addr_size addr data_size = let reads = ref (-1) -let builtin_read_mem ctx rk addr_size addr data_size ret_ctyp = +let builtin_read_mem l ctx rk addr_size addr data_size ret_ctyp = incr reads; let name = "R" ^ string_of_int !reads in - [Read_mem (name, ctx.node, Lazy.force ctx.pathcond, smt_ctyp ctx ret_ctyp, smt_cval ctx rk, - smt_cval ctx addr, smt_ctyp ctx (cval_ctyp addr))], + [Read_mem { + name = name; + node = ctx.node; + active = Lazy.force ctx.pathcond; + ret_type = smt_ctyp ctx ret_ctyp; + kind = smt_cval ctx rk; + addr = smt_cval ctx addr; + addr_type = smt_ctyp ctx (cval_ctyp addr); + doc = loc_doc l + }], Read_res name let excl_results = ref (-1) @@ -1150,26 +1269,51 @@ let builtin_excl_res ctx = let barriers = ref (-1) -let builtin_barrier ctx bk = +let builtin_barrier l ctx bk = incr barriers; let name = "B" ^ string_of_int !barriers in - [Barrier (name, ctx.node, Lazy.force ctx.pathcond, smt_cval ctx bk)], + [Barrier { + name = name; + node = ctx.node; + active = Lazy.force ctx.pathcond; + kind = smt_cval ctx bk; + doc = loc_doc l + }], Enum "unit" -let rec smt_conversion ctx from_ctyp to_ctyp x = - match from_ctyp, to_ctyp with - | _, _ when ctyp_equal from_ctyp to_ctyp -> x - | CT_constant c, CT_fint sz -> - bvint sz c - | CT_constant c, CT_lint -> - bvint ctx.lint_size c - | CT_fint sz, CT_lint -> - force_size ctx ctx.lint_size sz x - | CT_lbits _, CT_fbits (n, _) -> - unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [x])) - | _, _ -> failwith (Printf.sprintf "Cannot perform conversion from %s to %s" (string_of_ctyp from_ctyp) (string_of_ctyp to_ctyp)) +let cache_maintenances = ref (-1) + +let builtin_cache_maintenance l ctx cmk addr_size addr = + incr cache_maintenances; + let name = "M" ^ string_of_int !cache_maintenances in + [Cache_maintenance { + name = name; + node = ctx.node; + active = Lazy.force ctx.pathcond; + kind = smt_cval ctx cmk; + addr = smt_cval ctx addr; + addr_type = smt_ctyp ctx (cval_ctyp addr); + doc = loc_doc l + }], + Enum "unit" + +let branch_announces = ref (-1) + +let builtin_branch_announce l ctx addr_size addr = + incr branch_announces; + let name = "C" ^ string_of_int !branch_announces in + [Branch_announce { + name = name; + node = ctx.node; + active = Lazy.force ctx.pathcond; + addr = smt_cval ctx addr; + addr_type = smt_ctyp ctx (cval_ctyp addr); + doc = loc_doc l + }], + Enum "unit" let define_const ctx id ctyp exp = Define_const (zencode_name id, smt_ctyp ctx ctyp, exp) +let preserve_const ctx id ctyp exp = Preserve_const (string_of_id id, smt_ctyp ctx ctyp, exp) let declare_const ctx id ctyp = Declare_const (zencode_name id, smt_ctyp ctx ctyp) let smt_ctype_def ctx = function @@ -1205,143 +1349,144 @@ let rec generate_reg_decs ctx inits = function let max_int n = Big_int.pred (Big_int.pow_int_positive 2 (n - 1)) let min_int n = Big_int.negate (Big_int.pow_int_positive 2 (n - 1)) -(** Convert a sail type into a C-type. This function can be quite - slow, because it uses ctx.local_env and SMT to analyse the Sail - types and attempts to fit them into the smallest possible C - types, provided ctx.optimize_smt is true (default) **) -let rec ctyp_of_typ ctx typ = - let open Ast in - let open Type_check in - let open Jib_compile in - let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.tc_env typ in - match typ_aux with - | Typ_id id when string_of_id id = "bit" -> CT_bit - | Typ_id id when string_of_id id = "bool" -> CT_bool - | Typ_id id when string_of_id id = "int" -> CT_lint - | Typ_id id when string_of_id id = "nat" -> CT_lint - | Typ_id id when string_of_id id = "unit" -> CT_unit - | Typ_id id when string_of_id id = "string" -> CT_string - | Typ_id id when string_of_id id = "real" -> CT_real - - | Typ_app (id, _) when string_of_id id = "atom_bool" -> CT_bool - - | Typ_app (id, args) when string_of_id id = "itself" -> - ctyp_of_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) - | Typ_app (id, _) when string_of_id id = "range" || string_of_id id = "atom" || string_of_id id = "implicit" -> - begin match destruct_range Env.empty typ with - | None -> assert false (* Checked if range type in guard *) - | Some (kids, constr, n, m) -> - let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env } in - match nexp_simp n, nexp_simp m with - | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) - when n = m -> - CT_constant n - | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) - when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> - CT_fint 64 - | n, m -> - if prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) then +module SMT_config : Jib_compile.Config = struct + open Jib_compile + + (** Convert a sail type into a C-type. This function can be quite + slow, because it uses ctx.local_env and SMT to analyse the Sail + types and attempts to fit them into the smallest possible C + types, provided ctx.optimize_smt is true (default) **) + let rec convert_typ ctx typ = + let open Ast in + let open Type_check in + let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.tc_env typ in + match typ_aux with + | Typ_id id when string_of_id id = "bit" -> CT_bit + | Typ_id id when string_of_id id = "bool" -> CT_bool + | Typ_id id when string_of_id id = "int" -> CT_lint + | Typ_id id when string_of_id id = "nat" -> CT_lint + | Typ_id id when string_of_id id = "unit" -> CT_unit + | Typ_id id when string_of_id id = "string" -> CT_string + | Typ_id id when string_of_id id = "real" -> CT_real + + | Typ_app (id, _) when string_of_id id = "atom_bool" -> CT_bool + + | Typ_app (id, args) when string_of_id id = "itself" -> + convert_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) + | Typ_app (id, _) when string_of_id id = "range" || string_of_id id = "atom" || string_of_id id = "implicit" -> + begin match destruct_range Env.empty typ with + | None -> assert false (* Checked if range type in guard *) + | Some (kids, constr, n, m) -> + let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env } in + match nexp_simp n, nexp_simp m with + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + when n = m -> + CT_constant n + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> CT_fint 64 - else - CT_lint - end - - | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> - CT_list (ctyp_of_typ ctx typ) - - (* Note that we have to use lbits for zero-length bitvectors because they are not allowed by SMTLIB *) - | Typ_app (id, [A_aux (A_nexp n, _); A_aux (A_order ord, _)]) - when string_of_id id = "bitvector" -> - let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in - begin match nexp_simp n with - | Nexp_aux (Nexp_constant n, _) when Big_int.equal n Big_int.zero -> CT_lbits direction - | Nexp_aux (Nexp_constant n, _) -> CT_fbits (Big_int.to_int n, direction) - | _ -> CT_lbits direction - end - - | Typ_app (id, [A_aux (A_nexp n, _); - A_aux (A_order ord, _); - A_aux (A_typ typ, _)]) - when string_of_id id = "vector" -> - let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in - CT_vector (direction, ctyp_of_typ ctx typ) - - | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> - CT_ref (ctyp_of_typ ctx typ) - - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> UBindings.bindings) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> UBindings.bindings) - | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) - - | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) - - | Typ_exist _ -> - (* Use Type_check.destruct_exist when optimising with SMT, to - ensure that we don't cause any type variable clashes in - local_env, and that we can optimize the existential based upon - it's constraints. *) - begin match destruct_exist (Env.expand_synonyms ctx.local_env typ) with - | Some (kids, nc, typ) -> - let env = add_existential l kids nc ctx.local_env in - ctyp_of_typ { ctx with local_env = env } typ - | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!") - end - - | Typ_var kid -> CT_poly - - | _ -> raise (Reporting.err_unreachable l __POS__ ("No SMT type for type " ^ string_of_typ typ)) + | n, m -> + if prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) then + CT_fint 64 + else + CT_lint + end -(**************************************************************************) -(* 3. Optimization of primitives and literals *) -(**************************************************************************) + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> + CT_list (convert_typ ctx typ) + + (* Note that we have to use lbits for zero-length bitvectors because they are not allowed by SMTLIB *) + | Typ_app (id, [A_aux (A_nexp n, _); A_aux (A_order ord, _)]) + when string_of_id id = "bitvector" -> + let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in + begin match nexp_simp n with + | Nexp_aux (Nexp_constant n, _) when Big_int.equal n Big_int.zero -> CT_lbits direction + | Nexp_aux (Nexp_constant n, _) -> CT_fbits (Big_int.to_int n, direction) + | _ -> CT_lbits direction + end -let hex_char = - let open Sail2_values in - function - | '0' -> [B0; B0; B0; B0] - | '1' -> [B0; B0; B0; B1] - | '2' -> [B0; B0; B1; B0] - | '3' -> [B0; B0; B1; B1] - | '4' -> [B0; B1; B0; B0] - | '5' -> [B0; B1; B0; B1] - | '6' -> [B0; B1; B1; B0] - | '7' -> [B0; B1; B1; B1] - | '8' -> [B1; B0; B0; B0] - | '9' -> [B1; B0; B0; B1] - | 'A' | 'a' -> [B1; B0; B1; B0] - | 'B' | 'b' -> [B1; B0; B1; B1] - | 'C' | 'c' -> [B1; B1; B0; B0] - | 'D' | 'd' -> [B1; B1; B0; B1] - | 'E' | 'e' -> [B1; B1; B1; B0] - | 'F' | 'f' -> [B1; B1; B1; B1] - | _ -> failwith "Invalid hex character" - -let literal_to_cval (L_aux (l_aux, _) as lit) = - match l_aux with - | L_num n -> Some (V_lit (VL_int n, CT_constant n)) - | L_hex str when String.length str <= 16 -> - let content = Util.string_to_list str |> List.map hex_char |> List.concat in - Some (V_lit (VL_bits (content, true), CT_fbits (String.length str * 4, true))) - | L_unit -> Some (V_lit (VL_unit, CT_unit)) - | L_true -> Some (V_lit (VL_bool true, CT_bool)) - | L_false -> Some (V_lit (VL_bool false, CT_bool)) - | _ -> None - -let c_literals ctx = - let rec c_literal env l = function - | AV_lit (lit, typ) as v -> - begin match literal_to_cval lit with - | Some cval -> AV_cval (cval, typ) - | None -> v + | Typ_app (id, [A_aux (A_nexp n, _); + A_aux (A_order ord, _); + A_aux (A_typ typ, _)]) + when string_of_id id = "vector" -> + let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in + CT_vector (direction, convert_typ ctx typ) + + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> + CT_ref (convert_typ ctx typ) + + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> UBindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> UBindings.bindings) + | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) + + | Typ_tup typs -> CT_tup (List.map (convert_typ ctx) typs) + + | Typ_exist _ -> + (* Use Type_check.destruct_exist when optimising with SMT, to + ensure that we don't cause any type variable clashes in + local_env, and that we can optimize the existential based + upon it's constraints. *) + begin match destruct_exist (Env.expand_synonyms ctx.local_env typ) with + | Some (kids, nc, typ) -> + let env = add_existential l kids nc ctx.local_env in + convert_typ { ctx with local_env = env } typ + | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!") end - | AV_tuple avals -> AV_tuple (List.map (c_literal env l) avals) - | v -> v - in - map_aval c_literal -let unroll_foreach ctx = function + | Typ_var kid -> CT_poly + + | _ -> raise (Reporting.err_unreachable l __POS__ ("No SMT type for type " ^ string_of_typ typ)) + + let hex_char = + let open Sail2_values in + function + | '0' -> [B0; B0; B0; B0] + | '1' -> [B0; B0; B0; B1] + | '2' -> [B0; B0; B1; B0] + | '3' -> [B0; B0; B1; B1] + | '4' -> [B0; B1; B0; B0] + | '5' -> [B0; B1; B0; B1] + | '6' -> [B0; B1; B1; B0] + | '7' -> [B0; B1; B1; B1] + | '8' -> [B1; B0; B0; B0] + | '9' -> [B1; B0; B0; B1] + | 'A' | 'a' -> [B1; B0; B1; B0] + | 'B' | 'b' -> [B1; B0; B1; B1] + | 'C' | 'c' -> [B1; B1; B0; B0] + | 'D' | 'd' -> [B1; B1; B0; B1] + | 'E' | 'e' -> [B1; B1; B1; B0] + | 'F' | 'f' -> [B1; B1; B1; B1] + | _ -> failwith "Invalid hex character" + + let literal_to_cval (L_aux (l_aux, _) as lit) = + match l_aux with + | L_num n -> Some (V_lit (VL_int n, CT_constant n)) + | L_hex str when String.length str <= 16 -> + let content = Util.string_to_list str |> List.map hex_char |> List.concat in + Some (V_lit (VL_bits (content, true), CT_fbits (String.length str * 4, true))) + | L_unit -> Some (V_lit (VL_unit, CT_unit)) + | L_true -> Some (V_lit (VL_bool true, CT_bool)) + | L_false -> Some (V_lit (VL_bool false, CT_bool)) + | _ -> None + + let c_literals ctx = + let rec c_literal env l = function + | AV_lit (lit, typ) as v -> + begin match literal_to_cval lit with + | Some cval -> AV_cval (cval, typ) + | None -> v + end + | AV_tuple avals -> AV_tuple (List.map (c_literal env l) avals) + | v -> v + in + map_aval c_literal + +(* If we know the loop variables exactly (especially after + specialization), we can unroll the exact number of times required, + and omit any comparisons. *) +let unroll_static_foreach ctx = function | AE_aux (AE_for (id, from_aexp, to_aexp, by_aexp, order, body), env, l) as aexp -> - begin match ctyp_of_typ ctx (aexp_typ from_aexp), ctyp_of_typ ctx (aexp_typ to_aexp), ctyp_of_typ ctx (aexp_typ by_aexp), order with + begin match convert_typ ctx (aexp_typ from_aexp), convert_typ ctx (aexp_typ to_aexp), convert_typ ctx (aexp_typ by_aexp), order with | CT_constant f, CT_constant t, CT_constant b, Ord_aux (Ord_inc, _) -> let i = ref f in let unrolled = ref [] in @@ -1360,6 +1505,19 @@ let unroll_foreach ctx = function end | aexp -> aexp + let optimize_anf ctx aexp = + aexp + |> c_literals ctx + |> fold_aexp (unroll_static_foreach ctx) + + let specialize_calls = true + let ignore_64 = true + let unroll_loops () = Some !opt_unroll_limit + let struct_value = true + let use_real = true +end + + (**************************************************************************) (* 3. Generating SMT *) (**************************************************************************) @@ -1414,7 +1572,7 @@ let smt_ssanode ctx cfg preds = pis ids None in match mux with - | None -> [] + | None -> assert false | Some mux -> [Define_const (zencode_name id, smt_ctyp ctx ctyp, mux)] @@ -1492,7 +1650,7 @@ let rec rmw_write = function | CL_id _ -> assert false | CL_tuple (clexp, _) -> rmw_write clexp | CL_field (clexp, _) -> rmw_write clexp - | clexp -> assert false + | clexp -> failwith "Could not understand l-expression" let rmw_read = function | CL_rmw (read, _, _) -> zencode_name read @@ -1522,7 +1680,7 @@ let rmw_modify smt = function if UId.compare field field' = 0 then smt else - Fn (zencode_upper_id struct_id ^ "_" ^ zencode_uid field', [Var (rmw_read clexp)]) + Field (zencode_upper_id struct_id ^ "_" ^ zencode_uid field', Var (rmw_read clexp)) in Fn (zencode_upper_id struct_id, List.map set_field fields) | _ -> @@ -1564,7 +1722,7 @@ let smt_instr ctx = else if name = "platform_write_mem" then begin match args with | [wk; addr_size; addr; data_size; data] -> - let mem_event, var = builtin_write_mem ctx wk addr_size addr data_size data in + let mem_event, var = builtin_write_mem l ctx wk addr_size addr data_size data in mem_event @ [define_const ctx id ret_ctyp var] | _ -> Reporting.unreachable l __POS__ "Bad arguments for __write_mem" @@ -1580,7 +1738,7 @@ let smt_instr ctx = else if name = "platform_read_mem" then begin match args with | [rk; addr_size; addr; data_size] -> - let mem_event, var = builtin_read_mem ctx rk addr_size addr data_size ret_ctyp in + let mem_event, var = builtin_read_mem l ctx rk addr_size addr data_size ret_ctyp in mem_event @ [define_const ctx id ret_ctyp var] | _ -> Reporting.unreachable l __POS__ "Bad arguments for __read_mem" @@ -1588,7 +1746,23 @@ let smt_instr ctx = else if name = "platform_barrier" then begin match args with | [bk] -> - let mem_event, var = builtin_barrier ctx bk in + let mem_event, var = builtin_barrier l ctx bk in + mem_event @ [define_const ctx id ret_ctyp var] + | _ -> + Reporting.unreachable l __POS__ "Bad arguments for __barrier" + end + else if name = "platform_cache_maintenance" then + begin match args with + | [cmk; addr_size; addr] -> + let mem_event, var = builtin_cache_maintenance l ctx cmk addr_size addr in + mem_event @ [define_const ctx id ret_ctyp var] + | _ -> + Reporting.unreachable l __POS__ "Bad arguments for __barrier" + end + else if name = "platform_branch_announce" then + begin match args with + | [addr_size; addr] -> + let mem_event, var = builtin_branch_announce l ctx addr_size addr in mem_event @ [define_const ctx id ret_ctyp var] | _ -> Reporting.unreachable l __POS__ "Bad arguments for __barrier" @@ -1601,9 +1775,20 @@ let smt_instr ctx = | _ -> Reporting.unreachable l __POS__ "Bad arguments for __excl_res" end + else if name = "sail_exit" then + (add_event ctx Assertion (Bool_lit false); []) + else if name = "sail_assert" then + begin match args with + | [assertion; _] -> + let smt = smt_cval ctx assertion in + add_event ctx Assertion (Fn ("not", [smt])); + [] + | _ -> + Reporting.unreachable l __POS__ "Bad arguments for assertion" + end else let value = smt_builtin ctx name args ret_ctyp in - [define_const ctx id ret_ctyp value] + [define_const ctx id ret_ctyp (Syntactic (value, List.map (smt_cval ctx) args))] else if extern && string_of_id (fst function_id) = "internal_vector_init" then [declare_const ctx id ret_ctyp] else if extern && string_of_id (fst function_id) = "internal_vector_update" then @@ -1615,15 +1800,6 @@ let smt_instr ctx = | _ -> Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update" end - else if string_of_id (fst function_id) = "sail_assert" then - begin match args with - | [assertion; _] -> - let smt = smt_cval ctx assertion in - add_event ctx Assertion smt; - [] - | _ -> - Reporting.unreachable l __POS__ "Bad arguments for assertion" - end else if string_of_id (fst function_id) = "sail_assume" then begin match args with | [assumption] -> @@ -1643,8 +1819,14 @@ let smt_instr ctx = Reporting.unreachable l __POS__ "Register reference write should be re-written by now" | I_aux (I_init (ctyp, id, cval), _) | I_aux (I_copy (CL_id (id, ctyp), cval), _) -> - [define_const ctx id ctyp - (smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval))] + begin match id with + | Name (id, _) when IdSet.mem id ctx.preserved -> + [preserve_const ctx id ctyp + (smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval))] + | _ -> + [define_const ctx id ctyp + (smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval))] + end | I_aux (I_copy (clexp, cval), _) -> let smt = smt_cval ctx cval in @@ -1721,13 +1903,19 @@ module Make_optimizer(S : Sequence) = struct | Some n -> Hashtbl.replace uses var (n + 1) | None -> Hashtbl.add uses var 1 end - | Enum _ | Read_res _ | Hex _ | Bin _ | Bool_lit _ | String_lit _ | Real_lit _ -> () + | Syntactic (exp, _) -> uses_in_exp exp + | Shared _ | Enum _ | Read_res _ | Bitvec_lit _ | Bool_lit _ | String_lit _ | Real_lit _ -> () | Fn (_, exps) | Ctor (_, exps) -> List.iter uses_in_exp exps + | Field (_, exp) -> + uses_in_exp exp + | Struct (_, fields) -> + List.iter (fun (_, exp) -> uses_in_exp exp) fields | Ite (cond, t, e) -> uses_in_exp cond; uses_in_exp t; uses_in_exp e | Extract (_, _, exp) | Tester (_, exp) | SignExtend (_, exp) -> uses_in_exp exp + | Forall _ -> assert false in let remove_unused () = function @@ -1737,6 +1925,11 @@ module Make_optimizer(S : Sequence) = struct | Some _ -> Stack.push def stack' end + | Declare_fun _ as def -> + Stack.push def stack' + | Preserve_const (_, _, exp) as def -> + uses_in_exp exp; + Stack.push def stack' | Define_const (var, _, exp) as def -> begin match Hashtbl.find_opt uses var with | None -> () @@ -1746,17 +1939,23 @@ module Make_optimizer(S : Sequence) = struct end | (Declare_datatypes _ | Declare_tuple _) as def -> Stack.push def stack' - | Write_mem (_, _, active, wk, addr, _, data, _) as def -> - uses_in_exp active; uses_in_exp wk; uses_in_exp addr; uses_in_exp data; + | Write_mem w as def -> + uses_in_exp w.active; uses_in_exp w.kind; uses_in_exp w.addr; uses_in_exp w.data; Stack.push def stack' | Write_mem_ea (_, _, active, wk, addr, _, data_size, _) as def -> uses_in_exp active; uses_in_exp wk; uses_in_exp addr; uses_in_exp data_size; Stack.push def stack' - | Read_mem (_, _, active, _, rk, addr, _) as def -> - uses_in_exp active; uses_in_exp rk; uses_in_exp addr; + | Read_mem r as def -> + uses_in_exp r.active; uses_in_exp r.kind; uses_in_exp r.addr; + Stack.push def stack' + | Barrier b as def -> + uses_in_exp b.active; uses_in_exp b.kind; Stack.push def stack' - | Barrier (_, _, active, bk) as def -> - uses_in_exp active; uses_in_exp bk; + | Cache_maintenance m as def -> + uses_in_exp m.active; uses_in_exp m.kind; uses_in_exp m.addr; + Stack.push def stack' + | Branch_announce c as def -> + uses_in_exp c.active; uses_in_exp c.addr; Stack.push def stack' | Excl_res (_, _, active) as def -> uses_in_exp active; @@ -1775,10 +1974,14 @@ module Make_optimizer(S : Sequence) = struct let constant_propagate = function | Declare_const _ as def -> S.add def seq + | Declare_fun _ as def -> + S.add def seq + | Preserve_const (var, typ, exp) -> + S.add (Preserve_const (var, typ, simp_smt_exp vars kinds exp)) seq | Define_const (var, typ, exp) -> let exp = simp_smt_exp vars kinds exp in begin match Hashtbl.find_opt uses var, simp_smt_exp vars kinds exp with - | _, (Bin _ | Bool_lit _) -> + | _, (Bitvec_lit _ | Bool_lit _) -> Hashtbl.add vars var exp | _, Var _ when !opt_propagate_vars -> Hashtbl.add vars var exp @@ -1791,20 +1994,30 @@ module Make_optimizer(S : Sequence) = struct S.add (Define_const (var, typ, exp)) seq | None, _ -> assert false end - | Write_mem (name, node, active, wk, addr, addr_ty, data, data_ty) -> - S.add (Write_mem (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds wk, - simp_smt_exp vars kinds addr, addr_ty, simp_smt_exp vars kinds data, data_ty)) + | Write_mem w -> + S.add (Write_mem { w with active = simp_smt_exp vars kinds w.active; + kind = simp_smt_exp vars kinds w.kind; + addr = simp_smt_exp vars kinds w.addr; + data = simp_smt_exp vars kinds w.data }) seq | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> S.add (Write_mem_ea (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds wk, simp_smt_exp vars kinds addr, addr_ty, simp_smt_exp vars kinds data_size, data_size_ty)) seq - | Read_mem (name, node, active, typ, rk, addr, addr_typ) -> - S.add (Read_mem (name, node, simp_smt_exp vars kinds active, typ, simp_smt_exp vars kinds rk, - simp_smt_exp vars kinds addr, addr_typ)) + | Read_mem r -> + S.add (Read_mem { r with active = simp_smt_exp vars kinds r.active; + kind = simp_smt_exp vars kinds r.kind; + addr = simp_smt_exp vars kinds r.addr }) seq - | Barrier (name, node, active, bk) -> - S.add (Barrier (name, node, simp_smt_exp vars kinds active, simp_smt_exp vars kinds bk)) seq + | Barrier b -> + S.add (Barrier { b with active = simp_smt_exp vars kinds b.active; kind = simp_smt_exp vars kinds b.kind }) seq + | Cache_maintenance m -> + S.add (Cache_maintenance { m with active = simp_smt_exp vars kinds m.active; + kind = simp_smt_exp vars kinds m.kind; + addr = simp_smt_exp vars kinds m.addr }) + seq + | Branch_announce c -> + S.add (Branch_announce { c with active = simp_smt_exp vars kinds c.active; addr = simp_smt_exp vars kinds c.addr }) seq | Excl_res (name, node, active) -> S.add (Excl_res (name, node, simp_smt_exp vars kinds active)) seq | Assert exp -> @@ -1843,6 +2056,26 @@ let smt_header ctx cdefs = register if it is. We also do a similar thing for *r = x *) let expand_reg_deref env register_map = function + | I_aux (I_funcall (CL_addr (CL_id (id, ctyp)), false, function_id, args), (_, l)) -> + begin match ctyp with + | CT_ref reg_ctyp -> + begin match CTMap.find_opt reg_ctyp register_map with + | Some regs -> + let end_label = label "end_reg_write_" in + let try_reg r = + let next_label = label "next_reg_write_" in + [ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; + ifuncall (CL_id (name r, reg_ctyp)) function_id args; + igoto end_label; + ilabel next_label] + in + iblock (List.concat (List.map try_reg regs) @ [ilabel end_label]) + | None -> + raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) + end + | _ -> + raise (Reporting.err_general l "Register reference assignment must take a register reference as an argument") + end | I_aux (I_funcall (clexp, false, function_id, [reg_ref]), (_, l)) as instr -> let open Type_check in begin match (if Env.is_extern (fst function_id) env "smt" then Some (Env.get_extern (fst function_id) env "smt") else None) with @@ -1855,7 +2088,7 @@ let expand_reg_deref env register_map = function let end_label = label "end_reg_deref_" in let try_reg r = let next_label = label "next_reg_deref_" in - [ijump (V_call (Neq, [V_ref (name r, reg_ctyp); reg_ref])) next_label; + [ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); reg_ref])) next_label; icopy l clexp (V_id (name r, reg_ctyp)); igoto end_label; ilabel next_label] @@ -1877,7 +2110,7 @@ let expand_reg_deref env register_map = function let end_label = label "end_reg_write_" in let try_reg r = let next_label = label "next_reg_write_" in - [ijump (V_call (Neq, [V_ref (name r, reg_ctyp); V_id (id, ctyp)])) next_label; + [ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; icopy l (CL_id (name r, reg_ctyp)) cval; igoto end_label; ilabel next_label] @@ -1927,7 +2160,7 @@ let smt_instr_list name ctx all_cdefs instrs = dump_graph name cfg; List.iter (fun n -> - begin match get_vertex cfg n with + match get_vertex cfg n with | None -> () | Some ((ssa_elems, cfnode), preds, succs) -> let muxers = @@ -1937,13 +2170,12 @@ let smt_instr_list name ctx all_cdefs instrs = let basic_block = smt_cfnode all_cdefs ctx ssa_elems cfnode in push_smt_defs stack muxers; push_smt_defs stack basic_block - end ) visit_order; - stack, cfg + stack, start, cfg let smt_cdef props lets name_file ctx all_cdefs = function - | CDEF_spec (function_id, arg_ctyps, ret_ctyp) when Bindings.mem function_id props -> + | CDEF_spec (function_id, _, arg_ctyps, ret_ctyp) when Bindings.mem function_id props -> begin match find_function [] function_id all_cdefs with | intervening_lets, Some (None, args, instrs) -> let prop_type, prop_args, pragma_l, vs = Bindings.find function_id props in @@ -1967,7 +2199,7 @@ let smt_cdef props lets name_file ctx all_cdefs = function |> remove_pointless_goto in - let stack, _ = smt_instr_list (string_of_id function_id) ctx all_cdefs instrs in + let stack, _, _ = smt_instr_list (string_of_id function_id) ctx all_cdefs instrs in let query = smt_query ctx pragma.query in push_smt_defs stack [Assert (Fn ("not", [query]))]; @@ -2038,25 +2270,20 @@ let rec build_register_map rmap = function | [] -> rmap let compile env ast = - let cdefs = - let open Jib_compile in - let ctx = - initial_ctx - ~convert_typ:ctyp_of_typ - ~optimize_anf:(fun ctx aexp -> fold_aexp (unroll_foreach ctx) (c_literals ctx aexp)) - env - in + let cdefs, jib_ctx = + let module Jibc = Jib_compile.Make(SMT_config) in + let ctx = Jib_compile.(initial_ctx (add_special_functions env)) in let t = Profile.start () in - let cdefs, ctx = compile_ast { ctx with specialize_calls = true; ignore_64 = true; struct_value = true; use_real = true } ast in + let cdefs, ctx = Jibc.compile_ast ctx ast in Profile.finish "Compiling to Jib IR" t; - cdefs + cdefs, ctx in let cdefs = Jib_optimize.unique_per_function_ids cdefs in let rmap = build_register_map CTMap.empty cdefs in - cdefs, { (initial_ctx ()) with tc_env = env; register_map = rmap; ast = ast } + cdefs, jib_ctx, { (initial_ctx ()) with tc_env = jib_ctx.tc_env; register_map = rmap; ast = ast } let serialize_smt_model file env ast = - let cdefs, ctx = compile env ast in + let cdefs, _, ctx = compile env ast in let out_chan = open_out file in Marshal.to_channel out_chan cdefs []; Marshal.to_channel out_chan (Type_check.Env.set_prover None ctx.tc_env) []; @@ -2073,7 +2300,7 @@ let deserialize_smt_model file = let generate_smt props name_file env ast = try - let cdefs, ctx = compile env ast in + let cdefs, _, ctx = compile env ast in smt_cdefs props [] name_file ctx cdefs cdefs with | Type_check.Type_error (_, l, err) -> diff --git a/src/jib/jib_smt.mli b/src/jib/jib_smt.mli index cdaf7e39..616877e4 100644 --- a/src/jib/jib_smt.mli +++ b/src/jib/jib_smt.mli @@ -73,44 +73,57 @@ val opt_default_lbits_index : int ref val opt_default_vector_index : int ref type ctx = { - (** Arbitrary-precision bitvectors are represented as a (BitVec lbits_index, BitVec (2 ^ lbits_index)) pair. *) lbits_index : int; - (** The size we use for integers where we don't know how large they are statically. *) + (** Arbitrary-precision bitvectors are represented as a (BitVec lbits_index, BitVec (2 ^ lbits_index)) pair. *) lint_size : int; + (** The size we use for integers where we don't know how large they are statically. *) + vector_index : int; (** A generic vector, vector('a) becomes Array (BitVec vector_index) 'a. We need to take care that vector_index is large enough for all generic vectors. *) - vector_index : int; - (** A map from each ctyp to a list of registers of that ctyp *) register_map : id list CTMap.t; - (** A set to keep track of all the tuple sizes we need to generate types for *) + (** A map from each ctyp to a list of registers of that ctyp *) tuple_sizes : IntSet.t ref; - (** tc_env is the global type-checking environment *) + (** A set to keep track of all the tuple sizes we need to generate types for *) tc_env : Type_check.Env.t; + (** tc_env is the global type-checking environment *) + pragma_l : Ast.l; (** A location, usually the $counterexample or $property we are generating the SMT for. Used for error messages. *) - pragma_l : Ast.l; - (** Used internally to keep track of function argument names *) arg_stack : (int * string) Stack.t; - (** The fully type-checked ast *) + (** Used internally to keep track of function argument names *) ast : Type_check.tannot defs; + (** The fully type-checked ast *) + shared : ctyp Bindings.t; + (** Shared variables. These variables do not get renamed by + Smtlib.suffix_variables_def, and their SSA number is + omitted. They should therefore only ever be read and never + written. Used by sail-axiomatic for symbolic values in the + initial litmus state. *) + preserved : IdSet.t; + (** icopy instructions to an id in preserved will generated a + define-const (by using Smtlib.Preserved_const) that will not be + simplified away or renamed. It will also not get a SSA + number. Such variables can therefore only ever be written to + once, and never read. They are used by sail-axiomatic to + extract information from the generated SMT. *) + events : smt_exp Stack.t EventMap.t ref; (** For every event type we have a stack of boolean SMT expressions for each occurance of that event. See src/property.ml for the event types *) - events : smt_exp Stack.t EventMap.t ref; + node : int; + pathcond : smt_exp Lazy.t; (** When generating SMT for an instruction pathcond will contain the global path conditional of the containing block/node in the control flow graph *) - node : int; - pathcond : smt_exp Lazy.t; + use_string : bool ref; + use_real : bool ref (** Set if we need to use strings or real numbers in the generated SMT, which then requires set-logic ALL or similar depending on the solver *) - use_string : bool ref; - use_real : bool ref } (** Compile an AST into Jib suitable for SMT generation, and initialise a context. *) -val compile : Type_check.Env.t -> Type_check.tannot defs -> cdef list * ctx +val compile : Type_check.Env.t -> Type_check.tannot defs -> cdef list * Jib_compile.ctx * ctx (* TODO: Currently we internally use mutable stacks and queues to avoid any issues with stack overflows caused by some non @@ -122,7 +135,7 @@ val smt_header : ctx -> cdef list -> smt_def list val smt_query : ctx -> Property.query -> smt_exp -val smt_instr_list : string -> ctx -> cdef list -> instr list -> smt_def Stack.t * (ssa_elem list * cf_node) Jib_ssa.array_graph +val smt_instr_list : string -> ctx -> cdef list -> instr list -> smt_def Stack.t * int * (ssa_elem list * cf_node) Jib_ssa.array_graph module type Sequence = sig type 'a t diff --git a/src/jib/jib_smt_fuzz.ml b/src/jib/jib_smt_fuzz.ml index 846d0178..58665bde 100644 --- a/src/jib/jib_smt_fuzz.ml +++ b/src/jib/jib_smt_fuzz.ml @@ -152,13 +152,13 @@ let rec run frame = exception Skip_iteration of string;; let fuzz_cdef ctx all_cdefs = function - | CDEF_spec (id, arg_ctyps, ret_ctyp) when not (string_of_id id = "and_bool" || string_of_id id = "or_bool") -> + | CDEF_spec (id, _, arg_ctyps, ret_ctyp) when not (string_of_id id = "and_bool" || string_of_id id = "or_bool") -> let open Type_check in let open Interpreter in if Env.is_extern id ctx.tc_env "smt" then ( let extern = Env.get_extern id ctx.tc_env "smt" in let typq, (Typ_aux (aux, _) as typ) = Env.get_val_spec id ctx.tc_env in - let istate = initial_state ctx.ast ctx.tc_env Value.primops in + let istate = initial_state ctx.ast ctx.tc_env !Value.primops in let header = smt_header ctx all_cdefs in prerr_endline (Util.("Fuzz: " |> cyan |> clear) ^ string_of_id id ^ " = \"" ^ extern ^ "\" : " ^ string_of_typ typ); @@ -192,7 +192,7 @@ let fuzz_cdef ctx all_cdefs = function @ [iend ()] in let smt_defs = - try fst (smt_instr_list extern ctx all_cdefs jib) with + try (fun (x, _, _) -> x) (smt_instr_list extern ctx all_cdefs jib) with | _ -> raise (Skip_iteration ("SMT error for: " ^ Util.string_of_list ", " string_of_exp (List.map fst values))) in @@ -253,6 +253,6 @@ let fuzz_cdef ctx all_cdefs = function let fuzz seed env ast = Random.init seed; - let cdefs, ctx = compile env ast in + let cdefs, _, ctx = compile env ast in List.iter (fuzz_cdef ctx cdefs) cdefs diff --git a/src/jib/jib_ssa.ml b/src/jib/jib_ssa.ml index 9c405a48..fe3238a4 100644 --- a/src/jib/jib_ssa.ml +++ b/src/jib/jib_ssa.ml @@ -504,6 +504,7 @@ let rename_variables graph root children = | Name (id, _) -> Name (id, i) | Have_exception _ -> Have_exception i | Current_exception _ -> Current_exception i + | Throw_location _ -> Throw_location i | Return _ -> Return i in @@ -524,9 +525,6 @@ let rename_variables graph root children = | V_id (id, ctyp) -> let i = top_stack id in V_id (ssa_name i id, ctyp) - | V_ref (id, ctyp) -> - let i = top_stack id in - V_ref (ssa_name i id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_call (id, fs) -> V_call (id, List.map fold_cval fs) | V_field (f, field) -> V_field (fold_cval f, field) diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml index 13438208..9b06c7be 100644 --- a/src/jib/jib_util.ml +++ b/src/jib/jib_util.ml @@ -83,7 +83,7 @@ let ireset ?loc:(l=Parse_ast.Unknown) ctyp id = let iinit ?loc:(l=Parse_ast.Unknown) ctyp id cval = I_aux (I_init (ctyp, id, cval), (instr_number (), l)) -let iif ?loc:(l=Parse_ast.Unknown) cval then_instrs else_instrs ctyp = +let iif l cval then_instrs else_instrs ctyp = I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (instr_number (), l)) let ifuncall ?loc:(l=Parse_ast.Unknown) clexp id cvals = @@ -113,7 +113,7 @@ let iblock ?loc:(l=Parse_ast.Unknown) instrs = let itry_block ?loc:(l=Parse_ast.Unknown) instrs = I_aux (I_try_block instrs, (instr_number (), l)) -let ithrow ?loc:(l=Parse_ast.Unknown) cval = +let ithrow l cval = I_aux (I_throw cval, (instr_number (), l)) let icomment ?loc:(l=Parse_ast.Unknown) str = @@ -134,7 +134,7 @@ let imatch_failure ?loc:(l=Parse_ast.Unknown) () = let iraw ?loc:(l=Parse_ast.Unknown) str = I_aux (I_raw str, (instr_number (), l)) -let ijump ?loc:(l=Parse_ast.Unknown) cval label = +let ijump l cval label = I_aux (I_jump (cval, label), (instr_number (), l)) module Name = struct @@ -153,6 +153,8 @@ module Name = struct | _, Have_exception _ -> -1 | Current_exception _, _ -> 1 | _, Current_exception _ -> -1 + | Throw_location _, _ -> 1 + | _, Throw_location _ -> -1 end module NameSet = Set.Make(Name) @@ -160,6 +162,7 @@ module NameMap = Map.Make(Name) let current_exception = Current_exception (-1) let have_exception = Have_exception (-1) +let throw_location = Throw_location (-1) let return = Return (-1) let name id = Name (id, -1) @@ -167,8 +170,6 @@ let name id = Name (id, -1) let rec cval_rename from_id to_id = function | V_id (id, ctyp) when Name.compare id from_id = 0 -> V_id (to_id, ctyp) | V_id (id, ctyp) -> V_id (id, ctyp) - | V_ref (id, ctyp) when Name.compare id from_id = 0 -> V_ref (to_id, ctyp) - | V_ref (id, ctyp) -> V_ref (id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_call (call, cvals) -> V_call (call, List.map (cval_rename from_id to_id) cvals) | V_field (f, field) -> V_field (cval_rename from_id to_id f, field) @@ -257,8 +258,7 @@ let rec instr_rename from_id to_id (I_aux (instr, aux)) = (* 1. Instruction pretty printer *) (**************************************************************************) - -let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) = +let string_of_name ?deref_current_exception:(dce=false) ?zencode:(zencode=true) = let ssa_num n = if n = -1 then "" else ("/" ^ string_of_int n) in function | Name (id, n) -> @@ -271,6 +271,8 @@ let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) = "(*current_exception)" ^ ssa_num n | Current_exception n -> "current_exception" ^ ssa_num n + | Throw_location n -> + "throw_location" ^ ssa_num n let string_of_op = function | Bnot -> "@not" @@ -278,7 +280,6 @@ let string_of_op = function | Bor -> "@or" | List_hd -> "@hd" | List_tl -> "@tl" - | Bit_to_bool -> "@bit_to_bool" | Eq -> "@eq" | Neq -> "@neq" | Bvnot -> "@bvnot" @@ -309,9 +310,9 @@ let string_of_op = function let rec string_of_ctyp = function | CT_lint -> "%i" | CT_fint n -> "%i" ^ string_of_int n - | CT_lbits _ -> "%lb" - | CT_sbits (n, _) -> "%sb" ^ string_of_int n - | CT_fbits (n, _) -> "%fb" ^ string_of_int n + | CT_lbits _ -> "%bv" + | CT_sbits (n, _) -> "%sbv" ^ string_of_int n + | CT_fbits (n, _) -> "%bv" ^ string_of_int n | CT_constant n -> Big_int.to_string n | CT_bit -> "%bit" | CT_unit -> "%unit" @@ -323,6 +324,7 @@ let rec string_of_ctyp = function | CT_enum (id, _) -> "%enum " ^ Util.zencode_string (string_of_id id) | CT_variant (id, _) -> "%union " ^ Util.zencode_string (string_of_id id) | CT_vector (_, ctyp) -> "%vec(" ^ string_of_ctyp ctyp ^ ")" + | CT_fvector (n, _, ctyp) -> "%fvec(" ^ string_of_int n ^ ", " ^ string_of_ctyp ctyp ^ ")" | CT_list ctyp -> "%list(" ^ string_of_ctyp ctyp ^ ")" | CT_ref ctyp -> "&(" ^ string_of_ctyp ctyp ^ ")" | CT_poly -> "*" @@ -352,24 +354,27 @@ and full_string_of_ctyp = function | CT_ref ctyp -> "ref(" ^ full_string_of_ctyp ctyp ^ ")" | ctyp -> string_of_ctyp ctyp -let string_of_value = function - | VL_bits ([], _) -> "empty" +let rec string_of_value = function + | VL_bits ([], _) -> "UINT64_C(0)" | VL_bits (bs, true) -> Sail2_values.show_bitlist bs | VL_bits (bs, false) -> Sail2_values.show_bitlist (List.rev bs) | VL_int i -> Big_int.to_string i | VL_bool true -> "true" | VL_bool false -> "false" - | VL_null -> "NULL" | VL_unit -> "()" | VL_bit Sail2_values.B0 -> "bitzero" | VL_bit Sail2_values.B1 -> "bitone" - | VL_bit Sail2_values.BU -> "bitundef" + | VL_bit Sail2_values.BU -> failwith "Undefined bit found in value" | VL_real str -> str | VL_string str -> "\"" ^ str ^ "\"" + | VL_empty_list -> "NULL" + | VL_enum element -> Util.zencode_string element + | VL_ref r -> "&" ^ Util.zencode_string r + | VL_undefined -> "undefined" let rec string_of_cval = function | V_id (id, ctyp) -> string_of_name id - | V_ref (id, _) -> "&" ^ string_of_name id + | V_lit (VL_undefined, ctyp) -> string_of_value VL_undefined ^ " : " ^ string_of_ctyp ctyp | V_lit (vl, ctyp) -> string_of_value vl | V_call (op, cvals) -> Printf.sprintf "%s(%s)" (string_of_op op) (Util.string_of_list ", " string_of_cval cvals) @@ -377,16 +382,10 @@ let rec string_of_cval = function Printf.sprintf "%s.%s" (string_of_cval f) (string_of_uid field) | V_tuple_member (f, _, n) -> Printf.sprintf "%s.ztup%d" (string_of_cval f) n - | V_ctor_kind (f, ctor, [], _) -> - string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor) | V_ctor_kind (f, ctor, unifiers, _) -> - string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) - | V_ctor_unwrap (ctor, f, [], _) -> - Printf.sprintf "%s as %s" (string_of_cval f) (string_of_id ctor) + string_of_cval f ^ " is " ^ string_of_uid (ctor, unifiers) | V_ctor_unwrap (ctor, f, unifiers, _) -> - Printf.sprintf "%s as %s" - (string_of_cval f) - (Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)) + string_of_cval f ^ " as " ^ string_of_uid (ctor, unifiers) | V_struct (fields, _) -> Printf.sprintf "{%s}" (Util.string_of_list ", " (fun (field, cval) -> string_of_uid field ^ " = " ^ string_of_cval cval) fields) @@ -398,6 +397,7 @@ let rec map_ctyp f = function | CT_tup ctyps -> f (CT_tup (List.map (map_ctyp f) ctyps)) | CT_ref ctyp -> f (CT_ref (map_ctyp f ctyp)) | CT_vector (direction, ctyp) -> f (CT_vector (direction, map_ctyp f ctyp)) + | CT_fvector (n, direction, ctyp) -> f (CT_fvector (n, direction, map_ctyp f ctyp)) | CT_list ctyp -> f (CT_list (map_ctyp f ctyp)) | CT_struct (id, ctors) -> f (CT_struct (id, List.map (fun ((id, ctyps), ctyp) -> (id, List.map (map_ctyp f) ctyps), map_ctyp f ctyp) ctors)) @@ -423,6 +423,7 @@ let rec ctyp_equal ctyp1 ctyp2 = | CT_string, CT_string -> true | CT_real, CT_real -> true | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2 + | CT_fvector (n1, d1, ctyp1), CT_fvector (n2, d2, ctyp2) -> n1 = n2 && d1 = d2 && ctyp_equal ctyp1 ctyp2 | CT_list ctyp1, CT_list ctyp2 -> ctyp_equal ctyp1 ctyp2 | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_equal ctyp1 ctyp2 | CT_poly, CT_poly -> true @@ -492,6 +493,11 @@ let rec ctyp_compare ctyp1 ctyp2 = | CT_vector _, _ -> 1 | _, CT_vector _ -> -1 + | CT_fvector (n1, d1, ctyp1), CT_fvector (n2, d2, ctyp2) -> + lex_ord (compare n1 n2) (lex_ord (ctyp_compare ctyp1 ctyp2) (compare d1 d2)) + | CT_fvector _, _ -> 1 + | _, CT_fvector _ -> -1 + | ctyp1, ctyp2 -> String.compare (full_string_of_ctyp ctyp1) (full_string_of_ctyp ctyp2) module CT = struct @@ -564,6 +570,7 @@ let rec ctyp_suprema = function | CT_struct (id, ctors) -> CT_struct (id, ctors) | CT_variant (id, ctors) -> CT_variant (id, ctors) | CT_vector (d, ctyp) -> CT_vector (d, ctyp_suprema ctyp) + | CT_fvector (n, d, ctyp) -> CT_fvector (n, d, ctyp_suprema ctyp) | CT_list ctyp -> CT_list (ctyp_suprema ctyp) | CT_ref ctyp -> CT_ref (ctyp_suprema ctyp) | CT_poly -> CT_poly @@ -573,7 +580,7 @@ let rec ctyp_ids = function | CT_struct (id, ctors) | CT_variant (id, ctors) -> IdSet.add id (List.fold_left (fun ids (_, ctyp) -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctors) | CT_tup ctyps -> List.fold_left (fun ids ctyp -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctyps - | CT_vector (_, ctyp) | CT_list ctyp | CT_ref ctyp -> ctyp_ids ctyp + | CT_vector (_, ctyp) | CT_fvector (_, _, ctyp) | CT_list ctyp | CT_ref ctyp -> ctyp_ids ctyp | CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool | CT_real | CT_bit | CT_string | CT_poly -> IdSet.empty @@ -588,11 +595,11 @@ let rec is_polymorphic = function | CT_tup ctyps -> List.exists is_polymorphic ctyps | CT_enum _ -> false | CT_struct (_, ctors) | CT_variant (_, ctors) -> List.exists (fun (_, ctyp) -> is_polymorphic ctyp) ctors - | CT_vector (_, ctyp) | CT_list ctyp | CT_ref ctyp -> is_polymorphic ctyp + | CT_fvector (_, _, ctyp) | CT_vector (_, ctyp) | CT_list ctyp | CT_ref ctyp -> is_polymorphic ctyp | CT_poly -> true let rec cval_deps = function - | V_id (id, _) | V_ref (id, _) -> NameSet.singleton id + | V_id (id, _) -> NameSet.singleton id | V_lit _ -> NameSet.empty | V_field (cval, _) | V_poly (cval, _) | V_tuple_member (cval, _, _) -> cval_deps cval | V_call (_, cvals) -> List.fold_left NameSet.union NameSet.empty (List.map cval_deps cvals) @@ -666,7 +673,6 @@ let rec map_clexp_ctyp f = function let rec map_cval_ctyp f = function | V_id (id, ctyp) -> V_id (id, f ctyp) - | V_ref (id, ctyp) -> V_ref (id, f ctyp) | V_lit (vl, ctyp) -> V_lit (vl, f ctyp) | V_ctor_kind (cval, id, unifiers, ctyp) -> V_ctor_kind (map_cval_ctyp f cval, id, List.map f unifiers, f ctyp) @@ -734,7 +740,7 @@ let rec concatmap_instr f (I_aux (instr, aux)) = I_try_block (List.concat (List.map (concatmap_instr f) instrs)) in f (I_aux (instr, aux)) - + (** Iterate over each instruction within an instruction, bottom-up *) let rec iter_instr f (I_aux (instr, aux)) = match instr with @@ -754,7 +760,7 @@ let cdef_map_instr f = function | CDEF_fundef (id, heap_return, args, instrs) -> CDEF_fundef (id, heap_return, args, List.map (map_instr f) instrs) | CDEF_startup (id, instrs) -> CDEF_startup (id, List.map (map_instr f) instrs) | CDEF_finish (id, instrs) -> CDEF_finish (id, List.map (map_instr f) instrs) - | CDEF_spec (id, ctyps, ctyp) -> CDEF_spec (id, ctyps, ctyp) + | CDEF_spec (id, extern, ctyps, ctyp) -> CDEF_spec (id, extern, ctyps, ctyp) | CDEF_type tdef -> CDEF_type tdef (** Map over each instruction in a cdef using concatmap_instr *) @@ -769,7 +775,7 @@ let cdef_concatmap_instr f = function CDEF_startup (id, List.concat (List.map (concatmap_instr f) instrs)) | CDEF_finish (id, instrs) -> CDEF_finish (id, List.concat (List.map (concatmap_instr f) instrs)) - | CDEF_spec (id, ctyps, ctyp) -> CDEF_spec (id, ctyps, ctyp) + | CDEF_spec (id, extern, ctyps, ctyp) -> CDEF_spec (id, extern, ctyps, ctyp) | CDEF_type tdef -> CDEF_type tdef let ctype_def_map_ctyp f = function @@ -784,7 +790,7 @@ let cdef_map_ctyp f = function | CDEF_fundef (id, heap_return, args, instrs) -> CDEF_fundef (id, heap_return, args, List.map (map_instr_ctyp f) instrs) | CDEF_startup (id, instrs) -> CDEF_startup (id, List.map (map_instr_ctyp f) instrs) | CDEF_finish (id, instrs) -> CDEF_finish (id, List.map (map_instr_ctyp f) instrs) - | CDEF_spec (id, ctyps, ctyp) -> CDEF_spec (id, List.map f ctyps, f ctyp) + | CDEF_spec (id, extern, ctyps, ctyp) -> CDEF_spec (id, extern, List.map f ctyps, f ctyp) | CDEF_type tdef -> CDEF_type (ctype_def_map_ctyp f tdef) (* Map over all sequences of instructions contained within an instruction *) @@ -838,7 +844,6 @@ let label str = let rec infer_call op vs = match op, vs with - | Bit_to_bool, _ -> CT_bool | Bnot, _ -> CT_bool | Band, _ -> CT_bool | Bor, _ -> CT_bool @@ -900,7 +905,6 @@ let rec infer_call op vs = and cval_ctyp = function | V_id (_, ctyp) -> ctyp - | V_ref (_, ctyp) -> CT_ref ctyp | V_lit (vl, ctyp) -> ctyp | V_ctor_kind _ -> CT_bool | V_ctor_unwrap (ctor, cval, unifiers, ctyp) -> ctyp @@ -984,7 +988,7 @@ let ctype_def_ctyps = function let cdef_ctyps = function | CDEF_reg_dec (_, ctyp, instrs) -> CTSet.add ctyp (instrs_ctyps instrs) - | CDEF_spec (_, ctyps, ctyp) -> + | CDEF_spec (_, _, ctyps, ctyp) -> CTSet.add ctyp (List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty ctyps) | CDEF_fundef (_, _, _, instrs) | CDEF_startup (_, instrs) | CDEF_finish (_, instrs) -> instrs_ctyps instrs |
