summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml132
1 files changed, 119 insertions, 13 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 9cc1c404..2e7ad08f 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -899,6 +899,16 @@ let rec freshen_pat_bindings p =
(* Use the location pairs in choices to reduce case expressions at the first
location to the given case at the second. *)
let apply_pat_choices choices =
+ let rewrite_constraint (NC_aux (nc,l) as nconstr) =
+ match List.assoc l choices with
+ | choice,_ -> begin
+ match nc with
+ | NC_set (kid,is) ->
+ E_constraint (NC_aux ((if choice < List.length is then NC_true else NC_false), Generated l))
+ | _ -> E_constraint nconstr
+ end
+ | exception Not_found -> E_constraint nconstr
+ in
let rewrite_case (e,cases) =
match List.assoc (exp_loc e) choices with
| choice,subst ->
@@ -917,7 +927,9 @@ let apply_pat_choices choices =
| exception Not_found -> E_case (e,cases)
in
let open Rewriter in
- fold_exp { id_exp_alg with e_case = rewrite_case }
+ fold_exp { id_exp_alg with
+ e_constraint = rewrite_constraint;
+ e_case = rewrite_case }
let split_defs continue_anyway splits defs =
let split_constructors (Defs defs) =
@@ -1484,11 +1496,19 @@ let split_defs continue_anyway splits defs =
| P_id id ->
(match id_match id with
| None -> None
+ (* Total case split *)
| Some None -> Some (split id l annot)
+ (* Where the analysis proposed a specific case split, propagate a
+ literal as normal, but perform a more careful transformation
+ otherwise *)
| Some (Some (pats,l)) ->
Some (List.mapi (fun i p ->
- let p',subst = freshen_pat_bindings p in
- P_aux (P_as (p',id),(l,annot)),[],[l,(i,subst)])
+ match p with
+ | P_aux (P_lit lit,_) when (match lit with L_aux (L_undef,_) -> false | _ -> true) ->
+ p,[id,E_aux (E_lit lit,(Generated Unknown,None))],[l,(i,[])]
+ | _ ->
+ let p',subst = freshen_pat_bindings p in
+ P_aux (P_as (p',id),(l,annot)),[],[l,(i,subst)])
pats)
)
| P_app (id,ps) ->
@@ -2013,6 +2033,37 @@ let is_id env id =
let ids = id :: List.map (fun (Id_aux (id,_)) -> id) ids in
fun (Id_aux (x,_)) -> List.mem x ids
+(* Type-agnostic pattern comparison for merging below *)
+
+let lit_eq' (L_aux (l1,_)) (L_aux (l2,_)) =
+ match l1, l2 with
+ | L_num n1, L_num n2 -> Big_int.equal n1 n2
+ | _,_ -> l1 = l2
+
+let forall2 p x y =
+ try List.for_all2 p x y with Invalid_argument _ -> false
+
+let rec pat_eq (P_aux (p1,_)) (P_aux (p2,_)) =
+ match p1, p2 with
+ | P_lit lit1, P_lit lit2 -> lit_eq' lit1 lit2
+ | P_wild, P_wild -> true
+ | P_as (p1',id1), P_as (p2',id2) -> Id.compare id1 id2 == 0 && pat_eq p1' p2'
+ | P_typ (_,p1'), P_typ (_,p2') -> pat_eq p1' p2'
+ | P_id id1, P_id id2 -> Id.compare id1 id2 == 0
+ | P_var (p1',kid1), P_var (p2',kid2) -> Kid.compare kid1 kid2 == 0 && pat_eq p1' p2'
+ | P_app (id1,args1), P_app (id2,args2) ->
+ Id.compare id1 id2 == 0 && forall2 pat_eq args1 args2
+ | P_record (fpats1, flag1), P_record (fpats2, flag2) ->
+ flag1 == flag2 && forall2 fpat_eq fpats1 fpats2
+ | P_vector ps1, P_vector ps2
+ | P_vector_concat ps1, P_vector_concat ps2
+ | P_tup ps1, P_tup ps2
+ | P_list ps1, P_list ps2 -> List.for_all2 pat_eq ps1 ps2
+ | P_cons (p1',p1''), P_cons (p2',p2'') -> pat_eq p1' p2' && pat_eq p1'' p2''
+ | _,_ -> false
+and fpat_eq (FP_aux (FP_Fpat (id1,p1),_)) (FP_aux (FP_Fpat (id2,p2),_)) =
+ Id.compare id1 id2 == 0 && pat_eq p1 p2
+
module Analysis =
@@ -2122,7 +2173,9 @@ let merge_detail _ x y =
match x,y with
| None, x -> x
| x, None -> x
- | Some _, Some _ -> Some Total (* TODO preserve equivalent patterns *)
+ | Some (Partial (ps1,l1)), Some (Partial (ps2,l2))
+ when l1 = l2 && forall2 pat_eq ps1 ps2 -> x
+ | _ -> Some Total
let dmerge x y =
match x,y with
@@ -2597,14 +2650,28 @@ let translate_id (Id_aux (_,l) as id) =
| _ -> None
in aux l
-let initial_env fn_id (TypQ_aux (tq,_)) pat =
+let initial_env fn_id (TypQ_aux (tq,_)) pat set_assertions =
let pats =
match pat with
| P_aux (P_tup pats,_) -> pats
| _ -> [pat]
in
+ let default_split annot =
+ let env = env_of_annot annot in
+ let Typ_aux (typ,_) = Env.base_typ_of env (typ_of_annot annot) in
+ match typ with
+ | Typ_app (Id_aux (Id "atom",_),[Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]) ->
+ (match KBindings.find kid set_assertions with
+ | (l,is) ->
+ let l' = Generated l in
+ let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n,l')),(l',None))) is in
+ let pats = pats @ [P_aux (P_wild,(l',None))] in
+ Partial (pats,l)
+ | exception Not_found -> Total)
+ | _ -> Total
+ in
let arg i pat =
- let rec aux (P_aux (p,(l,_))) =
+ let rec aux (P_aux (p,(l,annot))) =
let of_list pats =
let ss,vs,ks = split3 (List.map aux pats) in
let s = List.fold_left (ArgSplits.merge merge_detail) ArgSplits.empty ss in
@@ -2634,7 +2701,7 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat =
begin
match translate_id id with
| Some id' ->
- let s = ArgSplits.singleton id' Total in
+ let s = ArgSplits.singleton id' (default_split (l,annot)) in
s,
Bindings.singleton id (Have (s,CallerArgSet.empty,CallerKidSet.empty)),
KBindings.empty
@@ -2671,6 +2738,46 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat =
let kid_deps = List.fold_left dep_kbindings_merge kid_quant_deps kid_deps in
{ var_deps = var_deps; kid_deps = kid_deps }
+(* When there's more than one pick the first *)
+let merge_set_asserts _ x y =
+ match x, y with
+ | None, _ -> y
+ | _, _ -> x
+let merge_set_asserts_by_kid sets1 sets2 =
+ KBindings.merge merge_set_asserts sets1 sets2
+
+(* Find all the easily reached set assertions in a function body, to use as
+ case splits *)
+let rec find_set_assertions (E_aux (e,_)) =
+ match e with
+ | E_block es
+ | E_nondet es ->
+ List.fold_left merge_set_asserts_by_kid KBindings.empty (List.map find_set_assertions es)
+ | E_cast (_,e) -> find_set_assertions e
+ | E_let (LB_aux (LB_val (p,e1),_),e2) ->
+ let sets1 = find_set_assertions e1 in
+ let sets2 = find_set_assertions e2 in
+ let kbound = kids_bound_by_pat p in
+ let sets2 = KBindings.filter (fun kid _ -> not (KidSet.mem kid kbound)) sets2 in
+ merge_set_asserts_by_kid sets1 sets2
+ | E_assert (E_aux (e1,_),_) -> begin
+ match e1 with
+ | E_constraint (NC_aux (NC_set (kid,is),l)) -> KBindings.singleton kid (l,is)
+ | _ -> KBindings.empty
+ end
+ | _ -> KBindings.empty
+
+let print_set_assertions set_assertions =
+ if KBindings.is_empty set_assertions then
+ print_endline "No top-level set assertions found."
+ else begin
+ print_endline "Top-level set assertions found:";
+ KBindings.iter (fun k (l,is) ->
+ print_endline (string_of_kid k ^ " " ^
+ String.concat "," (List.map Big_int.to_string is)))
+ set_assertions
+ end
+
let print_result r =
let _ = print_endline (" splits: " ^ string_of_argsplits r.split) in
let print_kbinding kid dep =
@@ -2695,20 +2802,19 @@ let print_result r =
()
let analyse_funcl debug tenv (FCL_aux (FCL_Funcl (id,pexp),_)) =
+ let _ = if debug > 2 then print_endline (string_of_id id) else () in
let pat,guard,body,_ = destruct_pexp pexp in
let (tq,_) = Env.get_val_spec id tenv in
- let aenv = initial_env id tq pat in
+ let set_assertions = find_set_assertions body in
+ let _ = if debug > 2 then print_set_assertions set_assertions in
+ let aenv = initial_env id tq pat set_assertions in
let _,_,r = analyse_exp id aenv Bindings.empty body in
let r = match guard with
| None -> r
| Some exp -> let _,_,r' = analyse_exp id aenv Bindings.empty exp in
merge r r'
in
- let _ =
- if debug > 2 then
- (print_endline (string_of_id id);
- print_result r)
- else ()
+ let _ = if debug > 2 then print_result r else ()
in r
let analyse_def debug env = function