summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorThomas Bauereiss2017-07-26 12:03:12 +0100
committerThomas Bauereiss2017-07-26 12:03:12 +0100
commit26e59493cde0ffbf1868426fe3bec158f2dbaad0 (patch)
tree2193492e4989608eb5d2fef9ed3a60aa84b9c316 /src
parenteae4d12ad793809482252be0b459bb7e634b5482 (diff)
Improve rewriting of sizeof expressions
If some type-level variables in a sizeof expression in a function body cannot be directly extracted from the parameters of the function, add a new parameter for each unresolved parameter, and rewrite calls to the function accordingly
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml13
-rw-r--r--src/ast_util.mli2
-rw-r--r--src/pretty_print_lem.ml2
-rw-r--r--src/rewriter.ml190
-rw-r--r--src/type_check.ml24
5 files changed, 178 insertions, 53 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 04ce5b07..5bb4e0a6 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -330,6 +330,19 @@ let rec string_of_index_range (BF_aux (ir, _)) =
| BF_range (n, m) -> string_of_int n ^ " .. " ^ string_of_int m
| BF_concat (ir1, ir2) -> "(" ^ string_of_index_range ir1 ^ ") : (" ^ string_of_index_range ir2 ^ ")"
+let id_of_fundef (FD_aux (FD_function (_, _, _, funcls), (l, _))) =
+ match (List.fold_right
+ (fun (FCL_aux (FCL_Funcl (id, _, _), _)) id' ->
+ match id' with
+ | Some id' -> if string_of_id id' = string_of_id id then Some id'
+ else raise (Reporting_basic.err_typ l
+ ("Function declaration expects all definitions to have the same name, "
+ ^ string_of_id id ^ " differs from other definitions of " ^ string_of_id id'))
+ | None -> Some id) funcls None)
+ with
+ | Some id -> id
+ | None -> raise (Reporting_basic.err_typ l "funcl list is empty")
+
module Kid = struct
type t = kid
let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2)
diff --git a/src/ast_util.mli b/src/ast_util.mli
index d7f68412..6e22d173 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -84,6 +84,8 @@ val string_of_pat : 'a pat -> string
val string_of_letbind : 'a letbind -> string
val string_of_index_range : index_range -> string
+val id_of_fundef : 'a fundef -> id
+
module Id : sig
type t = id
val compare : id -> id -> int
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 9f479cdc..95ddc580 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -274,7 +274,7 @@ let rec doc_pat_lem regtypes apat_needed (P_aux (p,(l,annot)) as pa) = match p w
| Id_aux (Id "None",_) -> string "Nothing" (* workaround temporary issue *)
| _ -> doc_id_lem id end
| P_as(p,id) -> parens (separate space [doc_pat_lem regtypes true p; string "as"; doc_id_lem id])
- | P_typ(typ,p) -> doc_op colon (doc_pat_lem regtypes true p) (doc_typ_lem regtypes typ)
+ | P_typ(typ,p) -> parens (doc_op colon (doc_pat_lem regtypes true p) (doc_typ_lem regtypes typ))
| P_vector pats ->
let ppp =
(separate space)
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 560159d2..166c31f0 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -337,8 +337,8 @@ let rewrite_pat rewriters (P_aux (pat,(l,annot))) =
(vector_string_to_bit_list l lit) in
rewrap (P_vector ps)
| P_lit _ | P_wild | P_id _ -> rewrap pat
- | P_as(pat,id) -> rewrap (P_as( rewrite pat, id))
- | P_typ(typ,pat) -> rewrite pat
+ | P_as(pat,id) -> rewrap (P_as(rewrite pat, id))
+ | P_typ(typ,pat) -> rewrap (P_typ(typ, rewrite pat))
| P_app(id ,pats) -> rewrap (P_app(id, List.map rewrite pats))
| P_record(fpats,_) ->
rewrap (P_record(List.map (fun (FP_aux(FP_Fpat(id,pat),pannot)) -> FP_aux(FP_Fpat(id, rewrite pat), pannot)) fpats,
@@ -958,13 +958,22 @@ let compute_exp_alg bot join =
; pat_alg = compute_pat_alg bot join
}
-let rewrite_sizeof defs =
+
+(* Rewrite sizeof expressions with type-level variables to
+ term-level expressions
+
+ For each type-level variable used in a sizeof expressions whose value cannot
+ be directly extracted from existing parameters of the surrounding function,
+ a further parameter is added; calls to the function are rewritten
+ accordingly (possibly causing further rewriting in the calling function) *)
+let rewrite_sizeof (Defs defs) =
let sizeof_frees exp =
fst (fold_exp
{ (compute_exp_alg KidSet.empty KidSet.union) with
e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) }
exp) in
+ (* Collect nexps whose values can be obtained directly from a pattern bind *)
let nexps_from_params pat =
fst (fold_pat
{ (compute_pat_alg [] (@)) with
@@ -986,13 +995,14 @@ let rewrite_sizeof defs =
| _ -> [] in
(v @ v', P_aux (pat,annot)))} pat) in
+ (* Substitute collected values in sizeof expressions *)
let rec e_sizeof nmap (Nexp_aux (nexp, l) as nexp_aux) =
try snd (List.find (fun (nexp,_) -> nexp_identical nexp nexp_aux) nmap)
with
| Not_found ->
let binop nexp1 op nexp2 = E_app_infix (
E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)),
- Id_aux (Id op, Unknown),
+ Id_aux (Id op, Parse_ast.Unknown),
E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2))
) in
(match nexp with
@@ -1002,31 +1012,125 @@ let rewrite_sizeof defs =
| Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2
| _ -> E_sizeof nexp_aux) in
- let rewrite_sizeof_exp nmap rewriters exp =
- let exp = rewriters_base.rewrite_exp rewriters exp in
- fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp in
-
- let rewrite_sizeof_fun rewriters
- (FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) =
- let rewrite_funcl_body (FCL_aux (FCL_Funcl (id,pat,exp), annot)) =
+ (* Rewrite calls to functions which have had parameters added to pass values
+ of type-level variables; these are added as sizeof expressions first, and
+ then further rewritten as above. *)
+ let e_app_aux param_map (exp, ((l,_) as annot)) =
+ let full_exp = E_aux (exp, annot) in
+ match exp with
+ | E_app (f, args) ->
+ if Bindings.mem f param_map then
+ (* Retrieve instantiation of the type variables of the called function
+ for the given parameters in the current environment *)
+ let inst = instantiation_of full_exp in
+ let kid_exp kid = begin
+ match KBindings.find kid inst with
+ | U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp))
+ | _ ->
+ raise (Reporting_basic.err_unreachable l
+ ("failed to infer nexp for type variable " ^ string_of_kid kid ^
+ " of function " ^ string_of_id f))
+ end in
+ let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in
+ E_aux (E_app (f, kid_exps @ args), annot)
+ else full_exp
+ | _ -> full_exp in
+
+ let rewrite_sizeof_fun params_map
+ (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) =
+ let rewrite_funcl_body (FCL_aux (FCL_Funcl (id,pat,exp), annot)) (funcls,nvars) =
let body_env = env_of exp in
let body_typ = typ_of exp in
let nmap = nexps_from_params pat in
- let exp =
- try check_exp body_env
- (strip_exp (fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp))
- body_typ
- with
- | Type_error _ -> exp in
- FCL_aux (FCL_Funcl (id,pat,exp), annot) in
- let funcls = List.map rewrite_funcl_body funcls in
- FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot) in
+ (* first rewrite calls to other functions... *)
+ let exp' = fold_exp { id_exp_alg with e_aux = e_app_aux params_map } exp in
+ (* ... then rewrite sizeof expressions in current function body *)
+ let exp'' = fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp' in
+ (FCL_aux (FCL_Funcl (id,pat,exp''), annot) :: funcls,
+ KidSet.union nvars (sizeof_frees exp'')) in
+ let (funcls, nvars) = List.fold_right rewrite_funcl_body funcls ([], KidSet.empty) in
+ (* Add a parameter for each remaining free type-level variable in a
+ sizeof expression *)
+ let kid_typ kid = atom_typ (nvar kid) in
+ let kid_annot kid = simple_annot l (kid_typ kid) in
+ let kid_pat kid =
+ P_aux (P_typ (kid_typ kid,
+ P_aux (P_id (Id_aux (Id (string_of_kid kid), l)),
+ kid_annot kid)), kid_annot kid) in
+ let kid_eaux kid = E_id (Id_aux (Id (string_of_kid kid), l)) in
+ let kid_typs = List.map kid_typ (KidSet.elements nvars) in
+ let kid_pats = List.map kid_pat (KidSet.elements nvars) in
+ let kid_nmap = List.map (fun kid -> (nvar kid, kid_eaux kid)) (KidSet.elements nvars) in
+ let rewrite_funcl_params (FCL_aux (FCL_Funcl (id, pat, exp), annot) as funcl) =
+ let rec rewrite_pat (P_aux (pat,(l,_)) as paux) =
+ if KidSet.is_empty nvars then paux else
+ match pat_typ_of paux with
+ | Typ_aux (Typ_tup _, _) ->
+ (match pat with
+ | P_tup pats ->
+ P_aux (P_tup (kid_pats @ pats), (l, None))
+ | P_wild -> paux
+ | P_typ (Typ_aux (Typ_tup typs, l), pat) ->
+ P_aux (P_typ (Typ_aux (Typ_tup (kid_typs @ typs), l),
+ rewrite_pat pat), (l, None))
+ | P_as (_, id) | P_id id ->
+ (* adding parameters here would change the type of id;
+ we should remove the P_as/P_id here and add a let-binding to the body *)
+ raise (Reporting_basic.err_todo l
+ "rewriting as- or id-patterns for sizeof expressions not yet implemented")
+ | _ ->
+ raise (Reporting_basic.err_unreachable l
+ "unexpected pattern while rewriting function parameters for sizeof expressions"))
+ | _ -> P_aux (P_tup (kid_pats @ [paux]), (l, None)) in
+ let exp' = fold_exp { id_exp_alg with e_sizeof = e_sizeof kid_nmap } exp in
+ FCL_aux (FCL_Funcl (id, rewrite_pat pat, exp'), annot) in
+ let funcls = List.map rewrite_funcl_params funcls in
+ (nvars, FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) in
+
+ let rewrite_sizeof_fundef (params_map, defs) = function
+ | DEF_fundef fd ->
+ let (nvars, fd') = rewrite_sizeof_fun params_map fd in
+ let params_map' =
+ if KidSet.is_empty nvars then params_map
+ else Bindings.add (id_of_fundef fd) nvars params_map in
+ (params_map', defs @ [DEF_fundef fd'])
+ | def ->
+ (params_map, defs @ [def]) in
+
+ let rewrite_sizeof_valspec params_map def =
+ let rewrite_typschm (TypSchm_aux (TypSchm_ts (tq, typ), l) as ts) id =
+ if Bindings.mem id params_map then
+ let kid_typs = List.map (fun kid -> atom_typ (nvar kid))
+ (KidSet.elements (Bindings.find id params_map)) in
+ let typ' = match typ with
+ | Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) ->
+ let vtyp_arg' = begin
+ match vtyp_arg with
+ | Typ_aux (Typ_tup typs, vl) ->
+ Typ_aux (Typ_tup (kid_typs @ typs), vl)
+ | _ -> Typ_aux (Typ_tup (kid_typs @ [vtyp_arg]), vl)
+ end in
+ Typ_aux (Typ_fn (vtyp_arg', vtyp_ret, declared_eff), vl)
+ | _ -> raise (Reporting_basic.err_typ l
+ "val spec with non-function type") in
+ TypSchm_aux (TypSchm_ts (tq, typ'), l)
+ else ts in
+ match def with
+ | DEF_spec (VS_aux (VS_val_spec (typschm, id), a)) ->
+ DEF_spec (VS_aux (VS_val_spec (rewrite_typschm typschm id, id), a))
+ | DEF_spec (VS_aux (VS_extern_no_rename (typschm, id), a)) ->
+ DEF_spec (VS_aux (VS_extern_no_rename (rewrite_typschm typschm id, id), a))
+ | DEF_spec (VS_aux (VS_extern_spec (typschm, id, e), a)) ->
+ DEF_spec (VS_aux (VS_extern_spec (rewrite_typschm typschm id, id, e), a))
+ | DEF_spec (VS_aux (VS_cast_spec (typschm, id), a)) ->
+ DEF_spec (VS_aux (VS_cast_spec (rewrite_typschm typschm id, id), a))
+ | _ -> def in
+
+ let (params_map, defs) = List.fold_left rewrite_sizeof_fundef
+ (Bindings.empty, []) defs in
+ let defs = List.map (rewrite_sizeof_valspec params_map) defs in
+ fst (check initial_env (Defs defs))
- rewrite_defs_base
- { rewriters_base with
- rewrite_exp = rewrite_sizeof_exp [];
- rewrite_fun = rewrite_sizeof_fun }
- defs
let remove_vector_concat_pat pat =
@@ -1040,7 +1144,7 @@ let remove_vector_concat_pat pat =
)
} in
- let pat = remove_typed_patterns pat in
+ (* let pat = remove_typed_patterns pat in *)
let fresh_id_v = fresh_id "v__" in
@@ -1080,10 +1184,11 @@ let remove_vector_concat_pat pat =
(* introduce names for all unnamed child nodes of P_vector_concat *)
let name_vector_concat_elements =
let p_vector_concat pats =
- let aux ((P_aux (p,((l,_) as a))) as pat) = match p with
+ let rec aux ((P_aux (p,((l,_) as a))) as pat) = match p with
| P_vector _ -> P_aux (P_as (pat,fresh_id_v l),a)
| P_id id -> P_aux (P_id id,a)
| P_as (p,id) -> P_aux (P_as (p,id),a)
+ | P_typ (typ, pat) -> P_aux (P_typ (typ, aux pat),a)
| P_wild -> P_aux (P_wild,a)
| _ ->
raise
@@ -1111,7 +1216,7 @@ let remove_vector_concat_pat pat =
pat_alg = *)
(* build a let-expression of the form "let child = root[i..j] in body" *)
- let letbind_vec (rootid,rannot) (child,cannot) (i,j) =
+ let letbind_vec typ_opt (rootid,rannot) (child,cannot) (i,j) =
let (l,_) = cannot in
let (Id_aux (Id rootname,_)) = rootid in
let (Id_aux (Id childname,_)) = child in
@@ -1122,7 +1227,11 @@ let remove_vector_concat_pat pat =
let subv = fix_eff_exp (E_aux (E_vector_subrange (root, index_i, index_j), cannot)) in
- let letbind = fix_eff_lb (LB_aux (LB_val_implicit (P_aux (P_id child,cannot),subv),cannot)) in
+ let id_pat =
+ match typ_opt with
+ | Some typ -> P_aux (P_typ (typ, P_aux (P_id child,cannot)), cannot)
+ | None -> P_aux (P_id child,cannot) in
+ let letbind = fix_eff_lb (LB_aux (LB_val_implicit (id_pat,subv),cannot)) in
(letbind,
(fun body -> fix_eff_exp (E_aux (E_let (letbind,body), simple_annot l (typ_of body)))),
(rootname,childname)) in
@@ -1136,7 +1245,7 @@ let remove_vector_concat_pat pat =
| _ ->
raise (Reporting_basic.err_unreachable (fst rannot')
("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern"))) in
- let aux (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) =
+ let rec aux typ_opt (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) =
let ctyp = Env.base_typ_of (env_of_annot cannot) (typ_of_annot cannot) in
let (_,length,ord,_) = vector_typ_args_of ctyp in
(*)| (_,length,ord,_) ->*)
@@ -1154,12 +1263,13 @@ let remove_vector_concat_pat pat =
(* if we see a named vector pattern, remove the name and remember to
declare it later *)
| P_as (P_aux (p,cannot),cname) ->
- let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in
+ let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in
(pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)])
(* if we see a P_id variable, remember to declare it later *)
| P_id cname ->
- let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in
+ let (lb,decl,info) = letbind_vec typ_opt (rootid,rannot) (cname,cannot) (pos,index_j) in
(pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)])
+ | P_typ (typ, pat) -> aux (Some typ) (pos,pat_acc,decl_acc) (pat, is_last)
(* normal vector patterns are fine *)
| _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc) )
(* non-vector patterns aren't *)
@@ -1171,7 +1281,7 @@ let remove_vector_concat_pat pat =
string_of_typ (typ_of_annot cannot))
)*) in
let pats_tagged = tag_last pats in
- let (_,pats',decls') = List.fold_left aux (start,[],[]) pats_tagged in
+ let (_,pats',decls') = List.fold_left (aux None) (start,[],[]) pats_tagged in
(* abuse P_vector_concat as a P_vector_const pattern: it has the of
patterns as an argument but they're meant to be consed together *)
@@ -1935,12 +2045,12 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base
rewrite_def = rewrite_def;
rewrite_defs = rewrite_defs_base} defs*)
-let rewrite_defs_ocaml defs =
- let defs_sorted = top_sort_defs defs in
- let defs_vec_concat_removed = rewrite_defs_remove_vector_concat defs_sorted in
- let defs_lifted_assign = rewrite_defs_exp_lift_assign defs_vec_concat_removed in
-(* let defs_separate_nums = rewrite_defs_separate_numbs defs_lifted_assign in *)
- defs_lifted_assign
+let rewrite_defs_ocaml =
+ top_sort_defs >>
+ rewrite_defs_remove_vector_concat >>
+ rewrite_sizeof >>
+ rewrite_defs_exp_lift_assign (* >>
+ rewrite_defs_separate_numbs *)
let rewrite_defs_remove_blocks =
let letbind_wild v body =
@@ -2746,9 +2856,9 @@ let rewrite_defs_remove_e_assign =
let rewrite_defs_lem =
top_sort_defs >>
- rewrite_sizeof >>
rewrite_defs_remove_vector_concat >>
rewrite_defs_remove_bitvector_pats >>
+ rewrite_sizeof >>
rewrite_defs_exp_lift_assign >>
rewrite_defs_remove_blocks >>
rewrite_defs_letbind_effects >>
diff --git a/src/type_check.ml b/src/type_check.ml
index ff5f1512..c2351a8a 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -1889,6 +1889,17 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) =
annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env
| P_lit lit ->
annot_pat (P_lit lit) (infer_lit env lit), env
+ | P_vector (pat :: pats) ->
+ let fold_pats (pats, env) pat =
+ let inferred_pat, env = infer_pat env pat in
+ pats @ [inferred_pat], env
+ in
+ let ((inferred_pat :: inferred_pats) as pats), env =
+ List.fold_left fold_pats ([], env) (pat :: pats) in
+ let len = nexp_simp (nconstant (List.length pats)) in
+ let etyp = pat_typ_of inferred_pat in
+ List.map (fun pat -> typ_equality l env etyp (pat_typ_of pat)) pats;
+ annot_pat (P_vector pats) (lvector_typ env len etyp), env
| P_vector_concat (pat :: pats) ->
let fold_pats (pats, env) pat =
let inferred_pat, env = infer_pat env pat in
@@ -2579,18 +2590,7 @@ let check_tannotopt typq ret_typ = function
else typ_error l (string_of_bind (typq, ret_typ) ^ " and " ^ string_of_bind (annot_typq, annot_ret_typ) ^ " do not match between function and val spec")
let check_fundef env (FD_aux (FD_function (recopt, tannotopt, effectopt, funcls), (l, _)) as fd_aux) =
- let id =
- match (List.fold_right
- (fun (FCL_aux (FCL_Funcl (id, _, _), _)) id' ->
- match id' with
- | Some id' -> if string_of_id id' = string_of_id id then Some id'
- else typ_error l ("Function declaration expects all definitions to have the same name, "
- ^ string_of_id id ^ " differs from other definitions of " ^ string_of_id id')
- | None -> Some id) funcls None)
- with
- | Some id -> id
- | None -> typ_error l "funcl list is empty"
- in
+ let id = id_of_fundef fd_aux in
typ_print ("\nChecking function " ^ string_of_id id);
let have_val_spec, (quant, (Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) as typ)), env =
try true, Env.get_val_spec id env, env with