diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/monomorphise.ml | 132 |
1 files changed, 119 insertions, 13 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 9cc1c404..2e7ad08f 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -899,6 +899,16 @@ let rec freshen_pat_bindings p = (* Use the location pairs in choices to reduce case expressions at the first location to the given case at the second. *) let apply_pat_choices choices = + let rewrite_constraint (NC_aux (nc,l) as nconstr) = + match List.assoc l choices with + | choice,_ -> begin + match nc with + | NC_set (kid,is) -> + E_constraint (NC_aux ((if choice < List.length is then NC_true else NC_false), Generated l)) + | _ -> E_constraint nconstr + end + | exception Not_found -> E_constraint nconstr + in let rewrite_case (e,cases) = match List.assoc (exp_loc e) choices with | choice,subst -> @@ -917,7 +927,9 @@ let apply_pat_choices choices = | exception Not_found -> E_case (e,cases) in let open Rewriter in - fold_exp { id_exp_alg with e_case = rewrite_case } + fold_exp { id_exp_alg with + e_constraint = rewrite_constraint; + e_case = rewrite_case } let split_defs continue_anyway splits defs = let split_constructors (Defs defs) = @@ -1484,11 +1496,19 @@ let split_defs continue_anyway splits defs = | P_id id -> (match id_match id with | None -> None + (* Total case split *) | Some None -> Some (split id l annot) + (* Where the analysis proposed a specific case split, propagate a + literal as normal, but perform a more careful transformation + otherwise *) | Some (Some (pats,l)) -> Some (List.mapi (fun i p -> - let p',subst = freshen_pat_bindings p in - P_aux (P_as (p',id),(l,annot)),[],[l,(i,subst)]) + match p with + | P_aux (P_lit lit,_) when (match lit with L_aux (L_undef,_) -> false | _ -> true) -> + p,[id,E_aux (E_lit lit,(Generated Unknown,None))],[l,(i,[])] + | _ -> + let p',subst = freshen_pat_bindings p in + P_aux (P_as (p',id),(l,annot)),[],[l,(i,subst)]) pats) ) | P_app (id,ps) -> @@ -2013,6 +2033,37 @@ let is_id env id = let ids = id :: List.map (fun (Id_aux (id,_)) -> id) ids in fun (Id_aux (x,_)) -> List.mem x ids +(* Type-agnostic pattern comparison for merging below *) + +let lit_eq' (L_aux (l1,_)) (L_aux (l2,_)) = + match l1, l2 with + | L_num n1, L_num n2 -> Big_int.equal n1 n2 + | _,_ -> l1 = l2 + +let forall2 p x y = + try List.for_all2 p x y with Invalid_argument _ -> false + +let rec pat_eq (P_aux (p1,_)) (P_aux (p2,_)) = + match p1, p2 with + | P_lit lit1, P_lit lit2 -> lit_eq' lit1 lit2 + | P_wild, P_wild -> true + | P_as (p1',id1), P_as (p2',id2) -> Id.compare id1 id2 == 0 && pat_eq p1' p2' + | P_typ (_,p1'), P_typ (_,p2') -> pat_eq p1' p2' + | P_id id1, P_id id2 -> Id.compare id1 id2 == 0 + | P_var (p1',kid1), P_var (p2',kid2) -> Kid.compare kid1 kid2 == 0 && pat_eq p1' p2' + | P_app (id1,args1), P_app (id2,args2) -> + Id.compare id1 id2 == 0 && forall2 pat_eq args1 args2 + | P_record (fpats1, flag1), P_record (fpats2, flag2) -> + flag1 == flag2 && forall2 fpat_eq fpats1 fpats2 + | P_vector ps1, P_vector ps2 + | P_vector_concat ps1, P_vector_concat ps2 + | P_tup ps1, P_tup ps2 + | P_list ps1, P_list ps2 -> List.for_all2 pat_eq ps1 ps2 + | P_cons (p1',p1''), P_cons (p2',p2'') -> pat_eq p1' p2' && pat_eq p1'' p2'' + | _,_ -> false +and fpat_eq (FP_aux (FP_Fpat (id1,p1),_)) (FP_aux (FP_Fpat (id2,p2),_)) = + Id.compare id1 id2 == 0 && pat_eq p1 p2 + module Analysis = @@ -2122,7 +2173,9 @@ let merge_detail _ x y = match x,y with | None, x -> x | x, None -> x - | Some _, Some _ -> Some Total (* TODO preserve equivalent patterns *) + | Some (Partial (ps1,l1)), Some (Partial (ps2,l2)) + when l1 = l2 && forall2 pat_eq ps1 ps2 -> x + | _ -> Some Total let dmerge x y = match x,y with @@ -2597,14 +2650,28 @@ let translate_id (Id_aux (_,l) as id) = | _ -> None in aux l -let initial_env fn_id (TypQ_aux (tq,_)) pat = +let initial_env fn_id (TypQ_aux (tq,_)) pat set_assertions = let pats = match pat with | P_aux (P_tup pats,_) -> pats | _ -> [pat] in + let default_split annot = + let env = env_of_annot annot in + let Typ_aux (typ,_) = Env.base_typ_of env (typ_of_annot annot) in + match typ with + | Typ_app (Id_aux (Id "atom",_),[Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]) -> + (match KBindings.find kid set_assertions with + | (l,is) -> + let l' = Generated l in + let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n,l')),(l',None))) is in + let pats = pats @ [P_aux (P_wild,(l',None))] in + Partial (pats,l) + | exception Not_found -> Total) + | _ -> Total + in let arg i pat = - let rec aux (P_aux (p,(l,_))) = + let rec aux (P_aux (p,(l,annot))) = let of_list pats = let ss,vs,ks = split3 (List.map aux pats) in let s = List.fold_left (ArgSplits.merge merge_detail) ArgSplits.empty ss in @@ -2634,7 +2701,7 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat = begin match translate_id id with | Some id' -> - let s = ArgSplits.singleton id' Total in + let s = ArgSplits.singleton id' (default_split (l,annot)) in s, Bindings.singleton id (Have (s,CallerArgSet.empty,CallerKidSet.empty)), KBindings.empty @@ -2671,6 +2738,46 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat = let kid_deps = List.fold_left dep_kbindings_merge kid_quant_deps kid_deps in { var_deps = var_deps; kid_deps = kid_deps } +(* When there's more than one pick the first *) +let merge_set_asserts _ x y = + match x, y with + | None, _ -> y + | _, _ -> x +let merge_set_asserts_by_kid sets1 sets2 = + KBindings.merge merge_set_asserts sets1 sets2 + +(* Find all the easily reached set assertions in a function body, to use as + case splits *) +let rec find_set_assertions (E_aux (e,_)) = + match e with + | E_block es + | E_nondet es -> + List.fold_left merge_set_asserts_by_kid KBindings.empty (List.map find_set_assertions es) + | E_cast (_,e) -> find_set_assertions e + | E_let (LB_aux (LB_val (p,e1),_),e2) -> + let sets1 = find_set_assertions e1 in + let sets2 = find_set_assertions e2 in + let kbound = kids_bound_by_pat p in + let sets2 = KBindings.filter (fun kid _ -> not (KidSet.mem kid kbound)) sets2 in + merge_set_asserts_by_kid sets1 sets2 + | E_assert (E_aux (e1,_),_) -> begin + match e1 with + | E_constraint (NC_aux (NC_set (kid,is),l)) -> KBindings.singleton kid (l,is) + | _ -> KBindings.empty + end + | _ -> KBindings.empty + +let print_set_assertions set_assertions = + if KBindings.is_empty set_assertions then + print_endline "No top-level set assertions found." + else begin + print_endline "Top-level set assertions found:"; + KBindings.iter (fun k (l,is) -> + print_endline (string_of_kid k ^ " " ^ + String.concat "," (List.map Big_int.to_string is))) + set_assertions + end + let print_result r = let _ = print_endline (" splits: " ^ string_of_argsplits r.split) in let print_kbinding kid dep = @@ -2695,20 +2802,19 @@ let print_result r = () let analyse_funcl debug tenv (FCL_aux (FCL_Funcl (id,pexp),_)) = + let _ = if debug > 2 then print_endline (string_of_id id) else () in let pat,guard,body,_ = destruct_pexp pexp in let (tq,_) = Env.get_val_spec id tenv in - let aenv = initial_env id tq pat in + let set_assertions = find_set_assertions body in + let _ = if debug > 2 then print_set_assertions set_assertions in + let aenv = initial_env id tq pat set_assertions in let _,_,r = analyse_exp id aenv Bindings.empty body in let r = match guard with | None -> r | Some exp -> let _,_,r' = analyse_exp id aenv Bindings.empty exp in merge r r' in - let _ = - if debug > 2 then - (print_endline (string_of_id id); - print_result r) - else () + let _ = if debug > 2 then print_result r else () in r let analyse_def debug env = function |
