diff options
| author | Thomas Bauereiss | 2017-07-26 12:03:12 +0100 |
|---|---|---|
| committer | Thomas Bauereiss | 2017-07-26 12:03:12 +0100 |
| commit | 26e59493cde0ffbf1868426fe3bec158f2dbaad0 (patch) | |
| tree | 2193492e4989608eb5d2fef9ed3a60aa84b9c316 /src | |
| parent | eae4d12ad793809482252be0b459bb7e634b5482 (diff) | |
Improve rewriting of sizeof expressions
If some type-level variables in a sizeof expression in a function body cannot
be directly extracted from the parameters of the function, add a new parameter
for each unresolved parameter, and rewrite calls to the function accordingly
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 13 | ||||
| -rw-r--r-- | src/ast_util.mli | 2 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 2 | ||||
| -rw-r--r-- | src/rewriter.ml | 190 | ||||
| -rw-r--r-- | src/type_check.ml | 24 |
5 files changed, 178 insertions, 53 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 04ce5b07..5bb4e0a6 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -330,6 +330,19 @@ let rec string_of_index_range (BF_aux (ir, _)) = | BF_range (n, m) -> string_of_int n ^ " .. " ^ string_of_int m | BF_concat (ir1, ir2) -> "(" ^ string_of_index_range ir1 ^ ") : (" ^ string_of_index_range ir2 ^ ")" +let id_of_fundef (FD_aux (FD_function (_, _, _, funcls), (l, _))) = + match (List.fold_right + (fun (FCL_aux (FCL_Funcl (id, _, _), _)) id' -> + match id' with + | Some id' -> if string_of_id id' = string_of_id id then Some id' + else raise (Reporting_basic.err_typ l + ("Function declaration expects all definitions to have the same name, " + ^ string_of_id id ^ " differs from other definitions of " ^ string_of_id id')) + | None -> Some id) funcls None) + with + | Some id -> id + | None -> raise (Reporting_basic.err_typ l "funcl list is empty") + module Kid = struct type t = kid let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2) diff --git a/src/ast_util.mli b/src/ast_util.mli index d7f68412..6e22d173 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -84,6 +84,8 @@ val string_of_pat : 'a pat -> string val string_of_letbind : 'a letbind -> string val string_of_index_range : index_range -> string +val id_of_fundef : 'a fundef -> id + module Id : sig type t = id val compare : id -> id -> int diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 9f479cdc..95ddc580 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -274,7 +274,7 @@ let rec doc_pat_lem regtypes apat_needed (P_aux (p,(l,annot)) as pa) = match p w | Id_aux (Id "None",_) -> string "Nothing" (* workaround temporary issue *) | _ -> doc_id_lem id end | P_as(p,id) -> parens (separate space [doc_pat_lem regtypes true p; string "as"; doc_id_lem id]) - | P_typ(typ,p) -> doc_op colon (doc_pat_lem regtypes true p) (doc_typ_lem regtypes typ) + | P_typ(typ,p) -> parens (doc_op colon (doc_pat_lem regtypes true p) (doc_typ_lem regtypes typ)) | P_vector pats -> let ppp = (separate space) diff --git a/src/rewriter.ml b/src/rewriter.ml index 560159d2..166c31f0 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -337,8 +337,8 @@ let rewrite_pat rewriters (P_aux (pat,(l,annot))) = (vector_string_to_bit_list l lit) in rewrap (P_vector ps) | P_lit _ | P_wild | P_id _ -> rewrap pat - | P_as(pat,id) -> rewrap (P_as( rewrite pat, id)) - | P_typ(typ,pat) -> rewrite pat + | P_as(pat,id) -> rewrap (P_as(rewrite pat, id)) + | P_typ(typ,pat) -> rewrap (P_typ(typ, rewrite pat)) | P_app(id ,pats) -> rewrap (P_app(id, List.map rewrite pats)) | P_record(fpats,_) -> rewrap (P_record(List.map (fun (FP_aux(FP_Fpat(id,pat),pannot)) -> FP_aux(FP_Fpat(id, rewrite pat), pannot)) fpats, @@ -958,13 +958,22 @@ let compute_exp_alg bot join = ; pat_alg = compute_pat_alg bot join } -let rewrite_sizeof defs = + +(* Rewrite sizeof expressions with type-level variables to + term-level expressions + + For each type-level variable used in a sizeof expressions whose value cannot + be directly extracted from existing parameters of the surrounding function, + a further parameter is added; calls to the function are rewritten + accordingly (possibly causing further rewriting in the calling function) *) +let rewrite_sizeof (Defs defs) = let sizeof_frees exp = fst (fold_exp { (compute_exp_alg KidSet.empty KidSet.union) with e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) } exp) in + (* Collect nexps whose values can be obtained directly from a pattern bind *) let nexps_from_params pat = fst (fold_pat { (compute_pat_alg [] (@)) with @@ -986,13 +995,14 @@ let rewrite_sizeof defs = | _ -> [] in (v @ v', P_aux (pat,annot)))} pat) in + (* Substitute collected values in sizeof expressions *) let rec e_sizeof nmap (Nexp_aux (nexp, l) as nexp_aux) = try snd (List.find (fun (nexp,_) -> nexp_identical nexp nexp_aux) nmap) with | Not_found -> let binop nexp1 op nexp2 = E_app_infix ( E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)), - Id_aux (Id op, Unknown), + Id_aux (Id op, Parse_ast.Unknown), E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2)) ) in (match nexp with @@ -1002,31 +1012,125 @@ let rewrite_sizeof defs = | Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2 | _ -> E_sizeof nexp_aux) in - let rewrite_sizeof_exp nmap rewriters exp = - let exp = rewriters_base.rewrite_exp rewriters exp in - fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp in - - let rewrite_sizeof_fun rewriters - (FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) = - let rewrite_funcl_body (FCL_aux (FCL_Funcl (id,pat,exp), annot)) = + (* Rewrite calls to functions which have had parameters added to pass values + of type-level variables; these are added as sizeof expressions first, and + then further rewritten as above. *) + let e_app_aux param_map (exp, ((l,_) as annot)) = + let full_exp = E_aux (exp, annot) in + match exp with + | E_app (f, args) -> + if Bindings.mem f param_map then + (* Retrieve instantiation of the type variables of the called function + for the given parameters in the current environment *) + let inst = instantiation_of full_exp in + let kid_exp kid = begin + match KBindings.find kid inst with + | U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp)) + | _ -> + raise (Reporting_basic.err_unreachable l + ("failed to infer nexp for type variable " ^ string_of_kid kid ^ + " of function " ^ string_of_id f)) + end in + let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in + E_aux (E_app (f, kid_exps @ args), annot) + else full_exp + | _ -> full_exp in + + let rewrite_sizeof_fun params_map + (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) = + let rewrite_funcl_body (FCL_aux (FCL_Funcl (id,pat,exp), annot)) (funcls,nvars) = let body_env = env_of exp in let body_typ = typ_of exp in let nmap = nexps_from_params pat in - let exp = - try check_exp body_env - (strip_exp (fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp)) - body_typ - with - | Type_error _ -> exp in - FCL_aux (FCL_Funcl (id,pat,exp), annot) in - let funcls = List.map rewrite_funcl_body funcls in - FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot) in + (* first rewrite calls to other functions... *) + let exp' = fold_exp { id_exp_alg with e_aux = e_app_aux params_map } exp in + (* ... then rewrite sizeof expressions in current function body *) + let exp'' = fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp' in + (FCL_aux (FCL_Funcl (id,pat,exp''), annot) :: funcls, + KidSet.union nvars (sizeof_frees exp'')) in + let (funcls, nvars) = List.fold_right rewrite_funcl_body funcls ([], KidSet.empty) in + (* Add a parameter for each remaining free type-level variable in a + sizeof expression *) + let kid_typ kid = atom_typ (nvar kid) in + let kid_annot kid = simple_annot l (kid_typ kid) in + let kid_pat kid = + P_aux (P_typ (kid_typ kid, + P_aux (P_id (Id_aux (Id (string_of_kid kid), l)), + kid_annot kid)), kid_annot kid) in + let kid_eaux kid = E_id (Id_aux (Id (string_of_kid kid), l)) in + let kid_typs = List.map kid_typ (KidSet.elements nvars) in + let kid_pats = List.map kid_pat (KidSet.elements nvars) in + let kid_nmap = List.map (fun kid -> (nvar kid, kid_eaux kid)) (KidSet.elements nvars) in + let rewrite_funcl_params (FCL_aux (FCL_Funcl (id, pat, exp), annot) as funcl) = + let rec rewrite_pat (P_aux (pat,(l,_)) as paux) = + if KidSet.is_empty nvars then paux else + match pat_typ_of paux with + | Typ_aux (Typ_tup _, _) -> + (match pat with + | P_tup pats -> + P_aux (P_tup (kid_pats @ pats), (l, None)) + | P_wild -> paux + | P_typ (Typ_aux (Typ_tup typs, l), pat) -> + P_aux (P_typ (Typ_aux (Typ_tup (kid_typs @ typs), l), + rewrite_pat pat), (l, None)) + | P_as (_, id) | P_id id -> + (* adding parameters here would change the type of id; + we should remove the P_as/P_id here and add a let-binding to the body *) + raise (Reporting_basic.err_todo l + "rewriting as- or id-patterns for sizeof expressions not yet implemented") + | _ -> + raise (Reporting_basic.err_unreachable l + "unexpected pattern while rewriting function parameters for sizeof expressions")) + | _ -> P_aux (P_tup (kid_pats @ [paux]), (l, None)) in + let exp' = fold_exp { id_exp_alg with e_sizeof = e_sizeof kid_nmap } exp in + FCL_aux (FCL_Funcl (id, rewrite_pat pat, exp'), annot) in + let funcls = List.map rewrite_funcl_params funcls in + (nvars, FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) in + + let rewrite_sizeof_fundef (params_map, defs) = function + | DEF_fundef fd -> + let (nvars, fd') = rewrite_sizeof_fun params_map fd in + let params_map' = + if KidSet.is_empty nvars then params_map + else Bindings.add (id_of_fundef fd) nvars params_map in + (params_map', defs @ [DEF_fundef fd']) + | def -> + (params_map, defs @ [def]) in + + let rewrite_sizeof_valspec params_map def = + let rewrite_typschm (TypSchm_aux (TypSchm_ts (tq, typ), l) as ts) id = + if Bindings.mem id params_map then + let kid_typs = List.map (fun kid -> atom_typ (nvar kid)) + (KidSet.elements (Bindings.find id params_map)) in + let typ' = match typ with + | Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) -> + let vtyp_arg' = begin + match vtyp_arg with + | Typ_aux (Typ_tup typs, vl) -> + Typ_aux (Typ_tup (kid_typs @ typs), vl) + | _ -> Typ_aux (Typ_tup (kid_typs @ [vtyp_arg]), vl) + end in + Typ_aux (Typ_fn (vtyp_arg', vtyp_ret, declared_eff), vl) + | _ -> raise (Reporting_basic.err_typ l + "val spec with non-function type") in + TypSchm_aux (TypSchm_ts (tq, typ'), l) + else ts in + match def with + | DEF_spec (VS_aux (VS_val_spec (typschm, id), a)) -> + DEF_spec (VS_aux (VS_val_spec (rewrite_typschm typschm id, id), a)) + | DEF_spec (VS_aux (VS_extern_no_rename (typschm, id), a)) -> + DEF_spec (VS_aux (VS_extern_no_rename (rewrite_typschm typschm id, id), a)) + | DEF_spec (VS_aux (VS_extern_spec (typschm, id, e), a)) -> + DEF_spec (VS_aux (VS_extern_spec (rewrite_typschm typschm id, id, e), a)) + | DEF_spec (VS_aux (VS_cast_spec (typschm, id), a)) -> + DEF_spec (VS_aux (VS_cast_spec (rewrite_typschm typschm id, id), a)) + | _ -> def in + + let (params_map, defs) = List.fold_left rewrite_sizeof_fundef + (Bindings.empty, []) defs in + let defs = List.map (rewrite_sizeof_valspec params_map) defs in + fst (check initial_env (Defs defs)) - rewrite_defs_base - { rewriters_base with - rewrite_exp = rewrite_sizeof_exp []; - rewrite_fun = rewrite_sizeof_fun } - defs let remove_vector_concat_pat pat = @@ -1040,7 +1144,7 @@ let remove_vector_concat_pat pat = ) } in - let pat = remove_typed_patterns pat in + (* let pat = remove_typed_patterns pat in *) let fresh_id_v = fresh_id "v__" in @@ -1080,10 +1184,11 @@ let remove_vector_concat_pat pat = (* introduce names for all unnamed child nodes of P_vector_concat *) let name_vector_concat_elements = let p_vector_concat pats = - let aux ((P_aux (p,((l,_) as a))) as pat) = match p with + let rec aux ((P_aux (p,((l,_) as a))) as pat) = match p with | P_vector _ -> P_aux (P_as (pat,fresh_id_v l),a) | P_id id -> P_aux (P_id id,a) | P_as (p,id) -> P_aux (P_as (p,id),a) + | P_typ (typ, pat) -> P_aux (P_typ (typ, aux pat),a) | P_wild -> P_aux (P_wild,a) | _ -> raise @@ -1111,7 +1216,7 @@ let remove_vector_concat_pat pat = pat_alg = *) (* build a let-expression of the form "let child = root[i..j] in body" *) - let letbind_vec (rootid,rannot) (child,cannot) (i,j) = + let letbind_vec typ_opt (rootid,rannot) (child,cannot) (i,j) = let (l,_) = cannot in let (Id_aux (Id rootname,_)) = rootid in let (Id_aux (Id childname,_)) = child in @@ -1122,7 +1227,11 @@ let remove_vector_concat_pat pat = let subv = fix_eff_exp (E_aux (E_vector_subrange (root, index_i, index_j), cannot)) in - let letbind = fix_eff_lb (LB_aux (LB_val_implicit (P_aux (P_id child,cannot),subv),cannot)) in + let id_pat = + match typ_opt with + | Some typ -> P_aux (P_typ (typ, P_aux (P_id child,cannot)), cannot) + | None -> P_aux (P_id child,cannot) in + let letbind = fix_eff_lb (LB_aux (LB_val_implicit (id_pat,subv),cannot)) in (letbind, (fun body -> fix_eff_exp (E_aux (E_let (letbind,body), simple_annot l (typ_of body)))), (rootname,childname)) in @@ -1136,7 +1245,7 @@ let remove_vector_concat_pat pat = | _ -> raise (Reporting_basic.err_unreachable (fst rannot') ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern"))) in - let aux (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) = + let rec aux typ_opt (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) = let ctyp = Env.base_typ_of (env_of_annot cannot) (typ_of_annot cannot) in let (_,length,ord,_) = vector_typ_args_of ctyp in (*)| (_,length,ord,_) ->*) @@ -1154,12 +1263,13 @@ let remove_vector_concat_pat pat = (* if we see a named vector pattern, remove the name and remember to declare it later *) | P_as (P_aux (p,cannot),cname) -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in + let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) (* if we see a P_id variable, remember to declare it later *) | P_id cname -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in + let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) + | P_typ (typ, pat) -> aux (Some typ) (pos,pat_acc,decl_acc) (pat, is_last) (* normal vector patterns are fine *) | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc) ) (* non-vector patterns aren't *) @@ -1171,7 +1281,7 @@ let remove_vector_concat_pat pat = string_of_typ (typ_of_annot cannot)) )*) in let pats_tagged = tag_last pats in - let (_,pats',decls') = List.fold_left aux (start,[],[]) pats_tagged in + let (_,pats',decls') = List.fold_left (aux None) (start,[],[]) pats_tagged in (* abuse P_vector_concat as a P_vector_const pattern: it has the of patterns as an argument but they're meant to be consed together *) @@ -1935,12 +2045,12 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base rewrite_def = rewrite_def; rewrite_defs = rewrite_defs_base} defs*) -let rewrite_defs_ocaml defs = - let defs_sorted = top_sort_defs defs in - let defs_vec_concat_removed = rewrite_defs_remove_vector_concat defs_sorted in - let defs_lifted_assign = rewrite_defs_exp_lift_assign defs_vec_concat_removed in -(* let defs_separate_nums = rewrite_defs_separate_numbs defs_lifted_assign in *) - defs_lifted_assign +let rewrite_defs_ocaml = + top_sort_defs >> + rewrite_defs_remove_vector_concat >> + rewrite_sizeof >> + rewrite_defs_exp_lift_assign (* >> + rewrite_defs_separate_numbs *) let rewrite_defs_remove_blocks = let letbind_wild v body = @@ -2746,9 +2856,9 @@ let rewrite_defs_remove_e_assign = let rewrite_defs_lem = top_sort_defs >> - rewrite_sizeof >> rewrite_defs_remove_vector_concat >> rewrite_defs_remove_bitvector_pats >> + rewrite_sizeof >> rewrite_defs_exp_lift_assign >> rewrite_defs_remove_blocks >> rewrite_defs_letbind_effects >> diff --git a/src/type_check.ml b/src/type_check.ml index ff5f1512..c2351a8a 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -1889,6 +1889,17 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env | P_lit lit -> annot_pat (P_lit lit) (infer_lit env lit), env + | P_vector (pat :: pats) -> + let fold_pats (pats, env) pat = + let inferred_pat, env = infer_pat env pat in + pats @ [inferred_pat], env + in + let ((inferred_pat :: inferred_pats) as pats), env = + List.fold_left fold_pats ([], env) (pat :: pats) in + let len = nexp_simp (nconstant (List.length pats)) in + let etyp = pat_typ_of inferred_pat in + List.map (fun pat -> typ_equality l env etyp (pat_typ_of pat)) pats; + annot_pat (P_vector pats) (lvector_typ env len etyp), env | P_vector_concat (pat :: pats) -> let fold_pats (pats, env) pat = let inferred_pat, env = infer_pat env pat in @@ -2579,18 +2590,7 @@ let check_tannotopt typq ret_typ = function else typ_error l (string_of_bind (typq, ret_typ) ^ " and " ^ string_of_bind (annot_typq, annot_ret_typ) ^ " do not match between function and val spec") let check_fundef env (FD_aux (FD_function (recopt, tannotopt, effectopt, funcls), (l, _)) as fd_aux) = - let id = - match (List.fold_right - (fun (FCL_aux (FCL_Funcl (id, _, _), _)) id' -> - match id' with - | Some id' -> if string_of_id id' = string_of_id id then Some id' - else typ_error l ("Function declaration expects all definitions to have the same name, " - ^ string_of_id id ^ " differs from other definitions of " ^ string_of_id id') - | None -> Some id) funcls None) - with - | Some id -> id - | None -> typ_error l "funcl list is empty" - in + let id = id_of_fundef fd_aux in typ_print ("\nChecking function " ^ string_of_id id); let have_val_spec, (quant, (Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) as typ)), env = try true, Env.get_val_spec id env, env with |
