summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml146
1 files changed, 145 insertions, 1 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index c6e2743e..33b50459 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -4760,7 +4760,150 @@ let minimise_recursive_functions (Defs defs) =
| d -> d
in Defs (List.map rewrite_def defs)
+(* Make recursive functions with a measure use the measure as an
+ explicit recursion limit, enforced by an assertion. *)
+let rewrite_explicit_measure (Defs defs) =
+ let scan_function measures = function
+ | FD_aux (FD_function (Rec_aux (Rec_measure (mpat,mexp),rl),topt,effopt,
+ FCL_aux (FCL_Funcl (id,_),_)::_),ann) ->
+ Bindings.add id (mpat,mexp) measures
+ | _ -> measures
+ in
+ let scan_def measures = function
+ | DEF_fundef fd -> scan_function measures fd
+ | _ -> measures
+ in
+ let measures = List.fold_left scan_def Bindings.empty defs in
+ let add_escape eff =
+ union_effects eff (mk_effect [BE_escape])
+ in
+ (* NB: the Coq backend relies on recognising the #rec# prefix *)
+ let rec_id = function
+ | Id_aux (Id id,l)
+ | Id_aux (DeIid id,l) -> Id_aux (Id ("#rec#" ^ id),Generated l)
+ in
+ let limit = mk_id "#reclimit" in
+ (* Add helper function with extra argument to spec *)
+ let rewrite_spec (VS_aux (VS_val_spec (typsch,id,extern,flag),ann) as vs) =
+ match Bindings.find id measures with
+ | _ -> begin
+ match typsch with
+ | TypSchm_aux (TypSchm_ts (tq,
+ Typ_aux (Typ_fn (args,res,eff),typl)),tsl) ->
+ [VS_aux (VS_val_spec (
+ TypSchm_aux (TypSchm_ts (tq,
+ Typ_aux (Typ_fn (args@[int_typ],res,add_escape eff),typl)),tsl)
+ ,rec_id id,extern,flag),ann);
+ VS_aux (VS_val_spec (
+ TypSchm_aux (TypSchm_ts (tq,
+ Typ_aux (Typ_fn (args,res,add_escape eff),typl)),tsl)
+ ,id,extern,flag),ann)]
+ | _ -> [vs] (* TODO warn *)
+ end
+ | exception Not_found -> [vs]
+ in
+ (* Add extra argument and assertion to each funcl, and rewrite recursive calls *)
+ let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),ann) as fcl) =
+ let loc = Parse_ast.Generated (fst ann) in
+ let P_aux (pat,pann),guard,body,ann = destruct_pexp pexp in
+ let extra_pat = P_aux (P_id limit,(loc,empty_tannot)) in
+ let pat = match pat with
+ | P_tup pats -> P_tup (pats@[extra_pat])
+ | p -> P_tup [P_aux (p,pann);extra_pat]
+ in
+ let assert_exp =
+ E_aux (E_assert
+ (E_aux (E_app (mk_id "gteq_int",
+ [E_aux (E_id limit,(loc,empty_tannot));
+ E_aux (E_lit (L_aux (L_num Big_int.zero,loc)),(loc,empty_tannot))]),
+ (loc,empty_tannot)),
+ (E_aux (E_lit (L_aux (L_string "recursion limit reached",loc)),(loc,empty_tannot)))),
+ (loc,empty_tannot))
+ in
+ let tick =
+ E_aux (E_app (mk_id "sub_int",
+ [E_aux (E_id limit,(loc,empty_tannot));
+ E_aux (E_lit (L_aux (L_num (Big_int.of_int 1),loc)),(loc,empty_tannot))]),
+ (loc,empty_tannot))
+ in
+ let open Rewriter in
+ let body =
+ fold_exp { id_exp_alg with
+ e_app = (fun (f,args) ->
+ if Id.compare f id == 0
+ then E_app (rec_id id, args@[tick])
+ else E_app (f, args))
+ } body
+ in
+ let body = E_aux (E_block [assert_exp; body],(loc,empty_tannot)) in
+ FCL_aux (FCL_Funcl (rec_id id, construct_pexp (P_aux (pat,pann),guard,body,ann)),ann)
+ in
+ let rewrite_function (FD_aux (FD_function (r,t,e,fcls),ann) as fd) =
+ let loc = Parse_ast.Generated (fst ann) in
+ match fcls with
+ | FCL_aux (FCL_Funcl (id,_),fcl_ann)::_ -> begin
+ match Bindings.find id measures with
+ | (measure_pat, measure_exp) ->
+ let e = match e with
+ | Effect_opt_aux (Effect_opt_pure, _) ->
+ Effect_opt_aux (Effect_opt_effect (mk_effect [BE_escape]), loc)
+ | Effect_opt_aux (Effect_opt_effect eff,_) ->
+ Effect_opt_aux (Effect_opt_effect (add_escape eff), loc)
+ in
+ let arg_typs = match Env.get_val_spec id (env_of_annot fcl_ann) with
+ | _, Typ_aux (Typ_fn (args,_,_),_) -> args
+ | _, _ -> raise (Reporting.err_unreachable (fst ann) __POS__
+ "Function doesn't have function type")
+ in
+ let measure_pats = match arg_typs, measure_pat with
+ | [_], _ -> [measure_pat]
+ | _, P_aux (P_tup ps,_) -> ps
+ | _, _ -> [measure_pat]
+ in
+ let mk_wrap i (P_aux (p,(l,_))) =
+ let id =
+ match p with
+ | P_id id
+ | P_typ (_,(P_aux (P_id id,_))) -> id
+ | P_wild
+ | P_typ (_,(P_aux (P_wild,_))) ->
+ mk_id ("_arg" ^ string_of_int i)
+ | _ -> raise (Reporting.err_todo l "Measure patterns can only be identifiers or wildcards")
+ in
+ P_aux (P_id id,(loc,empty_tannot)),
+ E_aux (E_id id,(loc,empty_tannot))
+ in
+ let wpats,wexps = List.split (Util.list_mapi mk_wrap measure_pats) in
+ let wpat = match wpats with
+ | [wpat] -> wpat
+ | _ -> P_aux (P_tup wpats,(loc,empty_tannot))
+ in
+ let wbody = E_aux (E_app (rec_id id,wexps@[measure_exp]),(loc,empty_tannot)) in
+ let wrapper =
+ FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (wpat,wbody),(loc,empty_tannot))),(loc,empty_tannot))
+ in
+ let new_rec =
+ Rec_aux (Rec_measure (P_aux (P_tup (List.map (fun _ -> P_aux (P_wild,(loc,empty_tannot))) measure_pats @ [P_aux (P_id limit,(loc,empty_tannot))]),(loc,empty_tannot)), E_aux (E_id limit, (loc,empty_tannot))), loc)
+ in
+ [FD_aux (FD_function (new_rec,t,e,List.map rewrite_funcl fcls),ann);
+ FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,e,[wrapper]),ann)]
+ | exception Not_found -> [fd]
+ end
+ | _ -> [fd]
+ in
+ let rewrite_def = function
+ | DEF_spec vs -> List.map (fun vs -> DEF_spec vs) (rewrite_spec vs)
+ | DEF_fundef fd -> List.map (fun f -> DEF_fundef f) (rewrite_function fd)
+ | d -> [d]
+ in
+ Defs (List.flatten (List.map rewrite_def defs))
+
let recheck_defs defs = fst (Type_error.check initial_env defs)
+let recheck_defs_without_effects defs =
+ let () = opt_no_effects := true in
+ let result,_ = Type_error.check initial_env defs in
+ let () = opt_no_effects := false in
+ result
let remove_mapping_valspecs (Defs defs) =
let allowed_def def =
@@ -4888,8 +5031,9 @@ let rewrite_defs_coq = [
("sizeof", rewrite_sizeof);
("early_return", rewrite_defs_early_return);
("make_cases_exhaustive", MakeExhaustive.rewrite);
+ ("rewrite_explicit_measure", rewrite_explicit_measure);
+ ("recheck_defs_without_effects", recheck_defs_without_effects);
("fix_val_specs", rewrite_fix_val_specs);
- ("recheck_defs", recheck_defs);
("remove_blocks", rewrite_defs_remove_blocks);
("letbind_effects", rewrite_defs_letbind_effects);
("remove_e_assign", rewrite_defs_remove_e_assign);