summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/rewrites.ml63
1 files changed, 49 insertions, 14 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 7e852092..86560415 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -1682,7 +1682,33 @@ let rewrite_defs_early_return (Defs defs) =
| E_return e -> e
| _ -> exp in
- let e_block es =
+ let e_if (e1, e2, e3) =
+ if is_return e2 && is_return e3 then
+ let (E_aux (_, annot)) = get_return e2 in
+ E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot))
+ else E_if (e1, e2, e3) in
+
+ let rec e_block es =
+ (* If one of the branches of an if-expression in a block is an early
+ return, fold the rest of the block after the if-expression into the
+ other branch *)
+ let fold_if_return exp block = match exp with
+ | E_aux (E_if (c, t, e), annot) when is_return t ->
+ let annot = match block with
+ | [] -> annot
+ | _ -> let (E_aux (_, annot)) = Util.last block in annot
+ in
+ let e' = E_aux (e_block (e :: block), annot) in
+ [E_aux (e_if (c, t, e'), annot)]
+ | E_aux (E_if (c, t, e), annot) when is_return e ->
+ let annot = match block with
+ | [] -> annot
+ | _ -> let (E_aux (_, annot)) = Util.last block in annot
+ in
+ let t' = E_aux (e_block (t :: block), annot) in
+ [E_aux (e_if (c, t', e), annot)]
+ | _ -> exp :: block in
+ let es = List.fold_right fold_if_return es [] in
match es with
| [E_aux (e, _)] -> e
| _ :: _ when is_return (Util.last es) ->
@@ -1690,12 +1716,6 @@ let rewrite_defs_early_return (Defs defs) =
E_return (E_aux (E_block (Util.butlast es @ [get_return e]), annot))
| _ -> E_block es in
- let e_if (e1, e2, e3) =
- if is_return e2 && is_return e3 then
- let (E_aux (_, annot)) = get_return e2 in
- E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot))
- else E_if (e1, e2, e3) in
-
let e_case (e, pes) =
let is_return_pexp (Pat_aux (pexp, _)) = match pexp with
| Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in
@@ -1710,6 +1730,17 @@ let rewrite_defs_early_return (Defs defs) =
then E_return (E_aux (E_case (e, List.map get_return_pexp pes), annot))
else E_case (e, pes) in
+ let e_let (lb, exp) =
+ let (E_aux (_, annot) as ret_exp) = get_return exp in
+ if is_return exp then E_return (E_aux (E_let (lb, ret_exp), annot))
+ else E_let (lb, exp) in
+
+ let e_internal_let (pat, exp1, exp2) =
+ let (E_aux (_, annot) as ret_exp2) = get_return exp2 in
+ if is_return exp2 then
+ E_return (E_aux (E_internal_let (pat, exp1, ret_exp2), annot))
+ else E_internal_let (pat, exp1, exp2) in
+
let e_aux (exp, (l, annot)) =
let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in
let env = env_of full_exp in
@@ -1724,14 +1755,18 @@ let rewrite_defs_early_return (Defs defs) =
let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pexp), a)) =
let pat,guard,exp,pannot = destruct_pexp pexp in
+ (* Try to pull out early returns as far as possible *)
+ let exp' =
+ fold_exp
+ { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case;
+ e_let = e_let; e_internal_let = e_internal_let }
+ exp in
+ (* Remove early return if we can pull it out completely, and rewrite
+ remaining early returns to "early_return" calls *)
let exp =
- exp
- (* Pull early returns out as far as possible *)
- |> fold_exp { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case }
- (* Remove singleton E_return *)
- |> get_return
- (* Fix effect annotations *)
- |> fold_exp { id_exp_alg with e_aux = e_aux } in
+ fold_exp
+ { id_exp_alg with e_aux = e_aux }
+ (if is_return exp' then get_return exp' else exp) in
let a = match a with
| (l, Some (env, typ, eff)) ->
(l, Some (env, typ, union_effects eff (effect_of exp)))