diff options
| author | Alasdair Armstrong | 2018-07-24 18:09:18 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2018-07-24 18:09:18 +0100 |
| commit | 6b4f407ad34ca7d4d8a89a5a4d401ac80c7413b0 (patch) | |
| tree | ed09b22b7ea4ca20fbcc89b761f1955caea85041 /src/rewrites.ml | |
| parent | dafb09e7c26840dce3d522fef3cf359729ca5b61 (diff) | |
| parent | 8114501b7b956ee4a98fa8599c7efee62fc19206 (diff) | |
Merge remote-tracking branch 'origin/sail2' into c_fixes
Diffstat (limited to 'src/rewrites.ml')
| -rw-r--r-- | src/rewrites.ml | 352 |
1 files changed, 340 insertions, 12 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index 214ca571..246a2670 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1168,6 +1168,25 @@ let case_exp e t cs = (* let efr = union_effs (List.map effect_of_pexp ps) in *) fix_eff_exp (annot_exp (E_case (e,ps)) l env t) +(* Rewrite guarded patterns into a combination of if-expressions and + unguarded pattern matches + + Strategy: + - Split clauses into groups where the first pattern subsumes all the + following ones + - Translate the groups in reverse order, using the next group as a + fall-through target, if there is one + - Within a group, + - translate the sequence of clauses to an if-then-else cascade using the + guards as long as the patterns are equivalent modulo substitution, or + - recursively translate the remaining clauses to a pattern match if + there is a difference in the patterns. + + TODO: Compare this more closely with the algorithm in the CPP'18 paper of + Spector-Zabusky et al, who seem to use the opposite grouping and merging + strategy to ours: group *mutually exclusive* clauses, and try to merge them + into a pattern match first instead of an if-then-else cascade. +*) let rewrite_guarded_clauses l cs = let rec group fallthrough clauses = let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in @@ -2285,6 +2304,7 @@ let rewrite_type_def_typs rw_typ rw_typquant rw_typschm (TD_aux (td, annot)) = let rewrite_dec_spec_typs rw_typ (DEC_aux (ds, annot)) = match ds with | DEC_reg (typ, id) -> DEC_aux (DEC_reg (rw_typ typ, id), annot) + | DEC_config (id, typ, exp) -> DEC_aux (DEC_config (id, rw_typ typ, exp), annot) | _ -> assert false (* Remove overload definitions and cast val specs from the @@ -2606,7 +2626,7 @@ let rewrite_defs_letbind_effects = match lexp_aux with | LEXP_id _ -> k lexp | LEXP_deref exp -> - n_exp exp (fun exp -> + n_exp_name exp (fun exp -> k (fix_eff_lexp (LEXP_aux (LEXP_deref exp, annot)))) | LEXP_memory (id,es) -> n_exp_nameL es (fun es -> @@ -3305,8 +3325,15 @@ let rewrite_defs_mapping_patterns = in pexp_rewriters rewrite_pexp +let rewrite_lit_lem (L_aux (lit, _)) = match lit with + | L_num _ | L_string _ | L_hex _ | L_bin _ | L_real _ -> true + | _ -> false + +let rewrite_no_strings (L_aux (lit, _)) = match lit with + | L_string _ -> false + | _ -> true -let rewrite_defs_pat_lits = +let rewrite_defs_pat_lits rewrite_lit = let rewrite_pexp (Pat_aux (pexp_aux, annot) as pexp) = let guards = ref [] in let counter = ref 0 in @@ -3314,7 +3341,7 @@ let rewrite_defs_pat_lits = let rewrite_pat = function (* HACK: ignore strings for now *) | P_lit (L_aux (L_string _, _)) as p_aux, p_annot -> P_aux (p_aux, p_annot) - | P_lit lit, p_annot -> + | P_lit lit, p_annot when rewrite_lit lit -> let env = env_of_annot p_annot in let typ = typ_of_annot p_annot in let id = mk_id ("p" ^ string_of_int !counter ^ "#") in @@ -3513,11 +3540,12 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let eff = union_eff_exps [c;e1;e2] in let v = E_aux (E_if (c,e1,e2), (gen_loc el, Some (env, typ, eff))) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) - | E_case (e1,ps) -> - (* after rewrite_defs_letbind_effects e1 needs no rewriting *) + | E_case (e1,ps) | E_try (e1, ps) -> + let is_case = match expaux with E_case _ -> true | _ -> false in let vars, varpats = - ps - |> List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e) + (* for E_case, e1 needs no rewriting after rewrite_defs_letbind_effects *) + ((if is_case then [] else [e1]) @ + List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e) ps) |> List.map find_updated_vars |> List.fold_left IdSet.union IdSet.empty |> IdSet.inter used_vars @@ -3528,8 +3556,10 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = Pat_aux (Pat_exp (p,rewrite_var_updates e),a) | Pat_aux (Pat_when (p,g,e),a) -> Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in - Same_vars (E_aux (E_case (e1,ps),annot)) + let expaux = if is_case then E_case (e1, ps) else E_try (e1, ps) in + Same_vars (E_aux (expaux, annot)) else + let e1 = if is_case then e1 else rewrite_var_updates (add_vars overwrite e1 vars) in let rewrite_pexp (Pat_aux (pexp, (l, _))) = match pexp with | Pat_exp (pat, exp) -> let exp = rewrite_var_updates (add_vars overwrite exp vars) in @@ -3538,10 +3568,12 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | Pat_when _ -> raise (Reporting_basic.err_unreachable l "Guarded patterns should have been rewritten already") in + let ps = List.map rewrite_pexp ps in + let expaux = if is_case then E_case (e1, ps) else E_try (e1, ps) in let typ = match ps with | Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_) :: _ -> typ_of first | _ -> unit_typ in - let v = fix_eff_exp (annot_exp (E_case (e1, List.map rewrite_pexp ps)) pl env typ) in + let v = fix_eff_exp (annot_exp expaux pl env typ) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_assign (lexp,vexp) -> let mk_id_pat id = match Env.lookup_id id env with @@ -3969,6 +4001,263 @@ let rewrite_defs_realise_mappings (Defs defs) = in Defs (List.map rewrite_def defs |> List.flatten) + +(* Rewrite to make all pattern matches in Coq output exhaustive. + Assumes that guards, vector patterns, etc have been rewritten already. *) + +let opt_coq_warn_nonexhaustive = ref false + +module MakeExhaustive = +struct + +type rlit = + | RL_unit + | RL_zero + | RL_one + | RL_true + | RL_false + | RL_inf + +let string_of_rlit = function + | RL_unit -> "()" + | RL_zero -> "bitzero" + | RL_one -> "bitone" + | RL_true -> "true" + | RL_false -> "false" + | RL_inf -> "..." + +let rlit_of_lit (L_aux (l,_)) = + match l with + | L_unit -> RL_unit + | L_zero -> RL_zero + | L_one -> RL_one + | L_true -> RL_true + | L_false -> RL_false + | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> RL_inf + | L_undef -> assert false + +let inv_rlit_of_lit (L_aux (l,_)) = + match l with + | L_unit -> [] + | L_zero -> [RL_one] + | L_one -> [RL_zero] + | L_true -> [RL_false] + | L_false -> [RL_true] + | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> [RL_inf] + | L_undef -> assert false + +type residual_pattern = + | RP_any + | RP_lit of rlit + | RP_enum of id + | RP_app of id * residual_pattern list + | RP_tup of residual_pattern list + | RP_nil + | RP_cons of residual_pattern * residual_pattern + +let rec string_of_rp = function + | RP_any -> "_" + | RP_lit rlit -> string_of_rlit rlit + | RP_enum id -> string_of_id id + | RP_app (f,args) -> string_of_id f ^ "(" ^ String.concat "," (List.map string_of_rp args) ^ ")" + | RP_tup rps -> "(" ^ String.concat "," (List.map string_of_rp rps) ^ ")" + | RP_nil -> "[| |]" + | RP_cons (rp1,rp2) -> string_of_rp rp1 ^ "::" ^ string_of_rp rp2 + +type ctx = { + env : Env.t; + enum_to_rest: (residual_pattern list) Bindings.t; + constructor_to_rest: (residual_pattern list) Bindings.t +} + +let make_enum_mappings ids m = + IdSet.fold (fun id m -> + Bindings.add id + (List.map (fun e -> RP_enum e) (IdSet.elements (IdSet.remove id ids))) m) + ids + m + +let make_cstr_mappings env ids m = + let ids = IdSet.elements ids in + let constructors = List.map + (fun id -> + let _,ty = Env.get_val_spec id env in + let args = match ty with + | Typ_aux (Typ_fn (Typ_aux (Typ_tup tys,_),_,_),_) -> List.map (fun _ -> RP_any) tys + | _ -> [RP_any] + in RP_app (id,args)) ids in + let rec aux ids acc l = + match ids, l with + | [], [] -> m + | id::ids, rp::t -> + let m = aux ids (acc@[rp]) t in + Bindings.add id (acc@t) m + | _ -> assert false + in aux ids [] constructors + +let ctx_from_pattern_completeness_ctx env = + let ctx = Env.pattern_completeness_ctx env in + { env = env; + enum_to_rest = Bindings.fold (fun _ ids m -> make_enum_mappings ids m) + ctx.Pattern_completeness.enums Bindings.empty; + constructor_to_rest = Bindings.fold (fun _ ids m -> make_cstr_mappings env ids m) + ctx.Pattern_completeness.variants Bindings.empty + } + +let printprefix = ref " " + +let rec remove_clause_from_pattern ctx (P_aux (rm_pat,ann)) res_pat = + let subpats rm_pats res_pats = + let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in + let rec aux acc fixed residual = + match fixed, residual with + | [], [] -> [] + | (fh::ft), (rh::rt) -> + let rt' = aux (acc@[fh]) ft rt in + let newr = List.map (fun x -> acc @ (x::ft)) rh in + newr @ rt' + | _,_ -> assert false (* impossible because we managed map2 above *) + in aux [] res_pats res_pats' + in + let inconsistent () = + raise (Reporting_basic.err_unreachable (fst ann) + ("Inconsistency during exhaustiveness analysis with " ^ + string_of_rp res_pat)) + in + (*let _ = print_endline (!printprefix ^ "pat: " ^string_of_pat (P_aux (rm_pat,ann))) in + let _ = print_endline (!printprefix ^ "res_pat: " ^string_of_rp res_pat) in + let _ = printprefix := " " ^ !printprefix in*) + let rp' = + match rm_pat with + | P_wild -> [] + | P_id id when (match Env.lookup_id id ctx.env with Unbound | Local _ -> true | _ -> false) -> [] + | P_lit lit -> + (match res_pat with + | RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit) + | RP_lit RL_inf -> [res_pat] + | RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat] + | _ -> inconsistent ()) + | P_as (p,_) + | P_typ (_,p) + | P_var (p,_) + -> remove_clause_from_pattern ctx p res_pat + | P_id id -> + (match Env.lookup_id id ctx.env with + | Enum enum -> + (match res_pat with + | RP_any -> Bindings.find id ctx.enum_to_rest + | RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat] + | _ -> inconsistent ()) + | _ -> assert false) + | P_tup rm_pats -> + let previous_res_pats = + match res_pat with + | RP_tup res_pats -> res_pats + | RP_any -> List.map (fun _ -> RP_any) rm_pats + | _ -> inconsistent () + in + let res_pats' = subpats rm_pats previous_res_pats in + List.map (fun rps -> RP_tup rps) res_pats' + | P_app (id,args) -> + (match res_pat with + | RP_app (id',residual_args) -> + if Id.compare id id' == 0 then + let res_pats' = subpats args residual_args in + List.map (fun rps -> RP_app (id,rps)) res_pats' + else [res_pat] + | RP_any -> + let res_args = subpats args (List.map (fun _ -> RP_any) args) in + (List.map (fun l -> (RP_app (id,l))) res_args) @ + (Bindings.find id ctx.constructor_to_rest) + | _ -> inconsistent () + ) + | P_list ps -> + (match ps with + | p1::ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1,P_aux (P_list ptl,ann)),ann)) res_pat + | [] -> + match res_pat with + | RP_any -> [RP_cons (RP_any,RP_any)] + | RP_cons _ -> [res_pat] + | RP_nil -> [] + | _ -> inconsistent ()) + | P_cons (p1,p2) -> begin + let rp',rps = + match res_pat with + | RP_cons (rp1,rp2) -> [], Some [rp1;rp2] + | RP_any -> [RP_nil], Some [RP_any;RP_any] + | RP_nil -> [RP_nil], None + | _ -> inconsistent () + in + match rps with + | None -> rp' + | Some rps -> + let res_pats = subpats [p1;p2] rps in + rp' @ List.map (function [rp1;rp2] -> RP_cons (rp1,rp2) | _ -> assert false) res_pats + end + | P_record _ -> + raise (Reporting_basic.err_unreachable (fst ann) + "Record pattern not supported") + | P_vector _ + | P_vector_concat _ + | P_string_append _ -> + raise (Reporting_basic.err_unreachable (fst ann) + "Found pattern that should have been rewritten away in earlier stage") + + (*in let _ = printprefix := String.sub (!printprefix) 0 (String.length !printprefix - 2) + in let _ = print_endline (!printprefix ^ "res_pats': " ^ String.concat "; " (List.map string_of_rp rp'))*) + in rp' + +let process_pexp env = + let ctx = ctx_from_pattern_completeness_ctx env in + fun rps patexp -> + (*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in + let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*) + match patexp with + | Pat_aux (Pat_exp (p,_),_) -> + List.concat (List.map (remove_clause_from_pattern ctx p) rps) + | Pat_aux (Pat_when _,(l,_)) -> + raise (Reporting_basic.err_unreachable l + "Guarded pattern should have been rewritten away") + +let rewrite_case (e,ann) = + match e with + | E_case (e1,cases) -> + begin + let env = env_of_annot ann in + let rps = List.fold_left (process_pexp env) [RP_any] cases in + match rps with + | [] -> E_aux (E_case (e1,cases),ann) + | (example::_) -> + + let _ = + if !opt_coq_warn_nonexhaustive + then Reporting_basic.print_err false false + (fst ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) in + + let l = Parse_ast.Generated Parse_ast.Unknown in + let p = P_aux (P_wild, (l, None)) in + let ann' = Some (env, typ_of_annot ann, mk_effect [BE_escape]) in + (* TODO: use an expression that specifically indicates a failed pattern match *) + let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)),(l,None))),(l,ann')) in + E_aux (E_case (e1,cases@[Pat_aux (Pat_exp (p,b),(l,None))]),ann) + end + | _ -> E_aux (e,ann) + +let rewrite = + let alg = { id_exp_alg with e_aux = rewrite_case } in + rewrite_defs_base + { rewrite_exp = (fun _ -> fold_exp alg) + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + + +end + let recheck_defs defs = fst (Type_error.check initial_env defs) let remove_mapping_valspecs (Defs defs) = @@ -3985,7 +4274,7 @@ let rewrite_defs_lem = [ ("remove_mapping_valspecs", remove_mapping_valspecs); ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_patterns); - ("pat_lits", rewrite_defs_pat_lits); + ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem); ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); @@ -4017,13 +4306,52 @@ let rewrite_defs_lem = [ ("recheck_defs", recheck_defs) ] +let rewrite_defs_coq = [ + ("realise_mappings", rewrite_defs_realise_mappings); + ("remove_mapping_valspecs", remove_mapping_valspecs); + ("pat_string_append", rewrite_defs_pat_string_append); + ("mapping_builtins", rewrite_defs_mapping_patterns); + ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem); + ("vector_concat_assignments", rewrite_vector_concat_assignments); + ("tuple_assignments", rewrite_tuple_assignments); + ("simple_assignments", rewrite_simple_assignments); + ("remove_vector_concat", rewrite_defs_remove_vector_concat); + ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats); + ("remove_numeral_pats", rewrite_defs_remove_numeral_pats); + ("guarded_pats", rewrite_defs_guarded_pats); + ("bitvector_exps", rewrite_bitvector_exps); + (* ("register_ref_writes", rewrite_register_ref_writes); *) + ("nexp_ids", rewrite_defs_nexp_ids); + ("fix_val_specs", rewrite_fix_val_specs); + ("split_execute", rewrite_split_fun_constr_pats "execute"); + ("recheck_defs", recheck_defs); + ("exp_lift_assign", rewrite_defs_exp_lift_assign); + (* ("constraint", rewrite_constraint); *) + (* ("remove_assert", rewrite_defs_remove_assert); *) + ("top_sort_defs", top_sort_defs); + ("trivial_sizeof", rewrite_trivial_sizeof); + ("sizeof", rewrite_sizeof); + ("early_return", rewrite_defs_early_return); + ("make_cases_exhaustive", MakeExhaustive.rewrite); + ("fix_val_specs", rewrite_fix_val_specs); + ("recheck_defs", recheck_defs); + ("remove_blocks", rewrite_defs_remove_blocks); + ("letbind_effects", rewrite_defs_letbind_effects); + ("remove_e_assign", rewrite_defs_remove_e_assign); + ("internal_lets", rewrite_defs_internal_lets); + ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds); + ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns); + ("merge function clauses", merge_funcls); + ("recheck_defs", recheck_defs) + ] + let rewrite_defs_ocaml = [ (* ("undefined", rewrite_undefined); *) ("no_effect_check", (fun defs -> opt_no_effects := true; defs)); ("realise_mappings", rewrite_defs_realise_mappings); ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_patterns); - ("pat_lits", rewrite_defs_pat_lits); + ("pat_lits", rewrite_defs_pat_lits rewrite_no_strings); ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); @@ -4045,7 +4373,7 @@ let rewrite_defs_c = [ ("realise_mappings", rewrite_defs_realise_mappings); ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_patterns); - ("pat_lits", rewrite_defs_pat_lits); + ("pat_lits", rewrite_defs_pat_lits rewrite_no_strings); ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); |
