diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/jib/anf.ml | 36 | ||||
| -rw-r--r-- | src/jib/anf.mli | 5 | ||||
| -rw-r--r-- | src/jib/c_backend.ml | 48 | ||||
| -rw-r--r-- | src/jib/jib_compile.ml | 2 | ||||
| -rw-r--r-- | src/jib/jib_optimize.ml | 9 | ||||
| -rw-r--r-- | src/jib/jib_smt.ml | 25 | ||||
| -rw-r--r-- | src/jib/jib_util.ml | 7 |
7 files changed, 100 insertions, 32 deletions
diff --git a/src/jib/anf.ml b/src/jib/anf.ml index 4bb24032..3edd2cd7 100644 --- a/src/jib/anf.ml +++ b/src/jib/anf.ml @@ -153,6 +153,21 @@ let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) = in AP_aux (apat_aux, env, l) +let rec aval_typ = function + | AV_lit (_, typ) -> typ + | AV_id (_, lvar) -> lvar_typ lvar + | AV_ref (_, lvar) -> lvar_typ lvar + | AV_tuple avals -> tuple_typ (List.map aval_typ avals) + | AV_list (_, typ) -> typ + | AV_vector (_, typ) -> typ + | AV_record (_, typ) -> typ + | AV_cval (_, typ) -> typ + +let aexp_typ (AE_aux (aux, _, _)) = + match aux with + | AE_val aval -> aval_typ aval + | AE_app (_, _, typ) -> typ + let rec aval_rename from_id to_id = function | AV_lit (lit, typ) -> AV_lit (lit, typ) | AV_id (id, lvar) when Id.compare id from_id = 0 -> AV_id (to_id, lvar) @@ -298,6 +313,27 @@ let rec map_functions f (AE_aux (aexp, env, l)) = in AE_aux (aexp, env, l) +let rec fold_aexp f (AE_aux (aexp, env, l)) = + let aexp = match aexp with + | AE_app (id, vs, typ) -> AE_app (id, vs, typ) + | AE_cast (aexp, typ) -> AE_cast (fold_aexp f aexp, typ) + | AE_assign (id, typ, aexp) -> AE_assign (id, typ, fold_aexp f aexp) + | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, fold_aexp f aexp) + | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> AE_let (mut, id, typ1, fold_aexp f aexp1, fold_aexp f aexp2, typ2) + | AE_block (aexps, aexp, typ) -> AE_block (List.map (fold_aexp f) aexps, fold_aexp f aexp, typ) + | AE_if (aval, aexp1, aexp2, typ) -> + AE_if (aval, fold_aexp f aexp1, fold_aexp f aexp2, typ) + | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, fold_aexp f aexp1, fold_aexp f aexp2) + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> + AE_for (id, fold_aexp f aexp1, fold_aexp f aexp2, fold_aexp f aexp3, order, fold_aexp f aexp4) + | AE_case (aval, cases, typ) -> + AE_case (aval, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ) + | AE_try (aexp, cases, typ) -> + AE_try (fold_aexp f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ) + | AE_field _ | AE_record_update _ | AE_val _ | AE_return _ | AE_throw _ as v -> v + in + f (AE_aux (aexp, env, l)) + (* For debugging we provide a pretty printer for ANF expressions. *) let pp_lvar lvar doc = diff --git a/src/jib/anf.mli b/src/jib/anf.mli index 6bc274e6..571546cb 100644 --- a/src/jib/anf.mli +++ b/src/jib/anf.mli @@ -134,12 +134,17 @@ val gensym : unit -> id (** {2 Functions for transforming ANF expressions} *) +val aval_typ : typ aval -> typ +val aexp_typ : typ aexp -> typ + (** Map over all values in an ANF expression *) val map_aval : (Env.t -> Ast.l -> 'a aval -> 'a aval) -> 'a aexp -> 'a aexp (** Map over all function calls in an ANF expression *) val map_functions : (Env.t -> Ast.l -> id -> ('a aval) list -> 'a -> 'a aexp_aux) -> 'a aexp -> 'a aexp +val fold_aexp : ('a aexp -> 'a aexp) -> 'a aexp -> 'a aexp + (** Remove all variable shadowing in an ANF expression *) val no_shadow : IdSet.t -> 'a aexp -> 'a aexp diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index f6d8dd80..0fe986e4 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -405,40 +405,46 @@ let analyze_primop' ctx id args typ = match extern, args with (* - | "eq_bits", [AV_cval (v1, _, CT_fbits _); AV_cval (v2, _, _)] -> - AE_val (AV_cval (F_op (v1, "==", v2), typ, CT_bool)) - | "eq_bits", [AV_cval (v1, _, CT_sbits _); AV_cval (v2, _, _)] -> - AE_val (AV_cval (F_call ("eq_sbits", [v1; v2]), typ, CT_bool)) - - | "neq_bits", [AV_cval (v1, _, CT_fbits _); AV_cval (v2, _, _)] -> - AE_val (AV_cval (F_op (v1, "!=", v2), typ, CT_bool)) - | "neq_bits", [AV_cval (v1, _, CT_sbits _); AV_cval (v2, _, _)] -> - AE_val (AV_cval (F_call ("neq_sbits", [v1; v2]), typ, CT_bool)) - - | "eq_int", [AV_cval (v1, typ1, _); AV_cval (v2, typ2, _)] -> - AE_val (AV_cval (F_op (v1, "==", v2), typ, CT_bool)) + | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + begin match cval_ctyp v1 with + | CT_fbits _ -> + AE_val (AV_cval (V_op (v1, "==", v2), typ)) + | CT_sbits _ -> + AE_val (AV_cval (V_call ("eq_sbits", [v1; v2]), typ)) + | _ -> no_change + end - | "zeros", [_] -> - begin match destruct_vector ctx.tc_env typ with - | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) - when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - AE_val (AV_C_fragment (F_raw "0x0", typ, CT_fbits (Big_int.to_int n, true))) + | "neq_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + begin match cval_ctyp v1 with + | CT_fbits _ -> + AE_val (AV_cval (V_op (v1, "!=", v2), typ)) + | CT_sbits _ -> + AE_val (AV_cval (V_call ("neq_sbits", [v1; v2]), typ)) | _ -> no_change end - | "zero_extend", [AV_C_fragment (v1, _, CT_fbits _); _] -> + | "eq_int", [AV_cval (v1, _); AV_cval (v2, _)] -> + AE_val (AV_cval (V_op (v1, "==", v2), typ)) + + | "zeros", [_] -> begin match destruct_vector ctx.tc_env typ with | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - AE_val (AV_C_fragment (v1, typ, CT_fbits (Big_int.to_int n, true))) + let n = Big_int.to_int n in + AE_val (AV_cval (V_lit (VL_bits (Util.list_init n (fun _ -> Sail2_values.B0), true), CT_fbits (n, true)), typ)) | _ -> no_change end - | "zero_extend", [AV_C_fragment (v1, _, CT_sbits _); _] -> + | "zero_extend", [AV_cval (v1, _); _] -> begin match destruct_vector ctx.tc_env typ with | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _)) when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) -> - AE_val (AV_C_fragment (F_call ("fast_zero_extend", [v1; v_int (Big_int.to_int n)]), typ, CT_fbits (Big_int.to_int n, true))) + begin match cval_ctyp v1 with + | CT_fbits _ -> + AE_val (AV_C_fragment (v1, typ, CT_fbits (Big_int.to_int n, true))) + | CT_sbits _ -> + AE_val (AV_C_fragment (F_call ("fast_zero_extend", [v1; v_int (Big_int.to_int n)]), typ, CT_fbits (Big_int.to_int n, true))) + end | _ -> no_change end diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index bc3314ec..a8a0c640 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -335,7 +335,7 @@ let rec compile_aval l ctx = function [icopy l (CL_id (gs, ctyp)) (V_op (V_id (gs, ctyp), "|", V_lit (mask i, ctyp)))] | _ -> (* FIXME: Make this work in C *) - setup @ [iif (V_unary ("bit_to_bool", cval)) [icopy l (CL_id (gs, ctyp)) (V_op (V_id (gs, ctyp), "|", V_lit (mask i, ctyp)))] [] CT_unit] @ cleanup + setup @ [iif (V_call ("bit_to_bool", [cval])) [icopy l (CL_id (gs, ctyp)) (V_op (V_id (gs, ctyp), "|", V_lit (mask i, ctyp)))] [] CT_unit] @ cleanup in [idecl ctyp gs; icopy l (CL_id (gs, ctyp)) (V_lit (VL_bits (Util.list_init 64 (fun _ -> Sail2_values.B0), direction), ctyp))] diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml index 3fc42aa3..331cf65e 100644 --- a/src/jib/jib_optimize.ml +++ b/src/jib/jib_optimize.ml @@ -169,6 +169,8 @@ let rec cval_subst id subst = function | V_ctor_unwrap (ctor, cval, unifiers, ctyp) -> V_ctor_unwrap (ctor, cval_subst id subst cval, unifiers, ctyp) | V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> field, cval_subst id subst cval) fields, ctyp) | V_poly (cval, ctyp) -> V_poly (cval_subst id subst cval, ctyp) + | V_hd cval -> V_hd (cval_subst id subst cval) + | V_tl cval -> V_tl (cval_subst id subst cval) let rec cval_map_id f = function | V_id (id, ctyp) -> V_id (f id, ctyp) @@ -231,16 +233,13 @@ let rec instrs_subst id subst = | [] -> [] let rec clexp_subst id subst = function - | CL_id (id', ctyp) when Name.compare id id' = 0 -> - if ctyp_equal ctyp (clexp_ctyp subst) then - subst - else - subst + | CL_id (id', ctyp) when Name.compare id id' = 0 -> subst | CL_id (id', ctyp) -> CL_id (id', ctyp) | CL_field (clexp, field) -> CL_field (clexp_subst id subst clexp, field) | CL_addr clexp -> CL_addr (clexp_subst id subst clexp) | CL_tuple (clexp, n) -> CL_tuple (clexp_subst id subst clexp, n) | CL_void -> CL_void + | CL_rmw _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot substitute into read-modify-write construct" let rec find_function fid = function | CDEF_fundef (fid', heap_return, args, body) :: _ when Id.compare fid fid' = 0 -> diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index 44f8e24b..76239b35 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -177,7 +177,7 @@ let rec smt_cval env cval = Fn ("not", [Fn ("=", [smt_cval env frag1; smt_cval env frag2])]) | V_op (frag1, "|", frag2) -> Fn ("bvor", [smt_cval env frag1; smt_cval env frag2]) - | V_unary ("bit_to_bool", cval) -> + | V_call ("bit_to_bool", [cval]) -> Fn ("=", [smt_cval env cval; Bin "1"]) | V_unary ("!", cval) -> Fn ("not", [smt_cval env cval]) @@ -1029,6 +1029,27 @@ let c_literals ctx = in map_aval c_literal +let unroll_foreach ctx = function + | AE_aux (AE_for (id, from_aexp, to_aexp, by_aexp, order, body), env, l) as aexp -> + begin match ctyp_of_typ ctx (aexp_typ from_aexp), ctyp_of_typ ctx (aexp_typ to_aexp), ctyp_of_typ ctx (aexp_typ by_aexp), order with + | CT_constant f, CT_constant t, CT_constant b, Ord_aux (Ord_inc, _) -> + let i = ref f in + let unrolled = ref [] in + while Big_int.less_equal !i t do + let current_index = AE_aux (AE_val (AV_lit (L_aux (L_num !i, gen_loc l), atom_typ (nconstant !i))), env, gen_loc l) in + let iteration = AE_aux (AE_let (Immutable, id, atom_typ (nconstant !i), current_index, body, unit_typ), env, gen_loc l) in + unrolled := iteration :: !unrolled; + i := Big_int.add !i b + done; + begin match !unrolled with + | last :: iterations -> + AE_aux (AE_block (List.rev iterations, last, unit_typ), env, gen_loc l) + | [] -> AE_aux (AE_val (AV_lit (L_aux (L_unit, gen_loc l), unit_typ)), env, gen_loc l) + end + | _ -> aexp + end + | aexp -> aexp + (**************************************************************************) (* 3. Generating SMT *) (**************************************************************************) @@ -1371,7 +1392,7 @@ let generate_smt props name_file env ast = let ctx = initial_ctx ~convert_typ:ctyp_of_typ - ~optimize_anf:(fun ctx aexp -> c_literals ctx aexp) + ~optimize_anf:(fun ctx aexp -> fold_aexp (unroll_foreach ctx) (c_literals ctx aexp)) env in let t = Profile.start () in diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml index df2ce369..7b0b4f92 100644 --- a/src/jib/jib_util.ml +++ b/src/jib/jib_util.ml @@ -280,9 +280,9 @@ let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) = "current_exception" ^ ssa_num n let rec string_of_cval ?zencode:(zencode=true) = function - | V_id (id, ctyp) -> string_of_name ~zencode:zencode id ^ " : " ^ string_of_ctyp ctyp + | V_id (id, ctyp) -> string_of_name ~zencode:zencode id | V_ref (id, _) -> "&" ^ string_of_name ~zencode:zencode id - | V_lit (vl, ctyp) -> string_of_value vl ^ " : " ^ string_of_ctyp ctyp + | V_lit (vl, ctyp) -> string_of_value vl | V_call (str, cvals) -> Printf.sprintf "%s(%s)" str (Util.string_of_list ", " (string_of_cval ~zencode:zencode) cvals) | V_field (f, field) -> @@ -675,6 +675,7 @@ let rec cval_deps = function | V_ctor_kind (cval, _, _, _) -> cval_deps cval | V_ctor_unwrap (_, cval, _, _) -> cval_deps cval | V_hd cval | V_tl cval -> cval_deps cval + | V_struct (fields, ctyp) -> List.fold_left (fun ns (_, cval) -> NameSet.union ns (cval_deps cval)) NameSet.empty fields let rec clexp_deps = function | CL_id (id, _) -> NameSet.empty, NameSet.singleton id @@ -890,7 +891,6 @@ let label str = let rec infer_unary v = function | "!" -> CT_bool - | "bit_to_bool" -> CT_bool | op -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Could not infer unary " ^ op) and infer_op v1 v2 = function @@ -907,6 +907,7 @@ and infer_op v1 v2 = function | op -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Cannot infer binary op: " ^ op) and infer_call vs = function + | "bit_to_bool" -> CT_bool | op -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Cannot infer call: " ^ op) and cval_ctyp = function |
