summaryrefslogtreecommitdiff
path: root/src/rewriter.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewriter.ml')
-rw-r--r--src/rewriter.ml99
1 files changed, 88 insertions, 11 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 33fe3c25..1f8452ba 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -66,7 +66,13 @@ let effect_of_fexp (FE_aux (_,(_,a))) = effect_of_annot a
let effect_of_fexps (FES_aux (FES_Fexps (fexps,_),_)) =
List.fold_left union_effects no_effect (List.map effect_of_fexp fexps)
let effect_of_opt_default (Def_val_aux (_,(_,a))) = effect_of_annot a
-let effect_of_pexp (Pat_aux (_,(_,a))) = effect_of_annot a
+(* The typechecker does not seem to annotate pexps themselves *)
+let effect_of_pexp (Pat_aux (pexp,(_,a))) = match a with
+ | Some (_, _, eff) -> eff
+ | None ->
+ (match pexp with
+ | Pat_exp (_, e) -> effect_of e
+ | Pat_when (_, g, e) -> union_effects (effect_of g) (effect_of e))
let effect_of_lb (LB_aux (_,(_,a))) = effect_of_annot a
let get_loc_exp (E_aux (_,(l,_))) = l
@@ -1048,6 +1054,7 @@ let rewrite_sizeof (Defs defs) =
Id_aux (Id op, Parse_ast.Unknown),
E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2))
) in
+ let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in
(match nexp with
| Nexp_constant i -> E_lit (L_aux (L_num i, l))
| Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2
@@ -1058,14 +1065,15 @@ let rewrite_sizeof (Defs defs) =
(* Rewrite calls to functions which have had parameters added to pass values
of type-level variables; these are added as sizeof expressions first, and
then further rewritten as above. *)
- let e_app_aux param_map (exp, ((l,_) as annot)) =
+ let e_app_aux param_map ((exp, exp_orig), ((l,_) as annot)) =
let full_exp = E_aux (exp, annot) in
+ let orig_exp = E_aux (exp_orig, annot) in
match exp with
| E_app (f, args) ->
if Bindings.mem f param_map then
(* Retrieve instantiation of the type variables of the called function
- for the given parameters in the current environment *)
- let inst = instantiation_of full_exp in
+ for the given parameters in the original environment *)
+ let inst = instantiation_of orig_exp in
let kid_exp kid = begin
match KBindings.find kid inst with
| U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp))
@@ -1075,9 +1083,75 @@ let rewrite_sizeof (Defs defs) =
" of function " ^ string_of_id f))
end in
let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in
- E_aux (E_app (f, kid_exps @ args), annot)
- else full_exp
- | _ -> full_exp in
+ (E_aux (E_app (f, kid_exps @ args), annot), orig_exp)
+ else (full_exp, orig_exp)
+ | _ -> (full_exp, orig_exp) in
+
+ (* Plug this into a folding algorithm that also keeps around a copy of the
+ original expressions, which we use to infer instantiations of type variables
+ in the original environments *)
+ let copy_exp_alg =
+ { e_block = (fun es -> let (es, es') = List.split es in (E_block es, E_block es'))
+ ; e_nondet = (fun es -> let (es, es') = List.split es in (E_nondet es, E_nondet es'))
+ ; e_id = (fun id -> (E_id id, E_id id))
+ ; e_lit = (fun lit -> (E_lit lit, E_lit lit))
+ ; e_cast = (fun (typ,(e,e')) -> (E_cast (typ,e), E_cast (typ,e')))
+ ; e_app = (fun (id,es) -> let (es, es') = List.split es in (E_app (id,es), E_app (id,es')))
+ ; e_app_infix = (fun ((e1,e1'),id,(e2,e2')) -> (E_app_infix (e1,id,e2), E_app_infix (e1',id,e2')))
+ ; e_tuple = (fun es -> let (es, es') = List.split es in (E_tuple es, E_tuple es'))
+ ; e_if = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_if (e1,e2,e3), E_if (e1',e2',e3')))
+ ; e_for = (fun (id,(e1,e1'),(e2,e2'),(e3,e3'),order,(e4,e4')) -> (E_for (id,e1,e2,e3,order,e4), E_for (id,e1',e2',e3',order,e4')))
+ ; e_vector = (fun es -> let (es, es') = List.split es in (E_vector es, E_vector es'))
+ ; e_vector_indexed = (fun (es,(opt2,opt2')) -> let (is, es) = List.split es in let (es, es') = List.split es in let (es, es') = (List.combine is es, List.combine is es') in (E_vector_indexed (es,opt2), E_vector_indexed (es',opt2')))
+ ; e_vector_access = (fun ((e1,e1'),(e2,e2')) -> (E_vector_access (e1,e2), E_vector_access (e1',e2')))
+ ; e_vector_subrange = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_vector_subrange (e1,e2,e3), E_vector_subrange (e1',e2',e3')))
+ ; e_vector_update = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_vector_update (e1,e2,e3), E_vector_update (e1',e2',e3')))
+ ; e_vector_update_subrange = (fun ((e1,e1'),(e2,e2'),(e3,e3'),(e4,e4')) -> (E_vector_update_subrange (e1,e2,e3,e4), E_vector_update_subrange (e1',e2',e3',e4')))
+ ; e_vector_append = (fun ((e1,e1'),(e2,e2')) -> (E_vector_append (e1,e2), E_vector_append (e1',e2')))
+ ; e_list = (fun es -> let (es, es') = List.split es in (E_list es, E_list es'))
+ ; e_cons = (fun ((e1,e1'),(e2,e2')) -> (E_cons (e1,e2), E_cons (e1',e2')))
+ ; e_record = (fun (fexps, fexps') -> (E_record fexps, E_record fexps'))
+ ; e_record_update = (fun ((e1,e1'),(fexp,fexp')) -> (E_record_update (e1,fexp), E_record_update (e1',fexp')))
+ ; e_field = (fun ((e1,e1'),id) -> (E_field (e1,id), E_field (e1',id)))
+ ; e_case = (fun ((e1,e1'),pexps) -> let (pexps, pexps') = List.split pexps in (E_case (e1,pexps), E_case (e1',pexps')))
+ ; e_let = (fun ((lb,lb'),(e2,e2')) -> (E_let (lb,e2), E_let (lb',e2')))
+ ; e_assign = (fun ((lexp,lexp'),(e2,e2')) -> (E_assign (lexp,e2), E_assign (lexp',e2')))
+ ; e_sizeof = (fun nexp -> (E_sizeof nexp, E_sizeof nexp))
+ ; e_exit = (fun (e1,e1') -> (E_exit (e1), E_exit (e1')))
+ ; e_return = (fun (e1,e1') -> (E_return e1, E_return e1'))
+ ; e_assert = (fun ((e1,e1'),(e2,e2')) -> (E_assert(e1,e2), E_assert(e1',e2')) )
+ ; e_internal_cast = (fun (a,(e1,e1')) -> (E_internal_cast (a,e1), E_internal_cast (a,e1')))
+ ; e_internal_exp = (fun a -> (E_internal_exp a, E_internal_exp a))
+ ; e_internal_exp_user = (fun (a1,a2) -> (E_internal_exp_user (a1,a2), E_internal_exp_user (a1,a2)))
+ ; e_comment = (fun c -> (E_comment c, E_comment c))
+ ; e_comment_struc = (fun (e,e') -> (E_comment_struc e, E_comment_struc e'))
+ ; e_internal_let = (fun ((lexp,lexp'), (e2,e2'), (e3,e3')) -> (E_internal_let (lexp,e2,e3), E_internal_let (lexp',e2',e3')))
+ ; e_internal_plet = (fun (pat, (e1,e1'), (e2,e2')) -> (E_internal_plet (pat,e1,e2), E_internal_plet (pat,e1',e2')))
+ ; e_internal_return = (fun (e,e') -> (E_internal_return e, E_internal_return e'))
+ ; e_aux = (fun ((e,e'),annot) -> (E_aux (e,annot), E_aux (e',annot)))
+ ; lEXP_id = (fun id -> (LEXP_id id, LEXP_id id))
+ ; lEXP_memory = (fun (id,es) -> let (es, es') = List.split es in (LEXP_memory (id,es), LEXP_memory (id,es')))
+ ; lEXP_cast = (fun (typ,id) -> (LEXP_cast (typ,id), LEXP_cast (typ,id)))
+ ; lEXP_tup = (fun tups -> let (tups,tups') = List.split tups in (LEXP_tup tups, LEXP_tup tups'))
+ ; lEXP_vector = (fun ((lexp,lexp'),(e2,e2')) -> (LEXP_vector (lexp,e2), LEXP_vector (lexp',e2')))
+ ; lEXP_vector_range = (fun ((lexp,lexp'),(e2,e2'),(e3,e3')) -> (LEXP_vector_range (lexp,e2,e3), LEXP_vector_range (lexp',e2',e3')))
+ ; lEXP_field = (fun ((lexp,lexp'),id) -> (LEXP_field (lexp,id), LEXP_field (lexp',id)))
+ ; lEXP_aux = (fun ((lexp,lexp'),annot) -> (LEXP_aux (lexp,annot), LEXP_aux (lexp',annot)))
+ ; fE_Fexp = (fun (id,(e,e')) -> (FE_Fexp (id,e), FE_Fexp (id,e')))
+ ; fE_aux = (fun ((fexp,fexp'),annot) -> (FE_aux (fexp,annot), FE_aux (fexp',annot)))
+ ; fES_Fexps = (fun (fexps,b) -> let (fexps, fexps') = List.split fexps in (FES_Fexps (fexps,b), FES_Fexps (fexps',b)))
+ ; fES_aux = (fun ((fexp,fexp'),annot) -> (FES_aux (fexp,annot), FES_aux (fexp',annot)))
+ ; def_val_empty = (Def_val_empty, Def_val_empty)
+ ; def_val_dec = (fun (e,e') -> (Def_val_dec e, Def_val_dec e'))
+ ; def_val_aux = (fun ((defval,defval'),aux) -> (Def_val_aux (defval,aux), Def_val_aux (defval',aux)))
+ ; pat_exp = (fun (pat,(e,e')) -> (Pat_exp (pat,e), Pat_exp (pat,e')))
+ ; pat_when = (fun (pat,(e1,e1'),(e2,e2')) -> (Pat_when (pat,e1,e2), Pat_when (pat,e1',e2')))
+ ; pat_aux = (fun ((pexp,pexp'),a) -> (Pat_aux (pexp,a), Pat_aux (pexp',a)))
+ ; lB_val_explicit = (fun (typ,pat,(e,e')) -> (LB_val_explicit (typ,pat,e), LB_val_explicit (typ,pat,e')))
+ ; lB_val_implicit = (fun (pat,(e,e')) -> (LB_val_implicit (pat,e), LB_val_implicit (pat,e')))
+ ; lB_aux = (fun ((lb,lb'),annot) -> (LB_aux (lb,annot), LB_aux (lb',annot)))
+ ; pat_alg = id_pat_alg
+ } in
let rewrite_sizeof_fun params_map
(FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) =
@@ -1086,7 +1160,7 @@ let rewrite_sizeof (Defs defs) =
let body_typ = typ_of exp in
let nmap = nexps_from_params pat in
(* first rewrite calls to other functions... *)
- let exp' = fold_exp { id_exp_alg with e_aux = e_app_aux params_map } exp in
+ let exp' = fst (fold_exp { copy_exp_alg with e_aux = e_app_aux params_map } exp) in
(* ... then rewrite sizeof expressions in current function body *)
let exp'' = fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp' in
(FCL_aux (FCL_Funcl (id,pat,exp''), annot) :: funcls,
@@ -1133,9 +1207,10 @@ let rewrite_sizeof (Defs defs) =
let rewrite_sizeof_fundef (params_map, defs) = function
| DEF_fundef fd ->
let (nvars, fd') = rewrite_sizeof_fun params_map fd in
+ let id = id_of_fundef fd in
let params_map' =
if KidSet.is_empty nvars then params_map
- else Bindings.add (id_of_fundef fd) nvars params_map in
+ else Bindings.add id nvars params_map in
(params_map', defs @ [DEF_fundef fd'])
| def ->
(params_map, defs @ [def]) in
@@ -1947,7 +2022,7 @@ let rewrite_defs_remove_bitvector_pats (Defs defs) =
let defvals = List.map (fun lb -> DEF_val lb) letbinds in
[DEF_val (LB_aux (LB_val_implicit (pat',exp),a))] @ defvals
| d -> [d] in
- Defs (List.flatten (List.map rewrite_def defs))
+ fst (check initial_env (Defs (List.flatten (List.map rewrite_def defs))))
(* Remove pattern guards by rewriting them to if-expressions within the
@@ -2399,7 +2474,7 @@ let rewrite_defs_letbind_effects =
| E_case (exp1,pexps) ->
let newreturn =
List.fold_left
- (fun b (Pat_aux (_,(_,annot))) -> b || effectful_effs (effect_of_annot annot))
+ (fun b pexp -> b || effectful_effs (effect_of_pexp pexp))
false pexps in
n_exp_name exp1 (fun exp1 ->
n_pexpL newreturn pexps (fun pexps ->
@@ -2969,6 +3044,7 @@ let rewrite_defs_remove_e_assign =
; rewrite_defs = rewrite_defs_base
}
+let recheck_defs defs = fst (check initial_env defs)
let rewrite_defs_lem =[
top_sort_defs;
@@ -2976,6 +3052,7 @@ let rewrite_defs_lem =[
rewrite_defs_remove_vector_concat;
rewrite_defs_remove_bitvector_pats;
rewrite_defs_guarded_pats;
+ (* recheck_defs; *)
rewrite_defs_exp_lift_assign;
rewrite_defs_remove_blocks;
rewrite_defs_letbind_effects;