diff options
| author | Brian Campbell | 2017-08-23 11:11:08 +0100 |
|---|---|---|
| committer | Brian Campbell | 2017-08-23 11:11:08 +0100 |
| commit | 22c2e970e9e52ff60b8262d02b4f50ad12174fd8 (patch) | |
| tree | e05bc639514a511d4d39399b8a263e817897e4fe /src/rewriter.ml | |
| parent | 2a6f3b8e42a4cb4cececb79a9011346b5b25ce80 (diff) | |
| parent | c380d2d0b51be71871085ac7d085268f5baccb56 (diff) | |
Merge branch 'experiments' into mono-experiments
Diffstat (limited to 'src/rewriter.ml')
| -rw-r--r-- | src/rewriter.ml | 415 |
1 files changed, 288 insertions, 127 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index ef4a209c..d61939ee 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -566,14 +566,7 @@ let rewrite_lexp rewriters (LEXP_aux(lexp,(l,annot))) = let rewrite_fun rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),(l,annot))) = let _ = reset_fresh_name_counter () in - (*let _ = Printf.eprintf "Rewriting function %s, pattern %s\n" - (match id with (Id_aux (Id i,_)) -> i) (Pretty_print.pat_to_string pat) in*) - (*let map = get_map_tannot fdannot in - let map = - match map with - | None -> None - | Some m -> Some(m, Envmap.empty) in*) - (FCL_aux (FCL_Funcl (id,rewriters.rewrite_pat rewriters pat, + (FCL_aux (FCL_Funcl (id,rewriters.rewrite_pat rewriters pat, rewriters.rewrite_exp rewriters exp),(l,annot))) in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot)) @@ -943,12 +936,12 @@ let compute_exp_alg bot join = ; e_tuple = split_join (fun es -> E_tuple es) ; e_if = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_if (e1,e2,e3))) ; e_for = (fun (id,(v1,e1),(v2,e2),(v3,e3),order,(v4,e4)) -> - (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4))) + (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4))) ; e_vector = split_join (fun es -> E_vector es) ; e_vector_indexed = (fun (es,(v2,opt2)) -> - let (is,es) = List.split es in - let (vs,es) = List.split es in - (join_list (vs @ [v2]), E_vector_indexed (List.combine is es,opt2))) + let (is,es) = List.split es in + let (vs,es) = List.split es in + (join_list (vs @ [v2]), E_vector_indexed (List.combine is es,opt2))) ; e_vector_access = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_access (e1,e2))) ; e_vector_subrange = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_subrange (e1,e2,e3))) ; e_vector_update = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_update (e1,e2,e3))) @@ -960,8 +953,8 @@ let compute_exp_alg bot join = ; e_record_update = (fun ((v1,e1),(vf,fexp)) -> (join v1 vf, E_record_update (e1,fexp))) ; e_field = (fun ((v1,e1),id) -> (v1, E_field (e1,id))) ; e_case = (fun ((v1,e1),pexps) -> - let (vps,pexps) = List.split pexps in - (join_list (v1::vps), E_case (e1,pexps))) + let (vps,pexps) = List.split pexps in + (join_list (v1::vps), E_case (e1,pexps))) ; e_let = (fun ((vl,lb),(v2,e2)) -> (join vl v2, E_let (lb,e2))) ; e_assign = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, E_assign (lexp,e2))) ; e_sizeof = (fun nexp -> (bot, E_sizeof nexp)) @@ -975,27 +968,27 @@ let compute_exp_alg bot join = ; e_comment = (fun c -> (bot, E_comment c)) ; e_comment_struc = (fun (v,e) -> (bot, E_comment_struc e)) (* ignore value by default, since it is comes from a comment *) ; e_internal_let = (fun ((vl, lexp), (v2,e2), (v3,e3)) -> - (join_list [vl;v2;v3], E_internal_let (lexp,e2,e3))) + (join_list [vl;v2;v3], E_internal_let (lexp,e2,e3))) ; e_internal_plet = (fun ((vp,pat), (v1,e1), (v2,e2)) -> - (join_list [vp;v1;v2], E_internal_plet (pat,e1,e2))) + (join_list [vp;v1;v2], E_internal_plet (pat,e1,e2))) ; e_internal_return = (fun (v,e) -> (v, E_internal_return e)) ; e_aux = (fun ((v,e),annot) -> (v, E_aux (e,annot))) ; lEXP_id = (fun id -> (bot, LEXP_id id)) ; lEXP_memory = (fun (id,es) -> split_join (fun es -> LEXP_memory (id,es)) es) ; lEXP_cast = (fun (typ,id) -> (bot, LEXP_cast (typ,id))) ; lEXP_tup = (fun ls -> - let (vs,ls) = List.split ls in - (join_list vs, LEXP_tup ls)) + let (vs,ls) = List.split ls in + (join_list vs, LEXP_tup ls)) ; lEXP_vector = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, LEXP_vector (lexp,e2))) ; lEXP_vector_range = (fun ((vl,lexp),(v2,e2),(v3,e3)) -> - (join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3))) + (join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3))) ; lEXP_field = (fun ((vl,lexp),id) -> (vl, LEXP_field (lexp,id))) ; lEXP_aux = (fun ((vl,lexp),annot) -> (vl, LEXP_aux (lexp,annot))) ; fE_Fexp = (fun (id,(v,e)) -> (v, FE_Fexp (id,e))) ; fE_aux = (fun ((vf,fexp),annot) -> (vf, FE_aux (fexp,annot))) ; fES_Fexps = (fun (fexps,b) -> - let (vs,fexps) = List.split fexps in - (join_list vs, FES_Fexps (fexps,b))) + let (vs,fexps) = List.split fexps in + (join_list vs, FES_Fexps (fexps,b))) ; fES_aux = (fun ((vf,fexp),annot) -> (vf, FES_aux (fexp,annot))) ; def_val_empty = (bot, Def_val_empty) ; def_val_dec = (fun (v,e) -> (v, Def_val_dec e)) @@ -1009,6 +1002,43 @@ let compute_exp_alg bot join = ; pat_alg = compute_pat_alg bot join } +(* Re-write trivial sizeof expressions - trivial meaning that the + value of the sizeof can be directly inferred from the type + variables in scope. *) +let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp = + let extract_typ_var l env nexp (id, (_, typ)) = + let var = E_aux (E_id id, (l, Some (env, typ, no_effect))) in + match destruct_atom_nexp env typ with + | Some size when prove env (nc_eq size nexp) -> Some var + | _ -> + begin + match destruct_vector env typ with + | Some (_, len, _, _) when prove env (nc_eq len nexp) -> + Some (E_aux (E_app (mk_id "length", [var]), (l, Some (env, atom_typ len, no_effect)))) + | _ -> None + end + in + let rewrite_e_aux (E_aux (e_aux, (l, _)) as orig_exp) = + let env = env_of orig_exp in + match e_aux with + | E_sizeof (Nexp_aux (Nexp_constant c, _) as nexp) -> + E_aux (E_lit (L_aux (L_num c, l)), (l, Some (env, atom_typ nexp, no_effect))) + | E_sizeof nexp -> + begin + let locals = Env.get_locals env in + let exps = Bindings.bindings locals + |> List.map (extract_typ_var l env nexp) + |> List.map (fun opt -> match opt with Some x -> [x] | None -> []) + |> List.concat + in + match exps with + | (exp :: _) -> exp + | [] -> orig_exp + end + | _ -> orig_exp + in + let rewrite_e_constraint = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in + rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_e_constraint) }, rewrite_e_aux (* Rewrite sizeof expressions with type-level variables to term-level expressions @@ -1020,78 +1050,91 @@ let compute_exp_alg bot join = let rewrite_sizeof (Defs defs) = let sizeof_frees exp = fst (fold_exp - { (compute_exp_alg KidSet.empty KidSet.union) with - e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) } - exp) in + { (compute_exp_alg KidSet.empty KidSet.union) with + e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) } + exp) in (* Collect nexps whose values can be obtained directly from a pattern bind *) let nexps_from_params pat = fst (fold_pat - { (compute_pat_alg [] (@)) with - p_aux = (fun ((v,pat),((l,_) as annot)) -> - let v' = match pat with - | P_id id | P_as (_, id) -> - let (Typ_aux (typ,_) as typ_aux) = typ_of_annot annot in - (match typ with - | Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp nexp, _)]) - when string_of_id atom = "atom" -> - [nexp, E_id id] - | Typ_app (vector, _) when string_of_id vector = "vector" -> - let id_length = Id_aux (Id "length", Parse_ast.Generated l) in - (try - (match Env.get_val_spec id_length (env_of_annot annot) with - | _ -> - let (_,len,_,_) = vector_typ_args_of typ_aux in - let exp = E_app (id_length, [E_aux (E_id id, annot)]) in - [len, exp]) - with - | _ -> []) - | _ -> []) - | _ -> [] in - (v @ v', P_aux (pat,annot)))} pat) in + { (compute_pat_alg [] (@)) with + p_aux = (fun ((v,pat),((l,_) as annot)) -> + let v' = match pat with + | P_id id | P_as (_, id) -> + let (Typ_aux (typ,_) as typ_aux) = typ_of_annot annot in + (match typ with + | Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp nexp, _)]) + when string_of_id atom = "atom" -> + [nexp, E_id id] + | Typ_app (vector, _) when string_of_id vector = "vector" -> + let id_length = Id_aux (Id "length", Parse_ast.Generated l) in + (try + (match Env.get_val_spec id_length (env_of_annot annot) with + | _ -> + let (_,len,_,_) = vector_typ_args_of typ_aux in + let exp = E_app (id_length, [E_aux (E_id id, annot)]) in + [len, exp]) + with + | _ -> []) + | _ -> []) + | _ -> [] in + (v @ v', P_aux (pat,annot)))} pat) in (* Substitute collected values in sizeof expressions *) let rec e_sizeof nmap (Nexp_aux (nexp, l) as nexp_aux) = try snd (List.find (fun (nexp,_) -> nexp_identical nexp nexp_aux) nmap) with | Not_found -> - let binop nexp1 op nexp2 = E_app_infix ( - E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)), - Id_aux (Id op, Parse_ast.Unknown), - E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2)) - ) in - let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in - (match nexp with - | Nexp_constant i -> E_lit (L_aux (L_num i, l)) - | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2 - | Nexp_sum (nexp1, nexp2) -> binop nexp1 "+" nexp2 - | Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2 - | _ -> E_sizeof nexp_aux) in + let binop nexp1 op nexp2 = E_app_infix ( + E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)), + Id_aux (Id op, Parse_ast.Unknown), + E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2)) + ) in + let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in + (match nexp with + | Nexp_constant i -> E_lit (L_aux (L_num i, l)) + | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2 + | Nexp_sum (nexp1, nexp2) -> binop nexp1 "+" nexp2 + | Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2 + | _ -> E_sizeof nexp_aux) in + + let ex_regex = Str.regexp "'ex[0-9]+" in (* Rewrite calls to functions which have had parameters added to pass values of type-level variables; these are added as sizeof expressions first, and then further rewritten as above. *) - let e_app_aux param_map ((exp, exp_orig), ((l,_) as annot)) = + let e_app_aux param_map ((exp, exp_orig), ((l, Some (env, _, _)) as annot)) = let full_exp = E_aux (exp, annot) in let orig_exp = E_aux (exp_orig, annot) in match exp with | E_app (f, args) -> - if Bindings.mem f param_map then - (* Retrieve instantiation of the type variables of the called function + if Bindings.mem f param_map then + (* Retrieve instantiation of the type variables of the called function for the given parameters in the original environment *) - let inst = instantiation_of orig_exp in - let inst = KBindings.fold (fun kid uvar b -> KBindings.add (orig_kid kid) uvar b) inst KBindings.empty in - let kid_exp kid = begin - match KBindings.find (orig_kid kid) inst with - | U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp)) - | _ -> - raise (Reporting_basic.err_unreachable l - ("failed to infer nexp for type variable " ^ string_of_kid kid ^ - " of function " ^ string_of_id f)) - end in - let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in - (E_aux (E_app (f, kid_exps @ args), annot), orig_exp) - else (full_exp, orig_exp) + let inst = instantiation_of orig_exp in + (* Rewrite the inst using orig_kid so that each type variable has it's + original name rather than a mangled typechecker name *) + let inst = KBindings.fold (fun kid uvar b -> KBindings.add (orig_kid kid) uvar b) inst KBindings.empty in + let kid_exp kid = begin + (* We really don't want to see an existential here! *) + assert (not (Str.string_match ex_regex (string_of_kid kid) 0)); + let uvar = try Some (KBindings.find (orig_kid kid) inst) with Not_found -> None in + match uvar with + | Some (U_nexp nexp) -> + let sizeof = E_aux (E_sizeof nexp, (l, Some (env, atom_typ nexp, no_effect))) in + rewrite_trivial_sizeof_exp sizeof + (* If the type variable is Not_found then it was probably + introduced by a P_var pattern, so it likely exists as + a variable in scope. It can't be an existential because the assert rules that out. *) + | None -> E_aux (E_id (id_of_kid (orig_kid kid)), simple_annot l (atom_typ (nvar (orig_kid kid)))) + | _ -> + raise (Reporting_basic.err_unreachable l + ("failed to infer nexp for type variable " ^ string_of_kid kid ^ + " of function " ^ string_of_id f)) + end in + let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in + (E_aux (E_app (f, kid_exps @ args), annot), orig_exp) + else (full_exp, orig_exp) | _ -> (full_exp, orig_exp) in (* Plug this into a folding algorithm that also keeps around a copy of the @@ -1162,7 +1205,7 @@ let rewrite_sizeof (Defs defs) = } in let rewrite_sizeof_fun params_map - (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) = + (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) = let rewrite_funcl_body (FCL_aux (FCL_Funcl (id,pat,exp), annot)) (funcls,nvars) = let body_env = env_of exp in let body_typ = typ_of exp in @@ -1172,7 +1215,7 @@ let rewrite_sizeof (Defs defs) = (* ... then rewrite sizeof expressions in current function body *) let exp'' = fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp' in (FCL_aux (FCL_Funcl (id,pat,exp''), annot) :: funcls, - KidSet.union nvars (sizeof_frees exp'')) in + KidSet.union nvars (sizeof_frees exp'')) in let (funcls, nvars) = List.fold_right rewrite_funcl_body funcls ([], KidSet.empty) in (* Add a parameter for each remaining free type-level variable in a sizeof expression *) @@ -1180,83 +1223,86 @@ let rewrite_sizeof (Defs defs) = let kid_annot kid = simple_annot l (kid_typ kid) in let kid_pat kid = P_aux (P_typ (kid_typ kid, - P_aux (P_id (Id_aux (Id (string_of_kid kid), l)), - kid_annot kid)), kid_annot kid) in - let kid_eaux kid = E_id (Id_aux (Id (string_of_kid kid), l)) in + P_aux (P_id (Id_aux (Id (string_of_id (id_of_kid kid) ^ "__tv"), l)), + kid_annot kid)), kid_annot kid) in + let kid_eaux kid = E_id (Id_aux (Id (string_of_id (id_of_kid kid) ^ "__tv"), l)) in let kid_typs = List.map kid_typ (KidSet.elements nvars) in let kid_pats = List.map kid_pat (KidSet.elements nvars) in let kid_nmap = List.map (fun kid -> (nvar kid, kid_eaux kid)) (KidSet.elements nvars) in let rewrite_funcl_params (FCL_aux (FCL_Funcl (id, pat, exp), annot) as funcl) = let rec rewrite_pat (P_aux (pat,(l,_)) as paux) = if KidSet.is_empty nvars then paux else - match pat_typ_of paux with - | Typ_aux (Typ_tup _, _) -> - (match pat with - | P_tup pats -> - P_aux (P_tup (kid_pats @ pats), (l, None)) - | P_wild -> paux - | P_typ (Typ_aux (Typ_tup typs, l), pat) -> - P_aux (P_typ (Typ_aux (Typ_tup (kid_typs @ typs), l), - rewrite_pat pat), (l, None)) - | P_as (_, id) | P_id id -> - (* adding parameters here would change the type of id; + match pat_typ_of paux with + | Typ_aux (Typ_tup _, _) -> + (match pat with + | P_tup pats -> + P_aux (P_tup (kid_pats @ pats), (l, None)) + | P_wild -> paux + | P_typ (Typ_aux (Typ_tup typs, l), pat) -> + P_aux (P_typ (Typ_aux (Typ_tup (kid_typs @ typs), l), + rewrite_pat pat), (l, None)) + | P_as (_, id) | P_id id -> + (* adding parameters here would change the type of id; we should remove the P_as/P_id here and add a let-binding to the body *) - raise (Reporting_basic.err_todo l - "rewriting as- or id-patterns for sizeof expressions not yet implemented") - | _ -> - raise (Reporting_basic.err_unreachable l - "unexpected pattern while rewriting function parameters for sizeof expressions")) - | _ -> P_aux (P_tup (kid_pats @ [paux]), (l, None)) in + raise (Reporting_basic.err_todo l + "rewriting as- or id-patterns for sizeof expressions not yet implemented") + | _ -> + raise (Reporting_basic.err_unreachable l + "unexpected pattern while rewriting function parameters for sizeof expressions")) + | _ -> P_aux (P_tup (kid_pats @ [paux]), (l, None)) in let exp' = fold_exp { id_exp_alg with e_sizeof = e_sizeof kid_nmap } exp in FCL_aux (FCL_Funcl (id, rewrite_pat pat, exp'), annot) in let funcls = List.map rewrite_funcl_params funcls in (nvars, FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) in let rewrite_sizeof_fundef (params_map, defs) = function - | DEF_fundef fd -> - let (nvars, fd') = rewrite_sizeof_fun params_map fd in - let id = id_of_fundef fd in - let params_map' = - if KidSet.is_empty nvars then params_map - else Bindings.add id nvars params_map in - (params_map', defs @ [DEF_fundef fd']) - | def -> - (params_map, defs @ [def]) in + | DEF_fundef fd as def -> + let (nvars, fd') = rewrite_sizeof_fun params_map fd in + let id = id_of_fundef fd in + let params_map' = + if KidSet.is_empty nvars then params_map + else Bindings.add id nvars params_map in + (params_map', defs @ [DEF_fundef fd']) + | def -> + (params_map, defs @ [def]) in let rewrite_sizeof_valspec params_map def = let rewrite_typschm (TypSchm_aux (TypSchm_ts (tq, typ), l) as ts) id = if Bindings.mem id params_map then let kid_typs = List.map (fun kid -> atom_typ (nvar kid)) - (KidSet.elements (Bindings.find id params_map)) in + (KidSet.elements (Bindings.find id params_map)) in let typ' = match typ with - | Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) -> - let vtyp_arg' = begin - match vtyp_arg with - | Typ_aux (Typ_tup typs, vl) -> - Typ_aux (Typ_tup (kid_typs @ typs), vl) - | _ -> Typ_aux (Typ_tup (kid_typs @ [vtyp_arg]), vl) - end in - Typ_aux (Typ_fn (vtyp_arg', vtyp_ret, declared_eff), vl) - | _ -> raise (Reporting_basic.err_typ l - "val spec with non-function type") in + | Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) -> + let vtyp_arg' = begin + match vtyp_arg with + | Typ_aux (Typ_tup typs, vl) -> + Typ_aux (Typ_tup (kid_typs @ typs), vl) + | _ -> Typ_aux (Typ_tup (kid_typs @ [vtyp_arg]), vl) + end in + Typ_aux (Typ_fn (vtyp_arg', vtyp_ret, declared_eff), vl) + | _ -> raise (Reporting_basic.err_typ l + "val spec with non-function type") in TypSchm_aux (TypSchm_ts (tq, typ'), l) else ts in match def with | DEF_spec (VS_aux (VS_val_spec (typschm, id), a)) -> - DEF_spec (VS_aux (VS_val_spec (rewrite_typschm typschm id, id), a)) + DEF_spec (VS_aux (VS_val_spec (rewrite_typschm typschm id, id), a)) | DEF_spec (VS_aux (VS_extern_no_rename (typschm, id), a)) -> - DEF_spec (VS_aux (VS_extern_no_rename (rewrite_typschm typschm id, id), a)) + DEF_spec (VS_aux (VS_extern_no_rename (rewrite_typschm typschm id, id), a)) | DEF_spec (VS_aux (VS_extern_spec (typschm, id, e), a)) -> - DEF_spec (VS_aux (VS_extern_spec (rewrite_typschm typschm id, id, e), a)) + DEF_spec (VS_aux (VS_extern_spec (rewrite_typschm typschm id, id, e), a)) | DEF_spec (VS_aux (VS_cast_spec (typschm, id), a)) -> - DEF_spec (VS_aux (VS_cast_spec (rewrite_typschm typschm id, id), a)) + DEF_spec (VS_aux (VS_cast_spec (rewrite_typschm typschm id, id), a)) | _ -> def in let (params_map, defs) = List.fold_left rewrite_sizeof_fundef - (Bindings.empty, []) defs in + (Bindings.empty, []) defs in let defs = List.map (rewrite_sizeof_valspec params_map) defs in + Defs defs + (* FIXME: Won't re-check due to flow typing and E_constraint re-write before E_sizeof re-write. + Requires the typechecker to be more smart about different representations for valid flow typing constraints. fst (check initial_env (Defs defs)) - + *) let remove_vector_concat_pat pat = @@ -2282,6 +2328,7 @@ let rewrite_defs_early_return = rewrite_defs_base { rewriters_base with rewrite_fun = rewrite_fun_early_return } +(* Turn constraints into numeric expressions with sizeof *) let rewrite_constraint = let rec rewrite_nc (NC_aux (nc_aux, l)) = mk_exp (rewrite_nc_aux nc_aux) and rewrite_nc_aux = function @@ -2294,7 +2341,7 @@ let rewrite_constraint = | NC_false -> E_lit (mk_lit L_true) | NC_true -> E_lit (mk_lit L_false) | NC_nat_set_bounded (kid, ints) -> - unaux_exp (rewrite_nc (List.fold_left (fun nc int -> nc_or nc (nc_eq (nvar kid) (nconstant int))) nc_true ints)) + unaux_exp (rewrite_nc (List.fold_left (fun nc int -> nc_or nc (nc_eq (nvar kid) (nconstant int))) nc_true ints)) in let rewrite_e_aux (E_aux (e_aux, _) as exp) = match e_aux with @@ -2307,13 +2354,127 @@ let rewrite_constraint = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_e_constraint) } +let rewrite_type_union_typs rw_typ (Tu_aux (tu, annot)) = + match tu with + | Tu_id id -> Tu_aux (Tu_id id, annot) + | Tu_ty_id (typ, id) -> Tu_aux (Tu_ty_id (rw_typ typ, id), annot) + +let rewrite_type_def_typs rw_typ rw_typquant rw_typschm (TD_aux (td, annot)) = + match td with + | TD_abbrev (id, nso, typschm) -> TD_aux (TD_abbrev (id, nso, rw_typschm typschm), annot) + | TD_record (id, nso, typq, typ_ids, flag) -> + TD_aux (TD_record (id, nso, rw_typquant typq, List.map (fun (typ, id) -> (rw_typ typ, id)) typ_ids, flag), annot) + | TD_variant (id, nso, typq, tus, flag) -> + TD_aux (TD_variant (id, nso, rw_typquant typq, List.map (rewrite_type_union_typs rw_typ) tus, flag), annot) + | TD_enum (id, nso, ids, flag) -> TD_aux (TD_enum (id, nso, ids, flag), annot) + | TD_register (id, n1, n2, ranges) -> TD_aux (TD_register (id, n1, n2, ranges), annot) + +(* FIXME: other reg_dec types *) +let rewrite_dec_spec_typs rw_typ (DEC_aux (ds, annot)) = + match ds with + | DEC_reg (typ, id) -> DEC_aux (DEC_reg (rw_typ typ, id), annot) + | _ -> assert false + +(* Remove overload definitions and cast val specs from the + specification because the interpreter doesn't know about them.*) +let rewrite_overload_cast (Defs defs) = + let remove_cast_vs (VS_aux (vs_aux, annot)) = + match vs_aux with + | VS_val_spec (typschm, id) -> VS_aux (VS_val_spec (typschm, id), annot) + | VS_extern_no_rename (typschm, id) -> VS_aux (VS_val_spec (typschm, id), annot) + | VS_extern_spec (typschm, id, e) -> VS_aux (VS_extern_spec (typschm, id, e), annot) + | VS_cast_spec (typschm, id) -> VS_aux (VS_val_spec (typschm, id), annot) + in + let simple_def = function + | DEF_spec vs -> DEF_spec (remove_cast_vs vs) + | def -> def + in + let is_overload = function + | DEF_overload _ -> true + | _ -> false + in + let defs = List.map simple_def defs in + Defs (List.filter (fun def -> not (is_overload def)) defs) + +(* This pass aims to remove all the Num quantifiers from the specification. *) +let rewrite_simple_types (Defs defs) = + let is_simple = function + | QI_aux (QI_id kopt, annot) as qi when is_typ_kopt kopt || is_order_kopt kopt -> true + | _ -> false + in + let simple_typquant (TypQ_aux (tq_aux, annot)) = + match tq_aux with + | TypQ_no_forall -> TypQ_aux (TypQ_no_forall, annot) + | TypQ_tq quants -> TypQ_aux (TypQ_tq (List.filter (fun q -> is_simple q) quants), annot) + in + let rec simple_typ (Typ_aux (typ_aux, l) as typ) = Typ_aux (simple_typ_aux typ_aux, l) + and simple_typ_aux = function + | Typ_wild -> Typ_wild + | Typ_id id -> Typ_id id + | Typ_app (id, [_; _; _; Typ_arg_aux (Typ_arg_typ typ, l)]) when Id.compare id (mk_id "vector") = 0 -> + Typ_app (mk_id "list", [Typ_arg_aux (Typ_arg_typ (simple_typ typ), l)]) + | Typ_app (id, [_]) when Id.compare id (mk_id "atom") = 0 -> + Typ_id (mk_id "int") + | Typ_app (id, [_; _]) when Id.compare id (mk_id "range") = 0 -> + Typ_id (mk_id "int") + | Typ_fn (typ1, typ2, effs) -> Typ_fn (simple_typ typ1, simple_typ typ2, effs) + | Typ_tup typs -> Typ_tup (List.map simple_typ typs) + | Typ_exist (_, _, Typ_aux (typ, l)) -> simple_typ_aux typ + | typ_aux -> typ_aux + in + let simple_typschm (TypSchm_aux (TypSchm_ts (typq, typ), annot)) = + TypSchm_aux (TypSchm_ts (simple_typquant typq, simple_typ typ), annot) + in + let simple_vs (VS_aux (vs_aux, annot)) = + match vs_aux with + | VS_val_spec (typschm, id) -> VS_aux (VS_val_spec (simple_typschm typschm, id), annot) + | VS_extern_no_rename (typschm, id) -> VS_aux (VS_val_spec (simple_typschm typschm, id), annot) + | VS_extern_spec (typschm, id, e) -> VS_aux (VS_extern_spec (simple_typschm typschm, id, e), annot) + | VS_cast_spec (typschm, id) -> VS_aux (VS_cast_spec (simple_typschm typschm, id), annot) + in + let rec simple_lit (L_aux (lit_aux, l) as lit) = + match lit_aux with + | L_bin _ | L_hex _ -> + E_list (List.map (fun b -> E_aux (E_lit b, simple_annot l bit_typ)) (vector_string_to_bit_list l lit_aux)) + | _ -> E_lit lit + in + let simple_def = function + | DEF_spec vs -> DEF_spec (simple_vs vs) + | DEF_type td -> DEF_type (rewrite_type_def_typs simple_typ simple_typquant simple_typschm td) + | DEF_reg_dec ds -> DEF_reg_dec (rewrite_dec_spec_typs simple_typ ds) + | def -> def + in + let simple_pat = { + id_pat_alg with + p_typ = (fun (typ, pat) -> P_typ (simple_typ typ, pat)); + p_var = (fun kid -> P_id (id_of_kid kid)); + p_vector = (fun pats -> P_list pats) + } in + let simple_exp = { + id_exp_alg with + e_lit = simple_lit; + e_vector = (fun exps -> E_list exps); + e_cast = (fun (typ, exp) -> E_cast (simple_typ typ, exp)); + e_assert = (fun (E_aux (_, annot), str) -> E_assert (E_aux (E_lit (mk_lit L_true), annot), str)); + lEXP_cast = (fun (typ, lexp) -> LEXP_cast (simple_typ typ, lexp)); + pat_alg = simple_pat + } in + let simple_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp simple_exp); + rewrite_pat = (fun _ -> fold_pat simple_pat) } + in + let defs = Defs (List.map simple_def defs) in + rewrite_defs_base simple_defs defs + let rewrite_defs_ocaml = [ - top_sort_defs; + (* top_sort_defs; *) rewrite_defs_remove_vector_concat; rewrite_constraint; + rewrite_trivial_sizeof; rewrite_sizeof; - rewrite_defs_exp_lift_assign (* ; - rewrite_defs_separate_numbs *) + rewrite_simple_types; + rewrite_overload_cast; + (* rewrite_defs_exp_lift_assign *) + (* rewrite_defs_separate_numbs *) ] let rewrite_defs_remove_blocks = @@ -2460,7 +2621,7 @@ let rewrite_defs_letbind_effects = | LEXP_vector_range (lexp,e1,e2) -> n_lexp lexp (fun lexp -> n_exp_name e1 (fun e1 -> - n_exp_name e2 (fun e2 -> + n_exp_name e2 (fun e2 -> k (fix_eff_lexp (LEXP_aux (LEXP_vector_range (lexp,e1,e2),annot)))))) | LEXP_field (lexp,id) -> n_lexp lexp (fun lexp -> |
