diff options
Diffstat (limited to 'src/rewrites.ml')
| -rw-r--r-- | src/rewrites.ml | 146 |
1 files changed, 145 insertions, 1 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index c6e2743e..33b50459 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -4760,7 +4760,150 @@ let minimise_recursive_functions (Defs defs) = | d -> d in Defs (List.map rewrite_def defs) +(* Make recursive functions with a measure use the measure as an + explicit recursion limit, enforced by an assertion. *) +let rewrite_explicit_measure (Defs defs) = + let scan_function measures = function + | FD_aux (FD_function (Rec_aux (Rec_measure (mpat,mexp),rl),topt,effopt, + FCL_aux (FCL_Funcl (id,_),_)::_),ann) -> + Bindings.add id (mpat,mexp) measures + | _ -> measures + in + let scan_def measures = function + | DEF_fundef fd -> scan_function measures fd + | _ -> measures + in + let measures = List.fold_left scan_def Bindings.empty defs in + let add_escape eff = + union_effects eff (mk_effect [BE_escape]) + in + (* NB: the Coq backend relies on recognising the #rec# prefix *) + let rec_id = function + | Id_aux (Id id,l) + | Id_aux (DeIid id,l) -> Id_aux (Id ("#rec#" ^ id),Generated l) + in + let limit = mk_id "#reclimit" in + (* Add helper function with extra argument to spec *) + let rewrite_spec (VS_aux (VS_val_spec (typsch,id,extern,flag),ann) as vs) = + match Bindings.find id measures with + | _ -> begin + match typsch with + | TypSchm_aux (TypSchm_ts (tq, + Typ_aux (Typ_fn (args,res,eff),typl)),tsl) -> + [VS_aux (VS_val_spec ( + TypSchm_aux (TypSchm_ts (tq, + Typ_aux (Typ_fn (args@[int_typ],res,add_escape eff),typl)),tsl) + ,rec_id id,extern,flag),ann); + VS_aux (VS_val_spec ( + TypSchm_aux (TypSchm_ts (tq, + Typ_aux (Typ_fn (args,res,add_escape eff),typl)),tsl) + ,id,extern,flag),ann)] + | _ -> [vs] (* TODO warn *) + end + | exception Not_found -> [vs] + in + (* Add extra argument and assertion to each funcl, and rewrite recursive calls *) + let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),ann) as fcl) = + let loc = Parse_ast.Generated (fst ann) in + let P_aux (pat,pann),guard,body,ann = destruct_pexp pexp in + let extra_pat = P_aux (P_id limit,(loc,empty_tannot)) in + let pat = match pat with + | P_tup pats -> P_tup (pats@[extra_pat]) + | p -> P_tup [P_aux (p,pann);extra_pat] + in + let assert_exp = + E_aux (E_assert + (E_aux (E_app (mk_id "gteq_int", + [E_aux (E_id limit,(loc,empty_tannot)); + E_aux (E_lit (L_aux (L_num Big_int.zero,loc)),(loc,empty_tannot))]), + (loc,empty_tannot)), + (E_aux (E_lit (L_aux (L_string "recursion limit reached",loc)),(loc,empty_tannot)))), + (loc,empty_tannot)) + in + let tick = + E_aux (E_app (mk_id "sub_int", + [E_aux (E_id limit,(loc,empty_tannot)); + E_aux (E_lit (L_aux (L_num (Big_int.of_int 1),loc)),(loc,empty_tannot))]), + (loc,empty_tannot)) + in + let open Rewriter in + let body = + fold_exp { id_exp_alg with + e_app = (fun (f,args) -> + if Id.compare f id == 0 + then E_app (rec_id id, args@[tick]) + else E_app (f, args)) + } body + in + let body = E_aux (E_block [assert_exp; body],(loc,empty_tannot)) in + FCL_aux (FCL_Funcl (rec_id id, construct_pexp (P_aux (pat,pann),guard,body,ann)),ann) + in + let rewrite_function (FD_aux (FD_function (r,t,e,fcls),ann) as fd) = + let loc = Parse_ast.Generated (fst ann) in + match fcls with + | FCL_aux (FCL_Funcl (id,_),fcl_ann)::_ -> begin + match Bindings.find id measures with + | (measure_pat, measure_exp) -> + let e = match e with + | Effect_opt_aux (Effect_opt_pure, _) -> + Effect_opt_aux (Effect_opt_effect (mk_effect [BE_escape]), loc) + | Effect_opt_aux (Effect_opt_effect eff,_) -> + Effect_opt_aux (Effect_opt_effect (add_escape eff), loc) + in + let arg_typs = match Env.get_val_spec id (env_of_annot fcl_ann) with + | _, Typ_aux (Typ_fn (args,_,_),_) -> args + | _, _ -> raise (Reporting.err_unreachable (fst ann) __POS__ + "Function doesn't have function type") + in + let measure_pats = match arg_typs, measure_pat with + | [_], _ -> [measure_pat] + | _, P_aux (P_tup ps,_) -> ps + | _, _ -> [measure_pat] + in + let mk_wrap i (P_aux (p,(l,_))) = + let id = + match p with + | P_id id + | P_typ (_,(P_aux (P_id id,_))) -> id + | P_wild + | P_typ (_,(P_aux (P_wild,_))) -> + mk_id ("_arg" ^ string_of_int i) + | _ -> raise (Reporting.err_todo l "Measure patterns can only be identifiers or wildcards") + in + P_aux (P_id id,(loc,empty_tannot)), + E_aux (E_id id,(loc,empty_tannot)) + in + let wpats,wexps = List.split (Util.list_mapi mk_wrap measure_pats) in + let wpat = match wpats with + | [wpat] -> wpat + | _ -> P_aux (P_tup wpats,(loc,empty_tannot)) + in + let wbody = E_aux (E_app (rec_id id,wexps@[measure_exp]),(loc,empty_tannot)) in + let wrapper = + FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (wpat,wbody),(loc,empty_tannot))),(loc,empty_tannot)) + in + let new_rec = + Rec_aux (Rec_measure (P_aux (P_tup (List.map (fun _ -> P_aux (P_wild,(loc,empty_tannot))) measure_pats @ [P_aux (P_id limit,(loc,empty_tannot))]),(loc,empty_tannot)), E_aux (E_id limit, (loc,empty_tannot))), loc) + in + [FD_aux (FD_function (new_rec,t,e,List.map rewrite_funcl fcls),ann); + FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,e,[wrapper]),ann)] + | exception Not_found -> [fd] + end + | _ -> [fd] + in + let rewrite_def = function + | DEF_spec vs -> List.map (fun vs -> DEF_spec vs) (rewrite_spec vs) + | DEF_fundef fd -> List.map (fun f -> DEF_fundef f) (rewrite_function fd) + | d -> [d] + in + Defs (List.flatten (List.map rewrite_def defs)) + let recheck_defs defs = fst (Type_error.check initial_env defs) +let recheck_defs_without_effects defs = + let () = opt_no_effects := true in + let result,_ = Type_error.check initial_env defs in + let () = opt_no_effects := false in + result let remove_mapping_valspecs (Defs defs) = let allowed_def def = @@ -4888,8 +5031,9 @@ let rewrite_defs_coq = [ ("sizeof", rewrite_sizeof); ("early_return", rewrite_defs_early_return); ("make_cases_exhaustive", MakeExhaustive.rewrite); + ("rewrite_explicit_measure", rewrite_explicit_measure); + ("recheck_defs_without_effects", recheck_defs_without_effects); ("fix_val_specs", rewrite_fix_val_specs); - ("recheck_defs", recheck_defs); ("remove_blocks", rewrite_defs_remove_blocks); ("letbind_effects", rewrite_defs_letbind_effects); ("remove_e_assign", rewrite_defs_remove_e_assign); |
