diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/_tags | 2 | ||||
| -rw-r--r-- | src/ast_util.ml | 7 | ||||
| -rw-r--r-- | src/pretty_print_sail.ml | 1 | ||||
| -rw-r--r-- | src/process_file.ml | 2 | ||||
| -rw-r--r-- | src/rewriter.ml | 350 | ||||
| -rw-r--r-- | src/sail.ml | 1 | ||||
| -rw-r--r-- | src/type_check.ml | 89 | ||||
| -rw-r--r-- | src/type_check.mli | 16 |
8 files changed, 303 insertions, 165 deletions
@@ -1,7 +1,7 @@ true: -traverse, debug, use_menhir <**/*.ml>: bin_annot, annot <lem_interp> or <test>: include -<sail.{byte,native}>: use_pprint, use_nums, use_unix +<sail.{byte,native}>: use_pprint, use_nums, use_unix, use_str <pprint> or <pprint/src>: include # see http://caml.inria.fr/mantis/view.php?id=4943 diff --git a/src/ast_util.ml b/src/ast_util.ml index ad14f0f1..96823b25 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -312,6 +312,13 @@ let rec string_of_exp (E_aux (exp, _)) = | E_throw exp -> "throw " ^ string_of_exp exp | E_cons (x, xs) -> string_of_exp x ^ " :: " ^ string_of_exp xs | E_list xs -> "[||" ^ string_of_list ", " string_of_exp xs ^ "||]" + | E_internal_cast _ -> "INTERNAL CAST" + | E_internal_exp _ -> "INTERNAL EXP" + | E_sizeof_internal _ -> "INTERNAL SIZEOF" + | E_internal_exp_user _ -> "INTERNAL EXP USER" + | E_comment _ -> "INTERNAL COMMENT" + | E_comment_struc _ -> "INTERNAL COMMENT STRUC" + | E_internal_let _ -> "INTERNAL LET" | _ -> "INTERNAL" and string_of_pexp (Pat_aux (pexp, _)) = match pexp with diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index b6d33d6d..b0b63ec1 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -345,6 +345,7 @@ let doc_exp, doc_let = | E_internal_exp_user _ -> raise (Reporting_basic.err_unreachable Unknown ("internal_exp_user not rewritten away")) | E_internal_cast ((_, Overload (_, _,_ )), _) | E_internal_exp _ -> assert false *) + | _ -> failwith ("Cannot print: " ^ Ast_util.string_of_exp expr) and let_exp (LB_aux(lb,_)) = match lb with | LB_val_explicit(ts,pat,e) -> (match ts with diff --git a/src/process_file.ml b/src/process_file.ml index 1c133ce8..6a3fc24c 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -225,8 +225,6 @@ let output libpath out_arg files = files let rewrite_step defs rewriter = - (* print_endline "=============================== REWRITE STEP"; - Pretty_print.pp_defs stdout defs; *) let defs = rewriter defs in let _ = match !(opt_ddump_rewrite_ast) with | Some (f, i) -> diff --git a/src/rewriter.ml b/src/rewriter.ml index ef4a209c..c96ea632 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -943,12 +943,12 @@ let compute_exp_alg bot join = ; e_tuple = split_join (fun es -> E_tuple es) ; e_if = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_if (e1,e2,e3))) ; e_for = (fun (id,(v1,e1),(v2,e2),(v3,e3),order,(v4,e4)) -> - (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4))) + (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4))) ; e_vector = split_join (fun es -> E_vector es) ; e_vector_indexed = (fun (es,(v2,opt2)) -> - let (is,es) = List.split es in - let (vs,es) = List.split es in - (join_list (vs @ [v2]), E_vector_indexed (List.combine is es,opt2))) + let (is,es) = List.split es in + let (vs,es) = List.split es in + (join_list (vs @ [v2]), E_vector_indexed (List.combine is es,opt2))) ; e_vector_access = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_access (e1,e2))) ; e_vector_subrange = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_subrange (e1,e2,e3))) ; e_vector_update = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_update (e1,e2,e3))) @@ -960,8 +960,8 @@ let compute_exp_alg bot join = ; e_record_update = (fun ((v1,e1),(vf,fexp)) -> (join v1 vf, E_record_update (e1,fexp))) ; e_field = (fun ((v1,e1),id) -> (v1, E_field (e1,id))) ; e_case = (fun ((v1,e1),pexps) -> - let (vps,pexps) = List.split pexps in - (join_list (v1::vps), E_case (e1,pexps))) + let (vps,pexps) = List.split pexps in + (join_list (v1::vps), E_case (e1,pexps))) ; e_let = (fun ((vl,lb),(v2,e2)) -> (join vl v2, E_let (lb,e2))) ; e_assign = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, E_assign (lexp,e2))) ; e_sizeof = (fun nexp -> (bot, E_sizeof nexp)) @@ -975,27 +975,27 @@ let compute_exp_alg bot join = ; e_comment = (fun c -> (bot, E_comment c)) ; e_comment_struc = (fun (v,e) -> (bot, E_comment_struc e)) (* ignore value by default, since it is comes from a comment *) ; e_internal_let = (fun ((vl, lexp), (v2,e2), (v3,e3)) -> - (join_list [vl;v2;v3], E_internal_let (lexp,e2,e3))) + (join_list [vl;v2;v3], E_internal_let (lexp,e2,e3))) ; e_internal_plet = (fun ((vp,pat), (v1,e1), (v2,e2)) -> - (join_list [vp;v1;v2], E_internal_plet (pat,e1,e2))) + (join_list [vp;v1;v2], E_internal_plet (pat,e1,e2))) ; e_internal_return = (fun (v,e) -> (v, E_internal_return e)) ; e_aux = (fun ((v,e),annot) -> (v, E_aux (e,annot))) ; lEXP_id = (fun id -> (bot, LEXP_id id)) ; lEXP_memory = (fun (id,es) -> split_join (fun es -> LEXP_memory (id,es)) es) ; lEXP_cast = (fun (typ,id) -> (bot, LEXP_cast (typ,id))) ; lEXP_tup = (fun ls -> - let (vs,ls) = List.split ls in - (join_list vs, LEXP_tup ls)) + let (vs,ls) = List.split ls in + (join_list vs, LEXP_tup ls)) ; lEXP_vector = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, LEXP_vector (lexp,e2))) ; lEXP_vector_range = (fun ((vl,lexp),(v2,e2),(v3,e3)) -> - (join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3))) + (join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3))) ; lEXP_field = (fun ((vl,lexp),id) -> (vl, LEXP_field (lexp,id))) ; lEXP_aux = (fun ((vl,lexp),annot) -> (vl, LEXP_aux (lexp,annot))) ; fE_Fexp = (fun (id,(v,e)) -> (v, FE_Fexp (id,e))) ; fE_aux = (fun ((vf,fexp),annot) -> (vf, FE_aux (fexp,annot))) ; fES_Fexps = (fun (fexps,b) -> - let (vs,fexps) = List.split fexps in - (join_list vs, FES_Fexps (fexps,b))) + let (vs,fexps) = List.split fexps in + (join_list vs, FES_Fexps (fexps,b))) ; fES_aux = (fun ((vf,fexp),annot) -> (vf, FES_aux (fexp,annot))) ; def_val_empty = (bot, Def_val_empty) ; def_val_dec = (fun (v,e) -> (v, Def_val_dec e)) @@ -1009,6 +1009,43 @@ let compute_exp_alg bot join = ; pat_alg = compute_pat_alg bot join } +(* Re-write trivial sizeof expressions - trivial meaning that the + value of the sizeof can be directly inferred from the type + variables in scope. *) +let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp = + let extract_typ_var l env nexp (id, (_, typ)) = + let var = E_aux (E_id id, (l, Some (env, typ, no_effect))) in + match destruct_atom_nexp env typ with + | Some size when prove env (nc_eq size nexp) -> Some var + | _ -> + begin + match destruct_vector env typ with + | Some (_, len, _, _) when prove env (nc_eq len nexp) -> + Some (E_aux (E_app (mk_id "length", [var]), (l, Some (env, atom_typ len, no_effect)))) + | _ -> None + end + in + let rewrite_e_aux (E_aux (e_aux, (l, _)) as orig_exp) = + let env = env_of orig_exp in + match e_aux with + | E_sizeof (Nexp_aux (Nexp_constant c, _) as nexp) -> + E_aux (E_lit (L_aux (L_num c, l)), (l, Some (env, atom_typ nexp, no_effect))) + | E_sizeof nexp -> + begin + let locals = Env.get_locals env in + let exps = Bindings.bindings locals + |> List.map (extract_typ_var l env nexp) + |> List.map (fun opt -> match opt with Some x -> [x] | None -> []) + |> List.concat + in + match exps with + | (exp :: _) -> exp + | [] -> orig_exp + end + | _ -> orig_exp + in + let rewrite_e_constraint = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in + rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_e_constraint) }, rewrite_e_aux (* Rewrite sizeof expressions with type-level variables to term-level expressions @@ -1020,78 +1057,91 @@ let compute_exp_alg bot join = let rewrite_sizeof (Defs defs) = let sizeof_frees exp = fst (fold_exp - { (compute_exp_alg KidSet.empty KidSet.union) with - e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) } - exp) in + { (compute_exp_alg KidSet.empty KidSet.union) with + e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) } + exp) in (* Collect nexps whose values can be obtained directly from a pattern bind *) let nexps_from_params pat = fst (fold_pat - { (compute_pat_alg [] (@)) with - p_aux = (fun ((v,pat),((l,_) as annot)) -> - let v' = match pat with - | P_id id | P_as (_, id) -> - let (Typ_aux (typ,_) as typ_aux) = typ_of_annot annot in - (match typ with - | Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp nexp, _)]) - when string_of_id atom = "atom" -> - [nexp, E_id id] - | Typ_app (vector, _) when string_of_id vector = "vector" -> - let id_length = Id_aux (Id "length", Parse_ast.Generated l) in - (try - (match Env.get_val_spec id_length (env_of_annot annot) with - | _ -> - let (_,len,_,_) = vector_typ_args_of typ_aux in - let exp = E_app (id_length, [E_aux (E_id id, annot)]) in - [len, exp]) - with - | _ -> []) - | _ -> []) - | _ -> [] in - (v @ v', P_aux (pat,annot)))} pat) in + { (compute_pat_alg [] (@)) with + p_aux = (fun ((v,pat),((l,_) as annot)) -> + let v' = match pat with + | P_id id | P_as (_, id) -> + let (Typ_aux (typ,_) as typ_aux) = typ_of_annot annot in + (match typ with + | Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp nexp, _)]) + when string_of_id atom = "atom" -> + [nexp, E_id id] + | Typ_app (vector, _) when string_of_id vector = "vector" -> + let id_length = Id_aux (Id "length", Parse_ast.Generated l) in + (try + (match Env.get_val_spec id_length (env_of_annot annot) with + | _ -> + let (_,len,_,_) = vector_typ_args_of typ_aux in + let exp = E_app (id_length, [E_aux (E_id id, annot)]) in + [len, exp]) + with + | _ -> []) + | _ -> []) + | _ -> [] in + (v @ v', P_aux (pat,annot)))} pat) in (* Substitute collected values in sizeof expressions *) let rec e_sizeof nmap (Nexp_aux (nexp, l) as nexp_aux) = try snd (List.find (fun (nexp,_) -> nexp_identical nexp nexp_aux) nmap) with | Not_found -> - let binop nexp1 op nexp2 = E_app_infix ( - E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)), - Id_aux (Id op, Parse_ast.Unknown), - E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2)) - ) in - let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in - (match nexp with - | Nexp_constant i -> E_lit (L_aux (L_num i, l)) - | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2 - | Nexp_sum (nexp1, nexp2) -> binop nexp1 "+" nexp2 - | Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2 - | _ -> E_sizeof nexp_aux) in + let binop nexp1 op nexp2 = E_app_infix ( + E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)), + Id_aux (Id op, Parse_ast.Unknown), + E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2)) + ) in + let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in + (match nexp with + | Nexp_constant i -> E_lit (L_aux (L_num i, l)) + | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2 + | Nexp_sum (nexp1, nexp2) -> binop nexp1 "+" nexp2 + | Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2 + | _ -> E_sizeof nexp_aux) in + + let ex_regex = Str.regexp "'ex[0-9]+" in (* Rewrite calls to functions which have had parameters added to pass values of type-level variables; these are added as sizeof expressions first, and then further rewritten as above. *) - let e_app_aux param_map ((exp, exp_orig), ((l,_) as annot)) = + let e_app_aux param_map ((exp, exp_orig), ((l, Some (env, _, _)) as annot)) = let full_exp = E_aux (exp, annot) in let orig_exp = E_aux (exp_orig, annot) in match exp with | E_app (f, args) -> - if Bindings.mem f param_map then - (* Retrieve instantiation of the type variables of the called function + if Bindings.mem f param_map then + (* Retrieve instantiation of the type variables of the called function for the given parameters in the original environment *) - let inst = instantiation_of orig_exp in - let inst = KBindings.fold (fun kid uvar b -> KBindings.add (orig_kid kid) uvar b) inst KBindings.empty in - let kid_exp kid = begin - match KBindings.find (orig_kid kid) inst with - | U_nexp nexp -> E_aux (E_sizeof nexp, simple_annot l (atom_typ nexp)) - | _ -> - raise (Reporting_basic.err_unreachable l - ("failed to infer nexp for type variable " ^ string_of_kid kid ^ - " of function " ^ string_of_id f)) - end in - let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in - (E_aux (E_app (f, kid_exps @ args), annot), orig_exp) - else (full_exp, orig_exp) + let inst = instantiation_of orig_exp in + (* Rewrite the inst using orig_kid so that each type variable has it's + original name rather than a mangled typechecker name *) + let inst = KBindings.fold (fun kid uvar b -> KBindings.add (orig_kid kid) uvar b) inst KBindings.empty in + let kid_exp kid = begin + (* We really don't want to see an existential here! *) + assert (not (Str.string_match ex_regex (string_of_kid kid) 0)); + let uvar = try Some (KBindings.find (orig_kid kid) inst) with Not_found -> None in + match uvar with + | Some (U_nexp nexp) -> + let sizeof = E_aux (E_sizeof nexp, (l, Some (env, atom_typ nexp, no_effect))) in + rewrite_trivial_sizeof_exp sizeof + (* If the type variable is Not_found then it was probably + introduced by a P_var pattern, so it likely exists as + a variable in scope. It can't be an existential because the assert rules that out. *) + | None -> E_aux (E_id (id_of_kid (orig_kid kid)), simple_annot l (atom_typ (nvar (orig_kid kid)))) + | _ -> + raise (Reporting_basic.err_unreachable l + ("failed to infer nexp for type variable " ^ string_of_kid kid ^ + " of function " ^ string_of_id f)) + end in + let kid_exps = List.map kid_exp (KidSet.elements (Bindings.find f param_map)) in + (E_aux (E_app (f, kid_exps @ args), annot), orig_exp) + else (full_exp, orig_exp) | _ -> (full_exp, orig_exp) in (* Plug this into a folding algorithm that also keeps around a copy of the @@ -1162,7 +1212,7 @@ let rewrite_sizeof (Defs defs) = } in let rewrite_sizeof_fun params_map - (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) = + (FD_aux (FD_function (rec_opt,tannot,eff,funcls),((l,_) as annot))) = let rewrite_funcl_body (FCL_aux (FCL_Funcl (id,pat,exp), annot)) (funcls,nvars) = let body_env = env_of exp in let body_typ = typ_of exp in @@ -1172,7 +1222,7 @@ let rewrite_sizeof (Defs defs) = (* ... then rewrite sizeof expressions in current function body *) let exp'' = fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp' in (FCL_aux (FCL_Funcl (id,pat,exp''), annot) :: funcls, - KidSet.union nvars (sizeof_frees exp'')) in + KidSet.union nvars (sizeof_frees exp'')) in let (funcls, nvars) = List.fold_right rewrite_funcl_body funcls ([], KidSet.empty) in (* Add a parameter for each remaining free type-level variable in a sizeof expression *) @@ -1180,8 +1230,8 @@ let rewrite_sizeof (Defs defs) = let kid_annot kid = simple_annot l (kid_typ kid) in let kid_pat kid = P_aux (P_typ (kid_typ kid, - P_aux (P_id (Id_aux (Id (string_of_kid kid), l)), - kid_annot kid)), kid_annot kid) in + P_aux (P_id (Id_aux (Id (string_of_kid kid), l)), + kid_annot kid)), kid_annot kid) in let kid_eaux kid = E_id (Id_aux (Id (string_of_kid kid), l)) in let kid_typs = List.map kid_typ (KidSet.elements nvars) in let kid_pats = List.map kid_pat (KidSet.elements nvars) in @@ -1189,74 +1239,77 @@ let rewrite_sizeof (Defs defs) = let rewrite_funcl_params (FCL_aux (FCL_Funcl (id, pat, exp), annot) as funcl) = let rec rewrite_pat (P_aux (pat,(l,_)) as paux) = if KidSet.is_empty nvars then paux else - match pat_typ_of paux with - | Typ_aux (Typ_tup _, _) -> - (match pat with - | P_tup pats -> - P_aux (P_tup (kid_pats @ pats), (l, None)) - | P_wild -> paux - | P_typ (Typ_aux (Typ_tup typs, l), pat) -> - P_aux (P_typ (Typ_aux (Typ_tup (kid_typs @ typs), l), - rewrite_pat pat), (l, None)) - | P_as (_, id) | P_id id -> - (* adding parameters here would change the type of id; + match pat_typ_of paux with + | Typ_aux (Typ_tup _, _) -> + (match pat with + | P_tup pats -> + P_aux (P_tup (kid_pats @ pats), (l, None)) + | P_wild -> paux + | P_typ (Typ_aux (Typ_tup typs, l), pat) -> + P_aux (P_typ (Typ_aux (Typ_tup (kid_typs @ typs), l), + rewrite_pat pat), (l, None)) + | P_as (_, id) | P_id id -> + (* adding parameters here would change the type of id; we should remove the P_as/P_id here and add a let-binding to the body *) - raise (Reporting_basic.err_todo l - "rewriting as- or id-patterns for sizeof expressions not yet implemented") - | _ -> - raise (Reporting_basic.err_unreachable l - "unexpected pattern while rewriting function parameters for sizeof expressions")) - | _ -> P_aux (P_tup (kid_pats @ [paux]), (l, None)) in + raise (Reporting_basic.err_todo l + "rewriting as- or id-patterns for sizeof expressions not yet implemented") + | _ -> + raise (Reporting_basic.err_unreachable l + "unexpected pattern while rewriting function parameters for sizeof expressions")) + | _ -> P_aux (P_tup (kid_pats @ [paux]), (l, None)) in let exp' = fold_exp { id_exp_alg with e_sizeof = e_sizeof kid_nmap } exp in FCL_aux (FCL_Funcl (id, rewrite_pat pat, exp'), annot) in let funcls = List.map rewrite_funcl_params funcls in (nvars, FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) in let rewrite_sizeof_fundef (params_map, defs) = function - | DEF_fundef fd -> - let (nvars, fd') = rewrite_sizeof_fun params_map fd in - let id = id_of_fundef fd in - let params_map' = - if KidSet.is_empty nvars then params_map - else Bindings.add id nvars params_map in - (params_map', defs @ [DEF_fundef fd']) - | def -> - (params_map, defs @ [def]) in + | DEF_fundef fd as def -> + let (nvars, fd') = rewrite_sizeof_fun params_map fd in + let id = id_of_fundef fd in + let params_map' = + if KidSet.is_empty nvars then params_map + else Bindings.add id nvars params_map in + (params_map', defs @ [DEF_fundef fd']) + | def -> + (params_map, defs @ [def]) in let rewrite_sizeof_valspec params_map def = let rewrite_typschm (TypSchm_aux (TypSchm_ts (tq, typ), l) as ts) id = if Bindings.mem id params_map then let kid_typs = List.map (fun kid -> atom_typ (nvar kid)) - (KidSet.elements (Bindings.find id params_map)) in + (KidSet.elements (Bindings.find id params_map)) in let typ' = match typ with - | Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) -> - let vtyp_arg' = begin - match vtyp_arg with - | Typ_aux (Typ_tup typs, vl) -> - Typ_aux (Typ_tup (kid_typs @ typs), vl) - | _ -> Typ_aux (Typ_tup (kid_typs @ [vtyp_arg]), vl) - end in - Typ_aux (Typ_fn (vtyp_arg', vtyp_ret, declared_eff), vl) - | _ -> raise (Reporting_basic.err_typ l - "val spec with non-function type") in + | Typ_aux (Typ_fn (vtyp_arg, vtyp_ret, declared_eff), vl) -> + let vtyp_arg' = begin + match vtyp_arg with + | Typ_aux (Typ_tup typs, vl) -> + Typ_aux (Typ_tup (kid_typs @ typs), vl) + | _ -> Typ_aux (Typ_tup (kid_typs @ [vtyp_arg]), vl) + end in + Typ_aux (Typ_fn (vtyp_arg', vtyp_ret, declared_eff), vl) + | _ -> raise (Reporting_basic.err_typ l + "val spec with non-function type") in TypSchm_aux (TypSchm_ts (tq, typ'), l) else ts in match def with | DEF_spec (VS_aux (VS_val_spec (typschm, id), a)) -> - DEF_spec (VS_aux (VS_val_spec (rewrite_typschm typschm id, id), a)) + DEF_spec (VS_aux (VS_val_spec (rewrite_typschm typschm id, id), a)) | DEF_spec (VS_aux (VS_extern_no_rename (typschm, id), a)) -> - DEF_spec (VS_aux (VS_extern_no_rename (rewrite_typschm typschm id, id), a)) + DEF_spec (VS_aux (VS_extern_no_rename (rewrite_typschm typschm id, id), a)) | DEF_spec (VS_aux (VS_extern_spec (typschm, id, e), a)) -> - DEF_spec (VS_aux (VS_extern_spec (rewrite_typschm typschm id, id, e), a)) + DEF_spec (VS_aux (VS_extern_spec (rewrite_typschm typschm id, id, e), a)) | DEF_spec (VS_aux (VS_cast_spec (typschm, id), a)) -> - DEF_spec (VS_aux (VS_cast_spec (rewrite_typschm typschm id, id), a)) + DEF_spec (VS_aux (VS_cast_spec (rewrite_typschm typschm id, id), a)) | _ -> def in let (params_map, defs) = List.fold_left rewrite_sizeof_fundef - (Bindings.empty, []) defs in + (Bindings.empty, []) defs in let defs = List.map (rewrite_sizeof_valspec params_map) defs in + Defs defs + (* FIXME: Won't re-check due to flow typing and E_constraint re-write before E_sizeof re-write. + Requires the typechecker to be more smart about different representations for valid flow typing constraints. fst (check initial_env (Defs defs)) - + *) let remove_vector_concat_pat pat = @@ -2282,6 +2335,7 @@ let rewrite_defs_early_return = rewrite_defs_base { rewriters_base with rewrite_fun = rewrite_fun_early_return } +(* Turn constraints into numeric expressions with sizeof *) let rewrite_constraint = let rec rewrite_nc (NC_aux (nc_aux, l)) = mk_exp (rewrite_nc_aux nc_aux) and rewrite_nc_aux = function @@ -2307,13 +2361,77 @@ let rewrite_constraint = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_e_constraint) } +let rewrite_simple_types (Defs defs) = + let is_constraint = function + | QI_aux (QI_id _, annot) as qi -> false + | _ -> true + in + + let simple_typquant (TypQ_aux (tq_aux, annot)) = + match tq_aux with + | TypQ_no_forall -> TypQ_aux (TypQ_no_forall, annot) + | TypQ_tq quants -> TypQ_aux (TypQ_tq (List.filter (fun q -> not (is_constraint q)) quants), annot) + in + + let rec simple_typ (Typ_aux (typ_aux, l) as typ) = Typ_aux (simple_typ_aux typ_aux, l) + and simple_typ_aux = function + | Typ_wild -> Typ_wild + | Typ_id id -> Typ_id id + | Typ_app (id, [_; _; _; Typ_arg_aux (Typ_arg_typ typ, l)]) when Id.compare id (mk_id "vector") = 0 -> + Typ_app (mk_id "list", [Typ_arg_aux (Typ_arg_typ (simple_typ typ), l)]) + | Typ_app (id, [_]) when Id.compare id (mk_id "atom") = 0 -> + Typ_id (mk_id "int") + | Typ_app (id, [_; _]) when Id.compare id (mk_id "range") = 0 -> + Typ_id (mk_id "int") + | Typ_fn (typ1, typ2, effs) -> Typ_fn (simple_typ typ1, simple_typ typ2, effs) + | Typ_tup typs -> Typ_tup (List.map simple_typ typs) + | Typ_exist (_, _, Typ_aux (typ, l)) -> simple_typ_aux typ + | typ_aux -> typ_aux + in + + let simple_typschm (TypSchm_aux (TypSchm_ts (typq, typ), annot)) = + TypSchm_aux (TypSchm_ts (simple_typquant typq, simple_typ typ), annot) + in + + let simple_vs (VS_aux (vs_aux, annot)) = + match vs_aux with + | VS_val_spec (typschm, id) -> VS_aux (VS_val_spec (simple_typschm typschm, id), annot) + | VS_extern_no_rename (typschm, id) -> VS_aux (VS_val_spec (simple_typschm typschm, id), annot) + | VS_extern_spec (typschm, id, e) -> VS_aux (VS_extern_spec (simple_typschm typschm, id, e), annot) + | VS_cast_spec (typschm, id) -> VS_aux (VS_cast_spec (simple_typschm typschm, id), annot) + in + + let simple_def def = + match def with + | DEF_spec vs -> DEF_spec (simple_vs vs) + | _ -> def + in + + let simple_pat = { + id_pat_alg with + p_typ = (fun (typ, pat) -> P_typ (simple_typ typ, pat)); + p_var = (fun kid -> P_id (id_of_kid kid)) + } in + let simple_exp = { + id_exp_alg with + e_cast = (fun (typ, exp) -> E_cast (simple_typ typ, exp)); + e_assert = (fun (E_aux (_, annot), str) -> E_assert (E_aux (E_lit (mk_lit L_true), annot), str)); + lEXP_cast = (fun (typ, lexp) -> LEXP_cast (simple_typ typ, lexp)); + pat_alg = simple_pat + } in + let simple_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp simple_exp) } in + + let defs = Defs (List.map simple_def defs) in + rewrite_defs_base simple_defs defs + let rewrite_defs_ocaml = [ - top_sort_defs; + (* top_sort_defs; *) rewrite_defs_remove_vector_concat; rewrite_constraint; + rewrite_trivial_sizeof; rewrite_sizeof; - rewrite_defs_exp_lift_assign (* ; - rewrite_defs_separate_numbs *) + (* rewrite_defs_exp_lift_assign *) + (* rewrite_defs_separate_numbs *) ] let rewrite_defs_remove_blocks = diff --git a/src/sail.ml b/src/sail.ml index dbc0eff4..cf366d42 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -158,7 +158,6 @@ let main() = else ()); (if !(opt_print_ocaml) then let ast_ocaml = rewrite_ast_ocaml ast in - print_endline "Finished re-writing ocaml"; if !(opt_libs_ocaml) = [] then output "" (Ocaml_out None) [out_name,ast_ocaml] else output "" (Ocaml_out (Some (List.hd !opt_libs_ocaml))) [out_name,ast_ocaml] diff --git a/src/type_check.ml b/src/type_check.ml index 2fcba97d..28169c41 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -417,6 +417,7 @@ module Env : sig val is_record : id -> t -> bool val get_accessor : id -> id -> t -> typquant * typ val add_local : id -> mut * typ -> t -> t + val get_locals : t -> (mut * typ) Bindings.t val add_variant : id -> typquant * type_union list -> t -> t val add_union_id : id -> typquant * typ -> t -> t val add_flow : id -> (typ -> typ) -> t -> t @@ -792,6 +793,8 @@ end = struct try Bindings.find id env.regtyps with | Not_found -> typ_error (id_loc id) (string_of_id id ^ " is not a register type") + let get_locals env = env.locals + let lookup_id id env = try let (mut, typ) = Bindings.find id env.locals in @@ -1286,27 +1289,30 @@ let ord_identical (Ord_aux (ord1, _)) (Ord_aux (ord2, _)) = | Ord_dec, Ord_dec -> true | _, _ -> false -let rec typ_identical (Typ_aux (typ1, _)) (Typ_aux (typ2, _)) = - match typ1, typ2 with - | Typ_wild, Typ_wild -> true - | Typ_id v1, Typ_id v2 -> Id.compare v1 v2 = 0 - | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 = 0 - | Typ_tup typs1, Typ_tup typs2 -> - begin - try List.for_all2 typ_identical typs1 typs2 with - | Invalid_argument _ -> false - end - | Typ_app (f1, args1), Typ_app (f2, args2) -> - begin - try Id.compare f1 f2 = 0 && List.for_all2 typ_arg_identical args1 args2 with - | Invalid_argument _ -> false - end - | _, _ -> false -and typ_arg_identical (Typ_arg_aux (arg1, _)) (Typ_arg_aux (arg2, _)) = - match arg1, arg2 with - | Typ_arg_nexp n1, Typ_arg_nexp n2 -> nexp_identical n1 n2 - | Typ_arg_typ typ1, Typ_arg_typ typ2 -> typ_identical typ1 typ2 - | Typ_arg_order ord1, Typ_arg_order ord2 -> ord_identical ord1 ord2 +let typ_identical env typ1 typ2 = + let rec typ_identical' (Typ_aux (typ1, _)) (Typ_aux (typ2, _)) = + match typ1, typ2 with + | Typ_wild, Typ_wild -> true + | Typ_id v1, Typ_id v2 -> Id.compare v1 v2 = 0 + | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 = 0 + | Typ_tup typs1, Typ_tup typs2 -> + begin + try List.for_all2 typ_identical' typs1 typs2 with + | Invalid_argument _ -> false + end + | Typ_app (f1, args1), Typ_app (f2, args2) -> + begin + try Id.compare f1 f2 = 0 && List.for_all2 typ_arg_identical args1 args2 with + | Invalid_argument _ -> false + end + | _, _ -> false + and typ_arg_identical (Typ_arg_aux (arg1, _)) (Typ_arg_aux (arg2, _)) = + match arg1, arg2 with + | Typ_arg_nexp n1, Typ_arg_nexp n2 -> nexp_identical n1 n2 + | Typ_arg_typ typ1, Typ_arg_typ typ2 -> typ_identical' typ1 typ2 + | Typ_arg_order ord1, Typ_arg_order ord2 -> ord_identical ord1 ord2 + in + typ_identical' (Env.expand_synonyms env typ1) (Env.expand_synonyms env typ2) type uvar = | U_nexp of nexp @@ -1513,6 +1519,7 @@ let uv_nexp_constraint env (kid, uvar) = | _ -> env let rec subtyp l env typ1 typ2 = + typ_print ("Subtype " ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2); match destruct_exist env typ1, destruct_exist env typ2 with | Some (kids, nc, typ1), _ -> let env = List.fold_left (fun env kid -> Env.add_typ_var kid BK_nat env) env kids in @@ -1673,13 +1680,13 @@ let destruct_atom (Typ_aux (typ_aux, _)) = when string_of_id f = "range" && c1 = c2 -> c1 | _ -> assert false -let destruct_atom_nexp (Typ_aux (typ_aux, _)) = - match typ_aux with - | Typ_app (f, [Typ_arg_aux (Typ_arg_nexp n, _)]) - when string_of_id f = "atom" -> n - | Typ_app (f, [Typ_arg_aux (Typ_arg_nexp n, _); Typ_arg_aux (Typ_arg_nexp _, _)]) - when string_of_id f = "range" -> n - | _ -> assert false +let destruct_atom_nexp env typ = + match Env.expand_synonyms env typ with + | Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_nexp n, _)]), _) + when string_of_id f = "atom" -> Some n + | Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_nexp n, _); Typ_arg_aux (Typ_arg_nexp _, _)]), _) + when string_of_id f = "range" -> Some n + | _ -> None let restrict_range_upper c1 (Typ_aux (typ_aux, l) as typ) = match typ_aux with @@ -1726,20 +1733,20 @@ let apply_flow_constraint = function let rec infer_flow env (E_aux (exp_aux, (l, _))) = match exp_aux with | E_app (f, [x; y]) when string_of_id f = "lteq_atom_atom" -> - let n1 = destruct_atom_nexp (typ_of x) in - let n2 = destruct_atom_nexp (typ_of y) in + let Some n1 = destruct_atom_nexp env (typ_of x) in + let Some n2 = destruct_atom_nexp env (typ_of y) in [], [nc_lteq n1 n2] | E_app (f, [x; y]) when string_of_id f = "gteq_atom_atom" -> - let n1 = destruct_atom_nexp (typ_of x) in - let n2 = destruct_atom_nexp (typ_of y) in + let Some n1 = destruct_atom_nexp env (typ_of x) in + let Some n2 = destruct_atom_nexp env (typ_of y) in [], [nc_gteq n1 n2] | E_app (f, [x; y]) when string_of_id f = "lt_atom_atom" -> - let n1 = destruct_atom_nexp (typ_of x) in - let n2 = destruct_atom_nexp (typ_of y) in + let Some n1 = destruct_atom_nexp env (typ_of x) in + let Some n2 = destruct_atom_nexp env (typ_of y) in [], [nc_lt n1 n2] | E_app (f, [x; y]) when string_of_id f = "gt_atom_atom" -> - let n1 = destruct_atom_nexp (typ_of x) in - let n2 = destruct_atom_nexp (typ_of y) in + let Some n1 = destruct_atom_nexp env (typ_of x) in + let Some n2 = destruct_atom_nexp env (typ_of y) in [], [nc_gt n1 n2] | E_app (f, [E_aux (E_id v, _); y]) when string_of_id f = "lt_range_atom" -> let kid = Env.fresh_kid env in @@ -2680,7 +2687,7 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ = typ_debug ("Existential constraints: " ^ string_of_list ", " string_of_n_constraint ex_constraints); let typ_ret = - if KidSet.is_empty (KidSet.of_list existentials) + if KidSet.is_empty (KidSet.of_list existentials) || KidSet.is_empty (typ_frees typ_ret) then (typ_debug "Returning Existential"; typ_ret) else mk_typ (Typ_exist (existentials, List.fold_left nc_and nc_true ex_constraints, typ_ret)) in @@ -2976,10 +2983,10 @@ let infer_funtyp l env tannotopt funcls = let mk_val_spec typq typ id = DEF_spec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown), id), (Parse_ast.Unknown, None))) -let check_tannotopt typq ret_typ = function +let check_tannotopt env typq ret_typ = function | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> () | Typ_annot_opt_aux (Typ_annot_opt_some (annot_typq, annot_ret_typ), l) -> - if typ_identical ret_typ annot_ret_typ + if typ_identical env ret_typ annot_ret_typ then () else typ_error l (string_of_bind (typq, ret_typ) ^ " and " ^ string_of_bind (annot_typq, annot_ret_typ) ^ " do not match between function and val spec") @@ -3003,7 +3010,7 @@ let check_fundef env (FD_aux (FD_function (recopt, tannotopt, effectopt, funcls) let (quant, typ) = infer_funtyp l env tannotopt funcls in false, (quant, typ), env in - check_tannotopt quant vtyp_ret tannotopt; + check_tannotopt env quant vtyp_ret tannotopt; typ_debug ("Checking fundef " ^ string_of_id id ^ " has type " ^ string_of_bind (quant, typ)); let funcl_env = add_typquant quant env in let funcls = List.map (fun funcl -> check_funcl funcl_env funcl typ) funcls in @@ -3033,7 +3040,7 @@ let check_val_spec env (VS_aux (vs, (l, _))) = | VS_extern_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id, ext) -> let env = Env.add_extern id ext env in (id, quants, typ, env) in - [DEF_spec (VS_aux (vs, (l, None)))], Env.add_val_spec id (quants, typ) env + [DEF_spec (VS_aux (vs, (l, None)))], Env.add_val_spec id (quants, Env.expand_synonyms env typ) env let check_default env (DT_aux (ds, l)) = match ds with diff --git a/src/type_check.mli b/src/type_check.mli index f22a6991..22812c89 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -73,6 +73,8 @@ module Env : sig (* Returns true if id is a register type, false otherwise *) val is_regtyp : id -> t -> bool + val get_locals : t -> (mut * typ) Bindings.t + (* Check if a local variable is mutable. Throws Type_error if it isn't a local variable. Probably best to use Env.lookup_id instead *) @@ -140,6 +142,13 @@ end (* Push all the type variables and constraints from a typquant into an environment *) val add_typquant : typquant -> Env.t -> Env.t +(* When the typechecker creates new type variables it gives them fresh + names of the form 'fvXXX#name, where XXX is a number (not + necessarily three digits), and name is the original name when the + type variable was created by renaming an exisiting type variable to + avoid shadowing. orig_kid takes such a type variable and strips out + the 'fvXXX# part. It returns the type variable unmodified if it is + not of this form. *) val orig_kid : kid -> kid (* Some handy utility functions for constructing types. *) @@ -219,6 +228,8 @@ val check_exp : Env.t -> unit exp -> typ -> tannot exp val infer_exp : Env.t -> unit exp -> tannot exp +val prove : Env.t -> n_constraint -> bool + val subtype_check : Env.t -> typ -> typ -> bool (* Partial functions: The expressions and patterns passed to these @@ -231,15 +242,12 @@ val env_of_annot : Ast.l * tannot -> Env.t val typ_of : tannot exp -> typ val typ_of_annot : Ast.l * tannot -> typ -val env_of : tannot exp -> Env.t - val pat_typ_of : tannot pat -> typ val effect_of : tannot exp -> effect val effect_of_annot : tannot -> effect -(* TODO: make return option *) -val destruct_atom_nexp : typ -> nexp +val destruct_atom_nexp : Env.t -> typ -> nexp option (* Safely destructure an existential type. Returns None if the type is not existential. This function will pick a fresh name for the |
