diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.mli | 1 | ||||
| -rw-r--r-- | src/rewriter_new_tc.ml | 2625 | ||||
| -rw-r--r-- | src/rewriter_new_tc.mli | 150 | ||||
| -rw-r--r-- | src/spec_analysis_new_tc.ml | 667 | ||||
| -rw-r--r-- | src/spec_analysis_new_tc.mli | 70 | ||||
| -rw-r--r-- | src/type_check_new.mli | 6 | ||||
| -rw-r--r-- | src/util.ml | 6 | ||||
| -rw-r--r-- | src/util.mli | 6 |
8 files changed, 3531 insertions, 0 deletions
diff --git a/src/ast_util.mli b/src/ast_util.mli index edce060d..9d6c5653 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -47,6 +47,7 @@ open Ast val map_exp_annot : ('a annot -> 'b annot) -> 'a exp -> 'b exp val map_pat_annot : ('a annot -> 'b annot) -> 'a pat -> 'b pat val map_lexp_annot : ('a annot -> 'b annot) -> 'a lexp -> 'b lexp +val map_letbind_annot : ('a annot -> 'b annot) -> 'a letbind -> 'b letbind (* Extract locations from identifiers *) val id_loc : id -> Parse_ast.l diff --git a/src/rewriter_new_tc.ml b/src/rewriter_new_tc.ml new file mode 100644 index 00000000..5fb50446 --- /dev/null +++ b/src/rewriter_new_tc.ml @@ -0,0 +1,2625 @@ +(**************************************************************************) +(* Sail *) +(* *) +(* Copyright (c) 2013-2017 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* *) +(* All rights reserved. *) +(* *) +(* This software was developed by the University of Cambridge Computer *) +(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) +(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(**************************************************************************) + +open Big_int +open Ast +open Ast_util +open Type_check_new +open Spec_analysis_new_tc +(*type typ = Type_internal.t +type 'a exp = 'a Ast.exp +type 'a emap = 'a Envmap.t +type envs = Type_check.envs +type 'a namemap = (typ * 'a exp) emap*) + +type 'a rewriters = { + rewrite_exp : 'a rewriters -> 'a exp -> 'a exp; + rewrite_lexp : 'a rewriters -> 'a lexp -> 'a lexp; + rewrite_pat : 'a rewriters -> 'a pat -> 'a pat; + rewrite_let : 'a rewriters -> 'a letbind -> 'a letbind; + rewrite_fun : 'a rewriters -> 'a fundef -> 'a fundef; + rewrite_def : 'a rewriters -> 'a def -> 'a def; + rewrite_defs : 'a rewriters -> 'a defs -> 'a defs; + } + + +let (>>) f g = fun x -> g(f(x)) + +let get_env_annot = function + | (_,Some(env,_,_)) -> env + | (l,None) -> raise (Reporting_basic.err_typ l "no type information") + +let get_typ_annot = function + | (_,Some(_,typ,_)) -> typ + | (l,None) -> raise (Reporting_basic.err_typ l "no type information") + +let get_eff_annot = function + | (_,Some(_,_,eff)) -> eff + | (l,None) -> raise (Reporting_basic.err_typ l "no type information") + +let get_env_exp (E_aux (_,a)) = get_env_annot a +let get_typ_exp (E_aux (_,a)) = get_typ_annot a +let get_eff_exp (E_aux (_,a)) = get_eff_annot a +let get_eff_fpat (FP_aux (_,a)) = get_eff_annot a +let get_eff_lexp (LEXP_aux (_,a)) = get_eff_annot a +let get_eff_fexp (FE_aux (_,a)) = get_eff_annot a +let get_eff_fexps (FES_aux (FES_Fexps (fexps,_),_)) = + List.fold_left union_effects no_effect (List.map get_eff_fexp fexps) +let get_eff_opt_default (Def_val_aux (_,a)) = get_eff_annot a +let get_eff_pexp (Pat_aux (_,a)) = get_eff_annot a +let get_eff_lb (LB_aux (_,a)) = get_eff_annot a + +let get_loc_exp (E_aux (_,(l,_))) = l + +let rec is_vector_typ = function + | Typ_aux (Typ_app (Id_aux (Id "vector",_), [_;_;_;_]), _) -> true + | Typ_aux (Typ_app (Id_aux (Id "register",_), [Typ_arg_aux (Typ_arg_typ rtyp,_)]), _) -> + is_vector_typ rtyp + | _ -> false + +let get_typ_app_args = function + | Typ_aux (Typ_app (Id_aux (Id c,_), targs), l) -> + (c, List.map (fun (Typ_arg_aux (a,_)) -> a) targs, l) + | Typ_aux (_, l) -> raise (Reporting_basic.err_typ l "get_typ_app_args called on non-app type") + +let rec get_vector_typ_args typ = match get_typ_app_args typ with + | ("vector", [Typ_arg_nexp start; Typ_arg_nexp len; Typ_arg_order ord; Typ_arg_typ etyp], _) -> + (start, len, ord, etyp) + | ("register", [Typ_arg_typ rtyp], _) -> get_vector_typ_args rtyp + | (_, _, l) -> raise (Reporting_basic.err_typ l "get_vector_typ_args called on non-vector type") + +let order_is_inc = function + | Ord_aux (Ord_inc, _) -> true + | Ord_aux (Ord_dec, _) -> false + | Ord_aux (Ord_var _, l) -> + raise (Reporting_basic.err_unreachable l "order_is_inc called on vector with variable ordering") + +let is_bit_typ = function + | Typ_aux (Typ_id (Id_aux (Id "bit", _)), _) -> true + | _ -> false + +let is_bitvector_typ typ = + if is_vector_typ typ then + let (_,_,_,etyp) = get_vector_typ_args typ in + is_bit_typ etyp + else false + +let simple_annot l typ = (Parse_ast.Generated l, Some (Env.empty, typ, no_effect)) +let simple_num l n = E_aux ( + E_lit (L_aux (L_num n, Parse_ast.Generated l)), + simple_annot (Parse_ast.Generated l) + (atom_typ (Nexp_aux (Nexp_constant n, Parse_ast.Generated l)))) + +let fresh_name_counter = ref 0 + +let fresh_name () = + let current = !fresh_name_counter in + let () = fresh_name_counter := (current + 1) in + current +let reset_fresh_name_counter () = + fresh_name_counter := 0 + +let fresh_id pre l = + let current = fresh_name () in + Id_aux (Id (pre ^ string_of_int current), Parse_ast.Generated l) + +let fresh_id_exp pre ((l,annot)) = + let id = fresh_id pre l in + E_aux (E_id id, (Parse_ast.Generated l, annot)) + +let fresh_id_pat pre ((l,annot)) = + let id = fresh_id pre l in + P_aux (P_id id, (Parse_ast.Generated l, annot)) + +let union_eff_exps es = + List.fold_left union_effects no_effect (List.map get_eff_exp es) + +let fix_eff_exp (E_aux (e,((l,_) as annot))) = + let effsum = union_effects (get_eff_annot annot) (match e with + | E_block es -> union_eff_exps es + | E_nondet es -> union_eff_exps es + | E_id _ + | E_lit _ -> no_effect + | E_cast (_,e) -> get_eff_exp e + | E_app (_,es) + | E_tuple es -> union_eff_exps es + | E_app_infix (e1,_,e2) -> union_eff_exps [e1;e2] + | E_if (e1,e2,e3) -> union_eff_exps [e1;e2;e3] + | E_for (_,e1,e2,e3,_,e4) -> union_eff_exps [e1;e2;e3;e4] + | E_vector es -> union_eff_exps es + | E_vector_indexed (ies,opt_default) -> + let (_,es) = List.split ies in + union_effects (get_eff_opt_default opt_default) (union_eff_exps es) + | E_vector_access (e1,e2) -> union_eff_exps [e1;e2] + | E_vector_subrange (e1,e2,e3) -> union_eff_exps [e1;e2;e3] + | E_vector_update (e1,e2,e3) -> union_eff_exps [e1;e2;e3] + | E_vector_update_subrange (e1,e2,e3,e4) -> union_eff_exps [e1;e2;e3;e4] + | E_vector_append (e1,e2) -> union_eff_exps [e1;e2] + | E_list es -> union_eff_exps es + | E_cons (e1,e2) -> union_eff_exps [e1;e2] + | E_record fexps -> get_eff_fexps fexps + | E_record_update(e,fexps) -> + union_effects (get_eff_exp e) (get_eff_fexps fexps) + | E_field (e,_) -> get_eff_exp e + | E_case (e,pexps) -> + List.fold_left union_effects (get_eff_exp e) (List.map get_eff_pexp pexps) + | E_let (lb,e) -> union_effects (get_eff_lb lb) (get_eff_exp e) + | E_assign (lexp,e) -> union_effects (get_eff_lexp lexp) (get_eff_exp e) + | E_exit e -> get_eff_exp e + | E_return e -> get_eff_exp e + | E_sizeof _ | E_sizeof_internal _ -> no_effect + | E_assert (c,m) -> no_effect + | E_comment _ | E_comment_struc _ -> no_effect + | E_internal_cast (_,e) -> get_eff_exp e + | E_internal_exp _ -> no_effect + | E_internal_exp_user _ -> no_effect + | E_internal_let (lexp,e1,e2) -> + union_effects (get_eff_lexp lexp) + (union_effects (get_eff_exp e1) (get_eff_exp e2)) + | E_internal_plet (_,e1,e2) -> union_effects (get_eff_exp e1) (get_eff_exp e2) + | E_internal_return e1 -> get_eff_exp e1) + in + E_aux (e, (l, Some (get_env_annot annot, get_typ_annot annot, effsum))) + +let fix_effsum_lexp (LEXP_aux (lexp,((l,_) as annot))) = + let effsum = union_effects (get_eff_annot annot) (match lexp with + | LEXP_id _ -> no_effect + | LEXP_cast _ -> no_effect + | LEXP_memory (_,es) -> union_eff_exps es + | LEXP_vector (lexp,e) -> union_effects (get_eff_lexp lexp) (get_eff_exp e) + | LEXP_vector_range (lexp,e1,e2) -> + union_effects (get_eff_lexp lexp) + (union_effects (get_eff_exp e1) (get_eff_exp e2)) + | LEXP_field (lexp,_) -> get_eff_lexp lexp) in + LEXP_aux (lexp, (l, Some (get_env_annot annot, get_typ_annot annot, effsum))) + +let fix_effsum_fexp (FE_aux (fexp,((l,_) as annot))) = + let effsum = union_effects (get_eff_annot annot) (match fexp with + | FE_Fexp (_,e) -> get_eff_exp e) in + FE_aux (fexp, (l, Some (get_env_annot annot, get_typ_annot annot, effsum))) + +let fix_effsum_fexps fexps = fexps (* FES_aux have no effect information *) + +let fix_effsum_opt_default (Def_val_aux (opt_default,((l,_) as annot))) = + let effsum = union_effects (get_eff_annot annot) (match opt_default with + | Def_val_empty -> no_effect + | Def_val_dec e -> get_eff_exp e) in + Def_val_aux (opt_default, (l, Some (get_env_annot annot, get_typ_annot annot, effsum))) + +let fix_effsum_pexp (Pat_aux (pexp,((l,_) as annot))) = + let effsum = union_effects (get_eff_annot annot) (match pexp with + | Pat_exp (_,e) -> get_eff_exp e) in + Pat_aux (pexp, (l, Some (get_env_annot annot, get_typ_annot annot, effsum))) + +let fix_effsum_lb (LB_aux (lb,((l,_) as annot))) = + let effsum = union_effects (get_eff_annot annot) (match lb with + | LB_val_explicit (_,_,e) -> get_eff_exp e + | LB_val_implicit (_,e) -> get_eff_exp e) in + LB_aux (lb, (l, Some (get_env_annot annot, get_typ_annot annot, effsum))) + +let effectful_effs = function + | Effect_aux (Effect_set effs, _) -> + List.exists + (fun (BE_aux (be,_)) -> + match be with + | BE_nondet | BE_unspec | BE_undef | BE_lset -> false + | _ -> true + ) effs + | _ -> true + +let effectful eaux = effectful_effs (get_eff_exp eaux) + +let updates_vars_effs = function + | Effect_aux (Effect_set effs, _) -> + List.exists + (fun (BE_aux (be,_)) -> + match be with + | BE_lset -> true + | _ -> false + ) effs + | _ -> true + +let updates_vars eaux = updates_vars_effs (get_eff_exp eaux) + +let id_to_string (Id_aux(id,l)) = + match id with + | Id(s) -> s + | DeIid(s) -> s + + +(*let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with + | [] -> None + | (v1,v2)::ls -> if (eq v1 v) then Some v2 else partial_assoc eq v ls + +let mk_atom_typ i = {t=Tapp("atom",[TA_nexp i])} + +let simple_num l n : tannot exp = + let typ = simple_annot (mk_atom_typ (mk_c (big_int_of_int n))) in + E_aux (E_lit (L_aux (L_num n,l)), (l,typ)) + +let rec rewrite_nexp_to_exp program_vars l nexp = + let rewrite n = rewrite_nexp_to_exp program_vars l n in + let typ = mk_atom_typ nexp in + let actual_rewrite_n nexp = + match nexp.nexp with + | Nconst i -> E_aux (E_lit (L_aux (L_num (int_of_big_int i),l)), (l,simple_annot typ)) + | Nadd (n1,n2) -> E_aux (E_app_infix (rewrite n1,(Id_aux (Id "+",l)),rewrite n2), + (l, (tag_annot typ (External (Some "add"))))) + | Nmult (n1,n2) -> E_aux (E_app_infix (rewrite n1,(Id_aux (Id "*",l)),rewrite n2), + (l, tag_annot typ (External (Some "multiply")))) + | Nsub (n1,n2) -> E_aux (E_app_infix (rewrite n1,(Id_aux (Id "-",l)),rewrite n2), + (l, tag_annot typ (External (Some "minus")))) + | N2n (n, _) -> E_aux (E_app_infix (E_aux (E_lit (L_aux (L_num 2,l)), (l, simple_annot (mk_atom_typ n_two))), + (Id_aux (Id "**",l)), + rewrite n), (l, tag_annot typ (External (Some "power")))) + | Npow(n,i) -> E_aux (E_app_infix + (rewrite n, (Id_aux (Id "**",l)), + E_aux (E_lit (L_aux (L_num i,l)), + (l, simple_annot (mk_atom_typ (mk_c_int i))))), + (l, tag_annot typ (External (Some "power")))) + | Nneg(n) -> E_aux (E_app_infix (E_aux (E_lit (L_aux (L_num 0,l)), (l, simple_annot (mk_atom_typ n_zero))), + (Id_aux (Id "-",l)), + rewrite n), + (l, tag_annot typ (External (Some "minus")))) + | Nvar v -> (*TODO these need to generate an error as it's a place where there's insufficient specification. + But, for now I need to permit this to make power.sail compile, and most errors are in trap + or vectors *) + (*let _ = Printf.eprintf "unbound variable here %s\n" v in*) + E_aux (E_id (Id_aux (Id v,l)),(l,simple_annot typ)) + | _ -> raise (Reporting_basic.err_unreachable l ("rewrite_nexp given n that can't be rewritten: " ^ (n_to_string nexp))) in + match program_vars with + | None -> actual_rewrite_n nexp + | Some program_vars -> + (match partial_assoc nexp_eq_check nexp program_vars with + | None -> actual_rewrite_n nexp + | Some(None,ev) -> + (*let _ = Printf.eprintf "var case of rewrite, %s\n" ev in*) + E_aux (E_id (Id_aux (Id ev,l)), (l, simple_annot typ)) + | Some(Some f,ev) -> + E_aux (E_app ((Id_aux (Id f,l)), [ (E_aux (E_id (Id_aux (Id ev,l)), (l,simple_annot typ)))]), + (l, tag_annot typ (External (Some f))))) + +let rec match_to_program_vars ns bounds = + match ns with + | [] -> [] + | n::ns -> match find_var_from_nexp n bounds with + | None -> match_to_program_vars ns bounds + | Some(augment,ev) -> + (*let _ = Printf.eprintf "adding n %s to program var %s\n" (n_to_string n) ev in*) + (n,(augment,ev))::(match_to_program_vars ns bounds)*) + +let explode s = + let rec exp i l = if i < 0 then l else exp (i - 1) (s.[i] :: l) in + exp (String.length s - 1) [] + + +let vector_string_to_bit_list l lit = + + let hexchar_to_binlist = function + | '0' -> ['0';'0';'0';'0'] + | '1' -> ['0';'0';'0';'1'] + | '2' -> ['0';'0';'1';'0'] + | '3' -> ['0';'0';'1';'1'] + | '4' -> ['0';'1';'0';'0'] + | '5' -> ['0';'1';'0';'1'] + | '6' -> ['0';'1';'1';'0'] + | '7' -> ['0';'1';'1';'1'] + | '8' -> ['1';'0';'0';'0'] + | '9' -> ['1';'0';'0';'1'] + | 'A' -> ['1';'0';'1';'0'] + | 'B' -> ['1';'0';'1';'1'] + | 'C' -> ['1';'1';'0';'0'] + | 'D' -> ['1';'1';'0';'1'] + | 'E' -> ['1';'1';'1';'0'] + | 'F' -> ['1';'1';'1';'1'] + | _ -> raise (Reporting_basic.err_unreachable l "hexchar_to_binlist given unrecognized character") in + + let s_bin = match lit with + | L_hex s_hex -> List.flatten (List.map hexchar_to_binlist (explode (String.uppercase s_hex))) + | L_bin s_bin -> explode s_bin + | _ -> raise (Reporting_basic.err_unreachable l "s_bin given non vector literal") in + + List.map (function '0' -> L_aux (L_zero, Parse_ast.Generated l) + | '1' -> L_aux (L_one,Parse_ast.Generated l) + | _ -> raise (Reporting_basic.err_unreachable (Parse_ast.Generated l) "binary had non-zero or one")) s_bin + +let rewrite_pat rewriters (P_aux (pat,(l,annot))) = + let rewrap p = P_aux (p,(l,annot)) in + let rewrite = rewriters.rewrite_pat rewriters in + match pat with + | P_lit (L_aux ((L_hex _ | L_bin _) as lit,_)) -> + let ps = List.map (fun p -> P_aux (P_lit p, simple_annot l bit_typ)) + (vector_string_to_bit_list l lit) in + rewrap (P_vector ps) + | P_lit _ | P_wild | P_id _ -> rewrap pat + | P_as(pat,id) -> rewrap (P_as( rewrite pat, id)) + | P_typ(typ,pat) -> rewrite pat + | P_app(id ,pats) -> rewrap (P_app(id, List.map rewrite pats)) + | P_record(fpats,_) -> + rewrap (P_record(List.map (fun (FP_aux(FP_Fpat(id,pat),pannot)) -> FP_aux(FP_Fpat(id, rewrite pat), pannot)) fpats, + false)) + | P_vector pats -> rewrap (P_vector(List.map rewrite pats)) + | P_vector_indexed ipats -> rewrap (P_vector_indexed(List.map (fun (i,pat) -> (i, rewrite pat)) ipats)) + | P_vector_concat pats -> rewrap (P_vector_concat (List.map rewrite pats)) + | P_tup pats -> rewrap (P_tup (List.map rewrite pats)) + | P_list pats -> rewrap (P_list (List.map rewrite pats)) + +let rewrite_exp rewriters (E_aux (exp,(l,annot))) = + let rewrap e = E_aux (e,(l,annot)) in + let rewrite = rewriters.rewrite_exp rewriters in + match exp with + | E_comment _ | E_comment_struc _ -> rewrap exp + | E_block exps -> rewrap (E_block (List.map rewrite exps)) + | E_nondet exps -> rewrap (E_nondet (List.map rewrite exps)) + | E_lit (L_aux ((L_hex _ | L_bin _) as lit,_)) -> + let es = List.map (fun p -> E_aux (E_lit p, simple_annot l bit_typ)) + (vector_string_to_bit_list l lit) in + rewrap (E_vector es) + | E_id _ | E_lit _ -> rewrap exp + | E_cast (typ, exp) -> rewrap (E_cast (typ, rewrite exp)) + | E_app (id,exps) -> rewrap (E_app (id,List.map rewrite exps)) + | E_app_infix(el,id,er) -> rewrap (E_app_infix(rewrite el,id,rewrite er)) + | E_tuple exps -> rewrap (E_tuple (List.map rewrite exps)) + | E_if (c,t,e) -> rewrap (E_if (rewrite c,rewrite t, rewrite e)) + | E_for (id, e1, e2, e3, o, body) -> + rewrap (E_for (id, rewrite e1, rewrite e2, rewrite e3, o, rewrite body)) + | E_vector exps -> rewrap (E_vector (List.map rewrite exps)) + | E_vector_indexed (exps,(Def_val_aux(default,dannot))) -> + let def = match default with + | Def_val_empty -> default + | Def_val_dec e -> Def_val_dec (rewrite e) in + rewrap (E_vector_indexed (List.map (fun (i,e) -> (i, rewrite e)) exps, Def_val_aux(def,dannot))) + | E_vector_access (vec,index) -> rewrap (E_vector_access (rewrite vec,rewrite index)) + | E_vector_subrange (vec,i1,i2) -> + rewrap (E_vector_subrange (rewrite vec,rewrite i1,rewrite i2)) + | E_vector_update (vec,index,new_v) -> + rewrap (E_vector_update (rewrite vec,rewrite index,rewrite new_v)) + | E_vector_update_subrange (vec,i1,i2,new_v) -> + rewrap (E_vector_update_subrange (rewrite vec,rewrite i1,rewrite i2,rewrite new_v)) + | E_vector_append (v1,v2) -> rewrap (E_vector_append (rewrite v1,rewrite v2)) + | E_list exps -> rewrap (E_list (List.map rewrite exps)) + | E_cons(h,t) -> rewrap (E_cons (rewrite h,rewrite t)) + | E_record (FES_aux (FES_Fexps(fexps, bool),fannot)) -> + rewrap (E_record + (FES_aux (FES_Fexps + (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) -> + FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot))) + | E_record_update (re,(FES_aux (FES_Fexps(fexps, bool),fannot))) -> + rewrap (E_record_update ((rewrite re), + (FES_aux (FES_Fexps + (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) -> + FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot)))) + | E_field(exp,id) -> rewrap (E_field(rewrite exp,id)) + | E_case (exp ,pexps) -> + rewrap (E_case (rewrite exp, + (List.map + (fun (Pat_aux (Pat_exp(p,e),pannot)) -> + Pat_aux (Pat_exp(rewriters.rewrite_pat rewriters p,rewrite e),pannot)) pexps))) + | E_let (letbind,body) -> rewrap (E_let(rewriters.rewrite_let rewriters letbind,rewrite body)) + | E_assign (lexp,exp) -> rewrap (E_assign(rewriters.rewrite_lexp rewriters lexp,rewrite exp)) + | E_sizeof n -> rewrap (E_sizeof n) + | E_exit e -> rewrap (E_exit (rewrite e)) + | E_return e -> rewrap (E_return (rewrite e)) + | E_assert(e1,e2) -> rewrap (E_assert(rewrite e1,rewrite e2)) + | E_internal_cast (casted_annot,exp) -> + check_exp (get_env_exp exp) (strip_exp exp) (get_typ_annot casted_annot) + (*let new_exp = rewrite exp in + (*let _ = Printf.eprintf "Removing an internal_cast with %s\n" (tannot_to_string casted_annot) in*) + (match casted_annot,exp with + | Base((_,t),_,_,_,_,_),E_aux(ec,(ecl,Base((_,exp_t),_,_,_,_,_))) -> + (*let _ = Printf.eprintf "Considering removing an internal cast where the two types are %s and %s\n" + (t_to_string t) (t_to_string exp_t) in*) + (match t.t,exp_t.t with + (*TODO should pass d_env into here so that I can look at the abbreviations if there are any here*) + | Tapp("vector",[TA_nexp n1;TA_nexp nw1;TA_ord o1;_]), + Tapp("vector",[TA_nexp n2;TA_nexp nw2;TA_ord o2;_]) + | Tapp("vector",[TA_nexp n1;TA_nexp nw1;TA_ord o1;_]), + Tapp("reg",[TA_typ {t=(Tapp("vector",[TA_nexp n2; TA_nexp nw2; TA_ord o2;_]))}]) -> + (match n1.nexp with + | Nconst i1 -> if nexp_eq n1 n2 then new_exp else rewrap (E_cast (t_to_typ t,new_exp)) + | _ -> (match o1.order with + | Odec -> + (*let _ = Printf.eprintf "Considering removing a cast or not: %s %s, %b\n" + (n_to_string nw1) (n_to_string n1) (nexp_one_more_than nw1 n1) in*) + rewrap (E_cast (Typ_aux (Typ_var (Kid_aux((Var "length"),Parse_ast.Generated l)), + Parse_ast.Generated l),new_exp)) + | _ -> new_exp)) + | _ -> new_exp + | Base((_,t),_,_,_,_,_),_ -> + (*let _ = Printf.eprintf "Considering removing an internal cast where the remaining type is %s\n%!" + (t_to_string t) in*) + (match t.t with + | Tapp("vector",[TA_nexp n1;TA_nexp nw1;TA_ord o1;_]) -> + (match o1.order with + | Odec -> + let _ = Printf.eprintf "Considering removing a cast or not: %s %s, %b\n" + (n_to_string nw1) (n_to_string n1) (nexp_one_more_than nw1 n1) in + rewrap (E_cast (Typ_aux (Typ_var (Kid_aux((Var "length"), Parse_ast.Generated l)), + Parse_ast.Generated l), new_exp)) + | _ -> new_exp) + | _ -> new_exp) + | _ -> (*let _ = Printf.eprintf "Not a base match?\n" in*) new_exp*) + (*| E_internal_exp (l,impl) -> + match impl with + | Base((_,t),_,_,_,_,bounds) -> + (*let _ = Printf.eprintf "Rewriting internal expression, with type %s, and bounds %s\n" + (t_to_string t) (bounds_to_string bounds) in*) + let bounds = match nmap with | None -> bounds | Some (nm,_) -> add_map_to_bounds nm bounds in + (*let _ = Printf.eprintf "Bounds after looking at nmap %s\n" (bounds_to_string bounds) in*) + (match t.t with + (*Old case; should possibly be removed*) + | Tapp("register",[TA_typ {t= Tapp("vector",[ _; TA_nexp r;_;_])}]) + | Tapp("vector", [_;TA_nexp r;_;_]) + | Tabbrev(_, {t=Tapp("vector",[_;TA_nexp r;_;_])}) -> + (*let _ = Printf.eprintf "vector case with %s, bounds are %s\n" + (n_to_string r) (bounds_to_string bounds) in*) + let nexps = expand_nexp r in + (match (match_to_program_vars nexps bounds) with + | [] -> rewrite_nexp_to_exp None l r + | map -> rewrite_nexp_to_exp (Some map) l r) + | Tapp("implicit", [TA_nexp i]) -> + (*let _ = Printf.eprintf "Implicit case with %s\n" (n_to_string i) in*) + let nexps = expand_nexp i in + (match (match_to_program_vars nexps bounds) with + | [] -> rewrite_nexp_to_exp None l i + | map -> rewrite_nexp_to_exp (Some map) l i) + | _ -> + raise (Reporting_basic.err_unreachable l + ("Internal_exp given unexpected types " ^ (t_to_string t)))) + | _ -> raise (Reporting_basic.err_unreachable l ("Internal_exp given none Base annot"))*) + (*| E_sizeof_internal (l,impl) -> + (match impl with + | Base((_,t),_,_,_,_,bounds) -> + let bounds = match nmap with | None -> bounds | Some (nm,_) -> add_map_to_bounds nm bounds in + (match t.t with + | Tapp("atom",[TA_nexp n]) -> + let nexps = expand_nexp n in + (*let _ = Printf.eprintf "Removing sizeof_internal with type %s\n" (t_to_string t) in*) + (match (match_to_program_vars nexps bounds) with + | [] -> rewrite_nexp_to_exp None l n + | map -> rewrite_nexp_to_exp (Some map) l n) + | _ -> raise (Reporting_basic.err_unreachable l ("Sizeof internal had non-atom type " ^ (t_to_string t)))) + | _ -> raise (Reporting_basic.err_unreachable l ("Sizeof internal had none base annot"))*) + (*| E_internal_exp_user ((l,user_spec),(_,impl)) -> + (match (user_spec,impl) with + | (Base((_,tu),_,_,_,_,_), Base((_,ti),_,_,_,_,bounds)) -> + (*let _ = Printf.eprintf "E_interal_user getting rewritten two types are %s and %s\n" + (t_to_string tu) (t_to_string ti) in*) + let bounds = match nmap with | None -> bounds | Some (nm,_) -> add_map_to_bounds nm bounds in + (match (tu.t,ti.t) with + | (Tapp("implicit", [TA_nexp u]),Tapp("implicit",[TA_nexp i])) -> + (*let _ = Printf.eprintf "Implicit case with %s\n" (n_to_string i) in*) + let nexps = expand_nexp i in + (match (match_to_program_vars nexps bounds) with + | [] -> rewrite_nexp_to_exp None l i + (*add u to program_vars env; for now it will work out properly by accident*) + | map -> rewrite_nexp_to_exp (Some map) l i) + | _ -> + raise (Reporting_basic.err_unreachable l + ("Internal_exp_user given unexpected types " ^ (t_to_string tu) ^ ", " ^ (t_to_string ti)))) + | _ -> raise (Reporting_basic.err_unreachable l ("Internal_exp_user given none Base annot")))*) + | E_internal_let _ -> raise (Reporting_basic.err_unreachable l "Internal let found before it should have been introduced") + | E_internal_return _ -> raise (Reporting_basic.err_unreachable l "Internal return found before it should have been introduced") + | E_internal_plet _ -> raise (Reporting_basic.err_unreachable l " Internal plet found before it should have been introduced") + | _ -> rewrap exp + +let rewrite_let rewriters (LB_aux(letbind,(l,annot))) = + (*let local_map = get_map_tannot annot in + let map = + match map,local_map with + | None,None -> None + | None,Some m -> Some(m, Envmap.empty) + | Some(m,s), None -> Some(m,s) + | Some(m,s), Some m' -> match merge_option_maps (Some m) local_map with + | None -> Some(m,s) (*Shouldn't happen*) + | Some new_m -> Some(new_m,s) in*) + match letbind with + | LB_val_explicit (typschm, pat,exp) -> + LB_aux(LB_val_explicit (typschm,rewriters.rewrite_pat rewriters pat, + rewriters.rewrite_exp rewriters exp),(l,annot)) + | LB_val_implicit ( pat, exp) -> + LB_aux(LB_val_implicit (rewriters.rewrite_pat rewriters pat, + rewriters.rewrite_exp rewriters exp),(l,annot)) + +let rewrite_lexp rewriters (LEXP_aux(lexp,(l,annot))) = + let rewrap le = LEXP_aux(le,(l,annot)) in + match lexp with + | LEXP_id _ | LEXP_cast _ -> rewrap lexp + | LEXP_tup tupls -> rewrap (LEXP_tup (List.map (rewriters.rewrite_lexp rewriters) tupls)) + | LEXP_memory (id,exps) -> rewrap (LEXP_memory(id,List.map (rewriters.rewrite_exp rewriters) exps)) + | LEXP_vector (lexp,exp) -> + rewrap (LEXP_vector (rewriters.rewrite_lexp rewriters lexp,rewriters.rewrite_exp rewriters exp)) + | LEXP_vector_range (lexp,exp1,exp2) -> + rewrap (LEXP_vector_range (rewriters.rewrite_lexp rewriters lexp, + rewriters.rewrite_exp rewriters exp1, + rewriters.rewrite_exp rewriters exp2)) + | LEXP_field (lexp,id) -> rewrap (LEXP_field (rewriters.rewrite_lexp rewriters lexp,id)) + +let rewrite_fun rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = + let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),(l,annot))) = + let _ = reset_fresh_name_counter () in + (*let _ = Printf.eprintf "Rewriting function %s, pattern %s\n" + (match id with (Id_aux (Id i,_)) -> i) (Pretty_print.pat_to_string pat) in*) + (*let map = get_map_tannot fdannot in + let map = + match map with + | None -> None + | Some m -> Some(m, Envmap.empty) in*) + (FCL_aux (FCL_Funcl (id,rewriters.rewrite_pat rewriters pat, + rewriters.rewrite_exp rewriters exp),(l,annot))) + in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot)) + +let rewrite_def rewriters d = match d with + | DEF_type _ | DEF_kind _ | DEF_spec _ | DEF_default _ | DEF_reg_dec _ | DEF_comm _ | DEF_overload _ -> d + | DEF_fundef fdef -> DEF_fundef (rewriters.rewrite_fun rewriters fdef) + | DEF_val letbind -> DEF_val (rewriters.rewrite_let rewriters letbind) + | DEF_scattered _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "DEF_scattered survived to rewritter") + +let rewrite_defs_base rewriters (Defs defs) = + let rec rewrite ds = match ds with + | [] -> [] + | d::ds -> (rewriters.rewrite_def rewriters d)::(rewrite ds) in + Defs (rewrite defs) + +let rewrite_defs (Defs defs) = rewrite_defs_base + {rewrite_exp = rewrite_exp; + rewrite_pat = rewrite_pat; + rewrite_let = rewrite_let; + rewrite_lexp = rewrite_lexp; + rewrite_fun = rewrite_fun; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base} (Defs defs) + +module Envmap = Finite_map.Fmap_map(String) + +(* TODO: This seems to only consider a single assignment (or possibly two, in + separate branches of an if-expression). Hence, it seems the result is always + at most one variable. Is this intended? + It is only used below when pulling out local variables inside if-expressions + into the outer scope, which seems dubious. I comment it out for now. *) +(*let rec introduced_variables (E_aux (exp,(l,annot))) = + match exp with + | E_cast (typ, exp) -> introduced_variables exp + | E_if (c,t,e) -> Envmap.intersect (introduced_variables t) (introduced_variables e) + | E_assign (lexp,exp) -> introduced_vars_le lexp exp + | _ -> Envmap.empty + +and introduced_vars_le (LEXP_aux(lexp,annot)) exp = + match lexp with + | LEXP_id (Id_aux (Id id,_)) | LEXP_cast(_,(Id_aux (Id id,_))) -> + (match annot with + | Base((_,t),Emp_intro,_,_,_,_) -> + Envmap.insert Envmap.empty (id,(t,exp)) + | _ -> Envmap.empty) + | _ -> Envmap.empty*) + +type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg = + { p_lit : lit -> 'pat_aux + ; p_wild : 'pat_aux + ; p_as : 'pat * id -> 'pat_aux + ; p_typ : Ast.typ * 'pat -> 'pat_aux + ; p_id : id -> 'pat_aux + ; p_app : id * 'pat list -> 'pat_aux + ; p_record : 'fpat list * bool -> 'pat_aux + ; p_vector : 'pat list -> 'pat_aux + ; p_vector_indexed : (int * 'pat) list -> 'pat_aux + ; p_vector_concat : 'pat list -> 'pat_aux + ; p_tup : 'pat list -> 'pat_aux + ; p_list : 'pat list -> 'pat_aux + ; p_aux : 'pat_aux * 'a annot -> 'pat + ; fP_aux : 'fpat_aux * 'a annot -> 'fpat + ; fP_Fpat : id * 'pat -> 'fpat_aux + } + +let rec fold_pat_aux (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a pat_aux -> 'pat_aux = + function + | P_lit lit -> alg.p_lit lit + | P_wild -> alg.p_wild + | P_id id -> alg.p_id id + | P_as (p,id) -> alg.p_as (fold_pat alg p,id) + | P_typ (typ,p) -> alg.p_typ (typ,fold_pat alg p) + | P_app (id,ps) -> alg.p_app (id,List.map (fold_pat alg) ps) + | P_record (ps,b) -> alg.p_record (List.map (fold_fpat alg) ps, b) + | P_vector ps -> alg.p_vector (List.map (fold_pat alg) ps) + | P_vector_indexed ps -> alg.p_vector_indexed (List.map (fun (i,p) -> (i, fold_pat alg p)) ps) + | P_vector_concat ps -> alg.p_vector_concat (List.map (fold_pat alg) ps) + | P_tup ps -> alg.p_tup (List.map (fold_pat alg) ps) + | P_list ps -> alg.p_list (List.map (fold_pat alg) ps) + + +and fold_pat (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a pat -> 'pat = + function + | P_aux (pat,annot) -> alg.p_aux (fold_pat_aux alg pat,annot) +and fold_fpat_aux (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a fpat_aux -> 'fpat_aux = + function + | FP_Fpat (id,pat) -> alg.fP_Fpat (id,fold_pat alg pat) +and fold_fpat (alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg) : 'a fpat -> 'fpat = + function + | FP_aux (fpat,annot) -> alg.fP_aux (fold_fpat_aux alg fpat,annot) + +(* identity fold from term alg to term alg *) +let id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg = + { p_lit = (fun lit -> P_lit lit) + ; p_wild = P_wild + ; p_as = (fun (pat,id) -> P_as (pat,id)) + ; p_typ = (fun (typ,pat) -> P_typ (typ,pat)) + ; p_id = (fun id -> P_id id) + ; p_app = (fun (id,ps) -> P_app (id,ps)) + ; p_record = (fun (ps,b) -> P_record (ps,b)) + ; p_vector = (fun ps -> P_vector ps) + ; p_vector_indexed = (fun ps -> P_vector_indexed ps) + ; p_vector_concat = (fun ps -> P_vector_concat ps) + ; p_tup = (fun ps -> P_tup ps) + ; p_list = (fun ps -> P_list ps) + ; p_aux = (fun (pat,annot) -> P_aux (pat,annot)) + ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) + ; fP_Fpat = (fun (id,pat) -> FP_Fpat (id,pat)) + } + +type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, + 'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind, + 'pat,'pat_aux,'fpat,'fpat_aux) exp_alg = + { e_block : 'exp list -> 'exp_aux + ; e_nondet : 'exp list -> 'exp_aux + ; e_id : id -> 'exp_aux + ; e_lit : lit -> 'exp_aux + ; e_cast : Ast.typ * 'exp -> 'exp_aux + ; e_app : id * 'exp list -> 'exp_aux + ; e_app_infix : 'exp * id * 'exp -> 'exp_aux + ; e_tuple : 'exp list -> 'exp_aux + ; e_if : 'exp * 'exp * 'exp -> 'exp_aux + ; e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux + ; e_vector : 'exp list -> 'exp_aux + ; e_vector_indexed : (int * 'exp) list * 'opt_default -> 'exp_aux + ; e_vector_access : 'exp * 'exp -> 'exp_aux + ; e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux + ; e_vector_update : 'exp * 'exp * 'exp -> 'exp_aux + ; e_vector_update_subrange : 'exp * 'exp * 'exp * 'exp -> 'exp_aux + ; e_vector_append : 'exp * 'exp -> 'exp_aux + ; e_list : 'exp list -> 'exp_aux + ; e_cons : 'exp * 'exp -> 'exp_aux + ; e_record : 'fexps -> 'exp_aux + ; e_record_update : 'exp * 'fexps -> 'exp_aux + ; e_field : 'exp * id -> 'exp_aux + ; e_case : 'exp * 'pexp list -> 'exp_aux + ; e_let : 'letbind * 'exp -> 'exp_aux + ; e_assign : 'lexp * 'exp -> 'exp_aux + ; e_exit : 'exp -> 'exp_aux + ; e_return : 'exp -> 'exp_aux + ; e_assert : 'exp * 'exp -> 'exp_aux + ; e_internal_cast : 'a annot * 'exp -> 'exp_aux + ; e_internal_exp : 'a annot -> 'exp_aux + ; e_internal_exp_user : 'a annot * 'a annot -> 'exp_aux + ; e_internal_let : 'lexp * 'exp * 'exp -> 'exp_aux + ; e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux + ; e_internal_return : 'exp -> 'exp_aux + ; e_aux : 'exp_aux * 'a annot -> 'exp + ; lEXP_id : id -> 'lexp_aux + ; lEXP_memory : id * 'exp list -> 'lexp_aux + ; lEXP_cast : Ast.typ * id -> 'lexp_aux + ; lEXP_tup : 'lexp list -> 'lexp_aux + ; lEXP_vector : 'lexp * 'exp -> 'lexp_aux + ; lEXP_vector_range : 'lexp * 'exp * 'exp -> 'lexp_aux + ; lEXP_field : 'lexp * id -> 'lexp_aux + ; lEXP_aux : 'lexp_aux * 'a annot -> 'lexp + ; fE_Fexp : id * 'exp -> 'fexp_aux + ; fE_aux : 'fexp_aux * 'a annot -> 'fexp + ; fES_Fexps : 'fexp list * bool -> 'fexps_aux + ; fES_aux : 'fexps_aux * 'a annot -> 'fexps + ; def_val_empty : 'opt_default_aux + ; def_val_dec : 'exp -> 'opt_default_aux + ; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default + ; pat_exp : 'pat * 'exp -> 'pexp_aux + ; pat_aux : 'pexp_aux * 'a annot -> 'pexp + ; lB_val_explicit : typschm * 'pat * 'exp -> 'letbind_aux + ; lB_val_implicit : 'pat * 'exp -> 'letbind_aux + ; lB_aux : 'letbind_aux * 'a annot -> 'letbind + ; pat_alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg + } + +let rec fold_exp_aux alg = function + | E_block es -> alg.e_block (List.map (fold_exp alg) es) + | E_nondet es -> alg.e_nondet (List.map (fold_exp alg) es) + | E_id id -> alg.e_id id + | E_lit lit -> alg.e_lit lit + | E_cast (typ,e) -> alg.e_cast (typ, fold_exp alg e) + | E_app (id,es) -> alg.e_app (id, List.map (fold_exp alg) es) + | E_app_infix (e1,id,e2) -> alg.e_app_infix (fold_exp alg e1, id, fold_exp alg e2) + | E_tuple es -> alg.e_tuple (List.map (fold_exp alg) es) + | E_if (e1,e2,e3) -> alg.e_if (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) + | E_for (id,e1,e2,e3,order,e4) -> + alg.e_for (id,fold_exp alg e1, fold_exp alg e2, fold_exp alg e3, order, fold_exp alg e4) + | E_vector es -> alg.e_vector (List.map (fold_exp alg) es) + | E_vector_indexed (es,opt) -> + alg.e_vector_indexed (List.map (fun (id,e) -> (id,fold_exp alg e)) es, fold_opt_default alg opt) + | E_vector_access (e1,e2) -> alg.e_vector_access (fold_exp alg e1, fold_exp alg e2) + | E_vector_subrange (e1,e2,e3) -> + alg.e_vector_subrange (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) + | E_vector_update (e1,e2,e3) -> + alg.e_vector_update (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) + | E_vector_update_subrange (e1,e2,e3,e4) -> + alg.e_vector_update_subrange (fold_exp alg e1,fold_exp alg e2, fold_exp alg e3, fold_exp alg e4) + | E_vector_append (e1,e2) -> alg.e_vector_append (fold_exp alg e1, fold_exp alg e2) + | E_list es -> alg.e_list (List.map (fold_exp alg) es) + | E_cons (e1,e2) -> alg.e_cons (fold_exp alg e1, fold_exp alg e2) + | E_record fexps -> alg.e_record (fold_fexps alg fexps) + | E_record_update (e,fexps) -> alg.e_record_update (fold_exp alg e, fold_fexps alg fexps) + | E_field (e,id) -> alg.e_field (fold_exp alg e, id) + | E_case (e,pexps) -> alg.e_case (fold_exp alg e, List.map (fold_pexp alg) pexps) + | E_let (letbind,e) -> alg.e_let (fold_letbind alg letbind, fold_exp alg e) + | E_assign (lexp,e) -> alg.e_assign (fold_lexp alg lexp, fold_exp alg e) + | E_exit e -> alg.e_exit (fold_exp alg e) + | E_return e -> alg.e_return (fold_exp alg e) + | E_assert(e1,e2) -> alg.e_assert (fold_exp alg e1, fold_exp alg e2) + | E_internal_cast (annot,e) -> alg.e_internal_cast (annot, fold_exp alg e) + | E_internal_exp annot -> alg.e_internal_exp annot + | E_internal_exp_user (annot1,annot2) -> alg.e_internal_exp_user (annot1,annot2) + | E_internal_let (lexp,e1,e2) -> + alg.e_internal_let (fold_lexp alg lexp, fold_exp alg e1, fold_exp alg e2) + | E_internal_plet (pat,e1,e2) -> + alg.e_internal_plet (fold_pat alg.pat_alg pat, fold_exp alg e1, fold_exp alg e2) + | E_internal_return e -> alg.e_internal_return (fold_exp alg e) +and fold_exp alg (E_aux (exp_aux,annot)) = alg.e_aux (fold_exp_aux alg exp_aux, annot) +and fold_lexp_aux alg = function + | LEXP_id id -> alg.lEXP_id id + | LEXP_memory (id,es) -> alg.lEXP_memory (id, List.map (fold_exp alg) es) + | LEXP_cast (typ,id) -> alg.lEXP_cast (typ,id) + | LEXP_vector (lexp,e) -> alg.lEXP_vector (fold_lexp alg lexp, fold_exp alg e) + | LEXP_vector_range (lexp,e1,e2) -> + alg.lEXP_vector_range (fold_lexp alg lexp, fold_exp alg e1, fold_exp alg e2) + | LEXP_field (lexp,id) -> alg.lEXP_field (fold_lexp alg lexp, id) +and fold_lexp alg (LEXP_aux (lexp_aux,annot)) = + alg.lEXP_aux (fold_lexp_aux alg lexp_aux, annot) +and fold_fexp_aux alg (FE_Fexp (id,e)) = alg.fE_Fexp (id, fold_exp alg e) +and fold_fexp alg (FE_aux (fexp_aux,annot)) = alg.fE_aux (fold_fexp_aux alg fexp_aux,annot) +and fold_fexps_aux alg (FES_Fexps (fexps,b)) = alg.fES_Fexps (List.map (fold_fexp alg) fexps, b) +and fold_fexps alg (FES_aux (fexps_aux,annot)) = alg.fES_aux (fold_fexps_aux alg fexps_aux, annot) +and fold_opt_default_aux alg = function + | Def_val_empty -> alg.def_val_empty + | Def_val_dec e -> alg.def_val_dec (fold_exp alg e) +and fold_opt_default alg (Def_val_aux (opt_default_aux,annot)) = + alg.def_val_aux (fold_opt_default_aux alg opt_default_aux, annot) +and fold_pexp_aux alg (Pat_exp (pat,e)) = alg.pat_exp (fold_pat alg.pat_alg pat, fold_exp alg e) +and fold_pexp alg (Pat_aux (pexp_aux,annot)) = alg.pat_aux (fold_pexp_aux alg pexp_aux, annot) +and fold_letbind_aux alg = function + | LB_val_explicit (t,pat,e) -> alg.lB_val_explicit (t,fold_pat alg.pat_alg pat, fold_exp alg e) + | LB_val_implicit (pat,e) -> alg.lB_val_implicit (fold_pat alg.pat_alg pat, fold_exp alg e) +and fold_letbind alg (LB_aux (letbind_aux,annot)) = alg.lB_aux (fold_letbind_aux alg letbind_aux, annot) + +let id_exp_alg = + { e_block = (fun es -> E_block es) + ; e_nondet = (fun es -> E_nondet es) + ; e_id = (fun id -> E_id id) + ; e_lit = (fun lit -> (E_lit lit)) + ; e_cast = (fun (typ,e) -> E_cast (typ,e)) + ; e_app = (fun (id,es) -> E_app (id,es)) + ; e_app_infix = (fun (e1,id,e2) -> E_app_infix (e1,id,e2)) + ; e_tuple = (fun es -> E_tuple es) + ; e_if = (fun (e1,e2,e3) -> E_if (e1,e2,e3)) + ; e_for = (fun (id,e1,e2,e3,order,e4) -> E_for (id,e1,e2,e3,order,e4)) + ; e_vector = (fun es -> E_vector es) + ; e_vector_indexed = (fun (es,opt2) -> E_vector_indexed (es,opt2)) + ; e_vector_access = (fun (e1,e2) -> E_vector_access (e1,e2)) + ; e_vector_subrange = (fun (e1,e2,e3) -> E_vector_subrange (e1,e2,e3)) + ; e_vector_update = (fun (e1,e2,e3) -> E_vector_update (e1,e2,e3)) + ; e_vector_update_subrange = (fun (e1,e2,e3,e4) -> E_vector_update_subrange (e1,e2,e3,e4)) + ; e_vector_append = (fun (e1,e2) -> E_vector_append (e1,e2)) + ; e_list = (fun es -> E_list es) + ; e_cons = (fun (e1,e2) -> E_cons (e1,e2)) + ; e_record = (fun fexps -> E_record fexps) + ; e_record_update = (fun (e1,fexp) -> E_record_update (e1,fexp)) + ; e_field = (fun (e1,id) -> (E_field (e1,id))) + ; e_case = (fun (e1,pexps) -> E_case (e1,pexps)) + ; e_let = (fun (lb,e2) -> E_let (lb,e2)) + ; e_assign = (fun (lexp,e2) -> E_assign (lexp,e2)) + ; e_exit = (fun e1 -> E_exit (e1)) + ; e_return = (fun e1 -> E_return e1) + ; e_assert = (fun (e1,e2) -> E_assert(e1,e2)) + ; e_internal_cast = (fun (a,e1) -> E_internal_cast (a,e1)) + ; e_internal_exp = (fun a -> E_internal_exp a) + ; e_internal_exp_user = (fun (a1,a2) -> E_internal_exp_user (a1,a2)) + ; e_internal_let = (fun (lexp, e2, e3) -> E_internal_let (lexp,e2,e3)) + ; e_internal_plet = (fun (pat, e1, e2) -> E_internal_plet (pat,e1,e2)) + ; e_internal_return = (fun e -> E_internal_return e) + ; e_aux = (fun (e,annot) -> E_aux (e,annot)) + ; lEXP_id = (fun id -> LEXP_id id) + ; lEXP_memory = (fun (id,es) -> LEXP_memory (id,es)) + ; lEXP_cast = (fun (typ,id) -> LEXP_cast (typ,id)) + ; lEXP_tup = (fun tups -> LEXP_tup tups) + ; lEXP_vector = (fun (lexp,e2) -> LEXP_vector (lexp,e2)) + ; lEXP_vector_range = (fun (lexp,e2,e3) -> LEXP_vector_range (lexp,e2,e3)) + ; lEXP_field = (fun (lexp,id) -> LEXP_field (lexp,id)) + ; lEXP_aux = (fun (lexp,annot) -> LEXP_aux (lexp,annot)) + ; fE_Fexp = (fun (id,e) -> FE_Fexp (id,e)) + ; fE_aux = (fun (fexp,annot) -> FE_aux (fexp,annot)) + ; fES_Fexps = (fun (fexps,b) -> FES_Fexps (fexps,b)) + ; fES_aux = (fun (fexp,annot) -> FES_aux (fexp,annot)) + ; def_val_empty = Def_val_empty + ; def_val_dec = (fun e -> Def_val_dec e) + ; def_val_aux = (fun (defval,aux) -> Def_val_aux (defval,aux)) + ; pat_exp = (fun (pat,e) -> (Pat_exp (pat,e))) + ; pat_aux = (fun (pexp,a) -> (Pat_aux (pexp,a))) + ; lB_val_explicit = (fun (typ,pat,e) -> LB_val_explicit (typ,pat,e)) + ; lB_val_implicit = (fun (pat,e) -> LB_val_implicit (pat,e)) + ; lB_aux = (fun (lb,annot) -> LB_aux (lb,annot)) + ; pat_alg = id_pat_alg + } + + +let remove_vector_concat_pat pat = + + (* ivc: bool that indicates whether the exp is in a vector_concat pattern *) + let remove_typed_patterns = + fold_pat { id_pat_alg with + p_aux = (function + | (P_typ (_,P_aux (p,_)),annot) + | (p,annot) -> + P_aux (p,annot) + ) + } in + + let pat = remove_typed_patterns pat in + + let fresh_id_v = fresh_id "v__" in + + (* expects that P_typ elements have been removed from AST, + that the length of all vectors involved is known, + that we don't have indexed vectors *) + + (* introduce names for all patterns of form P_vector_concat *) + let name_vector_concat_roots = + { p_lit = (fun lit -> P_lit lit) + ; p_typ = (fun (typ,p) -> P_typ (typ,p false)) (* cannot happen *) + ; p_wild = P_wild + ; p_as = (fun (pat,id) -> P_as (pat true,id)) + ; p_id = (fun id -> P_id id) + ; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps)) + ; p_record = (fun (fpats,b) -> P_record (fpats, b)) + ; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps)) + ; p_vector_indexed = (fun ps -> P_vector_indexed (List.map (fun (i,p) -> (i,p false)) ps)) + ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)) + ; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps)) + ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)) + ; p_aux = + (fun (pat,((l,_) as annot)) contained_in_p_as -> + match pat with + | P_vector_concat pats -> + (if contained_in_p_as + then P_aux (pat,annot) + else P_aux (P_as (P_aux (pat,annot),fresh_id_v l),annot)) + | _ -> P_aux (pat,annot) + ) + ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) + ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false)) + } in + + let pat = (fold_pat name_vector_concat_roots pat) false in + + (* introduce names for all unnamed child nodes of P_vector_concat *) + let name_vector_concat_elements = + let p_vector_concat pats = + let aux ((P_aux (p,((l,_) as a))) as pat) = match p with + | P_vector _ -> P_aux (P_as (pat,fresh_id_v l),a) + | P_id id -> P_aux (P_id id,a) + | P_as (p,id) -> P_aux (P_as (p,id),a) + | P_wild -> P_aux (P_wild,a) + | _ -> + raise + (Reporting_basic.err_unreachable + l "name_vector_concat_elements: Non-vector in vector-concat pattern") in + P_vector_concat (List.map aux pats) in + {id_pat_alg with p_vector_concat = p_vector_concat} in + + let pat = fold_pat name_vector_concat_elements pat in + + + + let rec tag_last = function + | x :: xs -> let is_last = xs = [] in (x,is_last) :: tag_last xs + | _ -> [] in + + (* remove names from vectors in vector_concat patterns and collect them as declarations for the + function body or expression *) + let unname_vector_concat_elements = (* : + ('a, + 'a pat * ((tannot exp -> tannot exp) list), + 'a pat_aux * ((tannot exp -> tannot exp) list), + 'a fpat * ((tannot exp -> tannot exp) list), + 'a fpat_aux * ((tannot exp -> tannot exp) list)) + pat_alg = *) + + (* build a let-expression of the form "let child = root[i..j] in body" *) + let letbind_vec (rootid,rannot) (child,cannot) (i,j) = + let (l,_) = cannot in + let (Id_aux (Id rootname,_)) = rootid in + let (Id_aux (Id childname,_)) = child in + + (*let vlength_info (Base ((_,{t = Tapp("vector",[_;TA_nexp nexp;_;_])}),_,_,_,_,_)) = + nexp in*) + let uannot = (Parse_ast.Generated l, ()) in + let unit_exp l eaux = E_aux (eaux, uannot) in + let unit_num l n = unit_exp l (E_lit (L_aux (L_num n, l))) in + + let root = unit_exp l (E_id rootid) in + let index_i = unit_num l i in + let index_j = (*match j with + | Some j ->*) unit_num l j in + (*)| None -> + let length_app_exp = unit_exp l (E_app (Id_aux (Id "length",l),[root])) in + (*let (_,length_root_nexp,_,_) = get_vector_typ_args (snd rannot) in + let length_app_exp : tannot exp = + let typ = mk_atom_typ length_root_nexp in + let annot = (l,tag_annot typ (External (Some "length"))) in + E_aux (E_app (Id_aux (Id "length",l),[root]),annot) in*) + let minus = Id_aux (Id "-",l) in + let one_exp = simple_num l 1 in + unit_exp l (E_app_infix(length_app_exp,minus,one_exp)) in*) + + let subv = unit_exp l (E_vector_subrange (root, index_i, index_j)) in + (*(E_app (Id_aux (Id "slice_raw",Unknown), [root;index_i;index_j])) in*) + + let letbind = LB_aux (LB_val_implicit (P_aux (P_id child,uannot),subv),uannot) in + (map_letbind_annot (fun (l,_) -> (l,None)) letbind, + (fun body -> unit_exp l (E_let (letbind,body))), + (rootname,childname)) in + + let p_aux = function + | ((P_as (P_aux (P_vector_concat pats,rannot'),rootid),decls),rannot) -> + let (start,last_idx) = (match get_vector_typ_args (get_typ_annot rannot') with + | (Nexp_aux (Nexp_constant start,_), Nexp_aux (Nexp_constant length,_), ord, _) -> + (start, if order_is_inc ord then start + length - 1 else start - length + 1) + | _ -> + raise (Reporting_basic.err_unreachable (fst rannot') + ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern"))) in + let aux (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) = + let (_,length,ord,_) = get_vector_typ_args (get_typ_annot cannot) in + (*)| (_,length,ord,_) ->*) + let (pos',index_j) = match length with + | Nexp_aux (Nexp_constant i,_) -> + if order_is_inc ord then (pos+i, pos+i-1) + else (pos-i, pos-i+1) + | Nexp_aux (_,l) -> + if is_last then (pos,last_idx) + else + raise + (Reporting_basic.err_unreachable + l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) in + (match p with + (* if we see a named vector pattern, remove the name and remember to + declare it later *) + | P_as (P_aux (p,cannot),cname) -> + let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in + (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) + (* if we see a P_id variable, remember to declare it later *) + | P_id cname -> + let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in + (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) + (* normal vector patterns are fine *) + | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc) ) + (* non-vector patterns aren't *) + (*)| _ -> + raise + (Reporting_basic.err_unreachable + (fst cannot) + ("unname_vector_concat_elements: Non-vector in vector-concat pattern:" ^ + string_of_typ (get_typ_annot cannot)) + )*) in + let pats_tagged = tag_last pats in + let (_,pats',decls') = List.fold_left aux (start,[],[]) pats_tagged in + + (* abuse P_vector_concat as a P_vector_const pattern: it has the of + patterns as an argument but they're meant to be consed together *) + (P_aux (P_as (P_aux (P_vector_concat pats',rannot'),rootid),rannot), decls @ decls') + | ((p,decls),annot) -> (P_aux (p,annot),decls) in + + { p_lit = (fun lit -> (P_lit lit,[])) + ; p_wild = (P_wild,[]) + ; p_as = (fun ((pat,decls),id) -> (P_as (pat,id),decls)) + ; p_typ = (fun (typ,(pat,decls)) -> (P_typ (typ,pat),decls)) + ; p_id = (fun id -> (P_id id,[])) + ; p_app = (fun (id,ps) -> let (ps,decls) = List.split ps in + (P_app (id,ps),List.flatten decls)) + ; p_record = (fun (ps,b) -> let (ps,decls) = List.split ps in + (P_record (ps,b),List.flatten decls)) + ; p_vector = (fun ps -> let (ps,decls) = List.split ps in + (P_vector ps,List.flatten decls)) + ; p_vector_indexed = (fun ps -> let (is,ps) = List.split ps in + let (ps,decls) = List.split ps in + let ps = List.combine is ps in + (P_vector_indexed ps,List.flatten decls)) + ; p_vector_concat = (fun ps -> let (ps,decls) = List.split ps in + (P_vector_concat ps,List.flatten decls)) + ; p_tup = (fun ps -> let (ps,decls) = List.split ps in + (P_tup ps,List.flatten decls)) + ; p_list = (fun ps -> let (ps,decls) = List.split ps in + (P_list ps,List.flatten decls)) + ; p_aux = (fun ((pat,decls),annot) -> p_aux ((pat,decls),annot)) + ; fP_aux = (fun ((fpat,decls),annot) -> (FP_aux (fpat,annot),decls)) + ; fP_Fpat = (fun (id,(pat,decls)) -> (FP_Fpat (id,pat),decls)) + } in + + let (pat,decls) = fold_pat unname_vector_concat_elements pat in + + let decls = + let module S = Set.Make(String) in + + let roots_needed = + List.fold_right + (fun (_,(rootid,childid)) roots_needed -> + if S.mem childid roots_needed then + (* let _ = print_endline rootid in *) + S.add rootid roots_needed + else if String.length childid >= 3 && String.sub childid 0 2 = String.sub "v__" 0 2 then + roots_needed + else + S.add rootid roots_needed + ) decls S.empty in + List.filter + (fun (_,(_,childid)) -> + S.mem childid roots_needed || + String.length childid < 3 || + not (String.sub childid 0 2 = String.sub "v__" 0 2)) + decls in + + let (letbinds,decls) = + let (decls,_) = List.split decls in + List.split decls in + + let decls = strip_exp >> List.fold_left (fun f g x -> f (g x)) (fun b -> b) decls in + + + (* at this point shouldn't have P_as patterns in P_vector_concat patterns any more, + all P_as and P_id vectors should have their declarations in decls. + Now flatten all vector_concat patterns *) + + let flatten = + let p_vector_concat ps = + let aux p acc = match p with + | (P_aux (P_vector_concat pats,_)) -> pats @ acc + | pat -> pat :: acc in + P_vector_concat (List.fold_right aux ps []) in + {id_pat_alg with p_vector_concat = p_vector_concat} in + + let pat = fold_pat flatten pat in + + (* at this point pat should be a flat pattern: no vector_concat patterns + with vector_concats patterns as direct child-nodes anymore *) + + let range a b = + let rec aux a b = if a > b then [] else a :: aux (a+1) b in + if a > b then List.rev (aux b a) else aux a b in + + let remove_vector_concats = + let p_vector_concat ps = + let aux acc (P_aux (p,annot),is_last) = + let env = get_env_annot annot in + let eff = get_eff_annot annot in + let (l,_) = annot in + let wild _ = P_aux (P_wild,(Parse_ast.Generated l, Some (env, bit_typ, eff))) in + match p, get_vector_typ_args (get_typ_annot annot) with + | P_vector ps,_ -> acc @ ps + | _, (_,Nexp_aux (Nexp_constant length,_),_,_) -> + acc @ (List.map wild (range 0 (length - 1))) + | _, _ -> + if is_last then acc @ [wild 0] + else raise + (Reporting_basic.err_unreachable l + ("remove_vector_concats: Non-vector in vector-concat pattern " ^ + string_of_typ (get_typ_annot annot))) in + + let has_length (P_aux (p,annot)) = + match get_vector_typ_args (get_typ_annot annot) with + | (_,Nexp_aux (Nexp_constant length,_),_,_) -> true + | _ -> false in + + let ps_tagged = tag_last ps in + let ps' = List.fold_left aux [] ps_tagged in + let last_has_length ps = List.exists (fun (p,b) -> b && has_length p) ps_tagged in + + if last_has_length ps then + P_vector ps' + else + (* If the last vector pattern in the vector_concat pattern has unknown + length we misuse the P_vector_concat constructor's argument to place in + the following way: P_vector_concat [x;y; ... ;z] should be mapped to the + pattern-match x :: y :: .. z, i.e. if x : 'a, then z : vector 'a. *) + P_vector_concat ps' in + + {id_pat_alg with p_vector_concat = p_vector_concat} in + + let pat = fold_pat remove_vector_concats pat in + + (pat,letbinds,decls) + +let map_check_exp f exp = check_exp (get_env_exp exp) (f exp) (get_typ_exp exp) + +(* assumes there are no more E_internal expressions *) +let rewrite_exp_remove_vector_concat_pat rewriters (E_aux (exp,(l,annot)) as full_exp) = + let rewrap e = E_aux (e,(l,annot)) in + let recheck f exp = check_exp (get_env_exp exp) (f exp) (get_typ_exp exp) in + let rewrite_rec = rewriters.rewrite_exp rewriters in + let rewrite_base = rewrite_exp rewriters in + match exp with + | E_case (e,ps) -> + let aux (Pat_aux (Pat_exp (pat,body),annot')) = + let (pat,_,decls) = remove_vector_concat_pat pat in + Pat_aux (Pat_exp (pat,map_check_exp (rewrite_rec >> decls) body),annot') in + rewrap (E_case (rewrite_rec e, List.map aux ps)) + | E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) -> + let (pat,_,decls) = remove_vector_concat_pat pat in + let body' = check_exp (get_env_exp body) (decls (rewrite_rec body)) (get_typ_exp body) in + rewrap (E_let (LB_aux (LB_val_explicit (typ,pat,rewrite_rec v),annot'), + map_check_exp (rewrite_rec >> decls) body)) + | E_let (LB_aux (LB_val_implicit (pat,v),annot'),body) -> + let (pat,_,decls) = remove_vector_concat_pat pat in + rewrap (E_let (LB_aux (LB_val_implicit (pat,rewrite_rec v),annot'), + map_check_exp (rewrite_rec >> decls) body)) + | exp -> rewrite_base full_exp + +let rewrite_fun_remove_vector_concat_pat + rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = + let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),(l,annot))) = + let (pat',_,decls) = remove_vector_concat_pat pat in + let exp' = map_check_exp (rewriters.rewrite_exp rewriters >> decls) exp in + (FCL_aux (FCL_Funcl (id,pat',exp'),(l,annot))) + in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot)) + +let rewrite_defs_remove_vector_concat (Defs defs) = + let rewriters = + {rewrite_exp = rewrite_exp_remove_vector_concat_pat; + rewrite_pat = rewrite_pat; + rewrite_let = rewrite_let; + rewrite_lexp = rewrite_lexp; + rewrite_fun = rewrite_fun_remove_vector_concat_pat; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base} in + let rewrite_def d = + let d = rewriters.rewrite_def rewriters d in + match d with + | DEF_val (LB_aux (LB_val_explicit (t,pat,exp),a)) -> + let (pat,letbinds,_) = remove_vector_concat_pat pat in + let defvals = List.map (fun lb -> DEF_val lb) letbinds in + [DEF_val (LB_aux (LB_val_explicit (t,pat,exp),a))] @ defvals + | DEF_val (LB_aux (LB_val_implicit (pat,exp),a)) -> + let (pat,letbinds,_) = remove_vector_concat_pat pat in + let defvals = List.map (fun lb -> DEF_val lb) letbinds in + [DEF_val (LB_aux (LB_val_implicit (pat,exp),a))] @ defvals + | d -> [d] in + Defs (List.flatten (List.map rewrite_def defs)) + +let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with +| P_lit _ | P_wild | P_id _ -> false +| P_as (pat,_) | P_typ (_,pat) -> contains_bitvector_pat pat +| P_vector _ | P_vector_concat _ | P_vector_indexed _ -> + is_bitvector_typ (get_typ_annot annot) +| P_app (_,pats) | P_tup pats | P_list pats -> + List.exists contains_bitvector_pat pats +| P_record (fpats,_) -> + List.exists (fun (FP_aux (FP_Fpat (_,pat),_)) -> contains_bitvector_pat pat) fpats + +let remove_bitvector_pat pat = + + (* first introduce names for bitvector patterns *) + let name_bitvector_roots = + { p_lit = (fun lit -> P_lit lit) + ; p_typ = (fun (typ,p) -> P_typ (typ,p false)) + ; p_wild = P_wild + ; p_as = (fun (pat,id) -> P_as (pat true,id)) + ; p_id = (fun id -> P_id id) + ; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps)) + ; p_record = (fun (fpats,b) -> P_record (fpats, b)) + ; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps)) + ; p_vector_indexed = (fun ps -> P_vector_indexed (List.map (fun (i,p) -> (i,p false)) ps)) + ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)) + ; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps)) + ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)) + ; p_aux = + (fun (pat,annot) contained_in_p_as -> + let t = get_typ_annot annot in + let (l,_) = annot in + match pat, is_bitvector_typ t, contained_in_p_as with + | P_vector _, true, false + | P_vector_indexed _, true, false -> + P_aux (P_as (P_aux (pat,annot),fresh_id "b__" l), annot) + | _ -> P_aux (pat,annot) + ) + ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) + ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false)) + } in + let pat = (fold_pat name_bitvector_roots pat) false in + + (* Then collect guard expressions testing whether the literal bits of a + bitvector pattern match those of a given bitvector, and collect let + bindings for the bits bound by P_id or P_as patterns *) + + (* Helper functions for generating guard expressions *) + let access_bit_exp (rootid,rannot) l idx = + let root : tannot exp = E_aux (E_id rootid,rannot) in + E_aux (E_vector_access (root,simple_num l idx), simple_annot l bit_typ) in + + let test_bit_exp rootid l t idx exp = + let rannot = simple_annot l t in + let elem = access_bit_exp (rootid,rannot) l idx in + let eqid = Id_aux (Id "==", Parse_ast.Generated l) in + let eqannot = simple_annot l bool_typ in + let eqexp : tannot exp = E_aux (E_app_infix(elem,eqid,exp), eqannot) in + Some (eqexp) in + + let test_subvec_exp rootid l typ i j lits = + let (start, length, ord, _) = get_vector_typ_args typ in + let length' = nconstant (List.length lits) in + let start' = + if order_is_inc ord then nconstant 0 + else nminus length' (nconstant 1) in + let typ' = vector_typ start' length' ord bit_typ in + let subvec_exp = + match start, length with + | Nexp_aux (Nexp_constant s, _), Nexp_aux (Nexp_constant l, _) + when s = i && l = List.length lits -> + E_id rootid + | _ -> + (*if vec_start t = i && vec_length t = List.length lits + then E_id rootid + else*) E_vector_subrange ( + E_aux (E_id rootid, simple_annot l typ), + simple_num l i, + simple_num l j) in + E_aux (E_app_infix( + E_aux (subvec_exp, simple_annot l typ'), + Id_aux (Id "==", Parse_ast.Generated l), + E_aux (E_vector lits, simple_annot l typ')), + simple_annot l bool_typ) in + + let letbind_bit_exp rootid l typ idx id = + let rannot = simple_annot l typ in + let elem = access_bit_exp (rootid,rannot) l idx in + let e = P_aux (P_id id, simple_annot l bit_typ) in + let letbind = LB_aux (LB_val_implicit (e,elem), simple_annot l bit_typ) in + let letexp = (fun body -> + let (E_aux (_,(_,bannot))) = body in + E_aux (E_let (letbind,body), (Parse_ast.Generated l, bannot))) in + (letexp, letbind) in + + (* Helper functions for composing guards *) + let bitwise_and exp1 exp2 = + let (E_aux (_,(l,_))) = exp1 in + let andid = Id_aux (Id "&", Parse_ast.Generated l) in + E_aux (E_app_infix(exp1,andid,exp2), simple_annot l bool_typ) in + + let compose_guards guards = + List.fold_right (Util.option_binop bitwise_and) guards None in + + let flatten_guards_decls gd = + let (guards,decls,letbinds) = Util.split3 gd in + (compose_guards guards, (List.fold_right (@@) decls), List.flatten letbinds) in + + (* Collect guards and let bindings *) + let guard_bitvector_pat = + let collect_guards_decls ps rootid t = + let (start,_,ord,_) = get_vector_typ_args t in + let rec collect current (guards,dls) idx ps = + let idx' = if order_is_inc ord then idx + 1 else idx - 1 in + (match ps with + | pat :: ps' -> + (match pat with + | P_aux (P_lit lit, (l,annot)) -> + let e = E_aux (E_lit lit, (Parse_ast.Generated l, annot)) in + let current' = (match current with + | Some (l,i,j,lits) -> Some (l,i,idx,lits @ [e]) + | None -> Some (l,idx,idx,[e])) in + collect current' (guards, dls) idx' ps' + | P_aux (P_as (pat',id), (l,annot)) -> + let dl = letbind_bit_exp rootid l t idx id in + collect current (guards, dls @ [dl]) idx (pat' :: ps') + | _ -> + let dls' = (match pat with + | P_aux (P_id id, (l,annot)) -> + dls @ [letbind_bit_exp rootid l t idx id] + | _ -> dls) in + let guards' = (match current with + | Some (l,i,j,lits) -> + guards @ [Some (test_subvec_exp rootid l t i j lits)] + | None -> guards) in + collect None (guards', dls') idx' ps') + | [] -> + let guards' = (match current with + | Some (l,i,j,lits) -> + guards @ [Some (test_subvec_exp rootid l t i j lits)] + | None -> guards) in + (guards',dls)) in + let (guards,dls) = match start with + | Nexp_aux (Nexp_constant s, _) -> + collect None ([],[]) s ps + | _ -> + let (P_aux (_, (l,_))) = pat in + raise (Reporting_basic.err_unreachable l + "guard_bitvector_pat called on pattern with non-constant start index") in + let (decls,letbinds) = List.split dls in + (compose_guards guards, List.fold_right (@@) decls, letbinds) in + + let collect_guards_decls_indexed ips rootid t = + let rec guard_decl (idx,pat) = (match pat with + | P_aux (P_lit lit, (l,annot)) -> + let exp = E_aux (E_lit lit, (l,annot)) in + (test_bit_exp rootid l t idx exp, (fun b -> b), []) + | P_aux (P_as (pat',id), (l,annot)) -> + let (guard,decls,letbinds) = guard_decl (idx,pat') in + let (letexp,letbind) = letbind_bit_exp rootid l t idx id in + (guard, decls >> letexp, letbind :: letbinds) + | P_aux (P_id id, (l,annot)) -> + let (letexp,letbind) = letbind_bit_exp rootid l t idx id in + (None, letexp, [letbind]) + | _ -> (None, (fun b -> b), [])) in + let (guards,decls,letbinds) = Util.split3 (List.map guard_decl ips) in + (compose_guards guards, List.fold_right (@@) decls, List.flatten letbinds) in + + { p_lit = (fun lit -> (P_lit lit, (None, (fun b -> b), []))) + ; p_wild = (P_wild, (None, (fun b -> b), [])) + ; p_as = (fun ((pat,gdls),id) -> (P_as (pat,id), gdls)) + ; p_typ = (fun (typ,(pat,gdls)) -> (P_typ (typ,pat), gdls)) + ; p_id = (fun id -> (P_id id, (None, (fun b -> b), []))) + ; p_app = (fun (id,ps) -> let (ps,gdls) = List.split ps in + (P_app (id,ps), flatten_guards_decls gdls)) + ; p_record = (fun (ps,b) -> let (ps,gdls) = List.split ps in + (P_record (ps,b), flatten_guards_decls gdls)) + ; p_vector = (fun ps -> let (ps,gdls) = List.split ps in + (P_vector ps, flatten_guards_decls gdls)) + ; p_vector_indexed = (fun p -> let (is,p) = List.split p in + let (ps,gdls) = List.split p in + let ps = List.combine is ps in + (P_vector_indexed ps, flatten_guards_decls gdls)) + ; p_vector_concat = (fun ps -> let (ps,gdls) = List.split ps in + (P_vector_concat ps, flatten_guards_decls gdls)) + ; p_tup = (fun ps -> let (ps,gdls) = List.split ps in + (P_tup ps, flatten_guards_decls gdls)) + ; p_list = (fun ps -> let (ps,gdls) = List.split ps in + (P_list ps, flatten_guards_decls gdls)) + ; p_aux = (fun ((pat,gdls),annot) -> + let t = get_typ_annot annot in + (match pat, is_bitvector_typ t with + | P_as (P_aux (P_vector ps, _), id), true -> + (P_aux (P_id id, annot), collect_guards_decls ps id t) + | P_as (P_aux (P_vector_indexed ips, _), id), true -> + (P_aux (P_id id, annot), collect_guards_decls_indexed ips id t) + | _, _ -> (P_aux (pat,annot), gdls))) + ; fP_aux = (fun ((fpat,gdls),annot) -> (FP_aux (fpat,annot), gdls)) + ; fP_Fpat = (fun (id,(pat,gdls)) -> (FP_Fpat (id,pat), gdls)) + } in + fold_pat guard_bitvector_pat pat + +let remove_wildcards pre (P_aux (_,(l,_)) as pat) = + fold_pat + {id_pat_alg with + p_aux = function + | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot)) + | (p,annot) -> P_aux (p,annot) } + pat + +(* Check if one pattern subsumes the other, and if so, calculate a + substitution of variables that are used in the same position. + TODO: Check somewhere that there are no variable clashes (the same variable + name used in different positions of the patterns) + *) +let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) = + let rewrap p = P_aux (p,annot1) in + let subsumes_list s pats1 pats2 = + if List.length pats1 = List.length pats2 + then + let subs = List.map2 s pats1 pats2 in + List.fold_right + (fun p acc -> match p, acc with + | Some subst, Some substs -> Some (subst @ substs) + | _ -> None) + subs (Some []) + else None in + match p1, p2 with + | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) -> + if lit1 = lit2 then Some [] else None + | P_as (pat1,_), _ -> subsumes_pat pat1 pat2 + | _, P_as (pat2,_) -> subsumes_pat pat1 pat2 + | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2 + | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2 + | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) -> + if id1 = id2 then Some [] + else if Env.lookup_id aid1 (get_env_annot annot1) = Unbound && + Env.lookup_id aid2 (get_env_annot annot2) = Unbound + then Some [(id2,id1)] else None + | P_id id1, _ -> + if Env.lookup_id id1 (get_env_annot annot1) = Unbound then Some [] else None + | P_wild, _ -> Some [] + | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) -> + if id1 = id2 then subsumes_list subsumes_pat args1 args2 else None + | P_record (fps1,b1), P_record (fps2,b2) -> + if b1 = b2 then subsumes_list subsumes_fpat fps1 fps2 else None + | P_vector pats1, P_vector pats2 + | P_vector_concat pats1, P_vector_concat pats2 + | P_tup pats1, P_tup pats2 + | P_list pats1, P_list pats2 -> + subsumes_list subsumes_pat pats1 pats2 + | P_vector_indexed ips1, P_vector_indexed ips2 -> + let (is1,ps1) = List.split ips1 in + let (is2,ps2) = List.split ips2 in + if is1 = is2 then subsumes_list subsumes_pat ps1 ps2 else None + | _ -> None +and subsumes_fpat (FP_aux (FP_Fpat (id1,pat1),_)) (FP_aux (FP_Fpat (id2,pat2),_)) = + if id1 = id2 then subsumes_pat pat1 pat2 else None + +let equiv_pats pat1 pat2 = + match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with + | Some _, Some _ -> true + | _, _ -> false + +let subst_id_pat pat (id1,id2) = + let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in + fold_pat {id_pat_alg with p_id = p_id} pat + +let subst_id_exp exp (id1,id2) = + (* TODO Don't substitute bound occurrences inside let expressions etc *) + let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in + fold_exp {id_exp_alg with e_id = e_id} exp + +let rec pat_to_exp (P_aux (pat,(l,annot))) = + let rewrap e = E_aux (e,(l,annot)) in + match pat with + | P_lit lit -> rewrap (E_lit lit) + | P_wild -> raise (Reporting_basic.err_unreachable l + "pat_to_exp given wildcard pattern") + | P_as (pat,id) -> rewrap (E_id id) + | P_typ (_,pat) -> pat_to_exp pat + | P_id id -> rewrap (E_id id) + | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats)) + | P_record (fpats,b) -> + rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot)))) + | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats)) + | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l + "pat_to_exp not implemented for P_vector_concat") + (* We assume that vector concatenation patterns have been transformed + away already *) + | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats)) + | P_list pats -> rewrap (E_list (List.map pat_to_exp pats)) + | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l + "pat_to_exp not implemented for P_vector_indexed") (* TODO *) +and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) = + FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot)) + +let case_exp e t cs = + let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in + let ps = List.map pexp cs in + (* let efr = union_effs (List.map get_eff_pexp ps) in *) + fix_eff_exp (E_aux (E_case (e,ps), (get_loc_exp e, Some (get_env_exp e, t, no_effect)))) + +let rewrite_guarded_clauses l cs = + let rec group clauses = + let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in + let rec group_aux current acc = (function + | ((pat,guard,body,annot) as c) :: cs -> + let (current_pat,_,_) = current in + (match subsumes_pat current_pat pat with + | Some substs -> + let pat' = List.fold_left subst_id_pat pat substs in + let guard' = (match guard with + | Some exp -> Some (List.fold_left subst_id_exp exp substs) + | None -> None) in + let body' = List.fold_left subst_id_exp body substs in + let c' = (pat',guard',body',annot) in + group_aux (add_clause current c') acc cs + | None -> + let pat = remove_wildcards "g__" pat in + group_aux (pat,[c],annot) (acc @ [current]) cs) + | [] -> acc @ [current]) in + let groups = match clauses with + | ((pat,guard,body,annot) as c) :: cs -> + group_aux (remove_wildcards "g__" pat, [c], annot) [] cs + | _ -> + raise (Reporting_basic.err_unreachable l + "group given empty list in rewrite_guarded_clauses") in + List.map (fun cs -> if_pexp cs) groups + and if_pexp (pat,cs,annot) = (match cs with + | c :: _ -> + (* fix_effsum_pexp (pexp *) + let body = if_exp pat cs in + let pexp = fix_effsum_pexp (Pat_aux (Pat_exp (pat,body),annot)) in + let (Pat_aux (Pat_exp (_,_),annot)) = pexp in + (pat, body, annot) + | [] -> + raise (Reporting_basic.err_unreachable l + "if_pexp given empty list in rewrite_guarded_clauses")) + and if_exp current_pat = (function + | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs -> + (match guard with + | Some exp -> + let else_exp = + if equiv_pats current_pat pat' + then if_exp current_pat (c' :: cs) + else case_exp (pat_to_exp current_pat) (get_typ_annot annot') (group (c' :: cs)) in + fix_eff_exp (E_aux (E_if (exp,body,else_exp), annot)) + | None -> body) + | [(pat,guard,body,annot)] -> body + | [] -> + raise (Reporting_basic.err_unreachable l + "if_exp given empty list in rewrite_guarded_clauses")) in + group cs + +let rewrite_exp_remove_bitvector_pat rewriters (E_aux (exp,(l,annot)) as full_exp) = + let rewrap e = E_aux (e,(l,annot)) in + let rewrite_rec = rewriters.rewrite_exp rewriters in + let rewrite_base = rewrite_exp rewriters in + match exp with + | E_case (e,ps) + when List.exists (fun (Pat_aux (Pat_exp (pat,_),_)) -> contains_bitvector_pat pat) ps -> + let clause (Pat_aux (Pat_exp (pat,body),annot')) = + let (pat',(guard,decls,_)) = remove_bitvector_pat pat in + let body' = decls (rewrite_rec body) in + (pat',guard,body',annot') in + let clauses = rewrite_guarded_clauses l (List.map clause ps) in + if (effectful e) then + let e = rewrite_rec e in + let (E_aux (_,(el,eannot))) = e in + let pat_e' = fresh_id_pat "p__" (el,eannot) in + let exp_e' = pat_to_exp pat_e' in + (* let fresh = fresh_id "p__" el in + let exp_e' = E_aux (E_id fresh, gen_annot l (get_type e) pure_e) in + let pat_e' = P_aux (P_id fresh, gen_annot l (get_type e) pure_e) in *) + let letbind_e = LB_aux (LB_val_implicit (pat_e',e), (el,eannot)) in + let exp' = case_exp exp_e' (get_typ_exp full_exp) clauses in + rewrap (E_let (letbind_e, exp')) + else case_exp e (get_typ_exp full_exp) clauses + | E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) -> + let (pat,(_,decls,_)) = remove_bitvector_pat pat in + rewrap (E_let (LB_aux (LB_val_explicit (typ,pat,rewrite_rec v),annot'), + decls (rewrite_rec body))) + | E_let (LB_aux (LB_val_implicit (pat,v),annot'),body) -> + let (pat,(_,decls,_)) = remove_bitvector_pat pat in + rewrap (E_let (LB_aux (LB_val_implicit (pat,rewrite_rec v),annot'), + decls (rewrite_rec body))) + | _ -> rewrite_base full_exp + +let rewrite_fun_remove_bitvector_pat + rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = + let _ = reset_fresh_name_counter () in + (* TODO Can there be clauses with different id's in one FD_function? *) + let funcls = match funcls with + | (FCL_aux (FCL_Funcl(id,_,_),_) :: _) -> + let clause (FCL_aux (FCL_Funcl(_,pat,exp),annot)) = + let (pat,(guard,decls,_)) = remove_bitvector_pat pat in + let exp = decls (rewriters.rewrite_exp rewriters exp) in + (pat,guard,exp,annot) in + let cs = rewrite_guarded_clauses l (List.map clause funcls) in + List.map (fun (pat,exp,annot) -> FCL_aux (FCL_Funcl(id,pat,exp),annot)) cs + | _ -> funcls (* TODO is the empty list possible here? *) in + FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot)) + +let rewrite_defs_remove_bitvector_pats (Defs defs) = + let rewriters = + {rewrite_exp = rewrite_exp_remove_bitvector_pat; + rewrite_pat = rewrite_pat; + rewrite_let = rewrite_let; + rewrite_lexp = rewrite_lexp; + rewrite_fun = rewrite_fun_remove_bitvector_pat; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base } in + let rewrite_def d = + let d = rewriters.rewrite_def rewriters d in + match d with + | DEF_val (LB_aux (LB_val_explicit (t,pat,exp),a)) -> + let (pat',(_,_,letbinds)) = remove_bitvector_pat pat in + let defvals = List.map (fun lb -> DEF_val lb) letbinds in + [DEF_val (LB_aux (LB_val_explicit (t,pat',exp),a))] @ defvals + | DEF_val (LB_aux (LB_val_implicit (pat,exp),a)) -> + let (pat',(_,_,letbinds)) = remove_bitvector_pat pat in + let defvals = List.map (fun lb -> DEF_val lb) letbinds in + [DEF_val (LB_aux (LB_val_implicit (pat',exp),a))] @ defvals + | d -> [d] in + Defs (List.flatten (List.map rewrite_def defs)) + + +(*Expects to be called after rewrite_defs; thus the following should not appear: + internal_exp of any form + lit vectors in patterns or expressions + *) +let rewrite_exp_lift_assign_intro rewriters ((E_aux (exp,((l,_) as annot))) as full_exp) = + let rewrap e = E_aux (e,annot) in + let rewrap_effects e eff = + E_aux (e, (l,Some (get_env_annot annot, get_typ_annot annot, eff))) in + let rewrite_rec = rewriters.rewrite_exp rewriters in + let rewrite_base = rewrite_exp rewriters in + match exp with + | E_block exps -> + let rec walker exps = match exps with + | [] -> [] + | (E_aux(E_assign((LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_)) as le,e), + ((l, Some (env,typ,eff)) as annot)) as exp)::exps -> + (match Env.lookup_id id env with + | Unbound -> + let le' = rewriters.rewrite_lexp rewriters le in + let e' = rewrite_base e in + let exps' = walker exps in + let effects = union_eff_exps exps' in + let block = E_aux (E_block exps', (l, Some (env, unit_typ, effects))) in + [fix_eff_exp (E_aux (E_internal_let(le', e', block), annot))] + | _ -> (rewrite_rec exp)::(walker exps)) + (*| ((E_aux(E_if(c,t,e),(l,annot))) as exp)::exps -> + let vars_t = introduced_variables t in + let vars_e = introduced_variables e in + let new_vars = Envmap.intersect vars_t vars_e in + if Envmap.is_empty new_vars + then (rewrite_base exp)::walker exps + else + let new_nmap = match nmap with + | None -> Some(Nexpmap.empty,new_vars) + | Some(nm,s) -> Some(nm, Envmap.union new_vars s) in + let c' = rewrite_base c in + let t' = rewriters.rewrite_exp rewriters new_nmap t in + let e' = rewriters.rewrite_exp rewriters new_nmap e in + let exps' = walker exps in + fst ((Envmap.fold + (fun (res,effects) i (t,e) -> + let bitlit = E_aux (E_lit (L_aux(L_zero, Parse_ast.Generated l)), + (Parse_ast.Generated l, simple_annot bit_t)) in + let rangelit = E_aux (E_lit (L_aux (L_num 0, Parse_ast.Generated l)), + (Parse_ast.Generated l, simple_annot nat_t)) in + let set_exp = + match t.t with + | Tid "bit" | Tabbrev(_,{t=Tid "bit"}) -> bitlit + | Tapp("range", _) | Tapp("atom", _) -> rangelit + | Tapp("vector", [_;_;_;TA_typ ( {t=Tid "bit"} | {t=Tabbrev(_,{t=Tid "bit"})})]) + | Tapp(("reg"|"register"),[TA_typ ({t = Tapp("vector", + [_;_;_;TA_typ ( {t=Tid "bit"} + | {t=Tabbrev(_,{t=Tid "bit"})})])})]) + | Tabbrev(_,{t = Tapp("vector", + [_;_;_;TA_typ ( {t=Tid "bit"} + | {t=Tabbrev(_,{t=Tid "bit"})})])}) -> + E_aux (E_vector_indexed([], Def_val_aux(Def_val_dec bitlit, + (Parse_ast.Generated l,simple_annot bit_t))), + (Parse_ast.Generated l, simple_annot t)) + | _ -> e in + let unioneffs = union_effects effects (get_effsum_exp set_exp) in + ([E_aux (E_internal_let (LEXP_aux (LEXP_id (Id_aux (Id i, Parse_ast.Generated l)), + (Parse_ast.Generated l, (tag_annot t Emp_intro))), + set_exp, + E_aux (E_block res, (Parse_ast.Generated l, (simple_annot_efr unit_t effects)))), + (Parse_ast.Generated l, simple_annot_efr unit_t unioneffs))],unioneffs))) + (E_aux(E_if(c',t',e'),(Parse_ast.Generated l, annot))::exps',eff_union_exps (c'::t'::e'::exps')) new_vars)*) + | e::exps -> (rewrite_rec e)::(walker exps) + in + rewrap (E_block (walker exps)) + | E_assign(((LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),lannot)) as le),e) -> + let le' = rewriters.rewrite_lexp rewriters le in + let e' = rewrite_base e in + let effects = get_eff_exp e' in + (match Env.lookup_id id (get_env_annot annot) with + | Unbound -> + rewrap_effects + (E_internal_let(le', e', E_aux(E_block [], simple_annot l unit_typ))) + effects + | Local _ -> + let effects' = union_effects effects (get_eff_annot lannot) in + let annot' = Some (get_env_annot annot, unit_typ, effects') in + E_aux((E_assign(le', e')),(l, annot')) + | _ -> rewrite_base full_exp) + | _ -> rewrite_base full_exp + +let rewrite_lexp_lift_assign_intro rewriters ((LEXP_aux(lexp,annot)) as le) = + let rewrap le = LEXP_aux(le,annot) in + let rewrite_base = rewrite_lexp rewriters in + match lexp, annot with + | (LEXP_id id | LEXP_cast (_,id)), (l, Some (env, typ, eff)) -> + (match Env.lookup_id id env with + | Unbound | Local _ -> + LEXP_aux (lexp, (l, Some (env, typ, union_effects eff (mk_effect [BE_lset])))) + | _ -> rewrap lexp) + | _ -> rewrite_base le + + +let rewrite_defs_exp_lift_assign defs = rewrite_defs_base + {rewrite_exp = rewrite_exp_lift_assign_intro; + rewrite_pat = rewrite_pat; + rewrite_let = rewrite_let; + rewrite_lexp = rewrite_lexp_lift_assign_intro; + rewrite_fun = rewrite_fun; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base} defs + +(*let rewrite_exp_separate_ints rewriters ((E_aux (exp,((l,_) as annot))) as full_exp) = + (*let tparms,t,tag,nexps,eff,cum_eff,bounds = match annot with + | Base((tparms,t),tag,nexps,eff,cum_eff,bounds) -> tparms,t,tag,nexps,eff,cum_eff,bounds + | _ -> [],unit_t,Emp_local,[],pure_e,pure_e,nob in*) + let rewrap e = E_aux (e,annot) in + (*let rewrap_effects e effsum = + E_aux (e,(l,Base ((tparms,t),tag,nexps,eff,effsum,bounds))) in*) + let rewrite_rec = rewriters.rewrite_exp rewriters in + let rewrite_base = rewrite_exp rewriters in + match exp with + | E_lit (L_aux (((L_num _) as lit),_)) -> + (match (is_within_machine64 t nexps) with + | Yes -> let _ = Printf.eprintf "Rewriter of num_const, within 64bit int yes\n" in rewrite_base full_exp + | Maybe -> let _ = Printf.eprintf "Rewriter of num_const, within 64bit int maybe\n" in rewrite_base full_exp + | No -> let _ = Printf.eprintf "Rewriter of num_const, within 64bit int no\n" in E_aux(E_app(Id_aux (Id "integer_of_int",l),[rewrite_base full_exp]), + (l, Base((tparms,t),External(None),nexps,eff,cum_eff,bounds)))) + | E_cast (typ, exp) -> rewrap (E_cast (typ, rewrite_rec exp)) + | E_app (id,exps) -> rewrap (E_app (id,List.map rewrite_rec exps)) + | E_app_infix(el,id,er) -> rewrap (E_app_infix(rewrite_rec el,id,rewrite_rec er)) + | E_for (id, e1, e2, e3, o, body) -> + rewrap (E_for (id, rewrite_rec e1, rewrite_rec e2, rewrite_rec e3, o, rewrite_rec body)) + | E_vector_access (vec,index) -> rewrap (E_vector_access (rewrite_rec vec,rewrite_rec index)) + | E_vector_subrange (vec,i1,i2) -> + rewrap (E_vector_subrange (rewrite_rec vec,rewrite_rec i1,rewrite_rec i2)) + | E_vector_update (vec,index,new_v) -> + rewrap (E_vector_update (rewrite_rec vec,rewrite_rec index,rewrite_rec new_v)) + | E_vector_update_subrange (vec,i1,i2,new_v) -> + rewrap (E_vector_update_subrange (rewrite_rec vec,rewrite_rec i1,rewrite_rec i2,rewrite_rec new_v)) + | E_case (exp ,pexps) -> + rewrap (E_case (rewrite_rec exp, + (List.map + (fun (Pat_aux (Pat_exp(p,e),pannot)) -> + Pat_aux (Pat_exp(rewriters.rewrite_pat rewriters nmap p,rewrite_rec e),pannot)) pexps))) + | E_let (letbind,body) -> rewrap (E_let(rewriters.rewrite_let rewriters nmap letbind,rewrite_rec body)) + | E_internal_let (lexp,exp,body) -> + rewrap (E_internal_let (rewriters.rewrite_lexp rewriters nmap lexp, rewrite_rec exp, rewrite_rec body)) + | _ -> rewrite_base full_exp + +let rewrite_defs_separate_numbs defs = rewrite_defs_base + {rewrite_exp = rewrite_exp_separate_ints; + rewrite_pat = rewrite_pat; + rewrite_let = rewrite_let; (*will likely need a new one?*) + rewrite_lexp = rewrite_lexp; (*will likely need a new one?*) + rewrite_fun = rewrite_fun; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base} defs*) + +let rewrite_defs_ocaml defs = + let defs_sorted = top_sort_defs defs in + let defs_vec_concat_removed = rewrite_defs_remove_vector_concat defs_sorted in + let defs_lifted_assign = rewrite_defs_exp_lift_assign defs_vec_concat_removed in +(* let defs_separate_nums = rewrite_defs_separate_numbs defs_lifted_assign in *) + defs_lifted_assign + +let rewrite_defs_remove_blocks = + let letbind_wild v body = + let (E_aux (_,(l,tannot))) = v in + let annot_pat = (simple_annot l (get_typ_exp v)) in + let annot_lb = (Parse_ast.Generated l, tannot) in + let annot_let = (Parse_ast.Generated l, Some (get_env_exp body, get_typ_exp body, union_eff_exps [v;body])) in + E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_wild,annot_pat),v),annot_lb),body),annot_let) in + + let rec f l = function + | [] -> E_aux (E_lit (L_aux (L_unit,Parse_ast.Generated l)), (simple_annot l unit_typ)) + | [e] -> e (* check with Kathy if that annotation is fine *) + | e :: es -> letbind_wild e (f l es) in + + let e_aux = function + | (E_block es,(l,_)) -> f l es + | (e,annot) -> E_aux (e,annot) in + + let alg = { id_exp_alg with e_aux = e_aux } in + + rewrite_defs_base + {rewrite_exp = (fun _ -> fold_exp alg) + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + + + +let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = + (* body is a function : E_id variable -> actual body *) + match get_typ_exp v with + | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> + let (E_aux (_,(l,annot))) = v in + let e = E_aux (E_lit (L_aux (L_unit,Parse_ast.Generated l)),(simple_annot l unit_typ)) in + let body = body e in + let annot_pat = simple_annot l unit_typ in + let annot_lb = annot_pat in + let annot_let = (Parse_ast.Generated l, Some (get_env_exp body, get_typ_exp body, union_eff_exps [v;body])) in + let pat = P_aux (P_wild,annot_pat) in + + E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) + | _ -> + let (E_aux (_,((l,_) as annot))) = v in + let id = fresh_id "w__" l in + let e_id = E_aux (E_id id, annot) in + let body = body e_id in + + let annot_pat = simple_annot l (get_typ_exp v) in + let annot_lb = annot_pat in + let annot_let = (Parse_ast.Generated l, Some (get_env_exp body, get_typ_exp body, union_eff_exps [v;body])) in + let pat = P_aux (P_id id,annot_pat) in + + E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) + + +let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp = + match l with + | [] -> k [] + | exp :: exps -> f exp (fun exp -> mapCont f exps (fun exps -> k (exp :: exps))) + +let rewrite_defs_letbind_effects = + + let rec value ((E_aux (exp_aux,_)) as exp) = + not (effectful exp) && not (updates_vars exp) + and value_optdefault (Def_val_aux (o,_)) = match o with + | Def_val_empty -> true + | Def_val_dec e -> value e + and value_fexps (FES_aux (FES_Fexps (fexps,_),_)) = + List.fold_left (fun b (FE_aux (FE_Fexp (_,e),_)) -> b && value e) true fexps in + + + let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = + n_exp exp (fun exp -> if value exp then k exp else letbind exp k) + + and n_exp_pure (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = + n_exp exp (fun exp -> if not (effectful exp || updates_vars exp) then k exp else letbind exp k) + + and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp = + mapCont n_exp_name exps k + + and n_fexp (fexp : 'a fexp) (k : 'a fexp -> 'a exp) : 'a exp = + let (FE_aux (FE_Fexp (id,exp),annot)) = fexp in + n_exp_name exp (fun exp -> + k (fix_effsum_fexp (FE_aux (FE_Fexp (id,exp),annot)))) + + and n_fexpL (fexps : 'a fexp list) (k : 'a fexp list -> 'a exp) : 'a exp = + mapCont n_fexp fexps k + + and n_pexp (newreturn : bool) (pexp : 'a pexp) (k : 'a pexp -> 'a exp) : 'a exp = + let (Pat_aux (Pat_exp (pat,exp),annot)) = pexp in + k (fix_effsum_pexp (Pat_aux (Pat_exp (pat,n_exp_term newreturn exp), annot))) + + and n_pexpL (newreturn : bool) (pexps : 'a pexp list) (k : 'a pexp list -> 'a exp) : 'a exp = + mapCont (n_pexp newreturn) pexps k + + and n_fexps (fexps : 'a fexps) (k : 'a fexps -> 'a exp) : 'a exp = + let (FES_aux (FES_Fexps (fexps_aux,b),annot)) = fexps in + n_fexpL fexps_aux (fun fexps_aux -> + k (fix_effsum_fexps (FES_aux (FES_Fexps (fexps_aux,b),annot)))) + + and n_opt_default (opt_default : 'a opt_default) (k : 'a opt_default -> 'a exp) : 'a exp = + let (Def_val_aux (opt_default,annot)) = opt_default in + match opt_default with + | Def_val_empty -> k (Def_val_aux (Def_val_empty,annot)) + | Def_val_dec exp -> + n_exp_name exp (fun exp -> + k (fix_effsum_opt_default (Def_val_aux (Def_val_dec exp,annot)))) + + and n_lb (lb : 'a letbind) (k : 'a letbind -> 'a exp) : 'a exp = + let (LB_aux (lb,annot)) = lb in + match lb with + | LB_val_explicit (typ,pat,exp1) -> + n_exp exp1 (fun exp1 -> + k (fix_effsum_lb (LB_aux (LB_val_explicit (typ,pat,exp1),annot)))) + | LB_val_implicit (pat,exp1) -> + n_exp exp1 (fun exp1 -> + k (fix_effsum_lb (LB_aux (LB_val_implicit (pat,exp1),annot)))) + + and n_lexp (lexp : 'a lexp) (k : 'a lexp -> 'a exp) : 'a exp = + let (LEXP_aux (lexp_aux,annot)) = lexp in + match lexp_aux with + | LEXP_id _ -> k lexp + | LEXP_memory (id,es) -> + n_exp_nameL es (fun es -> + k (fix_effsum_lexp (LEXP_aux (LEXP_memory (id,es),annot)))) + | LEXP_cast (typ,id) -> + k (fix_effsum_lexp (LEXP_aux (LEXP_cast (typ,id),annot))) + | LEXP_vector (lexp,e) -> + n_lexp lexp (fun lexp -> + n_exp_name e (fun e -> + k (fix_effsum_lexp (LEXP_aux (LEXP_vector (lexp,e),annot))))) + | LEXP_vector_range (lexp,e1,e2) -> + n_lexp lexp (fun lexp -> + n_exp_name e1 (fun e1 -> + n_exp_name e2 (fun e2 -> + k (fix_effsum_lexp (LEXP_aux (LEXP_vector_range (lexp,e1,e2),annot)))))) + | LEXP_field (lexp,id) -> + n_lexp lexp (fun lexp -> + k (fix_effsum_lexp (LEXP_aux (LEXP_field (lexp,id),annot)))) + + and n_exp_term (newreturn : bool) (exp : 'a exp) : 'a exp = + let (E_aux (_,(l,tannot))) = exp in + let exp = + if newreturn then + E_aux (E_internal_return exp,(Parse_ast.Generated l, tannot)) + else + exp in + (* n_exp_term forces an expression to be translated into a form + "let .. let .. let .. in EXP" where EXP has no effect and does not update + variables *) + n_exp_pure exp (fun exp -> exp) + + and n_exp (E_aux (exp_aux,annot) as exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = + + let rewrap e = fix_eff_exp (E_aux (e,annot)) in + + match exp_aux with + | E_block es -> failwith "E_block should have been removed till now" + | E_nondet _ -> failwith "E_nondet not supported" + | E_id id -> k exp + | E_lit _ -> k exp + | E_cast (typ,exp') -> + n_exp_name exp' (fun exp' -> + k (rewrap (E_cast (typ,exp')))) + | E_app (id,exps) -> + n_exp_nameL exps (fun exps -> + k (rewrap (E_app (id,exps)))) + | E_app_infix (exp1,id,exp2) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + k (rewrap (E_app_infix (exp1,id,exp2))))) + | E_tuple exps -> + n_exp_nameL exps (fun exps -> + k (rewrap (E_tuple exps))) + | E_if (exp1,exp2,exp3) -> + n_exp_name exp1 (fun exp1 -> + let (E_aux (_,annot2)) = exp2 in + let (E_aux (_,annot3)) = exp3 in + let newreturn = effectful exp2 || effectful exp3 in + let exp2 = n_exp_term newreturn exp2 in + let exp3 = n_exp_term newreturn exp3 in + k (rewrap (E_if (exp1,exp2,exp3)))) + | E_for (id,start,stop,by,dir,body) -> + n_exp_name start (fun start -> + n_exp_name stop (fun stop -> + n_exp_name by (fun by -> + let body = n_exp_term (effectful body) body in + k (rewrap (E_for (id,start,stop,by,dir,body)))))) + | E_vector exps -> + n_exp_nameL exps (fun exps -> + k (rewrap (E_vector exps))) + | E_vector_indexed (exps,opt_default) -> + let (is,exps) = List.split exps in + n_exp_nameL exps (fun exps -> + n_opt_default opt_default (fun opt_default -> + let exps = List.combine is exps in + k (rewrap (E_vector_indexed (exps,opt_default))))) + | E_vector_access (exp1,exp2) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + k (rewrap (E_vector_access (exp1,exp2))))) + | E_vector_subrange (exp1,exp2,exp3) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + n_exp_name exp3 (fun exp3 -> + k (rewrap (E_vector_subrange (exp1,exp2,exp3)))))) + | E_vector_update (exp1,exp2,exp3) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + n_exp_name exp3 (fun exp3 -> + k (rewrap (E_vector_update (exp1,exp2,exp3)))))) + | E_vector_update_subrange (exp1,exp2,exp3,exp4) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + n_exp_name exp3 (fun exp3 -> + n_exp_name exp4 (fun exp4 -> + k (rewrap (E_vector_update_subrange (exp1,exp2,exp3,exp4))))))) + | E_vector_append (exp1,exp2) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + k (rewrap (E_vector_append (exp1,exp2))))) + | E_list exps -> + n_exp_nameL exps (fun exps -> + k (rewrap (E_list exps))) + | E_cons (exp1,exp2) -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + k (rewrap (E_cons (exp1,exp2))))) + | E_record fexps -> + n_fexps fexps (fun fexps -> + k (rewrap (E_record fexps))) + | E_record_update (exp1,fexps) -> + n_exp_name exp1 (fun exp1 -> + n_fexps fexps (fun fexps -> + k (rewrap (E_record_update (exp1,fexps))))) + | E_field (exp1,id) -> + n_exp_name exp1 (fun exp1 -> + k (rewrap (E_field (exp1,id)))) + | E_case (exp1,pexps) -> + let newreturn = + List.fold_left + (fun b (Pat_aux (_,annot)) -> b || effectful_effs (get_eff_annot annot)) + false pexps in + n_exp_name exp1 (fun exp1 -> + n_pexpL newreturn pexps (fun pexps -> + k (rewrap (E_case (exp1,pexps))))) + | E_let (lb,body) -> + n_lb lb (fun lb -> + rewrap (E_let (lb,n_exp body k))) + | E_sizeof nexp -> + k (rewrap (E_sizeof nexp)) + | E_sizeof_internal annot -> + k (rewrap (E_sizeof_internal annot)) + | E_assign (lexp,exp1) -> + n_lexp lexp (fun lexp -> + n_exp_name exp1 (fun exp1 -> + k (rewrap (E_assign (lexp,exp1))))) + | E_exit exp' -> k (E_aux (E_exit (n_exp_term (effectful exp') exp'),annot)) + | E_assert (exp1,exp2) -> + n_exp exp1 (fun exp1 -> + n_exp exp2 (fun exp2 -> + k (rewrap (E_assert (exp1,exp2))))) + | E_internal_cast (annot',exp') -> + n_exp_name exp' (fun exp' -> + k (rewrap (E_internal_cast (annot',exp')))) + | E_internal_exp _ -> k exp + | E_internal_exp_user _ -> k exp + | E_internal_let (lexp,exp1,exp2) -> + n_lexp lexp (fun lexp -> + n_exp exp1 (fun exp1 -> + rewrap (E_internal_let (lexp,exp1,n_exp exp2 k)))) + | E_internal_return exp1 -> + n_exp_name exp1 (fun exp1 -> + k (rewrap (E_internal_return exp1))) + | E_comment str -> + k (rewrap (E_comment str)) + | E_comment_struc exp' -> + n_exp exp' (fun exp' -> + k (rewrap (E_comment_struc exp'))) + | E_return exp' -> + n_exp_name exp' (fun exp' -> + k (rewrap (E_return exp'))) + | E_internal_plet _ -> failwith "E_internal_plet should not be here yet" in + + let rewrite_fun _ (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),fdannot)) = + let newreturn = + List.fold_left + (fun b (FCL_aux (FCL_Funcl(id,pat,exp),annot)) -> + b || effectful_effs (get_eff_annot annot)) false funcls in + let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),annot)) = + let _ = reset_fresh_name_counter () in + FCL_aux (FCL_Funcl (id,pat,n_exp_term newreturn exp),annot) + in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),fdannot) in + rewrite_defs_base + {rewrite_exp = rewrite_exp + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + +let rewrite_defs_effectful_let_expressions = + + let e_let (lb,body) = + match lb with + | LB_aux (LB_val_explicit (_,pat,exp'),annot') + | LB_aux (LB_val_implicit (pat,exp'),annot') -> + if effectful exp' + then E_internal_plet (pat,exp',body) + else E_let (lb,body) in + + let e_internal_let = fun (lexp,exp1,exp2) -> + if effectful exp1 then + match lexp with + | LEXP_aux (LEXP_id id,annot) + | LEXP_aux (LEXP_cast (_,id),annot) -> + E_internal_plet (P_aux (P_id id,annot),exp1,exp2) + | _ -> failwith "E_internal_plet with unexpected lexp" + else E_internal_let (lexp,exp1,exp2) in + + let alg = { id_exp_alg with e_let = e_let; e_internal_let = e_internal_let } in + rewrite_defs_base + {rewrite_exp = (fun _ -> fold_exp alg) + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + + +(* Now all expressions have no blocks anymore, any term is a sequence of let-expressions, + * internal let-expressions, or internal plet-expressions ended by a term that does not + * access memory or registers and does not update variables *) + +let dedup eq = + List.fold_left (fun acc e -> if List.exists (eq e) acc then acc else e :: acc) [] + +let eqidtyp (id1,_) (id2,_) = + let name1 = match id1 with Id_aux ((Id name | DeIid name),_) -> name in + let name2 = match id2 with Id_aux ((Id name | DeIid name),_) -> name in + name1 = name2 + +let find_updated_vars exp = + let ( @@ ) (a,b) (a',b') = (a @ a',b @ b') in + let lapp2 (l : (('a list * 'b list) list)) : ('a list * 'b list) = + List.fold_left + (fun ((intros_acc : 'a list),(updates_acc : 'b list)) (intros,updates) -> + (intros_acc @ intros, updates_acc @ updates)) ([],[]) l in + + let (intros,updates) = + fold_exp + { e_aux = (fun (e,_) -> e) + ; e_id = (fun _ -> ([],[])) + ; e_lit = (fun _ -> ([],[])) + ; e_cast = (fun (_,e) -> e) + ; e_block = (fun es -> lapp2 es) + ; e_nondet = (fun es -> lapp2 es) + ; e_app = (fun (_,es) -> lapp2 es) + ; e_app_infix = (fun (e1,_,e2) -> e1 @@ e2) + ; e_tuple = (fun es -> lapp2 es) + ; e_if = (fun (e1,e2,e3) -> e1 @@ e2 @@ e3) + ; e_for = (fun (_,e1,e2,e3,_,e4) -> e1 @@ e2 @@ e3 @@ e4) + ; e_vector = (fun es -> lapp2 es) + ; e_vector_indexed = (fun (es,opt) -> opt @@ lapp2 (List.map snd es)) + ; e_vector_access = (fun (e1,e2) -> e1 @@ e2) + ; e_vector_subrange = (fun (e1,e2,e3) -> e1 @@ e2 @@ e3) + ; e_vector_update = (fun (e1,e2,e3) -> e1 @@ e2 @@ e3) + ; e_vector_update_subrange = (fun (e1,e2,e3,e4) -> e1 @@ e2 @@ e3 @@ e4) + ; e_vector_append = (fun (e1,e2) -> e1 @@ e2) + ; e_list = (fun es -> lapp2 es) + ; e_cons = (fun (e1,e2) -> e1 @@ e2) + ; e_record = (fun fexps -> fexps) + ; e_record_update = (fun (e1,fexp) -> e1 @@ fexp) + ; e_field = (fun (e1,id) -> e1) + ; e_case = (fun (e1,pexps) -> e1 @@ lapp2 pexps) + ; e_let = (fun (lb,e2) -> lb @@ e2) + ; e_assign = (fun ((ids,acc),e2) -> ([],ids) @@ acc @@ e2) + ; e_exit = (fun e1 -> ([],[])) + ; e_return = (fun e1 -> e1) + ; e_assert = (fun (e1,e2) -> ([],[])) + ; e_internal_cast = (fun (_,e1) -> e1) + ; e_internal_exp = (fun _ -> ([],[])) + ; e_internal_exp_user = (fun _ -> ([],[])) + ; e_internal_let = + (fun (([id],acc),e2,e3) -> + let (xs,ys) = ([id],[]) @@ acc @@ e2 @@ e3 in + let ys = List.filter (fun id2 -> not (eqidtyp id id2)) ys in + (xs,ys)) + ; e_internal_plet = (fun (_, e1, e2) -> e1 @@ e2) + ; e_internal_return = (fun e -> e) + ; lEXP_id = (fun id -> (Some id,[],([],[]))) + ; lEXP_memory = (fun (_,es) -> (None,[],lapp2 es)) + ; lEXP_cast = (fun (_,id) -> (Some id,[],([],[]))) + ; lEXP_tup = (fun tups -> failwith "FORCHRISTOPHER:: this needs implementing, not sure what you want to do") + ; lEXP_vector = (fun ((ids,acc),e1) -> (None,ids,acc @@ e1)) + ; lEXP_vector_range = (fun ((ids,acc),e1,e2) -> (None,ids,acc @@ e1 @@ e2)) + ; lEXP_field = (fun ((ids,acc),_) -> (None,ids,acc)) + ; lEXP_aux = + (function + | ((Some id,ids,acc),(annot)) -> + (match Env.lookup_id id (get_env_annot annot) with + | Unbound | Local _ -> ((id,annot) :: ids,acc) + | _ -> (ids,acc)) + | ((_,ids,acc),_) -> (ids,acc) + ) + ; fE_Fexp = (fun (_,e) -> e) + ; fE_aux = (fun (fexp,_) -> fexp) + ; fES_Fexps = (fun (fexps,_) -> lapp2 fexps) + ; fES_aux = (fun (fexp,_) -> fexp) + ; def_val_empty = ([],[]) + ; def_val_dec = (fun e -> e) + ; def_val_aux = (fun (defval,_) -> defval) + ; pat_exp = (fun (_,e) -> e) + ; pat_aux = (fun (pexp,_) -> pexp) + ; lB_val_explicit = (fun (_,_,e) -> e) + ; lB_val_implicit = (fun (_,e) -> e) + ; lB_aux = (fun (lb,_) -> lb) + ; pat_alg = id_pat_alg + } exp in + dedup eqidtyp updates + +let swaptyp typ (l,tannot) = match tannot with + | Some (env, typ', eff) -> (l, Some (env, typ, eff)) + | _ -> raise (Reporting_basic.err_unreachable l "swaptyp called with empty type annotation") + +let mktup l es = + match es with + | [] -> E_aux (E_lit (L_aux (L_unit,Parse_ast.Generated l)),(simple_annot l unit_typ)) + | [e] -> e + | e :: _ -> + let effs = + List.fold_left (fun acc e -> union_effects acc (get_eff_exp e)) no_effect es in + let typ = mk_typ (Typ_tup (List.map get_typ_exp es)) in + E_aux (E_tuple es,(Parse_ast.Generated l, Some (get_env_exp e, typ, effs))) + +let mktup_pat l es = + match es with + | [] -> P_aux (P_wild,(simple_annot l unit_typ)) + | [E_aux (E_id id,_) as exp] -> + P_aux (P_id id,(simple_annot l (get_typ_exp exp))) + | _ -> + let typ = mk_typ (Typ_tup (List.map get_typ_exp es)) in + let pats = List.map (function + | (E_aux (E_id id,_) as exp) -> + P_aux (P_id id,(simple_annot l (get_typ_exp exp))) + | exp -> + P_aux (P_wild,(simple_annot l (get_typ_exp exp)))) es in + P_aux (P_tup pats,(simple_annot l typ)) + + +type 'a updated_term = + | Added_vars of 'a exp * 'a pat + | Same_vars of 'a exp + +let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = + + let rec add_vars overwrite ((E_aux (expaux,annot)) as exp) vars = + match expaux with + | E_let (lb,exp) -> + let exp = add_vars overwrite exp vars in + E_aux (E_let (lb,exp),swaptyp (get_typ_exp exp) annot) + | E_internal_let (lexp,exp1,exp2) -> + let exp2 = add_vars overwrite exp2 vars in + E_aux (E_internal_let (lexp,exp1,exp2), swaptyp (get_typ_exp exp2) annot) + | E_internal_plet (pat,exp1,exp2) -> + let exp2 = add_vars overwrite exp2 vars in + E_aux (E_internal_plet (pat,exp1,exp2), swaptyp (get_typ_exp exp2) annot) + | E_internal_return exp2 -> + let exp2 = add_vars overwrite exp2 vars in + E_aux (E_internal_return exp2,swaptyp (get_typ_exp exp2) annot) + | _ -> + (* after rewrite_defs_letbind_effects there cannot be terms that have + effects/update local variables in "tail-position": check n_exp_term + and where it is used. *) + if overwrite then + match get_typ_exp exp with + | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> vars + | _ -> raise (Reporting_basic.err_unreachable l + "add_vars: left-over unit expression in tail position after rewriting") + else + let typ' = Typ_aux (Typ_tup [get_typ_exp exp;get_typ_exp vars], Parse_ast.Generated l) in + E_aux (E_tuple [exp;vars],swaptyp typ' annot) in + + let rewrite (E_aux (expaux,((el,_) as annot))) (P_aux (_,(pl,pannot)) as pat) = + let overwrite = match get_typ_annot annot with + | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> true + | _ -> false in + match expaux with + | E_for(id,exp1,exp2,exp3,order,exp4) -> + (* Translate for loops into calls to one of the foreach combinators. + The loop body becomes a function of the loop variable and any + mutable local variables that are updated inside the loop. + Since the foreach* combinators are higher-order functions, + they cannot be represented faithfully in the AST. The following + code abuses the parameters of an E_app node, embedding the loop body + function as an expression followed by the list of variables it + expects. In (Lem) pretty-printing, this turned into an anonymous + function and passed to foreach*. *) + let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars exp4) in + let vartuple = mktup el vars in + let exp4 = rewrite_var_updates (add_vars overwrite exp4 vartuple) in + let (E_aux (_,(_,annot4))) = exp4 in + let fname = match effectful exp4,order with + | false, Ord_aux (Ord_inc,_) -> "foreach_inc" + | false, Ord_aux (Ord_dec,_) -> "foreach_dec" + | true, Ord_aux (Ord_inc,_) -> "foreachM_inc" + | true, Ord_aux (Ord_dec,_) -> "foreachM_dec" in + let funcl = Id_aux (Id fname,Parse_ast.Generated el) in + let loopvar = + (* Don't bother with creating a range type annotation, since the + Lem pretty-printing does not use it. *) + (* let (bf,tf) = match get_typ_exp exp1 with + | {t = Tapp ("atom",[TA_nexp f])} -> (TA_nexp f,TA_nexp f) + | {t = Tapp ("reg", [TA_typ {t = Tapp ("atom",[TA_nexp f])}])} -> (TA_nexp f,TA_nexp f) + | {t = Tapp ("range",[TA_nexp bf;TA_nexp tf])} -> (TA_nexp bf,TA_nexp tf) + | {t = Tapp ("reg", [TA_typ {t = Tapp ("range",[TA_nexp bf;TA_nexp tf])}])} -> (TA_nexp bf,TA_nexp tf) + | {t = Tapp (name,_)} -> failwith (name ^ " shouldn't be here") in + let (bt,tt) = match get_typ_exp exp2 with + | {t = Tapp ("atom",[TA_nexp t])} -> (TA_nexp t,TA_nexp t) + | {t = Tapp ("atom",[TA_typ {t = Tapp ("atom", [TA_nexp t])}])} -> (TA_nexp t,TA_nexp t) + | {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])} -> (TA_nexp bt,TA_nexp tt) + | {t = Tapp ("atom",[TA_typ {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])}])} -> (TA_nexp bt,TA_nexp tt) + | {t = Tapp (name,_)} -> failwith (name ^ " shouldn't be here") in + let t = {t = Tapp ("range",match order with + | Ord_aux (Ord_inc,_) -> [bf;tt] + | Ord_aux (Ord_dec,_) -> [tf;bt])} in *) + E_aux (E_id id, simple_annot l int_typ) in + let v = E_aux (E_app (funcl,[loopvar;mktup el [exp1;exp2;exp3];exp4;vartuple]), + (Parse_ast.Generated el, annot4)) in + let pat = + if overwrite then mktup_pat el vars + else P_aux (P_tup [pat; mktup_pat pl vars], + simple_annot pl (get_typ_exp v)) in + Added_vars (v,pat) + | E_if (c,e1,e2) -> + let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) + (dedup eqidtyp (find_updated_vars e1 @ find_updated_vars e2)) in + if vars = [] then + (Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot))) + else + let vartuple = mktup el vars in + let e1 = rewrite_var_updates (add_vars overwrite e1 vartuple) in + let e2 = rewrite_var_updates (add_vars overwrite e2 vartuple) in + (* after rewrite_defs_letbind_effects c has no variable updates *) + let env = get_env_annot annot in + let typ = get_typ_exp e1 in + let eff = union_eff_exps [e1;e2] in + let v = E_aux (E_if (c,e1,e2), (Parse_ast.Generated el, Some (env, typ, eff))) in + let pat = + if overwrite then mktup_pat el vars + else P_aux (P_tup [pat; mktup_pat pl vars], + (simple_annot pl (get_typ_exp v))) in + Added_vars (v,pat) + | E_case (e1,ps) -> + (* after rewrite_defs_letbind_effects e1 needs no rewriting *) + let vars = + let f acc (Pat_aux (Pat_exp (_,e),_)) = acc @ find_updated_vars e in + List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) + (dedup eqidtyp (List.fold_left f [] ps)) in + if vars = [] then + let ps = List.map (fun (Pat_aux (Pat_exp (p,e),a)) -> Pat_aux (Pat_exp (p,rewrite_var_updates e),a)) ps in + Same_vars (E_aux (E_case (e1,ps),annot)) + else + let vartuple = mktup el vars in + let typ = + let (Pat_aux (Pat_exp (_,first),_)) = List.hd ps in + get_typ_exp first in + let (ps,typ,effs) = + let f (acc,typ,effs) (Pat_aux (Pat_exp (p,e),pannot)) = + let etyp = get_typ_exp e in + let () = assert (string_of_typ etyp = string_of_typ typ) in + let e = rewrite_var_updates (add_vars overwrite e vartuple) in + let pannot = simple_annot pl (get_typ_exp e) in + let effs = union_effects effs (get_eff_exp e) in + let pat' = Pat_aux (Pat_exp (p,e),pannot) in + (acc @ [pat'],typ,effs) in + List.fold_left f ([],typ,no_effect) ps in + let v = E_aux (E_case (e1,ps), (Parse_ast.Generated pl, Some (get_env_annot annot, typ, effs))) in + let pat = + if overwrite then mktup_pat el vars + else P_aux (P_tup [pat; mktup_pat pl vars], + (simple_annot pl (get_typ_exp v))) in + Added_vars (v,pat) + | E_assign (lexp,vexp) -> + let effs = match get_eff_annot annot with + | Effect_aux (Effect_set effs, _) -> effs + | _ -> + raise (Reporting_basic.err_unreachable l + "assignment without effects annotation") in + if not (List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs) then + Same_vars (E_aux (E_assign (lexp,vexp),annot)) + else + (match lexp with + | LEXP_aux (LEXP_id id,annot) -> + let pat = P_aux (P_id id, simple_annot pl (get_typ_exp vexp)) in + Added_vars (vexp,pat) + | LEXP_aux (LEXP_cast (_,id),annot) -> + let pat = P_aux (P_id id, simple_annot pl (get_typ_exp vexp)) in + Added_vars (vexp,pat) + | LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,((l2,_) as annot2)),i),((l1,_) as annot)) -> + let eid = E_aux (E_id id, simple_annot l2 (get_typ_annot annot2)) in + let vexp = E_aux (E_vector_update (eid,i,vexp), + simple_annot l1 (get_typ_annot annot)) in + let pat = P_aux (P_id id, simple_annot pl (get_typ_exp vexp)) in + Added_vars (vexp,pat) + | LEXP_aux (LEXP_vector_range (LEXP_aux (LEXP_id id,((l2,_) as annot2)),i,j), + ((l,_) as annot)) -> + let eid = E_aux (E_id id, simple_annot l2 (get_typ_annot annot2)) in + let vexp = E_aux (E_vector_update_subrange (eid,i,j,vexp), + simple_annot l (get_typ_annot annot)) in + let pat = P_aux (P_id id, simple_annot pl (get_typ_exp vexp)) in + Added_vars (vexp,pat)) + | _ -> + (* after rewrite_defs_letbind_effects this expression is pure and updates + no variables: check n_exp_term and where it's used. *) + Same_vars (E_aux (expaux,annot)) in + + match expaux with + | E_let (lb,body) -> + let body = rewrite_var_updates body in + let (eff,lb) = match lb with + | LB_aux (LB_val_implicit (pat,v),lbannot) -> + (match rewrite v pat with + | Added_vars (v,pat) -> + let (E_aux (_,(l,_))) = v in + let lbannot = (simple_annot l (get_typ_exp v)) in + (get_eff_exp v,LB_aux (LB_val_implicit (pat,v),lbannot)) + | Same_vars v -> (get_eff_exp v,LB_aux (LB_val_implicit (pat,v),lbannot))) + | LB_aux (LB_val_explicit (typ,pat,v),lbannot) -> + (match rewrite v pat with + | Added_vars (v,pat) -> + let (E_aux (_,(l,_))) = v in + let lbannot = (simple_annot l (get_typ_exp v)) in + (get_eff_exp v,LB_aux (LB_val_implicit (pat,v),lbannot)) + | Same_vars v -> (get_eff_exp v,LB_aux (LB_val_explicit (typ,pat,v),lbannot))) in + let tannot = Some (get_env_annot annot, get_typ_exp body, union_effects eff (get_eff_exp body)) in + E_aux (E_let (lb,body),(Parse_ast.Generated l,tannot)) + | E_internal_let (lexp,v,body) -> + (* Rewrite E_internal_let into E_let and call recursively *) + let id = match lexp with + | LEXP_aux (LEXP_id id,_) -> id + | LEXP_aux (LEXP_cast (_,id),_) -> id in + let env = get_env_annot annot in + let vtyp = get_typ_exp v in + let veff = get_eff_exp v in + let bodyenv = get_env_exp body in + let bodytyp = get_typ_exp body in + let bodyeff = get_eff_exp body in + let pat = P_aux (P_id id, (simple_annot l vtyp)) in + let lbannot = (Parse_ast.Generated l, Some (env, vtyp, veff)) in + let lb = LB_aux (LB_val_implicit (pat,v),lbannot) in + let exp = E_aux (E_let (lb,body),(Parse_ast.Generated l, Some (bodyenv, bodytyp, union_effects veff bodyeff))) in + rewrite_var_updates exp + | E_internal_plet (pat,v,body) -> + failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet" + (* There are no expressions that have effects or variable updates in + "tail-position": check the definition nexp_term and where it is used. *) + | _ -> exp + +let replace_memwrite_e_assign exp = + let e_aux = fun (expaux,annot) -> + match expaux with + | E_assign (LEXP_aux (LEXP_memory (id,args),_),v) -> E_aux (E_app (id,args @ [v]),annot) + | _ -> E_aux (expaux,annot) in + fold_exp { id_exp_alg with e_aux = e_aux } exp + + + +let remove_reference_types exp = + + let rec rewrite_t (Typ_aux (t_aux,a)) = (Typ_aux (rewrite_t_aux t_aux,a)) + and rewrite_t_aux t_aux = match t_aux with + | Typ_app (Id_aux (Id "reg",_), [Typ_arg_aux (Typ_arg_typ (Typ_aux (t_aux2, _)), _)]) -> + rewrite_t_aux t_aux2 + | Typ_app (name,t_args) -> Typ_app (name,List.map rewrite_t_arg t_args) + | Typ_fn (t1,t2,eff) -> Typ_fn (rewrite_t t1,rewrite_t t2,eff) + | Typ_tup ts -> Typ_tup (List.map rewrite_t ts) + | _ -> t_aux + and rewrite_t_arg t_arg = match t_arg with + | Typ_arg_aux (Typ_arg_typ t, a) -> Typ_arg_aux (Typ_arg_typ (rewrite_t t), a) + | _ -> t_arg in + + let rec rewrite_annot = function + | (l, None) -> (l, None) + | (l, Some (env, typ, eff)) -> (l, Some (env, rewrite_t typ, eff)) in + + map_exp_annot rewrite_annot exp + + + +let rewrite_defs_remove_superfluous_letbinds = + + let rec small (E_aux (exp,_)) = match exp with + | E_id _ + | E_lit _ -> true + | E_cast (_,e) -> small e + | E_list es -> List.for_all small es + | E_cons (e1,e2) -> small e1 && small e2 + | E_sizeof _ -> true + | _ -> false in + + let e_aux (exp,annot) = match exp with + | E_let (lb,exp2) -> + begin match lb,exp2 with + (* 'let x = EXP1 in x' can be replaced with 'EXP1' *) + | LB_aux (LB_val_explicit (_,P_aux (P_id (Id_aux (id,_)),_),exp1),_), + E_aux (E_id (Id_aux (id',_)),_) + | LB_aux (LB_val_explicit (_,P_aux (P_id (Id_aux (id,_)),_),exp1),_), + E_aux (E_cast (_,E_aux (E_id (Id_aux (id',_)),_)),_) + | LB_aux (LB_val_implicit (P_aux (P_id (Id_aux (id,_)),_),exp1),_), + E_aux (E_id (Id_aux (id',_)),_) + | LB_aux (LB_val_implicit (P_aux (P_id (Id_aux (id,_)),_),exp1),_), + E_aux (E_cast (_,E_aux (E_id (Id_aux (id',_)),_)),_) + when id = id' -> + exp1 + (* "let x = EXP1 in return x" can be replaced with 'return (EXP1)', at + least when EXP1 is 'small' enough *) + | LB_aux (LB_val_explicit (_,P_aux (P_id (Id_aux (id,_)),_),exp1),_), + E_aux (E_internal_return (E_aux (E_id (Id_aux (id',_)),_)),_) + | LB_aux (LB_val_implicit (P_aux (P_id (Id_aux (id,_)),_),exp1),_), + E_aux (E_internal_return (E_aux (E_id (Id_aux (id',_)),_)),_) + when id = id' && small exp1 -> + let (E_aux (_,e1annot)) = exp1 in + E_aux (E_internal_return (exp1),e1annot) + | _ -> E_aux (exp,annot) + end + | _ -> E_aux (exp,annot) in + + let alg = { id_exp_alg with e_aux = e_aux } in + rewrite_defs_base + { rewrite_exp = (fun _ -> fold_exp alg) + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + + +let rewrite_defs_remove_superfluous_returns = + + let has_unittype e = match get_typ_exp e with + | Typ_aux (Typ_id Id_aux (Id "unit", _), _) -> true + | _ -> false in + + let e_aux (exp,annot) = match exp with + | E_internal_plet (pat,exp1,exp2) -> + begin match pat,exp2 with + | P_aux (P_lit (L_aux (lit,_)),_), + E_aux (E_internal_return (E_aux (E_lit (L_aux (lit',_)),_)),_) + when lit = lit' -> + exp1 + | P_aux (P_wild,pannot), + E_aux (E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)),_) + when has_unittype exp1 -> + exp1 + | P_aux (P_id (Id_aux (id,_)),_), + E_aux (E_internal_return (E_aux (E_id (Id_aux (id',_)),_)),_) + when id = id' -> + exp1 + | _ -> E_aux (exp,annot) + end + | _ -> E_aux (exp,annot) in + + let alg = { id_exp_alg with e_aux = e_aux } in + rewrite_defs_base + {rewrite_exp = (fun _ -> fold_exp alg) + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + + +let rewrite_defs_remove_e_assign = + let rewrite_exp _ e = + replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in + rewrite_defs_base + { rewrite_exp = rewrite_exp + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + + +let rewrite_defs_lem = + top_sort_defs >> + rewrite_defs_remove_vector_concat >> + rewrite_defs_remove_bitvector_pats >> + rewrite_defs_exp_lift_assign >> + rewrite_defs_remove_blocks >> + rewrite_defs_letbind_effects >> + rewrite_defs_remove_e_assign >> + rewrite_defs_effectful_let_expressions >> + rewrite_defs_remove_superfluous_letbinds >> + rewrite_defs_remove_superfluous_returns + diff --git a/src/rewriter_new_tc.mli b/src/rewriter_new_tc.mli new file mode 100644 index 00000000..03c94d2a --- /dev/null +++ b/src/rewriter_new_tc.mli @@ -0,0 +1,150 @@ +(**************************************************************************) +(* Sail *) +(* *) +(* Copyright (c) 2013-2017 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* *) +(* All rights reserved. *) +(* *) +(* This software was developed by the University of Cambridge Computer *) +(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) +(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(**************************************************************************) + +open Big_int +open Ast +open Type_check_new + +type 'a rewriters = { rewrite_exp : 'a rewriters -> 'a exp -> 'a exp; + rewrite_lexp : 'a rewriters -> 'a lexp -> 'a lexp; + rewrite_pat : 'a rewriters -> 'a pat -> 'a pat; + rewrite_let : 'a rewriters -> 'a letbind -> 'a letbind; + rewrite_fun : 'a rewriters -> 'a fundef -> 'a fundef; + rewrite_def : 'a rewriters -> 'a def -> 'a def; + rewrite_defs : 'a rewriters -> 'a defs -> 'a defs; + } + +val rewrite_exp : tannot rewriters -> tannot exp -> tannot exp +val rewrite_defs : tannot defs -> tannot defs +val rewrite_defs_ocaml : tannot defs -> tannot defs (*Perform rewrites to exclude AST nodes not supported for ocaml out*) +val rewrite_defs_lem : tannot defs -> tannot defs (*Perform rewrites to exclude AST nodes not supported for lem out*) + +(* the type of interpretations of pattern-matching expressions *) +type ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg = + { p_lit : lit -> 'pat_aux + ; p_wild : 'pat_aux + ; p_as : 'pat * id -> 'pat_aux + ; p_typ : Ast.typ * 'pat -> 'pat_aux + ; p_id : id -> 'pat_aux + ; p_app : id * 'pat list -> 'pat_aux + ; p_record : 'fpat list * bool -> 'pat_aux + ; p_vector : 'pat list -> 'pat_aux + ; p_vector_indexed : (int * 'pat) list -> 'pat_aux + ; p_vector_concat : 'pat list -> 'pat_aux + ; p_tup : 'pat list -> 'pat_aux + ; p_list : 'pat list -> 'pat_aux + ; p_aux : 'pat_aux * 'a annot -> 'pat + ; fP_aux : 'fpat_aux * 'a annot -> 'fpat + ; fP_Fpat : id * 'pat -> 'fpat_aux + } + +(* fold over pat_aux expressions *) + + +(* the type of interpretations of expressions *) +type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, + 'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind, + 'pat,'pat_aux,'fpat,'fpat_aux) exp_alg = + { e_block : 'exp list -> 'exp_aux + ; e_nondet : 'exp list -> 'exp_aux + ; e_id : id -> 'exp_aux + ; e_lit : lit -> 'exp_aux + ; e_cast : Ast.typ * 'exp -> 'exp_aux + ; e_app : id * 'exp list -> 'exp_aux + ; e_app_infix : 'exp * id * 'exp -> 'exp_aux + ; e_tuple : 'exp list -> 'exp_aux + ; e_if : 'exp * 'exp * 'exp -> 'exp_aux + ; e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux + ; e_vector : 'exp list -> 'exp_aux + ; e_vector_indexed : (int * 'exp) list * 'opt_default -> 'exp_aux + ; e_vector_access : 'exp * 'exp -> 'exp_aux + ; e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux + ; e_vector_update : 'exp * 'exp * 'exp -> 'exp_aux + ; e_vector_update_subrange : 'exp * 'exp * 'exp * 'exp -> 'exp_aux + ; e_vector_append : 'exp * 'exp -> 'exp_aux + ; e_list : 'exp list -> 'exp_aux + ; e_cons : 'exp * 'exp -> 'exp_aux + ; e_record : 'fexps -> 'exp_aux + ; e_record_update : 'exp * 'fexps -> 'exp_aux + ; e_field : 'exp * id -> 'exp_aux + ; e_case : 'exp * 'pexp list -> 'exp_aux + ; e_let : 'letbind * 'exp -> 'exp_aux + ; e_assign : 'lexp * 'exp -> 'exp_aux + ; e_exit : 'exp -> 'exp_aux + ; e_return : 'exp -> 'exp_aux + ; e_assert : 'exp * 'exp -> 'exp_aux + ; e_internal_cast : 'a annot * 'exp -> 'exp_aux + ; e_internal_exp : 'a annot -> 'exp_aux + ; e_internal_exp_user : 'a annot * 'a annot -> 'exp_aux + ; e_internal_let : 'lexp * 'exp * 'exp -> 'exp_aux + ; e_internal_plet : 'pat * 'exp * 'exp -> 'exp_aux + ; e_internal_return : 'exp -> 'exp_aux + ; e_aux : 'exp_aux * 'a annot -> 'exp + ; lEXP_id : id -> 'lexp_aux + ; lEXP_memory : id * 'exp list -> 'lexp_aux + ; lEXP_cast : Ast.typ * id -> 'lexp_aux + ; lEXP_tup : 'lexp list -> 'lexp_aux + ; lEXP_vector : 'lexp * 'exp -> 'lexp_aux + ; lEXP_vector_range : 'lexp * 'exp * 'exp -> 'lexp_aux + ; lEXP_field : 'lexp * id -> 'lexp_aux + ; lEXP_aux : 'lexp_aux * 'a annot -> 'lexp + ; fE_Fexp : id * 'exp -> 'fexp_aux + ; fE_aux : 'fexp_aux * 'a annot -> 'fexp + ; fES_Fexps : 'fexp list * bool -> 'fexps_aux + ; fES_aux : 'fexps_aux * 'a annot -> 'fexps + ; def_val_empty : 'opt_default_aux + ; def_val_dec : 'exp -> 'opt_default_aux + ; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default + ; pat_exp : 'pat * 'exp -> 'pexp_aux + ; pat_aux : 'pexp_aux * 'a annot -> 'pexp + ; lB_val_explicit : typschm * 'pat * 'exp -> 'letbind_aux + ; lB_val_implicit : 'pat * 'exp -> 'letbind_aux + ; lB_aux : 'letbind_aux * 'a annot -> 'letbind + ; pat_alg : ('a,'pat,'pat_aux,'fpat,'fpat_aux) pat_alg + } + +(* fold over expressions *) +val fold_exp : ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, + 'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind, + 'pat,'pat_aux,'fpat,'fpat_aux) exp_alg -> 'a exp -> 'exp + +val id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg diff --git a/src/spec_analysis_new_tc.ml b/src/spec_analysis_new_tc.ml new file mode 100644 index 00000000..777990aa --- /dev/null +++ b/src/spec_analysis_new_tc.ml @@ -0,0 +1,667 @@ +(**************************************************************************) +(* Sail *) +(* *) +(* Copyright (c) 2013-2017 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* *) +(* All rights reserved. *) +(* *) +(* This software was developed by the University of Cambridge Computer *) +(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) +(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(**************************************************************************) + +open Ast +open Util +open Ast_util + +module Nameset = Set.Make(String) + +let mt = Nameset.empty + +let set_to_string n = + let rec list_to_string = function + | [] -> "" + | [n] -> n + | n::ns -> n ^ ", " ^ list_to_string ns in + list_to_string (Nameset.elements n) + + +(*Query a spec for its default order if one is provided. Assumes Inc if not *) +(* let get_default_order_sp (DT_aux(spec,_)) = + match spec with + | DT_order (Ord_aux(o,_)) -> + (match o with + | Ord_inc -> Some {order = Oinc} + | Ord_dec -> Some { order = Odec} + | _ -> Some {order = Oinc}) + | _ -> None + +let get_default_order_def = function + | DEF_default def_spec -> get_default_order_sp def_spec + | _ -> None + +let rec default_order (Defs defs) = + match defs with + | [] -> { order = Oinc } (*When no order is specified, we assume that it's inc*) + | def::defs -> + match get_default_order_def def with + | None -> default_order (Defs defs) + | Some o -> o *) + +(*Is within range*) + +(* let check_in_range (candidate : big_int) (range : typ) : bool = + match range.t with + | Tapp("range", [TA_nexp min; TA_nexp max]) | Tabbrev(_,{t=Tapp("range", [TA_nexp min; TA_nexp max])}) -> + let min,max = + match min.nexp,max.nexp with + | (Nconst min, Nconst max) + | (Nconst min, N2n(_, Some max)) + | (N2n(_, Some min), Nconst max) + | (N2n(_, Some min), N2n(_, Some max)) + -> min, max + | (Nneg n, Nconst max) | (Nneg n, N2n(_, Some max))-> + (match n.nexp with + | Nconst abs_min | N2n(_,Some abs_min) -> + (minus_big_int abs_min), max + | _ -> assert false (*Put a better error message here*)) + | (Nconst min,Nneg n) | (N2n(_, Some min), Nneg n) -> + (match n.nexp with + | Nconst abs_max | N2n(_,Some abs_max) -> + min, (minus_big_int abs_max) + | _ -> assert false (*Put a better error message here*)) + | (Nneg nmin, Nneg nmax) -> + ((match nmin.nexp with + | Nconst abs_min | N2n(_,Some abs_min) -> (minus_big_int abs_min) + | _ -> assert false (*Put a better error message here*)), + (match nmax.nexp with + | Nconst abs_max | N2n(_,Some abs_max) -> (minus_big_int abs_max) + | _ -> assert false (*Put a better error message here*))) + | _ -> assert false + in le_big_int min candidate && le_big_int candidate max + | _ -> assert false + +(*Rmove me when switch to zarith*) +let rec power_big_int b n = + if eq_big_int n zero_big_int + then unit_big_int + else mult_big_int b (power_big_int b (sub_big_int n unit_big_int)) + +let unpower_of_2 b = + let two = big_int_of_int 2 in + let four = big_int_of_int 4 in + let eight = big_int_of_int 8 in + let sixteen = big_int_of_int 16 in + let thirty_two = big_int_of_int 32 in + let sixty_four = big_int_of_int 64 in + let onetwentyeight = big_int_of_int 128 in + let twofiftysix = big_int_of_int 256 in + let fivetwelve = big_int_of_int 512 in + let oneotwentyfour = big_int_of_int 1024 in + let to_the_sixteen = big_int_of_int 65536 in + let to_the_thirtytwo = big_int_of_string "4294967296" in + let to_the_sixtyfour = big_int_of_string "18446744073709551616" in + let ck i = eq_big_int b i in + if ck unit_big_int then zero_big_int + else if ck two then unit_big_int + else if ck four then two + else if ck eight then big_int_of_int 3 + else if ck sixteen then four + else if ck thirty_two then big_int_of_int 5 + else if ck sixty_four then big_int_of_int 6 + else if ck onetwentyeight then big_int_of_int 7 + else if ck twofiftysix then eight + else if ck fivetwelve then big_int_of_int 9 + else if ck oneotwentyfour then big_int_of_int 10 + else if ck to_the_sixteen then sixteen + else if ck to_the_thirtytwo then thirty_two + else if ck to_the_sixtyfour then sixty_four + else let rec unpower b power = + if eq_big_int b unit_big_int + then power + else (unpower (div_big_int b two) (succ_big_int power)) in + unpower b zero_big_int + +let is_within_range candidate range constraints = + let candidate_actual = match candidate.t with + | Tabbrev(_,t) -> t + | _ -> candidate in + match candidate_actual.t with + | Tapp("atom", [TA_nexp n]) -> + (match n.nexp with + | Nconst i | N2n(_,Some i) -> if check_in_range i range then Yes else No + | _ -> Maybe) + | Tapp("range", [TA_nexp bot; TA_nexp top]) -> + (match bot.nexp,top.nexp with + | Nconst b, Nconst t | Nconst b, N2n(_,Some t) | N2n(_, Some b), Nconst t | N2n(_,Some b), N2n(_, Some t) -> + let at_least_in = check_in_range b range in + let at_most_in = check_in_range t range in + if at_least_in && at_most_in + then Yes + else if at_least_in || at_most_in + then Maybe + else No + | _ -> Maybe) + | Tapp("vector", [_; TA_nexp size ; _; _]) -> + (match size.nexp with + | Nconst i | N2n(_, Some i) -> + if check_in_range (power_big_int (big_int_of_int 2) i) range + then Yes + else No + | _ -> Maybe) + | _ -> Maybe + +let is_within_machine64 candidate constraints = is_within_range candidate int64_t constraints *) + +(************************************************************************************************) +(*FV finding analysis: identifies the free variables of a function, expression, etc *) + +let conditional_add typ_or_exp bound used id = + let known_list = + if typ_or_exp (*true for typ*) + then ["bit";"vector";"unit";"string";"int";"bool";"boolean"] + else ["=="; "!="; "|";"~";"&";"add_int"] in + let i = (string_of_id id) in + if Nameset.mem i bound || List.mem i known_list + then used + else Nameset.add i used + +let conditional_add_typ = conditional_add true +let conditional_add_exp = conditional_add false + + +let nameset_bigunion = List.fold_left Nameset.union mt + + +let rec free_type_names_t consider_var (Typ_aux (t, _)) = match t with + | Typ_var name -> if consider_var then Nameset.add (string_of_kid name) mt else mt + | Typ_id name -> Nameset.add (string_of_id name) mt + | Typ_fn (t1,t2,_) -> Nameset.union (free_type_names_t consider_var t1) + (free_type_names_t consider_var t2) + | Typ_tup ts -> free_type_names_ts consider_var ts + | Typ_app (name,targs) -> Nameset.add (string_of_id name) (free_type_names_t_args consider_var targs) + | Typ_wild -> mt +and free_type_names_ts consider_var ts = nameset_bigunion (List.map (free_type_names_t consider_var) ts) +and free_type_names_maybe_t consider_var = function + | Some t -> free_type_names_t consider_var t + | None -> mt +and free_type_names_t_arg consider_var = function + | Typ_arg_aux (Typ_arg_typ t, _) -> free_type_names_t consider_var t + | _ -> mt +and free_type_names_t_args consider_var targs = + nameset_bigunion (List.map (free_type_names_t_arg consider_var) targs) + + +let rec free_type_names_tannot consider_var = function + | None -> mt + | Some (_, t, _) -> free_type_names_t consider_var t + + +let rec fv_of_typ consider_var bound used (Typ_aux (t,_)) : Nameset.t = + match t with + | Typ_wild -> used + | Typ_var (Kid_aux (Var v,l)) -> + if consider_var + then conditional_add_typ bound used (Ast.Id_aux (Ast.Id v,l)) + else used + | Typ_id id -> conditional_add_typ bound used id + | Typ_fn(arg,ret,_) -> fv_of_typ consider_var bound (fv_of_typ consider_var bound used arg) ret + | Typ_tup ts -> List.fold_right (fun t n -> fv_of_typ consider_var bound n t) ts used + | Typ_app(id,targs) -> + List.fold_right (fun ta n -> fv_of_targ consider_var bound n ta) targs (conditional_add_typ bound used id) + +and fv_of_targ consider_var bound used (Ast.Typ_arg_aux(targ,_)) : Nameset.t = match targ with + | Typ_arg_typ t -> fv_of_typ consider_var bound used t + | Typ_arg_nexp n -> fv_of_nexp consider_var bound used n + | _ -> used + +and fv_of_nexp consider_var bound used (Ast.Nexp_aux(n,_)) = match n with + | Nexp_id id -> conditional_add_typ bound used id + | Nexp_var (Ast.Kid_aux (Ast.Var i,_)) -> + if consider_var + then conditional_add_typ bound used (Ast.Id_aux (Ast.Id i, Parse_ast.Unknown)) + else used + | Nexp_times (n1,n2) | Ast.Nexp_sum (n1,n2) | Ast.Nexp_minus(n1,n2) -> + fv_of_nexp consider_var bound (fv_of_nexp consider_var bound used n1) n2 + | Nexp_exp n | Ast.Nexp_neg n -> fv_of_nexp consider_var bound used n + | _ -> used + +let typq_bindings (TypQ_aux(tq,_)) = match tq with + | TypQ_tq quants -> + List.fold_right (fun (QI_aux (qi,_)) bounds -> + match qi with + | QI_id (KOpt_aux(k,_)) -> + (match k with + | KOpt_none (Kid_aux (Var s,_)) -> Nameset.add s bounds + | KOpt_kind (_, Kid_aux (Var s,_)) -> Nameset.add s bounds) + | _ -> bounds) quants mt + | TypQ_no_forall -> mt + +let fv_of_typschm consider_var bound used (Ast.TypSchm_aux ((Ast.TypSchm_ts(typq,typ)),_)) = + let ts_bound = if consider_var then typq_bindings typq else mt in + ts_bound, fv_of_typ consider_var (Nameset.union bound ts_bound) used typ + +let rec pat_bindings consider_var bound used (P_aux(p,(_,tannot))) = + let list_fv bound used ps = List.fold_right (fun p (b,n) -> pat_bindings consider_var b n p) ps (bound, used) in + match p with + | P_as(p,id) -> let b,ns = pat_bindings consider_var bound used p in + Nameset.add (string_of_id id) b,ns + | P_typ(t,p) -> + let used = Nameset.union (free_type_names_tannot consider_var tannot) used in + let ns = fv_of_typ consider_var bound used t in pat_bindings consider_var bound ns p + | P_id id -> + let used = Nameset.union (free_type_names_tannot consider_var tannot) used in + Nameset.add (string_of_id id) bound,used + | P_app(id,pats) -> + let used = Nameset.union (free_type_names_tannot consider_var tannot) used in + list_fv bound (Nameset.add (string_of_id id) used) pats + | P_record(fpats,_) -> + List.fold_right (fun (Ast.FP_aux(Ast.FP_Fpat(_,p),_)) (b,n) -> + pat_bindings consider_var bound used p) fpats (bound,used) + | P_vector pats | Ast.P_vector_concat pats | Ast.P_tup pats | Ast.P_list pats -> list_fv bound used pats + | P_vector_indexed ipats -> + List.fold_right (fun (_,p) (b,n) -> pat_bindings consider_var b n p) ipats (bound,used) + | _ -> bound,used + +let rec fv_of_exp consider_var bound used set (E_aux (e,(_,tannot))) : (Nameset.t * Nameset.t * Nameset.t) = + let list_fv b n s es = List.fold_right (fun e (b,n,s) -> fv_of_exp consider_var b n s e) es (b,n,s) in + match e with + | E_block es | Ast.E_nondet es | Ast.E_tuple es | Ast.E_vector es | Ast.E_list es -> + list_fv bound used set es + | E_id id -> + let used = conditional_add_exp bound used id in + let used = Nameset.union (free_type_names_tannot consider_var tannot) used in + bound,used,set + | E_cast (t,e) -> + let u = fv_of_typ consider_var (if consider_var then bound else mt) used t in + fv_of_exp consider_var bound u set e + | E_app(id,es) -> + let us = conditional_add_exp bound used id in + list_fv bound us set es + | E_app_infix(l,id,r) -> + let us = conditional_add_exp bound used id in + list_fv bound us set [l;r] + | E_if(c,t,e) -> list_fv bound used set [c;t;e] + | E_for(id,from,to_,by,_,body) -> + let _,used,set = list_fv bound used set [from;to_;by] in + fv_of_exp consider_var (Nameset.add (string_of_id id) bound) used set body + | E_vector_indexed (es_i,(Ast.Def_val_aux(default,_))) -> + let bound,used,set = + List.fold_right + (fun (_,e) (b,u,s) -> fv_of_exp consider_var b u s e) es_i (bound,used,set) in + (match default with + | Def_val_empty -> bound,used,set + | Def_val_dec e -> fv_of_exp consider_var bound used set e) + | E_vector_access(v,i) -> list_fv bound used set [v;i] + | E_vector_subrange(v,i1,i2) -> list_fv bound used set [v;i1;i2] + | E_vector_update(v,i,e) -> list_fv bound used set [v;i;e] + | E_vector_update_subrange(v,i1,i2,e) -> list_fv bound used set [v;i1;i2;e] + | E_vector_append(e1,e2) | E_cons(e1,e2) -> list_fv bound used set [e1;e2] + | E_record (FES_aux(FES_Fexps(fexps,_),_)) -> + let used = Nameset.union (free_type_names_tannot consider_var tannot) used in + List.fold_right + (fun (FE_aux(FE_Fexp(_,e),_)) (b,u,s) -> fv_of_exp consider_var b u s e) fexps (bound,used,set) + | E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) -> + let b,u,s = fv_of_exp consider_var bound used set e in + List.fold_right + (fun (FE_aux(FE_Fexp(_,e),_)) (b,u,s) -> fv_of_exp consider_var b u s e) fexps (b,u,s) + | E_field(e,_) -> fv_of_exp consider_var bound used set e + | E_case(e,pes) -> + let b,u,s = fv_of_exp consider_var bound used set e in + fv_of_pes consider_var b u s pes + | E_let(lebind,e) -> + let b,u,s = fv_of_let consider_var bound used set lebind in + fv_of_exp consider_var b u s e + | E_assign(lexp,e) -> + let b,u,s = fv_of_lexp consider_var bound used set lexp in + let _,used,set = fv_of_exp consider_var bound u s e in + b,used,set + | E_exit e -> fv_of_exp consider_var bound used set e + | E_assert(c,m) -> list_fv bound used set [c;m] + | _ -> bound,used,set + +and fv_of_pes consider_var bound used set pes = + match pes with + | [] -> bound,used,set + | Pat_aux(Pat_exp (p,e),_)::pes -> + let bound_p,us_p = pat_bindings consider_var bound used p in + let bound_e,us_e,set_e = fv_of_exp consider_var bound_p us_p set e in + fv_of_pes consider_var bound us_e set_e pes + +and fv_of_let consider_var bound used set (LB_aux(lebind,_)) = match lebind with + | LB_val_explicit(typsch,pat,exp) -> + let bound_t,us_t = fv_of_typschm consider_var bound used typsch in + let bound_p, us_p = pat_bindings consider_var (Nameset.union bound bound_t) used pat in + let _,us_e,set_e = fv_of_exp consider_var (Nameset.union bound bound_t) used set exp in + (Nameset.union bound_t bound_p),Nameset.union us_t (Nameset.union us_p us_e),set_e + | LB_val_implicit(pat,exp) -> + let bound_p, us_p = pat_bindings consider_var bound used pat in + let _,us_e,set_e = fv_of_exp consider_var bound used set exp in + bound_p,Nameset.union us_p us_e,set_e + +and fv_of_lexp consider_var bound used set (LEXP_aux(lexp,(_,tannot))) = + match lexp with + | LEXP_id id -> + let used = Nameset.union (free_type_names_tannot consider_var tannot) used in + let i = string_of_id id in + if Nameset.mem i bound + then bound, used, Nameset.add i set + else Nameset.add i bound, Nameset.add i used, set + | LEXP_cast(typ,id) -> + let used = Nameset.union (free_type_names_tannot consider_var tannot) used in + let i = string_of_id id in + let used_t = fv_of_typ consider_var bound used typ in + if Nameset.mem i bound + then bound, used_t, Nameset.add i set + else Nameset.add i bound, Nameset.add i used_t, set + | LEXP_tup(tups) -> + List.fold_right (fun l (b,u,s) -> fv_of_lexp consider_var b u s l) tups (bound,used,set) + | LEXP_memory(id,args) -> + let (bound,used,set) = + List.fold_right + (fun e (b,u,s) -> + fv_of_exp consider_var b u s e) args (bound,used,set) in + bound,Nameset.add (string_of_id id) used,set + | LEXP_field(lexp,_) -> fv_of_lexp consider_var bound used set lexp + | LEXP_vector(lexp,exp) -> + let bound_l,used,set = fv_of_lexp consider_var bound used set lexp in + let _,used,set = fv_of_exp consider_var bound used set exp in + bound_l,used,set + | LEXP_vector_range(lexp,e1,e2) -> + let bound_l,used,set = fv_of_lexp consider_var bound used set lexp in + let _,used,set = fv_of_exp consider_var bound used set e1 in + let _,used,set = fv_of_exp consider_var bound used set e2 in + bound_l,used,set + +let init_env s = Nameset.singleton s + +let typ_variants consider_var bound tunions = + List.fold_right + (fun (Tu_aux(t,_)) (b,n) -> match t with + | Tu_id id -> Nameset.add (string_of_id id) b,n + | Tu_ty_id(t,id) -> Nameset.add (string_of_id id) b, fv_of_typ consider_var b n t) + tunions + (bound,mt) + +let fv_of_kind_def consider_var (KD_aux(k,_)) = match k with + | KD_nabbrev(_,id,_,nexp) -> init_env (string_of_id id), fv_of_nexp consider_var mt mt nexp + | KD_abbrev(_,id,_,typschm) -> + init_env (string_of_id id), snd (fv_of_typschm consider_var mt mt typschm) + | KD_record(_,id,_,typq,tids,_) -> + let binds = init_env (string_of_id id) in + let bounds = if consider_var then typq_bindings typq else mt in + binds, List.fold_right (fun (t,_) n -> fv_of_typ consider_var bounds n t) tids mt + | KD_variant(_,id,_,typq,tunions,_) -> + let bindings = Nameset.add (string_of_id id) (if consider_var then typq_bindings typq else mt) in + typ_variants consider_var bindings tunions + | KD_enum(_,id,_,ids,_) -> + Nameset.of_list (List.map string_of_id (id::ids)),mt + | KD_register(_,id,n1,n2,_) -> + init_env (string_of_id id), fv_of_nexp consider_var mt (fv_of_nexp consider_var mt mt n1) n2 + +let fv_of_type_def consider_var (TD_aux(t,_)) = match t with + | TD_abbrev(id,_,typschm) -> init_env (string_of_id id), snd (fv_of_typschm consider_var mt mt typschm) + | TD_record(id,_,typq,tids,_) -> + let binds = init_env (string_of_id id) in + let bounds = if consider_var then typq_bindings typq else mt in + binds, List.fold_right (fun (t,_) n -> fv_of_typ consider_var bounds n t) tids mt + | TD_variant(id,_,typq,tunions,_) -> + let bindings = Nameset.add (string_of_id id) (if consider_var then typq_bindings typq else mt) in + typ_variants consider_var bindings tunions + | TD_enum(id,_,ids,_) -> + Nameset.of_list (List.map string_of_id (id::ids)),mt + | TD_register(id,n1,n2,_) -> + init_env (string_of_id id), fv_of_nexp consider_var mt (fv_of_nexp consider_var mt mt n1) n2 + +let fv_of_tannot_opt consider_var (Typ_annot_opt_aux (t,_)) = + match t with + | Typ_annot_opt_some (typq,typ) -> + let bindings = if consider_var then typq_bindings typq else mt in + let free = fv_of_typ consider_var bindings mt typ in + (bindings,free) + +(*Unlike the other fv, the bound returns are the names bound by the pattern for use in the exp*) +let fv_of_funcl consider_var base_bounds (FCL_aux(FCL_Funcl(id,pat,exp),l)) = + let pat_bs,pat_ns = pat_bindings consider_var base_bounds mt pat in + let _, exp_ns, exp_sets = fv_of_exp consider_var pat_bs pat_ns mt exp in + (pat_bs,exp_ns,exp_sets) + +let fv_of_fun consider_var (FD_aux (FD_function(rec_opt,tannot_opt,_,funcls),_)) = + let fun_name = match funcls with + | [] -> failwith "fv_of_fun fell off the end looking for the function name" + | FCL_aux(FCL_Funcl(id,_,_),_)::_ -> string_of_id id in + let base_bounds = match rec_opt with + | Rec_aux(Ast.Rec_rec,_) -> init_env fun_name + | _ -> mt in + let base_bounds,ns_r = match tannot_opt with + | Typ_annot_opt_aux(Typ_annot_opt_some (typq, typ),_) -> + let bindings = if consider_var then typq_bindings typq else mt in + let bound = Nameset.union bindings base_bounds in + bound, fv_of_typ consider_var bound mt typ in + let ns = List.fold_right (fun (FCL_aux(FCL_Funcl(_,pat,exp),_)) ns -> + let pat_bs,pat_ns = pat_bindings consider_var base_bounds ns pat in + let _, exp_ns,_ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty exp in + exp_ns) funcls mt in + init_env fun_name,Nameset.union ns ns_r + +let fv_of_vspec consider_var (VS_aux(vspec,_)) = match vspec with + | VS_val_spec(ts,id) | VS_extern_no_rename (ts,id) | VS_extern_spec(ts,id,_) + | VS_cast_spec(ts,id) -> + init_env ("val:" ^ (string_of_id id)), snd (fv_of_typschm consider_var mt mt ts) + +let rec find_scattered_of name = function + | [] -> [] + | DEF_scattered (SD_aux(sda,_) as sd):: defs -> + (match sda with + | SD_scattered_function(_,_,_,id) + | SD_scattered_funcl(FCL_aux(FCL_Funcl(id,_,_),_)) + | SD_scattered_unioncl(id,_) -> + if name = string_of_id id + then [sd] else [] + | _ -> [])@ + (find_scattered_of name defs) + | _::defs -> find_scattered_of name defs + +let rec fv_of_scattered consider_var consider_scatter_as_one all_defs (SD_aux(sd,_)) = match sd with + | SD_scattered_function(_,tannot_opt,_,id) -> + let b,ns = (match tannot_opt with + | Typ_annot_opt_aux(Typ_annot_opt_some (typq, typ),_) -> + let bindings = if consider_var then typq_bindings typq else mt in + bindings, fv_of_typ consider_var bindings mt typ) in + init_env (string_of_id id),ns + | SD_scattered_funcl (FCL_aux(FCL_Funcl(id,pat,exp),_)) -> + let pat_bs,pat_ns = pat_bindings consider_var mt mt pat in + let _,exp_ns,_ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty exp in + let scattered_binds = match pat with + | P_aux(P_app(pid,_),_) -> init_env ((string_of_id id) ^ "/" ^ (string_of_id pid)) + | _ -> mt in + scattered_binds, exp_ns + | SD_scattered_variant (id,_,_) -> + let name = string_of_id id in + let uses = + if consider_scatter_as_one + then + let variant_defs = find_scattered_of name all_defs in + let pieces_uses = + List.fold_right (fun (binds,uses) all_uses -> Nameset.union uses all_uses) + (List.map (fv_of_scattered consider_var false []) variant_defs) mt in + Nameset.remove name pieces_uses + else mt in + init_env name, uses + | SD_scattered_unioncl(id, type_union) -> + let typ_name = string_of_id id in + let b = init_env typ_name in + let (b,r) = typ_variants consider_var b [type_union] in + (Nameset.remove typ_name b, Nameset.add typ_name r) + | SD_scattered_end id -> + let name = string_of_id id in + let uses = if consider_scatter_as_one + (*Note: if this is a function ending, the dec is included *) + then + let scattered_defs = find_scattered_of name all_defs in + List.fold_right (fun (binds,uses) all_uses -> Nameset.union (Nameset.union binds uses) all_uses) + (List.map (fv_of_scattered consider_var false []) scattered_defs) (init_env name) + else init_env name in + init_env (name ^ "/end"), uses + +let fv_of_rd consider_var (DEC_aux (d,_)) = match d with + | DEC_reg(t,id) -> + init_env (string_of_id id), fv_of_typ consider_var mt mt t + | DEC_alias(id,alias) -> + init_env (string_of_id id),mt + | DEC_typ_alias(t,id,alias) -> + init_env (string_of_id id), mt + +let fv_of_def consider_var consider_scatter_as_one all_defs = function + | DEF_kind kdef -> fv_of_kind_def consider_var kdef + | DEF_type tdef -> fv_of_type_def consider_var tdef + | DEF_fundef fdef -> fv_of_fun consider_var fdef + | DEF_val lebind -> ((fun (b,u,_) -> (b,u)) (fv_of_let consider_var mt mt mt lebind)) + | DEF_spec vspec -> fv_of_vspec consider_var vspec + | DEF_overload (id,ids) -> init_env (string_of_id id), List.fold_left (fun ns id -> Nameset.add (string_of_id id) ns) mt ids + | DEF_default def -> mt,mt + | DEF_scattered sdef -> fv_of_scattered consider_var consider_scatter_as_one all_defs sdef + | DEF_reg_dec rdec -> fv_of_rd consider_var rdec + | DEF_comm _ -> mt,mt + +let group_defs consider_scatter_as_one (Ast.Defs defs) = + List.map (fun d -> (fv_of_def false consider_scatter_as_one defs d,d)) defs + +(******************************************************************************* + * Reorder defs take 2 +*) + +(*remove all of ns1 instances from ns2*) +let remove_all ns1 ns2 = + List.fold_right Nameset.remove (Nameset.elements ns1) ns2 + +let remove_from_all_uses bs dbts = + List.map (fun ((b,uses),d) -> (b,remove_all bs uses),d) dbts + +let remove_local_or_lib_vars dbts = + let bound_in_dbts = List.fold_right (fun ((b,_),_) bounds -> Nameset.union b bounds) dbts mt in + let is_bound_in_defs s = Nameset.mem s bound_in_dbts in + let rec remove_from_uses = function + | [] -> [] + | ((b,uses),d)::defs -> + ((b,(Nameset.filter is_bound_in_defs uses)),d)::remove_from_uses defs in + remove_from_uses dbts + +let compare_dbts ((_,u1),_) ((_,u2),_) = Pervasives.compare (Nameset.cardinal u1) (Nameset.cardinal u2) + +let rec print_dependencies orig_queue work_queue names = + match work_queue with + | [] -> () + | ((binds,uses),_)::wq -> + (if not(Nameset.is_empty(Nameset.inter names binds)) + then ((Printf.eprintf "binds of %s has uses of %s\n" (set_to_string binds) (set_to_string uses)); + print_dependencies orig_queue orig_queue uses)); + print_dependencies orig_queue wq names + +let rec topological_sort work_queue defs = + match work_queue with + | [] -> List.rev defs + | ((binds,uses),def)::wq -> + (*Assumes work queue given in sorted order, invariant mantained on appropriate recursive calls*) + if (Nameset.cardinal uses = 0) + then (*let _ = Printf.eprintf "Adding def that binds %s to definitions\n" (set_to_string binds) in*) + topological_sort (remove_from_all_uses binds wq) (def::defs) + else if not(Nameset.is_empty(Nameset.inter binds uses)) + then topological_sort (((binds,(remove_all binds uses)),def)::wq) defs + else + match List.stable_sort compare_dbts work_queue with (*We wait to sort until there are no 0 dependency nodes on top*) + | [] -> failwith "sort shrunk the list???" + | (((n,uses),_)::_) as wq -> + if (Nameset.cardinal uses = 0) + then topological_sort wq defs + else let _ = Printf.eprintf "Uses on failure are %s, binds are %s\n" (set_to_string uses) (set_to_string n) + in let _ = print_dependencies wq wq uses in failwith "A dependency was unmet" + +let rec add_to_partial_order ((binds,uses),def) = function + | [] -> +(* let _ = Printf.eprintf "add_to_partial_order for def with bindings %s, uses %s.\n Eol case.\n" (set_to_string binds) (set_to_string uses) in*) + [(binds,uses),def] + | (((bf,uf),deff)::defs as full_defs) -> + (*let _ = Printf.eprintf "add_to_partial_order for def with bindings %s, uses %s.\n None eol case. With first def binding %s, uses %s\n" (set_to_string binds) (set_to_string uses) (set_to_string bf) (set_to_string uf) in*) + if Nameset.is_empty uses + then ((binds,uses),def)::full_defs + else if Nameset.subset binds uf (*deff relies on def, so def must be defined first*) + then ((binds,uses),def)::((bf,(remove_all binds uf)),deff)::defs + else if Nameset.subset bf uses (*def relies at least on deff, but maybe more, push in*) + then ((bf,uf),deff)::(add_to_partial_order ((binds,(remove_all bf uses)),def) defs) + else (*These two are unrelated but new def might need to go further in*) + ((bf,uf),deff)::(add_to_partial_order ((binds,uses),def) defs) + +let rec gather_defs name already_included def_bind_triples = + match def_bind_triples with + | [] -> [],already_included,mt + | ((binds,uses),def)::def_bind_triples -> + let (defs,already_included,requires) = gather_defs name already_included def_bind_triples in + let bound_names = Nameset.elements binds in + if List.mem name already_included || List.exists (fun b -> List.mem b already_included) bound_names + then (defs,already_included,requires) + else + let uses = List.fold_right Nameset.remove already_included uses in + if Nameset.mem name binds + then (def::defs,(bound_names@already_included), Nameset.remove name (Nameset.union uses requires)) + else (defs,already_included,requires) + +let rec gather_all names already_included def_bind_triples = + let rec gather ns already_included defs reqs = match ns with + | [] -> defs,already_included,reqs + | name::ns -> + if List.mem name already_included + then gather ns already_included defs (Nameset.remove name reqs) + else + let (new_defs,already_included,new_reqs) = gather_defs name already_included def_bind_triples in + gather ns already_included (new_defs@defs) (Nameset.remove name (Nameset.union new_reqs reqs)) + in + let (defs,already_included,reqs) = gather names already_included [] mt in + if Nameset.is_empty reqs + then defs + else (gather_all (Nameset.elements reqs) already_included def_bind_triples)@defs + +let restrict_defs defs name_list = + let defsno = gather_all name_list [] (group_defs false defs) in + let rdbts = group_defs true (Defs defsno) in + (*let partial_order = + List.fold_left (fun po d -> add_to_partial_order d po) [] rdbts in + let defs = List.map snd partial_order in*) + let defs = topological_sort (List.sort compare_dbts (remove_local_or_lib_vars rdbts)) [] in + Defs defs + + +let top_sort_defs defs = + let rdbts = group_defs true defs in + let defs = topological_sort (List.stable_sort compare_dbts (remove_local_or_lib_vars rdbts)) [] in + Defs defs diff --git a/src/spec_analysis_new_tc.mli b/src/spec_analysis_new_tc.mli new file mode 100644 index 00000000..7c6f3685 --- /dev/null +++ b/src/spec_analysis_new_tc.mli @@ -0,0 +1,70 @@ +(**************************************************************************) +(* Sail *) +(* *) +(* Copyright (c) 2013-2017 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* *) +(* All rights reserved. *) +(* *) +(* This software was developed by the University of Cambridge Computer *) +(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) +(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(**************************************************************************) + +open Ast +open Util +open Type_check_new + +(*Determines if the first typ is within the range of the the second typ, + using the constraints provided when the first typ contains variables. + It is an error for second typ to be anything other than a range type + If the first typ is a vector, then determines if the max representable + number is in the range of the second; it is an error for the first typ + to be anything other than a vector, a range, an atom, or a bit (after + suitable unwrapping of abbreviations, reg, and registers). +*) +(* val is_within_range: typ -> typ -> nexp_range list -> triple +val is_within_machine64 : typ -> nexp_range list -> triple *) + +(* free variables and dependencies *) + +(*fv_of_def consider_ty_vars consider_scatter_as_one all_defs all_defs def -> (bound_by_def, free_in_def) *) +(* val fv_of_def: bool -> bool -> ('a def) list -> 'a def -> Nameset.t * Nameset.t *) + +(*group_defs consider_scatter_as_one all_defs -> ((bound_by_def, free_in_def), def) list *) +(* val group_defs : bool -> 'a defs -> ((Nameset.t * Nameset.t) * ('a def)) list *) + +(*reodering definitions, initial functions *) +(* produce a new ordering for defs, limiting to those listed in the list, which respects dependencies *) +(* val restrict_defs : 'a defs -> string list -> 'a defs *) + +val top_sort_defs : tannot defs -> tannot defs diff --git a/src/type_check_new.mli b/src/type_check_new.mli index 971ace5c..d4fe97e7 100644 --- a/src/type_check_new.mli +++ b/src/type_check_new.mli @@ -131,6 +131,11 @@ val mk_effect : base_effect_aux list -> effect val union_effects : effect -> effect -> effect val equal_effects : effect -> effect -> bool +val nconstant : int -> nexp +val nminus : nexp -> nexp -> nexp +val nsum : nexp -> nexp -> nexp +val nvar : kid -> nexp + (* Sail builtin types. *) val int_typ : typ val nat_typ : typ @@ -152,6 +157,7 @@ type tannot = (Env.t * typ * effect) option (* Strip the type annotations from an expression. *) val strip_exp : 'a exp -> unit exp +val strip_pat : 'a pat -> unit pat (* Check an expression has some type. Returns a fully annotated version of the expression, where each subexpression is annotated diff --git a/src/util.ml b/src/util.ml index 9b76c118..d2d4eea7 100644 --- a/src/util.ml +++ b/src/util.ml @@ -203,6 +203,12 @@ let option_bind f = function | None -> None | Some(o) -> f o +let rec option_binop f x y = match x, y with + | None, None -> None + | Some x, None -> Some x + | None, Some y -> Some y + | Some x, Some y -> Some (f x y) + let changed2 f g x h y = match (g x, h y) with | (None,None) -> None diff --git a/src/util.mli b/src/util.mli index 099839bb..cfd6a19e 100644 --- a/src/util.mli +++ b/src/util.mli @@ -77,6 +77,12 @@ val option_bind : ('a -> 'b option) -> 'a option -> 'b option whereas [option_default d (Some x)] returns [x]. *) val option_default : 'a -> 'a option -> 'a +(** [option_binop f None None] returns [None], while + [option_binop f (Some x) None] and [option_binop f None (Some x)] + return [Some x], and [option_binop f (Some x) (Some y)] returns + [Some (f x y)] *) +val option_binop : ('a -> 'a -> 'a) -> 'a option -> 'a option -> 'a option + (** [option_get_exn exn None] throws the exception [exn], whereas [option_get_exn exn (Some x)] returns [x]. *) val option_get_exn : exn -> 'a option -> 'a |
