summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml174
1 files changed, 113 insertions, 61 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index d926dfac..b328307d 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -1723,14 +1723,21 @@ let rewrite_defs_early_return (Defs defs) =
| E_lit (L_aux (L_unit, _)) -> true
| _ -> false in
- let is_return (E_aux (exp, _)) = match exp with
+ let rec is_return (E_aux (exp, _)) = match exp with
| E_return _ -> true
+ | E_cast (_, e) -> is_return e
| _ -> false in
- let get_return (E_aux (e, (l, _)) as exp) = match e with
+ let rec get_return (E_aux (e, annot) as exp) = match e with
| E_return e -> e
+ | E_cast (typ, e) -> E_aux (E_cast (typ, get_return e), annot)
| _ -> exp in
+ let contains_return exp =
+ fst (fold_exp
+ { (compute_exp_alg false (||))
+ with e_return = (fun (_, r) -> (true, E_return r)) } exp) in
+
let e_if (e1, e2, e3) =
if is_return e2 && is_return e3 then
let (E_aux (_, annot)) = get_return e2 in
@@ -1792,6 +1799,12 @@ let rewrite_defs_early_return (Defs defs) =
E_return (E_aux (E_var (lexp, exp1, ret_exp2), annot))
else E_var (lexp, exp1, exp2) in
+ let e_app (id, es) =
+ try E_return (get_return (List.find is_return es))
+ with
+ | Not_found -> E_app (id, es)
+ in
+
let e_aux (exp, (l, annot)) =
let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in
let env = env_of full_exp in
@@ -1804,20 +1817,51 @@ let rewrite_defs_early_return (Defs defs) =
E_aux (E_app (mk_id "early_return", [exp']), (l, annot'))
| _ -> full_exp in
+ (* Make sure that all final leaves of an expression (e.g. all branches of
+ the last if-expression) are wrapped in a return statement. This allows
+ the above rewriting to uniformly pull these returns back out, even if
+ originally only one of the branches of the last if-expression was a
+ return, and the other an "exit()", for example. *)
+ let rec add_final_return nested (E_aux (e, annot) as exp) =
+ let rewrap e = E_aux (e, annot) in
+ match e with
+ | E_return _ -> exp
+ | E_cast (typ, e') ->
+ begin
+ let (E_aux (e_aux', annot') as e') = add_final_return nested e' in
+ match e_aux' with
+ | E_return e' -> rewrap (E_return (rewrap (E_cast (typ, e'))))
+ | _ -> rewrap (E_cast (typ, e'))
+ end
+ | E_block ((_ :: _) as es) ->
+ rewrap (E_block (Util.butlast es @ [add_final_return true (Util.last es)]))
+ | E_if (c, t, e) ->
+ rewrap (E_if (c, add_final_return true t, add_final_return true e))
+ | E_let (lb, exp) ->
+ rewrap (E_let (lb, add_final_return true exp))
+ | E_var (lexp, e1, e2) ->
+ rewrap (E_var (lexp, e1, add_final_return true e2))
+ | _ ->
+ if nested && not (contains_return exp) then rewrap (E_return exp) else exp
+ in
+
let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pexp), a)) =
let pat,guard,exp,pannot = destruct_pexp pexp in
- (* Try to pull out early returns as far as possible *)
- let exp' =
- fold_exp
- { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case;
- e_let = e_let; e_internal_let = e_internal_let }
- exp in
- (* Remove early return if we can pull it out completely, and rewrite
- remaining early returns to "early_return" calls *)
let exp =
- fold_exp
- { id_exp_alg with e_aux = e_aux }
- (if is_return exp' then get_return exp' else exp) in
+ if contains_return exp then
+ (* Try to pull out early returns as far as possible *)
+ let exp' =
+ fold_exp
+ { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case;
+ e_let = e_let; e_internal_let = e_internal_let; e_app = e_app }
+ (add_final_return false exp) in
+ (* Remove early return if we can pull it out completely, and rewrite
+ remaining early returns to "early_return" calls *)
+ fold_exp
+ { id_exp_alg with e_aux = e_aux }
+ (if is_return exp' then get_return exp' else exp)
+ else exp
+ in
let a = match a with
| (l, Some (env, typ, eff)) ->
(l, Some (env, typ, union_effects eff (effect_of exp)))
@@ -1978,6 +2022,8 @@ let rewrite_fix_val_specs (Defs defs) =
| E_aux (E_app_infix (_, f, _) as exp, (l, Some (env, typ, eff)))
| E_aux (E_app (f, _) as exp, (l, Some (env, typ, eff))) ->
let vs = find_vs env val_specs f in
+ (* The (updated) environment is used later by fix_eff_exp to look up
+ the actual effects of a function call *)
let env = Env.update_val_spec f vs env in
E_aux (exp, (l, Some (env, typ, union_effects eff (eff_of_vs vs))))
| e_aux -> e_aux
@@ -1992,8 +2038,14 @@ let rewrite_fix_val_specs (Defs defs) =
let vs, eff = match find_vs (env_of_annot (l, annot)) val_specs id with
| (tq, Typ_aux (Typ_fn (args_t, ret_t, eff), a)) ->
let eff' = union_effects eff (effect_of exp) in
- let args_t' = rewrite_typ_nexp_ids (env_of exp) (pat_typ_of pat) in
- let ret_t' = rewrite_typ_nexp_ids (env_of exp) (typ_of exp) in
+ (* TODO We currently expand type synonyms here to make life easier
+ for the Lem pretty-printer later on, as it currently does not have
+ access to the environment when printing val specs (and unexpanded
+ type synonyms nested in existentials might cause problems).
+ A more robust solution would be to make the (global) environment
+ more easily available to the pretty-printer. *)
+ let args_t' = Env.expand_synonyms (env_of exp) args_t in
+ let ret_t' = Env.expand_synonyms (env_of exp) ret_t in
(tq, Typ_aux (Typ_fn (args_t', ret_t', eff'), a)), eff'
| _ -> assert false (* find_vs must return a function type *)
in
@@ -2014,7 +2066,7 @@ let rewrite_fix_val_specs (Defs defs) =
let tannotopt = match tannotopt, funcls with
| Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), l),
FCL_aux (FCL_Funcl (_, Pat_aux ((Pat_exp (_, exp) | Pat_when (_, _, exp)), _)), _) :: _ ->
- Typ_annot_opt_aux (Typ_annot_opt_some (typq, rewrite_typ_nexp_ids (env_of exp) typ), l)
+ Typ_annot_opt_aux (Typ_annot_opt_some (typq, Env.expand_synonyms (env_of exp) typ), l)
| _ -> tannotopt in
(val_specs, FD_aux (FD_function (recopt, tannotopt, effopt, funcls), a)) in
@@ -2572,9 +2624,9 @@ let rewrite_defs_letbind_effects =
k (rewrap (E_case (exp1,pexps)))))
| E_try (exp1,pexps) ->
let newreturn = effectful exp1 || List.exists effectful_pexp pexps in
- n_exp_name exp1 (fun exp1 ->
+ let exp1 = n_exp_term newreturn exp1 in
n_pexpL newreturn pexps (fun pexps ->
- k (rewrap (E_try (exp1,pexps)))))
+ k (rewrap (E_try (exp1,pexps))))
| E_let (lb,body) ->
n_lb lb (fun lb ->
rewrap (E_let (lb,n_exp body k)))
@@ -2797,7 +2849,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
exp, P_aux (P_id id, a))
|> List.split in
- let rewrite (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) =
+ let rewrite used_vars (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) =
let env = env_of_annot annot in
let overwrite = match paux with
| P_wild | P_typ (_, P_aux (P_wild, _)) -> true
@@ -2813,7 +2865,11 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
function as an expression followed by the list of variables it
expects. In (Lem) pretty-printing, this turned into an anonymous
function and passed to foreach*. *)
- let vars, varpats = mk_var_exps_pats pl env (find_updated_vars exp4) in
+ let vars, varpats =
+ find_updated_vars exp4
+ |> IdSet.inter used_vars
+ |> mk_var_exps_pats pl env
+ in
let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in
let ord_exp, lower, upper = match destruct_range (typ_of exp1), destruct_range (typ_of exp2) with
| None, _ | _, None ->
@@ -2833,7 +2889,11 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let v = annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; tuple_exp vars; body])) el env (typ_of body) in
Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
| E_loop(loop,cond,body) ->
- let vars, varpats = mk_var_exps_pats pl env (find_updated_vars body) in
+ let vars, varpats =
+ find_updated_vars body
+ |> IdSet.inter used_vars
+ |> mk_var_exps_pats pl env
+ in
let body = rewrite_var_updates (add_vars overwrite body vars) in
let (E_aux (_,(_,bannot))) = body in
let fname = match loop with
@@ -2845,6 +2905,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| E_if (c,e1,e2) ->
let vars, varpats =
IdSet.union (find_updated_vars e1) (find_updated_vars e2)
+ |> IdSet.inter used_vars
|> mk_var_exps_pats pl env in
if vars = [] then
(Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot)))
@@ -2864,6 +2925,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
|> List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e)
|> List.map find_updated_vars
|> List.fold_left IdSet.union IdSet.empty
+ |> IdSet.inter used_vars
|> mk_var_exps_pats pl env in
if vars = [] then
let ps = List.map (function
@@ -2924,7 +2986,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| E_let (lb,body) ->
let body = rewrite_var_updates body in
let (LB_aux (LB_val (pat, v), lbannot)) = lb in
- let lb = match rewrite v pat with
+ let lb = match rewrite (find_used_vars body) v pat with
| Added_vars (v, P_aux (pat, _)) ->
annot_letbind (pat, v) (get_loc_exp v) env (typ_of v)
| Same_vars v -> LB_aux (LB_val (pat, v),lbannot) in
@@ -2987,26 +3049,19 @@ let remove_reference_types exp =
let rewrite_defs_remove_superfluous_letbinds =
let e_aux (exp,annot) = match exp with
- | E_let (lb,exp2) ->
- begin match lb,exp2 with
+ | E_let (LB_aux (LB_val (pat, exp1), _), exp2) ->
+ begin match untyp_pat pat, uncast_exp exp1, uncast_exp exp2 with
(* 'let x = EXP1 in x' can be replaced with 'EXP1' *)
- | LB_aux (LB_val (P_aux (P_id id, _), exp1), _),
- E_aux (E_id id', _)
- | LB_aux (LB_val (P_aux (P_typ (_, P_aux (P_id id, _)), _), exp1), _),
- E_aux (E_id id', _)
- | LB_aux (LB_val (P_aux (P_id id, _), exp1), _),
- E_aux (E_cast (_,E_aux (E_id id', _)), _)
- when Id.compare id id' = 0 && id_is_unbound id (env_of_annot annot) ->
+ | (P_aux (P_id id, _), _), _, (E_aux (E_id id', _), _)
+ when Id.compare id id' = 0 ->
exp1
(* "let _ = () in exp" can be replaced with exp *)
- | LB_aux (LB_val (P_aux (P_wild, _), E_aux (E_lit (L_aux (L_unit, _)), _)), _),
- exp2 ->
+ | (P_aux (P_wild, _), _), (E_aux (E_lit (L_aux (L_unit, _)), _), _), _ ->
exp2
(* "let x = EXP1 in return x" can be replaced with 'return (EXP1)', at
least when EXP1 is 'small' enough *)
- | LB_aux (LB_val (P_aux (P_id id, _), exp1), _),
- E_aux (E_internal_return (E_aux (E_id id', _)), _)
- when Id.compare id id' = 0 && small exp1 && id_is_unbound id (env_of_annot annot) ->
+ | (P_aux (P_id id, _), _), _, (E_aux (E_internal_return (E_aux (E_id id', _)), _), _)
+ when Id.compare id id' = 0 && small exp1 ->
let (E_aux (_,e1annot)) = exp1 in
E_aux (E_internal_return (exp1),e1annot)
| _ -> E_aux (exp,annot)
@@ -3027,44 +3082,41 @@ let rewrite_defs_remove_superfluous_letbinds =
let rewrite_defs_remove_superfluous_returns =
- let untyp_pat = function
- | P_aux (P_typ (typ, pat), _) -> pat, Some typ
- | pat -> pat, None in
-
- let uncast_internal_return = function
- | E_aux (E_internal_return (E_aux (E_cast (typ, exp), _)), a) ->
- E_aux (E_internal_return exp, a), Some typ
- | exp -> exp, None in
+ let add_opt_cast typopt1 typopt2 annot exp =
+ match typopt1, typopt2 with
+ | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp), annot)
+ | None, None -> exp
+ in
let e_aux (exp,annot) = match exp with
| E_let (LB_aux (LB_val (pat, exp1), _), exp2)
| E_internal_plet (pat, exp1, exp2)
when effectful exp1 ->
- begin match untyp_pat pat, uncast_internal_return exp2 with
+ begin match untyp_pat pat, uncast_exp exp2 with
| (P_aux (P_lit (L_aux (lit,_)),_), ptyp),
(E_aux (E_internal_return (E_aux (E_lit (L_aux (lit',_)),_)), a), etyp)
when lit = lit' ->
- begin
- match ptyp, etyp with
- | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a)
- | None, None -> exp1
- end
+ add_opt_cast ptyp etyp a exp1
| (P_aux (P_wild,pannot), ptyp),
(E_aux (E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)), a), etyp)
when is_unit_typ (typ_of exp1) ->
- begin
- match ptyp, etyp with
- | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a)
- | None, None -> exp1
- end
+ add_opt_cast ptyp etyp a exp1
| (P_aux (P_id id,_), ptyp),
(E_aux (E_internal_return (E_aux (E_id id',_)), a), etyp)
- when Id.compare id id' == 0 && id_is_unbound id (env_of_annot annot) ->
- begin
- match ptyp, etyp with
- | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a)
- | None, None -> exp1
- end
+ when Id.compare id id' == 0 ->
+ add_opt_cast ptyp etyp a exp1
+ | (P_aux (P_tup ps, _), ptyp),
+ (E_aux (E_internal_return (E_aux (E_tuple es, _)), a), etyp)
+ when List.length ps = List.length es ->
+ let same_id (P_aux (p, _)) (E_aux (e, _)) = match p, e with
+ | P_id id, E_id id' -> Id.compare id id' == 0
+ | _, _ -> false
+ in
+ let ps = List.map fst (List.map untyp_pat ps) in
+ let es = List.map fst (List.map uncast_exp es) in
+ if List.for_all2 same_id ps es
+ then add_opt_cast ptyp etyp a exp1
+ else E_aux (exp,annot)
| _ -> E_aux (exp,annot)
end
| _ -> E_aux (exp,annot) in