summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml74
1 files changed, 58 insertions, 16 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index e4ac71cd..40772828 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -1087,7 +1087,7 @@ let remove_bitvector_pat (P_aux (_, (l, _)) as pat) =
; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot))
; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false))
} in
- let pat, env = bind_pat env
+ let pat, env = bind_pat_no_guard env
(strip_pat ((fold_pat name_bitvector_roots pat) false))
(pat_typ_of pat) in
@@ -1624,6 +1624,10 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base
rewriting of early returns
*)
let rewrite_defs_early_return (Defs defs) =
+ let is_unit (E_aux (exp, _)) = match exp with
+ | E_lit (L_aux (L_unit, _)) -> true
+ | _ -> false in
+
let is_return (E_aux (exp, _)) = match exp with
| E_return _ -> true
| _ -> false in
@@ -1632,7 +1636,35 @@ let rewrite_defs_early_return (Defs defs) =
| E_return e -> e
| _ -> exp in
- let e_block es =
+ let e_if (e1, e2, e3) =
+ if is_return e2 && is_return e3 then
+ let (E_aux (_, annot)) = get_return e2 in
+ E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot))
+ else E_if (e1, e2, e3) in
+
+ let rec e_block es =
+ (* If one of the branches of an if-expression in a block is an early
+ return, fold the rest of the block after the if-expression into the
+ other branch *)
+ let fold_if_return exp block = match exp with
+ | E_aux (E_if (c, t, (E_aux (_, annot) as e)), _) when is_return t ->
+ let annot = match block with
+ | [] -> annot
+ | _ -> let (E_aux (_, annot)) = Util.last block in annot
+ in
+ let block = if is_unit e then block else e :: block in
+ let e' = E_aux (e_block block, annot) in
+ [E_aux (e_if (c, t, e'), annot)]
+ | E_aux (E_if (c, (E_aux (_, annot) as t), e), _) when is_return e ->
+ let annot = match block with
+ | [] -> annot
+ | _ -> let (E_aux (_, annot)) = Util.last block in annot
+ in
+ let block = if is_unit t then block else t :: block in
+ let t' = E_aux (e_block block, annot) in
+ [E_aux (e_if (c, t', e), annot)]
+ | _ -> exp :: block in
+ let es = List.fold_right fold_if_return es [] in
match es with
| [E_aux (e, _)] -> e
| _ :: _ when is_return (Util.last es) ->
@@ -1640,12 +1672,6 @@ let rewrite_defs_early_return (Defs defs) =
E_return (E_aux (E_block (Util.butlast es @ [get_return e]), annot))
| _ -> E_block es in
- let e_if (e1, e2, e3) =
- if is_return e2 && is_return e3 then
- let (E_aux (_, annot)) = get_return e2 in
- E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot))
- else E_if (e1, e2, e3) in
-
let e_case (e, pes) =
let is_return_pexp (Pat_aux (pexp, _)) = match pexp with
| Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in
@@ -1660,6 +1686,17 @@ let rewrite_defs_early_return (Defs defs) =
then E_return (E_aux (E_case (e, List.map get_return_pexp pes), annot))
else E_case (e, pes) in
+ let e_let (lb, exp) =
+ let (E_aux (_, annot) as ret_exp) = get_return exp in
+ if is_return exp then E_return (E_aux (E_let (lb, ret_exp), annot))
+ else E_let (lb, exp) in
+
+ let e_internal_let (lexp, exp1, exp2) =
+ let (E_aux (_, annot) as ret_exp2) = get_return exp2 in
+ if is_return exp2 then
+ E_return (E_aux (E_var (lexp, exp1, ret_exp2), annot))
+ else E_var (lexp, exp1, exp2) in
+
let e_aux (exp, (l, annot)) =
let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in
let env = env_of full_exp in
@@ -1674,14 +1711,18 @@ let rewrite_defs_early_return (Defs defs) =
let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pexp), a)) =
let pat,guard,exp,pannot = destruct_pexp pexp in
+ (* Try to pull out early returns as far as possible *)
+ let exp' =
+ fold_exp
+ { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case;
+ e_let = e_let; e_internal_let = e_internal_let }
+ exp in
+ (* Remove early return if we can pull it out completely, and rewrite
+ remaining early returns to "early_return" calls *)
let exp =
- exp
- (* Pull early returns out as far as possible *)
- |> fold_exp { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case }
- (* Remove singleton E_return *)
- |> get_return
- (* Fix effect annotations *)
- |> fold_exp { id_exp_alg with e_aux = e_aux } in
+ fold_exp
+ { id_exp_alg with e_aux = e_aux }
+ (if is_return exp' then get_return exp' else exp) in
let a = match a with
| (l, Some (env, typ, eff)) ->
(l, Some (env, typ, union_effects eff (effect_of exp)))
@@ -2141,7 +2182,7 @@ let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list
| [] -> k []
| exp :: exps -> f exp (fun exp -> mapCont f exps (fun exps -> k (exp :: exps)))
-let rewrite_defs_letbind_effects =
+let rewrite_defs_letbind_effects =
let rec value ((E_aux (exp_aux,_)) as exp) =
not (effectful exp || updates_vars exp)
@@ -2235,6 +2276,7 @@ let rewrite_defs_letbind_effects =
let exp =
if newreturn then
(* let typ = try typ_of exp with _ -> unit_typ in *)
+ let exp = annot_exp (E_cast (typ_of exp, exp)) l (env_of exp) (typ_of exp) in
annot_exp (E_internal_return exp) l (env_of exp) (typ_of exp)
else
exp in