diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/LICENCE | 48 | ||||
| -rw-r--r-- | src/ast_util.ml | 18 | ||||
| -rw-r--r-- | src/ast_util.mli | 6 | ||||
| -rw-r--r-- | src/bitfield.ml | 40 | ||||
| -rw-r--r-- | src/c_backend.ml | 420 | ||||
| -rw-r--r-- | src/monomorphise.ml | 365 | ||||
| -rw-r--r-- | src/pattern_completeness.ml | 7 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 42 | ||||
| -rw-r--r-- | src/pretty_print_sail.ml | 2 | ||||
| -rw-r--r-- | src/rewrites.ml | 1 | ||||
| -rw-r--r-- | src/sail.ml | 2 | ||||
| -rw-r--r-- | src/specialize.ml | 7 |
12 files changed, 712 insertions, 246 deletions
diff --git a/src/LICENCE b/src/LICENCE deleted file mode 100644 index c777e037..00000000 --- a/src/LICENCE +++ /dev/null @@ -1,48 +0,0 @@ - Sail - -Copyright (c) 2013-2017 - Kathyrn Gray - Shaked Flur - Stephen Kell - Gabriel Kerneis - Robert Norton-Wright - Christopher Pulte - Peter Sewell - Alasdair Armstrong - Brian Campbell - Thomas Bauereiss - Anthony Fox - Jon French - Dominic Mulligan - Stephen Kell - Mark Wassell - -All rights reserved. - -This software was developed by the University of Cambridge Computer -Laboratory and the University of Edinburgh as part of the Rigorous -Engineering of Mainstream Systems (REMS) project, funded by EPSRC -grant EP/K008528/1. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in - the documentation and/or other materials provided with the - distribution. - -THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED -TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF -USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT -OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. diff --git a/src/ast_util.ml b/src/ast_util.ml index 27ae93e8..5bbf9a40 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -182,6 +182,7 @@ module IdSet = Set.Make(Id) module KBindings = Map.Make(Kid) module KidSet = Set.Make(Kid) module NexpSet = Set.Make(Nexp) +module NexpMap = Map.Make(Nexp) let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0) @@ -235,6 +236,14 @@ and nexp_simp_aux = function when Big_int.equal c1 (Big_int.negate c2) -> n | _, _ -> Nexp_minus (n1, n2) end + | Nexp_app (Id_aux (Id "div",_) as id,[n1;n2]) -> + begin + let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in + let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in + match n1_simp, n2_simp with + | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (Big_int.div c1 c2) + | _, _ -> Nexp_app (id,[n1;n2]) + end | nexp -> nexp let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown) @@ -483,7 +492,7 @@ let append_id id str = match id with | Id_aux (Id v, l) -> Id_aux (Id (v ^ str), l) | Id_aux (DeIid v, l) -> Id_aux (DeIid (v ^ str), l) - + let prepend_kid str = function | Kid_aux (Var v, l) -> Kid_aux (Var ("'" ^ str ^ String.sub v 1 (String.length v - 1)), l) @@ -1098,3 +1107,10 @@ and subst_lexp id value (LEXP_aux (lexp_aux, annot) as lexp) = | LEXP_field (lexp, id') -> LEXP_field (subst_lexp id value lexp, id') in wrap lexp_aux + +let hex_to_bin hex = + Util.string_to_list hex + |> List.map Sail_lib.hex_char + |> List.concat + |> List.map Sail_lib.char_of_bit + |> (fun bits -> String.init (List.length bits) (List.nth bits)) diff --git a/src/ast_util.mli b/src/ast_util.mli index bbbde27f..9f815899 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -248,6 +248,10 @@ module NexpSet : sig include Set.S with type elt = nexp end +module NexpMap : sig + include Map.S with type key = nexp +end + module BESet : sig include Set.S with type elt = base_effect end @@ -316,3 +320,5 @@ val ids_of_defs : 'a defs -> IdSet.t val pat_ids : 'a pat -> IdSet.t val subst : id -> 'a exp -> 'a exp -> 'a exp + +val hex_to_bin : string -> string diff --git a/src/bitfield.ml b/src/bitfield.ml index 67a26b89..391a653d 100644 --- a/src/bitfield.ml +++ b/src/bitfield.ml @@ -92,18 +92,18 @@ let full_accessor name size order = combine [full_getter name size order; full_setter name size order; full_overload name order] (* For every index range, create a getter and setter *) -let index_range_getter' name field order start stop = +let index_range_getter name field order start stop = let size = if start > stop then start - (stop - 1) else stop - (start - 1) in - let irg_val = Printf.sprintf "val _get_%s : %s -> %s" field name (bitvec size order) in - let irg_function = Printf.sprintf "function _get_%s Mk_%s(v) = v[%i .. %i]" field name start stop in + let irg_val = Printf.sprintf "val _get_%s_%s : %s -> %s" name field name (bitvec size order) in + let irg_function = Printf.sprintf "function _get_%s_%s Mk_%s(v) = v[%i .. %i]" name field name start stop in combine [ast_of_def_string order irg_val; ast_of_def_string order irg_function] -let index_range_setter' name field order start stop = +let index_range_setter name field order start stop = let size = if start > stop then start - (stop - 1) else stop - (start - 1) in - let irs_val = Printf.sprintf "val _set_%s : (register(%s), %s) -> unit effect {wreg}" field name (bitvec size order) in + let irs_val = Printf.sprintf "val _set_%s_%s : (register(%s), %s) -> unit effect {wreg}" name field name (bitvec size order) in (* Read-modify-write using an internal _reg_deref function without rreg effect *) let irs_function = String.concat "\n" - [ Printf.sprintf "function _set_%s (r_ref, v) = {" field; + [ Printf.sprintf "function _set_%s_%s (r_ref, v) = {" name field; Printf.sprintf " r = _get_%s(_reg_deref(r_ref));" name; Printf.sprintf " r[%i .. %i] = v;" start stop; Printf.sprintf " (*r_ref) = Mk_%s(r)" name; @@ -112,16 +112,30 @@ let index_range_setter' name field order start stop = in combine [ast_of_def_string order irs_val; ast_of_def_string order irs_function] -let index_range_overload field order = - ast_of_def_string order (Printf.sprintf "overload _mod_%s = {_get_%s, _set_%s}" field field field) +let index_range_update name field order start stop = + let size = if start > stop then start - (stop - 1) else stop - (start - 1) in + let iru_val = Printf.sprintf "val _update_%s_%s : (%s, %s) -> %s" name field name (bitvec size order) name in + (* Read-modify-write using an internal _reg_deref function without rreg effect *) + let iru_function = String.concat "\n" + [ Printf.sprintf "function _update_%s_%s (Mk_%s(v), x) = {" name field name; + Printf.sprintf " Mk_%s([v with %i .. %i = x]);" name start stop; + "}" + ] + in + let iru_overload = Printf.sprintf "overload update_%s = {_update_%s_%s}" field name field in + combine [ast_of_def_string order iru_val; ast_of_def_string order iru_function; ast_of_def_string order iru_overload] + +let index_range_overload name field order = + ast_of_def_string order (Printf.sprintf "overload _mod_%s = {_get_%s_%s, _set_%s_%s}" field name field name field) let index_range_accessor name field order (BF_aux (bf_aux, l)) = - let getter n m = index_range_getter' name field order (Big_int.to_int n) (Big_int.to_int m) in - let setter n m = index_range_setter' name field order (Big_int.to_int n) (Big_int.to_int m) in - let overload = index_range_overload field order in + let getter n m = index_range_getter name field order (Big_int.to_int n) (Big_int.to_int m) in + let setter n m = index_range_setter name field order (Big_int.to_int n) (Big_int.to_int m) in + let update n m = index_range_update name field order (Big_int.to_int n) (Big_int.to_int m) in + let overload = index_range_overload name field order in match bf_aux with - | BF_single n -> combine [getter n n; setter n n; overload] - | BF_range (n, m) -> combine [getter n m; setter n m; overload] + | BF_single n -> combine [getter n n; setter n n; update n n; overload] + | BF_range (n, m) -> combine [getter n m; setter n m; update n m; overload] | BF_concat _ -> failwith "Unimplemented" let field_accessor name order (id, ir) = index_range_accessor name (string_of_id id) order ir diff --git a/src/c_backend.ml b/src/c_backend.ml index 77f1b39f..fa1f2b5e 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -158,6 +158,107 @@ and aval = | AV_record of aval Bindings.t * typ | AV_C_fragment of fragment * typ +(* Renaming variables in ANF expressions *) + +let rec frag_rename from_id to_id = function + | F_id id when Id.compare id from_id = 0 -> F_id to_id + | F_id id -> F_id id + | F_lit str -> F_lit str + | F_have_exception -> F_have_exception + | F_current_exception -> F_current_exception + | F_op (f1, op, f2) -> F_op (frag_rename from_id to_id f1, op, frag_rename from_id to_id f2) + | F_unary (op, f) -> F_unary (op, frag_rename from_id to_id f) + | F_field (f, field) -> F_field (frag_rename from_id to_id f, field) + +let rec apat_bindings = function + | AP_tup apats -> List.fold_left IdSet.union IdSet.empty (List.map apat_bindings apats) + | AP_id id -> IdSet.singleton id + | AP_global (id, typ) -> IdSet.empty + | AP_app (id, apat) -> apat_bindings apat + | AP_cons (apat1, apat2) -> IdSet.union (apat_bindings apat1) (apat_bindings apat2) + | AP_nil -> IdSet.empty + | AP_wild -> IdSet.empty + +let rec aval_rename from_id to_id = function + | AV_lit (lit, typ) -> AV_lit (lit, typ) + | AV_id (id, lvar) when Id.compare id from_id = 0 -> AV_id (to_id, lvar) + | AV_id (id, lvar) -> AV_id (id, lvar) + | AV_ref (id, lvar) when Id.compare id from_id = 0 -> AV_ref (to_id, lvar) + | AV_ref (id, lvar) -> AV_ref (id, lvar) + | AV_tuple avals -> AV_tuple (List.map (aval_rename from_id to_id) avals) + | AV_list (avals, typ) -> AV_list (List.map (aval_rename from_id to_id) avals, typ) + | AV_vector (avals, typ) -> AV_vector (List.map (aval_rename from_id to_id) avals, typ) + | AV_record (avals, typ) -> AV_record (Bindings.map (aval_rename from_id to_id) avals, typ) + | AV_C_fragment (fragment, typ) -> AV_C_fragment (frag_rename from_id to_id fragment, typ) + +let rec aexp_rename from_id to_id aexp = + let recur = aexp_rename from_id to_id in + match aexp with + | AE_val aval -> AE_val (aval_rename from_id to_id aval) + | AE_app (id, avals, typ) -> AE_app (id, List.map (aval_rename from_id to_id) avals, typ) + | AE_cast (aexp, typ) -> AE_cast (recur aexp, typ) + | AE_assign (id, typ, aexp) when Id.compare from_id id = 0 -> AE_assign (to_id, typ, aexp) + | AE_assign (id, typ, aexp) -> AE_assign (id, typ, aexp) + | AE_let (id, typ1, aexp1, aexp2, typ2) when Id.compare from_id id = 0 -> AE_let (id, typ1, aexp1, aexp2, typ2) + | AE_let (id, typ1, aexp1, aexp2, typ2) -> AE_let (id, typ1, recur aexp1, recur aexp2, typ2) + | AE_block (aexps, aexp, typ) -> AE_block (List.map recur aexps, recur aexp, typ) + | AE_return (aval, typ) -> AE_return (aval_rename from_id to_id aval, typ) + | AE_throw (aval, typ) -> AE_throw (aval_rename from_id to_id aval, typ) + | AE_if (aval, then_aexp, else_aexp, typ) -> AE_if (aval_rename from_id to_id aval, recur then_aexp, recur else_aexp, typ) + | AE_field (aval, id, typ) -> AE_field (aval_rename from_id to_id aval, id, typ) + | AE_case (aval, apexps, typ) -> AE_case (aval_rename from_id to_id aval, List.map (apexp_rename from_id to_id) apexps, typ) + | AE_try (aexp, apexps, typ) -> AE_try (aexp_rename from_id to_id aexp, List.map (apexp_rename from_id to_id) apexps, typ) + | AE_record_update (aval, avals, typ) -> AE_record_update (aval_rename from_id to_id aval, Bindings.map (aval_rename from_id to_id) avals, typ) + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) when Id.compare from_id to_id = 0 -> AE_for (id, aexp1, aexp2, aexp3, order, aexp4) + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> AE_for (id, recur aexp1, recur aexp2, recur aexp3, order, recur aexp4) + | AE_loop (loop, aexp1, aexp2) -> AE_loop (loop, recur aexp1, recur aexp2) + +and apexp_rename from_id to_id (apat, aexp1, aexp2) = + if IdSet.mem from_id (apat_bindings apat) then + (apat, aexp1, aexp2) + else + (apat, aexp_rename from_id to_id aexp1, aexp_rename from_id to_id aexp2) + +let shadow_counter = ref 0 + +let new_shadow id = + let shadow_id = append_id id ("shadow#" ^ string_of_int !shadow_counter) in + incr shadow_counter; + shadow_id + +let rec no_shadow ids aexp = + match aexp with + | AE_val aval -> AE_val aval + | AE_app (id, avals, typ) -> AE_app (id, avals, typ) + | AE_cast (aexp, typ) -> AE_cast (no_shadow ids aexp, typ) + | AE_assign (id, typ, aexp) -> AE_assign (id, typ, no_shadow ids aexp) + | AE_let (id, typ1, aexp1, aexp2, typ2) when IdSet.mem id ids -> + let shadow_id = new_shadow id in + let aexp1 = no_shadow ids aexp1 in + let ids = IdSet.add shadow_id ids in + AE_let (shadow_id, typ1, aexp1, no_shadow ids (aexp_rename id shadow_id aexp2), typ2) + | AE_let (id, typ1, aexp1, aexp2, typ2) -> + AE_let (id, typ1, no_shadow ids aexp1, no_shadow (IdSet.add id ids) aexp2, typ2) + | AE_block (aexps, aexp, typ) -> AE_block (List.map (no_shadow ids) aexps, no_shadow ids aexp, typ) + | AE_return (aval, typ) -> AE_return (aval, typ) + | AE_throw (aval, typ) -> AE_throw (aval, typ) + | AE_if (aval, then_aexp, else_aexp, typ) -> AE_if (aval, no_shadow ids then_aexp, no_shadow ids else_aexp, typ) + | AE_field (aval, id, typ) -> AE_field (aval, id, typ) + | AE_case (aval, apexps, typ) -> AE_case (aval, List.map (no_shadow_apexp ids) apexps, typ) + | AE_try (aexp, apexps, typ) -> AE_try (no_shadow ids aexp, List.map (no_shadow_apexp ids) apexps, typ) + | AE_record_update (aval, avals, typ) -> AE_record_update (aval, avals, typ) + | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> + let ids = IdSet.add id ids in + AE_for (id, no_shadow ids aexp1, no_shadow ids aexp2, no_shadow ids aexp3, order, no_shadow ids aexp4) + | AE_loop (loop, aexp1, aexp2) -> AE_loop (loop, no_shadow ids aexp1, no_shadow ids aexp2) + +and no_shadow_apexp ids (apat, aexp1, aexp2) = + let shadows = IdSet.inter (apat_bindings apat) ids in + let shadows = List.map (fun id -> id, new_shadow id) (IdSet.elements shadows) in + let rename aexp = List.fold_left (fun aexp (from_id, to_id) -> aexp_rename from_id to_id aexp) aexp shadows in + let ids = IdSet.union ids (IdSet.of_list (List.map snd shadows)) in + (apat, no_shadow ids (rename aexp1), no_shadow ids (rename aexp2)) + (* Map over all the avals in an aexp. *) let rec map_aval f = function | AE_val v -> AE_val (f v) @@ -557,8 +658,6 @@ let rec anf (E_aux (e_aux, exp_annot) as exp) = | E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_plet _ | E_internal_return _ | E_internal_exp_user _ -> failwith "encountered unexpected internal node when converting to ANF" - | E_record _ -> AE_val (AV_lit (mk_lit (L_string "testing"), string_typ)) (* c_error ("Cannot convert to ANF: " ^ string_of_exp exp) *) - (**************************************************************************) (* 2. Converting sail types to C types *) (**************************************************************************) @@ -927,8 +1026,14 @@ let rec instr_ctyps (I_aux (instr, aux)) = | I_throw cval | I_jump (cval, _) | I_return cval -> [cval_ctyp cval] | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure -> [] +let rec c_ast_registers = function + | CDEF_reg_dec (id, ctyp) :: ast -> (id, ctyp) :: c_ast_registers ast + | _ :: ast -> c_ast_registers ast + | [] -> [] + let cdef_ctyps ctx = function | CDEF_reg_dec (_, ctyp) -> [ctyp] + | CDEF_spec (_, ctyps, ctyp) -> ctyp :: ctyps | CDEF_fundef (id, _, _, instrs) -> (* TODO: Move this code to DEF_fundef -> CDEF_fundef translation, and modify bytecode.ott *) let _, Typ_aux (fn_typ, _) = @@ -1036,6 +1141,8 @@ let pp_ctype_def = function ^^ surround 2 0 lbrace (separate_map (semi ^^ hardline) (fun (id, ctyp) -> pp_id id ^^ string " : " ^^ pp_ctyp ctyp) ctors) rbrace let pp_cdef = function + | CDEF_spec (id, ctyps, ctyp) -> + pp_keyword "val" ^^ pp_id id ^^ space ^^ parens (separate_map (comma ^^ space) pp_ctyp ctyps) ^^ string " -> " ^^ pp_ctyp ctyp | CDEF_fundef (id, ret, args, instrs) -> let ret = match ret with | None -> empty @@ -1071,6 +1178,10 @@ let is_ct_list = function | CT_list _ -> true | _ -> false +let is_ct_vector = function + | CT_vector _ -> true + | _ -> false + let rec is_bitvector = function | [] -> true | AV_lit (L_aux (L_zero, _), _) :: avals -> is_bitvector avals @@ -1100,7 +1211,7 @@ let rec compile_aval ctx = function | AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal min_int64 n && Big_int.less_equal n max_int64 -> let gs = gensym () in [idecl CT_mpz gs; - iinit CT_mpz gs (F_lit (Big_int.to_string n ^ "L"), CT_int64)], + iinit CT_mpz gs (F_lit (Big_int.to_string n ^ "l"), CT_int64)], (F_id gs, CT_mpz), [iclear CT_mpz gs] @@ -1117,6 +1228,13 @@ let rec compile_aval ctx = function | AV_lit (L_aux (L_true, _), _) -> [], (F_lit "true", CT_bool), [] | AV_lit (L_aux (L_false, _), _) -> [], (F_lit "false", CT_bool), [] + | AV_lit (L_aux (L_real str, _), _) -> + let gs = gensym () in + [idecl CT_real gs; + iinit CT_real gs (F_lit ("\"" ^ str ^ "\""), CT_string)], + (F_id gs, CT_real), + [iclear CT_real gs] + | AV_lit (L_aux (_, l) as lit, _) -> c_error ~loc:l ("Encountered unexpected literal " ^ string_of_lit lit) @@ -1301,7 +1419,7 @@ let rec compile_match ctx apat cval case_label = [] | AP_global (pid, _), _ -> [icopy (CL_id pid) cval], [] | AP_id pid, (frag, ctyp) when is_ct_enum ctyp -> - [ijump (F_op (F_id pid, "!=", frag), CT_bool) case_label], [] + [idecl ctyp pid; ijump (F_op (F_id pid, "!=", frag), CT_bool) case_label], [] | AP_id pid, _ -> let ctyp = cval_ctyp cval in let init, cleanup = if is_stack_ctyp ctyp then [], [] else [ialloc ctyp pid], [iclear ctyp pid] in @@ -1355,15 +1473,14 @@ let label str = let rec compile_aexp ctx = function | AE_let (id, _, binding, body, typ) -> let setup, ctyp, call, cleanup = compile_aexp ctx binding in - let letb1, letb1c = + let letb_setup, letb_cleanup = if is_stack_ctyp ctyp then - [idecl ctyp id; call (CL_id id)], [] + [idecl ctyp id; iblock (setup @ [call (CL_id id)] @ cleanup)], [] else - [idecl ctyp id; ialloc ctyp id; call (CL_id id)], [iclear ctyp id] + [idecl ctyp id; ialloc ctyp id; iblock (setup @ [call (CL_id id)] @ cleanup)], [iclear ctyp id] in - let letb2 = setup @ letb1 @ cleanup in let setup, ctyp, call, cleanup = compile_aexp ctx body in - letb2 @ setup, ctyp, call, cleanup @ letb1c + letb_setup @ setup, ctyp, call, cleanup @ letb_cleanup | AE_app (id, vs, typ) -> compile_funcall ctx id vs typ @@ -1539,6 +1656,29 @@ let rec compile_aexp ctx = function (fun clexp -> icopy clexp unit_fragment), [] + | AE_loop (Until, cond, body) -> + let loop_start_label = label "repeat_" in + let loop_end_label = label "until_" in + let cond_setup, _, cond_call, cond_cleanup = compile_aexp ctx cond in + let body_setup, _, body_call, body_cleanup = compile_aexp ctx body in + let gs = gensym () in + let unit_gs = gensym () in + let loop_test = (F_unary ("!", F_id gs), CT_bool) in + [idecl CT_bool gs; idecl CT_unit unit_gs] + @ [ilabel loop_start_label] + @ [iblock (body_setup + @ [body_call (CL_id unit_gs)] + @ body_cleanup + @ cond_setup + @ [cond_call (CL_id gs)] + @ cond_cleanup + @ [ijump loop_test loop_end_label] + @ [igoto loop_start_label])] + @ [ilabel loop_end_label], + CT_unit, + (fun clexp -> icopy clexp unit_fragment), + [] + | AE_cast (aexp, typ) -> compile_aexp ctx aexp | AE_return (aval, typ) -> @@ -1583,15 +1723,20 @@ and compile_block ctx = function let gs = gensym () in setup @ [idecl CT_unit gs; call (CL_id gs)] @ cleanup @ rest -let rec pat_ids (P_aux (p_aux, (l, _)) as pat) = - match p_aux with - | P_id id -> [id] - | P_tup pats -> List.concat (List.map pat_ids pats) - | P_lit (L_aux (L_unit, _)) -> let gs = gensym () in [gs] - | P_wild -> let gs = gensym () in [gs] - | P_var (pat, _) -> pat_ids pat - | P_typ (_, pat) -> pat_ids pat - | _ -> c_error ~loc:l ("Cannot compile pattern " ^ string_of_pat pat ^ " to C") +(* FIXME: this function is a bit of a hack *) +let rec pat_ids (Typ_aux (arg_typ_aux, _) as arg_typ) (P_aux (p_aux, (l, _)) as pat) = + prerr_endline (string_of_typ arg_typ); + match p_aux, arg_typ_aux with + | P_id id, _ -> [id] + | P_tup pats, Typ_tup arg_typs when List.length pats = List.length arg_typs -> + List.concat (List.map2 pat_ids arg_typs pats) + | P_tup pats, _ -> c_error ~loc:l ("Cannot compile tuple pattern " ^ string_of_pat pat ^ " to C, as it doesn't have tuple type.") + | P_lit (L_aux (L_unit, _)), _ -> let gs = gensym () in [gs] + | P_wild, Typ_tup arg_typs -> List.map (fun _ -> let gs = gensym () in gs) arg_typs + | P_wild, _ -> let gs = gensym () in [gs] + | P_var (pat, _), _ -> pat_ids arg_typ pat + | P_typ (_, pat), _ -> pat_ids arg_typ pat + | _, _ -> c_error ~loc:l ("Cannot compile pattern " ^ string_of_pat pat ^ " to C") (** Compile a sail type definition into a IR one. Most of the actual work of translating the typedefs into C is done by the code @@ -1773,27 +1918,40 @@ let compile_def ctx = function [CDEF_reg_dec (id, ctyp_of_typ ctx typ)], ctx | DEF_reg_dec _ -> failwith "Unsupported register declaration" (* FIXME *) - | DEF_spec _ -> [], ctx + | DEF_spec (VS_aux (VS_val_spec (_, id, _, _), _)) -> + let _, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in + let arg_typs, ret_typ = match fn_typ with + | Typ_fn (Typ_aux (Typ_tup arg_typs, _), ret_typ, _) -> arg_typs, ret_typ + | Typ_fn (arg_typ, ret_typ, _) -> [arg_typ], ret_typ + | _ -> assert false + in + let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in + [CDEF_spec (id, arg_ctyps, ret_ctyp)], ctx | DEF_fundef (FD_aux (FD_function (_, _, _, [FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (pat, exp), _)), _)]), _)) -> - let aexp = map_functions (analyze_primop ctx) (c_literals ctx (anf exp)) in - prerr_endline (Pretty_print_sail.to_string (pp_aexp aexp)); + let aexp = map_functions (analyze_primop ctx) (c_literals ctx (no_shadow IdSet.empty (anf exp))) in + if string_of_id id = "system_barriers_decode" then prerr_endline (Pretty_print_sail.to_string (pp_aexp aexp)) else (); let setup, ctyp, call, cleanup = compile_aexp ctx aexp in let gs = gensym () in let pat = match pat with | P_aux (P_tup [], annot) -> P_aux (P_lit (mk_lit L_unit), annot) | _ -> pat in + let _, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in + let arg_typ, ret_typ = match fn_typ with + | Typ_fn (arg_typ, ret_typ, _) -> arg_typ, ret_typ + | _ -> assert false + in prerr_endline (string_of_id id ^ " : " ^ string_of_ctyp ctyp); if is_stack_ctyp ctyp then let instrs = [idecl ctyp gs] @ setup @ [call (CL_id gs)] @ cleanup @ [ireturn (F_id gs, ctyp)] in let instrs = fix_exception ctx instrs in - [CDEF_fundef (id, None, pat_ids pat, instrs)], ctx + [CDEF_fundef (id, None, pat_ids arg_typ pat, instrs)], ctx else let instrs = setup @ [call (CL_addr gs)] @ cleanup in let instrs = fix_early_return (CL_addr gs) ctx instrs in let instrs = fix_exception ctx instrs in - [CDEF_fundef (id, Some gs, pat_ids pat, instrs)], ctx + [CDEF_fundef (id, Some gs, pat_ids arg_typ pat, instrs)], ctx | DEF_fundef (FD_aux (FD_function (_, _, _, []), (l, _))) -> c_error ~loc:l "Encountered function with no clauses" @@ -1809,7 +1967,7 @@ let compile_def ctx = function [CDEF_type tdef], ctx | DEF_val (LB_aux (LB_val (pat, exp), _)) -> - let aexp = map_functions (analyze_primop ctx) (c_literals ctx (anf exp)) in + let aexp = map_functions (analyze_primop ctx) (c_literals ctx (no_shadow IdSet.empty (anf exp))) in let setup, ctyp, call, cleanup = compile_aexp ctx aexp in let apat = anf_pat ~global:true pat in let gs = gensym () in @@ -2073,8 +2231,9 @@ let sgen_ctyp = function | CT_enum (id, _) -> "enum " ^ sgen_id id | CT_variant (id, _) -> "struct " ^ sgen_id id | CT_list _ as l -> Util.zencode_string (string_of_ctyp l) - | CT_vector _ -> "int" (* FIXME *) + | CT_vector _ as v -> Util.zencode_string (string_of_ctyp v) | CT_string -> "sail_string" + | CT_real -> "real" let sgen_ctyp_name = function | CT_unit -> "unit" @@ -2089,8 +2248,9 @@ let sgen_ctyp_name = function | CT_enum (id, _) -> sgen_id id | CT_variant (id, _) -> sgen_id id | CT_list _ as l -> Util.zencode_string (string_of_ctyp l) - | CT_vector _ -> "int" (* FIXME *) + | CT_vector _ as v -> Util.zencode_string (string_of_ctyp v) | CT_string -> "sail_string" + | CT_real -> "real" let sgen_cval_param (frag, ctyp) = match ctyp with @@ -2149,12 +2309,41 @@ let rec codegen_instr ctx (I_aux (instr, _)) = ^^ jump 2 2 (separate_map hardline (codegen_instr ctx) instrs) ^^ hardline ^^ string " }" | I_funcall (x, f, args, ctyp) -> - let args = Util.string_of_list ", " sgen_cval args in + let c_args = Util.string_of_list ", " sgen_cval args in let fname = if Env.is_extern f ctx.tc_env "c" then Env.get_extern f ctx.tc_env "c" else sgen_id f in + let fname = + match fname, ctyp with + | "internal_pick", _ -> Printf.sprintf "pick_%s" (sgen_ctyp_name ctyp) + | "eq_anything", _ -> + begin match args with + | cval :: _ -> Printf.sprintf "eq_%s" (sgen_ctyp_name (cval_ctyp cval)) + | _ -> c_error "eq_anything function with bad arity." + end + | "length", _ -> + begin match args with + | cval :: _ -> Printf.sprintf "length_%s" (sgen_ctyp_name (cval_ctyp cval)) + | _ -> c_error "length function with bad arity." + end + | "vector_access", CT_bit -> "bitvector_access" + | "vector_access", _ -> + begin match args with + | cval :: _ -> Printf.sprintf "vector_access_%s" (sgen_ctyp_name (cval_ctyp cval)) + | _ -> c_error "vector access function with bad arity." + end + | "vector_update_subrange", _ -> Printf.sprintf "vector_update_subrange_%s" (sgen_ctyp_name ctyp) + | "vector_subrange", _ -> Printf.sprintf "vector_subrange_%s" (sgen_ctyp_name ctyp) + | "vector_update", CT_uint64 _ -> "update_uint64_t" + | "vector_update", CT_bv _ -> "update_bv" + | "vector_update", _ -> Printf.sprintf "vector_update_%s" (sgen_ctyp_name ctyp) + | "undefined_vector", CT_uint64 _ -> "undefined_uint64_t" + | "undefined_vector", CT_bv _ -> "undefined_bv_t" + | "undefined_vector", _ -> Printf.sprintf "undefined_vector_%s" (sgen_ctyp_name ctyp) + | fname, _ -> fname + in if is_stack_ctyp ctyp then - string (Printf.sprintf " %s = %s(%s);" (sgen_clexp_pure x) fname args) + string (Printf.sprintf " %s = %s(%s);" (sgen_clexp_pure x) fname c_args) else - string (Printf.sprintf " %s(%s, %s);" fname (sgen_clexp x) args) + string (Printf.sprintf " %s(%s, %s);" fname (sgen_clexp x) c_args) | I_clear (ctyp, id) -> string (Printf.sprintf " clear_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id)) | I_init (ctyp, id, cval) -> @@ -2165,6 +2354,23 @@ let rec codegen_instr ctx (I_aux (instr, _)) = (sgen_cval_param cval)) | I_alloc (ctyp, id) -> string (Printf.sprintf " init_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id)) + (* FIXME: This just covers the cases we see in our specs, need a + special conversion code-generator for full generality *) + | I_convert (x, CT_tup ctyps1, y, CT_tup ctyps2) when List.length ctyps1 = List.length ctyps2 -> + let convert i (ctyp1, ctyp2) = + if ctyp_equal ctyp1 ctyp2 then string " /* no change */" + else if is_stack_ctyp ctyp1 then + string (Printf.sprintf " %s.ztup%i = convert_%s_of_%s(%s.ztup%i);" + (sgen_clexp_pure x) + i + (sgen_ctyp_name ctyp1) + (sgen_ctyp_name ctyp2) + (sgen_id y) + i) + else + c_error "Cannot compile type conversion" + in + separate hardline (List.mapi convert (List.map2 (fun x y -> (x, y)) ctyps1 ctyps2)) | I_convert (x, ctyp1, y, ctyp2) -> if is_stack_ctyp ctyp1 then string (Printf.sprintf " %s = convert_%s_of_%s(%s);" @@ -2195,8 +2401,14 @@ let rec codegen_instr ctx (I_aux (instr, _)) = let codegen_type_def ctx = function | CTD_enum (id, ids) -> + let codegen_eq = + let name = sgen_id id in + string (Printf.sprintf "bool eq_%s(enum %s op1, enum %s op2) { return op1 == op2; }" name name name) + in string (Printf.sprintf "// enum %s" (string_of_id id)) ^^ hardline ^^ separate space [string "enum"; codegen_id id; lbrace; separate_map (comma ^^ space) upper_codegen_id ids; rbrace ^^ semi] + ^^ twice hardline + ^^ codegen_eq | CTD_struct (id, ctors) -> (* Generate a set_T function for every struct T *) @@ -2224,6 +2436,9 @@ let codegen_type_def ctx = function (separate hardline (Bindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat)) rbrace in + let codegen_eq = + string (Printf.sprintf "bool eq_%s(struct %s op1, struct %s op2) { return true; }" (sgen_id id) (sgen_id id) (sgen_id id)) + in (* Generate the struct and add the generated functions *) let codegen_ctor (id, ctyp) = string (sgen_ctyp ctyp) ^^ space ^^ codegen_id id @@ -2239,6 +2454,8 @@ let codegen_type_def ctx = function ^^ codegen_init "init" id (ctor_bindings ctors) ^^ twice hardline ^^ codegen_init "clear" id (ctor_bindings ctors) + ^^ twice hardline + ^^ codegen_eq | CTD_variant (id, tus) -> let codegen_tu (ctor_id, ctyp) = @@ -2403,7 +2620,8 @@ let codegen_list_init id = let codegen_list_clear id ctyp = string (Printf.sprintf "void clear_%s(%s *rop) {\n" (sgen_id id) (sgen_id id)) ^^ string (Printf.sprintf " if (*rop == NULL) return;") - ^^ string (Printf.sprintf " clear_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) + ^^ (if is_stack_ctyp ctyp then empty + else string (Printf.sprintf " clear_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))) ^^ string (Printf.sprintf " clear_%s(&(*rop)->tl);\n" (sgen_id id)) ^^ string " free(*rop);" ^^ string "}" @@ -2412,8 +2630,11 @@ let codegen_list_set id ctyp = string (Printf.sprintf "void internal_set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) ^^ string " if (op == NULL) { *rop = NULL; return; };\n" ^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id)) - ^^ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) - ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, op->hd);\n" (sgen_ctyp_name ctyp)) + ^^ (if is_stack_ctyp ctyp then + string " (*rop)->hd = op->hd;\n" + else + string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) + ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, op->hd);\n" (sgen_ctyp_name ctyp))) ^^ string (Printf.sprintf " internal_set_%s(&(*rop)->tl, op->tl);\n" (sgen_id id)) ^^ string "}" ^^ twice hardline @@ -2426,11 +2647,20 @@ let codegen_cons id ctyp = let cons_id = mk_id ("cons#" ^ string_of_ctyp ctyp) in string (Printf.sprintf "void %s(%s *rop, const %s x, const %s xs) {\n" (sgen_id cons_id) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) ^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id)) - ^^ string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) - ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, x);\n" (sgen_ctyp_name ctyp)) + ^^ (if is_stack_ctyp ctyp then + string " (*rop)->hd = x;\n" + else + string (Printf.sprintf " init_%s(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)) + ^^ string (Printf.sprintf " set_%s(&(*rop)->hd, x);\n" (sgen_ctyp_name ctyp))) ^^ string " (*rop)->tl = xs;\n" ^^ string "}" +let codegen_pick id ctyp = + if is_stack_ctyp ctyp then + string (Printf.sprintf "%s pick_%s(const %s xs) { return xs->hd; }" (sgen_ctyp ctyp) (sgen_ctyp_name ctyp) (sgen_id id)) + else + string (Printf.sprintf "void pick_%s(%s *x, const %s xs) { set_%s(x, xs->hd); }" (sgen_ctyp_name ctyp) (sgen_ctyp ctyp) (sgen_id id) (sgen_ctyp_name ctyp)) + let codegen_list ctx ctyp = let id = mk_id (string_of_ctyp (CT_list ctyp)) in if IdSet.mem id !generated then @@ -2443,6 +2673,94 @@ let codegen_list ctx ctyp = ^^ codegen_list_clear id ctyp ^^ twice hardline ^^ codegen_list_set id ctyp ^^ twice hardline ^^ codegen_cons id ctyp ^^ twice hardline + ^^ codegen_pick id ctyp ^^ twice hardline + end + +let codegen_vector ctx (direction, ctyp) = + let id = mk_id (string_of_ctyp (CT_vector (direction, ctyp))) in + if IdSet.mem id !generated then + empty + else + let vector_typedef = + string (Printf.sprintf "struct %s {\n size_t len;\n %s *data;\n};\n" (sgen_id id) (sgen_ctyp ctyp)) + ^^ string (Printf.sprintf "typedef struct %s %s;" (sgen_id id) (sgen_id id)) + in + let vector_init = + string (Printf.sprintf "void init_%s(%s *rop) {\n rop->len = 0;\n rop->data = NULL;\n}" (sgen_id id) (sgen_id id)) + in + let vector_set = + string (Printf.sprintf "void set_%s(%s *rop, %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) + ^^ string (Printf.sprintf " clear_%s(rop);\n" (sgen_id id)) + ^^ string " rop->len = op.len;\n" + ^^ string (Printf.sprintf " rop->data = malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp)) + ^^ string " for (int i = 0; i < op.len; i++) {\n" + ^^ string (if is_stack_ctyp ctyp then + " (rop->data)[i] = op.data[i];\n" + else + Printf.sprintf " init_%s((rop->data) + i);\n set_%s((rop->data) + i, op.data[i]);\n" (sgen_ctyp_name ctyp) (sgen_ctyp_name ctyp)) + ^^ string " }\n" + ^^ string "}" + in + let vector_clear = + string (Printf.sprintf "void clear_%s(%s *rop) {\n" (sgen_id id) (sgen_id id)) + ^^ (if is_stack_ctyp ctyp then empty + else + string " for (int i = 0; i < (rop->len); i++) {\n" + ^^ string (Printf.sprintf " clear_%s((rop->data) + i);\n" (sgen_ctyp_name ctyp)) + ^^ string " }\n") + ^^ string " if (rop->data != NULL) free(rop->data);\n" + ^^ string "}" + in + let vector_update = + string (Printf.sprintf "void vector_update_%s(%s *rop, %s op, mpz_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + ^^ string " int m = mpz_get_ui(n);\n" + ^^ string " if (rop->data == op.data) {\n" + ^^ string (if is_stack_ctyp ctyp then + " rop->data[m] = elem;\n" + else + Printf.sprintf " set_%s((rop->data) + m, elem);\n" (sgen_ctyp_name ctyp)) + ^^ string " } else {\n" + ^^ string (Printf.sprintf " set_%s(rop, op);\n" (sgen_id id)) + ^^ string (if is_stack_ctyp ctyp then + " rop->data[m] = elem;\n" + else + Printf.sprintf " set_%s((rop->data) + m, elem);\n" (sgen_ctyp_name ctyp)) + ^^ string " }\n" + ^^ string "}" + in + let vector_access = + if is_stack_ctyp ctyp then + string (Printf.sprintf "%s vector_access_%s(%s op, mpz_t n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) + ^^ string " int m = mpz_get_ui(n);\n" + ^^ string " return op.data[m];\n" + ^^ string "}" + else + string (Printf.sprintf "void vector_access_%s(%s *rop, %s op, mpz_t n) {\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + ^^ string " int m = mpz_get_ui(n);\n" + ^^ string (Printf.sprintf " set_%s(rop, op.data[m]);\n" (sgen_ctyp_name ctyp)) + ^^ string "}" + in + let vector_undefined = + string (Printf.sprintf "void undefined_vector_%s(%s *rop, mpz_t len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + ^^ string (Printf.sprintf " rop->len = mpz_get_ui(len);\n") + ^^ string (Printf.sprintf " rop->data = malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp)) + ^^ string " for (int i = 0; i < (rop->len); i++) {\n" + ^^ string (if is_stack_ctyp ctyp then + " (rop->data)[i] = elem;\n" + else + Printf.sprintf " init_%s((rop->data) + i);\n set_%s((rop->data) + i, elem);\n" (sgen_ctyp_name ctyp) (sgen_ctyp_name ctyp)) + ^^ string " }\n" + ^^ string "}" + in + begin + generated := IdSet.add id !generated; + vector_typedef ^^ twice hardline + ^^ vector_init ^^ twice hardline + ^^ vector_clear ^^ twice hardline + ^^ vector_undefined ^^ twice hardline + ^^ vector_access ^^ twice hardline + ^^ vector_set ^^ twice hardline + ^^ vector_update ^^ twice hardline end let codegen_def' ctx = function @@ -2450,6 +2768,14 @@ let codegen_def' ctx = function string (Printf.sprintf "// register %s" (string_of_id id)) ^^ hardline ^^ string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id)) + | CDEF_spec (id, arg_ctyps, ret_ctyp) -> + if Env.is_extern id ctx.tc_env "c" then + empty + else if is_stack_ctyp ret_ctyp then + string (Printf.sprintf "%s %s(%s);" (sgen_ctyp ret_ctyp) (sgen_id id) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) + else + string (Printf.sprintf "void %s(%s *rop, %s);" (sgen_id id) (sgen_ctyp ret_ctyp) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) + | CDEF_fundef (id, ret_arg, args, instrs) as def -> if !opt_ddump_flow_graphs then make_dot id (instrs_graph instrs) else (); let instrs = add_local_labels instrs in @@ -2504,13 +2830,20 @@ let codegen_def ctx def = | CT_list ctyp -> ctyp | _ -> assert false in + let unvector = function + | CT_vector (direction, ctyp) -> (direction, ctyp) + | _ -> assert false + in let tups = List.filter is_ct_tup (cdef_ctyps ctx def) in let tups = List.map (fun ctyp -> codegen_tup ctx (untup ctyp)) tups in let lists = List.filter is_ct_list (cdef_ctyps ctx def) in let lists = List.map (fun ctyp -> codegen_list ctx (unlist ctyp)) lists in - prerr_endline (Pretty_print_sail.to_string (pp_cdef def)); + let vectors = List.filter is_ct_vector (cdef_ctyps ctx def) in + let vectors = List.map (fun ctyp -> codegen_vector ctx (unvector ctyp)) vectors in + (* prerr_endline (Pretty_print_sail.to_string (pp_cdef def)); *) concat tups ^^ concat lists + ^^ concat vectors ^^ codegen_def' ctx def let compile_ast ctx (Defs defs) = @@ -2542,14 +2875,29 @@ let compile_ast ctx (Defs defs) = List.map (fun n -> Printf.sprintf " kill_letbind_%d();" n) ctx.letbinds in + let regs = c_ast_registers cdefs in + + let register_init_clear (id, ctyp) = + if is_stack_ctyp ctyp then + [], [] + else + [ Printf.sprintf " init_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id) ], + [ Printf.sprintf " clear_%s(&%s);" (sgen_ctyp_name ctyp) (sgen_id id) ] + in + let postamble = separate hardline (List.map string ( [ "int main(void)"; - "{" ] + "{"; + " setup_real();" ] @ fst exn_boilerplate + @ List.concat (List.map (fun r -> fst (register_init_clear r)) regs) + @ (if regs = [] then [] else [ " zinitializze_registers(UNIT);" ]) @ letbind_initializers @ [ " zmain(UNIT);" ] @ letbind_finalizers + @ List.concat (List.map (fun r -> snd (register_init_clear r)) regs) @ snd exn_boilerplate + @ [ " return 0;" ] @ [ "}" ] )) in diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 71efcb22..d14097af 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -54,7 +54,7 @@ open Ast_util module Big_int = Nat_big_num open Type_check -let size_set_limit = 32 +let size_set_limit = 64 let optmap v f = match v with @@ -69,6 +69,11 @@ let bindings_union s1 s2 = | _, (Some x) -> Some x | (Some x), _ -> Some x | _, _ -> None) s1 s2 +let kbindings_union s1 s2 = + KBindings.merge (fun _ x y -> match x,y with + | _, (Some x) -> Some x + | (Some x), _ -> Some x + | _, _ -> None) s1 s2 let subst_nexp substs nexp = let rec s_snexp substs (Nexp_aux (ne,l) as nexp) = @@ -615,9 +620,9 @@ let bindings_from_pat p = and aux_fpat (FP_aux (FP_Fpat (_,p), _)) = aux_pat p in aux_pat p -let remove_bound env pat = +let remove_bound (substs,ksubsts) pat = let bound = bindings_from_pat pat in - List.fold_left (fun sub v -> Bindings.remove v sub) env bound + List.fold_left (fun sub v -> Bindings.remove v sub) substs bound, ksubsts (* Attempt simple pattern matches *) let lit_match = function @@ -721,6 +726,30 @@ let int_of_str_lit = function | L_bin bin -> Big_int.of_string ("0b" ^ bin) | _ -> assert false +let bits_of_lit = function + | L_bin bin -> bin + | L_hex hex -> hex_to_bin hex + | _ -> assert false + +let slice_lit (L_aux (lit,ll)) i len (Ord_aux (ord,_)) = + let i = Big_int.to_int i in + let len = Big_int.to_int len in + match match ord with + | Ord_inc -> Some i + | Ord_dec -> Some (len - i) + | Ord_var _ -> None + with + | None -> None + | Some i -> + match lit with + | L_bin bin -> Some (L_aux (L_bin (String.sub bin i len),Generated ll)) + | _ -> assert false + +let concat_vec lit1 lit2 = + let bits1 = bits_of_lit lit1 in + let bits2 = bits_of_lit lit2 in + L_bin (bits1 ^ bits2) + let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) = match l1,l2 with | (L_zero|L_false), (L_zero|L_false) @@ -758,16 +787,47 @@ let try_app (l,ann) (id,args) = | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit,_), _)] -> Some (E_aux (E_lit (L_aux (L_num (int_of_str_lit lit),new_l)),(l,ann))) | _ -> None + else if is_id "slice" then + match args with + | [E_aux (E_lit (L_aux ((L_hex _| L_bin _),_) as lit), + (_,Some (_,Typ_aux (Typ_app (_,[_;Typ_arg_aux (Typ_arg_order ord,_);_]),_),_))); + E_aux (E_lit L_aux (L_num i,_), _); + E_aux (E_lit L_aux (L_num len,_), _)] -> + (match slice_lit lit i len ord with + | Some lit' -> Some (E_aux (E_lit lit',(l,ann))) + | None -> None) + | _ -> None + else if is_id "bitvector_concat" then + match args with + | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit1,_), _); + E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit2,_), _)] -> + Some (E_aux (E_lit (L_aux (concat_vec lit1 lit2,new_l)),(l,ann))) + | _ -> None else if is_id "shl_int" then match args with | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> Some (E_aux (E_lit (L_aux (L_num (Big_int.shift_left i (Big_int.to_int j)),new_l)),(l,ann))) | _ -> None - else if is_id "mult_int" then + else if is_id "mult_int" || is_id "mult_range" then match args with | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> Some (E_aux (E_lit (L_aux (L_num (Big_int.mul i j),new_l)),(l,ann))) | _ -> None + else if is_id "quotient_nat" then + match args with + | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> + Some (E_aux (E_lit (L_aux (L_num (Big_int.div i j),new_l)),(l,ann))) + | _ -> None + else if is_id "add_range" then + match args with + | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> + Some (E_aux (E_lit (L_aux (L_num (Big_int.add i j),new_l)),(l,ann))) + | _ -> None + else if is_id "negate_range" then + match args with + | [E_aux (E_lit L_aux (L_num i,_),_)] -> + Some (E_aux (E_lit (L_aux (L_num (Big_int.negate i),new_l)),(l,ann))) + | _ -> None else if is_id "ex_int" then match args with | [E_aux (E_lit lit,(l,_))] -> Some (E_aux (E_lit lit,(l,ann))) @@ -1034,6 +1094,13 @@ let apply_pat_choices choices = e_assert = rewrite_assert; e_case = rewrite_case } +(* Check whether the current environment with the given kid assignments is + inconsistent (and hence whether the code is dead) *) +let is_env_inconsistent env ksubsts = + let env = KBindings.fold (fun k nexp env -> + Env.add_constraint (nc_eq (nvar k) nexp) env) ksubsts env in + prove env nc_false + let split_defs all_errors splits defs = let no_errors_happened = ref true in let split_constructors (Defs defs) = @@ -1065,8 +1132,13 @@ let split_defs all_errors splits defs = let (refinements, defs') = split_constructors defs in + (* COULD DO: dead code is only eliminated at if expressions, but we could + also cut out impossible case branches and code after assertions. *) + (* Constant propogation. Takes maps of immutable/mutable variables to subsitute. + The substs argument also contains the current type-level kid refinements + so that we can check for dead code. Extremely conservative about evaluation order of assignments in subexpressions, dropping assignments rather than committing to any particular order *) @@ -1123,7 +1195,7 @@ let split_defs all_errors splits defs = let env = Type_check.env_of_annot (l, annot) in (try match Env.lookup_id id env with - | Local (Immutable,_) -> Bindings.find id substs + | Local (Immutable,_) -> Bindings.find id (fst substs) | Local (Mutable,_) -> Bindings.find id assigns | _ -> exp with Not_found -> exp),assigns @@ -1154,20 +1226,48 @@ let split_defs all_errors splits defs = re (E_tuple es') assigns | E_if (e1,e2,e3) -> let e1',assigns = const_prop_exp substs assigns e1 in - let e2',assigns2 = const_prop_exp substs assigns e2 in - let e3',assigns3 = const_prop_exp substs assigns e3 in - (match drop_casts e1' with + let e1_no_casts = drop_casts e1' in + (match e1_no_casts with | E_aux (E_lit (L_aux ((L_true|L_false) as lit ,_)),_) -> - (match lit with L_true -> e2',assigns2 | _ -> e3',assigns3) + (match lit with + | L_true -> const_prop_exp substs assigns e2 + | _ -> const_prop_exp substs assigns e3) | _ -> - let assigns = isubst_minus_set assigns (assigned_vars e2) in - let assigns = isubst_minus_set assigns (assigned_vars e3) in - re (E_if (e1',e2',e3')) assigns) + (* If the guard is an equality check, propagate the value. *) + let env1 = env_of e1_no_casts in + let is_equal id = + List.exists (fun id' -> Id.compare id id' == 0) + (Env.get_overloads (Id_aux (DeIid "==", Parse_ast.Unknown)) + env1) + in + let substs_true = + match e1_no_casts with + | E_aux (E_app (id, [E_aux (E_id var,_); vl]),_) + | E_aux (E_app (id, [vl; E_aux (E_id var,_)]),_) + when is_equal id -> + if is_value vl then + (match Env.lookup_id var env1 with + | Local (Immutable,_) -> Bindings.add var vl (fst substs),snd substs + | _ -> substs) + else substs + | _ -> substs + in + (* Discard impossible branches *) + if is_env_inconsistent (env_of e2) (snd substs) then + const_prop_exp substs assigns e3 + else if is_env_inconsistent (env_of e3) (snd substs) then + const_prop_exp substs_true assigns e2 + else + let e2',assigns2 = const_prop_exp substs_true assigns e2 in + let e3',assigns3 = const_prop_exp substs assigns e3 in + let assigns = isubst_minus_set assigns (assigned_vars e2) in + let assigns = isubst_minus_set assigns (assigned_vars e3) in + re (E_if (e1',e2',e3')) assigns) | E_for (id,e1,e2,e3,ord,e4) -> (* Treat e1, e2 and e3 (from, to and by) as a non-det tuple *) let e1',e2',e3',assigns = non_det_exp_3 e1 e2 e3 in let assigns = isubst_minus_set assigns (assigned_vars e4) in - let e4',_ = const_prop_exp (Bindings.remove id substs) assigns e4 in + let e4',_ = const_prop_exp (Bindings.remove id (fst substs),snd substs) assigns e4 in re (E_for (id,e1',e2',e3',ord,e4')) assigns | E_loop (loop,e1,e2) -> let assigns = isubst_minus_set assigns (IdSet.union (assigned_vars e1) (assigned_vars e2)) in @@ -1227,7 +1327,7 @@ let split_defs all_errors splits defs = | Some (E_aux (_,(_,annot')) as exp,newbindings,kbindings) -> let exp = nexp_subst_exp (kbindings_from_list kbindings) exp in let newbindings_env = bindings_from_list newbindings in - let substs' = bindings_union substs newbindings_env in + let substs' = bindings_union (fst substs) newbindings_env, snd substs in const_prop_exp substs' assigns exp) | E_let (lb,e2) -> begin @@ -1245,7 +1345,7 @@ let split_defs all_errors splits defs = | Some (e'',bindings,kbindings) -> let e'' = nexp_subst_exp (kbindings_from_list kbindings) e'' in let bindings = bindings_from_list bindings in - let substs'' = bindings_union substs' bindings in + let substs'' = bindings_union (fst substs') bindings, snd substs' in const_prop_exp substs'' assigns e'' else plain () end @@ -1350,9 +1450,9 @@ let split_defs all_errors splits defs = let cases = List.map (function | FCL_aux (FCL_Funcl (_,pexp), ann) -> pexp) fcls in - match can_match_with_env env arg cases Bindings.empty Bindings.empty with + match can_match_with_env env arg cases (Bindings.empty,KBindings.empty) Bindings.empty with | Some (exp,bindings,kbindings) -> - let substs = bindings_from_list bindings in + let substs = bindings_from_list bindings, kbindings_from_list kbindings in let result,_ = const_prop_exp substs Bindings.empty exp in let result = match result with | E_aux (E_return e,_) -> e @@ -1361,7 +1461,7 @@ let split_defs all_errors splits defs = if is_value result then Some result else None | None -> None - and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases substs assigns = + and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases (substs,ksubsts) assigns = let rec findpat_generic check_pat description assigns = function | [] -> (Reporting_basic.print_err false true l "Monomorphisation" ("Failed to find a case for " ^ description); None) @@ -1373,7 +1473,7 @@ let split_defs all_errors splits defs = Some (exp, [(id', exp0)], []) | (Pat_aux (Pat_when (P_aux (P_id id',_),guard,exp),_))::tl when pat_id_is_variable env id' -> begin - let substs = Bindings.add id' exp0 substs in + let substs = Bindings.add id' exp0 substs, ksubsts in let (E_aux (guard,_)),assigns = const_prop_exp substs assigns guard in match guard with | E_lit (L_aux (L_true,_)) -> Some (exp,[(id',exp0)],[]) @@ -1385,7 +1485,8 @@ let split_defs all_errors splits defs = | DoesNotMatch -> findpat_generic check_pat description assigns tl | DoesMatch (vsubst,ksubst) -> begin let guard = nexp_subst_exp (kbindings_from_list ksubst) guard in - let substs = bindings_union substs (bindings_from_list vsubst) in + let substs = bindings_union substs (bindings_from_list vsubst), + kbindings_union ksubsts (kbindings_from_list ksubst) in let (E_aux (guard,_)),assigns = const_prop_exp substs assigns guard in match guard with | E_lit (L_aux (L_true,_)) -> Some (exp,vsubst,ksubst) @@ -1463,8 +1564,8 @@ let split_defs all_errors splits defs = can_match_with_env env exp in - let subst_exp substs exp = - let substs = bindings_from_list substs in + let subst_exp substs ksubsts exp = + let substs = bindings_from_list substs, ksubsts in fst (const_prop_exp substs Bindings.empty exp) in @@ -1813,8 +1914,9 @@ let split_defs all_errors splits defs = | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then List.map (fun (pat',substs,pchoices,ksubsts) -> - let exp' = nexp_subst_exp (kbindings_from_list ksubsts) e in - let exp' = subst_exp substs exp' in + let ksubsts = kbindings_from_list ksubsts in + let exp' = nexp_subst_exp ksubsts e in + let exp' = subst_exp substs ksubsts exp' in let exp' = apply_pat_choices pchoices exp' in let exp' = stop_at_false_assertions exp' in Pat_aux (Pat_exp (pat', map_exp exp'),l)) @@ -1833,11 +1935,12 @@ let split_defs all_errors splits defs = | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then List.map (fun (pat',substs,pchoices,ksubsts) -> - let exp1' = nexp_subst_exp (kbindings_from_list ksubsts) e1 in - let exp1' = subst_exp substs exp1' in + let ksubsts = kbindings_from_list ksubsts in + let exp1' = nexp_subst_exp ksubsts e1 in + let exp1' = subst_exp substs ksubsts exp1' in let exp1' = apply_pat_choices pchoices exp1' in - let exp2' = nexp_subst_exp (kbindings_from_list ksubsts) e2 in - let exp2' = subst_exp substs exp2' in + let exp2' = nexp_subst_exp ksubsts e2 in + let exp2' = subst_exp substs ksubsts exp2' in let exp2' = apply_pat_choices pchoices exp2' in let exp2' = stop_at_false_assertions exp2' in Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)) @@ -1917,27 +2020,27 @@ let findi f = let mapat f is xs = let rec aux n = function - | _, [] -> [] - | (i,_)::is, h::t when i = n -> + | [] -> [] + | h::t when Util.IntSet.mem n is -> let h' = f h in - let t' = aux (n+1) (is, t) in + let t' = aux (n+1) t in h'::t' - | is, h::t -> - let t' = aux (n+1) (is, t) in + | h::t -> + let t' = aux (n+1) t in h::t' - in aux 0 (is, xs) + in aux 0 xs let mapat_extra f is xs = let rec aux n = function - | _, [] -> [], [] - | (i,v)::is, h::t when i = n -> - let h',x = f v h in - let t',xs = aux (n+1) (is, t) in + | [] -> [], [] + | h::t when Util.IntSet.mem n is -> + let h',x = f h in + let t',xs = aux (n+1) t in h'::t',x::xs - | is, h::t -> - let t',xs = aux (n+1) (is, t) in + | h::t -> + let t',xs = aux (n+1) t in h::t',xs - in aux 0 (is, xs) + in aux 0 xs let tyvars_bound_in_pat pat = let open Rewriter in @@ -1975,34 +2078,45 @@ let sizes_of_annot = function | _,None -> KidSet.empty | _,Some (env,typ,_) -> sizes_of_typ (Env.base_typ_of env typ) -let change_parameter_pat kid = function - | P_aux (P_id var, (l,_)) - | P_aux (P_typ (_,P_aux (P_id var, (l,_))),_) - -> P_aux (P_id var, (l,None)), (var,kid) +let change_parameter_pat = function + | P_aux (P_id var, (l,Some (env,typ,_))) + | P_aux (P_typ (_,P_aux (P_id var, (l,Some (env,typ,_)))),_) -> + P_aux (P_id var, (l,None)), var | P_aux (_,(l,_)) -> raise (Reporting_basic.err_unreachable l "Expected variable pattern") (* We add code to change the itself('n) parameter into the corresponding integer. *) -let add_var_rebind exp (var,kid) = +let add_var_rebind exp var = let l = Generated Unknown in let annot = (l,None) in E_aux (E_let (LB_aux (LB_val (P_aux (P_id var,annot), E_aux (E_app (mk_id "size_itself_int",[E_aux (E_id var,annot)]),annot)),annot),exp),annot) (* atom('n) arguments to function calls need to be rewritten *) -let replace_with_the_value (E_aux (_,(l,_)) as exp) = +let replace_with_the_value bound_nexps (E_aux (_,(l,_)) as exp) = let env = env_of exp in let typ, wrap = match typ_of exp with | Typ_aux (Typ_exist (kids,nc,typ),l) -> typ, fun t -> Typ_aux (Typ_exist (kids,nc,t),l) | typ -> typ, fun x -> x in let typ = Env.expand_synonyms env typ in + let replace_size size = + (* TODO: pick simpler nexp when there's a choice (also in pretty printer) *) + let is_equal nexp = + prove env (NC_aux (NC_equal (size,nexp), Parse_ast.Unknown)) + in + if is_nexp_constant size then size else + match List.find is_equal bound_nexps with + | nexp -> nexp + | exception Not_found -> size + in let mk_exp nexp l l' = - E_aux (E_cast (wrap (Typ_aux (Typ_app (Id_aux (Id "itself",Generated Unknown), - [Typ_arg_aux (Typ_arg_nexp nexp,l')]),Generated Unknown)), - E_aux (E_app (Id_aux (Id "make_the_value",Generated Unknown),[exp]),(Generated l,None))), - (Generated l,None)) + let nexp = replace_size nexp in + E_aux (E_cast (wrap (Typ_aux (Typ_app (Id_aux (Id "itself",Generated Unknown), + [Typ_arg_aux (Typ_arg_nexp nexp,l')]),Generated Unknown)), + E_aux (E_app (Id_aux (Id "make_the_value",Generated Unknown),[exp]),(Generated l,None))), + (Generated l,None)) in match typ with | Typ_aux (Typ_app (Id_aux (Id "range",_), @@ -2032,91 +2146,77 @@ let replace_type env typ = let rewrite_size_parameters env (Defs defs) = let open Rewriter in - let size_vars pexp = - fst (fold_pexp - { (compute_exp_alg KidSet.empty KidSet.union) with - e_aux = (fun ((s,e),annot) -> KidSet.union s (sizes_of_annot annot), E_aux (e,annot)); - e_let = (fun ((sl,lb),(s2,e2)) -> KidSet.union sl (KidSet.diff s2 (tyvars_bound_in_lb lb)), E_let (lb,e2)); - e_for = (fun (id,(s1,e1),(s2,e2),(s3,e3),ord,(s4,e4)) -> - let kid = mk_kid ("loop_" ^ string_of_id id) in - KidSet.union s1 (KidSet.union s2 (KidSet.union s3 (KidSet.remove kid s4))), - E_for (id,e1,e2,e3,ord,e4)); - pat_exp = (fun ((sp,pat),(s,e)) -> KidSet.diff s (tyvars_bound_in_pat pat), Pat_exp (pat,e))} - pexp) - in - let exposed_sizes_funcl fnsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = - let sizes = size_vars pexp in - let pat,guard,exp,pannot = destruct_pexp pexp in - let visible_tyvars = - KidSet.union - (Pretty_print_lem.lem_tyvars_of_typ (pat_typ_of pat)) - (Pretty_print_lem.lem_tyvars_of_typ (typ_of exp)) - in - let expose_tyvars = KidSet.diff sizes visible_tyvars in - KidSet.union fnsizes expose_tyvars - in - let sizes_funcl expose_tyvars fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = + let open Util in + + let sizes_funcl fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = let pat,guard,exp,pannot = destruct_pexp pexp in let parameters = match pat with | P_aux (P_tup ps,_) -> ps | _ -> [pat] in - let to_change = Util.map_filter - (fun kid -> - let check (P_aux (_,(_,Some (env,typ,_)))) = - match Env.expand_synonyms env typ with - Typ_aux (Typ_app(Id_aux (Id "range",_), - [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_); - Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid'',_)),_)]),_) -> - if Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 then Some kid else None - | Typ_aux (Typ_app(Id_aux (Id "atom", _), - [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_)]), _) -> - if Kid.compare kid kid' = 0 then Some kid else None - | _ -> None - in match findi check parameters with - | None -> (Reporting_basic.print_error (Reporting_basic.Err_general (l, - ("Unable to find an argument for " ^ string_of_kid kid))); - None) - | Some i -> Some i) - (KidSet.elements expose_tyvars) + let add_parameter (i,nmap) (P_aux (_,(_,Some (env,typ,_)))) = + let nmap = + match Env.base_typ_of env typ with + Typ_aux (Typ_app(Id_aux (Id "range",_), + [Typ_arg_aux (Typ_arg_nexp nexp,_); + Typ_arg_aux (Typ_arg_nexp nexp',_)]),_) + when Nexp.compare nexp nexp' = 0 && not (NexpMap.mem nexp nmap) -> + NexpMap.add nexp i nmap + | Typ_aux (Typ_app(Id_aux (Id "atom", _), + [Typ_arg_aux (Typ_arg_nexp nexp,_)]), _) + when not (NexpMap.mem nexp nmap) -> + NexpMap.add nexp i nmap + | _ -> nmap + in (i+1,nmap) + in + let (_,nexp_map) = List.fold_left add_parameter (0,NexpMap.empty) parameters in + let nexp_list = NexpMap.bindings nexp_map in + let parameters_for = function + | Some (env,typ,_) -> + begin match Env.base_typ_of env typ with + | Typ_aux (Typ_app (Id_aux (Id "vector",_), [Typ_arg_aux (Typ_arg_nexp size,_);_;_]),_) + when not (is_nexp_constant size) -> + begin + match NexpMap.find size nexp_map with + | i -> IntSet.singleton i + | exception Not_found -> + (* Look for equivalent nexps, but only in consistent type env *) + if prove env (NC_aux (NC_false,Unknown)) then IntSet.empty else + match List.find (fun (nexp,i) -> + prove env (NC_aux (NC_equal (nexp,size),Unknown))) nexp_list with + | _, i -> IntSet.singleton i + | exception Not_found -> IntSet.empty + end + | _ -> IntSet.empty + end + | None -> IntSet.empty in - let ik_compare (i,k) (i',k') = - match compare (i : int) i' with - | 0 -> Kid.compare k k' - | x -> x + let parameters_to_rewrite = + fst (fold_pexp + { (compute_exp_alg IntSet.empty IntSet.union) with + e_aux = (fun ((s,e),(l,annot)) -> IntSet.union s (parameters_for annot),E_aux (e,(l,annot))) + } pexp) in - let to_change = List.sort ik_compare to_change in + let new_nexps = NexpSet.of_list (List.map fst + (List.filter (fun (nexp,i) -> IntSet.mem i parameters_to_rewrite) nexp_list)) in match Bindings.find id fsizes with - | old -> if List.for_all2 (fun x y -> ik_compare x y = 0) old to_change then fsizes else - let str l = String.concat "," (List.map (fun (i,k) -> string_of_int i ^ "." ^ string_of_kid k) l) in - raise (Reporting_basic.err_general l - ("Different size type variables in different clauses of " ^ string_of_id id ^ - " old: " ^ str old ^ " new: " ^ str to_change)) - | exception Not_found -> Bindings.add id to_change fsizes + | old,old_nexps -> Bindings.add id (IntSet.union old parameters_to_rewrite, + NexpSet.union old_nexps new_nexps) fsizes + | exception Not_found -> Bindings.add id (parameters_to_rewrite, new_nexps) fsizes in let sizes_def fsizes = function | DEF_fundef (FD_aux (FD_function (_,_,_,funcls),_)) -> - let expose_tyvars = List.fold_left exposed_sizes_funcl KidSet.empty funcls in - List.fold_left (sizes_funcl expose_tyvars) fsizes funcls + List.fold_left sizes_funcl fsizes funcls | _ -> fsizes in let fn_sizes = List.fold_left sizes_def Bindings.empty defs in - let rewrite_e_app (id,args) = - match Bindings.find id fn_sizes with - | [] -> E_app (id,args) - | to_change -> - let args' = mapat replace_with_the_value to_change args in - E_app (id,args') - | exception Not_found -> E_app (id,args) - in let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),(l,annot))) = let pat,guard,body,(pl,_) = destruct_pexp pexp in - let pat,guard,body = + let pat,guard,body, nexps = (* Update pattern and add itself -> nat wrapper to body *) match Bindings.find id fn_sizes with - | [] -> pat,guard,body - | to_change -> + | to_change,nexps -> let pat, vars = match pat with P_aux (P_tup pats,(l,_)) -> @@ -2124,13 +2224,10 @@ let rewrite_size_parameters env (Defs defs) = P_aux (P_tup pats,(l,None)), vars | P_aux (_,(l,_)) -> begin - match to_change with - | [0,kid] -> - let pat, var = change_parameter_pat kid pat in + if IntSet.is_empty to_change then pat, [] + else + let pat, var = change_parameter_pat pat in pat, [var] - | _ -> - raise (Reporting_basic.err_unreachable l - "Expected multiple parameters at single parameter") end in (* TODO: only add bindings that are necessary (esp for guards) *) @@ -2139,10 +2236,24 @@ let rewrite_size_parameters env (Defs defs) = | None -> None | Some exp -> Some (List.fold_left add_var_rebind exp vars) in - pat,guard,body - | exception Not_found -> pat,guard,body + pat,guard,body,nexps + | exception Not_found -> pat,guard,body,NexpSet.empty in (* Update function applications *) + let funcl_typ = typ_of_annot (l,annot) in + let already_visible_nexps = + NexpSet.union + (Pretty_print_lem.lem_nexps_of_typ funcl_typ) + (Pretty_print_lem.typeclass_nexps funcl_typ) + in + let bound_nexps = NexpSet.elements (NexpSet.union nexps already_visible_nexps) in + let rewrite_e_app (id,args) = + match Bindings.find id fn_sizes with + | to_change,_ -> + let args' = mapat (replace_with_the_value bound_nexps) to_change args in + E_app (id,args') + | exception Not_found -> E_app (id,args) + in let body = fold_exp { id_exp_alg with e_app = rewrite_e_app } body in let guard = match guard with | None -> None @@ -2156,8 +2267,7 @@ let rewrite_size_parameters env (Defs defs) = | DEF_spec (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,annot))) as spec -> begin match Bindings.find id fn_sizes with - | [] -> spec - | to_change -> + | to_change,_ when not (IntSet.is_empty to_change) -> let typschm = match typschm with | TypSchm_aux (TypSchm_ts (tq,typ),l) -> let typ = match typ with @@ -2169,6 +2279,7 @@ let rewrite_size_parameters env (Defs defs) = in TypSchm_aux (TypSchm_ts (tq,typ),l) in DEF_spec (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,None))) + | _ -> spec | exception Not_found -> spec end | def -> def diff --git a/src/pattern_completeness.ml b/src/pattern_completeness.ml index 123592a3..ebb402e5 100644 --- a/src/pattern_completeness.ml +++ b/src/pattern_completeness.ml @@ -58,13 +58,6 @@ type ctx = variants : IdSet.t Bindings.t } -let hex_to_bin hex = - Util.string_to_list hex - |> List.map Sail_lib.hex_char - |> List.concat - |> List.map Sail_lib.char_of_bit - |> (fun bits -> String.init (List.length bits) (List.nth bits)) - type gpat = | GP_lit of lit | GP_wild diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 38862382..ac8ad48d 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -179,10 +179,11 @@ let doc_nexp_lem nexp = | Nexp_minus (n1, n2) -> mangle_nexp n1 ^ "_minus_" ^ mangle_nexp n2 | Nexp_exp n -> "exp_" ^ mangle_nexp n | Nexp_neg n -> "neg_" ^ mangle_nexp n + | _ -> + raise (Reporting_basic.err_unreachable l + ("cannot pretty-print nexp \"" ^ string_of_nexp full_nexp ^ "\"")) end in string ("'" ^ mangle_nexp full_nexp) - (* raise (Reporting_basic.err_unreachable l - ("cannot pretty-print non-atomic nexp \"" ^ string_of_nexp full_nexp ^ "\"")) *) (* Rewrite mangled names of type variables to the original names *) let rec orig_nexp (Nexp_aux (nexp, l)) = @@ -321,12 +322,30 @@ let contains_t_pp_var ctxt (Typ_aux (t,a) as typ) = NexpSet.diff (lem_nexps_of_typ typ) ctxt.bound_nexps |> NexpSet.exists (fun nexp -> not (is_nexp_constant nexp)) -let doc_tannot_lem ctxt eff typ = - if contains_t_pp_var ctxt typ then empty - else +let replace_typ_size ctxt env (Typ_aux (t,a)) = + match t with + | Typ_app (Id_aux (Id "vector",_) as id, [Typ_arg_aux (Typ_arg_nexp size,_);ord;typ']) -> + begin + let is_equal nexp = + prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown)) + in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with + | nexp -> Some (Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp nexp,Parse_ast.Unknown);ord;typ']),a)) + | exception Not_found -> None + end + | _ -> None + +let doc_tannot_lem ctxt env eff typ = + let of_typ typ = let ta = doc_typ_lem typ in if eff then string " : M " ^^ parens ta else string " : " ^^ ta + in + if contains_t_pp_var ctxt typ + then + match replace_typ_size ctxt env typ with + | None -> empty + | Some typ -> of_typ typ + else of_typ typ let doc_lit_lem (L_aux(lit,l)) = match lit with @@ -676,10 +695,11 @@ let doc_exp_lem, doc_let_lem = let argspp = align (separate_map (break 1) (expV true) args) in let epp = align (call ^//^ argspp) in let (taepp,aexp_needed) = - let t = Env.expand_synonyms (env_of full_exp) (typ_of full_exp) in + let env = env_of full_exp in + let t = Env.expand_synonyms env (typ_of full_exp) in let eff = effect_of full_exp in if typ_needs_printed t - then (align epp ^^ (doc_tannot_lem ctxt (effectful eff) t), true) + then (align epp ^^ (doc_tannot_lem ctxt env (effectful eff) t), true) else (epp, aexp_needed) in liftR (if aexp_needed then parens (align taepp) else taepp) end @@ -714,7 +734,7 @@ let doc_exp_lem, doc_let_lem = if has_effect eff BE_rreg then let epp = separate space [string "read_reg";doc_id_lem id] in if is_bitvector_typ base_typ - then liftR (parens (epp ^^ doc_tannot_lem ctxt true base_typ)) + then liftR (parens (epp ^^ doc_tannot_lem ctxt env true base_typ)) else liftR epp else if is_ctor env id then doc_id_lem_ctor id else doc_id_lem id @@ -768,7 +788,7 @@ let doc_exp_lem, doc_let_lem = let (epp,aexp_needed) = if is_bit_typ etyp && !opt_mwords then let bepp = string "of_bits" ^^ space ^^ parens (align epp) in - (bepp ^^ doc_tannot_lem ctxt false t, true) + (bepp ^^ doc_tannot_lem ctxt (env_of full_exp) false t, true) else (epp,aexp_needed) in if aexp_needed then parens (align epp) else epp | E_vector_update(v,e1,e2) -> @@ -912,7 +932,7 @@ let doc_typdef_lem (TD_aux(td, (l, annot))) = match td with mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), [mk_typ_arg (Typ_arg_typ rectyp); mk_typ_arg (Typ_arg_typ ftyp)])) in - let rfannot = doc_tannot_lem empty_ctxt false reftyp in + let rfannot = doc_tannot_lem empty_ctxt env false reftyp in let get, set = string "rec_val" ^^ dot ^^ fname fid, anglebars (space ^^ string "rec_val with " ^^ @@ -1342,7 +1362,7 @@ let doc_regtype_fields (tname, (n1, n2, fields)) = mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), [mk_typ_arg (Typ_arg_typ (mk_id_typ (mk_id tname))); mk_typ_arg (Typ_arg_typ ftyp)])) in - let rfannot = doc_tannot_lem empty_ctxt false reftyp in + let rfannot = doc_tannot_lem empty_ctxt Env.empty false reftyp in doc_op equals (concat [string "let "; parens (concat [string tname; underscore; doc_id_lem fid; rfannot])]) (concat [ diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 1dac7a1c..7620ca50 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -265,7 +265,7 @@ let fixities = (mk_id "|", (InfixR, 2)); ] in - ref Bindings.empty (*(fixities' : (prec * int) Bindings.t)*) + ref (fixities' : (prec * int) Bindings.t) let rec doc_exp (E_aux (e_aux, _) as exp) = match e_aux with diff --git a/src/rewrites.ml b/src/rewrites.ml index 9cba6b39..cc3df801 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -3005,6 +3005,7 @@ let rewrite_defs_c = [ ("simple_assignments", rewrite_simple_assignments); ("remove_vector_concat", rewrite_defs_remove_vector_concat); ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats); + ("guarded_pats", rewrite_defs_guarded_pats); ("exp_lift_assign", rewrite_defs_exp_lift_assign); ("constraint", rewrite_constraint); ("trivial_sizeof", rewrite_trivial_sizeof); diff --git a/src/sail.ml b/src/sail.ml index 35a7279b..95e060b2 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -88,7 +88,7 @@ let options = Arg.align ([ Arg.Tuple [Arg.Set opt_print_ocaml; Arg.Set Initial_check.opt_undefined_gen; Arg.Set Ocaml_backend.opt_trace_ocaml], " output an OCaml translated version of the input with tracing instrumentation, implies -ocaml"); ( "-c", - Arg.Tuple [Arg.Set opt_print_c; (* Arg.Set Initial_check.opt_undefined_gen *)], + Arg.Tuple [Arg.Set opt_print_c; Arg.Set Initial_check.opt_undefined_gen], " output a C translated version of the input"); ( "-lem_ast", Arg.Set opt_print_lem_ast, diff --git a/src/specialize.ml b/src/specialize.ml index efa8783e..2ebc7307 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -80,6 +80,10 @@ let id_of_instantiation id instantiation = let str = Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) ^ "#" in prepend_id str id +let string_of_instantiation instantiation = + let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar in + Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) + (* Returns a list of all the instantiations of a function id in an ast. *) let rec instantiations_of id ast = @@ -161,6 +165,7 @@ let specialize_id_valspec instantiations id ast = let typschm = mk_typschm typq typ in let spec_id = id_of_instantiation id instantiation in + if IdSet.mem spec_id !spec_ids then [] else begin spec_ids := IdSet.add spec_id !spec_ids; @@ -209,7 +214,7 @@ let specialize_id_overloads instantiations id (Defs defs) = valspecs are then re-specialized. This process is iterated until the whole spec is specialized. *) let remove_unused_valspecs ast = - let calls = ref (IdSet.singleton (mk_id "main")) in + let calls = ref (IdSet.of_list [mk_id "main"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"]) in let vs_ids = Initial_check.val_spec_ids ast in let inspect_exp = function |
