diff options
Diffstat (limited to 'src/rewrites.ml')
| -rw-r--r-- | src/rewrites.ml | 174 |
1 files changed, 113 insertions, 61 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index d926dfac..b328307d 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1723,14 +1723,21 @@ let rewrite_defs_early_return (Defs defs) = | E_lit (L_aux (L_unit, _)) -> true | _ -> false in - let is_return (E_aux (exp, _)) = match exp with + let rec is_return (E_aux (exp, _)) = match exp with | E_return _ -> true + | E_cast (_, e) -> is_return e | _ -> false in - let get_return (E_aux (e, (l, _)) as exp) = match e with + let rec get_return (E_aux (e, annot) as exp) = match e with | E_return e -> e + | E_cast (typ, e) -> E_aux (E_cast (typ, get_return e), annot) | _ -> exp in + let contains_return exp = + fst (fold_exp + { (compute_exp_alg false (||)) + with e_return = (fun (_, r) -> (true, E_return r)) } exp) in + let e_if (e1, e2, e3) = if is_return e2 && is_return e3 then let (E_aux (_, annot)) = get_return e2 in @@ -1792,6 +1799,12 @@ let rewrite_defs_early_return (Defs defs) = E_return (E_aux (E_var (lexp, exp1, ret_exp2), annot)) else E_var (lexp, exp1, exp2) in + let e_app (id, es) = + try E_return (get_return (List.find is_return es)) + with + | Not_found -> E_app (id, es) + in + let e_aux (exp, (l, annot)) = let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in let env = env_of full_exp in @@ -1804,20 +1817,51 @@ let rewrite_defs_early_return (Defs defs) = E_aux (E_app (mk_id "early_return", [exp']), (l, annot')) | _ -> full_exp in + (* Make sure that all final leaves of an expression (e.g. all branches of + the last if-expression) are wrapped in a return statement. This allows + the above rewriting to uniformly pull these returns back out, even if + originally only one of the branches of the last if-expression was a + return, and the other an "exit()", for example. *) + let rec add_final_return nested (E_aux (e, annot) as exp) = + let rewrap e = E_aux (e, annot) in + match e with + | E_return _ -> exp + | E_cast (typ, e') -> + begin + let (E_aux (e_aux', annot') as e') = add_final_return nested e' in + match e_aux' with + | E_return e' -> rewrap (E_return (rewrap (E_cast (typ, e')))) + | _ -> rewrap (E_cast (typ, e')) + end + | E_block ((_ :: _) as es) -> + rewrap (E_block (Util.butlast es @ [add_final_return true (Util.last es)])) + | E_if (c, t, e) -> + rewrap (E_if (c, add_final_return true t, add_final_return true e)) + | E_let (lb, exp) -> + rewrap (E_let (lb, add_final_return true exp)) + | E_var (lexp, e1, e2) -> + rewrap (E_var (lexp, e1, add_final_return true e2)) + | _ -> + if nested && not (contains_return exp) then rewrap (E_return exp) else exp + in + let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pexp), a)) = let pat,guard,exp,pannot = destruct_pexp pexp in - (* Try to pull out early returns as far as possible *) - let exp' = - fold_exp - { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case; - e_let = e_let; e_internal_let = e_internal_let } - exp in - (* Remove early return if we can pull it out completely, and rewrite - remaining early returns to "early_return" calls *) let exp = - fold_exp - { id_exp_alg with e_aux = e_aux } - (if is_return exp' then get_return exp' else exp) in + if contains_return exp then + (* Try to pull out early returns as far as possible *) + let exp' = + fold_exp + { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case; + e_let = e_let; e_internal_let = e_internal_let; e_app = e_app } + (add_final_return false exp) in + (* Remove early return if we can pull it out completely, and rewrite + remaining early returns to "early_return" calls *) + fold_exp + { id_exp_alg with e_aux = e_aux } + (if is_return exp' then get_return exp' else exp) + else exp + in let a = match a with | (l, Some (env, typ, eff)) -> (l, Some (env, typ, union_effects eff (effect_of exp))) @@ -1978,6 +2022,8 @@ let rewrite_fix_val_specs (Defs defs) = | E_aux (E_app_infix (_, f, _) as exp, (l, Some (env, typ, eff))) | E_aux (E_app (f, _) as exp, (l, Some (env, typ, eff))) -> let vs = find_vs env val_specs f in + (* The (updated) environment is used later by fix_eff_exp to look up + the actual effects of a function call *) let env = Env.update_val_spec f vs env in E_aux (exp, (l, Some (env, typ, union_effects eff (eff_of_vs vs)))) | e_aux -> e_aux @@ -1992,8 +2038,14 @@ let rewrite_fix_val_specs (Defs defs) = let vs, eff = match find_vs (env_of_annot (l, annot)) val_specs id with | (tq, Typ_aux (Typ_fn (args_t, ret_t, eff), a)) -> let eff' = union_effects eff (effect_of exp) in - let args_t' = rewrite_typ_nexp_ids (env_of exp) (pat_typ_of pat) in - let ret_t' = rewrite_typ_nexp_ids (env_of exp) (typ_of exp) in + (* TODO We currently expand type synonyms here to make life easier + for the Lem pretty-printer later on, as it currently does not have + access to the environment when printing val specs (and unexpanded + type synonyms nested in existentials might cause problems). + A more robust solution would be to make the (global) environment + more easily available to the pretty-printer. *) + let args_t' = Env.expand_synonyms (env_of exp) args_t in + let ret_t' = Env.expand_synonyms (env_of exp) ret_t in (tq, Typ_aux (Typ_fn (args_t', ret_t', eff'), a)), eff' | _ -> assert false (* find_vs must return a function type *) in @@ -2014,7 +2066,7 @@ let rewrite_fix_val_specs (Defs defs) = let tannotopt = match tannotopt, funcls with | Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), l), FCL_aux (FCL_Funcl (_, Pat_aux ((Pat_exp (_, exp) | Pat_when (_, _, exp)), _)), _) :: _ -> - Typ_annot_opt_aux (Typ_annot_opt_some (typq, rewrite_typ_nexp_ids (env_of exp) typ), l) + Typ_annot_opt_aux (Typ_annot_opt_some (typq, Env.expand_synonyms (env_of exp) typ), l) | _ -> tannotopt in (val_specs, FD_aux (FD_function (recopt, tannotopt, effopt, funcls), a)) in @@ -2572,9 +2624,9 @@ let rewrite_defs_letbind_effects = k (rewrap (E_case (exp1,pexps))))) | E_try (exp1,pexps) -> let newreturn = effectful exp1 || List.exists effectful_pexp pexps in - n_exp_name exp1 (fun exp1 -> + let exp1 = n_exp_term newreturn exp1 in n_pexpL newreturn pexps (fun pexps -> - k (rewrap (E_try (exp1,pexps))))) + k (rewrap (E_try (exp1,pexps)))) | E_let (lb,body) -> n_lb lb (fun lb -> rewrap (E_let (lb,n_exp body k))) @@ -2797,7 +2849,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = exp, P_aux (P_id id, a)) |> List.split in - let rewrite (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) = + let rewrite used_vars (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) = let env = env_of_annot annot in let overwrite = match paux with | P_wild | P_typ (_, P_aux (P_wild, _)) -> true @@ -2813,7 +2865,11 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = function as an expression followed by the list of variables it expects. In (Lem) pretty-printing, this turned into an anonymous function and passed to foreach*. *) - let vars, varpats = mk_var_exps_pats pl env (find_updated_vars exp4) in + let vars, varpats = + find_updated_vars exp4 + |> IdSet.inter used_vars + |> mk_var_exps_pats pl env + in let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in let ord_exp, lower, upper = match destruct_range (typ_of exp1), destruct_range (typ_of exp2) with | None, _ | _, None -> @@ -2833,7 +2889,11 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let v = annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; tuple_exp vars; body])) el env (typ_of body) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_loop(loop,cond,body) -> - let vars, varpats = mk_var_exps_pats pl env (find_updated_vars body) in + let vars, varpats = + find_updated_vars body + |> IdSet.inter used_vars + |> mk_var_exps_pats pl env + in let body = rewrite_var_updates (add_vars overwrite body vars) in let (E_aux (_,(_,bannot))) = body in let fname = match loop with @@ -2845,6 +2905,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | E_if (c,e1,e2) -> let vars, varpats = IdSet.union (find_updated_vars e1) (find_updated_vars e2) + |> IdSet.inter used_vars |> mk_var_exps_pats pl env in if vars = [] then (Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot))) @@ -2864,6 +2925,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = |> List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e) |> List.map find_updated_vars |> List.fold_left IdSet.union IdSet.empty + |> IdSet.inter used_vars |> mk_var_exps_pats pl env in if vars = [] then let ps = List.map (function @@ -2924,7 +2986,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | E_let (lb,body) -> let body = rewrite_var_updates body in let (LB_aux (LB_val (pat, v), lbannot)) = lb in - let lb = match rewrite v pat with + let lb = match rewrite (find_used_vars body) v pat with | Added_vars (v, P_aux (pat, _)) -> annot_letbind (pat, v) (get_loc_exp v) env (typ_of v) | Same_vars v -> LB_aux (LB_val (pat, v),lbannot) in @@ -2987,26 +3049,19 @@ let remove_reference_types exp = let rewrite_defs_remove_superfluous_letbinds = let e_aux (exp,annot) = match exp with - | E_let (lb,exp2) -> - begin match lb,exp2 with + | E_let (LB_aux (LB_val (pat, exp1), _), exp2) -> + begin match untyp_pat pat, uncast_exp exp1, uncast_exp exp2 with (* 'let x = EXP1 in x' can be replaced with 'EXP1' *) - | LB_aux (LB_val (P_aux (P_id id, _), exp1), _), - E_aux (E_id id', _) - | LB_aux (LB_val (P_aux (P_typ (_, P_aux (P_id id, _)), _), exp1), _), - E_aux (E_id id', _) - | LB_aux (LB_val (P_aux (P_id id, _), exp1), _), - E_aux (E_cast (_,E_aux (E_id id', _)), _) - when Id.compare id id' = 0 && id_is_unbound id (env_of_annot annot) -> + | (P_aux (P_id id, _), _), _, (E_aux (E_id id', _), _) + when Id.compare id id' = 0 -> exp1 (* "let _ = () in exp" can be replaced with exp *) - | LB_aux (LB_val (P_aux (P_wild, _), E_aux (E_lit (L_aux (L_unit, _)), _)), _), - exp2 -> + | (P_aux (P_wild, _), _), (E_aux (E_lit (L_aux (L_unit, _)), _), _), _ -> exp2 (* "let x = EXP1 in return x" can be replaced with 'return (EXP1)', at least when EXP1 is 'small' enough *) - | LB_aux (LB_val (P_aux (P_id id, _), exp1), _), - E_aux (E_internal_return (E_aux (E_id id', _)), _) - when Id.compare id id' = 0 && small exp1 && id_is_unbound id (env_of_annot annot) -> + | (P_aux (P_id id, _), _), _, (E_aux (E_internal_return (E_aux (E_id id', _)), _), _) + when Id.compare id id' = 0 && small exp1 -> let (E_aux (_,e1annot)) = exp1 in E_aux (E_internal_return (exp1),e1annot) | _ -> E_aux (exp,annot) @@ -3027,44 +3082,41 @@ let rewrite_defs_remove_superfluous_letbinds = let rewrite_defs_remove_superfluous_returns = - let untyp_pat = function - | P_aux (P_typ (typ, pat), _) -> pat, Some typ - | pat -> pat, None in - - let uncast_internal_return = function - | E_aux (E_internal_return (E_aux (E_cast (typ, exp), _)), a) -> - E_aux (E_internal_return exp, a), Some typ - | exp -> exp, None in + let add_opt_cast typopt1 typopt2 annot exp = + match typopt1, typopt2 with + | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp), annot) + | None, None -> exp + in let e_aux (exp,annot) = match exp with | E_let (LB_aux (LB_val (pat, exp1), _), exp2) | E_internal_plet (pat, exp1, exp2) when effectful exp1 -> - begin match untyp_pat pat, uncast_internal_return exp2 with + begin match untyp_pat pat, uncast_exp exp2 with | (P_aux (P_lit (L_aux (lit,_)),_), ptyp), (E_aux (E_internal_return (E_aux (E_lit (L_aux (lit',_)),_)), a), etyp) when lit = lit' -> - begin - match ptyp, etyp with - | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a) - | None, None -> exp1 - end + add_opt_cast ptyp etyp a exp1 | (P_aux (P_wild,pannot), ptyp), (E_aux (E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)), a), etyp) when is_unit_typ (typ_of exp1) -> - begin - match ptyp, etyp with - | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a) - | None, None -> exp1 - end + add_opt_cast ptyp etyp a exp1 | (P_aux (P_id id,_), ptyp), (E_aux (E_internal_return (E_aux (E_id id',_)), a), etyp) - when Id.compare id id' == 0 && id_is_unbound id (env_of_annot annot) -> - begin - match ptyp, etyp with - | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a) - | None, None -> exp1 - end + when Id.compare id id' == 0 -> + add_opt_cast ptyp etyp a exp1 + | (P_aux (P_tup ps, _), ptyp), + (E_aux (E_internal_return (E_aux (E_tuple es, _)), a), etyp) + when List.length ps = List.length es -> + let same_id (P_aux (p, _)) (E_aux (e, _)) = match p, e with + | P_id id, E_id id' -> Id.compare id id' == 0 + | _, _ -> false + in + let ps = List.map fst (List.map untyp_pat ps) in + let es = List.map fst (List.map uncast_exp es) in + if List.for_all2 same_id ps es + then add_opt_cast ptyp etyp a exp1 + else E_aux (exp,annot) | _ -> E_aux (exp,annot) end | _ -> E_aux (exp,annot) in |
