diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/rewrites.ml | 63 |
1 files changed, 49 insertions, 14 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index 7e852092..86560415 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1682,7 +1682,33 @@ let rewrite_defs_early_return (Defs defs) = | E_return e -> e | _ -> exp in - let e_block es = + let e_if (e1, e2, e3) = + if is_return e2 && is_return e3 then + let (E_aux (_, annot)) = get_return e2 in + E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot)) + else E_if (e1, e2, e3) in + + let rec e_block es = + (* If one of the branches of an if-expression in a block is an early + return, fold the rest of the block after the if-expression into the + other branch *) + let fold_if_return exp block = match exp with + | E_aux (E_if (c, t, e), annot) when is_return t -> + let annot = match block with + | [] -> annot + | _ -> let (E_aux (_, annot)) = Util.last block in annot + in + let e' = E_aux (e_block (e :: block), annot) in + [E_aux (e_if (c, t, e'), annot)] + | E_aux (E_if (c, t, e), annot) when is_return e -> + let annot = match block with + | [] -> annot + | _ -> let (E_aux (_, annot)) = Util.last block in annot + in + let t' = E_aux (e_block (t :: block), annot) in + [E_aux (e_if (c, t', e), annot)] + | _ -> exp :: block in + let es = List.fold_right fold_if_return es [] in match es with | [E_aux (e, _)] -> e | _ :: _ when is_return (Util.last es) -> @@ -1690,12 +1716,6 @@ let rewrite_defs_early_return (Defs defs) = E_return (E_aux (E_block (Util.butlast es @ [get_return e]), annot)) | _ -> E_block es in - let e_if (e1, e2, e3) = - if is_return e2 && is_return e3 then - let (E_aux (_, annot)) = get_return e2 in - E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot)) - else E_if (e1, e2, e3) in - let e_case (e, pes) = let is_return_pexp (Pat_aux (pexp, _)) = match pexp with | Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in @@ -1710,6 +1730,17 @@ let rewrite_defs_early_return (Defs defs) = then E_return (E_aux (E_case (e, List.map get_return_pexp pes), annot)) else E_case (e, pes) in + let e_let (lb, exp) = + let (E_aux (_, annot) as ret_exp) = get_return exp in + if is_return exp then E_return (E_aux (E_let (lb, ret_exp), annot)) + else E_let (lb, exp) in + + let e_internal_let (pat, exp1, exp2) = + let (E_aux (_, annot) as ret_exp2) = get_return exp2 in + if is_return exp2 then + E_return (E_aux (E_internal_let (pat, exp1, ret_exp2), annot)) + else E_internal_let (pat, exp1, exp2) in + let e_aux (exp, (l, annot)) = let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in let env = env_of full_exp in @@ -1724,14 +1755,18 @@ let rewrite_defs_early_return (Defs defs) = let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pexp), a)) = let pat,guard,exp,pannot = destruct_pexp pexp in + (* Try to pull out early returns as far as possible *) + let exp' = + fold_exp + { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case; + e_let = e_let; e_internal_let = e_internal_let } + exp in + (* Remove early return if we can pull it out completely, and rewrite + remaining early returns to "early_return" calls *) let exp = - exp - (* Pull early returns out as far as possible *) - |> fold_exp { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case } - (* Remove singleton E_return *) - |> get_return - (* Fix effect annotations *) - |> fold_exp { id_exp_alg with e_aux = e_aux } in + fold_exp + { id_exp_alg with e_aux = e_aux } + (if is_return exp' then get_return exp' else exp) in let a = match a with | (l, Some (env, typ, eff)) -> (l, Some (env, typ, union_effects eff (effect_of exp))) |
