summaryrefslogtreecommitdiff
path: root/src/rewriter.ml
diff options
context:
space:
mode:
authorBrian Campbell2017-08-23 11:11:08 +0100
committerBrian Campbell2017-08-23 11:11:08 +0100
commit22c2e970e9e52ff60b8262d02b4f50ad12174fd8 (patch)
treee05bc639514a511d4d39399b8a263e817897e4fe /src/rewriter.ml
parent2a6f3b8e42a4cb4cececb79a9011346b5b25ce80 (diff)
parentc380d2d0b51be71871085ac7d085268f5baccb56 (diff)
Merge branch 'experiments' into mono-experiments
Diffstat (limited to 'src/rewriter.ml')
-rw-r--r--src/rewriter.ml415
1 files changed, 288 insertions, 127 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index ef4a209c..d61939ee 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -566,14 +566,7 @@ let rewrite_lexp rewriters (LEXP_aux(lexp,(l,annot))) =
let rewrite_fun rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) =
let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),(l,annot))) =
let _ = reset_fresh_name_counter () in
- (*let _ = Printf.eprintf "Rewriting function %s, pattern %s\n"
- (match id with (Id_aux (Id i,_)) -> i) (Pretty_print.pat_to_string pat) in*)
- (*let map = get_map_tannot fdannot in
- let map =
- match map with
- | None -> None
- | Some m -> Some(m, Envmap.empty) in*)
- (FCL_aux (FCL_Funcl (id,rewriters.rewrite_pat rewriters pat,
+ (FCL_aux (FCL_Funcl (id,rewriters.rewrite_pat rewriters pat,
rewriters.rewrite_exp rewriters exp),(l,annot)))
in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot))
@@ -943,12 +936,12 @@ let compute_exp_alg bot join =
; e_tuple = split_join (fun es -> E_tuple es)
; e_if = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_if (e1,e2,e3)))
; e_for = (fun (id,(v1,e1),(v2,e2),(v3,e3),order,(v4,e4)) ->
- (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4)))
+ (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4)))
; e_vector = split_join (fun es -> E_vector es)
; e_vector_indexed = (fun (es,(v2,opt2)) ->
- let (is,es) = List.split es in
- let (vs,es) = List.split es in
- (join_list (vs @ [v2]), E_vector_indexed (List.combine is es,opt2)))
+ let (is,es) = List.split es in
+ let (vs,es) = List.split es in
+ (join_list (vs @ [v2]), E_vector_indexed (List.combine is es,opt2)))
; e_vector_access = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_access (e1,e2)))
; e_vector_subrange = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_subrange (e1,e2,e3)))
; e_vector_update = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_update (e1,e2,e3)))
@@ -960,8 +953,8 @@ let compute_exp_alg bot join =
; e_record_update = (fun ((v1,e1),(vf,fexp)) -> (join v1 vf, E_record_update (e1,fexp)))
; e_field = (fun ((v1,e1),id) -> (v1, E_field (e1,id)))
; e_case = (fun ((v1,e1),pexps) ->
- let (vps,pexps) = List.split pexps in
- (join_list (v1::vps), E_case (e1,pexps)))
+ let (vps,pexps) = List.split pexps in
+ (join_list (v1::vps), E_case (e1,pexps)))
; e_let = (fun ((vl,lb),(v2,e2)) -> (join vl v2, E_let (lb,e2)))
; e_assign = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, E_assign (lexp,e2)))
; e_sizeof = (fun nexp -> (bot, E_sizeof nexp))
@@ -975,27 +968,27 @@ let compute_exp_alg bot join =
; e_comment = (fun c -> (bot, E_comment c))
; e_comment_struc = (fun (v,e) -> (bot, E_comment_struc e)) (* ignore value by default, since it is comes from a comment *)
; e_internal_let = (fun ((vl, lexp), (v2,e2), (v3,e3)) ->
- (join_list [vl;v2;v3], E_internal_let (lexp,e2,e3)))
+ (join_list [vl;v2;v3], E_internal_let (lexp,e2,e3)))
; e_internal_plet = (fun ((vp,pat), (v1,e1), (v2,e2)) ->
- (join_list [vp;v1;v2], E_internal_plet (pat,e1,e2)))
+ (join_list [vp;v1;v2], E_internal_plet (pat,e1,e2)))
; e_internal_return = (fun (v,e) -> (v, E_internal_return e))
; e_aux = (fun ((v,e),annot) -> (v, E_aux (e,annot)))
; lEXP_id = (fun id -> (bot, LEXP_id id))
; lEXP_memory = (fun (id,es) -> split_join (fun es -> LEXP_memory (id,es)) es)
; lEXP_cast = (fun (typ,id) -> (bot, LEXP_cast (typ,id)))
; lEXP_tup = (fun ls ->
- let (vs,ls) = List.split ls in
- (join_list vs, LEXP_tup ls))
+ let (vs,ls) = List.split ls in
+ (join_list vs, LEXP_tup ls))
; lEXP_vector = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, LEXP_vector (lexp,e2)))
; lEXP_vector_range = (fun ((vl,lexp),(v2,e2),(v3,e3)) ->
- (join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3)))
+ (join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3)))
; lEXP_field = (fun ((vl,lexp),id) -> (vl, LEXP_field (lexp,id)))
; lEXP_aux = (fun ((vl,lexp),annot) -> (vl, LEXP_aux (lexp,annot)))
; fE_Fexp = (fun (id,(v,e)) -> (v, FE_Fexp (id,e)))
; fE_aux = (fun ((vf,fexp),annot) -> (vf, FE_aux (fexp,annot)))
; fES_Fexps = (fun (fexps,b) ->
- let (vs,fexps) = List.split fexps in
- (join_list vs, FES_Fexps (fexps,b)))
+ let (vs,fexps) = List.split fexps in
+ (join_list vs, FES_Fexps (fexps,b)))
; fES_aux = (fun ((vf,fexp),annot) -> (vf, FES_aux (fexp,annot)))
; def_val_empty = (bot, Def_val_empty)
; def_val_dec = (fun (v,e) -> (v, Def_val_dec e))
@@ -1009,6 +1002,43 @@ let compute_exp_alg bot join =
; pat_alg = compute_pat_alg bot join
}
+(* Re-write trivial sizeof expressions - trivial meaning that the
+ value of the sizeof can be directly inferred from the type
+ variables in scope. *)
+let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp =
+ let extract_typ_var l env nexp (id, (_, typ)) =
+ let var = E_aux (E_id id, (l, Some (env, typ, no_effect))) in
+ match destruct_atom_nexp env typ with
+ | Some size when prove env (nc_eq size nexp) -> Some var
+ | _ ->
+ begin
+ match destruct_vector env typ with
+ | Some (_, len, _, _) when prove env (nc_eq len nexp) ->
+ Some (E_aux (E_app (mk_id "length", [var]), (l, Some (env, atom_typ len, no_effect))))
+ | _ -> None
+ end
+ in
+ let rewrite_e_aux (E_aux (e_aux, (l, _)) as orig_exp) =
+ let env = env_of orig_exp in
+ match e_aux with
+ | E_sizeof (Nexp_aux (Nexp_constant c, _) as nexp) ->
+ E_aux (E_lit (L_aux (L_num c, l)), (l, Some (env, atom_typ nexp, no_effect)))
+ | E_sizeof nexp ->
+ begin
+ let locals = Env.get_locals env in
+ let exps = Bindings.bindings locals
+ |> List.map (extract_typ_var l env nexp)
+ |> List.map (fun opt -> match opt with Some x -> [x] | None -> [])
+ |> List.concat
+ in
+ match exps with
+ | (exp :: _) -> exp
+ | [] -> orig_exp
+ end
+ | _ -> orig_exp
+ in
+ let rewrite_e_constraint = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in
+ rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_e_constraint) }, rewrite_e_aux
(* Rewrite sizeof expressions with type-level variables to
term-level expressions
@@ -1020,78 +1050,91 @@ let compute_exp_alg bot join =
let rewrite_sizeof (Defs defs) =
let sizeof_frees exp =
fst (fold_exp
- { (compute_exp_alg KidSet.empty KidSet.union) with
- e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) }
- exp) in
+ { (compute_exp_alg KidSet.empty KidSet.union) with
+ e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) }
+ exp) in
(* Collect nexps whose values can be obtained directly from a pattern bind *)
let nexps_from_params pat =
fst (fold_pat
- { (compute_pat_alg [] (@)) with
- p_aux = (fun ((v,pat),((l,_) as annot)) ->
- let v' = match pat with
- | P_id id | P_as (_, id) ->
- let (Typ_aux (typ,_) as typ_aux) = typ_of_annot annot in
- (match typ with
- | Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp nexp, _)])
- when string_of_id atom = "atom" ->
- [nexp, E_id id]
- | Typ_app (vector, _) when string_of_id vector = "vector" ->
- let id_length = Id_aux (Id "length", Parse_ast.Generated l) in
- (try
- (match Env.get_val_spec id_length (env_of_annot annot) with
- | _ ->
- let (_,len,_,_) = vector_typ_args_of typ_aux in
- let exp = E_app (id_length, [E_aux (E_id id, annot)]) in
- [len, exp])
- with
- | _ -> [])
- | _ -> [])
- | _ -> [] in
- (v @ v', P_aux (pat,annot)))} pat) in
+ { (compute_pat_alg [] (@)) with
+ p_aux = (fun ((v,pat),((l,_) as annot)) ->
+ let v' = match pat with
+ | P_id id | P_as (_, id) ->
+ let (Typ_aux (typ,_) as typ_aux) = typ_of_annot annot in
+ (match typ with
+ | Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp nexp, _)])
+ when string_of_id atom = "atom" ->
+ [nexp, E_id id]
+ | Typ_app (vector, _) when string_of_id vector = "vector" ->
+ let id_length = Id_aux (Id "length", Parse_ast.Generated l) in
+ (try
+ (match Env.get_val_spec id_length (env_of_annot annot) with
+ | _ ->
+ let (_,len,_,_) = vector_typ_args_of typ_aux in
+ let exp = E_app (id_length, [E_aux (E_id id, annot)]) in
+ [len, exp])
+ with
+ | _ -> [])
+ | _ -> [])
+ | _ -> [] in
+ (v @ v', P_aux (pat,annot)))} pat) in
(* Substitute collected values in sizeof expressions *)
let rec e_sizeof nmap (Nexp_aux (nexp, l) as nexp_aux) =
try snd (List.find (fun (nexp,_) -> nexp_identical nexp nexp_aux) nmap)
with
| Not_found ->
- let binop nexp1 op nexp2 = E_app_infix (
- E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)),
- Id_aux (Id op, Parse_ast.Unknown),
- E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2))
- ) in
- let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in
- (match nexp with
- | Nexp_constant i -> E_lit (L_aux (L_num i, l))
- | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2
- | Nexp_sum (nexp1, nexp2) -> binop nexp1 "+" nexp2
- | Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2
- | _ -> E_sizeof nexp_aux) in
+ let binop nexp1 op nexp2 = E_app_infix (
+ E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)),
+ Id_aux (Id op, Parse_ast.Unknown),
+ E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2))
+ ) in
+ let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in
+ (match nexp with
+ | Nexp_constant i -> E_lit (L_aux (L_num i, l))
+ | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2
+ | Nexp_sum (nexp1, nexp2) -> binop nexp1 "+" nexp2
+ | Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2
+ | _ -> E_sizeof nexp_aux) in
+
+ let ex_regex = Str.regexp "'ex[0-9]+" in
(* Rewrite calls to functions which have had parameters added to pass values
of type-level variables; these are added as sizeof expressions first, and
then further rewritten as above. *)
- let e_app_aux param_map ((exp, exp_orig), ((l,_) as annot)) =
+ let e_app_aux param_map ((exp, exp_orig), ((l, Some (env, _, _)) as annot)) =
let full_exp = E_aux (exp, annot) in
let orig_exp = E_aux (exp_orig, annot) in
match exp with
| E_app (f, args) ->
- if Bindings.mem f param_map then
- (* Retrieve instantiation of the type variables of the called function
+ if Bindings.mem f param_map then
+ (* Retrieve instantiation of the type variables of the called function
for the given parameters in the original environment *)
- let inst = instantiation_of orig_exp in
- let inst = KBindings.fold (fun kid uvar b -> KBindings.add (orig_kid kid) uvar b) inst KBindings.empty in
- let kid_exp kid = begin
- match KBindings.find (orig_kid kid) inst with
- | U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp))
- | _ ->
- raise (Reporting_basic.err_unreachable l
- ("failed to infer nexp for type variable " ^ string_of_kid kid ^
- " of function " ^ string_of_id f))
- end in
- let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in
- (E_aux (E_app (f, kid_exps @ args), annot), orig_exp)
- else (full_exp, orig_exp)
+ let inst = instantiation_of orig_exp in
+ (* Rewrite the inst using orig_kid so that each type variable has it's
+ original name rather than a mangled typechecker name *)
+ let inst = KBindings.fold (fun kid uvar b -> KBindings.add (orig_kid kid) uvar b) inst KBindings.empty in
+ let kid_exp kid = begin
+ (* We really don't want to see an existential here! *)
+ assert (not (Str.string_match ex_regex (string_of_kid kid) 0));
+ let uvar = try Some (KBindings.find (orig_kid kid) inst) with Not_found -> None in
+ match uvar with
+ | Some (U_nexp nexp) ->
+ let sizeof = E_aux (E_sizeof nexp, (l, Some (env, atom_typ nexp, no_effect))) in
+ rewrite_trivial_sizeof_exp sizeof
+ (* If the type variable is Not_found then it was probably
+ introduced by a P_var pattern, so it likely exists as
+ a variable in scope. It can't be an existential because the assert rules that out. *)
+ | None -> E_aux (E_id (id_of_kid (orig_kid kid)), simple_annot l (atom_typ (nvar (orig_kid kid))))
+ | _ ->
+ raise (Reporting_basic.err_unreachable l
+ ("failed to infer nexp for type variable " ^ string_of_kid kid ^
+ " of function " ^ string_of_id f))
+ end in
+ let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in
+ (E_aux (E_app (f, kid_exps @ args), annot), orig_exp)
+ else (full_exp, orig_exp)
| _ -> (full_exp, orig_exp) in
(* Plug this into a folding algorithm that also keeps around a copy of the
@@ -1162,7 +1205,7 @@ let rewrite_sizeof (Defs defs) =
} in
let rewrite_sizeof_fun params_map
- (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) =
+ (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) =
let rewrite_funcl_body (FCL_aux (FCL_Funcl (id,pat,exp), annot)) (funcls,nvars) =
let body_env = env_of exp in
let body_typ = typ_of exp in
@@ -1172,7 +1215,7 @@ let rewrite_sizeof (Defs defs) =
(* ... then rewrite sizeof expressions in current function body *)
let exp'' = fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp' in
(FCL_aux (FCL_Funcl (id,pat,exp''), annot) :: funcls,
- KidSet.union nvars (sizeof_frees exp'')) in
+ KidSet.union nvars (sizeof_frees exp'')) in
let (funcls, nvars) = List.fold_right rewrite_funcl_body funcls ([], KidSet.empty) in
(* Add a parameter for each remaining free type-level variable in a
sizeof expression *)
@@ -1180,83 +1223,86 @@ let rewrite_sizeof (Defs defs) =
let kid_annot kid = simple_annot l (kid_typ kid) in
let kid_pat kid =
P_aux (P_typ (kid_typ kid,
- P_aux (P_id (Id_aux (Id (string_of_kid kid), l)),
- kid_annot kid)), kid_annot kid) in
- let kid_eaux kid = E_id (Id_aux (Id (string_of_kid kid), l)) in
+ P_aux (P_id (Id_aux (Id (string_of_id (id_of_kid kid) ^ "__tv"), l)),
+ kid_annot kid)), kid_annot kid) in
+ let kid_eaux kid = E_id (Id_aux (Id (string_of_id (id_of_kid kid) ^ "__tv"), l)) in
let kid_typs = List.map kid_typ (KidSet.elements nvars) in
let kid_pats = List.map kid_pat (KidSet.elements nvars) in
let kid_nmap = List.map (fun kid -> (nvar kid, kid_eaux kid)) (KidSet.elements nvars) in
let rewrite_funcl_params (FCL_aux (FCL_Funcl (id, pat, exp), annot) as funcl) =
let rec rewrite_pat (P_aux (pat,(l,_)) as paux) =
if KidSet.is_empty nvars then paux else
- match pat_typ_of paux with
- | Typ_aux (Typ_tup _, _) ->
- (match pat with
- | P_tup pats ->
- P_aux (P_tup (kid_pats @ pats), (l, None))
- | P_wild -> paux
- | P_typ (Typ_aux (Typ_tup typs, l), pat) ->
- P_aux (P_typ (Typ_aux (Typ_tup (kid_typs @ typs), l),
- rewrite_pat pat), (l, None))
- | P_as (_, id) | P_id id ->
- (* adding parameters here would change the type of id;
+ match pat_typ_of paux with
+ | Typ_aux (Typ_tup _, _) ->
+ (match pat with
+ | P_tup pats ->
+ P_aux (P_tup (kid_pats @ pats), (l, None))
+ | P_wild -> paux
+ | P_typ (Typ_aux (Typ_tup typs, l), pat) ->
+ P_aux (P_typ (Typ_aux (Typ_tup (kid_typs @ typs), l),
+ rewrite_pat pat), (l, None))
+ | P_as (_, id) | P_id id ->
+ (* adding parameters here would change the type of id;
we should remove the P_as/P_id here and add a let-binding to the body *)
- raise (Reporting_basic.err_todo l
- "rewriting as- or id-patterns for sizeof expressions not yet implemented")
- | _ ->
- raise (Reporting_basic.err_unreachable l
- "unexpected pattern while rewriting function parameters for sizeof expressions"))
- | _ -> P_aux (P_tup (kid_pats @ [paux]), (l, None)) in
+ raise (Reporting_basic.err_todo l
+ "rewriting as- or id-patterns for sizeof expressions not yet implemented")
+ | _ ->
+ raise (Reporting_basic.err_unreachable l
+ "unexpected pattern while rewriting function parameters for sizeof expressions"))
+ | _ -> P_aux (P_tup (kid_pats @ [paux]), (l, None)) in
let exp' = fold_exp { id_exp_alg with e_sizeof = e_sizeof kid_nmap } exp in
FCL_aux (FCL_Funcl (id, rewrite_pat pat, exp'), annot) in
let funcls = List.map rewrite_funcl_params funcls in
(nvars, FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) in
let rewrite_sizeof_fundef (params_map, defs) = function
- | DEF_fundef fd ->
- let (nvars, fd') = rewrite_sizeof_fun params_map fd in
- let id = id_of_fundef fd in
- let params_map' =
- if KidSet.is_empty nvars then params_map
- else Bindings.add id nvars params_map in
- (params_map', defs @ [DEF_fundef fd'])
- | def ->
- (params_map, defs @ [def]) in
+ | DEF_fundef fd as def ->
+ let (nvars, fd') = rewrite_sizeof_fun params_map fd in
+ let id = id_of_fundef fd in
+ let params_map' =
+ if KidSet.is_empty nvars then params_map
+ else Bindings.add id nvars params_map in
+ (params_map', defs @ [DEF_fundef fd'])
+ | def ->
+ (params_map, defs @ [def]) in
let rewrite_sizeof_valspec params_map def =
let rewrite_typschm (TypSchm_aux (TypSchm_ts (tq, typ), l) as ts) id =
if Bindings.mem id params_map then
let kid_typs = List.map (fun kid -> atom_typ (nvar kid))
- (KidSet.elements (Bindings.find id params_map)) in
+ (KidSet.elements (Bindings.find id params_map)) in
let typ' = match typ with
- | Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) ->
- let vtyp_arg' = begin
- match vtyp_arg with
- | Typ_aux (Typ_tup typs, vl) ->
- Typ_aux (Typ_tup (kid_typs @ typs), vl)
- | _ -> Typ_aux (Typ_tup (kid_typs @ [vtyp_arg]), vl)
- end in
- Typ_aux (Typ_fn (vtyp_arg', vtyp_ret, declared_eff), vl)
- | _ -> raise (Reporting_basic.err_typ l
- "val spec with non-function type") in
+ | Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) ->
+ let vtyp_arg' = begin
+ match vtyp_arg with
+ | Typ_aux (Typ_tup typs, vl) ->
+ Typ_aux (Typ_tup (kid_typs @ typs), vl)
+ | _ -> Typ_aux (Typ_tup (kid_typs @ [vtyp_arg]), vl)
+ end in
+ Typ_aux (Typ_fn (vtyp_arg', vtyp_ret, declared_eff), vl)
+ | _ -> raise (Reporting_basic.err_typ l
+ "val spec with non-function type") in
TypSchm_aux (TypSchm_ts (tq, typ'), l)
else ts in
match def with
| DEF_spec (VS_aux (VS_val_spec (typschm, id), a)) ->
- DEF_spec (VS_aux (VS_val_spec (rewrite_typschm typschm id, id), a))
+ DEF_spec (VS_aux (VS_val_spec (rewrite_typschm typschm id, id), a))
| DEF_spec (VS_aux (VS_extern_no_rename (typschm, id), a)) ->
- DEF_spec (VS_aux (VS_extern_no_rename (rewrite_typschm typschm id, id), a))
+ DEF_spec (VS_aux (VS_extern_no_rename (rewrite_typschm typschm id, id), a))
| DEF_spec (VS_aux (VS_extern_spec (typschm, id, e), a)) ->
- DEF_spec (VS_aux (VS_extern_spec (rewrite_typschm typschm id, id, e), a))
+ DEF_spec (VS_aux (VS_extern_spec (rewrite_typschm typschm id, id, e), a))
| DEF_spec (VS_aux (VS_cast_spec (typschm, id), a)) ->
- DEF_spec (VS_aux (VS_cast_spec (rewrite_typschm typschm id, id), a))
+ DEF_spec (VS_aux (VS_cast_spec (rewrite_typschm typschm id, id), a))
| _ -> def in
let (params_map, defs) = List.fold_left rewrite_sizeof_fundef
- (Bindings.empty, []) defs in
+ (Bindings.empty, []) defs in
let defs = List.map (rewrite_sizeof_valspec params_map) defs in
+ Defs defs
+ (* FIXME: Won't re-check due to flow typing and E_constraint re-write before E_sizeof re-write.
+ Requires the typechecker to be more smart about different representations for valid flow typing constraints.
fst (check initial_env (Defs defs))
-
+ *)
let remove_vector_concat_pat pat =
@@ -2282,6 +2328,7 @@ let rewrite_defs_early_return =
rewrite_defs_base { rewriters_base with rewrite_fun = rewrite_fun_early_return }
+(* Turn constraints into numeric expressions with sizeof *)
let rewrite_constraint =
let rec rewrite_nc (NC_aux (nc_aux, l)) = mk_exp (rewrite_nc_aux nc_aux)
and rewrite_nc_aux = function
@@ -2294,7 +2341,7 @@ let rewrite_constraint =
| NC_false -> E_lit (mk_lit L_true)
| NC_true -> E_lit (mk_lit L_false)
| NC_nat_set_bounded (kid, ints) ->
- unaux_exp (rewrite_nc (List.fold_left (fun nc int -> nc_or nc (nc_eq (nvar kid) (nconstant int))) nc_true ints))
+ unaux_exp (rewrite_nc (List.fold_left (fun nc int -> nc_or nc (nc_eq (nvar kid) (nconstant int))) nc_true ints))
in
let rewrite_e_aux (E_aux (e_aux, _) as exp) =
match e_aux with
@@ -2307,13 +2354,127 @@ let rewrite_constraint =
rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_e_constraint) }
+let rewrite_type_union_typs rw_typ (Tu_aux (tu, annot)) =
+ match tu with
+ | Tu_id id -> Tu_aux (Tu_id id, annot)
+ | Tu_ty_id (typ, id) -> Tu_aux (Tu_ty_id (rw_typ typ, id), annot)
+
+let rewrite_type_def_typs rw_typ rw_typquant rw_typschm (TD_aux (td, annot)) =
+ match td with
+ | TD_abbrev (id, nso, typschm) -> TD_aux (TD_abbrev (id, nso, rw_typschm typschm), annot)
+ | TD_record (id, nso, typq, typ_ids, flag) ->
+ TD_aux (TD_record (id, nso, rw_typquant typq, List.map (fun (typ, id) -> (rw_typ typ, id)) typ_ids, flag), annot)
+ | TD_variant (id, nso, typq, tus, flag) ->
+ TD_aux (TD_variant (id, nso, rw_typquant typq, List.map (rewrite_type_union_typs rw_typ) tus, flag), annot)
+ | TD_enum (id, nso, ids, flag) -> TD_aux (TD_enum (id, nso, ids, flag), annot)
+ | TD_register (id, n1, n2, ranges) -> TD_aux (TD_register (id, n1, n2, ranges), annot)
+
+(* FIXME: other reg_dec types *)
+let rewrite_dec_spec_typs rw_typ (DEC_aux (ds, annot)) =
+ match ds with
+ | DEC_reg (typ, id) -> DEC_aux (DEC_reg (rw_typ typ, id), annot)
+ | _ -> assert false
+
+(* Remove overload definitions and cast val specs from the
+ specification because the interpreter doesn't know about them.*)
+let rewrite_overload_cast (Defs defs) =
+ let remove_cast_vs (VS_aux (vs_aux, annot)) =
+ match vs_aux with
+ | VS_val_spec (typschm, id) -> VS_aux (VS_val_spec (typschm, id), annot)
+ | VS_extern_no_rename (typschm, id) -> VS_aux (VS_val_spec (typschm, id), annot)
+ | VS_extern_spec (typschm, id, e) -> VS_aux (VS_extern_spec (typschm, id, e), annot)
+ | VS_cast_spec (typschm, id) -> VS_aux (VS_val_spec (typschm, id), annot)
+ in
+ let simple_def = function
+ | DEF_spec vs -> DEF_spec (remove_cast_vs vs)
+ | def -> def
+ in
+ let is_overload = function
+ | DEF_overload _ -> true
+ | _ -> false
+ in
+ let defs = List.map simple_def defs in
+ Defs (List.filter (fun def -> not (is_overload def)) defs)
+
+(* This pass aims to remove all the Num quantifiers from the specification. *)
+let rewrite_simple_types (Defs defs) =
+ let is_simple = function
+ | QI_aux (QI_id kopt, annot) as qi when is_typ_kopt kopt || is_order_kopt kopt -> true
+ | _ -> false
+ in
+ let simple_typquant (TypQ_aux (tq_aux, annot)) =
+ match tq_aux with
+ | TypQ_no_forall -> TypQ_aux (TypQ_no_forall, annot)
+ | TypQ_tq quants -> TypQ_aux (TypQ_tq (List.filter (fun q -> is_simple q) quants), annot)
+ in
+ let rec simple_typ (Typ_aux (typ_aux, l) as typ) = Typ_aux (simple_typ_aux typ_aux, l)
+ and simple_typ_aux = function
+ | Typ_wild -> Typ_wild
+ | Typ_id id -> Typ_id id
+ | Typ_app (id, [_; _; _; Typ_arg_aux (Typ_arg_typ typ, l)]) when Id.compare id (mk_id "vector") = 0 ->
+ Typ_app (mk_id "list", [Typ_arg_aux (Typ_arg_typ (simple_typ typ), l)])
+ | Typ_app (id, [_]) when Id.compare id (mk_id "atom") = 0 ->
+ Typ_id (mk_id "int")
+ | Typ_app (id, [_; _]) when Id.compare id (mk_id "range") = 0 ->
+ Typ_id (mk_id "int")
+ | Typ_fn (typ1, typ2, effs) -> Typ_fn (simple_typ typ1, simple_typ typ2, effs)
+ | Typ_tup typs -> Typ_tup (List.map simple_typ typs)
+ | Typ_exist (_, _, Typ_aux (typ, l)) -> simple_typ_aux typ
+ | typ_aux -> typ_aux
+ in
+ let simple_typschm (TypSchm_aux (TypSchm_ts (typq, typ), annot)) =
+ TypSchm_aux (TypSchm_ts (simple_typquant typq, simple_typ typ), annot)
+ in
+ let simple_vs (VS_aux (vs_aux, annot)) =
+ match vs_aux with
+ | VS_val_spec (typschm, id) -> VS_aux (VS_val_spec (simple_typschm typschm, id), annot)
+ | VS_extern_no_rename (typschm, id) -> VS_aux (VS_val_spec (simple_typschm typschm, id), annot)
+ | VS_extern_spec (typschm, id, e) -> VS_aux (VS_extern_spec (simple_typschm typschm, id, e), annot)
+ | VS_cast_spec (typschm, id) -> VS_aux (VS_cast_spec (simple_typschm typschm, id), annot)
+ in
+ let rec simple_lit (L_aux (lit_aux, l) as lit) =
+ match lit_aux with
+ | L_bin _ | L_hex _ ->
+ E_list (List.map (fun b -> E_aux (E_lit b, simple_annot l bit_typ)) (vector_string_to_bit_list l lit_aux))
+ | _ -> E_lit lit
+ in
+ let simple_def = function
+ | DEF_spec vs -> DEF_spec (simple_vs vs)
+ | DEF_type td -> DEF_type (rewrite_type_def_typs simple_typ simple_typquant simple_typschm td)
+ | DEF_reg_dec ds -> DEF_reg_dec (rewrite_dec_spec_typs simple_typ ds)
+ | def -> def
+ in
+ let simple_pat = {
+ id_pat_alg with
+ p_typ = (fun (typ, pat) -> P_typ (simple_typ typ, pat));
+ p_var = (fun kid -> P_id (id_of_kid kid));
+ p_vector = (fun pats -> P_list pats)
+ } in
+ let simple_exp = {
+ id_exp_alg with
+ e_lit = simple_lit;
+ e_vector = (fun exps -> E_list exps);
+ e_cast = (fun (typ, exp) -> E_cast (simple_typ typ, exp));
+ e_assert = (fun (E_aux (_, annot), str) -> E_assert (E_aux (E_lit (mk_lit L_true), annot), str));
+ lEXP_cast = (fun (typ, lexp) -> LEXP_cast (simple_typ typ, lexp));
+ pat_alg = simple_pat
+ } in
+ let simple_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp simple_exp);
+ rewrite_pat = (fun _ -> fold_pat simple_pat) }
+ in
+ let defs = Defs (List.map simple_def defs) in
+ rewrite_defs_base simple_defs defs
+
let rewrite_defs_ocaml = [
- top_sort_defs;
+ (* top_sort_defs; *)
rewrite_defs_remove_vector_concat;
rewrite_constraint;
+ rewrite_trivial_sizeof;
rewrite_sizeof;
- rewrite_defs_exp_lift_assign (* ;
- rewrite_defs_separate_numbs *)
+ rewrite_simple_types;
+ rewrite_overload_cast;
+ (* rewrite_defs_exp_lift_assign *)
+ (* rewrite_defs_separate_numbs *)
]
let rewrite_defs_remove_blocks =
@@ -2460,7 +2621,7 @@ let rewrite_defs_letbind_effects =
| LEXP_vector_range (lexp,e1,e2) ->
n_lexp lexp (fun lexp ->
n_exp_name e1 (fun e1 ->
- n_exp_name e2 (fun e2 ->
+ n_exp_name e2 (fun e2 ->
k (fix_eff_lexp (LEXP_aux (LEXP_vector_range (lexp,e1,e2),annot))))))
| LEXP_field (lexp,id) ->
n_lexp lexp (fun lexp ->