summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/pretty_print_coq.ml103
-rw-r--r--src/rewrites.ml146
2 files changed, 180 insertions, 69 deletions
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml
index a851c5fa..0b9fe077 100644
--- a/src/pretty_print_coq.ml
+++ b/src/pretty_print_coq.ml
@@ -913,6 +913,11 @@ let general_typ_of_annot annot =
let general_typ_of (E_aux (_,annot)) = general_typ_of_annot annot
+let is_prefix s s' =
+ let l = String.length s in
+ String.length s' >= l &&
+ String.sub s' 0 l = s
+
let prefix_recordtype = true
let report = Reporting.err_unreachable
let doc_exp, doc_let =
@@ -1181,7 +1186,7 @@ let doc_exp, doc_let =
if Env.is_extern f env "coq"
then string (Env.get_extern f env "coq"), true, false, false
else if IdSet.mem f ctxt.recursive_ids
- then string "_rec_" ^^ doc_id f, false, false, true
+ then doc_id f, false, false, true
else doc_id f, false, false, false in
let (tqs,fn_ty) = Env.get_val_spec_orig f env in
let arg_typs, ret_typ, eff = match fn_ty with
@@ -1236,10 +1241,13 @@ let doc_exp, doc_let =
then hang 2 (call ^^ break 1 ^^ parens (flow (comma ^^ break 1) (List.map2 (doc_arg false) args arg_typs)))
else
let main_call = call :: List.map2 (doc_arg true) args arg_typs in
- let all = if is_rec then main_call @
- [parens (string "_limit - 1");
- parens (string "Acc_inv _acc (_limit_is_limit _limit_ok)")]
- else main_call
+ let all =
+ if is_rec then main_call @
+ [parens (string "_limit_reduces _acc")]
+ else match f with
+ | Id_aux (Id x,_) when is_prefix "#rec#" x ->
+ main_call @ [parens (string "Zwf_well_founded _ _")]
+ | _ -> main_call
in hang 2 (flow (break 1) all) in
(* Decide whether to unpack an existential result, pack one, or cast.
@@ -2122,66 +2130,31 @@ let doc_funcl rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) =
then string "M" ^^ space ^^ parens (doc_typ ctxt ret_typ)
else doc_typ ctxt ret_typ
in
- let intropp, idpp, accpp, measurepp, fixupspp, postpp = match rec_opt with
- | Rec_aux (Rec_measure (meas_pat,meas_exp),_) ->
- let check_ids (arg_pat,_) m_pat =
- match arg_pat, m_pat with
- | P_aux ((P_id arg_id | P_typ (_,P_aux (P_id arg_id,_))),_),
- P_aux ((P_id m_id | P_typ (_,P_aux (P_id m_id,_))),_) ->
- if Id.compare arg_id m_id == 0 then () else
- failwith "TODO"
- | _, P_aux (P_wild,_) -> () (* TODO generalise *)
- | _ -> failwith "TODO"
- in
- let idpp = doc_id id in
- let recidpp = string "_rec_" ^^ idpp in
- let patnames = List.map (function
- | P_aux (P_id id,_), _ -> doc_id id
- | P_aux (P_typ (_,P_aux (P_id id,_)),_), _ -> doc_id id
- | p,_ -> raise (Reporting.err_unreachable (pat_loc p) __POS__
- "Pattern has not been reduced to a simple binder"))
- pats in
- let quantnames, constrnames = typquant_names_separate ctxt tq in
- let atomconstrsnames = List.map (fun _ -> underscore) atom_constrs in
- let fixupspp = Util.map_filter (fun (pat,typ) ->
- match pat_is_plain_binder env pat with
- | Some id -> begin
- match destruct_exist env (expand_range_type typ) with
- | Some (_, NC_aux (NC_true,_), _) -> None
- | Some ([kid], nc,
- Typ_aux (Typ_app (Id_aux (Id "atom",_),
- [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_))
- when Kid.compare kid kid' == 0 ->
- Some (string "let " ^^ doc_id id ^^ string " := projT1 " ^^ doc_id id ^^ string " in")
- | _ -> None
- end
- | None -> None) pats
- in
- let no_fixups = match fixupspp with [] -> true | _ -> false in
- let measure_pp =
- match pats, meas_pat with
- | _, P_aux (P_tup ps,_) when List.length pats = List.length ps ->
- let () = List.iter2 check_ids pats ps in
- doc_exp ctxt no_fixups meas_exp
- | [pat], _ ->
- let () = check_ids pat meas_pat in
- doc_exp ctxt no_fixups meas_exp
- | _, _ -> failwith "TODO"
- in
- let measure_pp = match fixupspp with
- [] -> measure_pp
- | _ -> parens (flow (break 1) fixupspp ^/^ measure_pp)
+ let idpp = doc_id id in
+ let intropp, accpp, measurepp, fixupspp = match rec_opt with
+ | Rec_aux (Rec_measure _,_) ->
+ let fixupspp =
+ Util.map_filter (fun (pat,typ) ->
+ match pat_is_plain_binder env pat with
+ | Some id -> begin
+ match destruct_exist env (expand_range_type typ) with
+ | Some (_, NC_aux (NC_true,_), _) -> None
+ | Some ([kid], nc,
+ Typ_aux (Typ_app (Id_aux (Id "atom",_),
+ [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_))
+ when Kid.compare kid kid' == 0 ->
+ Some (string "let " ^^ doc_id id ^^ string " := projT1 " ^^ doc_id id ^^ string " in")
+ | _ -> None
+ end
+ | None -> None) pats
in
string "Fixpoint",
- recidpp,
- [parens (string "_limit : Z");
- parens (string "_acc : Acc (Zwf 0) _limit")],
+ [parens (string "_acc : Acc (Zwf 0) _rec_limit")],
[string "{struct _acc}"],
- fixupspp,
- hardline ^^ string "Definition " ^^ idpp ^/^ flow (break 1) (quantspp @ patspp :: constrspp @ atom_constrs) ^/^ coloneq ^/^ recidpp ^/^ flow (break 1) (quantnames @ patnames @ constrnames @ atomconstrsnames) ^/^ measure_pp ^/^ string "(Zwf_well_founded _ _)."
+ fixupspp
| Rec_aux (r,_) ->
let d = match r with Rec_nonrec -> "Definition" | _ -> "Fixpoint" in
- string d, doc_id id, [], [], [], empty
+ string d, [], [], []
in
(* Work around Coq bug 7975 about pattern binders followed by implicit arguments *)
let implicitargs =
@@ -2202,17 +2175,11 @@ let doc_funcl rec_opt (FCL_aux(FCL_Funcl(id, pexp), annot)) =
"guarded pattern expression should have been rewritten before pretty-printing") in
let bodypp = doc_fun_body ctxt exp in
let bodypp = if effectful eff || not build_ex then bodypp else string "build_ex" ^^ parens bodypp in
- let bodypp = match rec_opt with
- | Rec_aux (Rec_measure _,_) ->
- string "assert_exp' (_limit >? 0) \"termination limit reached\" >>= fun _limit_ok =>" ^/^
- separate (break 1) fixupspp ^/^
- bodypp
- | _ -> bodypp
- in
+ let bodypp = separate (break 1) fixupspp ^/^ bodypp in
group (prefix 3 1
(flow (break 1) ([intropp; idpp] @ quantspp @ [patspp] @ constrspp @ [atom_constr_pp] @ accpp) ^/^
flow (break 1) (measurepp @ [colon; retpp; coloneq]))
- (bodypp ^^ dot)) ^^ postpp ^^ implicitargs
+ (bodypp ^^ dot)) ^^ implicitargs
let get_id = function
| [] -> failwith "FD_function with empty list"
diff --git a/src/rewrites.ml b/src/rewrites.ml
index c6e2743e..33b50459 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -4760,7 +4760,150 @@ let minimise_recursive_functions (Defs defs) =
| d -> d
in Defs (List.map rewrite_def defs)
+(* Make recursive functions with a measure use the measure as an
+ explicit recursion limit, enforced by an assertion. *)
+let rewrite_explicit_measure (Defs defs) =
+ let scan_function measures = function
+ | FD_aux (FD_function (Rec_aux (Rec_measure (mpat,mexp),rl),topt,effopt,
+ FCL_aux (FCL_Funcl (id,_),_)::_),ann) ->
+ Bindings.add id (mpat,mexp) measures
+ | _ -> measures
+ in
+ let scan_def measures = function
+ | DEF_fundef fd -> scan_function measures fd
+ | _ -> measures
+ in
+ let measures = List.fold_left scan_def Bindings.empty defs in
+ let add_escape eff =
+ union_effects eff (mk_effect [BE_escape])
+ in
+ (* NB: the Coq backend relies on recognising the #rec# prefix *)
+ let rec_id = function
+ | Id_aux (Id id,l)
+ | Id_aux (DeIid id,l) -> Id_aux (Id ("#rec#" ^ id),Generated l)
+ in
+ let limit = mk_id "#reclimit" in
+ (* Add helper function with extra argument to spec *)
+ let rewrite_spec (VS_aux (VS_val_spec (typsch,id,extern,flag),ann) as vs) =
+ match Bindings.find id measures with
+ | _ -> begin
+ match typsch with
+ | TypSchm_aux (TypSchm_ts (tq,
+ Typ_aux (Typ_fn (args,res,eff),typl)),tsl) ->
+ [VS_aux (VS_val_spec (
+ TypSchm_aux (TypSchm_ts (tq,
+ Typ_aux (Typ_fn (args@[int_typ],res,add_escape eff),typl)),tsl)
+ ,rec_id id,extern,flag),ann);
+ VS_aux (VS_val_spec (
+ TypSchm_aux (TypSchm_ts (tq,
+ Typ_aux (Typ_fn (args,res,add_escape eff),typl)),tsl)
+ ,id,extern,flag),ann)]
+ | _ -> [vs] (* TODO warn *)
+ end
+ | exception Not_found -> [vs]
+ in
+ (* Add extra argument and assertion to each funcl, and rewrite recursive calls *)
+ let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),ann) as fcl) =
+ let loc = Parse_ast.Generated (fst ann) in
+ let P_aux (pat,pann),guard,body,ann = destruct_pexp pexp in
+ let extra_pat = P_aux (P_id limit,(loc,empty_tannot)) in
+ let pat = match pat with
+ | P_tup pats -> P_tup (pats@[extra_pat])
+ | p -> P_tup [P_aux (p,pann);extra_pat]
+ in
+ let assert_exp =
+ E_aux (E_assert
+ (E_aux (E_app (mk_id "gteq_int",
+ [E_aux (E_id limit,(loc,empty_tannot));
+ E_aux (E_lit (L_aux (L_num Big_int.zero,loc)),(loc,empty_tannot))]),
+ (loc,empty_tannot)),
+ (E_aux (E_lit (L_aux (L_string "recursion limit reached",loc)),(loc,empty_tannot)))),
+ (loc,empty_tannot))
+ in
+ let tick =
+ E_aux (E_app (mk_id "sub_int",
+ [E_aux (E_id limit,(loc,empty_tannot));
+ E_aux (E_lit (L_aux (L_num (Big_int.of_int 1),loc)),(loc,empty_tannot))]),
+ (loc,empty_tannot))
+ in
+ let open Rewriter in
+ let body =
+ fold_exp { id_exp_alg with
+ e_app = (fun (f,args) ->
+ if Id.compare f id == 0
+ then E_app (rec_id id, args@[tick])
+ else E_app (f, args))
+ } body
+ in
+ let body = E_aux (E_block [assert_exp; body],(loc,empty_tannot)) in
+ FCL_aux (FCL_Funcl (rec_id id, construct_pexp (P_aux (pat,pann),guard,body,ann)),ann)
+ in
+ let rewrite_function (FD_aux (FD_function (r,t,e,fcls),ann) as fd) =
+ let loc = Parse_ast.Generated (fst ann) in
+ match fcls with
+ | FCL_aux (FCL_Funcl (id,_),fcl_ann)::_ -> begin
+ match Bindings.find id measures with
+ | (measure_pat, measure_exp) ->
+ let e = match e with
+ | Effect_opt_aux (Effect_opt_pure, _) ->
+ Effect_opt_aux (Effect_opt_effect (mk_effect [BE_escape]), loc)
+ | Effect_opt_aux (Effect_opt_effect eff,_) ->
+ Effect_opt_aux (Effect_opt_effect (add_escape eff), loc)
+ in
+ let arg_typs = match Env.get_val_spec id (env_of_annot fcl_ann) with
+ | _, Typ_aux (Typ_fn (args,_,_),_) -> args
+ | _, _ -> raise (Reporting.err_unreachable (fst ann) __POS__
+ "Function doesn't have function type")
+ in
+ let measure_pats = match arg_typs, measure_pat with
+ | [_], _ -> [measure_pat]
+ | _, P_aux (P_tup ps,_) -> ps
+ | _, _ -> [measure_pat]
+ in
+ let mk_wrap i (P_aux (p,(l,_))) =
+ let id =
+ match p with
+ | P_id id
+ | P_typ (_,(P_aux (P_id id,_))) -> id
+ | P_wild
+ | P_typ (_,(P_aux (P_wild,_))) ->
+ mk_id ("_arg" ^ string_of_int i)
+ | _ -> raise (Reporting.err_todo l "Measure patterns can only be identifiers or wildcards")
+ in
+ P_aux (P_id id,(loc,empty_tannot)),
+ E_aux (E_id id,(loc,empty_tannot))
+ in
+ let wpats,wexps = List.split (Util.list_mapi mk_wrap measure_pats) in
+ let wpat = match wpats with
+ | [wpat] -> wpat
+ | _ -> P_aux (P_tup wpats,(loc,empty_tannot))
+ in
+ let wbody = E_aux (E_app (rec_id id,wexps@[measure_exp]),(loc,empty_tannot)) in
+ let wrapper =
+ FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (wpat,wbody),(loc,empty_tannot))),(loc,empty_tannot))
+ in
+ let new_rec =
+ Rec_aux (Rec_measure (P_aux (P_tup (List.map (fun _ -> P_aux (P_wild,(loc,empty_tannot))) measure_pats @ [P_aux (P_id limit,(loc,empty_tannot))]),(loc,empty_tannot)), E_aux (E_id limit, (loc,empty_tannot))), loc)
+ in
+ [FD_aux (FD_function (new_rec,t,e,List.map rewrite_funcl fcls),ann);
+ FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,e,[wrapper]),ann)]
+ | exception Not_found -> [fd]
+ end
+ | _ -> [fd]
+ in
+ let rewrite_def = function
+ | DEF_spec vs -> List.map (fun vs -> DEF_spec vs) (rewrite_spec vs)
+ | DEF_fundef fd -> List.map (fun f -> DEF_fundef f) (rewrite_function fd)
+ | d -> [d]
+ in
+ Defs (List.flatten (List.map rewrite_def defs))
+
let recheck_defs defs = fst (Type_error.check initial_env defs)
+let recheck_defs_without_effects defs =
+ let () = opt_no_effects := true in
+ let result,_ = Type_error.check initial_env defs in
+ let () = opt_no_effects := false in
+ result
let remove_mapping_valspecs (Defs defs) =
let allowed_def def =
@@ -4888,8 +5031,9 @@ let rewrite_defs_coq = [
("sizeof", rewrite_sizeof);
("early_return", rewrite_defs_early_return);
("make_cases_exhaustive", MakeExhaustive.rewrite);
+ ("rewrite_explicit_measure", rewrite_explicit_measure);
+ ("recheck_defs_without_effects", recheck_defs_without_effects);
("fix_val_specs", rewrite_fix_val_specs);
- ("recheck_defs", recheck_defs);
("remove_blocks", rewrite_defs_remove_blocks);
("letbind_effects", rewrite_defs_letbind_effects);
("remove_e_assign", rewrite_defs_remove_e_assign);