diff options
| author | Alasdair Armstrong | 2018-01-25 19:08:55 +0000 |
|---|---|---|
| committer | Alasdair Armstrong | 2018-01-25 19:08:55 +0000 |
| commit | b2d580f7154f2e0d55ac710663bde16fd074720c (patch) | |
| tree | 93f0151ff5b655e8ff11639cda7166f81018707f /src/monomorphise.ml | |
| parent | b7e388f0193a89608687760f50e476c059f0f49c (diff) | |
| parent | 98493e9de3e591d565d6d8c4f081f3dcb5346125 (diff) | |
Merge branch 'sail2' of https://bitbucket.org/Peter_Sewell/sail into sail2
Diffstat (limited to 'src/monomorphise.ml')
| -rw-r--r-- | src/monomorphise.ml | 456 |
1 files changed, 378 insertions, 78 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 9098f03d..2e7ad08f 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -790,11 +790,14 @@ let construct_lit_vector args = in aux [] args (* We may need to split up a pattern match if (1) we've been told to case split - on a variable by the user, or (2) we monomorphised a constructor that's used + on a variable by the user or analysis, or (2) we monomorphised a constructor that's used in the pattern. *) type split = | NoSplit - | VarSplit of (tannot pat * (id * tannot Ast.exp) list) list + | VarSplit of (tannot pat * (* pattern for this case *) + (id * tannot Ast.exp) list * (* substitutions for arguments *) + (Parse_ast.l * (int * (id * tannot exp) list)) list) (* optional locations of case expressions to reduce *) + list | ConstrSplit of (tannot pat * nexp KBindings.t) list let threaded_map f state l = @@ -846,6 +849,88 @@ let keep_undef_typ value = E_aux (E_cast (typ_of_annot eann,value),(Generated Unknown,snd eann)) | _ -> value +let freshen_id = + let counter = ref 0 in + fun id -> + let n = !counter in + let () = counter := n + 1 in + match id with + | Id_aux (Id x, l) -> Id_aux (Id (x ^ "#m" ^ string_of_int n),Generated l) + | Id_aux (DeIid x, l) -> Id_aux (DeIid (x ^ "#m" ^ string_of_int n),Generated l) + +(* TODO: only freshen bindings that might be shadowed *) +let rec freshen_pat_bindings p = + let rec aux (P_aux (p,(l,annot)) as pat) = + let mkp p = P_aux (p,(Generated l, annot)) in + match p with + | P_lit _ + | P_wild -> pat, [] + | P_as (p,_) -> aux p + | P_typ (typ,p) -> let p',vs = aux p in mkp (P_typ (typ,p')),vs + | P_id id -> let id' = freshen_id id in mkp (P_id id'),[id,E_aux (E_id id',(Generated Unknown,None))] + | P_var (p,_) -> aux p + | P_app (id,args) -> + let args',vs = List.split (List.map aux args) in + mkp (P_app (id,args')),List.concat vs + | P_record (fpats,flag) -> + let fpats,vs = List.split (List.map auxr fpats) in + mkp (P_record (fpats,flag)),List.concat vs + | P_vector ps -> + let ps,vs = List.split (List.map aux ps) in + mkp (P_vector ps),List.concat vs + | P_vector_concat ps -> + let ps,vs = List.split (List.map aux ps) in + mkp (P_vector_concat ps),List.concat vs + | P_tup ps -> + let ps,vs = List.split (List.map aux ps) in + mkp (P_tup ps),List.concat vs + | P_list ps -> + let ps,vs = List.split (List.map aux ps) in + mkp (P_list ps),List.concat vs + | P_cons (p1,p2) -> + let p1,vs1 = aux p1 in + let p2,vs2 = aux p2 in + mkp (P_cons (p1, p2)), vs1@vs2 + and auxr (FP_aux (FP_Fpat (id,p),(l,annot))) = + let p,vs = aux p in + FP_aux (FP_Fpat (id, p),(Generated l,annot)), vs + in aux p + +(* Use the location pairs in choices to reduce case expressions at the first + location to the given case at the second. *) +let apply_pat_choices choices = + let rewrite_constraint (NC_aux (nc,l) as nconstr) = + match List.assoc l choices with + | choice,_ -> begin + match nc with + | NC_set (kid,is) -> + E_constraint (NC_aux ((if choice < List.length is then NC_true else NC_false), Generated l)) + | _ -> E_constraint nconstr + end + | exception Not_found -> E_constraint nconstr + in + let rewrite_case (e,cases) = + match List.assoc (exp_loc e) choices with + | choice,subst -> + (match List.nth cases choice with + | Pat_aux (Pat_exp (p,E_aux (e,_)),_) -> + let dummyannot = (Generated Unknown,None) in + (* TODO: use a proper substitution *) + List.fold_left (fun e (id,e') -> + E_let (LB_aux (LB_val (P_aux (P_id id, dummyannot),e'),dummyannot),E_aux (e,dummyannot))) e subst + | Pat_aux (Pat_when _,(l,_)) -> + raise (Reporting_basic.err_unreachable l + "Pattern acquired a guard after analysis!") + | exception Not_found -> + raise (Reporting_basic.err_unreachable (exp_loc e) + "Unable to find case I found earlier!")) + | exception Not_found -> E_case (e,cases) + in + let open Rewriter in + fold_exp { id_exp_alg with + e_constraint = rewrite_constraint; + e_case = rewrite_case } + let split_defs continue_anyway splits defs = let split_constructors (Defs defs) = let sc_type_union q (Tu_aux (tu,l) as tua) = @@ -1289,26 +1374,26 @@ let split_defs continue_anyway splits defs = Err_general (pat_l, ("Cannot split type " ^ string_of_typ typ ^ " for variable " ^ v ^ ": " ^ msg)) in if continue_anyway - then (print_error error; [P_aux (P_id var,(pat_l,annot)),[]]) + then (print_error error; [P_aux (P_id var,(pat_l,annot)),[],[]]) else raise (Fatal_error error) in match ty with | Typ_id (Id_aux (Id "bool",_)) -> - [P_aux (P_lit (L_aux (L_true,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_true,new_l)),(new_l,annot))]; - P_aux (P_lit (L_aux (L_false,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_false,new_l)),(new_l,annot))]] + [P_aux (P_lit (L_aux (L_true,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_true,new_l)),(new_l,annot))],[]; + P_aux (P_lit (L_aux (L_false,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_false,new_l)),(new_l,annot))],[]] | Typ_id id -> (try (* enumerations *) let ns = Env.get_enum id env in List.map (fun n -> (P_aux (P_id (renew_id n),(l,annot)), - [var,E_aux (E_id (renew_id n),(new_l,annot))])) ns + [var,E_aux (E_id (renew_id n),(new_l,annot))],[])) ns with Type_error _ -> match id with | Id_aux (Id "bit",_) -> List.map (fun b -> P_aux (P_lit (L_aux (b,new_l)),(l,annot)), - [var,E_aux (E_lit (L_aux (b,new_l)),(new_l, annot))]) + [var,E_aux (E_lit (L_aux (b,new_l)),(new_l, annot))],[]) [L_zero; L_one] | _ -> cannot ("don't know about type " ^ string_of_id id)) @@ -1318,7 +1403,7 @@ let split_defs continue_anyway splits defs = let lits = make_vectors (Big_int.to_int sz) in List.map (fun lit -> P_aux (P_lit lit,(l,annot)), - [var,E_aux (E_lit lit,(new_l,annot))]) lits + [var,E_aux (E_lit lit,(new_l,annot))],[]) lits | _ -> cannot ("length not constant, " ^ string_of_nexp len) ) @@ -1328,7 +1413,7 @@ let split_defs continue_anyway splits defs = let mk_lit i = let lit = L_aux (L_num i,new_l) in P_aux (P_lit lit,(l,annot)), - [var,E_aux (E_lit lit,(new_l,annot))] + [var,E_aux (E_lit lit,(new_l,annot))],[] in match value with | Nexp_constant i -> [mk_lit i] @@ -1353,16 +1438,16 @@ let split_defs continue_anyway splits defs = | Generated l -> [] (* Could do match_l l, but only want to split user-written patterns *) | Range (p,q) -> let matches = - List.filter (fun ((filename,line),_) -> + List.filter (fun ((filename,line),_,_) -> p.Lexing.pos_fname = filename && p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum) ls - in List.map snd matches + in List.map (fun (_,var,optpats) -> (var,optpats)) matches in let split_pat vars p = - let id_matches = function - | Id_aux (Id x,_) -> List.mem x vars - | Id_aux (DeIid x,_) -> List.mem x vars + let id_match = function + | Id_aux (Id x,_) -> (try Some (List.assoc x vars) with Not_found -> None) + | Id_aux (DeIid x,_) -> (try Some (List.assoc x vars) with Not_found -> None) in let rec list f = function @@ -1370,45 +1455,62 @@ let split_defs continue_anyway splits defs = | h::t -> let t' = match list f t with - | None -> [t,[]] + | None -> [t,[],[]] | Some t' -> t' in let h' = match f h with - | None -> [h,[]] + | None -> [h,[],[]] | Some ps -> ps in - Some (List.concat (List.map (fun (h,hsubs) -> List.map (fun (t,tsubs) -> (h::t,hsubs@tsubs)) t') h')) + Some (List.concat + (List.map (fun (h,hsubs,hpchoices) -> + List.map (fun (t,tsubs,tpchoices) -> + (h::t, hsubs@tsubs, hpchoices@tpchoices)) t') h')) in let rec spl (P_aux (p,(l,annot))) = let relist f ctx ps = optmap (list f ps) (fun ps -> - List.map (fun (ps,sub) -> P_aux (ctx ps,(l,annot)),sub) ps) + List.map (fun (ps,sub,pchoices) -> P_aux (ctx ps,(l,annot)),sub,pchoices) ps) in let re f p = optmap (spl p) - (fun ps -> List.map (fun (p,sub) -> (P_aux (f p,(l,annot)), sub)) ps) + (fun ps -> List.map (fun (p,sub,pchoices) -> (P_aux (f p,(l,annot)), sub, pchoices)) ps) in let fpat (FP_aux ((FP_Fpat (id,p),annot))) = optmap (spl p) - (fun ps -> List.map (fun (p,sub) -> FP_aux (FP_Fpat (id,p), annot), sub) ps) + (fun ps -> List.map (fun (p,sub,pchoices) -> FP_aux (FP_Fpat (id,p), annot), sub, pchoices) ps) in match p with | P_lit _ | P_wild | P_var _ -> None - | P_as (p',id) when id_matches id -> + | P_as (p',id) when id_match id <> None -> raise (Reporting_basic.err_general l ("Cannot split " ^ string_of_id id ^ " on 'as' pattern")) | P_as (p',id) -> re (fun p -> P_as (p,id)) p' | P_typ (t,p') -> re (fun p -> P_typ (t,p)) p' - | P_id id when id_matches id -> - Some (split id l annot) - | P_id _ -> - None + | P_id id -> + (match id_match id with + | None -> None + (* Total case split *) + | Some None -> Some (split id l annot) + (* Where the analysis proposed a specific case split, propagate a + literal as normal, but perform a more careful transformation + otherwise *) + | Some (Some (pats,l)) -> + Some (List.mapi (fun i p -> + match p with + | P_aux (P_lit lit,_) when (match lit with L_aux (L_undef,_) -> false | _ -> true) -> + p,[id,E_aux (E_lit lit,(Generated Unknown,None))],[l,(i,[])] + | _ -> + let p',subst = freshen_pat_bindings p in + P_aux (P_as (p',id),(l,annot)),[],[l,(i,subst)]) + pats) + ) | P_app (id,ps) -> relist spl (fun ps -> P_app (id,ps)) ps | P_record (fps,flag) -> @@ -1425,10 +1527,10 @@ let split_defs continue_anyway splits defs = match spl p1, spl p2 with | None, None -> None | p1', p2' -> - let p1' = match p1' with None -> [p1,[]] | Some p1' -> p1' in - let p2' = match p2' with None -> [p2,[]] | Some p2' -> p2' in - let ps = List.map (fun (p1',subs1) -> List.map (fun (p2',subs2) -> - P_aux (P_cons (p1',p2'),(l,annot)),subs1@subs2) p2') p1' in + let p1' = match p1' with None -> [p1,[],[]] | Some p1' -> p1' in + let p2' = match p2' with None -> [p2,[],[]] | Some p2' -> p2' in + let ps = List.map (fun (p1',subs1,pchoices1) -> List.map (fun (p2',subs2,pchoices2) -> + P_aux (P_cons (p1',p2'),(l,annot)),subs1@subs2,pchoices1@pchoices2) p2') p1' in Some (List.concat ps) in spl p in @@ -1481,7 +1583,7 @@ let split_defs continue_anyway splits defs = | lvs -> let pvs = bindings_from_pat p in let pvs = List.map string_of_id pvs in - let overlap = List.exists (fun v -> List.mem v pvs) lvs in + let overlap = List.exists (fun (v,_) -> List.mem v pvs) lvs in let () = if overlap then Reporting_basic.print_err false true l "Monomorphisation" @@ -1569,8 +1671,9 @@ let split_defs continue_anyway splits defs = | NoSplit -> nosplit | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then - List.map (fun (pat',substs) -> + List.map (fun (pat',substs,pchoices) -> let exp' = subst_exp substs e in + let exp' = apply_pat_choices pchoices exp' in Pat_aux (Pat_exp (pat', map_exp exp'),l)) patsubsts else nosplit @@ -1586,9 +1689,11 @@ let split_defs continue_anyway splits defs = | NoSplit -> nosplit | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then - List.map (fun (pat',substs) -> + List.map (fun (pat',substs,pchoices) -> let exp1' = subst_exp substs e1 in + let exp1' = apply_pat_choices pchoices exp1' in let exp2' = subst_exp substs e2 in + let exp2' = apply_pat_choices pchoices exp2' in Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)) patsubsts else nosplit @@ -1923,6 +2028,44 @@ let rewrite_size_parameters env (Defs defs) = end +let is_id env id = + let ids = Env.get_overloads (Id_aux (id,Parse_ast.Unknown)) env in + let ids = id :: List.map (fun (Id_aux (id,_)) -> id) ids in + fun (Id_aux (x,_)) -> List.mem x ids + +(* Type-agnostic pattern comparison for merging below *) + +let lit_eq' (L_aux (l1,_)) (L_aux (l2,_)) = + match l1, l2 with + | L_num n1, L_num n2 -> Big_int.equal n1 n2 + | _,_ -> l1 = l2 + +let forall2 p x y = + try List.for_all2 p x y with Invalid_argument _ -> false + +let rec pat_eq (P_aux (p1,_)) (P_aux (p2,_)) = + match p1, p2 with + | P_lit lit1, P_lit lit2 -> lit_eq' lit1 lit2 + | P_wild, P_wild -> true + | P_as (p1',id1), P_as (p2',id2) -> Id.compare id1 id2 == 0 && pat_eq p1' p2' + | P_typ (_,p1'), P_typ (_,p2') -> pat_eq p1' p2' + | P_id id1, P_id id2 -> Id.compare id1 id2 == 0 + | P_var (p1',kid1), P_var (p2',kid2) -> Kid.compare kid1 kid2 == 0 && pat_eq p1' p2' + | P_app (id1,args1), P_app (id2,args2) -> + Id.compare id1 id2 == 0 && forall2 pat_eq args1 args2 + | P_record (fpats1, flag1), P_record (fpats2, flag2) -> + flag1 == flag2 && forall2 fpat_eq fpats1 fpats2 + | P_vector ps1, P_vector ps2 + | P_vector_concat ps1, P_vector_concat ps2 + | P_tup ps1, P_tup ps2 + | P_list ps1, P_list ps2 -> List.for_all2 pat_eq ps1 ps2 + | P_cons (p1',p1''), P_cons (p2',p2'') -> pat_eq p1' p2' && pat_eq p1'' p2'' + | _,_ -> false +and fpat_eq (FP_aux (FP_Fpat (id1,p1),_)) (FP_aux (FP_Fpat (id2,p2),_)) = + Id.compare id1 id2 == 0 && pat_eq p1 p2 + + + module Analysis = struct @@ -1935,11 +2078,19 @@ let id_pair_compare (id,l) (id',l') = | 0 -> compare l l' | x -> x +(* Usually we do a full case split on an argument, but sometimes we find a + case expression in the function body that suggests a more compact case + splitting. *) +type match_detail = + | Total + | Partial of tannot pat list * Parse_ast.l + (* Arguments that we might split on *) -module ArgSet = Set.Make (struct +module ArgSplits = Map.Make (struct type t = id * loc let compare = id_pair_compare end) +type arg_splits = match_detail ArgSplits.t (* Arguments that we should look at in callers *) module CallerArgSet = Set.Make (struct @@ -1967,13 +2118,19 @@ module StringSet = Set.Make (struct end) type dependencies = - | Have of ArgSet.t * CallerArgSet.t * CallerKidSet.t + | Have of arg_splits * CallerArgSet.t * CallerKidSet.t (* args to split inside fn * caller args to split * caller kids that are bitvector parameters *) | Unknown of Parse_ast.l * string -let string_of_argset s = - String.concat ", " (List.map (fun (id,l) -> string_of_id id ^ "." ^ string_of_loc l) - (ArgSet.elements s)) +let string_of_match_detail = function + | Total -> "[total]" + | Partial (pats,_) -> "[" ^ String.concat " | " (List.map string_of_pat pats) ^ "]" + +let string_of_argsplits s = + String.concat ", " + (List.map (fun ((id,l),detail) -> + string_of_id id ^ "." ^ string_of_loc l ^ string_of_match_detail detail) + (ArgSplits.bindings s)) let string_of_callerset s = String.concat ", " (List.map (fun (id,arg) -> string_of_id id ^ "." ^ string_of_int arg) @@ -1984,8 +2141,8 @@ let string_of_callerkidset s = (CallerKidSet.elements s)) let string_of_dep = function - | Have (argset,callset,kidset) -> - "Have (" ^ string_of_argset argset ^ "; " ^ string_of_callerset callset ^ "; " ^ + | Have (args,callset,kidset) -> + "Have (" ^ string_of_argsplits args ^ "; " ^ string_of_callerset callset ^ "; " ^ string_of_callerkidset kidset ^ ")" | Unknown (l,msg) -> "Unknown " ^ msg ^ " at " ^ Reporting_basic.loc_to_string l @@ -1995,7 +2152,7 @@ let string_of_dep = function the end for the interprocedural phase. *) type result = { - split : ArgSet.t; + split : arg_splits; failures : StringSet.t Failures.t; (* Dependencies for arguments and type variables of each fn called, so that if the fn uses one for a bitvector size we can track it back *) @@ -2005,21 +2162,29 @@ type result = { } let empty = { - split = ArgSet.empty; + split = ArgSplits.empty; failures = Failures.empty; split_on_call = Bindings.empty; split_in_caller = CallerArgSet.empty; kid_in_caller = CallerKidSet.empty } +let merge_detail _ x y = + match x,y with + | None, x -> x + | x, None -> x + | Some (Partial (ps1,l1)), Some (Partial (ps2,l2)) + when l1 = l2 && forall2 pat_eq ps1 ps2 -> x + | _ -> Some Total + let dmerge x y = match x,y with | Unknown (l,s), _ -> Unknown (l,s) | _, Unknown (l,s) -> Unknown (l,s) | Have (a,c,k), Have (a',c',k') -> - Have (ArgSet.union a a', CallerArgSet.union c c', CallerKidSet.union k k') + Have (ArgSplits.merge merge_detail a a', CallerArgSet.union c c', CallerKidSet.union k k') -let dempty = Have (ArgSet.empty, CallerArgSet.empty, CallerKidSet.empty) +let dempty = Have (ArgSplits.empty, CallerArgSet.empty, CallerKidSet.empty) let dopt_merge k x y = match x, y with @@ -2053,7 +2218,7 @@ let failure_merge _ x y = | Some x, Some y -> Some (StringSet.union x y) let merge rs rs' = { - split = ArgSet.union rs.split rs'.split; + split = ArgSplits.merge merge_detail rs.split rs'.split; failures = Failures.merge failure_merge rs.failures rs'.failures; split_on_call = Bindings.merge call_arg_merge rs.split_on_call rs'.split_on_call; split_in_caller = CallerArgSet.union rs.split_in_caller rs'.split_in_caller; @@ -2150,6 +2315,75 @@ let deps_of_uvar kid_deps arg_deps = function | U_effect _ -> dempty | U_typ typ -> deps_of_typ kid_deps arg_deps typ +let mk_subrange_pattern vannot vstart vend = + let (_,len,ord,typ) = vector_typ_args_of (Env.base_typ_of (env_of_annot vannot) (typ_of_annot vannot)) in + match ord with + | Ord_aux (Ord_var _,_) -> None + | Ord_aux (ord',_) -> + let vstart,vend = if ord' = Ord_inc then vstart,vend else vend,vstart + in + let dummyl = Generated Unknown in + match len with + | Nexp_aux (Nexp_constant len,_) -> + Some (fun pat -> + let end_len = Big_int.pred (Big_int.sub len vend) in + (* Wrap pat in its type; in particular the type checker won't + manage P_wild in the middle of a P_vector_concat *) + let pat = P_aux (P_typ (pat_typ_of pat, pat),(Generated (pat_loc pat),None)) in + let pats = if Big_int.greater end_len Big_int.zero then + [pat;P_aux (P_typ (vector_typ (nconstant end_len) ord typ, + P_aux (P_wild,(dummyl,None))),(dummyl,None))] + else [pat] + in + let pats = if Big_int.greater vstart Big_int.zero then + (P_aux (P_typ (vector_typ (nconstant vstart) ord typ, + P_aux (P_wild,(dummyl,None))),(dummyl,None)))::pats + else pats + in + let pats = if ord' = Ord_inc then pats else List.rev pats + in + P_aux (P_vector_concat pats,(Generated (fst vannot),None))) + | _ -> None + +(* If the expression matched on in a case expression is a function argument, + and has no other dependencies, we can try to use the pattern match directly + rather than doing a full case split. *) +let refine_dependency env (E_aux (e,(l,annot)) as exp) pexps = + let check_dep id ctx = + match Bindings.find id env.var_deps with + | Have (args,callargs,callkids) -> + if CallerArgSet.is_empty callargs && CallerKidSet.is_empty callkids then + match ArgSplits.bindings args with + | [(id',loc),Total] when Id.compare id id' == 0 -> + (match Util.map_all (function + | Pat_aux (Pat_exp (pat,_),_) -> Some (ctx pat) + | Pat_aux (Pat_when (_,_,_),_) -> None) pexps + with + | Some pats -> + if l = Parse_ast.Unknown then + (Reporting_basic.print_error + (Reporting_basic.Err_general + (l, "No location for pattern match: " ^ string_of_exp exp)); + None) + else + Some (Have (ArgSplits.singleton (id,loc) (Partial (pats,l)),callargs,callkids)) + | None -> None) + | _ -> None + else None + | Unknown _ -> None + | exception Not_found -> None + in + match e with + | E_id id -> check_dep id (fun x -> x) + | E_app (fn_id, [E_aux (E_id id,vannot); + E_aux (E_lit (L_aux (L_num vstart,_)),_); + E_aux (E_lit (L_aux (L_num vend,_)),_)]) + when is_id (env_of exp) (Id "vector_subrange") fn_id -> + (match mk_subrange_pattern vannot vstart vend with + | Some mk_pat -> check_dep id mk_pat + | None -> None) + | _ -> None + (* Takes an environment of dependencies on vars, type vars, and flow control, and dependencies on mutable variables. The latter are quite conservative, we currently drop variables assigned inside loops, for example. *) @@ -2275,6 +2509,10 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) = | E_field (e,_) -> analyse_exp fn_id env assigns e | E_case (e,cases) -> let deps,assigns,r = analyse_exp fn_id env assigns e in + let deps = match refine_dependency env e cases with + | Some deps -> deps + | None -> deps + in let analyse_case (Pat_aux (pexp,_)) = match pexp with | Pat_exp (pat,e1) -> @@ -2365,7 +2603,7 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) = match deps_of_nexp env.kid_deps [] size with | Have (args,caller,caller_kids) -> { r with - split = ArgSet.union r.split args; + split = ArgSplits.merge merge_detail r.split args; split_in_caller = CallerArgSet.union r.split_in_caller caller; kid_in_caller = CallerKidSet.union r.kid_in_caller caller_kids } @@ -2412,17 +2650,31 @@ let translate_id (Id_aux (_,l) as id) = | _ -> None in aux l -let initial_env fn_id (TypQ_aux (tq,_)) pat = +let initial_env fn_id (TypQ_aux (tq,_)) pat set_assertions = let pats = match pat with | P_aux (P_tup pats,_) -> pats | _ -> [pat] in + let default_split annot = + let env = env_of_annot annot in + let Typ_aux (typ,_) = Env.base_typ_of env (typ_of_annot annot) in + match typ with + | Typ_app (Id_aux (Id "atom",_),[Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]) -> + (match KBindings.find kid set_assertions with + | (l,is) -> + let l' = Generated l in + let pats = List.map (fun n -> P_aux (P_lit (L_aux (L_num n,l')),(l',None))) is in + let pats = pats @ [P_aux (P_wild,(l',None))] in + Partial (pats,l) + | exception Not_found -> Total) + | _ -> Total + in let arg i pat = - let rec aux (P_aux (p,(l,_))) = + let rec aux (P_aux (p,(l,annot))) = let of_list pats = let ss,vs,ks = split3 (List.map aux pats) in - let s = List.fold_left ArgSet.union ArgSet.empty ss in + let s = List.fold_left (ArgSplits.merge merge_detail) ArgSplits.empty ss in let v = List.fold_left dep_bindings_merge Bindings.empty vs in let k = List.fold_left dep_kbindings_merge KBindings.empty ks in s,v,k @@ -2430,27 +2682,37 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat = match p with | P_lit _ | P_wild - -> ArgSet.empty,Bindings.empty,KBindings.empty + -> ArgSplits.empty,Bindings.empty,KBindings.empty | P_as (pat,id) -> begin let s,v,k = aux pat in match translate_id id with - | Some id' -> ArgSet.add id' s, Bindings.add id (Have (ArgSet.singleton id',CallerArgSet.empty,CallerKidSet.empty)) v,k - | None -> s, Bindings.add id (Unknown (l, ("Unable to give location for " ^ string_of_id id))) v, k + | Some id' -> + ArgSplits.add id' Total s, + Bindings.add id (Have (ArgSplits.singleton id' Total,CallerArgSet.empty,CallerKidSet.empty)) v, + k + | None -> + s, + Bindings.add id (Unknown (l, ("Unable to give location for " ^ string_of_id id))) v, + k end | P_typ (_,pat) -> aux pat | P_id id -> begin match translate_id id with | Some id' -> - let s = ArgSet.singleton id' in - s, Bindings.singleton id (Have (s,CallerArgSet.empty,CallerKidSet.empty)), KBindings.empty + let s = ArgSplits.singleton id' (default_split (l,annot)) in + s, + Bindings.singleton id (Have (s,CallerArgSet.empty,CallerKidSet.empty)), + KBindings.empty | None -> - ArgSet.empty, Bindings.singleton id (Unknown (l, ("Unable to give location for " ^ string_of_id id))), KBindings.empty + ArgSplits.empty, + Bindings.singleton id (Unknown (l, ("Unable to give location for " ^ string_of_id id))), + KBindings.empty end | P_var (pat,kid) -> let s,v,k = aux pat in - s,v,KBindings.add kid (Have (ArgSet.empty,CallerArgSet.singleton (fn_id,i),CallerKidSet.empty)) k + s,v,KBindings.add kid (Have (ArgSplits.empty,CallerArgSet.singleton (fn_id,i),CallerKidSet.empty)) k | P_app (_,pats) -> of_list pats | P_record (fpats,_) -> of_list (List.map (fun (FP_aux (FP_Fpat (_,p),_)) -> p) fpats) | P_vector pats @@ -2463,7 +2725,7 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat = in let quant k = function | QI_aux (QI_id (KOpt_aux ((KOpt_none kid | KOpt_kind (_,kid)),_)),_) -> - KBindings.add kid (Have (ArgSet.empty,CallerArgSet.empty,CallerKidSet.singleton (fn_id,kid))) k + KBindings.add kid (Have (ArgSplits.empty,CallerArgSet.empty,CallerKidSet.singleton (fn_id,kid))) k | QI_aux (QI_const _,_) -> k in let kid_quant_deps = @@ -2476,8 +2738,48 @@ let initial_env fn_id (TypQ_aux (tq,_)) pat = let kid_deps = List.fold_left dep_kbindings_merge kid_quant_deps kid_deps in { var_deps = var_deps; kid_deps = kid_deps } +(* When there's more than one pick the first *) +let merge_set_asserts _ x y = + match x, y with + | None, _ -> y + | _, _ -> x +let merge_set_asserts_by_kid sets1 sets2 = + KBindings.merge merge_set_asserts sets1 sets2 + +(* Find all the easily reached set assertions in a function body, to use as + case splits *) +let rec find_set_assertions (E_aux (e,_)) = + match e with + | E_block es + | E_nondet es -> + List.fold_left merge_set_asserts_by_kid KBindings.empty (List.map find_set_assertions es) + | E_cast (_,e) -> find_set_assertions e + | E_let (LB_aux (LB_val (p,e1),_),e2) -> + let sets1 = find_set_assertions e1 in + let sets2 = find_set_assertions e2 in + let kbound = kids_bound_by_pat p in + let sets2 = KBindings.filter (fun kid _ -> not (KidSet.mem kid kbound)) sets2 in + merge_set_asserts_by_kid sets1 sets2 + | E_assert (E_aux (e1,_),_) -> begin + match e1 with + | E_constraint (NC_aux (NC_set (kid,is),l)) -> KBindings.singleton kid (l,is) + | _ -> KBindings.empty + end + | _ -> KBindings.empty + +let print_set_assertions set_assertions = + if KBindings.is_empty set_assertions then + print_endline "No top-level set assertions found." + else begin + print_endline "Top-level set assertions found:"; + KBindings.iter (fun k (l,is) -> + print_endline (string_of_kid k ^ " " ^ + String.concat "," (List.map Big_int.to_string is))) + set_assertions + end + let print_result r = - let _ = print_endline (" splits: " ^ string_of_argset r.split) in + let _ = print_endline (" splits: " ^ string_of_argsplits r.split) in let print_kbinding kid dep = let _ = print_endline (" " ^ string_of_kid kid ^ ": " ^ string_of_dep dep) in () @@ -2500,20 +2802,19 @@ let print_result r = () let analyse_funcl debug tenv (FCL_aux (FCL_Funcl (id,pexp),_)) = + let _ = if debug > 2 then print_endline (string_of_id id) else () in let pat,guard,body,_ = destruct_pexp pexp in let (tq,_) = Env.get_val_spec id tenv in - let aenv = initial_env id tq pat in + let set_assertions = find_set_assertions body in + let _ = if debug > 2 then print_set_assertions set_assertions in + let aenv = initial_env id tq pat set_assertions in let _,_,r = analyse_exp id aenv Bindings.empty body in let r = match guard with | None -> r | Some exp -> let _,_,r' = analyse_exp id aenv Bindings.empty exp in merge r r' in - let _ = - if debug > 2 then - (print_endline (string_of_id id); - print_result r) - else () + let _ = if debug > 2 then print_result r else () in r let analyse_def debug env = function @@ -2533,25 +2834,25 @@ let analyse_defs debug env (Defs defs) = let splits,fails = CallerKidSet.fold add_kid caller_kids (splits,fails) in splits, fails | Unknown (l,msg) -> - ArgSet.empty , Failures.singleton l (StringSet.singleton ("Unable to monomorphise dependency: " ^ msg)) + ArgSplits.empty , Failures.singleton l (StringSet.singleton ("Unable to monomorphise dependency: " ^ msg)) and chase_kid_caller (id,kid) = match Bindings.find id r.split_on_call with | (_,kid_deps) -> begin match KBindings.find kid kid_deps with | deps -> chase_deps deps - | exception Not_found -> ArgSet.empty,Failures.empty + | exception Not_found -> ArgSplits.empty,Failures.empty end - | exception Not_found -> ArgSet.empty,Failures.empty + | exception Not_found -> ArgSplits.empty,Failures.empty and chase_arg_caller (id,i) = match Bindings.find id r.split_on_call with | (arg_deps,_) -> chase_deps (List.nth arg_deps i) - | exception Not_found -> ArgSet.empty,Failures.empty + | exception Not_found -> ArgSplits.empty,Failures.empty and add_arg arg (splits,fails) = let splits',fails' = chase_arg_caller arg in - ArgSet.union splits splits', Failures.merge failure_merge fails fails' + ArgSplits.merge merge_detail splits splits', Failures.merge failure_merge fails fails' and add_kid k (splits,fails) = let splits',fails' = chase_kid_caller k in - ArgSet.union splits splits', Failures.merge failure_merge fails fails' + ArgSplits.merge merge_detail splits splits', Failures.merge failure_merge fails fails' in let _ = if debug > 1 then print_result r else () in let splits,fails = CallerArgSet.fold add_arg r.split_in_caller (r.split,r.failures) in @@ -2559,7 +2860,7 @@ let analyse_defs debug env (Defs defs) = let _ = if debug > 0 then (print_endline "Final splits:"; - print_endline (string_of_argset splits)) + print_endline (string_of_argsplits splits)) else () in let _ = @@ -2573,8 +2874,11 @@ let analyse_defs debug env (Defs defs) = in splits let argset_to_list splits = - let l = ArgSet.elements splits in - let argelt (id,(file,loc)) = ((file,loc),string_of_id id) in + let l = ArgSplits.bindings splits in + let argelt = function + | ((id,(file,loc)),Total) -> ((file,loc),string_of_id id,None) + | ((id,(file,loc)),Partial (pats,l)) -> ((file,loc),string_of_id id,Some (pats,l)) + in List.map argelt l end @@ -2598,11 +2902,6 @@ let is_constant_vec_typ env typ = | _ -> false) | _ -> false -let is_id env id = - let ids = Env.get_overloads (Id_aux (id,Parse_ast.Unknown)) env in - let ids = id :: List.map (fun (Id_aux (id,_)) -> id) ids in - fun (Id_aux (x,_)) -> List.mem x ids - (* We have to add casts in here with appropriate length information so that the type checker knows the expected return types. *) @@ -2816,7 +3115,8 @@ let monomorphise opts splits env defs = if opts.auto then Analysis.argset_to_list (Analysis.analyse_defs opts.debug_analysis env defs) else [] in - let defs = split_defs opts.all_split_errors (new_splits@splits) defs in + let splits = new_splits @ (List.map (fun (loc,id) -> (loc,id,None)) splits) in + let defs = split_defs opts.all_split_errors splits defs in (* TODO: stop if opts.all_split_errors && something went wrong *) (* TODO: currently doing this because constant propagation leaves numeric literals as int, try to avoid this later; also use final env for DEF_spec case above, because the |
