summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/rewrites.ml54
1 files changed, 54 insertions, 0 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 25d1467f..c274ded4 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -4081,6 +4081,56 @@ let rewrite_defs_remove_superfluous_letbinds =
; rewrite_defs = rewrite_defs_base
}
+(* FIXME: We shouldn't allow nested not-patterns *)
+let rewrite_defs_not_pats =
+ let rewrite_pexp (pexp_aux, annot) =
+ let rewrite_pexp' pat exp orig_guard =
+ let guards = ref [] in
+ let not_counter = ref 0 in
+ let rewrite_not_pat (pat_aux, annot) =
+ match pat_aux with
+ | P_not pat ->
+ incr not_counter;
+ let np_id = mk_id ("np#" ^ string_of_int !not_counter) in
+ let guard =
+ mk_exp (E_case (mk_exp (E_id np_id),
+ [mk_pexp (Pat_exp (strip_pat pat, mk_lit_exp L_false));
+ mk_pexp (Pat_exp (mk_pat P_wild, mk_lit_exp L_true))]))
+ in
+ guards := (np_id, typ_of_annot annot, guard) :: !guards;
+ P_aux (P_id np_id, annot)
+
+ | _ -> P_aux (pat_aux, annot)
+ in
+ let pat = fold_pat { id_pat_alg with p_aux = rewrite_not_pat } pat in
+ begin match !guards with
+ | [] ->
+ Pat_aux (pexp_aux, annot)
+ | guards ->
+ let guard_exp =
+ match orig_guard, guards with
+ | Some guard, _ ->
+ List.fold_left (fun exp1 (_, _, exp2) -> mk_exp (E_app_infix (exp1, mk_id "&", exp2))) guard guards
+ | None, (_, _, guard) :: guards ->
+ List.fold_left (fun exp1 (_, _, exp2) -> mk_exp (E_app_infix (exp1, mk_id "&", exp2))) guard guards
+ | _ -> raise (Reporting.err_unreachable (fst annot) __POS__ "Case in not-pattern re-writing should be unreachable")
+ in
+ (* We need to construct an environment to check the match guard in *)
+ let env = env_of_pat pat in
+ let env = List.fold_left (fun env (np_id, np_typ, _) -> Env.add_local np_id (Immutable, np_typ) env) env guards in
+ let guard_exp = Type_check.check_exp env guard_exp bool_typ in
+ Pat_aux (Pat_when (pat, guard_exp, exp), annot)
+ end
+ in
+ match pexp_aux with
+ | Pat_exp (pat, exp) ->
+ rewrite_pexp' pat exp None
+ | Pat_when (pat, guard, exp) ->
+ rewrite_pexp' pat exp (Some (strip_exp guard))
+ in
+ let rw_exp = { id_exp_alg with pat_aux = rewrite_pexp } in
+ rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp) }
+
let rewrite_defs_remove_superfluous_returns =
let add_opt_cast typopt1 typopt2 annot exp =
@@ -4872,6 +4922,7 @@ let rewrite_defs_lem = [
("recheck_defs", if_mono recheck_defs);
("rewrite_undefined", rewrite_undefined_if_gen false);
("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
+ ("remove_not_pats", rewrite_defs_not_pats);
("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem);
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
@@ -4914,6 +4965,7 @@ let rewrite_defs_coq = [
("mapping_builtins", rewrite_defs_mapping_patterns);
("rewrite_undefined", rewrite_undefined_if_gen true);
("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
+ ("remove_not_pats", rewrite_defs_not_pats);
("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem);
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
@@ -4962,6 +5014,7 @@ let rewrite_defs_ocaml = [
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
+ ("remove_not_pats", rewrite_defs_not_pats);
("remove_vector_concat", rewrite_defs_remove_vector_concat);
("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats);
("remove_numeral_pats", rewrite_defs_remove_numeral_pats);
@@ -4983,6 +5036,7 @@ let rewrite_defs_c = [
("mapping_builtins", rewrite_defs_mapping_patterns);
("rewrite_undefined", rewrite_undefined_if_gen false);
("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
+ ("remove_not_pats", rewrite_defs_not_pats);
("pat_lits", rewrite_defs_pat_lits (fun _ -> true));
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);