diff options
Diffstat (limited to 'src/rewriter.ml')
| -rw-r--r-- | src/rewriter.ml | 665 |
1 files changed, 387 insertions, 278 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index fd1479a7..002d7630 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -78,6 +78,8 @@ let effect_of_lb (LB_aux (_,(_,a))) = effect_of_annot a let get_loc_exp (E_aux (_,(l,_))) = l let gen_loc l = Parse_ast.Generated l +let gen_vs (id, spec) = Initial_check.val_spec_of_string dec_ord (mk_id id) spec + let simple_annot l typ = (gen_loc l, Some (Env.empty, typ, no_effect)) let simple_num l n = E_aux ( E_lit (L_aux (L_num n, gen_loc l)), @@ -246,18 +248,6 @@ let effectful_effs = function let effectful eaux = effectful_effs (effect_of (propagate_exp_effect eaux)) let effectful_pexp pexp = effectful_effs (snd (propagate_pexp_effect pexp)) -let updates_vars_effs = function - | Effect_aux (Effect_set effs, _) -> - List.exists - (fun (BE_aux (be,_)) -> - match be with - | BE_lset -> true - | _ -> false - ) effs - | _ -> true - -let updates_vars eaux = updates_vars_effs (effect_of eaux) - let id_to_string (Id_aux(id,l)) = match id with | Id(s) -> s @@ -535,7 +525,8 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot))) = raise (Reporting_basic.err_unreachable l ("Internal_exp_user given unexpected types " ^ (t_to_string tu) ^ ", " ^ (t_to_string ti)))) | _ -> raise (Reporting_basic.err_unreachable l ("Internal_exp_user given none Base annot")))*) - | E_internal_let _ -> raise (Reporting_basic.err_unreachable l "Internal let found before it should have been introduced") + | E_internal_let (lexp, e1, e2) -> + rewrap (E_internal_let (rewriters.rewrite_lexp rewriters lexp, rewriters.rewrite_exp rewriters e1, rewriters.rewrite_exp rewriters e2)) | E_internal_return _ -> raise (Reporting_basic.err_unreachable l "Internal return found before it should have been introduced") | E_internal_plet _ -> raise (Reporting_basic.err_unreachable l " Internal plet found before it should have been introduced") | _ -> rewrap exp @@ -1306,16 +1297,25 @@ let rewrite_sizeof (Defs defs) = let exp' = fold_exp { id_exp_alg with e_sizeof = e_sizeof kid_nmap } exp in FCL_aux (FCL_Funcl (id, rewrite_pat pat, exp'), annot) in let funcls = List.map rewrite_funcl_params funcls in - (nvars, FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) in + let fd = FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot) in + let params_map = + if KidSet.is_empty nvars then params_map else + Bindings.add (id_of_fundef fd) nvars params_map in + (params_map, FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) in let rewrite_sizeof_def (params_map, defs) = function - | DEF_fundef fd as def -> - let (nvars, fd') = rewrite_sizeof_fun params_map fd in - let id = id_of_fundef fd in - let params_map' = - if KidSet.is_empty nvars then params_map - else Bindings.add id nvars params_map in + | DEF_fundef fd -> + let (params_map', fd') = rewrite_sizeof_fun params_map fd in (params_map', defs @ [DEF_fundef fd']) + | DEF_internal_mutrec fds -> + let rewrite_fd (params_map, fds) fd = + let (params_map', fd') = rewrite_sizeof_fun params_map fd in + (params_map', fds @ [fd']) in + (* TODO Split rewrite_sizeof_fun into an analysis and a rewrite pass, + so that we can call the analysis until a fixpoint is reached and then + rewrite the mutually recursive functions *) + let (params_map', fds') = List.fold_left rewrite_fd (params_map, []) fds in + (params_map', defs @ [DEF_internal_mutrec fds']) | DEF_val (LB_aux (lb, annot)) -> begin let lb' = match lb with @@ -1354,11 +1354,18 @@ let rewrite_sizeof (Defs defs) = let (params_map, defs) = List.fold_left rewrite_sizeof_def (Bindings.empty, []) defs in let defs = List.map (rewrite_sizeof_valspec params_map) defs in - Defs defs - (* FIXME: Won't re-check due to flow typing and E_constraint re-write before E_sizeof re-write. - Requires the typechecker to be more smart about different representations for valid flow typing constraints. fst (check initial_env (Defs defs)) - *) + +let rewrite_defs_remove_assert defs = + let e_assert ((E_aux (eaux, (l, _)) as exp), str) = match eaux with + | E_constraint _ -> + E_assert (exp, str) + | _ -> + E_assert (E_aux (E_lit (mk_lit L_true), simple_annot l bool_typ), str) in + rewrite_defs_base + { rewriters_base with + rewrite_exp = (fun _ -> fold_exp { id_exp_alg with e_assert = e_assert}) } + defs let remove_vector_concat_pat pat = @@ -1874,7 +1881,7 @@ let contains_bitvector_pexp = function (* Rewrite bitvector patterns to guarded patterns *) -let remove_bitvector_pat pat = +let remove_bitvector_pat (P_aux (_, (l, _)) as pat) = let env = try pat_env_of pat with _ -> Env.empty in @@ -1906,54 +1913,39 @@ let remove_bitvector_pat pat = ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false)) } in - let pat = (fold_pat name_bitvector_roots pat) false in + let pat, env = bind_pat env + (strip_pat ((fold_pat name_bitvector_roots pat) false)) + (pat_typ_of pat) in (* Then collect guard expressions testing whether the literal bits of a bitvector pattern match those of a given bitvector, and collect let bindings for the bits bound by P_id or P_as patterns *) (* Helper functions for generating guard expressions *) + let mk_exp e_aux = E_aux (e_aux, (l, ())) in + let mk_num_exp i = mk_lit_exp (L_num i) in + let check_eq_exp l r = + let exp = mk_exp (E_app_infix (l, Id_aux (DeIid "==", Parse_ast.Unknown), r)) in + check_exp env exp bool_typ in + let access_bit_exp rootid l typ idx = - let root = annot_exp (E_id rootid) l env typ in - (* FIXME *) - annot_exp (E_vector_access (root, simple_num l idx)) l env bit_typ in - (*let env = env_of_annot rannot in - let t = Env.base_typ_of env (typ_of_annot rannot) in - let (_, _, ord, _) = vector_typ_args_of t in - let access_id = if is_order_inc ord then "bitvector_access_inc" else "bitvector_access_dec" in - E_aux (E_app (mk_id access_id, [root; simple_num l idx]), simple_annot l bit_typ) in*) + let access_aux = E_vector_access (mk_exp (E_id rootid), mk_num_exp idx) in + check_exp env (mk_exp access_aux) bit_typ in let test_bit_exp rootid l typ idx exp = - let rannot = (l, Some (env_of exp, typ, no_effect)) in let elem = access_bit_exp rootid l typ idx in - Some (annot_exp (E_app (mk_id "eq", [elem; exp])) l env bool_typ) in + Some (check_eq_exp (strip_exp elem) (strip_exp exp)) in let test_subvec_exp rootid l typ i j lits = let (start, length, ord, _) = vector_typ_args_of typ in - let length' = nint (List.length lits) in - let start' = - if is_order_inc ord then nint 0 - else nminus length' (nint 1) in - let typ' = vector_typ start' length' ord bit_typ in let subvec_exp = match start, length with | Nexp_aux (Nexp_constant s, _), Nexp_aux (Nexp_constant l, _) when eq_big_int s i && eq_big_int l (big_int_of_int (List.length lits)) -> - E_id rootid + mk_exp (E_id rootid) | _ -> - (*if vec_start t = i && vec_length t = List.length lits - then E_id rootid - else*) - E_vector_subrange ( - annot_exp (E_id rootid) l env typ, - simple_num l i, - simple_num l j) in - (* let subrange_id = if is_order_inc ord then "bitvector_subrange_inc" else "bitvector_subrange_dec" in - E_app (mk_id subrange_id, [E_aux (E_id rootid, simple_annot l typ); simple_num l i; simple_num l j]) in *) - annot_exp (E_app( - Id_aux (Id "eq_vec", gen_loc l), - [annot_exp subvec_exp l env typ'; - annot_exp (E_vector lits) l env typ'])) l env bool_typ in + mk_exp (E_vector_subrange (mk_exp (E_id rootid), mk_num_exp i, mk_num_exp j)) in + check_eq_exp subvec_exp (mk_exp (E_vector (List.map strip_exp lits))) in let letbind_bit_exp rootid l typ idx id = let rannot = simple_annot l typ in @@ -2099,7 +2091,6 @@ let rewrite_exp_remove_bitvector_pat rewriters (E_aux (exp,(l,annot)) as full_ex let rewrite_fun_remove_bitvector_pat rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = let _ = reset_fresh_name_counter () in - (* TODO Can there be clauses with different id's in one FD_function? *) let funcls = match funcls with | (FCL_aux (FCL_Funcl(id,_,_),_) :: _) -> let clause (FCL_aux (FCL_Funcl(_,pat,exp),annot)) = @@ -2108,7 +2099,7 @@ let rewrite_fun_remove_bitvector_pat (pat,guard,exp,annot) in let cs = rewrite_guarded_clauses l (List.map clause funcls) in List.map (fun (pat,exp,annot) -> FCL_aux (FCL_Funcl(id,pat,exp),annot)) cs - | _ -> funcls (* TODO is the empty list possible here? *) in + | _ -> funcls in FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot)) let rewrite_defs_remove_bitvector_pats (Defs defs) = @@ -2171,6 +2162,10 @@ let id_is_local_var id env = match Env.lookup_id id env with | Local _ -> true | _ -> false +let id_is_unbound id env = match Env.lookup_id id env with + | Unbound -> true + | _ -> false + let rec lexp_is_local (LEXP_aux (lexp, _)) env = match lexp with | LEXP_memory _ -> false | LEXP_id id @@ -2180,10 +2175,6 @@ let rec lexp_is_local (LEXP_aux (lexp, _)) env = match lexp with | LEXP_vector_range (lexp,_,_) | LEXP_field (lexp,_) -> lexp_is_local lexp env -let id_is_unbound id env = match Env.lookup_id id env with - | Unbound -> true - | _ -> false - let rec lexp_is_local_intro (LEXP_aux (lexp, _)) env = match lexp with | LEXP_memory _ -> false | LEXP_id id @@ -2197,21 +2188,49 @@ let lexp_is_effectful (LEXP_aux (_, (_, annot))) = match annot with | Some (_, _, eff) -> effectful_effs eff | _ -> false -let rec rewrite_local_lexp ((LEXP_aux(lexp,((l,_) as annot))) as le) = - match lexp with - | LEXP_id _ | LEXP_cast (_, _) | LEXP_tup _ -> (le, (fun exp -> exp)) - | LEXP_vector (lexp, e) -> - let (lhs, rhs) = rewrite_local_lexp lexp in - (lhs, (fun exp -> rhs (E_aux (E_vector_update (lexp_to_exp lexp, e, exp), annot)))) - | LEXP_vector_range (lexp, e1, e2) -> - let (lhs, rhs) = rewrite_local_lexp lexp in - (lhs, (fun exp -> rhs (E_aux (E_vector_update_subrange (lexp_to_exp lexp, e1, e2, exp), annot)))) - | LEXP_field (lexp, id) -> - let (lhs, rhs) = rewrite_local_lexp lexp in - let (LEXP_aux (_, recannot)) = lexp in - let field_update exp = FES_aux (FES_Fexps ([FE_aux (FE_Fexp (id, exp), annot)], false), annot) in - (lhs, (fun exp -> rhs (E_aux (E_record_update (lexp_to_exp lexp, field_update exp), recannot)))) - | _ -> raise (Reporting_basic.err_unreachable l ("Unsupported lexp: " ^ string_of_lexp le)) +let rec rewrite_lexp_to_rhs (do_rewrite : tannot lexp -> bool) ((LEXP_aux(lexp,((l,_) as annot))) as le) = + if do_rewrite le then + match lexp with + | LEXP_id _ | LEXP_cast (_, _) | LEXP_tup _ -> (le, (fun exp -> exp)) + | LEXP_vector (lexp, e) -> + let (lhs, rhs) = rewrite_lexp_to_rhs do_rewrite lexp in + (lhs, (fun exp -> rhs (E_aux (E_vector_update (lexp_to_exp lexp, e, exp), annot)))) + | LEXP_vector_range (lexp, e1, e2) -> + let (lhs, rhs) = rewrite_lexp_to_rhs do_rewrite lexp in + (lhs, (fun exp -> rhs (E_aux (E_vector_update_subrange (lexp_to_exp lexp, e1, e2, exp), annot)))) + | LEXP_field (lexp, id) -> + begin + let (lhs, rhs) = rewrite_lexp_to_rhs do_rewrite lexp in + let (LEXP_aux (_, lannot)) = lexp in + let env = env_of_annot lannot in + match Env.expand_synonyms env (typ_of_annot lannot) with + | Typ_aux (Typ_app (Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id regtyp_id, _)), _)]), _) + | Typ_aux (Typ_id regtyp_id, _) when Env.is_regtyp regtyp_id env -> + let base, top, ranges = Env.get_regtyp regtyp_id env in + let range, _ = + try List.find (fun (_, fid) -> Id.compare fid id = 0) ranges with + | Not_found -> + raise (Reporting_basic.err_typ l ("Field " ^ string_of_id id ^ " doesn't exist for register type " ^ string_of_id regtyp_id)) + in + let lexp_exp = E_aux (E_app (mk_id ("cast_" ^ string_of_id regtyp_id), [lexp_to_exp lexp]), (l, None)) in + let n, m = match range with + | BF_aux (BF_single n, _) -> n, n + | BF_aux (BF_range (n, m), _) -> n, m + | _ -> raise (Reporting_basic.err_unreachable l ("Unsupported lexp: " ^ string_of_lexp le)) in + let rhs' exp = rhs (E_aux (E_vector_update_subrange (lexp_exp, simple_num l n, simple_num l m, exp), lannot)) in + (lhs, rhs') + | Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env -> + let field_update exp = FES_aux (FES_Fexps ([FE_aux (FE_Fexp (id, exp), annot)], false), annot) in + (lhs, (fun exp -> rhs (E_aux (E_record_update (lexp_to_exp lexp, field_update exp), lannot)))) + | _ -> raise (Reporting_basic.err_unreachable l ("Unsupported lexp: " ^ string_of_lexp le)) + end + | _ -> raise (Reporting_basic.err_unreachable l ("Unsupported lexp: " ^ string_of_lexp le)) + else (le, (fun exp -> exp)) + +let updates_vars exp = + let e_assign ((_, lexp), (u, exp)) = + (u || lexp_is_local lexp (env_of exp), E_assign (lexp, exp)) in + fst (fold_exp { (compute_exp_alg false (||)) with e_assign = e_assign } exp) (*Expects to be called after rewrite_defs; thus the following should not appear: internal_exp of any form @@ -2229,7 +2248,7 @@ let rewrite_exp_lift_assign_intro rewriters ((E_aux (exp,((l,_) as annot))) as f | [] -> [] | (E_aux(E_assign(le,e), ((l, Some (env,typ,eff)) as annot)) as exp)::exps when lexp_is_local_intro le env && not (lexp_is_effectful le) -> - let (le', re') = rewrite_local_lexp le in + let (le', re') = rewrite_lexp_to_rhs (fun _ -> true) le in let e' = re' (rewrite_base e) in let exps' = walker exps in let effects = union_eff_exps exps' in @@ -2279,13 +2298,15 @@ let rewrite_exp_lift_assign_intro rewriters ((E_aux (exp,((l,_) as annot))) as f (E_aux(E_if(c',t',e'),(Parse_ast.Generated l, annot))::exps',eff_union_exps (c'::t'::e'::exps')) new_vars)*) | e::exps -> (rewrite_rec e)::(walker exps) in - rewrap (E_block (walker exps)) + check_exp (env_of full_exp) + (E_aux (E_block (List.map strip_exp (walker exps)), (l, ()))) (typ_of full_exp) | E_assign(le,e) when lexp_is_local_intro le (env_of full_exp) && not (lexp_is_effectful le) -> - let (le', re') = rewrite_local_lexp le in + let (le', re') = rewrite_lexp_to_rhs (fun _ -> true) le in let e' = re' (rewrite_base e) in let block = annot_exp (E_block []) l (env_of full_exp) unit_typ in - fix_eff_exp (E_aux (E_internal_let(le', e', block), annot)) + check_exp (env_of full_exp) + (strip_exp (E_aux (E_internal_let(le', e', block), annot))) (typ_of full_exp) | _ -> rewrite_base full_exp let rewrite_lexp_lift_assign_intro rewriters ((LEXP_aux(lexp,annot)) as le) = @@ -2309,6 +2330,79 @@ let rewrite_defs_exp_lift_assign defs = rewrite_defs_base rewrite_def = rewrite_def; rewrite_defs = rewrite_defs_base} defs + +(* Rewrite assignments to register references into calls to a builtin function + "write_reg_ref" (in the Lem shallow embedding). For example, if GPR is a + vector of register references, then + GPR[i] := exp; + becomes + write_reg_ref (vector_access (GPR, i)) exp + *) +let rewrite_register_ref_writes (Defs defs) = + let (Defs write_reg_spec) = fst (check Env.empty (Defs (List.map gen_vs + [("write_reg_ref", "forall ('a : Type). (register('a), 'a) -> unit effect {wreg}")]))) in + let lexp_ref_exp (LEXP_aux (_, annot) as lexp) = + try + let exp = infer_exp (env_of_annot annot) (strip_exp (lexp_to_exp lexp)) in + if is_reftyp (typ_of exp) then Some exp else None + with | _ -> None in + let e_assign (lexp, exp) = + let (lhs, rhs) = rewrite_lexp_to_rhs (fun le -> lexp_ref_exp le = None) lexp in + match lexp_ref_exp lhs with + | Some (E_aux (_, annot) as lhs_exp) -> + let lhs = LEXP_aux (LEXP_memory (mk_id "write_reg_ref", [lhs_exp]), annot) in + E_assign (lhs, rhs exp) + | None -> E_assign (lexp, exp) in + let rewrite_exp _ = fold_exp { id_exp_alg with e_assign = e_assign } in + + let generate_field_accessors l env id n1 n2 fields = + let i1, i2 = match n1, n2 with + | Nexp_aux(Nexp_constant i1, _),Nexp_aux(Nexp_constant i2, _) -> i1, i2 + | _ -> raise (Reporting_basic.err_typ l + ("Non-constant indices in register type " ^ string_of_id id)) in + let dir_b = i1 < i2 in + let dir = (if dir_b then "true" else "false") in + let ord = Ord_aux ((if dir_b then Ord_inc else Ord_dec), Parse_ast.Unknown) in + let size = if dir_b then succ_big_int (sub_big_int i2 i1) else succ_big_int (sub_big_int i1 i2) in + let rtyp = mk_id_typ id in + let vtyp = vector_typ (nconstant i1) (nconstant size) ord bit_typ in + let accessors (fr, fid) = + let i, j = match fr with + | BF_aux (BF_single i, _) -> (i, i) + | BF_aux (BF_range (i, j), _) -> (i, j) + | _ -> raise (Reporting_basic.err_unreachable l "unsupported field type") in + let mk_num_exp i = mk_lit_exp (L_num i) in + let reg_pat, reg_env = bind_pat env (mk_pat (P_typ (rtyp, mk_pat (P_id (mk_id "reg"))))) rtyp in + let inferred_get = infer_exp reg_env (mk_exp (E_vector_subrange + (mk_exp (E_id (mk_id "reg")), mk_num_exp i, mk_num_exp j))) in + let ftyp = typ_of inferred_get in + let v_pat, v_env = bind_pat reg_env (mk_pat (P_typ (ftyp, mk_pat (P_id (mk_id "v"))))) ftyp in + let inferred_set = infer_exp v_env (mk_exp (E_vector_update_subrange + (mk_exp (E_id (mk_id "reg")), mk_num_exp i, mk_num_exp j, mk_exp (E_id (mk_id "v"))))) in + let set_args = P_aux (P_tup [reg_pat; v_pat], (l, Some (env, tuple_typ [rtyp; ftyp], no_effect))) in + let fsuffix = "_" ^ string_of_id id ^ "_" ^ string_of_id fid in + let rec_opt = Rec_aux (Rec_nonrec, l) in + let tannot ret_typ = Typ_annot_opt_aux (Typ_annot_opt_some (TypQ_aux (TypQ_tq [], l), ret_typ), l) in + let eff_opt = Effect_opt_aux (Effect_opt_pure, l) in + let mk_funcl id pat exp = FCL_aux (FCL_Funcl (mk_id id, pat, exp), (l, None)) in + let mk_fundef id pat exp ret_typ = DEF_fundef (FD_aux (FD_function (rec_opt, tannot ret_typ, eff_opt, [mk_funcl id pat exp]), (l, None))) in + [mk_fundef ("get" ^ fsuffix) reg_pat inferred_get ftyp; + mk_fundef ("set" ^ fsuffix) set_args inferred_set (typ_of inferred_set)] in + List.concat (List.map accessors fields) in + + let rewriters = { rewriters_base with rewrite_exp = rewrite_exp } in + let rec rewrite ds = match ds with + | (DEF_type (TD_aux (TD_register (id, n1, n2, fields), (l, Some (env, _, _)))) as d) :: ds -> + let (Defs d), env = check env (Defs [d]) in + d @ (generate_field_accessors l env id n1 n2 fields) @ rewrite ds + | d::ds -> (rewriters.rewrite_def rewriters d)::(rewrite ds) + | [] -> [] in + Defs (rewrite (write_reg_spec @ defs)) + + (* rewrite_defs_base { rewriters_base with rewrite_exp = rewrite_exp } + (Defs (write_reg_spec @ defs)) *) + + (*let rewrite_exp_separate_ints rewriters ((E_aux (exp,((l,_) as annot))) as full_exp) = (*let tparms,t,tag,nexps,eff,cum_eff,bounds = match annot with | Base((tparms,t),tag,nexps,eff,cum_eff,bounds) -> tparms,t,tag,nexps,eff,cum_eff,bounds @@ -2356,7 +2450,14 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base rewrite_def = rewrite_def; rewrite_defs = rewrite_defs_base} defs*) -let rewrite_defs_early_return = +(* Remove redundant return statements, and translate remaining ones into an + (effectful) call to builtin function "early_return" (in the Lem shallow + embedding). + + TODO: Maybe separate generic removal of redundant returns, and Lem-specific + rewriting of early returns + *) +let rewrite_defs_early_return (Defs defs) = let is_return (E_aux (exp, _)) = match exp with | E_return _ -> true | _ -> false in @@ -2394,13 +2495,15 @@ let rewrite_defs_early_return = else E_case (e, pes) in let e_aux (exp, (l, annot)) = - let full_exp = fix_eff_exp (E_aux (exp, (l, annot))) in - match annot with - | Some (env, typ, eff) when is_return full_exp -> + let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in + let env = env_of full_exp in + match full_exp with + | E_aux (E_return exp, (l, Some (env, typ, eff))) -> (* Add escape effect annotation, since we use the exception mechanism of the state monad to implement early return in the Lem backend *) let annot' = Some (env, typ, union_effects eff (mk_effect [BE_escape])) in - E_aux (exp, (l, annot')) + let exp' = annot_exp (E_cast (typ_of exp, exp)) l env (typ_of exp) in + E_aux (E_app (mk_id "early_return", [exp']), (l, annot')) | _ -> full_exp in let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pat, exp), a)) = @@ -2423,7 +2526,12 @@ let rewrite_defs_early_return = FD_aux (FD_function (rec_opt, tannot_opt, effect_opt, List.map (rewrite_funcl_early_return rewriters) funcls), a) in - rewrite_defs_base { rewriters_base with rewrite_fun = rewrite_fun_early_return } + let (Defs early_ret_spec) = fst (check Env.empty (Defs [gen_vs + ("early_return", "forall ('a : Type) ('b : Type). 'a -> 'b effect {escape}")])) in + + rewrite_defs_base + { rewriters_base with rewrite_fun = rewrite_fun_early_return } + (Defs (early_ret_spec @ defs)) (* Propagate effects of functions, if effect checking and propagation have not been performed already by the type checker. *) @@ -2497,6 +2605,11 @@ let rewrite_fix_val_specs (Defs defs) = Rec_aux (Rec_rec, Parse_ast.Unknown) else recopt in + let tannotopt = match tannotopt, funcls with + | Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), l), + FCL_aux (FCL_Funcl (_, _, exp), _) :: _ -> + Typ_annot_opt_aux (Typ_annot_opt_some (typq, rewrite_typ_nexp_ids (env_of exp) typ), l) + | _ -> tannotopt in (val_specs, FD_aux (FD_function (recopt, tannotopt, effopt, funcls), a)) in let rec rewrite_fundefs (val_specs, fundefs) = @@ -2784,7 +2897,7 @@ let rewrite_simple_assignments defs = let env = env_of_annot annot in match e_aux with | E_assign (lexp, exp) -> - let (lexp, rhs) = rewrite_local_lexp lexp in + let (lexp, rhs) = rewrite_lexp_to_rhs (fun _ -> true) lexp in let assign = mk_exp (E_assign (strip_lexp lexp, strip_exp (rhs exp))) in check_exp env assign unit_typ | _ -> E_aux (e_aux, annot) @@ -2801,13 +2914,9 @@ let rewrite_defs_remove_blocks = let l = get_loc_exp v in let env = env_of v in let typ = typ_of v in - annot_exp (E_let (annot_letbind (P_wild, v) l env typ, body)) l env (typ_of body) in - (* let pat = annot_pat P_wild l env typ in - let (E_aux (_,(l,tannot))) = v in - let annot_pat = (simple_annot l (typ_of v)) in - let annot_lb = (gen_loc l, tannot) in - let annot_let = (gen_loc l, Some (env_of body, typ_of body, union_eff_exps [v;body])) in - E_aux (E_let (LB_aux (LB_val (P_aux (P_wild,annot_pat),v),annot_lb),body),annot_let) in *) + let wild = P_typ (typ, annot_pat P_wild l env typ) in + let e_aux = E_let (annot_letbind (wild, v) l env typ, body) in + propagate_exp_effect (annot_exp e_aux l env (typ_of body)) in let rec f l = function | [] -> E_aux (E_lit (L_aux (L_unit,gen_loc l)), (simple_annot l unit_typ)) @@ -2839,11 +2948,13 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = | Some (env, Typ_aux (Typ_id tid, _), eff) when string_of_id tid = "unit" -> let body = body (annot_exp (E_lit (mk_lit L_unit)) l env unit_typ) in let body_typ = try typ_of body with _ -> unit_typ in - let lb = annot_letbind (P_wild, v) l env unit_typ in + let wild = P_typ (typ_of v, annot_pat P_wild l env (typ_of v)) in + let lb = annot_letbind (wild, v) l env unit_typ in propagate_exp_effect (annot_exp (E_let (lb, body)) l env body_typ) | Some (env, typ, eff) -> let id = fresh_id "w__" l in - let lb = annot_letbind (P_id id, v) l env typ in + let pat = P_typ (typ_of v, annot_pat (P_id id) l env (typ_of v)) in + let lb = annot_letbind (pat, v) l env typ in let body = body (annot_exp (E_id id) l env typ) in propagate_exp_effect (annot_exp (E_let (lb, body)) l env (typ_of body)) | None -> @@ -3058,8 +3169,8 @@ let rewrite_defs_letbind_effects = k (rewrap (E_assign (lexp,exp1))))) | E_exit exp' -> k (E_aux (E_exit (n_exp_term (effectful exp') exp'),annot)) | E_assert (exp1,exp2) -> - n_exp exp1 (fun exp1 -> - n_exp exp2 (fun exp2 -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> k (rewrap (E_assert (exp1,exp2))))) | E_internal_cast (annot',exp') -> n_exp_name exp' (fun exp' -> @@ -3114,7 +3225,7 @@ let rewrite_defs_letbind_effects = ; rewrite_defs = rewrite_defs_base } -let rewrite_defs_effectful_let_expressions = +let rewrite_defs_internal_lets = let rec pat_of_local_lexp (LEXP_aux (lexp, ((l, _) as annot))) = match lexp with | LEXP_id id -> P_aux (P_id id, annot) @@ -3124,26 +3235,31 @@ let rewrite_defs_effectful_let_expressions = let e_let (lb,body) = match lb with - | LB_aux (LB_val (P_aux (P_wild, _), E_aux (E_assign ((LEXP_aux (_, annot) as le), exp), _)), _) + | LB_aux (LB_val (P_aux ((P_wild | P_typ (_, P_aux (P_wild, _))), _), + E_aux (E_assign ((LEXP_aux (_, annot) as le), exp), (l, _))), _) when lexp_is_local le (env_of_annot annot) && not (lexp_is_effectful le) -> - (* Rewrite assignments to local variables into let bindings *) - let (lhs, rhs) = rewrite_local_lexp le in - E_let (LB_aux (LB_val (pat_of_local_lexp lhs, rhs exp), annot), body) + (* Rewrite assignments to local variables into let bindings *) + let (lhs, rhs) = rewrite_lexp_to_rhs (fun _ -> true) le in + let (LEXP_aux (_, lannot)) = lhs in + let ltyp = typ_of_annot lannot in + let rhs = annot_exp (E_cast (ltyp, rhs exp)) l (env_of_annot lannot) ltyp in + E_let (LB_aux (LB_val (pat_of_local_lexp lhs, rhs), annot), body) | LB_aux (LB_val (pat,exp'),annot') -> if effectful exp' then E_internal_plet (pat,exp',body) else E_let (lb,body) in let e_internal_let = fun (lexp,exp1,exp2) -> - match lexp with - | LEXP_aux (LEXP_id id,annot) - | LEXP_aux (LEXP_cast (_,id),annot) -> - if effectful exp1 then - E_internal_plet (P_aux (P_id id,annot),exp1,exp2) - else - let lb = LB_aux (LB_val (P_aux (P_id id,annot), exp1), annot) in - E_let (lb, exp2) + let paux, annot = match lexp with + | LEXP_aux (LEXP_id id, annot) -> + (P_id id, annot) + | LEXP_aux (LEXP_cast (typ, id), annot) -> + (P_typ (typ, P_aux (P_id id, annot)), annot) | _ -> failwith "E_internal_let with unexpected lexp" in + if effectful exp1 then + E_internal_plet (P_aux (paux, annot), exp1, exp2) + else + E_let (LB_aux (LB_val (P_aux (paux, annot), exp1), annot), exp2) in let alg = { id_exp_alg with e_let = e_let; e_internal_let = e_internal_let } in rewrite_defs_base @@ -3170,56 +3286,31 @@ let eqidtyp (id1,_) (id2,_) = name1 = name2 let find_introduced_vars exp = - let e_aux ((ids,e_aux),annot) = - let ids = match e_aux, annot with - | E_internal_let (LEXP_aux (LEXP_id id, _), _, _), (_, Some (env, _, _)) - | E_internal_let (LEXP_aux (LEXP_cast (_, id), _), _, _), (_, Some (env, _, _)) - when id_is_unbound id env -> IdSet.add id ids - | _ -> ids in - (ids, E_aux (e_aux, annot)) in + let lEXP_aux ((ids, lexp), annot) = + let ids = match lexp with + | LEXP_id id | LEXP_cast (_, id) + when id_is_unbound id (env_of_annot annot) -> IdSet.add id ids + | _ -> ids in + (ids, LEXP_aux (lexp, annot)) in fst (fold_exp - { (compute_exp_alg IdSet.empty IdSet.union) with e_aux = e_aux } exp) + { (compute_exp_alg IdSet.empty IdSet.union) with lEXP_aux = lEXP_aux } exp) let find_updated_vars exp = let intros = find_introduced_vars exp in - let e_aux ((ids,e_aux),annot) = - let ids = match e_aux, annot with - | E_assign (LEXP_aux (LEXP_id id, _), _), (_, Some (env, _, _)) - | E_assign (LEXP_aux (LEXP_cast (_, id), _), _), (_, Some (env, _, _)) - when id_is_local_var id env && not (IdSet.mem id intros) -> - (id, annot) :: ids - | _ -> ids in - (ids, E_aux (e_aux, annot)) in + let lEXP_aux ((ids, lexp), annot) = + let ids = match lexp with + | LEXP_id id | LEXP_cast (_, id) + when id_is_local_var id (env_of_annot annot) && not (IdSet.mem id intros) -> + (id, annot) :: ids + | _ -> ids in + (ids, LEXP_aux (lexp, annot)) in dedup eqidtyp (fst (fold_exp - { (compute_exp_alg [] (@)) with e_aux = e_aux } exp)) + { (compute_exp_alg [] (@)) with lEXP_aux = lEXP_aux } exp)) let swaptyp typ (l,tannot) = match tannot with | Some (env, typ', eff) -> (l, Some (env, typ, eff)) | _ -> raise (Reporting_basic.err_unreachable l "swaptyp called with empty type annotation") -let mktup l es = - match es with - | [] -> annot_exp (E_lit (mk_lit L_unit)) (gen_loc l) Env.empty unit_typ - | [e] -> e - | e :: _ -> - let typ = mk_typ (Typ_tup (List.map typ_of es)) in - propagate_exp_effect (annot_exp (E_tuple es) (gen_loc l) (env_of e) typ) - -let mktup_pat l es = - match es with - | [] -> annot_pat P_wild (gen_loc l) Env.empty unit_typ - | [E_aux (E_id id,_) as exp] -> - annot_pat (P_id id) (gen_loc l) (env_of exp) (typ_of exp) - | exp :: _ -> - let typ = mk_typ (Typ_tup (List.map typ_of es)) in - let pats = List.map (function - | (E_aux (E_id id,_) as exp) -> - annot_pat (P_id id) (gen_loc l) (env_of exp) (typ_of exp) - | exp -> - annot_pat P_wild (gen_loc l) (env_of exp) (typ_of exp)) es in - annot_pat (P_tup pats) (gen_loc l) (env_of exp) typ - - type 'a updated_term = | Added_vars of 'a exp * 'a pat | Same_vars of 'a exp @@ -3254,9 +3345,33 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = else let typ' = Typ_aux (Typ_tup [typ_of exp;typ_of vars], gen_loc l) in E_aux (E_tuple [exp;vars],swaptyp typ' annot) in - - let rewrite (E_aux (expaux,((el,_) as annot))) (P_aux (_,(pl,pannot)) as pat) = - let overwrite = match typ_of_annot annot with + + let mk_varstup l es = + let exp_to_pat (E_aux (eaux, annot) as exp) = match eaux with + | E_lit lit -> + P_aux (P_lit lit, annot) + | E_id id -> + annot_pat (P_id id) l (env_of exp) (typ_of exp) + | _ -> raise (Reporting_basic.err_unreachable l + ("Failed to extract pattern from expression " ^ string_of_exp exp)) in + match es with + | [] -> + annot_exp (E_lit (mk_lit L_unit)) (gen_loc l) Env.empty unit_typ, + annot_pat P_wild (gen_loc l) Env.empty unit_typ + | [e] -> + let e = infer_exp (env_of e) (strip_exp e) in + e, annot_pat (P_typ (typ_of e, exp_to_pat e)) l (env_of e) (typ_of e) + | e :: _ -> + let infer_e e = infer_exp (env_of e) (strip_exp e) in + let es = List.map infer_e es in + let pats = List.map exp_to_pat es in + let typ = tuple_typ (List.map typ_of es) in + annot_exp (E_tuple es) l (env_of e) typ, + annot_pat (P_typ (typ, annot_pat (P_tup pats) l (env_of e) typ)) l (env_of e) typ in + + let rewrite (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (_,(pl,pannot)) as pat) = + let env = env_of_annot annot in + let overwrite = match typ_of full_exp with | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> true | _ -> false in match expaux with @@ -3271,62 +3386,44 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = expects. In (Lem) pretty-printing, this turned into an anonymous function and passed to foreach*. *) let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars exp4) in - let vartuple = mktup el vars in - let exp4 = rewrite_var_updates (add_vars overwrite exp4 vartuple) in - let (E_aux (_,(_,annot4))) = exp4 in - let fname = match effectful exp4,order with - | false, Ord_aux (Ord_inc,_) -> "foreach_inc" - | false, Ord_aux (Ord_dec,_) -> "foreach_dec" - | true, Ord_aux (Ord_inc,_) -> "foreachM_inc" - | true, Ord_aux (Ord_dec,_) -> "foreachM_dec" - | _ -> raise (Reporting_basic.err_unreachable el - "Could not determine foreach combinator") in - let funcl = Id_aux (Id fname,gen_loc el) in - let loopvar = - (* Don't bother with creating a range type annotation, since the - Lem pretty-printing does not use it. *) - (* let (bf,tf) = match typ_of exp1 with - | {t = Tapp ("atom",[TA_nexp f])} -> (TA_nexp f,TA_nexp f) - | {t = Tapp ("reg", [TA_typ {t = Tapp ("atom",[TA_nexp f])}])} -> (TA_nexp f,TA_nexp f) - | {t = Tapp ("range",[TA_nexp bf;TA_nexp tf])} -> (TA_nexp bf,TA_nexp tf) - | {t = Tapp ("reg", [TA_typ {t = Tapp ("range",[TA_nexp bf;TA_nexp tf])}])} -> (TA_nexp bf,TA_nexp tf) - | {t = Tapp (name,_)} -> failwith (name ^ " shouldn't be here") in - let (bt,tt) = match typ_of exp2 with - | {t = Tapp ("atom",[TA_nexp t])} -> (TA_nexp t,TA_nexp t) - | {t = Tapp ("atom",[TA_typ {t = Tapp ("atom", [TA_nexp t])}])} -> (TA_nexp t,TA_nexp t) - | {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])} -> (TA_nexp bt,TA_nexp tt) - | {t = Tapp ("atom",[TA_typ {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])}])} -> (TA_nexp bt,TA_nexp tt) - | {t = Tapp (name,_)} -> failwith (name ^ " shouldn't be here") in - let t = {t = Tapp ("range",match order with - | Ord_aux (Ord_inc,_) -> [bf;tt] - | Ord_aux (Ord_dec,_) -> [tf;bt])} in *) - annot_exp (E_id id) l env int_typ in - let v = E_aux (E_app (funcl,[loopvar;mktup el [exp1;exp2;exp3];exp4;vartuple]), - (gen_loc el, annot4)) in + let varstuple, varspat = mk_varstup el vars in + let varstyp = typ_of varstuple in + let exp4 = rewrite_var_updates (add_vars overwrite exp4 varstuple) in + let ord_exp, lower, upper = match destruct_range (typ_of exp1), destruct_range (typ_of exp2) with + | None, _ | _, None -> + raise (Reporting_basic.err_unreachable el "Could not determine loop bounds") + | Some (l1, u1), Some (l2, u2) -> + if is_order_inc order + then (annot_exp (E_lit (mk_lit L_true)) el env bool_typ, l1, u2) + else (annot_exp (E_lit (mk_lit L_false)) el env bool_typ, l2, u1) in + let lvar_kid = mk_kid ("loop_" ^ string_of_id id) in + let lvar_nc = nc_and (nc_lteq lower (nvar lvar_kid)) (nc_lteq (nvar lvar_kid) upper) in + let lvar_typ = mk_typ (Typ_exist ([lvar_kid], lvar_nc, atom_typ (nvar lvar_kid))) in + let lvar_pat = P_typ (lvar_typ, annot_pat (P_var ( + annot_pat (P_id id) el env (atom_typ (nvar lvar_kid)), + lvar_kid)) el env lvar_typ) in + let lb = annot_letbind (lvar_pat, exp1) el env lvar_typ in + let body = annot_exp (E_let (lb, exp4)) el env (typ_of exp4) in + let v = annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; varstuple; body])) el env (typ_of body) in let pat = - if overwrite then mktup_pat el vars - else annot_pat (P_tup [pat; mktup_pat pl vars]) pl env (typ_of v) in + if overwrite then varspat + else annot_pat (P_tup [pat; varspat]) pl env (typ_of v) in Added_vars (v,pat) | E_loop(loop,cond,body) -> let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars body) in - let vartuple = mktup el vars in - (* let cond = rewrite_var_updates (add_vars false cond vartuple) in *) - let body = rewrite_var_updates (add_vars overwrite body vartuple) in + let varstuple, varspat = mk_varstup el vars in + let varstyp = typ_of varstuple in + (* let cond = rewrite_var_updates (add_vars false cond varstuple) in *) + let body = rewrite_var_updates (add_vars overwrite body varstuple) in let (E_aux (_,(_,bannot))) = body in - let fname = match loop, effectful cond, effectful body with - | While, false, false -> "while_PP" - | While, false, true -> "while_PM" - | While, true, false -> "while_MP" - | While, true, true -> "while_MM" - | Until, false, false -> "until_PP" - | Until, false, true -> "until_PM" - | Until, true, false -> "until_MP" - | Until, true, true -> "until_MM" in + let fname = match loop with + | While -> "while" + | Until -> "until" in let funcl = Id_aux (Id fname,gen_loc el) in - let v = E_aux (E_app (funcl,[cond;body;vartuple]), (gen_loc el, bannot)) in + let v = E_aux (E_app (funcl,[cond;varstuple;body]), (gen_loc el, bannot)) in let pat = - if overwrite then mktup_pat el vars - else annot_pat (P_tup [pat; mktup_pat pl vars]) pl env (typ_of v) in + if overwrite then varspat + else annot_pat (P_tup [pat; varspat]) pl env (typ_of v) in Added_vars (v,pat) | E_if (c,e1,e2) -> let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) @@ -3334,17 +3431,18 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = if vars = [] then (Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot))) else - let vartuple = mktup el vars in - let e1 = rewrite_var_updates (add_vars overwrite e1 vartuple) in - let e2 = rewrite_var_updates (add_vars overwrite e2 vartuple) in + let varstuple, varspat = mk_varstup el vars in + let varstyp = typ_of varstuple in + let e1 = rewrite_var_updates (add_vars overwrite e1 varstuple) in + let e2 = rewrite_var_updates (add_vars overwrite e2 varstuple) in (* after rewrite_defs_letbind_effects c has no variable updates *) let env = env_of_annot annot in let typ = typ_of e1 in let eff = union_eff_exps [e1;e2] in let v = E_aux (E_if (c,e1,e2), (gen_loc el, Some (env, typ, eff))) in let pat = - if overwrite then mktup_pat el vars - else annot_pat (P_tup [pat; mktup_pat pl vars]) pl env (typ_of v) in + if overwrite then varspat + else annot_pat (P_tup [pat; varspat]) pl env (typ_of v) in Added_vars (v,pat) | E_case (e1,ps) -> (* after rewrite_defs_letbind_effects e1 needs no rewriting *) @@ -3361,10 +3459,11 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in Same_vars (E_aux (E_case (e1,ps),annot)) else - let vartuple = mktup el vars in + let varstuple, varspat = mk_varstup el vars in + let varstyp = typ_of varstuple in let rewrite_pexp (Pat_aux (pexp, (l, _))) = match pexp with | Pat_exp (pat, exp) -> - let exp = rewrite_var_updates (add_vars overwrite exp vartuple) in + let exp = rewrite_var_updates (add_vars overwrite exp varstuple) in let pannot = (l, Some (env_of exp, typ_of exp, effect_of exp)) in Pat_aux (Pat_exp (pat, exp), pannot) | Pat_when _ -> @@ -3374,36 +3473,26 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_) :: _ -> typ_of first | _ -> unit_typ in let v = propagate_exp_effect (annot_exp (E_case (e1, List.map rewrite_pexp ps)) pl env typ) in - (* let (ps,typ,effs) = - let f (acc,typ,effs) (Pat_aux (Pat_exp (p,e),pannot)) = - let etyp = typ_of e in - let () = assert (string_of_typ etyp = string_of_typ typ) in - let e = rewrite_var_updates (add_vars overwrite e vartuple) in - let pannot = simple_annot pl (typ_of e) in - let effs = union_effects effs (effect_of e) in - let pat' = Pat_aux (Pat_exp (p,e),pannot) in - (acc @ [pat'],typ,effs) in - List.fold_left f ([],typ,no_effect) ps in - let v = E_aux (E_case (e1,ps), (gen_loc pl, Some (env_of_annot annot, typ, effs))) in *) let pat = - if overwrite then mktup_pat el vars - else annot_pat (P_tup [pat; mktup_pat pl vars]) pl env (typ_of v) in + if overwrite then varspat + else annot_pat (P_tup [pat; varspat]) pl env (typ_of v) in Added_vars (v,pat) | E_assign (lexp,vexp) -> - let effs = match effect_of_annot (snd annot) with - | Effect_aux (Effect_set effs, _) -> effs + let mk_id_pat id = match Env.lookup_id id env with + | Local (_, typ) -> + annot_pat (P_typ (typ, annot_pat (P_id id) pl env typ)) pl env typ | _ -> - raise (Reporting_basic.err_unreachable l - "assignment without effects annotation") in + raise (Reporting_basic.err_unreachable pl + ("Failed to look up type of variable " ^ string_of_id id)) in if effectful exp then Same_vars (E_aux (E_assign (lexp,vexp),annot)) else (match lexp with | LEXP_aux (LEXP_id id,annot) -> let pat = annot_pat (P_id id) pl env (typ_of vexp) in - Added_vars (vexp,pat) - | LEXP_aux (LEXP_cast (_,id),annot) -> - let pat = annot_pat (P_id id) pl env (typ_of vexp) in + Added_vars (vexp, mk_id_pat id) + | LEXP_aux (LEXP_cast (typ,id),annot) -> + let pat = annot_pat (P_typ (typ, annot_pat (P_id id) pl env (typ_of vexp))) pl env typ in Added_vars (vexp,pat) | LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,((l2,_) as annot2)),i),((l1,_) as annot)) -> let eid = annot_exp (E_id id) l2 env (typ_of_annot annot2) in @@ -3433,27 +3522,17 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = propagate_exp_effect (annot_exp (E_let (lb, body)) l env (typ_of body)) | E_internal_let (lexp,v,body) -> (* Rewrite E_internal_let into E_let and call recursively *) - let id = match lexp with - | LEXP_aux (LEXP_id id,_) -> id - | LEXP_aux (LEXP_cast (_,id),_) -> id + let paux, typ = match lexp with + | LEXP_aux (LEXP_id id, _) -> + P_id id, typ_of v + | LEXP_aux (LEXP_cast (typ, id), _) -> + P_typ (typ, annot_pat (P_id id) l env (typ_of v)), typ | _ -> raise (Reporting_basic.err_unreachable l "E_internal_let with a lexp that is not a variable") in - let pat = annot_pat (P_id id) l env (typ_of v) in - let lb = annot_letbind (P_id id, v) l env (typ_of v) in + let lb = annot_letbind (paux, v) l env typ in let exp = propagate_exp_effect (annot_exp (E_let (lb, body)) l env (typ_of body)) in rewrite_var_updates exp - (* let env = env_of_annot annot in - let vtyp = typ_of v in - let veff = effect_of v in - let bodyenv = env_of body in - let bodytyp = typ_of body in - let bodyeff = effect_of body in - let pat = P_aux (P_id id, (simple_annot l vtyp)) in - let lbannot = (gen_loc l, Some (env, vtyp, veff)) in - let lb = LB_aux (LB_val (pat,v),lbannot) in - let exp = E_aux (E_let (lb,body),(gen_loc l, Some (bodyenv, bodytyp, union_effects veff bodyeff))) in - rewrite_var_updates exp *) | E_internal_plet (pat,v,body) -> failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet" (* There are no expressions that have effects or variable updates in @@ -3497,17 +3576,17 @@ let rewrite_defs_remove_superfluous_letbinds = | E_let (lb,exp2) -> begin match lb,exp2 with (* 'let x = EXP1 in x' can be replaced with 'EXP1' *) - | LB_aux (LB_val (P_aux (P_id (Id_aux (id,_)),_),exp1),_), - E_aux (E_id (Id_aux (id',_)),_) - | LB_aux (LB_val (P_aux (P_id (Id_aux (id,_)),_),exp1),_), - E_aux (E_cast (_,E_aux (E_id (Id_aux (id',_)),_)),_) - when id = id' -> + | LB_aux (LB_val (P_aux (P_id id, _), exp1), _), + E_aux (E_id id', _) + | LB_aux (LB_val (P_aux (P_id id, _), exp1), _), + E_aux (E_cast (_,E_aux (E_id id', _)), _) + when Id.compare id id' == 0 && id_is_unbound id (env_of_annot annot) -> exp1 (* "let x = EXP1 in return x" can be replaced with 'return (EXP1)', at least when EXP1 is 'small' enough *) - | LB_aux (LB_val (P_aux (P_id (Id_aux (id,_)),_),exp1),_), - E_aux (E_internal_return (E_aux (E_id (Id_aux (id',_)),_)),_) - when id = id' && small exp1 -> + | LB_aux (LB_val (P_aux (P_id id, _), exp1), _), + E_aux (E_internal_return (E_aux (E_id id', _)), _) + when Id.compare id id' == 0 && small exp1 && id_is_unbound id (env_of_annot annot) -> let (E_aux (_,e1annot)) = exp1 in E_aux (E_internal_return (exp1),e1annot) | _ -> E_aux (exp,annot) @@ -3532,21 +3611,44 @@ let rewrite_defs_remove_superfluous_returns = | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> true | _ -> false in + let untyp_pat = function + | P_aux (P_typ (typ, pat), _) -> pat, Some typ + | pat -> pat, None in + + let uncast_internal_return = function + | E_aux (E_internal_return (E_aux (E_cast (typ, exp), _)), a) -> + E_aux (E_internal_return exp, a), Some typ + | exp -> exp, None in + let e_aux (exp,annot) = match exp with - | E_internal_plet (pat,exp1,exp2) when effectful exp1 -> - begin match pat,exp2 with - | P_aux (P_lit (L_aux (lit,_)),_), - E_aux (E_internal_return (E_aux (E_lit (L_aux (lit',_)),_)),_) - when lit = lit' -> - exp1 - | P_aux (P_wild,pannot), - E_aux (E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)),_) - when has_unittype exp1 -> - exp1 - | P_aux (P_id (Id_aux (id,_)),_), - E_aux (E_internal_return (E_aux (E_id (Id_aux (id',_)),_)),_) - when id = id' -> - exp1 + | E_let (LB_aux (LB_val (pat, exp1), _), exp2) + | E_internal_plet (pat, exp1, exp2) + when effectful exp1 -> + begin match untyp_pat pat, uncast_internal_return exp2 with + | (P_aux (P_lit (L_aux (lit,_)),_), ptyp), + (E_aux (E_internal_return (E_aux (E_lit (L_aux (lit',_)),_)), a), etyp) + when lit = lit' -> + begin + match ptyp, etyp with + | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a) + | None, None -> exp1 + end + | (P_aux (P_wild,pannot), ptyp), + (E_aux (E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)), a), etyp) + when has_unittype exp1 -> + begin + match ptyp, etyp with + | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a) + | None, None -> exp1 + end + | (P_aux (P_id id,_), ptyp), + (E_aux (E_internal_return (E_aux (E_id id',_)), a), etyp) + when Id.compare id id' == 0 && id_is_unbound id (env_of_annot annot) -> + begin + match ptyp, etyp with + | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a) + | None, None -> exp1 + end | _ -> E_aux (exp,annot) end | _ -> E_aux (exp,annot) in @@ -3563,7 +3665,11 @@ let rewrite_defs_remove_superfluous_returns = } -let rewrite_defs_remove_e_assign = +let rewrite_defs_remove_e_assign (Defs defs) = + let (Defs loop_specs) = fst (check Env.empty (Defs (List.map gen_vs + [("foreach", "forall ('vars : Type). (int, int, int, bool, 'vars, 'vars) -> 'vars"); + ("while", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); + ("until", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars")]))) in let rewrite_exp _ e = replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in rewrite_defs_base @@ -3574,32 +3680,35 @@ let rewrite_defs_remove_e_assign = ; rewrite_fun = rewrite_fun ; rewrite_def = rewrite_def ; rewrite_defs = rewrite_defs_base - } + } (Defs (loop_specs @ defs)) let recheck_defs defs = fst (check initial_env defs) let rewrite_defs_lem = [ - ("top_sort_defs", top_sort_defs); ("tuple_vector_assignments", rewrite_tuple_vector_assignments); ("tuple_assignments", rewrite_tuple_assignments); (* ("simple_assignments", rewrite_simple_assignments); *) - ("constraint", rewrite_constraint); - ("trivial_sizeof", rewrite_trivial_sizeof); - ("sizeof", rewrite_sizeof); ("remove_vector_concat", rewrite_defs_remove_vector_concat); ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats); ("guarded_pats", rewrite_defs_guarded_pats); - (* ("recheck_defs", recheck_defs); *) + ("exp_lift_assign", rewrite_defs_exp_lift_assign); + ("register_ref_writes", rewrite_register_ref_writes); + ("recheck_defs", recheck_defs); + (* ("constraint", rewrite_constraint); *) + (* ("remove_assert", rewrite_defs_remove_assert); *) + ("top_sort_defs", top_sort_defs); + ("trivial_sizeof", rewrite_trivial_sizeof); + ("sizeof", rewrite_sizeof); ("early_return", rewrite_defs_early_return); ("nexp_ids", rewrite_defs_nexp_ids); ("fix_val_specs", rewrite_fix_val_specs); - ("exp_lift_assign", rewrite_defs_exp_lift_assign); ("remove_blocks", rewrite_defs_remove_blocks); ("letbind_effects", rewrite_defs_letbind_effects); ("remove_e_assign", rewrite_defs_remove_e_assign); - ("effectful_let_expressions", rewrite_defs_effectful_let_expressions); + ("internal_lets", rewrite_defs_internal_lets); ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds); - ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns) + ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns); + ("recheck_defs", recheck_defs) ] let rewrite_defs_ocaml = [ |
