diff options
Diffstat (limited to 'src/rewriter.ml')
| -rw-r--r-- | src/rewriter.ml | 99 |
1 files changed, 88 insertions, 11 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index 33fe3c25..1f8452ba 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -66,7 +66,13 @@ let effect_of_fexp (FE_aux (_,(_,a))) = effect_of_annot a let effect_of_fexps (FES_aux (FES_Fexps (fexps,_),_)) = List.fold_left union_effects no_effect (List.map effect_of_fexp fexps) let effect_of_opt_default (Def_val_aux (_,(_,a))) = effect_of_annot a -let effect_of_pexp (Pat_aux (_,(_,a))) = effect_of_annot a +(* The typechecker does not seem to annotate pexps themselves *) +let effect_of_pexp (Pat_aux (pexp,(_,a))) = match a with + | Some (_, _, eff) -> eff + | None -> + (match pexp with + | Pat_exp (_, e) -> effect_of e + | Pat_when (_, g, e) -> union_effects (effect_of g) (effect_of e)) let effect_of_lb (LB_aux (_,(_,a))) = effect_of_annot a let get_loc_exp (E_aux (_,(l,_))) = l @@ -1048,6 +1054,7 @@ let rewrite_sizeof (Defs defs) = Id_aux (Id op, Parse_ast.Unknown), E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2)) ) in + let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in (match nexp with | Nexp_constant i -> E_lit (L_aux (L_num i, l)) | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2 @@ -1058,14 +1065,15 @@ let rewrite_sizeof (Defs defs) = (* Rewrite calls to functions which have had parameters added to pass values of type-level variables; these are added as sizeof expressions first, and then further rewritten as above. *) - let e_app_aux param_map (exp, ((l,_) as annot)) = + let e_app_aux param_map ((exp, exp_orig), ((l,_) as annot)) = let full_exp = E_aux (exp, annot) in + let orig_exp = E_aux (exp_orig, annot) in match exp with | E_app (f, args) -> if Bindings.mem f param_map then (* Retrieve instantiation of the type variables of the called function - for the given parameters in the current environment *) - let inst = instantiation_of full_exp in + for the given parameters in the original environment *) + let inst = instantiation_of orig_exp in let kid_exp kid = begin match KBindings.find kid inst with | U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp)) @@ -1075,9 +1083,75 @@ let rewrite_sizeof (Defs defs) = " of function " ^ string_of_id f)) end in let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in - E_aux (E_app (f, kid_exps @ args), annot) - else full_exp - | _ -> full_exp in + (E_aux (E_app (f, kid_exps @ args), annot), orig_exp) + else (full_exp, orig_exp) + | _ -> (full_exp, orig_exp) in + + (* Plug this into a folding algorithm that also keeps around a copy of the + original expressions, which we use to infer instantiations of type variables + in the original environments *) + let copy_exp_alg = + { e_block = (fun es -> let (es, es') = List.split es in (E_block es, E_block es')) + ; e_nondet = (fun es -> let (es, es') = List.split es in (E_nondet es, E_nondet es')) + ; e_id = (fun id -> (E_id id, E_id id)) + ; e_lit = (fun lit -> (E_lit lit, E_lit lit)) + ; e_cast = (fun (typ,(e,e')) -> (E_cast (typ,e), E_cast (typ,e'))) + ; e_app = (fun (id,es) -> let (es, es') = List.split es in (E_app (id,es), E_app (id,es'))) + ; e_app_infix = (fun ((e1,e1'),id,(e2,e2')) -> (E_app_infix (e1,id,e2), E_app_infix (e1',id,e2'))) + ; e_tuple = (fun es -> let (es, es') = List.split es in (E_tuple es, E_tuple es')) + ; e_if = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_if (e1,e2,e3), E_if (e1',e2',e3'))) + ; e_for = (fun (id,(e1,e1'),(e2,e2'),(e3,e3'),order,(e4,e4')) -> (E_for (id,e1,e2,e3,order,e4), E_for (id,e1',e2',e3',order,e4'))) + ; e_vector = (fun es -> let (es, es') = List.split es in (E_vector es, E_vector es')) + ; e_vector_indexed = (fun (es,(opt2,opt2')) -> let (is, es) = List.split es in let (es, es') = List.split es in let (es, es') = (List.combine is es, List.combine is es') in (E_vector_indexed (es,opt2), E_vector_indexed (es',opt2'))) + ; e_vector_access = (fun ((e1,e1'),(e2,e2')) -> (E_vector_access (e1,e2), E_vector_access (e1',e2'))) + ; e_vector_subrange = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_vector_subrange (e1,e2,e3), E_vector_subrange (e1',e2',e3'))) + ; e_vector_update = (fun ((e1,e1'),(e2,e2'),(e3,e3')) -> (E_vector_update (e1,e2,e3), E_vector_update (e1',e2',e3'))) + ; e_vector_update_subrange = (fun ((e1,e1'),(e2,e2'),(e3,e3'),(e4,e4')) -> (E_vector_update_subrange (e1,e2,e3,e4), E_vector_update_subrange (e1',e2',e3',e4'))) + ; e_vector_append = (fun ((e1,e1'),(e2,e2')) -> (E_vector_append (e1,e2), E_vector_append (e1',e2'))) + ; e_list = (fun es -> let (es, es') = List.split es in (E_list es, E_list es')) + ; e_cons = (fun ((e1,e1'),(e2,e2')) -> (E_cons (e1,e2), E_cons (e1',e2'))) + ; e_record = (fun (fexps, fexps') -> (E_record fexps, E_record fexps')) + ; e_record_update = (fun ((e1,e1'),(fexp,fexp')) -> (E_record_update (e1,fexp), E_record_update (e1',fexp'))) + ; e_field = (fun ((e1,e1'),id) -> (E_field (e1,id), E_field (e1',id))) + ; e_case = (fun ((e1,e1'),pexps) -> let (pexps, pexps') = List.split pexps in (E_case (e1,pexps), E_case (e1',pexps'))) + ; e_let = (fun ((lb,lb'),(e2,e2')) -> (E_let (lb,e2), E_let (lb',e2'))) + ; e_assign = (fun ((lexp,lexp'),(e2,e2')) -> (E_assign (lexp,e2), E_assign (lexp',e2'))) + ; e_sizeof = (fun nexp -> (E_sizeof nexp, E_sizeof nexp)) + ; e_exit = (fun (e1,e1') -> (E_exit (e1), E_exit (e1'))) + ; e_return = (fun (e1,e1') -> (E_return e1, E_return e1')) + ; e_assert = (fun ((e1,e1'),(e2,e2')) -> (E_assert(e1,e2), E_assert(e1',e2')) ) + ; e_internal_cast = (fun (a,(e1,e1')) -> (E_internal_cast (a,e1), E_internal_cast (a,e1'))) + ; e_internal_exp = (fun a -> (E_internal_exp a, E_internal_exp a)) + ; e_internal_exp_user = (fun (a1,a2) -> (E_internal_exp_user (a1,a2), E_internal_exp_user (a1,a2))) + ; e_comment = (fun c -> (E_comment c, E_comment c)) + ; e_comment_struc = (fun (e,e') -> (E_comment_struc e, E_comment_struc e')) + ; e_internal_let = (fun ((lexp,lexp'), (e2,e2'), (e3,e3')) -> (E_internal_let (lexp,e2,e3), E_internal_let (lexp',e2',e3'))) + ; e_internal_plet = (fun (pat, (e1,e1'), (e2,e2')) -> (E_internal_plet (pat,e1,e2), E_internal_plet (pat,e1',e2'))) + ; e_internal_return = (fun (e,e') -> (E_internal_return e, E_internal_return e')) + ; e_aux = (fun ((e,e'),annot) -> (E_aux (e,annot), E_aux (e',annot))) + ; lEXP_id = (fun id -> (LEXP_id id, LEXP_id id)) + ; lEXP_memory = (fun (id,es) -> let (es, es') = List.split es in (LEXP_memory (id,es), LEXP_memory (id,es'))) + ; lEXP_cast = (fun (typ,id) -> (LEXP_cast (typ,id), LEXP_cast (typ,id))) + ; lEXP_tup = (fun tups -> let (tups,tups') = List.split tups in (LEXP_tup tups, LEXP_tup tups')) + ; lEXP_vector = (fun ((lexp,lexp'),(e2,e2')) -> (LEXP_vector (lexp,e2), LEXP_vector (lexp',e2'))) + ; lEXP_vector_range = (fun ((lexp,lexp'),(e2,e2'),(e3,e3')) -> (LEXP_vector_range (lexp,e2,e3), LEXP_vector_range (lexp',e2',e3'))) + ; lEXP_field = (fun ((lexp,lexp'),id) -> (LEXP_field (lexp,id), LEXP_field (lexp',id))) + ; lEXP_aux = (fun ((lexp,lexp'),annot) -> (LEXP_aux (lexp,annot), LEXP_aux (lexp',annot))) + ; fE_Fexp = (fun (id,(e,e')) -> (FE_Fexp (id,e), FE_Fexp (id,e'))) + ; fE_aux = (fun ((fexp,fexp'),annot) -> (FE_aux (fexp,annot), FE_aux (fexp',annot))) + ; fES_Fexps = (fun (fexps,b) -> let (fexps, fexps') = List.split fexps in (FES_Fexps (fexps,b), FES_Fexps (fexps',b))) + ; fES_aux = (fun ((fexp,fexp'),annot) -> (FES_aux (fexp,annot), FES_aux (fexp',annot))) + ; def_val_empty = (Def_val_empty, Def_val_empty) + ; def_val_dec = (fun (e,e') -> (Def_val_dec e, Def_val_dec e')) + ; def_val_aux = (fun ((defval,defval'),aux) -> (Def_val_aux (defval,aux), Def_val_aux (defval',aux))) + ; pat_exp = (fun (pat,(e,e')) -> (Pat_exp (pat,e), Pat_exp (pat,e'))) + ; pat_when = (fun (pat,(e1,e1'),(e2,e2')) -> (Pat_when (pat,e1,e2), Pat_when (pat,e1',e2'))) + ; pat_aux = (fun ((pexp,pexp'),a) -> (Pat_aux (pexp,a), Pat_aux (pexp',a))) + ; lB_val_explicit = (fun (typ,pat,(e,e')) -> (LB_val_explicit (typ,pat,e), LB_val_explicit (typ,pat,e'))) + ; lB_val_implicit = (fun (pat,(e,e')) -> (LB_val_implicit (pat,e), LB_val_implicit (pat,e'))) + ; lB_aux = (fun ((lb,lb'),annot) -> (LB_aux (lb,annot), LB_aux (lb',annot))) + ; pat_alg = id_pat_alg + } in let rewrite_sizeof_fun params_map (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) = @@ -1086,7 +1160,7 @@ let rewrite_sizeof (Defs defs) = let body_typ = typ_of exp in let nmap = nexps_from_params pat in (* first rewrite calls to other functions... *) - let exp' = fold_exp { id_exp_alg with e_aux = e_app_aux params_map } exp in + let exp' = fst (fold_exp { copy_exp_alg with e_aux = e_app_aux params_map } exp) in (* ... then rewrite sizeof expressions in current function body *) let exp'' = fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp' in (FCL_aux (FCL_Funcl (id,pat,exp''), annot) :: funcls, @@ -1133,9 +1207,10 @@ let rewrite_sizeof (Defs defs) = let rewrite_sizeof_fundef (params_map, defs) = function | DEF_fundef fd -> let (nvars, fd') = rewrite_sizeof_fun params_map fd in + let id = id_of_fundef fd in let params_map' = if KidSet.is_empty nvars then params_map - else Bindings.add (id_of_fundef fd) nvars params_map in + else Bindings.add id nvars params_map in (params_map', defs @ [DEF_fundef fd']) | def -> (params_map, defs @ [def]) in @@ -1947,7 +2022,7 @@ let rewrite_defs_remove_bitvector_pats (Defs defs) = let defvals = List.map (fun lb -> DEF_val lb) letbinds in [DEF_val (LB_aux (LB_val_implicit (pat',exp),a))] @ defvals | d -> [d] in - Defs (List.flatten (List.map rewrite_def defs)) + fst (check initial_env (Defs (List.flatten (List.map rewrite_def defs)))) (* Remove pattern guards by rewriting them to if-expressions within the @@ -2399,7 +2474,7 @@ let rewrite_defs_letbind_effects = | E_case (exp1,pexps) -> let newreturn = List.fold_left - (fun b (Pat_aux (_,(_,annot))) -> b || effectful_effs (effect_of_annot annot)) + (fun b pexp -> b || effectful_effs (effect_of_pexp pexp)) false pexps in n_exp_name exp1 (fun exp1 -> n_pexpL newreturn pexps (fun pexps -> @@ -2969,6 +3044,7 @@ let rewrite_defs_remove_e_assign = ; rewrite_defs = rewrite_defs_base } +let recheck_defs defs = fst (check initial_env defs) let rewrite_defs_lem =[ top_sort_defs; @@ -2976,6 +3052,7 @@ let rewrite_defs_lem =[ rewrite_defs_remove_vector_concat; rewrite_defs_remove_bitvector_pats; rewrite_defs_guarded_pats; + (* recheck_defs; *) rewrite_defs_exp_lift_assign; rewrite_defs_remove_blocks; rewrite_defs_letbind_effects; |
