diff options
Diffstat (limited to 'src/rewrites.ml')
| -rw-r--r-- | src/rewrites.ml | 74 |
1 files changed, 58 insertions, 16 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index e4ac71cd..40772828 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1087,7 +1087,7 @@ let remove_bitvector_pat (P_aux (_, (l, _)) as pat) = ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false)) } in - let pat, env = bind_pat env + let pat, env = bind_pat_no_guard env (strip_pat ((fold_pat name_bitvector_roots pat) false)) (pat_typ_of pat) in @@ -1624,6 +1624,10 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base rewriting of early returns *) let rewrite_defs_early_return (Defs defs) = + let is_unit (E_aux (exp, _)) = match exp with + | E_lit (L_aux (L_unit, _)) -> true + | _ -> false in + let is_return (E_aux (exp, _)) = match exp with | E_return _ -> true | _ -> false in @@ -1632,7 +1636,35 @@ let rewrite_defs_early_return (Defs defs) = | E_return e -> e | _ -> exp in - let e_block es = + let e_if (e1, e2, e3) = + if is_return e2 && is_return e3 then + let (E_aux (_, annot)) = get_return e2 in + E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot)) + else E_if (e1, e2, e3) in + + let rec e_block es = + (* If one of the branches of an if-expression in a block is an early + return, fold the rest of the block after the if-expression into the + other branch *) + let fold_if_return exp block = match exp with + | E_aux (E_if (c, t, (E_aux (_, annot) as e)), _) when is_return t -> + let annot = match block with + | [] -> annot + | _ -> let (E_aux (_, annot)) = Util.last block in annot + in + let block = if is_unit e then block else e :: block in + let e' = E_aux (e_block block, annot) in + [E_aux (e_if (c, t, e'), annot)] + | E_aux (E_if (c, (E_aux (_, annot) as t), e), _) when is_return e -> + let annot = match block with + | [] -> annot + | _ -> let (E_aux (_, annot)) = Util.last block in annot + in + let block = if is_unit t then block else t :: block in + let t' = E_aux (e_block block, annot) in + [E_aux (e_if (c, t', e), annot)] + | _ -> exp :: block in + let es = List.fold_right fold_if_return es [] in match es with | [E_aux (e, _)] -> e | _ :: _ when is_return (Util.last es) -> @@ -1640,12 +1672,6 @@ let rewrite_defs_early_return (Defs defs) = E_return (E_aux (E_block (Util.butlast es @ [get_return e]), annot)) | _ -> E_block es in - let e_if (e1, e2, e3) = - if is_return e2 && is_return e3 then - let (E_aux (_, annot)) = get_return e2 in - E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot)) - else E_if (e1, e2, e3) in - let e_case (e, pes) = let is_return_pexp (Pat_aux (pexp, _)) = match pexp with | Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in @@ -1660,6 +1686,17 @@ let rewrite_defs_early_return (Defs defs) = then E_return (E_aux (E_case (e, List.map get_return_pexp pes), annot)) else E_case (e, pes) in + let e_let (lb, exp) = + let (E_aux (_, annot) as ret_exp) = get_return exp in + if is_return exp then E_return (E_aux (E_let (lb, ret_exp), annot)) + else E_let (lb, exp) in + + let e_internal_let (lexp, exp1, exp2) = + let (E_aux (_, annot) as ret_exp2) = get_return exp2 in + if is_return exp2 then + E_return (E_aux (E_var (lexp, exp1, ret_exp2), annot)) + else E_var (lexp, exp1, exp2) in + let e_aux (exp, (l, annot)) = let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in let env = env_of full_exp in @@ -1674,14 +1711,18 @@ let rewrite_defs_early_return (Defs defs) = let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pexp), a)) = let pat,guard,exp,pannot = destruct_pexp pexp in + (* Try to pull out early returns as far as possible *) + let exp' = + fold_exp + { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case; + e_let = e_let; e_internal_let = e_internal_let } + exp in + (* Remove early return if we can pull it out completely, and rewrite + remaining early returns to "early_return" calls *) let exp = - exp - (* Pull early returns out as far as possible *) - |> fold_exp { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case } - (* Remove singleton E_return *) - |> get_return - (* Fix effect annotations *) - |> fold_exp { id_exp_alg with e_aux = e_aux } in + fold_exp + { id_exp_alg with e_aux = e_aux } + (if is_return exp' then get_return exp' else exp) in let a = match a with | (l, Some (env, typ, eff)) -> (l, Some (env, typ, union_effects eff (effect_of exp))) @@ -2141,7 +2182,7 @@ let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list | [] -> k [] | exp :: exps -> f exp (fun exp -> mapCont f exps (fun exps -> k (exp :: exps))) -let rewrite_defs_letbind_effects = +let rewrite_defs_letbind_effects = let rec value ((E_aux (exp_aux,_)) as exp) = not (effectful exp || updates_vars exp) @@ -2235,6 +2276,7 @@ let rewrite_defs_letbind_effects = let exp = if newreturn then (* let typ = try typ_of exp with _ -> unit_typ in *) + let exp = annot_exp (E_cast (typ_of exp, exp)) l (env_of exp) (typ_of exp) in annot_exp (E_internal_return exp) l (env_of exp) (typ_of exp) else exp in |
