summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBrian Campbell2018-07-23 16:14:12 +0100
committerBrian Campbell2018-07-23 16:15:53 +0100
commitf3ef82fee78d40c628d319dab4cc35a41c638e8e (patch)
tree94a4ac32a05c69f0ef9a69d99e6207e9777f9e68 /src
parent4c25326519d00bc781d6ee33ca507d1d525af686 (diff)
Coq: make all pattern matches in the output exhaustive
Uses previous stage to deal with (e.g.) guards. New option -dcoq_warn_nonex tells you where all of the extra default cases were added.
Diffstat (limited to 'src')
-rw-r--r--src/process_file.ml2
-rw-r--r--src/rewrites.ml296
-rw-r--r--src/rewrites.mli7
-rw-r--r--src/sail.ml3
-rw-r--r--src/type_check.mli1
5 files changed, 308 insertions, 1 deletions
diff --git a/src/process_file.ml b/src/process_file.ml
index 9ed52e8d..c3e1b510 100644
--- a/src/process_file.ml
+++ b/src/process_file.ml
@@ -398,7 +398,7 @@ let rewrite rewriters defs =
let rewrite_ast = rewrite [("initial", Rewriter.rewrite_defs)]
let rewrite_undefined bitvectors = rewrite [("undefined", fun x -> Rewrites.rewrite_undefined bitvectors x)]
let rewrite_ast_lem = rewrite Rewrites.rewrite_defs_lem
-let rewrite_ast_coq = rewrite Rewrites.rewrite_defs_lem
+let rewrite_ast_coq = rewrite Rewrites.rewrite_defs_coq
let rewrite_ast_ocaml = rewrite Rewrites.rewrite_defs_ocaml
let rewrite_ast_c ast =
ast
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 8fe30d6b..246a2670 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -4001,6 +4001,263 @@ let rewrite_defs_realise_mappings (Defs defs) =
in
Defs (List.map rewrite_def defs |> List.flatten)
+
+(* Rewrite to make all pattern matches in Coq output exhaustive.
+ Assumes that guards, vector patterns, etc have been rewritten already. *)
+
+let opt_coq_warn_nonexhaustive = ref false
+
+module MakeExhaustive =
+struct
+
+type rlit =
+ | RL_unit
+ | RL_zero
+ | RL_one
+ | RL_true
+ | RL_false
+ | RL_inf
+
+let string_of_rlit = function
+ | RL_unit -> "()"
+ | RL_zero -> "bitzero"
+ | RL_one -> "bitone"
+ | RL_true -> "true"
+ | RL_false -> "false"
+ | RL_inf -> "..."
+
+let rlit_of_lit (L_aux (l,_)) =
+ match l with
+ | L_unit -> RL_unit
+ | L_zero -> RL_zero
+ | L_one -> RL_one
+ | L_true -> RL_true
+ | L_false -> RL_false
+ | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> RL_inf
+ | L_undef -> assert false
+
+let inv_rlit_of_lit (L_aux (l,_)) =
+ match l with
+ | L_unit -> []
+ | L_zero -> [RL_one]
+ | L_one -> [RL_zero]
+ | L_true -> [RL_false]
+ | L_false -> [RL_true]
+ | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> [RL_inf]
+ | L_undef -> assert false
+
+type residual_pattern =
+ | RP_any
+ | RP_lit of rlit
+ | RP_enum of id
+ | RP_app of id * residual_pattern list
+ | RP_tup of residual_pattern list
+ | RP_nil
+ | RP_cons of residual_pattern * residual_pattern
+
+let rec string_of_rp = function
+ | RP_any -> "_"
+ | RP_lit rlit -> string_of_rlit rlit
+ | RP_enum id -> string_of_id id
+ | RP_app (f,args) -> string_of_id f ^ "(" ^ String.concat "," (List.map string_of_rp args) ^ ")"
+ | RP_tup rps -> "(" ^ String.concat "," (List.map string_of_rp rps) ^ ")"
+ | RP_nil -> "[| |]"
+ | RP_cons (rp1,rp2) -> string_of_rp rp1 ^ "::" ^ string_of_rp rp2
+
+type ctx = {
+ env : Env.t;
+ enum_to_rest: (residual_pattern list) Bindings.t;
+ constructor_to_rest: (residual_pattern list) Bindings.t
+}
+
+let make_enum_mappings ids m =
+ IdSet.fold (fun id m ->
+ Bindings.add id
+ (List.map (fun e -> RP_enum e) (IdSet.elements (IdSet.remove id ids))) m)
+ ids
+ m
+
+let make_cstr_mappings env ids m =
+ let ids = IdSet.elements ids in
+ let constructors = List.map
+ (fun id ->
+ let _,ty = Env.get_val_spec id env in
+ let args = match ty with
+ | Typ_aux (Typ_fn (Typ_aux (Typ_tup tys,_),_,_),_) -> List.map (fun _ -> RP_any) tys
+ | _ -> [RP_any]
+ in RP_app (id,args)) ids in
+ let rec aux ids acc l =
+ match ids, l with
+ | [], [] -> m
+ | id::ids, rp::t ->
+ let m = aux ids (acc@[rp]) t in
+ Bindings.add id (acc@t) m
+ | _ -> assert false
+ in aux ids [] constructors
+
+let ctx_from_pattern_completeness_ctx env =
+ let ctx = Env.pattern_completeness_ctx env in
+ { env = env;
+ enum_to_rest = Bindings.fold (fun _ ids m -> make_enum_mappings ids m)
+ ctx.Pattern_completeness.enums Bindings.empty;
+ constructor_to_rest = Bindings.fold (fun _ ids m -> make_cstr_mappings env ids m)
+ ctx.Pattern_completeness.variants Bindings.empty
+ }
+
+let printprefix = ref " "
+
+let rec remove_clause_from_pattern ctx (P_aux (rm_pat,ann)) res_pat =
+ let subpats rm_pats res_pats =
+ let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in
+ let rec aux acc fixed residual =
+ match fixed, residual with
+ | [], [] -> []
+ | (fh::ft), (rh::rt) ->
+ let rt' = aux (acc@[fh]) ft rt in
+ let newr = List.map (fun x -> acc @ (x::ft)) rh in
+ newr @ rt'
+ | _,_ -> assert false (* impossible because we managed map2 above *)
+ in aux [] res_pats res_pats'
+ in
+ let inconsistent () =
+ raise (Reporting_basic.err_unreachable (fst ann)
+ ("Inconsistency during exhaustiveness analysis with " ^
+ string_of_rp res_pat))
+ in
+ (*let _ = print_endline (!printprefix ^ "pat: " ^string_of_pat (P_aux (rm_pat,ann))) in
+ let _ = print_endline (!printprefix ^ "res_pat: " ^string_of_rp res_pat) in
+ let _ = printprefix := " " ^ !printprefix in*)
+ let rp' =
+ match rm_pat with
+ | P_wild -> []
+ | P_id id when (match Env.lookup_id id ctx.env with Unbound | Local _ -> true | _ -> false) -> []
+ | P_lit lit ->
+ (match res_pat with
+ | RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit)
+ | RP_lit RL_inf -> [res_pat]
+ | RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat]
+ | _ -> inconsistent ())
+ | P_as (p,_)
+ | P_typ (_,p)
+ | P_var (p,_)
+ -> remove_clause_from_pattern ctx p res_pat
+ | P_id id ->
+ (match Env.lookup_id id ctx.env with
+ | Enum enum ->
+ (match res_pat with
+ | RP_any -> Bindings.find id ctx.enum_to_rest
+ | RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat]
+ | _ -> inconsistent ())
+ | _ -> assert false)
+ | P_tup rm_pats ->
+ let previous_res_pats =
+ match res_pat with
+ | RP_tup res_pats -> res_pats
+ | RP_any -> List.map (fun _ -> RP_any) rm_pats
+ | _ -> inconsistent ()
+ in
+ let res_pats' = subpats rm_pats previous_res_pats in
+ List.map (fun rps -> RP_tup rps) res_pats'
+ | P_app (id,args) ->
+ (match res_pat with
+ | RP_app (id',residual_args) ->
+ if Id.compare id id' == 0 then
+ let res_pats' = subpats args residual_args in
+ List.map (fun rps -> RP_app (id,rps)) res_pats'
+ else [res_pat]
+ | RP_any ->
+ let res_args = subpats args (List.map (fun _ -> RP_any) args) in
+ (List.map (fun l -> (RP_app (id,l))) res_args) @
+ (Bindings.find id ctx.constructor_to_rest)
+ | _ -> inconsistent ()
+ )
+ | P_list ps ->
+ (match ps with
+ | p1::ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1,P_aux (P_list ptl,ann)),ann)) res_pat
+ | [] ->
+ match res_pat with
+ | RP_any -> [RP_cons (RP_any,RP_any)]
+ | RP_cons _ -> [res_pat]
+ | RP_nil -> []
+ | _ -> inconsistent ())
+ | P_cons (p1,p2) -> begin
+ let rp',rps =
+ match res_pat with
+ | RP_cons (rp1,rp2) -> [], Some [rp1;rp2]
+ | RP_any -> [RP_nil], Some [RP_any;RP_any]
+ | RP_nil -> [RP_nil], None
+ | _ -> inconsistent ()
+ in
+ match rps with
+ | None -> rp'
+ | Some rps ->
+ let res_pats = subpats [p1;p2] rps in
+ rp' @ List.map (function [rp1;rp2] -> RP_cons (rp1,rp2) | _ -> assert false) res_pats
+ end
+ | P_record _ ->
+ raise (Reporting_basic.err_unreachable (fst ann)
+ "Record pattern not supported")
+ | P_vector _
+ | P_vector_concat _
+ | P_string_append _ ->
+ raise (Reporting_basic.err_unreachable (fst ann)
+ "Found pattern that should have been rewritten away in earlier stage")
+
+ (*in let _ = printprefix := String.sub (!printprefix) 0 (String.length !printprefix - 2)
+ in let _ = print_endline (!printprefix ^ "res_pats': " ^ String.concat "; " (List.map string_of_rp rp'))*)
+ in rp'
+
+let process_pexp env =
+ let ctx = ctx_from_pattern_completeness_ctx env in
+ fun rps patexp ->
+ (*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in
+ let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*)
+ match patexp with
+ | Pat_aux (Pat_exp (p,_),_) ->
+ List.concat (List.map (remove_clause_from_pattern ctx p) rps)
+ | Pat_aux (Pat_when _,(l,_)) ->
+ raise (Reporting_basic.err_unreachable l
+ "Guarded pattern should have been rewritten away")
+
+let rewrite_case (e,ann) =
+ match e with
+ | E_case (e1,cases) ->
+ begin
+ let env = env_of_annot ann in
+ let rps = List.fold_left (process_pexp env) [RP_any] cases in
+ match rps with
+ | [] -> E_aux (E_case (e1,cases),ann)
+ | (example::_) ->
+
+ let _ =
+ if !opt_coq_warn_nonexhaustive
+ then Reporting_basic.print_err false false
+ (fst ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) in
+
+ let l = Parse_ast.Generated Parse_ast.Unknown in
+ let p = P_aux (P_wild, (l, None)) in
+ let ann' = Some (env, typ_of_annot ann, mk_effect [BE_escape]) in
+ (* TODO: use an expression that specifically indicates a failed pattern match *)
+ let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)),(l,None))),(l,ann')) in
+ E_aux (E_case (e1,cases@[Pat_aux (Pat_exp (p,b),(l,None))]),ann)
+ end
+ | _ -> E_aux (e,ann)
+
+let rewrite =
+ let alg = { id_exp_alg with e_aux = rewrite_case } in
+ rewrite_defs_base
+ { rewrite_exp = (fun _ -> fold_exp alg)
+ ; rewrite_pat = rewrite_pat
+ ; rewrite_let = rewrite_let
+ ; rewrite_lexp = rewrite_lexp
+ ; rewrite_fun = rewrite_fun
+ ; rewrite_def = rewrite_def
+ ; rewrite_defs = rewrite_defs_base
+ }
+
+
+end
+
let recheck_defs defs = fst (Type_error.check initial_env defs)
let remove_mapping_valspecs (Defs defs) =
@@ -4049,6 +4306,45 @@ let rewrite_defs_lem = [
("recheck_defs", recheck_defs)
]
+let rewrite_defs_coq = [
+ ("realise_mappings", rewrite_defs_realise_mappings);
+ ("remove_mapping_valspecs", remove_mapping_valspecs);
+ ("pat_string_append", rewrite_defs_pat_string_append);
+ ("mapping_builtins", rewrite_defs_mapping_patterns);
+ ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem);
+ ("vector_concat_assignments", rewrite_vector_concat_assignments);
+ ("tuple_assignments", rewrite_tuple_assignments);
+ ("simple_assignments", rewrite_simple_assignments);
+ ("remove_vector_concat", rewrite_defs_remove_vector_concat);
+ ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats);
+ ("remove_numeral_pats", rewrite_defs_remove_numeral_pats);
+ ("guarded_pats", rewrite_defs_guarded_pats);
+ ("bitvector_exps", rewrite_bitvector_exps);
+ (* ("register_ref_writes", rewrite_register_ref_writes); *)
+ ("nexp_ids", rewrite_defs_nexp_ids);
+ ("fix_val_specs", rewrite_fix_val_specs);
+ ("split_execute", rewrite_split_fun_constr_pats "execute");
+ ("recheck_defs", recheck_defs);
+ ("exp_lift_assign", rewrite_defs_exp_lift_assign);
+ (* ("constraint", rewrite_constraint); *)
+ (* ("remove_assert", rewrite_defs_remove_assert); *)
+ ("top_sort_defs", top_sort_defs);
+ ("trivial_sizeof", rewrite_trivial_sizeof);
+ ("sizeof", rewrite_sizeof);
+ ("early_return", rewrite_defs_early_return);
+ ("make_cases_exhaustive", MakeExhaustive.rewrite);
+ ("fix_val_specs", rewrite_fix_val_specs);
+ ("recheck_defs", recheck_defs);
+ ("remove_blocks", rewrite_defs_remove_blocks);
+ ("letbind_effects", rewrite_defs_letbind_effects);
+ ("remove_e_assign", rewrite_defs_remove_e_assign);
+ ("internal_lets", rewrite_defs_internal_lets);
+ ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds);
+ ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns);
+ ("merge function clauses", merge_funcls);
+ ("recheck_defs", recheck_defs)
+ ]
+
let rewrite_defs_ocaml = [
(* ("undefined", rewrite_undefined); *)
("no_effect_check", (fun defs -> opt_no_effects := true; defs));
diff --git a/src/rewrites.mli b/src/rewrites.mli
index 70cb75af..7d6bc0b2 100644
--- a/src/rewrites.mli
+++ b/src/rewrites.mli
@@ -66,6 +66,13 @@ val rewrite_defs_interpreter : (string * (tannot defs -> tannot defs)) list
(* Perform rewrites to exclude AST nodes not supported for lem out*)
val rewrite_defs_lem : (string * (tannot defs -> tannot defs)) list
+(* Perform rewrites to exclude AST nodes not supported for coq out*)
+val rewrite_defs_coq : (string * (tannot defs -> tannot defs)) list
+
+(* Warn about matches where we add a default case for Coq because they're not
+ exhaustive *)
+val opt_coq_warn_nonexhaustive : bool ref
+
(* Perform rewrites to exclude AST nodes not supported for C compilation *)
val rewrite_defs_c : (string * (tannot defs -> tannot defs)) list
diff --git a/src/sail.ml b/src/sail.ml
index 64e60c23..f0903146 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -151,6 +151,9 @@ let options = Arg.align ([
( "-dcoq_undef_axioms",
Arg.Set Pretty_print_coq.opt_undef_axioms,
"Generate axioms for functions that are declared but not defined");
+ ( "-dcoq_warn_nonex",
+ Arg.Set Rewrites.opt_coq_warn_nonexhaustive,
+ "Generate warnings for non-exhaustive pattern matches in the Coq backend");
( "-latex_prefix",
Arg.String (fun prefix -> Latex.opt_prefix_latex := prefix),
" set a custom prefix for generated latex command (default sail)");
diff --git a/src/type_check.mli b/src/type_check.mli
index 665981e9..31a5a8dd 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -197,6 +197,7 @@ module Env : sig
val empty : t
+ val pattern_completeness_ctx : t -> Pattern_completeness.ctx
end
(** Push all the type variables and constraints from a typquant into