summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
authorAlasdair Armstrong2018-07-24 18:09:18 +0100
committerAlasdair Armstrong2018-07-24 18:09:18 +0100
commit6b4f407ad34ca7d4d8a89a5a4d401ac80c7413b0 (patch)
treeed09b22b7ea4ca20fbcc89b761f1955caea85041 /src/rewrites.ml
parentdafb09e7c26840dce3d522fef3cf359729ca5b61 (diff)
parent8114501b7b956ee4a98fa8599c7efee62fc19206 (diff)
Merge remote-tracking branch 'origin/sail2' into c_fixes
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml352
1 files changed, 340 insertions, 12 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 214ca571..246a2670 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -1168,6 +1168,25 @@ let case_exp e t cs =
(* let efr = union_effs (List.map effect_of_pexp ps) in *)
fix_eff_exp (annot_exp (E_case (e,ps)) l env t)
+(* Rewrite guarded patterns into a combination of if-expressions and
+ unguarded pattern matches
+
+ Strategy:
+ - Split clauses into groups where the first pattern subsumes all the
+ following ones
+ - Translate the groups in reverse order, using the next group as a
+ fall-through target, if there is one
+ - Within a group,
+ - translate the sequence of clauses to an if-then-else cascade using the
+ guards as long as the patterns are equivalent modulo substitution, or
+ - recursively translate the remaining clauses to a pattern match if
+ there is a difference in the patterns.
+
+ TODO: Compare this more closely with the algorithm in the CPP'18 paper of
+ Spector-Zabusky et al, who seem to use the opposite grouping and merging
+ strategy to ours: group *mutually exclusive* clauses, and try to merge them
+ into a pattern match first instead of an if-then-else cascade.
+*)
let rewrite_guarded_clauses l cs =
let rec group fallthrough clauses =
let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in
@@ -2285,6 +2304,7 @@ let rewrite_type_def_typs rw_typ rw_typquant rw_typschm (TD_aux (td, annot)) =
let rewrite_dec_spec_typs rw_typ (DEC_aux (ds, annot)) =
match ds with
| DEC_reg (typ, id) -> DEC_aux (DEC_reg (rw_typ typ, id), annot)
+ | DEC_config (id, typ, exp) -> DEC_aux (DEC_config (id, rw_typ typ, exp), annot)
| _ -> assert false
(* Remove overload definitions and cast val specs from the
@@ -2606,7 +2626,7 @@ let rewrite_defs_letbind_effects =
match lexp_aux with
| LEXP_id _ -> k lexp
| LEXP_deref exp ->
- n_exp exp (fun exp ->
+ n_exp_name exp (fun exp ->
k (fix_eff_lexp (LEXP_aux (LEXP_deref exp, annot))))
| LEXP_memory (id,es) ->
n_exp_nameL es (fun es ->
@@ -3305,8 +3325,15 @@ let rewrite_defs_mapping_patterns =
in
pexp_rewriters rewrite_pexp
+let rewrite_lit_lem (L_aux (lit, _)) = match lit with
+ | L_num _ | L_string _ | L_hex _ | L_bin _ | L_real _ -> true
+ | _ -> false
+
+let rewrite_no_strings (L_aux (lit, _)) = match lit with
+ | L_string _ -> false
+ | _ -> true
-let rewrite_defs_pat_lits =
+let rewrite_defs_pat_lits rewrite_lit =
let rewrite_pexp (Pat_aux (pexp_aux, annot) as pexp) =
let guards = ref [] in
let counter = ref 0 in
@@ -3314,7 +3341,7 @@ let rewrite_defs_pat_lits =
let rewrite_pat = function
(* HACK: ignore strings for now *)
| P_lit (L_aux (L_string _, _)) as p_aux, p_annot -> P_aux (p_aux, p_annot)
- | P_lit lit, p_annot ->
+ | P_lit lit, p_annot when rewrite_lit lit ->
let env = env_of_annot p_annot in
let typ = typ_of_annot p_annot in
let id = mk_id ("p" ^ string_of_int !counter ^ "#") in
@@ -3513,11 +3540,12 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let eff = union_eff_exps [c;e1;e2] in
let v = E_aux (E_if (c,e1,e2), (gen_loc el, Some (env, typ, eff))) in
Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
- | E_case (e1,ps) ->
- (* after rewrite_defs_letbind_effects e1 needs no rewriting *)
+ | E_case (e1,ps) | E_try (e1, ps) ->
+ let is_case = match expaux with E_case _ -> true | _ -> false in
let vars, varpats =
- ps
- |> List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e)
+ (* for E_case, e1 needs no rewriting after rewrite_defs_letbind_effects *)
+ ((if is_case then [] else [e1]) @
+ List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e) ps)
|> List.map find_updated_vars
|> List.fold_left IdSet.union IdSet.empty
|> IdSet.inter used_vars
@@ -3528,8 +3556,10 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
Pat_aux (Pat_exp (p,rewrite_var_updates e),a)
| Pat_aux (Pat_when (p,g,e),a) ->
Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in
- Same_vars (E_aux (E_case (e1,ps),annot))
+ let expaux = if is_case then E_case (e1, ps) else E_try (e1, ps) in
+ Same_vars (E_aux (expaux, annot))
else
+ let e1 = if is_case then e1 else rewrite_var_updates (add_vars overwrite e1 vars) in
let rewrite_pexp (Pat_aux (pexp, (l, _))) = match pexp with
| Pat_exp (pat, exp) ->
let exp = rewrite_var_updates (add_vars overwrite exp vars) in
@@ -3538,10 +3568,12 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| Pat_when _ ->
raise (Reporting_basic.err_unreachable l
"Guarded patterns should have been rewritten already") in
+ let ps = List.map rewrite_pexp ps in
+ let expaux = if is_case then E_case (e1, ps) else E_try (e1, ps) in
let typ = match ps with
| Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_) :: _ -> typ_of first
| _ -> unit_typ in
- let v = fix_eff_exp (annot_exp (E_case (e1, List.map rewrite_pexp ps)) pl env typ) in
+ let v = fix_eff_exp (annot_exp expaux pl env typ) in
Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
| E_assign (lexp,vexp) ->
let mk_id_pat id = match Env.lookup_id id env with
@@ -3969,6 +4001,263 @@ let rewrite_defs_realise_mappings (Defs defs) =
in
Defs (List.map rewrite_def defs |> List.flatten)
+
+(* Rewrite to make all pattern matches in Coq output exhaustive.
+ Assumes that guards, vector patterns, etc have been rewritten already. *)
+
+let opt_coq_warn_nonexhaustive = ref false
+
+module MakeExhaustive =
+struct
+
+type rlit =
+ | RL_unit
+ | RL_zero
+ | RL_one
+ | RL_true
+ | RL_false
+ | RL_inf
+
+let string_of_rlit = function
+ | RL_unit -> "()"
+ | RL_zero -> "bitzero"
+ | RL_one -> "bitone"
+ | RL_true -> "true"
+ | RL_false -> "false"
+ | RL_inf -> "..."
+
+let rlit_of_lit (L_aux (l,_)) =
+ match l with
+ | L_unit -> RL_unit
+ | L_zero -> RL_zero
+ | L_one -> RL_one
+ | L_true -> RL_true
+ | L_false -> RL_false
+ | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> RL_inf
+ | L_undef -> assert false
+
+let inv_rlit_of_lit (L_aux (l,_)) =
+ match l with
+ | L_unit -> []
+ | L_zero -> [RL_one]
+ | L_one -> [RL_zero]
+ | L_true -> [RL_false]
+ | L_false -> [RL_true]
+ | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> [RL_inf]
+ | L_undef -> assert false
+
+type residual_pattern =
+ | RP_any
+ | RP_lit of rlit
+ | RP_enum of id
+ | RP_app of id * residual_pattern list
+ | RP_tup of residual_pattern list
+ | RP_nil
+ | RP_cons of residual_pattern * residual_pattern
+
+let rec string_of_rp = function
+ | RP_any -> "_"
+ | RP_lit rlit -> string_of_rlit rlit
+ | RP_enum id -> string_of_id id
+ | RP_app (f,args) -> string_of_id f ^ "(" ^ String.concat "," (List.map string_of_rp args) ^ ")"
+ | RP_tup rps -> "(" ^ String.concat "," (List.map string_of_rp rps) ^ ")"
+ | RP_nil -> "[| |]"
+ | RP_cons (rp1,rp2) -> string_of_rp rp1 ^ "::" ^ string_of_rp rp2
+
+type ctx = {
+ env : Env.t;
+ enum_to_rest: (residual_pattern list) Bindings.t;
+ constructor_to_rest: (residual_pattern list) Bindings.t
+}
+
+let make_enum_mappings ids m =
+ IdSet.fold (fun id m ->
+ Bindings.add id
+ (List.map (fun e -> RP_enum e) (IdSet.elements (IdSet.remove id ids))) m)
+ ids
+ m
+
+let make_cstr_mappings env ids m =
+ let ids = IdSet.elements ids in
+ let constructors = List.map
+ (fun id ->
+ let _,ty = Env.get_val_spec id env in
+ let args = match ty with
+ | Typ_aux (Typ_fn (Typ_aux (Typ_tup tys,_),_,_),_) -> List.map (fun _ -> RP_any) tys
+ | _ -> [RP_any]
+ in RP_app (id,args)) ids in
+ let rec aux ids acc l =
+ match ids, l with
+ | [], [] -> m
+ | id::ids, rp::t ->
+ let m = aux ids (acc@[rp]) t in
+ Bindings.add id (acc@t) m
+ | _ -> assert false
+ in aux ids [] constructors
+
+let ctx_from_pattern_completeness_ctx env =
+ let ctx = Env.pattern_completeness_ctx env in
+ { env = env;
+ enum_to_rest = Bindings.fold (fun _ ids m -> make_enum_mappings ids m)
+ ctx.Pattern_completeness.enums Bindings.empty;
+ constructor_to_rest = Bindings.fold (fun _ ids m -> make_cstr_mappings env ids m)
+ ctx.Pattern_completeness.variants Bindings.empty
+ }
+
+let printprefix = ref " "
+
+let rec remove_clause_from_pattern ctx (P_aux (rm_pat,ann)) res_pat =
+ let subpats rm_pats res_pats =
+ let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in
+ let rec aux acc fixed residual =
+ match fixed, residual with
+ | [], [] -> []
+ | (fh::ft), (rh::rt) ->
+ let rt' = aux (acc@[fh]) ft rt in
+ let newr = List.map (fun x -> acc @ (x::ft)) rh in
+ newr @ rt'
+ | _,_ -> assert false (* impossible because we managed map2 above *)
+ in aux [] res_pats res_pats'
+ in
+ let inconsistent () =
+ raise (Reporting_basic.err_unreachable (fst ann)
+ ("Inconsistency during exhaustiveness analysis with " ^
+ string_of_rp res_pat))
+ in
+ (*let _ = print_endline (!printprefix ^ "pat: " ^string_of_pat (P_aux (rm_pat,ann))) in
+ let _ = print_endline (!printprefix ^ "res_pat: " ^string_of_rp res_pat) in
+ let _ = printprefix := " " ^ !printprefix in*)
+ let rp' =
+ match rm_pat with
+ | P_wild -> []
+ | P_id id when (match Env.lookup_id id ctx.env with Unbound | Local _ -> true | _ -> false) -> []
+ | P_lit lit ->
+ (match res_pat with
+ | RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit)
+ | RP_lit RL_inf -> [res_pat]
+ | RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat]
+ | _ -> inconsistent ())
+ | P_as (p,_)
+ | P_typ (_,p)
+ | P_var (p,_)
+ -> remove_clause_from_pattern ctx p res_pat
+ | P_id id ->
+ (match Env.lookup_id id ctx.env with
+ | Enum enum ->
+ (match res_pat with
+ | RP_any -> Bindings.find id ctx.enum_to_rest
+ | RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat]
+ | _ -> inconsistent ())
+ | _ -> assert false)
+ | P_tup rm_pats ->
+ let previous_res_pats =
+ match res_pat with
+ | RP_tup res_pats -> res_pats
+ | RP_any -> List.map (fun _ -> RP_any) rm_pats
+ | _ -> inconsistent ()
+ in
+ let res_pats' = subpats rm_pats previous_res_pats in
+ List.map (fun rps -> RP_tup rps) res_pats'
+ | P_app (id,args) ->
+ (match res_pat with
+ | RP_app (id',residual_args) ->
+ if Id.compare id id' == 0 then
+ let res_pats' = subpats args residual_args in
+ List.map (fun rps -> RP_app (id,rps)) res_pats'
+ else [res_pat]
+ | RP_any ->
+ let res_args = subpats args (List.map (fun _ -> RP_any) args) in
+ (List.map (fun l -> (RP_app (id,l))) res_args) @
+ (Bindings.find id ctx.constructor_to_rest)
+ | _ -> inconsistent ()
+ )
+ | P_list ps ->
+ (match ps with
+ | p1::ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1,P_aux (P_list ptl,ann)),ann)) res_pat
+ | [] ->
+ match res_pat with
+ | RP_any -> [RP_cons (RP_any,RP_any)]
+ | RP_cons _ -> [res_pat]
+ | RP_nil -> []
+ | _ -> inconsistent ())
+ | P_cons (p1,p2) -> begin
+ let rp',rps =
+ match res_pat with
+ | RP_cons (rp1,rp2) -> [], Some [rp1;rp2]
+ | RP_any -> [RP_nil], Some [RP_any;RP_any]
+ | RP_nil -> [RP_nil], None
+ | _ -> inconsistent ()
+ in
+ match rps with
+ | None -> rp'
+ | Some rps ->
+ let res_pats = subpats [p1;p2] rps in
+ rp' @ List.map (function [rp1;rp2] -> RP_cons (rp1,rp2) | _ -> assert false) res_pats
+ end
+ | P_record _ ->
+ raise (Reporting_basic.err_unreachable (fst ann)
+ "Record pattern not supported")
+ | P_vector _
+ | P_vector_concat _
+ | P_string_append _ ->
+ raise (Reporting_basic.err_unreachable (fst ann)
+ "Found pattern that should have been rewritten away in earlier stage")
+
+ (*in let _ = printprefix := String.sub (!printprefix) 0 (String.length !printprefix - 2)
+ in let _ = print_endline (!printprefix ^ "res_pats': " ^ String.concat "; " (List.map string_of_rp rp'))*)
+ in rp'
+
+let process_pexp env =
+ let ctx = ctx_from_pattern_completeness_ctx env in
+ fun rps patexp ->
+ (*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in
+ let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*)
+ match patexp with
+ | Pat_aux (Pat_exp (p,_),_) ->
+ List.concat (List.map (remove_clause_from_pattern ctx p) rps)
+ | Pat_aux (Pat_when _,(l,_)) ->
+ raise (Reporting_basic.err_unreachable l
+ "Guarded pattern should have been rewritten away")
+
+let rewrite_case (e,ann) =
+ match e with
+ | E_case (e1,cases) ->
+ begin
+ let env = env_of_annot ann in
+ let rps = List.fold_left (process_pexp env) [RP_any] cases in
+ match rps with
+ | [] -> E_aux (E_case (e1,cases),ann)
+ | (example::_) ->
+
+ let _ =
+ if !opt_coq_warn_nonexhaustive
+ then Reporting_basic.print_err false false
+ (fst ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) in
+
+ let l = Parse_ast.Generated Parse_ast.Unknown in
+ let p = P_aux (P_wild, (l, None)) in
+ let ann' = Some (env, typ_of_annot ann, mk_effect [BE_escape]) in
+ (* TODO: use an expression that specifically indicates a failed pattern match *)
+ let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)),(l,None))),(l,ann')) in
+ E_aux (E_case (e1,cases@[Pat_aux (Pat_exp (p,b),(l,None))]),ann)
+ end
+ | _ -> E_aux (e,ann)
+
+let rewrite =
+ let alg = { id_exp_alg with e_aux = rewrite_case } in
+ rewrite_defs_base
+ { rewrite_exp = (fun _ -> fold_exp alg)
+ ; rewrite_pat = rewrite_pat
+ ; rewrite_let = rewrite_let
+ ; rewrite_lexp = rewrite_lexp
+ ; rewrite_fun = rewrite_fun
+ ; rewrite_def = rewrite_def
+ ; rewrite_defs = rewrite_defs_base
+ }
+
+
+end
+
let recheck_defs defs = fst (Type_error.check initial_env defs)
let remove_mapping_valspecs (Defs defs) =
@@ -3985,7 +4274,7 @@ let rewrite_defs_lem = [
("remove_mapping_valspecs", remove_mapping_valspecs);
("pat_string_append", rewrite_defs_pat_string_append);
("mapping_builtins", rewrite_defs_mapping_patterns);
- ("pat_lits", rewrite_defs_pat_lits);
+ ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem);
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
@@ -4017,13 +4306,52 @@ let rewrite_defs_lem = [
("recheck_defs", recheck_defs)
]
+let rewrite_defs_coq = [
+ ("realise_mappings", rewrite_defs_realise_mappings);
+ ("remove_mapping_valspecs", remove_mapping_valspecs);
+ ("pat_string_append", rewrite_defs_pat_string_append);
+ ("mapping_builtins", rewrite_defs_mapping_patterns);
+ ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem);
+ ("vector_concat_assignments", rewrite_vector_concat_assignments);
+ ("tuple_assignments", rewrite_tuple_assignments);
+ ("simple_assignments", rewrite_simple_assignments);
+ ("remove_vector_concat", rewrite_defs_remove_vector_concat);
+ ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats);
+ ("remove_numeral_pats", rewrite_defs_remove_numeral_pats);
+ ("guarded_pats", rewrite_defs_guarded_pats);
+ ("bitvector_exps", rewrite_bitvector_exps);
+ (* ("register_ref_writes", rewrite_register_ref_writes); *)
+ ("nexp_ids", rewrite_defs_nexp_ids);
+ ("fix_val_specs", rewrite_fix_val_specs);
+ ("split_execute", rewrite_split_fun_constr_pats "execute");
+ ("recheck_defs", recheck_defs);
+ ("exp_lift_assign", rewrite_defs_exp_lift_assign);
+ (* ("constraint", rewrite_constraint); *)
+ (* ("remove_assert", rewrite_defs_remove_assert); *)
+ ("top_sort_defs", top_sort_defs);
+ ("trivial_sizeof", rewrite_trivial_sizeof);
+ ("sizeof", rewrite_sizeof);
+ ("early_return", rewrite_defs_early_return);
+ ("make_cases_exhaustive", MakeExhaustive.rewrite);
+ ("fix_val_specs", rewrite_fix_val_specs);
+ ("recheck_defs", recheck_defs);
+ ("remove_blocks", rewrite_defs_remove_blocks);
+ ("letbind_effects", rewrite_defs_letbind_effects);
+ ("remove_e_assign", rewrite_defs_remove_e_assign);
+ ("internal_lets", rewrite_defs_internal_lets);
+ ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds);
+ ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns);
+ ("merge function clauses", merge_funcls);
+ ("recheck_defs", recheck_defs)
+ ]
+
let rewrite_defs_ocaml = [
(* ("undefined", rewrite_undefined); *)
("no_effect_check", (fun defs -> opt_no_effects := true; defs));
("realise_mappings", rewrite_defs_realise_mappings);
("pat_string_append", rewrite_defs_pat_string_append);
("mapping_builtins", rewrite_defs_mapping_patterns);
- ("pat_lits", rewrite_defs_pat_lits);
+ ("pat_lits", rewrite_defs_pat_lits rewrite_no_strings);
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
@@ -4045,7 +4373,7 @@ let rewrite_defs_c = [
("realise_mappings", rewrite_defs_realise_mappings);
("pat_string_append", rewrite_defs_pat_string_append);
("mapping_builtins", rewrite_defs_mapping_patterns);
- ("pat_lits", rewrite_defs_pat_lits);
+ ("pat_lits", rewrite_defs_pat_lits rewrite_no_strings);
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);