diff options
Diffstat (limited to 'src/rewrites.ml')
| -rw-r--r-- | src/rewrites.ml | 156 |
1 files changed, 151 insertions, 5 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index 30318e3f..0ad4c56e 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2204,9 +2204,11 @@ let rewrite_fix_val_specs (Defs defs) = (* Repeat once to cross-propagate effects between clauses *) let (val_specs, funcls) = List.fold_left rewrite_funcl (val_specs, []) funcls in let recopt = - if List.exists is_funcl_rec funcls then - Rec_aux (Rec_rec, Parse_ast.Unknown) - else recopt + match recopt with + | Rec_aux ((Rec_rec | Rec_measure _), _) -> recopt + | _ when List.exists is_funcl_rec funcls -> + Rec_aux (Rec_rec, Parse_ast.Unknown) + | _ -> recopt in let tannotopt = match tannotopt, funcls with | Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), l), @@ -4751,7 +4753,7 @@ let minimise_recursive_functions (Defs defs) = let rewrite_function (FD_aux (FD_function (recopt,topt,effopt,funcls),ann) as fd) = match recopt with | Rec_aux (Rec_nonrec, _) -> fd - | Rec_aux (Rec_rec, l) -> + | Rec_aux ((Rec_rec | Rec_measure _), l) -> if List.exists funcl_is_rec funcls then fd else FD_aux (FD_function (Rec_aux (Rec_nonrec, Generated l),topt,effopt,funcls),ann) @@ -4761,7 +4763,150 @@ let minimise_recursive_functions (Defs defs) = | d -> d in Defs (List.map rewrite_def defs) +(* Make recursive functions with a measure use the measure as an + explicit recursion limit, enforced by an assertion. *) +let rewrite_explicit_measure (Defs defs) = + let scan_function measures = function + | FD_aux (FD_function (Rec_aux (Rec_measure (mpat,mexp),rl),topt,effopt, + FCL_aux (FCL_Funcl (id,_),_)::_),ann) -> + Bindings.add id (mpat,mexp) measures + | _ -> measures + in + let scan_def measures = function + | DEF_fundef fd -> scan_function measures fd + | _ -> measures + in + let measures = List.fold_left scan_def Bindings.empty defs in + let add_escape eff = + union_effects eff (mk_effect [BE_escape]) + in + (* NB: the Coq backend relies on recognising the #rec# prefix *) + let rec_id = function + | Id_aux (Id id,l) + | Id_aux (DeIid id,l) -> Id_aux (Id ("#rec#" ^ id),Generated l) + in + let limit = mk_id "#reclimit" in + (* Add helper function with extra argument to spec *) + let rewrite_spec (VS_aux (VS_val_spec (typsch,id,extern,flag),ann) as vs) = + match Bindings.find id measures with + | _ -> begin + match typsch with + | TypSchm_aux (TypSchm_ts (tq, + Typ_aux (Typ_fn (args,res,eff),typl)),tsl) -> + [VS_aux (VS_val_spec ( + TypSchm_aux (TypSchm_ts (tq, + Typ_aux (Typ_fn (args@[int_typ],res,add_escape eff),typl)),tsl) + ,rec_id id,extern,flag),ann); + VS_aux (VS_val_spec ( + TypSchm_aux (TypSchm_ts (tq, + Typ_aux (Typ_fn (args,res,add_escape eff),typl)),tsl) + ,id,extern,flag),ann)] + | _ -> [vs] (* TODO warn *) + end + | exception Not_found -> [vs] + in + (* Add extra argument and assertion to each funcl, and rewrite recursive calls *) + let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),ann) as fcl) = + let loc = Parse_ast.Generated (fst ann) in + let P_aux (pat,pann),guard,body,ann = destruct_pexp pexp in + let extra_pat = P_aux (P_id limit,(loc,empty_tannot)) in + let pat = match pat with + | P_tup pats -> P_tup (pats@[extra_pat]) + | p -> P_tup [P_aux (p,pann);extra_pat] + in + let assert_exp = + E_aux (E_assert + (E_aux (E_app (mk_id "gteq_int", + [E_aux (E_id limit,(loc,empty_tannot)); + E_aux (E_lit (L_aux (L_num Big_int.zero,loc)),(loc,empty_tannot))]), + (loc,empty_tannot)), + (E_aux (E_lit (L_aux (L_string "recursion limit reached",loc)),(loc,empty_tannot)))), + (loc,empty_tannot)) + in + let tick = + E_aux (E_app (mk_id "sub_int", + [E_aux (E_id limit,(loc,empty_tannot)); + E_aux (E_lit (L_aux (L_num (Big_int.of_int 1),loc)),(loc,empty_tannot))]), + (loc,empty_tannot)) + in + let open Rewriter in + let body = + fold_exp { id_exp_alg with + e_app = (fun (f,args) -> + if Id.compare f id == 0 + then E_app (rec_id id, args@[tick]) + else E_app (f, args)) + } body + in + let body = E_aux (E_block [assert_exp; body],(loc,empty_tannot)) in + FCL_aux (FCL_Funcl (rec_id id, construct_pexp (P_aux (pat,pann),guard,body,ann)),ann) + in + let rewrite_function (FD_aux (FD_function (r,t,e,fcls),ann) as fd) = + let loc = Parse_ast.Generated (fst ann) in + match fcls with + | FCL_aux (FCL_Funcl (id,_),fcl_ann)::_ -> begin + match Bindings.find id measures with + | (measure_pat, measure_exp) -> + let e = match e with + | Effect_opt_aux (Effect_opt_pure, _) -> + Effect_opt_aux (Effect_opt_effect (mk_effect [BE_escape]), loc) + | Effect_opt_aux (Effect_opt_effect eff,_) -> + Effect_opt_aux (Effect_opt_effect (add_escape eff), loc) + in + let arg_typs = match Env.get_val_spec id (env_of_annot fcl_ann) with + | _, Typ_aux (Typ_fn (args,_,_),_) -> args + | _, _ -> raise (Reporting.err_unreachable (fst ann) __POS__ + "Function doesn't have function type") + in + let measure_pats = match arg_typs, measure_pat with + | [_], _ -> [measure_pat] + | _, P_aux (P_tup ps,_) -> ps + | _, _ -> [measure_pat] + in + let mk_wrap i (P_aux (p,(l,_))) = + let id = + match p with + | P_id id + | P_typ (_,(P_aux (P_id id,_))) -> id + | P_wild + | P_typ (_,(P_aux (P_wild,_))) -> + mk_id ("_arg" ^ string_of_int i) + | _ -> raise (Reporting.err_todo l "Measure patterns can only be identifiers or wildcards") + in + P_aux (P_id id,(loc,empty_tannot)), + E_aux (E_id id,(loc,empty_tannot)) + in + let wpats,wexps = List.split (Util.list_mapi mk_wrap measure_pats) in + let wpat = match wpats with + | [wpat] -> wpat + | _ -> P_aux (P_tup wpats,(loc,empty_tannot)) + in + let wbody = E_aux (E_app (rec_id id,wexps@[measure_exp]),(loc,empty_tannot)) in + let wrapper = + FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (wpat,wbody),(loc,empty_tannot))),(loc,empty_tannot)) + in + let new_rec = + Rec_aux (Rec_measure (P_aux (P_tup (List.map (fun _ -> P_aux (P_wild,(loc,empty_tannot))) measure_pats @ [P_aux (P_id limit,(loc,empty_tannot))]),(loc,empty_tannot)), E_aux (E_id limit, (loc,empty_tannot))), loc) + in + [FD_aux (FD_function (new_rec,t,e,List.map rewrite_funcl fcls),ann); + FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,e,[wrapper]),ann)] + | exception Not_found -> [fd] + end + | _ -> [fd] + in + let rewrite_def = function + | DEF_spec vs -> List.map (fun vs -> DEF_spec vs) (rewrite_spec vs) + | DEF_fundef fd -> List.map (fun f -> DEF_fundef f) (rewrite_function fd) + | d -> [d] + in + Defs (List.flatten (List.map rewrite_def defs)) + let recheck_defs defs = fst (Type_error.check initial_env defs) +let recheck_defs_without_effects defs = + let () = opt_no_effects := true in + let result,_ = Type_error.check initial_env defs in + let () = opt_no_effects := false in + result let remove_mapping_valspecs (Defs defs) = let allowed_def def = @@ -4889,8 +5034,9 @@ let rewrite_defs_coq = [ ("sizeof", rewrite_sizeof); ("early_return", rewrite_defs_early_return); ("make_cases_exhaustive", MakeExhaustive.rewrite); + ("rewrite_explicit_measure", rewrite_explicit_measure); + ("recheck_defs_without_effects", recheck_defs_without_effects); ("fix_val_specs", rewrite_fix_val_specs); - ("recheck_defs", recheck_defs); ("remove_blocks", rewrite_defs_remove_blocks); ("letbind_effects", rewrite_defs_letbind_effects); ("remove_e_assign", rewrite_defs_remove_e_assign); |
