diff options
| author | Robert Norton | 2018-02-22 17:23:48 +0000 |
|---|---|---|
| committer | Robert Norton | 2018-02-22 17:23:48 +0000 |
| commit | bac62a260ce9aa8f83bb71515daf1829133b0127 (patch) | |
| tree | 03b24eea504d09dc6fa3267fc9740aef6b66e446 /src/monomorphise.ml | |
| parent | 5308167903db5e81c07a5aff9f20c83f33afcb9c (diff) | |
| parent | c63741a21b5a1f77f85987f15f6aac3321a91f0a (diff) | |
Merge branch 'sail2' of github.com:rems-project/sail into sail2
Diffstat (limited to 'src/monomorphise.ml')
| -rw-r--r-- | src/monomorphise.ml | 365 |
1 files changed, 238 insertions, 127 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 71efcb22..d14097af 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -54,7 +54,7 @@ open Ast_util module Big_int = Nat_big_num open Type_check -let size_set_limit = 32 +let size_set_limit = 64 let optmap v f = match v with @@ -69,6 +69,11 @@ let bindings_union s1 s2 = | _, (Some x) -> Some x | (Some x), _ -> Some x | _, _ -> None) s1 s2 +let kbindings_union s1 s2 = + KBindings.merge (fun _ x y -> match x,y with + | _, (Some x) -> Some x + | (Some x), _ -> Some x + | _, _ -> None) s1 s2 let subst_nexp substs nexp = let rec s_snexp substs (Nexp_aux (ne,l) as nexp) = @@ -615,9 +620,9 @@ let bindings_from_pat p = and aux_fpat (FP_aux (FP_Fpat (_,p), _)) = aux_pat p in aux_pat p -let remove_bound env pat = +let remove_bound (substs,ksubsts) pat = let bound = bindings_from_pat pat in - List.fold_left (fun sub v -> Bindings.remove v sub) env bound + List.fold_left (fun sub v -> Bindings.remove v sub) substs bound, ksubsts (* Attempt simple pattern matches *) let lit_match = function @@ -721,6 +726,30 @@ let int_of_str_lit = function | L_bin bin -> Big_int.of_string ("0b" ^ bin) | _ -> assert false +let bits_of_lit = function + | L_bin bin -> bin + | L_hex hex -> hex_to_bin hex + | _ -> assert false + +let slice_lit (L_aux (lit,ll)) i len (Ord_aux (ord,_)) = + let i = Big_int.to_int i in + let len = Big_int.to_int len in + match match ord with + | Ord_inc -> Some i + | Ord_dec -> Some (len - i) + | Ord_var _ -> None + with + | None -> None + | Some i -> + match lit with + | L_bin bin -> Some (L_aux (L_bin (String.sub bin i len),Generated ll)) + | _ -> assert false + +let concat_vec lit1 lit2 = + let bits1 = bits_of_lit lit1 in + let bits2 = bits_of_lit lit2 in + L_bin (bits1 ^ bits2) + let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) = match l1,l2 with | (L_zero|L_false), (L_zero|L_false) @@ -758,16 +787,47 @@ let try_app (l,ann) (id,args) = | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit,_), _)] -> Some (E_aux (E_lit (L_aux (L_num (int_of_str_lit lit),new_l)),(l,ann))) | _ -> None + else if is_id "slice" then + match args with + | [E_aux (E_lit (L_aux ((L_hex _| L_bin _),_) as lit), + (_,Some (_,Typ_aux (Typ_app (_,[_;Typ_arg_aux (Typ_arg_order ord,_);_]),_),_))); + E_aux (E_lit L_aux (L_num i,_), _); + E_aux (E_lit L_aux (L_num len,_), _)] -> + (match slice_lit lit i len ord with + | Some lit' -> Some (E_aux (E_lit lit',(l,ann))) + | None -> None) + | _ -> None + else if is_id "bitvector_concat" then + match args with + | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit1,_), _); + E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit2,_), _)] -> + Some (E_aux (E_lit (L_aux (concat_vec lit1 lit2,new_l)),(l,ann))) + | _ -> None else if is_id "shl_int" then match args with | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> Some (E_aux (E_lit (L_aux (L_num (Big_int.shift_left i (Big_int.to_int j)),new_l)),(l,ann))) | _ -> None - else if is_id "mult_int" then + else if is_id "mult_int" || is_id "mult_range" then match args with | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> Some (E_aux (E_lit (L_aux (L_num (Big_int.mul i j),new_l)),(l,ann))) | _ -> None + else if is_id "quotient_nat" then + match args with + | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> + Some (E_aux (E_lit (L_aux (L_num (Big_int.div i j),new_l)),(l,ann))) + | _ -> None + else if is_id "add_range" then + match args with + | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> + Some (E_aux (E_lit (L_aux (L_num (Big_int.add i j),new_l)),(l,ann))) + | _ -> None + else if is_id "negate_range" then + match args with + | [E_aux (E_lit L_aux (L_num i,_),_)] -> + Some (E_aux (E_lit (L_aux (L_num (Big_int.negate i),new_l)),(l,ann))) + | _ -> None else if is_id "ex_int" then match args with | [E_aux (E_lit lit,(l,_))] -> Some (E_aux (E_lit lit,(l,ann))) @@ -1034,6 +1094,13 @@ let apply_pat_choices choices = e_assert = rewrite_assert; e_case = rewrite_case } +(* Check whether the current environment with the given kid assignments is + inconsistent (and hence whether the code is dead) *) +let is_env_inconsistent env ksubsts = + let env = KBindings.fold (fun k nexp env -> + Env.add_constraint (nc_eq (nvar k) nexp) env) ksubsts env in + prove env nc_false + let split_defs all_errors splits defs = let no_errors_happened = ref true in let split_constructors (Defs defs) = @@ -1065,8 +1132,13 @@ let split_defs all_errors splits defs = let (refinements, defs') = split_constructors defs in + (* COULD DO: dead code is only eliminated at if expressions, but we could + also cut out impossible case branches and code after assertions. *) + (* Constant propogation. Takes maps of immutable/mutable variables to subsitute. + The substs argument also contains the current type-level kid refinements + so that we can check for dead code. Extremely conservative about evaluation order of assignments in subexpressions, dropping assignments rather than committing to any particular order *) @@ -1123,7 +1195,7 @@ let split_defs all_errors splits defs = let env = Type_check.env_of_annot (l, annot) in (try match Env.lookup_id id env with - | Local (Immutable,_) -> Bindings.find id substs + | Local (Immutable,_) -> Bindings.find id (fst substs) | Local (Mutable,_) -> Bindings.find id assigns | _ -> exp with Not_found -> exp),assigns @@ -1154,20 +1226,48 @@ let split_defs all_errors splits defs = re (E_tuple es') assigns | E_if (e1,e2,e3) -> let e1',assigns = const_prop_exp substs assigns e1 in - let e2',assigns2 = const_prop_exp substs assigns e2 in - let e3',assigns3 = const_prop_exp substs assigns e3 in - (match drop_casts e1' with + let e1_no_casts = drop_casts e1' in + (match e1_no_casts with | E_aux (E_lit (L_aux ((L_true|L_false) as lit ,_)),_) -> - (match lit with L_true -> e2',assigns2 | _ -> e3',assigns3) + (match lit with + | L_true -> const_prop_exp substs assigns e2 + | _ -> const_prop_exp substs assigns e3) | _ -> - let assigns = isubst_minus_set assigns (assigned_vars e2) in - let assigns = isubst_minus_set assigns (assigned_vars e3) in - re (E_if (e1',e2',e3')) assigns) + (* If the guard is an equality check, propagate the value. *) + let env1 = env_of e1_no_casts in + let is_equal id = + List.exists (fun id' -> Id.compare id id' == 0) + (Env.get_overloads (Id_aux (DeIid "==", Parse_ast.Unknown)) + env1) + in + let substs_true = + match e1_no_casts with + | E_aux (E_app (id, [E_aux (E_id var,_); vl]),_) + | E_aux (E_app (id, [vl; E_aux (E_id var,_)]),_) + when is_equal id -> + if is_value vl then + (match Env.lookup_id var env1 with + | Local (Immutable,_) -> Bindings.add var vl (fst substs),snd substs + | _ -> substs) + else substs + | _ -> substs + in + (* Discard impossible branches *) + if is_env_inconsistent (env_of e2) (snd substs) then + const_prop_exp substs assigns e3 + else if is_env_inconsistent (env_of e3) (snd substs) then + const_prop_exp substs_true assigns e2 + else + let e2',assigns2 = const_prop_exp substs_true assigns e2 in + let e3',assigns3 = const_prop_exp substs assigns e3 in + let assigns = isubst_minus_set assigns (assigned_vars e2) in + let assigns = isubst_minus_set assigns (assigned_vars e3) in + re (E_if (e1',e2',e3')) assigns) | E_for (id,e1,e2,e3,ord,e4) -> (* Treat e1, e2 and e3 (from, to and by) as a non-det tuple *) let e1',e2',e3',assigns = non_det_exp_3 e1 e2 e3 in let assigns = isubst_minus_set assigns (assigned_vars e4) in - let e4',_ = const_prop_exp (Bindings.remove id substs) assigns e4 in + let e4',_ = const_prop_exp (Bindings.remove id (fst substs),snd substs) assigns e4 in re (E_for (id,e1',e2',e3',ord,e4')) assigns | E_loop (loop,e1,e2) -> let assigns = isubst_minus_set assigns (IdSet.union (assigned_vars e1) (assigned_vars e2)) in @@ -1227,7 +1327,7 @@ let split_defs all_errors splits defs = | Some (E_aux (_,(_,annot')) as exp,newbindings,kbindings) -> let exp = nexp_subst_exp (kbindings_from_list kbindings) exp in let newbindings_env = bindings_from_list newbindings in - let substs' = bindings_union substs newbindings_env in + let substs' = bindings_union (fst substs) newbindings_env, snd substs in const_prop_exp substs' assigns exp) | E_let (lb,e2) -> begin @@ -1245,7 +1345,7 @@ let split_defs all_errors splits defs = | Some (e'',bindings,kbindings) -> let e'' = nexp_subst_exp (kbindings_from_list kbindings) e'' in let bindings = bindings_from_list bindings in - let substs'' = bindings_union substs' bindings in + let substs'' = bindings_union (fst substs') bindings, snd substs' in const_prop_exp substs'' assigns e'' else plain () end @@ -1350,9 +1450,9 @@ let split_defs all_errors splits defs = let cases = List.map (function | FCL_aux (FCL_Funcl (_,pexp), ann) -> pexp) fcls in - match can_match_with_env env arg cases Bindings.empty Bindings.empty with + match can_match_with_env env arg cases (Bindings.empty,KBindings.empty) Bindings.empty with | Some (exp,bindings,kbindings) -> - let substs = bindings_from_list bindings in + let substs = bindings_from_list bindings, kbindings_from_list kbindings in let result,_ = const_prop_exp substs Bindings.empty exp in let result = match result with | E_aux (E_return e,_) -> e @@ -1361,7 +1461,7 @@ let split_defs all_errors splits defs = if is_value result then Some result else None | None -> None - and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases substs assigns = + and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases (substs,ksubsts) assigns = let rec findpat_generic check_pat description assigns = function | [] -> (Reporting_basic.print_err false true l "Monomorphisation" ("Failed to find a case for " ^ description); None) @@ -1373,7 +1473,7 @@ let split_defs all_errors splits defs = Some (exp, [(id', exp0)], []) | (Pat_aux (Pat_when (P_aux (P_id id',_),guard,exp),_))::tl when pat_id_is_variable env id' -> begin - let substs = Bindings.add id' exp0 substs in + let substs = Bindings.add id' exp0 substs, ksubsts in let (E_aux (guard,_)),assigns = const_prop_exp substs assigns guard in match guard with | E_lit (L_aux (L_true,_)) -> Some (exp,[(id',exp0)],[]) @@ -1385,7 +1485,8 @@ let split_defs all_errors splits defs = | DoesNotMatch -> findpat_generic check_pat description assigns tl | DoesMatch (vsubst,ksubst) -> begin let guard = nexp_subst_exp (kbindings_from_list ksubst) guard in - let substs = bindings_union substs (bindings_from_list vsubst) in + let substs = bindings_union substs (bindings_from_list vsubst), + kbindings_union ksubsts (kbindings_from_list ksubst) in let (E_aux (guard,_)),assigns = const_prop_exp substs assigns guard in match guard with | E_lit (L_aux (L_true,_)) -> Some (exp,vsubst,ksubst) @@ -1463,8 +1564,8 @@ let split_defs all_errors splits defs = can_match_with_env env exp in - let subst_exp substs exp = - let substs = bindings_from_list substs in + let subst_exp substs ksubsts exp = + let substs = bindings_from_list substs, ksubsts in fst (const_prop_exp substs Bindings.empty exp) in @@ -1813,8 +1914,9 @@ let split_defs all_errors splits defs = | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then List.map (fun (pat',substs,pchoices,ksubsts) -> - let exp' = nexp_subst_exp (kbindings_from_list ksubsts) e in - let exp' = subst_exp substs exp' in + let ksubsts = kbindings_from_list ksubsts in + let exp' = nexp_subst_exp ksubsts e in + let exp' = subst_exp substs ksubsts exp' in let exp' = apply_pat_choices pchoices exp' in let exp' = stop_at_false_assertions exp' in Pat_aux (Pat_exp (pat', map_exp exp'),l)) @@ -1833,11 +1935,12 @@ let split_defs all_errors splits defs = | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then List.map (fun (pat',substs,pchoices,ksubsts) -> - let exp1' = nexp_subst_exp (kbindings_from_list ksubsts) e1 in - let exp1' = subst_exp substs exp1' in + let ksubsts = kbindings_from_list ksubsts in + let exp1' = nexp_subst_exp ksubsts e1 in + let exp1' = subst_exp substs ksubsts exp1' in let exp1' = apply_pat_choices pchoices exp1' in - let exp2' = nexp_subst_exp (kbindings_from_list ksubsts) e2 in - let exp2' = subst_exp substs exp2' in + let exp2' = nexp_subst_exp ksubsts e2 in + let exp2' = subst_exp substs ksubsts exp2' in let exp2' = apply_pat_choices pchoices exp2' in let exp2' = stop_at_false_assertions exp2' in Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)) @@ -1917,27 +2020,27 @@ let findi f = let mapat f is xs = let rec aux n = function - | _, [] -> [] - | (i,_)::is, h::t when i = n -> + | [] -> [] + | h::t when Util.IntSet.mem n is -> let h' = f h in - let t' = aux (n+1) (is, t) in + let t' = aux (n+1) t in h'::t' - | is, h::t -> - let t' = aux (n+1) (is, t) in + | h::t -> + let t' = aux (n+1) t in h::t' - in aux 0 (is, xs) + in aux 0 xs let mapat_extra f is xs = let rec aux n = function - | _, [] -> [], [] - | (i,v)::is, h::t when i = n -> - let h',x = f v h in - let t',xs = aux (n+1) (is, t) in + | [] -> [], [] + | h::t when Util.IntSet.mem n is -> + let h',x = f h in + let t',xs = aux (n+1) t in h'::t',x::xs - | is, h::t -> - let t',xs = aux (n+1) (is, t) in + | h::t -> + let t',xs = aux (n+1) t in h::t',xs - in aux 0 (is, xs) + in aux 0 xs let tyvars_bound_in_pat pat = let open Rewriter in @@ -1975,34 +2078,45 @@ let sizes_of_annot = function | _,None -> KidSet.empty | _,Some (env,typ,_) -> sizes_of_typ (Env.base_typ_of env typ) -let change_parameter_pat kid = function - | P_aux (P_id var, (l,_)) - | P_aux (P_typ (_,P_aux (P_id var, (l,_))),_) - -> P_aux (P_id var, (l,None)), (var,kid) +let change_parameter_pat = function + | P_aux (P_id var, (l,Some (env,typ,_))) + | P_aux (P_typ (_,P_aux (P_id var, (l,Some (env,typ,_)))),_) -> + P_aux (P_id var, (l,None)), var | P_aux (_,(l,_)) -> raise (Reporting_basic.err_unreachable l "Expected variable pattern") (* We add code to change the itself('n) parameter into the corresponding integer. *) -let add_var_rebind exp (var,kid) = +let add_var_rebind exp var = let l = Generated Unknown in let annot = (l,None) in E_aux (E_let (LB_aux (LB_val (P_aux (P_id var,annot), E_aux (E_app (mk_id "size_itself_int",[E_aux (E_id var,annot)]),annot)),annot),exp),annot) (* atom('n) arguments to function calls need to be rewritten *) -let replace_with_the_value (E_aux (_,(l,_)) as exp) = +let replace_with_the_value bound_nexps (E_aux (_,(l,_)) as exp) = let env = env_of exp in let typ, wrap = match typ_of exp with | Typ_aux (Typ_exist (kids,nc,typ),l) -> typ, fun t -> Typ_aux (Typ_exist (kids,nc,t),l) | typ -> typ, fun x -> x in let typ = Env.expand_synonyms env typ in + let replace_size size = + (* TODO: pick simpler nexp when there's a choice (also in pretty printer) *) + let is_equal nexp = + prove env (NC_aux (NC_equal (size,nexp), Parse_ast.Unknown)) + in + if is_nexp_constant size then size else + match List.find is_equal bound_nexps with + | nexp -> nexp + | exception Not_found -> size + in let mk_exp nexp l l' = - E_aux (E_cast (wrap (Typ_aux (Typ_app (Id_aux (Id "itself",Generated Unknown), - [Typ_arg_aux (Typ_arg_nexp nexp,l')]),Generated Unknown)), - E_aux (E_app (Id_aux (Id "make_the_value",Generated Unknown),[exp]),(Generated l,None))), - (Generated l,None)) + let nexp = replace_size nexp in + E_aux (E_cast (wrap (Typ_aux (Typ_app (Id_aux (Id "itself",Generated Unknown), + [Typ_arg_aux (Typ_arg_nexp nexp,l')]),Generated Unknown)), + E_aux (E_app (Id_aux (Id "make_the_value",Generated Unknown),[exp]),(Generated l,None))), + (Generated l,None)) in match typ with | Typ_aux (Typ_app (Id_aux (Id "range",_), @@ -2032,91 +2146,77 @@ let replace_type env typ = let rewrite_size_parameters env (Defs defs) = let open Rewriter in - let size_vars pexp = - fst (fold_pexp - { (compute_exp_alg KidSet.empty KidSet.union) with - e_aux = (fun ((s,e),annot) -> KidSet.union s (sizes_of_annot annot), E_aux (e,annot)); - e_let = (fun ((sl,lb),(s2,e2)) -> KidSet.union sl (KidSet.diff s2 (tyvars_bound_in_lb lb)), E_let (lb,e2)); - e_for = (fun (id,(s1,e1),(s2,e2),(s3,e3),ord,(s4,e4)) -> - let kid = mk_kid ("loop_" ^ string_of_id id) in - KidSet.union s1 (KidSet.union s2 (KidSet.union s3 (KidSet.remove kid s4))), - E_for (id,e1,e2,e3,ord,e4)); - pat_exp = (fun ((sp,pat),(s,e)) -> KidSet.diff s (tyvars_bound_in_pat pat), Pat_exp (pat,e))} - pexp) - in - let exposed_sizes_funcl fnsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = - let sizes = size_vars pexp in - let pat,guard,exp,pannot = destruct_pexp pexp in - let visible_tyvars = - KidSet.union - (Pretty_print_lem.lem_tyvars_of_typ (pat_typ_of pat)) - (Pretty_print_lem.lem_tyvars_of_typ (typ_of exp)) - in - let expose_tyvars = KidSet.diff sizes visible_tyvars in - KidSet.union fnsizes expose_tyvars - in - let sizes_funcl expose_tyvars fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = + let open Util in + + let sizes_funcl fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) = let pat,guard,exp,pannot = destruct_pexp pexp in let parameters = match pat with | P_aux (P_tup ps,_) -> ps | _ -> [pat] in - let to_change = Util.map_filter - (fun kid -> - let check (P_aux (_,(_,Some (env,typ,_)))) = - match Env.expand_synonyms env typ with - Typ_aux (Typ_app(Id_aux (Id "range",_), - [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_); - Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid'',_)),_)]),_) -> - if Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 then Some kid else None - | Typ_aux (Typ_app(Id_aux (Id "atom", _), - [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_)]), _) -> - if Kid.compare kid kid' = 0 then Some kid else None - | _ -> None - in match findi check parameters with - | None -> (Reporting_basic.print_error (Reporting_basic.Err_general (l, - ("Unable to find an argument for " ^ string_of_kid kid))); - None) - | Some i -> Some i) - (KidSet.elements expose_tyvars) + let add_parameter (i,nmap) (P_aux (_,(_,Some (env,typ,_)))) = + let nmap = + match Env.base_typ_of env typ with + Typ_aux (Typ_app(Id_aux (Id "range",_), + [Typ_arg_aux (Typ_arg_nexp nexp,_); + Typ_arg_aux (Typ_arg_nexp nexp',_)]),_) + when Nexp.compare nexp nexp' = 0 && not (NexpMap.mem nexp nmap) -> + NexpMap.add nexp i nmap + | Typ_aux (Typ_app(Id_aux (Id "atom", _), + [Typ_arg_aux (Typ_arg_nexp nexp,_)]), _) + when not (NexpMap.mem nexp nmap) -> + NexpMap.add nexp i nmap + | _ -> nmap + in (i+1,nmap) + in + let (_,nexp_map) = List.fold_left add_parameter (0,NexpMap.empty) parameters in + let nexp_list = NexpMap.bindings nexp_map in + let parameters_for = function + | Some (env,typ,_) -> + begin match Env.base_typ_of env typ with + | Typ_aux (Typ_app (Id_aux (Id "vector",_), [Typ_arg_aux (Typ_arg_nexp size,_);_;_]),_) + when not (is_nexp_constant size) -> + begin + match NexpMap.find size nexp_map with + | i -> IntSet.singleton i + | exception Not_found -> + (* Look for equivalent nexps, but only in consistent type env *) + if prove env (NC_aux (NC_false,Unknown)) then IntSet.empty else + match List.find (fun (nexp,i) -> + prove env (NC_aux (NC_equal (nexp,size),Unknown))) nexp_list with + | _, i -> IntSet.singleton i + | exception Not_found -> IntSet.empty + end + | _ -> IntSet.empty + end + | None -> IntSet.empty in - let ik_compare (i,k) (i',k') = - match compare (i : int) i' with - | 0 -> Kid.compare k k' - | x -> x + let parameters_to_rewrite = + fst (fold_pexp + { (compute_exp_alg IntSet.empty IntSet.union) with + e_aux = (fun ((s,e),(l,annot)) -> IntSet.union s (parameters_for annot),E_aux (e,(l,annot))) + } pexp) in - let to_change = List.sort ik_compare to_change in + let new_nexps = NexpSet.of_list (List.map fst + (List.filter (fun (nexp,i) -> IntSet.mem i parameters_to_rewrite) nexp_list)) in match Bindings.find id fsizes with - | old -> if List.for_all2 (fun x y -> ik_compare x y = 0) old to_change then fsizes else - let str l = String.concat "," (List.map (fun (i,k) -> string_of_int i ^ "." ^ string_of_kid k) l) in - raise (Reporting_basic.err_general l - ("Different size type variables in different clauses of " ^ string_of_id id ^ - " old: " ^ str old ^ " new: " ^ str to_change)) - | exception Not_found -> Bindings.add id to_change fsizes + | old,old_nexps -> Bindings.add id (IntSet.union old parameters_to_rewrite, + NexpSet.union old_nexps new_nexps) fsizes + | exception Not_found -> Bindings.add id (parameters_to_rewrite, new_nexps) fsizes in let sizes_def fsizes = function | DEF_fundef (FD_aux (FD_function (_,_,_,funcls),_)) -> - let expose_tyvars = List.fold_left exposed_sizes_funcl KidSet.empty funcls in - List.fold_left (sizes_funcl expose_tyvars) fsizes funcls + List.fold_left sizes_funcl fsizes funcls | _ -> fsizes in let fn_sizes = List.fold_left sizes_def Bindings.empty defs in - let rewrite_e_app (id,args) = - match Bindings.find id fn_sizes with - | [] -> E_app (id,args) - | to_change -> - let args' = mapat replace_with_the_value to_change args in - E_app (id,args') - | exception Not_found -> E_app (id,args) - in let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),(l,annot))) = let pat,guard,body,(pl,_) = destruct_pexp pexp in - let pat,guard,body = + let pat,guard,body, nexps = (* Update pattern and add itself -> nat wrapper to body *) match Bindings.find id fn_sizes with - | [] -> pat,guard,body - | to_change -> + | to_change,nexps -> let pat, vars = match pat with P_aux (P_tup pats,(l,_)) -> @@ -2124,13 +2224,10 @@ let rewrite_size_parameters env (Defs defs) = P_aux (P_tup pats,(l,None)), vars | P_aux (_,(l,_)) -> begin - match to_change with - | [0,kid] -> - let pat, var = change_parameter_pat kid pat in + if IntSet.is_empty to_change then pat, [] + else + let pat, var = change_parameter_pat pat in pat, [var] - | _ -> - raise (Reporting_basic.err_unreachable l - "Expected multiple parameters at single parameter") end in (* TODO: only add bindings that are necessary (esp for guards) *) @@ -2139,10 +2236,24 @@ let rewrite_size_parameters env (Defs defs) = | None -> None | Some exp -> Some (List.fold_left add_var_rebind exp vars) in - pat,guard,body - | exception Not_found -> pat,guard,body + pat,guard,body,nexps + | exception Not_found -> pat,guard,body,NexpSet.empty in (* Update function applications *) + let funcl_typ = typ_of_annot (l,annot) in + let already_visible_nexps = + NexpSet.union + (Pretty_print_lem.lem_nexps_of_typ funcl_typ) + (Pretty_print_lem.typeclass_nexps funcl_typ) + in + let bound_nexps = NexpSet.elements (NexpSet.union nexps already_visible_nexps) in + let rewrite_e_app (id,args) = + match Bindings.find id fn_sizes with + | to_change,_ -> + let args' = mapat (replace_with_the_value bound_nexps) to_change args in + E_app (id,args') + | exception Not_found -> E_app (id,args) + in let body = fold_exp { id_exp_alg with e_app = rewrite_e_app } body in let guard = match guard with | None -> None @@ -2156,8 +2267,7 @@ let rewrite_size_parameters env (Defs defs) = | DEF_spec (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,annot))) as spec -> begin match Bindings.find id fn_sizes with - | [] -> spec - | to_change -> + | to_change,_ when not (IntSet.is_empty to_change) -> let typschm = match typschm with | TypSchm_aux (TypSchm_ts (tq,typ),l) -> let typ = match typ with @@ -2169,6 +2279,7 @@ let rewrite_size_parameters env (Defs defs) = in TypSchm_aux (TypSchm_ts (tq,typ),l) in DEF_spec (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,None))) + | _ -> spec | exception Not_found -> spec end | def -> def |
