diff options
| author | Alasdair Armstrong | 2019-05-03 17:28:30 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2019-05-03 17:28:30 +0100 |
| commit | f6ad93e7cbbb3e43b045ae3313e556ea70e54c8f (patch) | |
| tree | c6f1bc2e499046cb7e5c22f750e0e63162f6d253 | |
| parent | c7a3389c34eebac4fed7764f339f4cd1b2b204f7 (diff) | |
Jib: Fix optimizations for SMT IR changes
Fixes C backend optimizations that were disabled due to changes in the
IR while working on the SMT generation.
Also add a -Oaarch64_fast option that optimizes any integer within a
struct to be an int64_t, which is safe for the ARM v8.5 spec and
improves performance significantly (reduces Linux boot times by 4-5
minutes). Eventually this should probably be a directive that can be
attached to any arbitrary struct/type.
Fixes the -c_specialize option for ARM v8.5. However this only gives a
very small performance improvment for a very large increase in
compilation time however.
| -rw-r--r-- | language/jib.ott | 42 | ||||
| -rw-r--r-- | lib/sail.c | 51 | ||||
| -rw-r--r-- | lib/sail.h | 5 | ||||
| -rw-r--r-- | src/interpreter.ml | 2 | ||||
| -rw-r--r-- | src/jib/c_backend.ml | 221 | ||||
| -rw-r--r-- | src/jib/jib_compile.ml | 91 | ||||
| -rw-r--r-- | src/jib/jib_compile.mli | 5 | ||||
| -rw-r--r-- | src/jib/jib_util.ml | 37 | ||||
| -rw-r--r-- | src/sail.ml | 3 | ||||
| -rw-r--r-- | src/specialize.ml | 10 | ||||
| -rw-r--r-- | src/value.ml | 2 | ||||
| -rwxr-xr-x | test/c/run_tests.py | 1 |
12 files changed, 328 insertions, 142 deletions
diff --git a/language/jib.ott b/language/jib.ott index 4ab0e22e..058c50d2 100644 --- a/language/jib.ott +++ b/language/jib.ott @@ -57,25 +57,33 @@ name :: '' ::= | return nat :: :: return op :: '' ::= - | not :: :: bnot - | hd :: :: list_hd - | tl :: :: list_tl - | bit_to_bool :: :: bit_to_bool - | eq :: :: eq - | neq :: :: neq + | not :: :: bnot + | hd :: :: list_hd + | tl :: :: list_tl + | bit_to_bool :: :: bit_to_bool + | eq :: :: eq + | neq :: :: neq % Integer ops - | lt :: :: ilt - | lteq :: :: ilteq - | gt :: :: igt - | gteq :: :: igteq - | add :: :: iadd - | sub :: :: isub + | lt :: :: ilt + | lteq :: :: ilteq + | gt :: :: igt + | gteq :: :: igteq + | add :: :: iadd + | sub :: :: isub % Bitvector ops - | bvor :: :: bvor - | bvand :: :: bvand - | concat :: :: concat - | zero_extend nat :: :: zero_extend - | sign_extend nat :: :: sign_extend + | bvnot :: :: bvnot + | bvor :: :: bvor + | bvand :: :: bvand + | bvxor :: :: bvxor + | bvadd :: :: bvadd + | bvsub :: :: bvsub + | bvaccess :: :: bvaccess + | concat :: :: concat + | zero_extend nat :: :: zero_extend + | sign_extend nat :: :: sign_extend + | slice nat :: :: slice + | sslice nat :: :: sslice + | replicate nat :: :: replicate cval :: 'V_' ::= | name : ctyp :: :: id @@ -1017,15 +1017,11 @@ void slice(lbits *rop, const lbits op, const sail_int start_mpz, const sail_int } } -inline +__attribute__((target ("bmi2"))) sbits sslice(const fbits op, const mach_int start, const mach_int len) { sbits rop; -#ifdef INTRINSICS rop.bits = _bzhi_u64(op >> start, len); -#else - rop.bits = (op >> start) & safe_rshift(UINT64_MAX, 64 - len); -#endif rop.len = len; return rop; } @@ -1126,18 +1122,25 @@ void reverse_endianness(lbits *rop, const lbits op) } } -inline bool eq_sbits(const sbits op1, const sbits op2) { return op1.bits == op2.bits; } -inline bool neq_sbits(const sbits op1, const sbits op2) { return op1.bits != op2.bits; } +__attribute__((target ("bmi2"))) +sbits not_sbits(const sbits op) +{ + sbits rop; + rop.bits = (~op.bits) & _bzhi_u64(UINT64_MAX, op.len); + rop.len = op.len; + return rop; +} + sbits xor_sbits(const sbits op1, const sbits op2) { sbits rop; @@ -1146,6 +1149,40 @@ sbits xor_sbits(const sbits op1, const sbits op2) return rop; } +sbits or_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = op1.bits | op2.bits; + rop.len = op1.len; + return rop; +} + +sbits and_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = op1.bits & op2.bits; + rop.len = op1.len; + return rop; +} + +__attribute__((target ("bmi2"))) +sbits add_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = (op1.bits + op2.bits) & _bzhi_u64(UINT64_MAX, op1.len); + rop.len = op1.len; + return rop; +} + +__attribute__((target ("bmi2"))) +sbits sub_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = (op1.bits - op2.bits) & _bzhi_u64(UINT64_MAX, op1.len); + rop.len = op1.len; + return rop; +} + /* ***** Sail Reals ***** */ void CREATE(real)(real *rop) @@ -333,7 +333,12 @@ void reverse_endianness(lbits*, lbits); bool eq_sbits(const sbits op1, const sbits op2); bool neq_sbits(const sbits op1, const sbits op2); +sbits not_sbits(const sbits op); sbits xor_sbits(const sbits op1, const sbits op2); +sbits or_sbits(const sbits op1, const sbits op2); +sbits and_sbits(const sbits op1, const sbits op2); +sbits add_sbits(const sbits op1, const sbits op2); +sbits sub_sbits(const sbits op1, const sbits op2); /* ***** Sail reals ***** */ diff --git a/src/interpreter.ml b/src/interpreter.ml index c1f84ae2..572c0a18 100644 --- a/src/interpreter.ml +++ b/src/interpreter.ml @@ -241,7 +241,7 @@ let put_local name v : unit monad = let get_global_letbinds () : (Type_check.tannot letbind) list monad = Yield (Get_global_letbinds (fun lbs -> Pure lbs)) - + let early_return v = Yield (Early_return v) let assertion_failed msg = Yield (Assertion_failed msg) diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index 2545ed7a..94425384 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -127,7 +127,7 @@ let rec ctyp_of_typ ctx typ = begin match destruct_range Env.empty typ with | None -> assert false (* Checked if range type in guard *) | Some (kids, constr, n, m) -> - let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env } in + let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env }in match nexp_simp n, nexp_simp m with | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> @@ -266,18 +266,6 @@ let c_literals ctx = in map_aval c_literal -let mask m = - if Big_int.less_equal m (Big_int.of_int 64) then - let n = Big_int.to_int m in - if n = 0 then - "UINT64_C(0)" - else if n mod 4 = 0 then - "UINT64_C(0x" ^ String.make (16 - n / 4) '0' ^ String.make (n / 4) 'F' ^ ")" - else - "UINT64_C(" ^ String.make (64 - n) '0' ^ String.make n '1' ^ ")" - else - failwith "Tried to create a mask literal for a vector greater than 64 bits." - let rec is_bitvector = function | [] -> true | AV_lit (L_aux (L_zero, _), _) :: avals -> is_bitvector avals @@ -424,6 +412,9 @@ let analyze_primop' ctx id args typ = | "eq_int", [AV_cval (v1, _); AV_cval (v2, _)] -> AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) + | "eq_bit", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) + | "zeros", [_] -> begin match destruct_vector ctx.tc_env typ with | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) @@ -465,56 +456,70 @@ let analyze_primop' ctx id args typ = | _ -> no_change end - (* - | "add_bits", [AV_C_fragment (v1, _, CT_fbits (n, ord)); AV_C_fragment (v2, _, CT_fbits _)] - when n <= 63 -> - AE_val (AV_C_fragment (F_op (F_op (v1, "+", v2), "&", v_mask_lower n), typ, CT_fbits (n, ord))) + | "not_bits", [AV_cval (v, _)] -> + AE_val (AV_cval (V_call (Bvnot, [v]), typ)) - | "xor_bits", [AV_C_fragment (v1, _, (CT_fbits _ as ctyp)); AV_C_fragment (v2, _, CT_fbits _)] -> - AE_val (AV_C_fragment (F_op (v1, "^", v2), typ, ctyp)) - | "xor_bits", [AV_C_fragment (v1, _, (CT_sbits _ as ctyp)); AV_C_fragment (v2, _, CT_sbits _)] -> - AE_val (AV_C_fragment (F_call ("xor_sbits", [v1; v2]), typ, ctyp)) + | "add_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Bvadd, [v1; v2]), typ)) - | "or_bits", [AV_C_fragment (v1, _, (CT_fbits _ as ctyp)); AV_C_fragment (v2, _, CT_fbits _)] -> - AE_val (AV_C_fragment (F_op (v1, "|", v2), typ, ctyp)) + | "sub_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Bvsub, [v1; v2]), typ)) - | "and_bits", [AV_C_fragment (v1, _, (CT_fbits _ as ctyp)); AV_C_fragment (v2, _, CT_fbits _)] -> - AE_val (AV_C_fragment (F_op (v1, "&", v2), typ, ctyp)) + | "and_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Bvand, [v1; v2]), typ)) - | "not_bits", [AV_C_fragment (v, _, ctyp)] -> - begin match destruct_vector ctx.tc_env typ with - | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) - when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - AE_val (AV_C_fragment (F_op (F_unary ("~", v), "&", v_mask_lower (Big_int.to_int n)), typ, ctyp)) - | _ -> no_change - end + | "or_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Bvor, [v1; v2]), typ)) - | "vector_subrange", [AV_C_fragment (vec, _, CT_fbits _); AV_C_fragment (f, _, _); AV_C_fragment (t, _, _)] - when is_fbits_typ ctx typ -> - let len = F_op (f, "-", F_op (t, "-", v_one)) in - AE_val (AV_C_fragment (F_op (F_call ("safe_rshift", [F_raw "UINT64_MAX"; F_op (v_int 64, "-", len)]), "&", F_op (vec, ">>", t)), - typ, - ctyp_of_typ ctx typ)) + | "xor_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_call (Bvxor, [v1; v2]), typ)) - | "vector_access", [AV_C_fragment (vec, _, CT_fbits _); AV_C_fragment (n, _, _)] -> - AE_val (AV_C_fragment (F_op (v_one, "&", F_op (vec, ">>", n)), typ, CT_bit)) + | "vector_subrange", [AV_cval (vec, _); AV_cval (f, _); AV_cval (t, _)] -> + begin match ctyp_of_typ ctx typ with + | CT_fbits (n, true) -> + AE_val (AV_cval (V_call (Slice n, [vec; t]), typ)) + | _ -> no_change + end - | "eq_bit", [AV_C_fragment (a, _, _); AV_C_fragment (b, _, _)] -> - AE_val (AV_C_fragment (F_op (a, "==", b), typ, CT_bool)) + | "slice", [AV_cval (vec, _); AV_cval (start, _); AV_cval (len, _)] -> + begin match ctyp_of_typ ctx typ with + | CT_fbits (n, _) -> + AE_val (AV_cval (V_call (Slice n, [vec; start]), typ)) + | CT_sbits (64, _) -> + AE_val (AV_cval (V_call (Sslice 64, [vec; start; len]), typ)) + | _ -> no_change + end - | "slice", [AV_C_fragment (vec, _, CT_fbits _); AV_C_fragment (start, _, _); AV_C_fragment (len, _, _)] - when is_fbits_typ ctx typ -> - AE_val (AV_C_fragment (F_op (F_call ("safe_rshift", [F_raw "UINT64_MAX"; F_op (v_int 64, "-", len)]), "&", F_op (vec, ">>", start)), - typ, - ctyp_of_typ ctx typ)) + | "vector_access", [AV_cval (vec, _); AV_cval (n, _)] -> + AE_val (AV_cval (V_call (Bvaccess, [vec; n]), typ)) - | "slice", [AV_C_fragment (vec, _, CT_fbits _); AV_C_fragment (start, _, _); AV_C_fragment (len, _, _)] - when is_sbits_typ ctx typ -> - AE_val (AV_C_fragment (F_call ("sslice", [vec; start; len]), typ, ctyp_of_typ ctx typ)) + | "add_int", [AV_cval (op1, _); AV_cval (op2, _)] -> + begin match destruct_range Env.empty typ with + | None -> no_change + | Some (kids, constr, n, m) -> + match nexp_simp n, nexp_simp m with + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> + AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) + | n, m when prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) -> + AE_val (AV_cval (V_call (Iadd, [op1; op2]), typ)) + | _ -> no_change + end - | "undefined_bit", _ -> - AE_val (AV_C_fragment (F_lit (V_bit Sail2_values.B0), typ, CT_bit)) + | "replicate_bits", [AV_cval (vec, vtyp); _] -> + begin match destruct_vector ctx.tc_env typ, destruct_vector ctx.tc_env vtyp with + | Some (Nexp_aux (Nexp_constant n, _), _, _), Some (Nexp_aux (Nexp_constant m, _), _, _) + when Big_int.less_equal n (Big_int.of_int 64) -> + let times = Big_int.div n m in + if Big_int.equal (Big_int.mul m times) n then + AE_val (AV_cval (V_call (Replicate (Big_int.to_int times), [vec]), typ)) + else + no_change + | _, _ -> + no_change + end + (* | "undefined_vector", [AV_C_fragment (len, _, _); _] -> begin match destruct_vector ctx.tc_env typ with | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) @@ -539,39 +544,23 @@ let analyze_primop' ctx id args typ = | _ -> no_change end - | "add_int", [AV_C_fragment (op1, _, _); AV_C_fragment (op2, _, _)] -> - begin match destruct_range Env.empty typ with - | None -> no_change - | Some (kids, constr, n, m) -> - match nexp_simp n, nexp_simp m with - | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) - when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> - AE_val (AV_C_fragment (F_op (op1, "+", op2), typ, CT_fint 64)) - | n, m when prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) -> - AE_val (AV_C_fragment (F_op (op1, "+", op2), typ, CT_fint 64)) - | _ -> no_change - end - | "neg_int", [AV_C_fragment (frag, _, _)] -> AE_val (AV_C_fragment (F_op (v_int 0, "-", frag), typ, CT_fint 64)) - | "replicate_bits", [AV_C_fragment (vec, vtyp, _); AV_C_fragment (times, _, _)] -> - begin match destruct_vector ctx.tc_env typ, destruct_vector ctx.tc_env vtyp with - | Some (Nexp_aux (Nexp_constant n, _), _, _), Some (Nexp_aux (Nexp_constant m, _), _, _) - when Big_int.less_equal n (Big_int.of_int 64) -> - AE_val (AV_C_fragment (F_call ("fast_replicate_bits", [F_lit (V_int m); vec; times]), typ, ctyp_of_typ ctx typ)) - | _ -> no_change - end - | "vector_update_subrange", [AV_C_fragment (xs, _, CT_fbits (n, true)); AV_C_fragment (hi, _, CT_fint 64); AV_C_fragment (lo, _, CT_fint 64); AV_C_fragment (ys, _, CT_fbits (m, true))] -> AE_val (AV_C_fragment (F_call ("fast_update_subrange", [xs; hi; lo; ys]), typ, CT_fbits (n, true))) + *) + + | "undefined_bit", _ -> + AE_val (AV_cval (V_lit (VL_bit Sail2_values.B0, CT_bit), typ)) + | "undefined_bool", _ -> - AE_val (AV_C_fragment (F_lit (V_bool false), typ, CT_bool)) - *) + AE_val (AV_cval (V_lit (VL_bool false, CT_bool), typ)) + | _, _ -> c_debug (lazy ("No optimization routine found")); no_change @@ -1076,6 +1065,7 @@ let rec sgen_ctyp = function | CT_fbits _ -> "uint64_t" | CT_sbits _ -> "sbits" | CT_fint _ -> "int64_t" + | CT_constant _ -> "int64_t" | CT_lint -> "sail_int" | CT_lbits _ -> "lbits" | CT_tup _ as tup -> "struct " ^ Util.zencode_string ("tuple_" ^ string_of_ctyp tup) @@ -1088,7 +1078,6 @@ let rec sgen_ctyp = function | CT_real -> "real" | CT_ref ctyp -> sgen_ctyp ctyp ^ "*" | CT_poly -> "POLY" (* c_error "Tried to generate code for non-monomorphic type" *) - | CT_constant _ -> "CONSTANT" let rec sgen_ctyp_name = function | CT_unit -> "unit" @@ -1097,6 +1086,7 @@ let rec sgen_ctyp_name = function | CT_fbits _ -> "fbits" | CT_sbits _ -> "sbits" | CT_fint _ -> "mach_int" + | CT_constant _ -> "mach_int" | CT_lint -> "sail_int" | CT_lbits _ -> "lbits" | CT_tup _ as tup -> Util.zencode_string ("tuple_" ^ string_of_ctyp tup) @@ -1109,7 +1099,6 @@ let rec sgen_ctyp_name = function | CT_real -> "real" | CT_ref ctyp -> "ref_" ^ sgen_ctyp_name ctyp | CT_poly -> "POLY" (* c_error "Tried to generate code for non-monomorphic type" *) - | CT_constant _ -> "CONSTANT" let rec sgen_cval = function | V_id (id, ctyp) -> string_of_name id @@ -1172,9 +1161,77 @@ and sgen_call op cvals = | Isub, [v1; v2] -> sprintf "(%s - %s)" (sgen_cval v1) (sgen_cval v2) | Bvand, [v1; v2] -> - sprintf "(%s & %s)" (sgen_cval v1) (sgen_cval v2) + begin match cval_ctyp v1 with + | CT_fbits _ -> + sprintf "(%s & %s)" (sgen_cval v1) (sgen_cval v2) + | CT_sbits _ -> + sprintf "and_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvnot, [v] -> + begin match cval_ctyp v with + | CT_fbits (n, _) -> + sprintf "(~(%s) & %s)" (sgen_cval v) (sgen_cval (v_mask_lower n)) + | CT_sbits _ -> + sprintf "not_sbits(%s)" (sgen_cval v) + | _ -> assert false + end | Bvor, [v1; v2] -> - sprintf "(%s | %s)" (sgen_cval v1) (sgen_cval v2) + begin match cval_ctyp v1 with + | CT_fbits _ -> + sprintf "(%s | %s)" (sgen_cval v1) (sgen_cval v2) + | CT_sbits _ -> + sprintf "or_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvxor, [v1; v2] -> + begin match cval_ctyp v1 with + | CT_fbits _ -> + sprintf "(%s ^ %s)" (sgen_cval v1) (sgen_cval v2) + | CT_sbits _ -> + sprintf "xor_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvadd, [v1; v2] -> + begin match cval_ctyp v1 with + | CT_fbits (n, _) -> + sprintf "((%s + %s) & %s)" (sgen_cval v1) (sgen_cval v2) (sgen_cval (v_mask_lower n)) + | CT_sbits _ -> + sprintf "add_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvsub, [v1; v2] -> + begin match cval_ctyp v1 with + | CT_fbits (n, _) -> + sprintf "((%s - %s) & %s)" (sgen_cval v1) (sgen_cval v2) (sgen_cval (v_mask_lower n)) + | CT_sbits _ -> + sprintf "sub_sbits(%s, %s)" (sgen_cval v1) (sgen_cval v2) + | _ -> assert false + end + | Bvaccess, [vec; n] -> + begin match cval_ctyp vec with + | CT_fbits _ -> + sprintf "(UINT64_C(1) & (%s >> %s))" (sgen_cval vec) (sgen_cval n) + | CT_sbits _ -> + sprintf "(UINT64_C(1) & (%s.bits >> %s))" (sgen_cval vec) (sgen_cval n) + | _ -> assert false + end + | Slice len, [vec; start] -> + begin match cval_ctyp vec with + | CT_fbits _ -> + sprintf "(safe_rshift(UINT64_MAX, 64 - %d) & (%s >> %s))" len (sgen_cval vec) (sgen_cval start) + | CT_sbits _ -> + sprintf "(safe_rshift(UINT64_MAX, 64 - %d) & (%s.bits >> %s))" len (sgen_cval vec) (sgen_cval start) + | _ -> assert false + end + | Sslice 64, [vec; start; len] -> + begin match cval_ctyp vec with + | CT_fbits _ -> + sprintf "sslice(%s, %s, %s)" (sgen_cval vec) (sgen_cval start) (sgen_cval len) + | CT_sbits _ -> + sprintf "sslice(%s.bits, %s, %s)" (sgen_cval vec) (sgen_cval start) (sgen_cval len) + | _ -> assert false + end | Zero_extend n, [v] -> begin match cval_ctyp v with | CT_fbits _ -> sgen_cval v @@ -1190,6 +1247,12 @@ and sgen_call op cvals = sprintf "fast_sign_extend2(%s, %d)" (sgen_cval v) n | _ -> assert false end + | Replicate n, [v] -> + begin match cval_ctyp v with + | CT_fbits (m, _) -> + sprintf "fast_replicate_bits(UINT64_C(%d), %s, %d)" m (sgen_cval v) n + | _ -> assert false + end | Concat, [v1; v2] -> (* Optimized routines for all combinations of fixed and small bits appends, where the result is guaranteed to be smaller than 64. *) diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index 1a4b8b4b..b76e65b7 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -61,6 +61,8 @@ let opt_debug_function = ref "" let opt_debug_flow_graphs = ref false let opt_memo_cache = ref false +let optimize_aarch64_fast_struct = ref false + let ngensym () = name (gensym ()) (**************************************************************************) @@ -132,6 +134,10 @@ let is_ct_ref = function | CT_ref _ -> true | _ -> false +let iblock1 = function + | [instr] -> instr + | instrs -> iblock instrs + let ctor_bindings = List.fold_left (fun map (id, ctyp) -> Bindings.add id ctyp map) Bindings.empty (** The context type contains two type-checking @@ -266,7 +272,7 @@ let rec compile_aval l ctx = function [idecl ctyp gs], V_struct (fields, ctyp), [iclear ctyp gs] - + | AV_record (fields, typ) -> let ctyp = ctyp_of_typ ctx typ in let gs = ngensym () in @@ -396,7 +402,49 @@ let rec compile_aval l ctx = function V_id (gs, CT_list ctyp), [iclear (CT_list ctyp) gs] -let compile_funcall l ctx id args typ = +let optimize_call l ctx clexp id args arg_ctyps ret_ctyp = + let call () = + let setup = ref [] in + let cleanup = ref [] in + let cast_args = + List.map2 + (fun ctyp cval -> + let have_ctyp = cval_ctyp cval in + if is_polymorphic ctyp then + V_poly (cval, have_ctyp) + else if ctx.specialize_calls || ctyp_equal ctyp have_ctyp then + cval + else + let gs = ngensym () in + setup := iinit ctyp gs cval :: !setup; + cleanup := iclear ctyp gs :: !cleanup; + V_id (gs, ctyp)) + arg_ctyps args + in + if ctx.specialize_calls || ctyp_equal (clexp_ctyp clexp) ret_ctyp then + !setup @ [ifuncall clexp id cast_args] @ !cleanup + else + let gs = ngensym () in + List.rev !setup + @ [idecl ret_ctyp gs; + ifuncall (CL_id (gs, ret_ctyp)) id cast_args; + icopy l clexp (V_id (gs, ret_ctyp)); + iclear ret_ctyp gs] + @ !cleanup + in + if not ctx.specialize_calls && Env.is_extern id ctx.tc_env "c" then + let extern = Env.get_extern id ctx.tc_env "c" in + begin match extern, List.map cval_ctyp args, clexp_ctyp clexp with + | "slice", [CT_fbits _; _; _], CT_fbits (n, _) -> + let start = ngensym () in + [iinit (CT_fint 64) start (List.nth args 1); + icopy l clexp (V_call (Slice n, [List.nth args 0; V_id (start, CT_fint 64)]))] + | _, _, _ -> + call () + end + else call () + +let compile_funcall l ctx id args = let setup = ref [] in let cleanup = ref [] in @@ -412,38 +460,21 @@ let compile_funcall l ctx id args typ = in let ctx' = { ctx with local_env = add_typquant (id_loc id) quant ctx.tc_env } in let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ in - let final_ctyp = ctyp_of_typ ctx typ in + + assert (List.length arg_ctyps = List.length args); let setup_arg ctyp aval = let arg_setup, cval, arg_cleanup = compile_aval l ctx aval in setup := List.rev arg_setup @ !setup; cleanup := arg_cleanup @ !cleanup; - let have_ctyp = cval_ctyp cval in - if is_polymorphic ctyp then - V_poly (cval, have_ctyp) - else if ctx.specialize_calls || ctyp_equal ctyp have_ctyp then - cval - else - let gs = ngensym () in - setup := iinit ctyp gs cval :: !setup; - cleanup := iclear ctyp gs :: !cleanup; - V_id (gs, ctyp) + cval in - assert (List.length arg_ctyps = List.length args); - let setup_args = List.map2 setup_arg arg_ctyps args in List.rev !setup, begin fun clexp -> - if ctx.specialize_calls || ctyp_equal (clexp_ctyp clexp) ret_ctyp then - ifuncall clexp id setup_args - else - let gs = ngensym () in - iblock [idecl ret_ctyp gs; - ifuncall (CL_id (gs, ret_ctyp)) id setup_args; - icopy l clexp (V_id (gs, ret_ctyp)); - iclear ret_ctyp gs] + iblock1 (optimize_call l ctx clexp id setup_args arg_ctyps ret_ctyp) end, !cleanup @@ -553,15 +584,15 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let setup, call, cleanup = compile_aexp ctx binding in let letb_setup, letb_cleanup = [idecl binding_ctyp (name id); - iblock (setup @ [call (CL_id (name id, binding_ctyp))] @ cleanup)], + iblock1 (setup @ [call (CL_id (name id, binding_ctyp))] @ cleanup)], [iclear binding_ctyp (name id)] in let ctx = { ctx with locals = Bindings.add id (mut, binding_ctyp) ctx.locals } in let setup, call, cleanup = compile_aexp ctx body in letb_setup @ setup, call, cleanup @ letb_cleanup - | AE_app (id, vs, typ) -> - compile_funcall l ctx id vs typ + | AE_app (id, vs, _) -> + compile_funcall l ctx id vs | AE_val aval -> let setup, cval, cleanup = compile_aval l ctx aval in @@ -912,6 +943,10 @@ and compile_block ctx = function let gs = ngensym () in iblock (setup @ [idecl CT_unit gs; call (CL_id (gs, CT_unit))] @ cleanup) :: rest +let fast_int = function + | CT_lint when !optimize_aarch64_fast_struct -> CT_fint 64 + | ctyp -> ctyp + (** Compile a sail type definition into a IR one. Most of the actual work of translating the typedefs into C is done by the code generator, as it's easy to keep track of structs, tuples and unions @@ -928,7 +963,7 @@ let compile_type_def ctx (TD_aux (type_def, (l, _))) = | TD_record (id, typq, ctors, _) -> let record_ctx = { ctx with local_env = add_typquant l typq ctx.local_env } in let ctors = - List.fold_left (fun ctors (typ, id) -> Bindings.add id (ctyp_of_typ record_ctx typ) ctors) Bindings.empty ctors + List.fold_left (fun ctors (typ, id) -> Bindings.add id (fast_int (ctyp_of_typ record_ctx typ)) ctors) Bindings.empty ctors in CTD_struct (id, Bindings.bindings ctors), { ctx with records = Bindings.add id ctors ctx.records } @@ -1318,7 +1353,7 @@ and compile_def' n total ctx = function (* Termination measures only needed for Coq, and other theorem prover output *) | DEF_measure _ -> [], ctx | DEF_loop_measures _ -> [], ctx - + | DEF_internal_mutrec fundefs -> let defs = List.map (fun fdef -> DEF_fundef fdef) fundefs in List.fold_left (fun (cdefs, ctx) def -> let cdefs', ctx = compile_def n total ctx def in (cdefs @ cdefs', ctx)) ([], ctx) defs diff --git a/src/jib/jib_compile.mli b/src/jib/jib_compile.mli index ac41670c..b882f2f9 100644 --- a/src/jib/jib_compile.mli +++ b/src/jib/jib_compile.mli @@ -63,6 +63,11 @@ val opt_debug_flow_graphs : bool ref (** Print the IR representation of a specific function. *) val opt_debug_function : string ref +(** This forces all integer struct fields to be represented as + int64_t. Specifically intended for the various TLB structs in the + ARM v8.5 spec. *) +val optimize_aarch64_fast_struct : bool ref + val ngensym : unit -> name (** {2 Jib context} *) diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml index 1362d4f8..441f8bd1 100644 --- a/src/jib/jib_util.ml +++ b/src/jib/jib_util.ml @@ -282,21 +282,29 @@ let string_of_op = function | Bit_to_bool -> "bit_to_bool" | Eq -> "eq" | Neq -> "neq" + | Bvnot -> "bvnot" | Bvor -> "bvor" | Bvand -> "bvand" + | Bvxor -> "bvxor" + | Bvadd -> "bvadd" + | Bvsub -> "bvsub" + | Bvaccess -> "bvaccess" | Ilt -> "lt" | Igt -> "gt" | Ilteq -> "lteq" | Igteq -> "gteq" | Iadd -> "iadd" | Isub -> "isub" - | Zero_extend n -> "zero_extend" ^ string_of_int n - | Sign_extend n -> "sign_extend" ^ string_of_int n + | Zero_extend n -> "zero_extend::<" ^ string_of_int n ^ ">" + | Sign_extend n -> "sign_extend::<" ^ string_of_int n ^ ">" + | Slice n -> "slice::<" ^ string_of_int n ^ ">" + | Sslice n -> "sslice::<" ^ string_of_int n ^ ">" + | Replicate n -> "replicate::<" ^ string_of_int n ^ ">" | Concat -> "concat" let rec string_of_cval = function - | V_id (id, ctyp) -> string_of_name id - | V_ref (id, _) -> "&" ^ string_of_name id + | V_id (id, ctyp) -> string_of_name ~zencode:false id + | V_ref (id, _) -> "&" ^ string_of_name ~zencode:false id | V_lit (vl, ctyp) -> string_of_value vl | V_call (op, cvals) -> Printf.sprintf "%s(%s)" (string_of_op op) (Util.string_of_list ", " string_of_cval cvals) @@ -895,7 +903,9 @@ let rec infer_call op vs = | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid call to tl" end | (Eq | Neq), _ -> CT_bool - | (Bvor | Bvand), [v; _] -> cval_ctyp v + | Bvnot, [v] -> cval_ctyp v + | Bvaccess, _ -> CT_bit + | (Bvor | Bvand | Bvxor | Bvadd | Bvsub), [v; _] -> cval_ctyp v | (Ilt | Igt | Ilteq | Igteq), _ -> CT_bool | (Iadd | Isub), _ -> CT_fint 64 | (Zero_extend n | Sign_extend n), [v] -> @@ -904,6 +914,23 @@ let rec infer_call op vs = CT_fbits (n, ord) | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for zero/sign_extend argument" end + | Slice n, [vec; start] -> + begin match cval_ctyp vec with + | CT_fbits (_, ord) | CT_sbits (_, ord) -> + CT_fbits (n, ord) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for extract argument" + end + | Sslice n, [vec; start; len] -> + begin match cval_ctyp vec with + | CT_fbits (_, ord) | CT_sbits (_, ord) -> + CT_sbits (n, ord) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for extract argument" + end + | Replicate n, [vec] -> + begin match cval_ctyp vec with + | CT_fbits (m, ord) -> CT_fbits (n * m, ord) + | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Invalid type for replicate argument" + end | Concat, [v1; v2] -> begin match cval_ctyp v1, cval_ctyp v2 with | CT_fbits (n, ord), CT_fbits (m, _) -> diff --git a/src/sail.ml b/src/sail.ml index 14f24251..3c277fab 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -196,6 +196,9 @@ let options = Arg.align ([ ( "-Oconstant_fold", Arg.Set Constant_fold.optimize_constant_fold, " apply constant folding optimizations"); + ( "-Oaarch64_fast", + Arg.Set Jib_compile.optimize_aarch64_fast_struct, + " apply ARMv8.5 specific optimizations (potentially unsound in general)"); ( "-static", Arg.Set C_backend.opt_static, " make generated C functions static"); diff --git a/src/specialize.ml b/src/specialize.ml index b2eb5314..a601974e 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -73,7 +73,7 @@ let typ_ord_specialization = { let int_specialization = { is_polymorphic = is_int_kopt; - instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp _, _) -> true | _ -> false); + instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp (Nexp_aux (Nexp_constant _, _)), _) -> true | _ -> false); extern_filter = (fun externs -> match Ast_util.extern_assoc "c" externs with Some _ -> true | None -> false) } @@ -574,8 +574,12 @@ let specialize_ids spec ids ast = | None -> () end; let ast, _ = Type_error.check Type_check.initial_env ast in - let ast = - List.fold_left (fun ast id -> rewrite_polymorphic_calls spec id ast) ast (IdSet.elements ids) + let _, ast = + List.fold_left + (fun (n, ast) id -> + Util.progress "Rewriting " (string_of_id id) n total; + (n + 1, rewrite_polymorphic_calls spec id ast)) + (1, ast) (IdSet.elements ids) in let ast, env = Type_error.check Type_check.initial_env ast in let ast = remove_unused_valspecs env ast in diff --git a/src/value.ml b/src/value.ml index 7a65f6ea..958a1919 100644 --- a/src/value.ml +++ b/src/value.ml @@ -599,7 +599,7 @@ let value_string_append = function let value_decimal_string_of_bits = function | [v] -> V_string (Sail_lib.decimal_string_of_bits (coerce_bv v)) | _ -> failwith "value decimal_string_of_bits" - + let primops = List.fold_left (fun r (x, y) -> StringMap.add x y r) diff --git a/test/c/run_tests.py b/test/c/run_tests.py index be953749..f5347831 100755 --- a/test/c/run_tests.py +++ b/test/c/run_tests.py @@ -94,7 +94,6 @@ xml += test_c('unoptimized C', '', '', True) xml += test_c('optimized C', '-O2', '-O', True) xml += test_c('constant folding', '', '-Oconstant_fold', True) xml += test_c('monomorphised C', '-O2', '-O -Oconstant_fold -auto_mono', True) -xml += test_c('full optimizations', '-O2 -mbmi2 -DINTRINSICS', '-O -Oconstant_fold', True) xml += test_c('specialization', '-O1', '-O -c_specialize', True) xml += test_c('undefined behavior sanitised', '-O2 -fsanitize=undefined', '-O', False) |
