diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/pretty_print_coq.ml | 103 | ||||
| -rw-r--r-- | src/rewrites.ml | 146 |
2 files changed, 180 insertions, 69 deletions
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index a851c5fa..0b9fe077 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -913,6 +913,11 @@ let general_typ_of_annot annot = let general_typ_of (E_aux (_,annot)) = general_typ_of_annot annot +let is_prefix s s' = + let l = String.length s in + String.length s' >= l && + String.sub s' 0 l = s + let prefix_recordtype = true let report = Reporting.err_unreachable let doc_exp, doc_let = @@ -1181,7 +1186,7 @@ let doc_exp, doc_let = if Env.is_extern f env "coq" then string (Env.get_extern f env "coq"), true, false, false else if IdSet.mem f ctxt.recursive_ids - then string "_rec_" ^^ doc_id f, false, false, true + then doc_id f, false, false, true else doc_id f, false, false, false in let (tqs,fn_ty) = Env.get_val_spec_orig f env in let arg_typs, ret_typ, eff = match fn_ty with @@ -1236,10 +1241,13 @@ let doc_exp, doc_let = then hang 2 (call ^^ break 1 ^^ parens (flow (comma ^^ break 1) (List.map2 (doc_arg false) args arg_typs))) else let main_call = call :: List.map2 (doc_arg true) args arg_typs in - let all = if is_rec then main_call @ - [parens (string "_limit - 1"); - parens (string "Acc_inv _acc (_limit_is_limit _limit_ok)")] - else main_call + let all = + if is_rec then main_call @ + [parens (string "_limit_reduces _acc")] + else match f with + | Id_aux (Id x,_) when is_prefix "#rec#" x -> + main_call @ [parens (string "Zwf_well_founded _ _")] + | _ -> main_call in hang 2 (flow (break 1) all) in (* Decide whether to unpack an existential result, pack one, or cast. @@ -2122,66 +2130,31 @@ let doc_funcl rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) = then string "M" ^^ space ^^ parens (doc_typ ctxt ret_typ) else doc_typ ctxt ret_typ in - let intropp, idpp, accpp, measurepp, fixupspp, postpp = match rec_opt with - | Rec_aux (Rec_measure (meas_pat,meas_exp),_) -> - let check_ids (arg_pat,_) m_pat = - match arg_pat, m_pat with - | P_aux ((P_id arg_id | P_typ (_,P_aux (P_id arg_id,_))),_), - P_aux ((P_id m_id | P_typ (_,P_aux (P_id m_id,_))),_) -> - if Id.compare arg_id m_id == 0 then () else - failwith "TODO" - | _, P_aux (P_wild,_) -> () (* TODO generalise *) - | _ -> failwith "TODO" - in - let idpp = doc_id id in - let recidpp = string "_rec_" ^^ idpp in - let patnames = List.map (function - | P_aux (P_id id,_), _ -> doc_id id - | P_aux (P_typ (_,P_aux (P_id id,_)),_), _ -> doc_id id - | p,_ -> raise (Reporting.err_unreachable (pat_loc p) __POS__ - "Pattern has not been reduced to a simple binder")) - pats in - let quantnames, constrnames = typquant_names_separate ctxt tq in - let atomconstrsnames = List.map (fun _ -> underscore) atom_constrs in - let fixupspp = Util.map_filter (fun (pat,typ) -> - match pat_is_plain_binder env pat with - | Some id -> begin - match destruct_exist env (expand_range_type typ) with - | Some (_, NC_aux (NC_true,_), _) -> None - | Some ([kid], nc, - Typ_aux (Typ_app (Id_aux (Id "atom",_), - [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_)) - when Kid.compare kid kid' == 0 -> - Some (string "let " ^^ doc_id id ^^ string " := projT1 " ^^ doc_id id ^^ string " in") - | _ -> None - end - | None -> None) pats - in - let no_fixups = match fixupspp with [] -> true | _ -> false in - let measure_pp = - match pats, meas_pat with - | _, P_aux (P_tup ps,_) when List.length pats = List.length ps -> - let () = List.iter2 check_ids pats ps in - doc_exp ctxt no_fixups meas_exp - | [pat], _ -> - let () = check_ids pat meas_pat in - doc_exp ctxt no_fixups meas_exp - | _, _ -> failwith "TODO" - in - let measure_pp = match fixupspp with - [] -> measure_pp - | _ -> parens (flow (break 1) fixupspp ^/^ measure_pp) + let idpp = doc_id id in + let intropp, accpp, measurepp, fixupspp = match rec_opt with + | Rec_aux (Rec_measure _,_) -> + let fixupspp = + Util.map_filter (fun (pat,typ) -> + match pat_is_plain_binder env pat with + | Some id -> begin + match destruct_exist env (expand_range_type typ) with + | Some (_, NC_aux (NC_true,_), _) -> None + | Some ([kid], nc, + Typ_aux (Typ_app (Id_aux (Id "atom",_), + [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_)) + when Kid.compare kid kid' == 0 -> + Some (string "let " ^^ doc_id id ^^ string " := projT1 " ^^ doc_id id ^^ string " in") + | _ -> None + end + | None -> None) pats in string "Fixpoint", - recidpp, - [parens (string "_limit : Z"); - parens (string "_acc : Acc (Zwf 0) _limit")], + [parens (string "_acc : Acc (Zwf 0) _rec_limit")], [string "{struct _acc}"], - fixupspp, - hardline ^^ string "Definition " ^^ idpp ^/^ flow (break 1) (quantspp @ patspp :: constrspp @ atom_constrs) ^/^ coloneq ^/^ recidpp ^/^ flow (break 1) (quantnames @ patnames @ constrnames @ atomconstrsnames) ^/^ measure_pp ^/^ string "(Zwf_well_founded _ _)." + fixupspp | Rec_aux (r,_) -> let d = match r with Rec_nonrec -> "Definition" | _ -> "Fixpoint" in - string d, doc_id id, [], [], [], empty + string d, [], [], [] in (* Work around Coq bug 7975 about pattern binders followed by implicit arguments *) let implicitargs = @@ -2202,17 +2175,11 @@ let doc_funcl rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) = "guarded pattern expression should have been rewritten before pretty-printing") in let bodypp = doc_fun_body ctxt exp in let bodypp = if effectful eff || not build_ex then bodypp else string "build_ex" ^^ parens bodypp in - let bodypp = match rec_opt with - | Rec_aux (Rec_measure _,_) -> - string "assert_exp' (_limit >? 0) \"termination limit reached\" >>= fun _limit_ok =>" ^/^ - separate (break 1) fixupspp ^/^ - bodypp - | _ -> bodypp - in + let bodypp = separate (break 1) fixupspp ^/^ bodypp in group (prefix 3 1 (flow (break 1) ([intropp; idpp] @ quantspp @ [patspp] @ constrspp @ [atom_constr_pp] @ accpp) ^/^ flow (break 1) (measurepp @ [colon; retpp; coloneq])) - (bodypp ^^ dot)) ^^ postpp ^^ implicitargs + (bodypp ^^ dot)) ^^ implicitargs let get_id = function | [] -> failwith "FD_function with empty list" diff --git a/src/rewrites.ml b/src/rewrites.ml index c6e2743e..33b50459 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -4760,7 +4760,150 @@ let minimise_recursive_functions (Defs defs) = | d -> d in Defs (List.map rewrite_def defs) +(* Make recursive functions with a measure use the measure as an + explicit recursion limit, enforced by an assertion. *) +let rewrite_explicit_measure (Defs defs) = + let scan_function measures = function + | FD_aux (FD_function (Rec_aux (Rec_measure (mpat,mexp),rl),topt,effopt, + FCL_aux (FCL_Funcl (id,_),_)::_),ann) -> + Bindings.add id (mpat,mexp) measures + | _ -> measures + in + let scan_def measures = function + | DEF_fundef fd -> scan_function measures fd + | _ -> measures + in + let measures = List.fold_left scan_def Bindings.empty defs in + let add_escape eff = + union_effects eff (mk_effect [BE_escape]) + in + (* NB: the Coq backend relies on recognising the #rec# prefix *) + let rec_id = function + | Id_aux (Id id,l) + | Id_aux (DeIid id,l) -> Id_aux (Id ("#rec#" ^ id),Generated l) + in + let limit = mk_id "#reclimit" in + (* Add helper function with extra argument to spec *) + let rewrite_spec (VS_aux (VS_val_spec (typsch,id,extern,flag),ann) as vs) = + match Bindings.find id measures with + | _ -> begin + match typsch with + | TypSchm_aux (TypSchm_ts (tq, + Typ_aux (Typ_fn (args,res,eff),typl)),tsl) -> + [VS_aux (VS_val_spec ( + TypSchm_aux (TypSchm_ts (tq, + Typ_aux (Typ_fn (args@[int_typ],res,add_escape eff),typl)),tsl) + ,rec_id id,extern,flag),ann); + VS_aux (VS_val_spec ( + TypSchm_aux (TypSchm_ts (tq, + Typ_aux (Typ_fn (args,res,add_escape eff),typl)),tsl) + ,id,extern,flag),ann)] + | _ -> [vs] (* TODO warn *) + end + | exception Not_found -> [vs] + in + (* Add extra argument and assertion to each funcl, and rewrite recursive calls *) + let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),ann) as fcl) = + let loc = Parse_ast.Generated (fst ann) in + let P_aux (pat,pann),guard,body,ann = destruct_pexp pexp in + let extra_pat = P_aux (P_id limit,(loc,empty_tannot)) in + let pat = match pat with + | P_tup pats -> P_tup (pats@[extra_pat]) + | p -> P_tup [P_aux (p,pann);extra_pat] + in + let assert_exp = + E_aux (E_assert + (E_aux (E_app (mk_id "gteq_int", + [E_aux (E_id limit,(loc,empty_tannot)); + E_aux (E_lit (L_aux (L_num Big_int.zero,loc)),(loc,empty_tannot))]), + (loc,empty_tannot)), + (E_aux (E_lit (L_aux (L_string "recursion limit reached",loc)),(loc,empty_tannot)))), + (loc,empty_tannot)) + in + let tick = + E_aux (E_app (mk_id "sub_int", + [E_aux (E_id limit,(loc,empty_tannot)); + E_aux (E_lit (L_aux (L_num (Big_int.of_int 1),loc)),(loc,empty_tannot))]), + (loc,empty_tannot)) + in + let open Rewriter in + let body = + fold_exp { id_exp_alg with + e_app = (fun (f,args) -> + if Id.compare f id == 0 + then E_app (rec_id id, args@[tick]) + else E_app (f, args)) + } body + in + let body = E_aux (E_block [assert_exp; body],(loc,empty_tannot)) in + FCL_aux (FCL_Funcl (rec_id id, construct_pexp (P_aux (pat,pann),guard,body,ann)),ann) + in + let rewrite_function (FD_aux (FD_function (r,t,e,fcls),ann) as fd) = + let loc = Parse_ast.Generated (fst ann) in + match fcls with + | FCL_aux (FCL_Funcl (id,_),fcl_ann)::_ -> begin + match Bindings.find id measures with + | (measure_pat, measure_exp) -> + let e = match e with + | Effect_opt_aux (Effect_opt_pure, _) -> + Effect_opt_aux (Effect_opt_effect (mk_effect [BE_escape]), loc) + | Effect_opt_aux (Effect_opt_effect eff,_) -> + Effect_opt_aux (Effect_opt_effect (add_escape eff), loc) + in + let arg_typs = match Env.get_val_spec id (env_of_annot fcl_ann) with + | _, Typ_aux (Typ_fn (args,_,_),_) -> args + | _, _ -> raise (Reporting.err_unreachable (fst ann) __POS__ + "Function doesn't have function type") + in + let measure_pats = match arg_typs, measure_pat with + | [_], _ -> [measure_pat] + | _, P_aux (P_tup ps,_) -> ps + | _, _ -> [measure_pat] + in + let mk_wrap i (P_aux (p,(l,_))) = + let id = + match p with + | P_id id + | P_typ (_,(P_aux (P_id id,_))) -> id + | P_wild + | P_typ (_,(P_aux (P_wild,_))) -> + mk_id ("_arg" ^ string_of_int i) + | _ -> raise (Reporting.err_todo l "Measure patterns can only be identifiers or wildcards") + in + P_aux (P_id id,(loc,empty_tannot)), + E_aux (E_id id,(loc,empty_tannot)) + in + let wpats,wexps = List.split (Util.list_mapi mk_wrap measure_pats) in + let wpat = match wpats with + | [wpat] -> wpat + | _ -> P_aux (P_tup wpats,(loc,empty_tannot)) + in + let wbody = E_aux (E_app (rec_id id,wexps@[measure_exp]),(loc,empty_tannot)) in + let wrapper = + FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (wpat,wbody),(loc,empty_tannot))),(loc,empty_tannot)) + in + let new_rec = + Rec_aux (Rec_measure (P_aux (P_tup (List.map (fun _ -> P_aux (P_wild,(loc,empty_tannot))) measure_pats @ [P_aux (P_id limit,(loc,empty_tannot))]),(loc,empty_tannot)), E_aux (E_id limit, (loc,empty_tannot))), loc) + in + [FD_aux (FD_function (new_rec,t,e,List.map rewrite_funcl fcls),ann); + FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,e,[wrapper]),ann)] + | exception Not_found -> [fd] + end + | _ -> [fd] + in + let rewrite_def = function + | DEF_spec vs -> List.map (fun vs -> DEF_spec vs) (rewrite_spec vs) + | DEF_fundef fd -> List.map (fun f -> DEF_fundef f) (rewrite_function fd) + | d -> [d] + in + Defs (List.flatten (List.map rewrite_def defs)) + let recheck_defs defs = fst (Type_error.check initial_env defs) +let recheck_defs_without_effects defs = + let () = opt_no_effects := true in + let result,_ = Type_error.check initial_env defs in + let () = opt_no_effects := false in + result let remove_mapping_valspecs (Defs defs) = let allowed_def def = @@ -4888,8 +5031,9 @@ let rewrite_defs_coq = [ ("sizeof", rewrite_sizeof); ("early_return", rewrite_defs_early_return); ("make_cases_exhaustive", MakeExhaustive.rewrite); + ("rewrite_explicit_measure", rewrite_explicit_measure); + ("recheck_defs_without_effects", recheck_defs_without_effects); ("fix_val_specs", rewrite_fix_val_specs); - ("recheck_defs", recheck_defs); ("remove_blocks", rewrite_defs_remove_blocks); ("letbind_effects", rewrite_defs_letbind_effects); ("remove_e_assign", rewrite_defs_remove_e_assign); |
