summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml28
1 files changed, 18 insertions, 10 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 2a5799d3..11b1d469 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -4476,8 +4476,9 @@ let rewrite_explicit_measure env (Defs defs) =
Bindings.add id (mpat,mexp) measures
| _ -> measures
in
- let scan_def measures = function
+ let rec scan_def measures = function
| DEF_fundef fd -> scan_function measures fd
+ | DEF_internal_mutrec fds -> List.fold_left scan_function measures fds
| _ -> measures
in
let measures = List.fold_left scan_def Bindings.empty defs in
@@ -4510,7 +4511,7 @@ let rewrite_explicit_measure env (Defs defs) =
| exception Not_found -> [vs]
in
(* Add extra argument and assertion to each funcl, and rewrite recursive calls *)
- let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann) as fcl) =
+ let rewrite_funcl recset (FCL_aux (FCL_Funcl (id,pexp),fcl_ann) as fcl) =
let loc = Parse_ast.Generated (fst fcl_ann) in
let P_aux (pat,pann),guard,body,ann = destruct_pexp pexp in
let extra_pat = P_aux (P_id limit,(loc,empty_tannot)) in
@@ -4537,15 +4538,15 @@ let rewrite_explicit_measure env (Defs defs) =
let body =
fold_exp { id_exp_alg with
e_app = (fun (f,args) ->
- if Id.compare f id == 0
- then E_app (rec_id id, args@[tick])
+ if IdSet.mem f recset
+ then E_app (rec_id f, args@[tick])
else E_app (f, args))
} body
in
let body = E_aux (E_block [assert_exp; body],(loc,empty_tannot)) in
FCL_aux (FCL_Funcl (rec_id id, construct_pexp (P_aux (pat,pann),guard,body,ann)),fcl_ann)
in
- let rewrite_function (FD_aux (FD_function (r,t,e,fcls),ann) as fd) =
+ let rewrite_function recset (FD_aux (FD_function (r,t,e,fcls),ann) as fd) =
let loc = Parse_ast.Generated (fst ann) in
match fcls with
| FCL_aux (FCL_Funcl (id,_),fcl_ann)::_ -> begin
@@ -4593,15 +4594,22 @@ let rewrite_explicit_measure env (Defs defs) =
let new_rec =
Rec_aux (Rec_measure (P_aux (P_tup (List.map (fun _ -> P_aux (P_wild,(loc,empty_tannot))) measure_pats @ [P_aux (P_id limit,(loc,empty_tannot))]),(loc,empty_tannot)), E_aux (E_id limit, (loc,empty_tannot))), loc)
in
- [FD_aux (FD_function (new_rec,t,e,List.map rewrite_funcl fcls),ann);
- FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,e,[wrapper]),ann)]
- | exception Not_found -> [fd]
+ FD_aux (FD_function (new_rec,t,e,List.map (rewrite_funcl recset) fcls),ann),
+ [FD_aux (FD_function (Rec_aux (Rec_nonrec,loc),t,e,[wrapper]),ann)]
+ | exception Not_found -> fd,[]
end
- | _ -> [fd]
+ | _ -> fd,[]
in
let rewrite_def = function
| DEF_spec vs -> List.map (fun vs -> DEF_spec vs) (rewrite_spec vs)
- | DEF_fundef fd -> List.map (fun f -> DEF_fundef f) (rewrite_function fd)
+ | DEF_fundef fd ->
+ let fd,extra = rewrite_function (IdSet.singleton (id_of_fundef fd)) fd in
+ List.map (fun f -> DEF_fundef f) (fd::extra)
+ | (DEF_internal_mutrec fds) as d ->
+ let recset = ids_of_def d in
+ let fds,extras = List.split (List.map (rewrite_function recset) fds) in
+ let extras = List.concat extras in
+ (DEF_internal_mutrec fds)::(List.map (fun f -> DEF_fundef f) extras)
| d -> [d]
in
Defs (List.flatten (List.map rewrite_def defs))