diff options
| author | Brian Campbell | 2018-07-23 16:14:12 +0100 |
|---|---|---|
| committer | Brian Campbell | 2018-07-23 16:15:53 +0100 |
| commit | f3ef82fee78d40c628d319dab4cc35a41c638e8e (patch) | |
| tree | 94a4ac32a05c69f0ef9a69d99e6207e9777f9e68 /src | |
| parent | 4c25326519d00bc781d6ee33ca507d1d525af686 (diff) | |
Coq: make all pattern matches in the output exhaustive
Uses previous stage to deal with (e.g.) guards.
New option -dcoq_warn_nonex tells you where all of the extra default
cases were added.
Diffstat (limited to 'src')
| -rw-r--r-- | src/process_file.ml | 2 | ||||
| -rw-r--r-- | src/rewrites.ml | 296 | ||||
| -rw-r--r-- | src/rewrites.mli | 7 | ||||
| -rw-r--r-- | src/sail.ml | 3 | ||||
| -rw-r--r-- | src/type_check.mli | 1 |
5 files changed, 308 insertions, 1 deletions
diff --git a/src/process_file.ml b/src/process_file.ml index 9ed52e8d..c3e1b510 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -398,7 +398,7 @@ let rewrite rewriters defs = let rewrite_ast = rewrite [("initial", Rewriter.rewrite_defs)] let rewrite_undefined bitvectors = rewrite [("undefined", fun x -> Rewrites.rewrite_undefined bitvectors x)] let rewrite_ast_lem = rewrite Rewrites.rewrite_defs_lem -let rewrite_ast_coq = rewrite Rewrites.rewrite_defs_lem +let rewrite_ast_coq = rewrite Rewrites.rewrite_defs_coq let rewrite_ast_ocaml = rewrite Rewrites.rewrite_defs_ocaml let rewrite_ast_c ast = ast diff --git a/src/rewrites.ml b/src/rewrites.ml index 8fe30d6b..246a2670 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -4001,6 +4001,263 @@ let rewrite_defs_realise_mappings (Defs defs) = in Defs (List.map rewrite_def defs |> List.flatten) + +(* Rewrite to make all pattern matches in Coq output exhaustive. + Assumes that guards, vector patterns, etc have been rewritten already. *) + +let opt_coq_warn_nonexhaustive = ref false + +module MakeExhaustive = +struct + +type rlit = + | RL_unit + | RL_zero + | RL_one + | RL_true + | RL_false + | RL_inf + +let string_of_rlit = function + | RL_unit -> "()" + | RL_zero -> "bitzero" + | RL_one -> "bitone" + | RL_true -> "true" + | RL_false -> "false" + | RL_inf -> "..." + +let rlit_of_lit (L_aux (l,_)) = + match l with + | L_unit -> RL_unit + | L_zero -> RL_zero + | L_one -> RL_one + | L_true -> RL_true + | L_false -> RL_false + | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> RL_inf + | L_undef -> assert false + +let inv_rlit_of_lit (L_aux (l,_)) = + match l with + | L_unit -> [] + | L_zero -> [RL_one] + | L_one -> [RL_zero] + | L_true -> [RL_false] + | L_false -> [RL_true] + | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> [RL_inf] + | L_undef -> assert false + +type residual_pattern = + | RP_any + | RP_lit of rlit + | RP_enum of id + | RP_app of id * residual_pattern list + | RP_tup of residual_pattern list + | RP_nil + | RP_cons of residual_pattern * residual_pattern + +let rec string_of_rp = function + | RP_any -> "_" + | RP_lit rlit -> string_of_rlit rlit + | RP_enum id -> string_of_id id + | RP_app (f,args) -> string_of_id f ^ "(" ^ String.concat "," (List.map string_of_rp args) ^ ")" + | RP_tup rps -> "(" ^ String.concat "," (List.map string_of_rp rps) ^ ")" + | RP_nil -> "[| |]" + | RP_cons (rp1,rp2) -> string_of_rp rp1 ^ "::" ^ string_of_rp rp2 + +type ctx = { + env : Env.t; + enum_to_rest: (residual_pattern list) Bindings.t; + constructor_to_rest: (residual_pattern list) Bindings.t +} + +let make_enum_mappings ids m = + IdSet.fold (fun id m -> + Bindings.add id + (List.map (fun e -> RP_enum e) (IdSet.elements (IdSet.remove id ids))) m) + ids + m + +let make_cstr_mappings env ids m = + let ids = IdSet.elements ids in + let constructors = List.map + (fun id -> + let _,ty = Env.get_val_spec id env in + let args = match ty with + | Typ_aux (Typ_fn (Typ_aux (Typ_tup tys,_),_,_),_) -> List.map (fun _ -> RP_any) tys + | _ -> [RP_any] + in RP_app (id,args)) ids in + let rec aux ids acc l = + match ids, l with + | [], [] -> m + | id::ids, rp::t -> + let m = aux ids (acc@[rp]) t in + Bindings.add id (acc@t) m + | _ -> assert false + in aux ids [] constructors + +let ctx_from_pattern_completeness_ctx env = + let ctx = Env.pattern_completeness_ctx env in + { env = env; + enum_to_rest = Bindings.fold (fun _ ids m -> make_enum_mappings ids m) + ctx.Pattern_completeness.enums Bindings.empty; + constructor_to_rest = Bindings.fold (fun _ ids m -> make_cstr_mappings env ids m) + ctx.Pattern_completeness.variants Bindings.empty + } + +let printprefix = ref " " + +let rec remove_clause_from_pattern ctx (P_aux (rm_pat,ann)) res_pat = + let subpats rm_pats res_pats = + let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in + let rec aux acc fixed residual = + match fixed, residual with + | [], [] -> [] + | (fh::ft), (rh::rt) -> + let rt' = aux (acc@[fh]) ft rt in + let newr = List.map (fun x -> acc @ (x::ft)) rh in + newr @ rt' + | _,_ -> assert false (* impossible because we managed map2 above *) + in aux [] res_pats res_pats' + in + let inconsistent () = + raise (Reporting_basic.err_unreachable (fst ann) + ("Inconsistency during exhaustiveness analysis with " ^ + string_of_rp res_pat)) + in + (*let _ = print_endline (!printprefix ^ "pat: " ^string_of_pat (P_aux (rm_pat,ann))) in + let _ = print_endline (!printprefix ^ "res_pat: " ^string_of_rp res_pat) in + let _ = printprefix := " " ^ !printprefix in*) + let rp' = + match rm_pat with + | P_wild -> [] + | P_id id when (match Env.lookup_id id ctx.env with Unbound | Local _ -> true | _ -> false) -> [] + | P_lit lit -> + (match res_pat with + | RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit) + | RP_lit RL_inf -> [res_pat] + | RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat] + | _ -> inconsistent ()) + | P_as (p,_) + | P_typ (_,p) + | P_var (p,_) + -> remove_clause_from_pattern ctx p res_pat + | P_id id -> + (match Env.lookup_id id ctx.env with + | Enum enum -> + (match res_pat with + | RP_any -> Bindings.find id ctx.enum_to_rest + | RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat] + | _ -> inconsistent ()) + | _ -> assert false) + | P_tup rm_pats -> + let previous_res_pats = + match res_pat with + | RP_tup res_pats -> res_pats + | RP_any -> List.map (fun _ -> RP_any) rm_pats + | _ -> inconsistent () + in + let res_pats' = subpats rm_pats previous_res_pats in + List.map (fun rps -> RP_tup rps) res_pats' + | P_app (id,args) -> + (match res_pat with + | RP_app (id',residual_args) -> + if Id.compare id id' == 0 then + let res_pats' = subpats args residual_args in + List.map (fun rps -> RP_app (id,rps)) res_pats' + else [res_pat] + | RP_any -> + let res_args = subpats args (List.map (fun _ -> RP_any) args) in + (List.map (fun l -> (RP_app (id,l))) res_args) @ + (Bindings.find id ctx.constructor_to_rest) + | _ -> inconsistent () + ) + | P_list ps -> + (match ps with + | p1::ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1,P_aux (P_list ptl,ann)),ann)) res_pat + | [] -> + match res_pat with + | RP_any -> [RP_cons (RP_any,RP_any)] + | RP_cons _ -> [res_pat] + | RP_nil -> [] + | _ -> inconsistent ()) + | P_cons (p1,p2) -> begin + let rp',rps = + match res_pat with + | RP_cons (rp1,rp2) -> [], Some [rp1;rp2] + | RP_any -> [RP_nil], Some [RP_any;RP_any] + | RP_nil -> [RP_nil], None + | _ -> inconsistent () + in + match rps with + | None -> rp' + | Some rps -> + let res_pats = subpats [p1;p2] rps in + rp' @ List.map (function [rp1;rp2] -> RP_cons (rp1,rp2) | _ -> assert false) res_pats + end + | P_record _ -> + raise (Reporting_basic.err_unreachable (fst ann) + "Record pattern not supported") + | P_vector _ + | P_vector_concat _ + | P_string_append _ -> + raise (Reporting_basic.err_unreachable (fst ann) + "Found pattern that should have been rewritten away in earlier stage") + + (*in let _ = printprefix := String.sub (!printprefix) 0 (String.length !printprefix - 2) + in let _ = print_endline (!printprefix ^ "res_pats': " ^ String.concat "; " (List.map string_of_rp rp'))*) + in rp' + +let process_pexp env = + let ctx = ctx_from_pattern_completeness_ctx env in + fun rps patexp -> + (*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in + let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*) + match patexp with + | Pat_aux (Pat_exp (p,_),_) -> + List.concat (List.map (remove_clause_from_pattern ctx p) rps) + | Pat_aux (Pat_when _,(l,_)) -> + raise (Reporting_basic.err_unreachable l + "Guarded pattern should have been rewritten away") + +let rewrite_case (e,ann) = + match e with + | E_case (e1,cases) -> + begin + let env = env_of_annot ann in + let rps = List.fold_left (process_pexp env) [RP_any] cases in + match rps with + | [] -> E_aux (E_case (e1,cases),ann) + | (example::_) -> + + let _ = + if !opt_coq_warn_nonexhaustive + then Reporting_basic.print_err false false + (fst ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) in + + let l = Parse_ast.Generated Parse_ast.Unknown in + let p = P_aux (P_wild, (l, None)) in + let ann' = Some (env, typ_of_annot ann, mk_effect [BE_escape]) in + (* TODO: use an expression that specifically indicates a failed pattern match *) + let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)),(l,None))),(l,ann')) in + E_aux (E_case (e1,cases@[Pat_aux (Pat_exp (p,b),(l,None))]),ann) + end + | _ -> E_aux (e,ann) + +let rewrite = + let alg = { id_exp_alg with e_aux = rewrite_case } in + rewrite_defs_base + { rewrite_exp = (fun _ -> fold_exp alg) + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + + +end + let recheck_defs defs = fst (Type_error.check initial_env defs) let remove_mapping_valspecs (Defs defs) = @@ -4049,6 +4306,45 @@ let rewrite_defs_lem = [ ("recheck_defs", recheck_defs) ] +let rewrite_defs_coq = [ + ("realise_mappings", rewrite_defs_realise_mappings); + ("remove_mapping_valspecs", remove_mapping_valspecs); + ("pat_string_append", rewrite_defs_pat_string_append); + ("mapping_builtins", rewrite_defs_mapping_patterns); + ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem); + ("vector_concat_assignments", rewrite_vector_concat_assignments); + ("tuple_assignments", rewrite_tuple_assignments); + ("simple_assignments", rewrite_simple_assignments); + ("remove_vector_concat", rewrite_defs_remove_vector_concat); + ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats); + ("remove_numeral_pats", rewrite_defs_remove_numeral_pats); + ("guarded_pats", rewrite_defs_guarded_pats); + ("bitvector_exps", rewrite_bitvector_exps); + (* ("register_ref_writes", rewrite_register_ref_writes); *) + ("nexp_ids", rewrite_defs_nexp_ids); + ("fix_val_specs", rewrite_fix_val_specs); + ("split_execute", rewrite_split_fun_constr_pats "execute"); + ("recheck_defs", recheck_defs); + ("exp_lift_assign", rewrite_defs_exp_lift_assign); + (* ("constraint", rewrite_constraint); *) + (* ("remove_assert", rewrite_defs_remove_assert); *) + ("top_sort_defs", top_sort_defs); + ("trivial_sizeof", rewrite_trivial_sizeof); + ("sizeof", rewrite_sizeof); + ("early_return", rewrite_defs_early_return); + ("make_cases_exhaustive", MakeExhaustive.rewrite); + ("fix_val_specs", rewrite_fix_val_specs); + ("recheck_defs", recheck_defs); + ("remove_blocks", rewrite_defs_remove_blocks); + ("letbind_effects", rewrite_defs_letbind_effects); + ("remove_e_assign", rewrite_defs_remove_e_assign); + ("internal_lets", rewrite_defs_internal_lets); + ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds); + ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns); + ("merge function clauses", merge_funcls); + ("recheck_defs", recheck_defs) + ] + let rewrite_defs_ocaml = [ (* ("undefined", rewrite_undefined); *) ("no_effect_check", (fun defs -> opt_no_effects := true; defs)); diff --git a/src/rewrites.mli b/src/rewrites.mli index 70cb75af..7d6bc0b2 100644 --- a/src/rewrites.mli +++ b/src/rewrites.mli @@ -66,6 +66,13 @@ val rewrite_defs_interpreter : (string * (tannot defs -> tannot defs)) list (* Perform rewrites to exclude AST nodes not supported for lem out*) val rewrite_defs_lem : (string * (tannot defs -> tannot defs)) list +(* Perform rewrites to exclude AST nodes not supported for coq out*) +val rewrite_defs_coq : (string * (tannot defs -> tannot defs)) list + +(* Warn about matches where we add a default case for Coq because they're not + exhaustive *) +val opt_coq_warn_nonexhaustive : bool ref + (* Perform rewrites to exclude AST nodes not supported for C compilation *) val rewrite_defs_c : (string * (tannot defs -> tannot defs)) list diff --git a/src/sail.ml b/src/sail.ml index 64e60c23..f0903146 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -151,6 +151,9 @@ let options = Arg.align ([ ( "-dcoq_undef_axioms", Arg.Set Pretty_print_coq.opt_undef_axioms, "Generate axioms for functions that are declared but not defined"); + ( "-dcoq_warn_nonex", + Arg.Set Rewrites.opt_coq_warn_nonexhaustive, + "Generate warnings for non-exhaustive pattern matches in the Coq backend"); ( "-latex_prefix", Arg.String (fun prefix -> Latex.opt_prefix_latex := prefix), " set a custom prefix for generated latex command (default sail)"); diff --git a/src/type_check.mli b/src/type_check.mli index 665981e9..31a5a8dd 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -197,6 +197,7 @@ module Env : sig val empty : t + val pattern_completeness_ctx : t -> Pattern_completeness.ctx end (** Push all the type variables and constraints from a typquant into |
