summaryrefslogtreecommitdiff
path: root/src/rewriter.ml
diff options
context:
space:
mode:
authorAlasdair Armstrong2017-08-10 23:28:43 +0100
committerAlasdair Armstrong2017-08-10 23:28:43 +0100
commit01f382196302e378c377c96bf249236e06d7291c (patch)
tree69bfa09d2ec3d8011740f3f322e37f8112c5e0a9 /src/rewriter.ml
parentde787176067f4569af1ed4133b0edf72d4dcd4a1 (diff)
parent588c45e84642425fe9530f4ef6a44753cc54a0f8 (diff)
Merge remote-tracking branch 'origin/sail_new_tc' into experiments
Conflicts: src/pretty_print_common.ml
Diffstat (limited to 'src/rewriter.ml')
-rw-r--r--src/rewriter.ml205
1 files changed, 172 insertions, 33 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 8cf682bf..8da8aacf 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -66,7 +66,13 @@ let effect_of_fexp (FE_aux (_,(_,a))) = effect_of_annot a
let effect_of_fexps (FES_aux (FES_Fexps (fexps,_),_)) =
List.fold_left union_effects no_effect (List.map effect_of_fexp fexps)
let effect_of_opt_default (Def_val_aux (_,(_,a))) = effect_of_annot a
-let effect_of_pexp (Pat_aux (_,(_,a))) = effect_of_annot a
+(* The typechecker does not seem to annotate pexps themselves *)
+let effect_of_pexp (Pat_aux (pexp,(_,a))) = match a with
+ | Some (_, _, eff) -> eff
+ | None ->
+ (match pexp with
+ | Pat_exp (_, e) -> effect_of e
+ | Pat_when (_, g, e) -> union_effects (effect_of g) (effect_of e))
let effect_of_lb (LB_aux (_,(_,a))) = effect_of_annot a
let get_loc_exp (E_aux (_,(l,_))) = l
@@ -143,8 +149,8 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match snd annot with
List.fold_left union_effects (effect_of e) (List.map effect_of_pexp pexps)
| E_let (lb,e) -> union_effects (effect_of_lb lb) (effect_of e)
| E_assign (lexp,e) -> union_effects (effect_of_lexp lexp) (effect_of e)
- | E_exit e -> effect_of e
- | E_return e -> effect_of e
+ | E_exit e -> union_effects eff (effect_of e)
+ | E_return e -> union_effects eff (effect_of e)
| E_sizeof _ | E_sizeof_internal _ | E_constraint _ -> no_effect
| E_assert (c,m) -> eff
| E_comment _ | E_comment_struc _ -> no_effect
@@ -1048,6 +1054,7 @@ let rewrite_sizeof (Defs defs) =
Id_aux (Id op, Parse_ast.Unknown),
E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2))
) in
+ let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in
(match nexp with
| Nexp_constant i -> E_lit (L_aux (L_num i, l))
| Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2
@@ -1058,14 +1065,15 @@ let rewrite_sizeof (Defs defs) =
(* Rewrite calls to functions which have had parameters added to pass values
of type-level variables; these are added as sizeof expressions first, and
then further rewritten as above. *)
- let e_app_aux param_map (exp, ((l,_) as annot)) =
+ let e_app_aux param_map ((exp, exp_orig), ((l,_) as annot)) =
let full_exp = E_aux (exp, annot) in
+ let orig_exp = E_aux (exp_orig, annot) in
match exp with
| E_app (f, args) ->
if Bindings.mem f param_map then
(* Retrieve instantiation of the type variables of the called function
- for the given parameters in the current environment *)
- let inst = instantiation_of full_exp in
+ for the given parameters in the original environment *)
+ let inst = instantiation_of orig_exp in
let kid_exp kid = begin
match KBindings.find kid inst with
| U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp))
@@ -1075,9 +1083,75 @@ let rewrite_sizeof (Defs defs) =
" of function " ^ string_of_id f))
end in
let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in
- E_aux (E_app (f, kid_exps @ args), annot)
- else full_exp
- | _ -> full_exp in
+ (E_aux (E_app (f, kid_exps @ args), annot), orig_exp)
+ else (full_exp, orig_exp)
+ | _ -> (full_exp, orig_exp) in
+
+ (* Plug this into a folding algorithm that also keeps around a copy of the
+ original expressions, which we use to infer instantiations of type variables
+ in the original environments *)
+ let copy_exp_alg =
+ { e_block = (fun es -> let (es, es') = List.split es in (E_block es, E_block es'))
+ ; e_nondet = (fun es -> let (es, es') = List.split es in (E_nondet es, E_nondet es'))
+ ; e_id = (fun id -> (E_id id, E_id id))
+ ; e_lit = (fun lit -> (E_lit lit, E_lit lit))
+ ; e_cast = (fun (typ,(e,e')) -> (E_cast (typ,e), E_cast (typ,e')))
+ ; e_app = (fun (id,es) -> let (es, es') = List.split es in (E_app (id,es), E_app (id,es')))
+ ; e_app_infix = (fun ((e1,e1'),id,(e2,e2')) -> (E_app_infix (e1,id,e2), E_app_infix (e1',id,e2')))
+ ; e_tuple = (fun es -> let (es, es') = List.split es in (E_tuple es, E_tuple es'))
+ ; e_if = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_if (e1,e2,e3), E_if (e1',e2',e3')))
+ ; e_for = (fun (id,(e1,e1'),(e2,e2'),(e3,e3'),order,(e4,e4')) -> (E_for (id,e1,e2,e3,order,e4), E_for (id,e1',e2',e3',order,e4')))
+ ; e_vector = (fun es -> let (es, es') = List.split es in (E_vector es, E_vector es'))
+ ; e_vector_indexed = (fun (es,(opt2,opt2')) -> let (is, es) = List.split es in let (es, es') = List.split es in let (es, es') = (List.combine is es, List.combine is es') in (E_vector_indexed (es,opt2), E_vector_indexed (es',opt2')))
+ ; e_vector_access = (fun ((e1,e1'),(e2,e2')) -> (E_vector_access (e1,e2), E_vector_access (e1',e2')))
+ ; e_vector_subrange = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_vector_subrange (e1,e2,e3), E_vector_subrange (e1',e2',e3')))
+ ; e_vector_update = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_vector_update (e1,e2,e3), E_vector_update (e1',e2',e3')))
+ ; e_vector_update_subrange = (fun ((e1,e1'),(e2,e2'),(e3,e3'),(e4,e4')) -> (E_vector_update_subrange (e1,e2,e3,e4), E_vector_update_subrange (e1',e2',e3',e4')))
+ ; e_vector_append = (fun ((e1,e1'),(e2,e2')) -> (E_vector_append (e1,e2), E_vector_append (e1',e2')))
+ ; e_list = (fun es -> let (es, es') = List.split es in (E_list es, E_list es'))
+ ; e_cons = (fun ((e1,e1'),(e2,e2')) -> (E_cons (e1,e2), E_cons (e1',e2')))
+ ; e_record = (fun (fexps, fexps') -> (E_record fexps, E_record fexps'))
+ ; e_record_update = (fun ((e1,e1'),(fexp,fexp')) -> (E_record_update (e1,fexp), E_record_update (e1',fexp')))
+ ; e_field = (fun ((e1,e1'),id) -> (E_field (e1,id), E_field (e1',id)))
+ ; e_case = (fun ((e1,e1'),pexps) -> let (pexps, pexps') = List.split pexps in (E_case (e1,pexps), E_case (e1',pexps')))
+ ; e_let = (fun ((lb,lb'),(e2,e2')) -> (E_let (lb,e2), E_let (lb',e2')))
+ ; e_assign = (fun ((lexp,lexp'),(e2,e2')) -> (E_assign (lexp,e2), E_assign (lexp',e2')))
+ ; e_sizeof = (fun nexp -> (E_sizeof nexp, E_sizeof nexp))
+ ; e_exit = (fun (e1,e1') -> (E_exit (e1), E_exit (e1')))
+ ; e_return = (fun (e1,e1') -> (E_return e1, E_return e1'))
+ ; e_assert = (fun ((e1,e1'),(e2,e2')) -> (E_assert(e1,e2), E_assert(e1',e2')) )
+ ; e_internal_cast = (fun (a,(e1,e1')) -> (E_internal_cast (a,e1), E_internal_cast (a,e1')))
+ ; e_internal_exp = (fun a -> (E_internal_exp a, E_internal_exp a))
+ ; e_internal_exp_user = (fun (a1,a2) -> (E_internal_exp_user (a1,a2), E_internal_exp_user (a1,a2)))
+ ; e_comment = (fun c -> (E_comment c, E_comment c))
+ ; e_comment_struc = (fun (e,e') -> (E_comment_struc e, E_comment_struc e'))
+ ; e_internal_let = (fun ((lexp,lexp'), (e2,e2'), (e3,e3')) -> (E_internal_let (lexp,e2,e3), E_internal_let (lexp',e2',e3')))
+ ; e_internal_plet = (fun (pat, (e1,e1'), (e2,e2')) -> (E_internal_plet (pat,e1,e2), E_internal_plet (pat,e1',e2')))
+ ; e_internal_return = (fun (e,e') -> (E_internal_return e, E_internal_return e'))
+ ; e_aux = (fun ((e,e'),annot) -> (E_aux (e,annot), E_aux (e',annot)))
+ ; lEXP_id = (fun id -> (LEXP_id id, LEXP_id id))
+ ; lEXP_memory = (fun (id,es) -> let (es, es') = List.split es in (LEXP_memory (id,es), LEXP_memory (id,es')))
+ ; lEXP_cast = (fun (typ,id) -> (LEXP_cast (typ,id), LEXP_cast (typ,id)))
+ ; lEXP_tup = (fun tups -> let (tups,tups') = List.split tups in (LEXP_tup tups, LEXP_tup tups'))
+ ; lEXP_vector = (fun ((lexp,lexp'),(e2,e2')) -> (LEXP_vector (lexp,e2), LEXP_vector (lexp',e2')))
+ ; lEXP_vector_range = (fun ((lexp,lexp'),(e2,e2'),(e3,e3')) -> (LEXP_vector_range (lexp,e2,e3), LEXP_vector_range (lexp',e2',e3')))
+ ; lEXP_field = (fun ((lexp,lexp'),id) -> (LEXP_field (lexp,id), LEXP_field (lexp',id)))
+ ; lEXP_aux = (fun ((lexp,lexp'),annot) -> (LEXP_aux (lexp,annot), LEXP_aux (lexp',annot)))
+ ; fE_Fexp = (fun (id,(e,e')) -> (FE_Fexp (id,e), FE_Fexp (id,e')))
+ ; fE_aux = (fun ((fexp,fexp'),annot) -> (FE_aux (fexp,annot), FE_aux (fexp',annot)))
+ ; fES_Fexps = (fun (fexps,b) -> let (fexps, fexps') = List.split fexps in (FES_Fexps (fexps,b), FES_Fexps (fexps',b)))
+ ; fES_aux = (fun ((fexp,fexp'),annot) -> (FES_aux (fexp,annot), FES_aux (fexp',annot)))
+ ; def_val_empty = (Def_val_empty, Def_val_empty)
+ ; def_val_dec = (fun (e,e') -> (Def_val_dec e, Def_val_dec e'))
+ ; def_val_aux = (fun ((defval,defval'),aux) -> (Def_val_aux (defval,aux), Def_val_aux (defval',aux)))
+ ; pat_exp = (fun (pat,(e,e')) -> (Pat_exp (pat,e), Pat_exp (pat,e')))
+ ; pat_when = (fun (pat,(e1,e1'),(e2,e2')) -> (Pat_when (pat,e1,e2), Pat_when (pat,e1',e2')))
+ ; pat_aux = (fun ((pexp,pexp'),a) -> (Pat_aux (pexp,a), Pat_aux (pexp',a)))
+ ; lB_val_explicit = (fun (typ,pat,(e,e')) -> (LB_val_explicit (typ,pat,e), LB_val_explicit (typ,pat,e')))
+ ; lB_val_implicit = (fun (pat,(e,e')) -> (LB_val_implicit (pat,e), LB_val_implicit (pat,e')))
+ ; lB_aux = (fun ((lb,lb'),annot) -> (LB_aux (lb,annot), LB_aux (lb',annot)))
+ ; pat_alg = id_pat_alg
+ } in
let rewrite_sizeof_fun params_map
(FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) =
@@ -1086,7 +1160,7 @@ let rewrite_sizeof (Defs defs) =
let body_typ = typ_of exp in
let nmap = nexps_from_params pat in
(* first rewrite calls to other functions... *)
- let exp' = fold_exp { id_exp_alg with e_aux = e_app_aux params_map } exp in
+ let exp' = fst (fold_exp { copy_exp_alg with e_aux = e_app_aux params_map } exp) in
(* ... then rewrite sizeof expressions in current function body *)
let exp'' = fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp' in
(FCL_aux (FCL_Funcl (id,pat,exp''), annot) :: funcls,
@@ -1133,9 +1207,10 @@ let rewrite_sizeof (Defs defs) =
let rewrite_sizeof_fundef (params_map, defs) = function
| DEF_fundef fd ->
let (nvars, fd') = rewrite_sizeof_fun params_map fd in
+ let id = id_of_fundef fd in
let params_map' =
if KidSet.is_empty nvars then params_map
- else Bindings.add (id_of_fundef fd) nvars params_map in
+ else Bindings.add id nvars params_map in
(params_map', defs @ [DEF_fundef fd'])
| def ->
(params_map, defs @ [def]) in
@@ -1947,7 +2022,7 @@ let rewrite_defs_remove_bitvector_pats (Defs defs) =
let defvals = List.map (fun lb -> DEF_val lb) letbinds in
[DEF_val (LB_aux (LB_val_implicit (pat',exp),a))] @ defvals
| d -> [d] in
- Defs (List.flatten (List.map rewrite_def defs))
+ fst (check initial_env (Defs (List.flatten (List.map rewrite_def defs))))
(* Remove pattern guards by rewriting them to if-expressions within the
@@ -1999,7 +2074,7 @@ let rewrite_exp_lift_assign_intro rewriters ((E_aux (exp,((l,_) as annot))) as f
| (E_aux(E_assign((LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_)) as le,e),
((l, Some (env,typ,eff)) as annot)) as exp)::exps ->
(match Env.lookup_id id env with
- | Unbound ->
+ | Unbound | Local _ ->
let le' = rewriters.rewrite_lexp rewriters le in
let e' = rewrite_base e in
let exps' = walker exps in
@@ -2136,12 +2211,72 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base
rewrite_def = rewrite_def;
rewrite_defs = rewrite_defs_base} defs*)
-let rewrite_defs_ocaml =
- top_sort_defs >>
- rewrite_defs_remove_vector_concat >>
- rewrite_sizeof >>
- rewrite_defs_exp_lift_assign (* >>
+let rewrite_defs_early_return =
+ let is_return (E_aux (exp, _)) = match exp with
+ | E_return _ -> true
+ | _ -> false in
+
+ let get_return (E_aux (e, (l, _)) as exp) = match e with
+ | E_return e -> e
+ | _ -> exp in
+
+ let e_block es =
+ (* let rec walker = function
+ | e :: es -> if is_return e then [e] else e :: walker es
+ | [] -> [] in
+ let es = walker es in *)
+ match es with
+ | [E_aux (e, _)] -> e
+ | _ -> E_block es in
+
+ let e_if (e1, e2, e3) =
+ if is_return e2 && is_return e3 then E_if (e1, get_return e2, get_return e3)
+ else E_if (e1, e2, e3) in
+
+ let e_case (e, pes) =
+ let is_return_pexp (Pat_aux (pexp, _)) = match pexp with
+ | Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in
+ let get_return_pexp (Pat_aux (pexp, a)) = match pexp with
+ | Pat_exp (p, e) -> Pat_aux (Pat_exp (p, get_return e), a)
+ | Pat_when (p, g, e) -> Pat_aux (Pat_when (p, g, get_return e), a) in
+ if List.for_all is_return_pexp pes
+ then E_return (E_aux (E_case (e, List.map get_return_pexp pes), (Parse_ast.Unknown, None)))
+ else E_case (e, pes) in
+
+ let e_aux (exp, (l, annot)) =
+ let full_exp = fix_eff_exp (E_aux (exp, (l, annot))) in
+ match annot with
+ | Some (env, typ, eff) when is_return full_exp ->
+ (* Add escape effect annotation, since we use the exception mechanism
+ of the state monad to implement early return in the Lem backend *)
+ let annot' = Some (env, typ, union_effects eff (mk_effect [BE_escape])) in
+ E_aux (exp, (l, annot'))
+ | _ -> full_exp in
+
+ let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pat, exp), a)) =
+ let exp = fold_exp
+ { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case;
+ e_aux = e_aux } exp in
+ let a = match a with
+ | (l, Some (env, typ, eff)) ->
+ (l, Some (env, typ, union_effects eff (effect_of exp)))
+ | _ -> a in
+ FCL_aux (FCL_Funcl (id, pat, get_return exp), a) in
+
+ let rewrite_fun_early_return rewriters
+ (FD_aux (FD_function (rec_opt, tannot_opt, effect_opt, funcls), a)) =
+ FD_aux (FD_function (rec_opt, tannot_opt, effect_opt,
+ List.map (rewrite_funcl_early_return rewriters) funcls), a) in
+
+ rewrite_defs_base { rewriters_base with rewrite_fun = rewrite_fun_early_return }
+
+let rewrite_defs_ocaml = [
+ top_sort_defs;
+ rewrite_defs_remove_vector_concat;
+ rewrite_sizeof;
+ rewrite_defs_exp_lift_assign (* ;
rewrite_defs_separate_numbs *)
+ ]
let rewrite_defs_remove_blocks =
let letbind_wild v body =
@@ -2398,7 +2533,7 @@ let rewrite_defs_letbind_effects =
| E_case (exp1,pexps) ->
let newreturn =
List.fold_left
- (fun b (Pat_aux (_,(_,annot))) -> b || effectful_effs (effect_of_annot annot))
+ (fun b pexp -> b || effectful_effs (effect_of_pexp pexp))
false pexps in
n_exp_name exp1 (fun exp1 ->
n_pexpL newreturn pexps (fun pexps ->
@@ -2766,7 +2901,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| _ ->
raise (Reporting_basic.err_unreachable l
"assignment without effects annotation") in
- if not (List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs) then
+ if effectful exp then
Same_vars (E_aux (E_assign (lexp,vexp),annot))
else
(match lexp with
@@ -2968,18 +3103,22 @@ let rewrite_defs_remove_e_assign =
; rewrite_defs = rewrite_defs_base
}
-
-let rewrite_defs_lem =
- top_sort_defs >>
- rewrite_sizeof >>
- rewrite_defs_remove_vector_concat >>
- rewrite_defs_remove_bitvector_pats >>
- rewrite_defs_guarded_pats >>
- rewrite_defs_exp_lift_assign >>
- rewrite_defs_remove_blocks >>
- rewrite_defs_letbind_effects >>
- rewrite_defs_remove_e_assign >>
- rewrite_defs_effectful_let_expressions >>
- rewrite_defs_remove_superfluous_letbinds >>
+let recheck_defs defs = fst (check initial_env defs)
+
+let rewrite_defs_lem =[
+ top_sort_defs;
+ rewrite_sizeof;
+ rewrite_defs_remove_vector_concat;
+ rewrite_defs_remove_bitvector_pats;
+ rewrite_defs_guarded_pats;
+ (* recheck_defs; *)
+ rewrite_defs_early_return;
+ rewrite_defs_exp_lift_assign;
+ rewrite_defs_remove_blocks;
+ rewrite_defs_letbind_effects;
+ rewrite_defs_remove_e_assign;
+ rewrite_defs_effectful_let_expressions;
+ rewrite_defs_remove_superfluous_letbinds;
rewrite_defs_remove_superfluous_returns
+ ]