diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 63 | ||||
| -rw-r--r-- | src/ast_util.mli | 19 | ||||
| -rw-r--r-- | src/initial_check.ml | 4 | ||||
| -rw-r--r-- | src/monomorphise.ml | 53 | ||||
| -rw-r--r-- | src/pretty_print_coq.ml | 41 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 7 | ||||
| -rw-r--r-- | src/pretty_print_sail.ml | 23 | ||||
| -rw-r--r-- | src/rewrites.ml | 2 | ||||
| -rw-r--r-- | src/spec_analysis.ml | 7 | ||||
| -rw-r--r-- | src/specialize.ml | 16 | ||||
| -rw-r--r-- | src/type_check.ml | 161 | ||||
| -rw-r--r-- | src/type_check.mli | 14 |
12 files changed, 249 insertions, 161 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index bd7a51bb..02e297cb 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -128,7 +128,7 @@ let mk_val_spec vs_aux = let kopt_kid (KOpt_aux (KOpt_kind (_, kid), _)) = kid let kopt_kind (KOpt_aux (KOpt_kind (k, _), _)) = k - + let is_nat_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _) -> true | _ -> false @@ -144,7 +144,7 @@ let is_typ_kopt = function let is_bool_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_bool, _), _), _) -> true | _ -> false - + let string_of_kid = function | Kid_aux (Var v, _) -> v @@ -153,6 +153,27 @@ module Kid = struct let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2) end +module Kind = struct + type t = kind + let compare (K_aux (aux1, _)) (K_aux (aux2, _)) = + match aux1, aux2 with + | K_int, K_int -> 0 + | K_type, K_type -> 0 + | K_order, K_order -> 0 + | K_bool, K_bool -> 0 + | K_int, _ -> 1 | _, K_int -> -1 + | K_type, _ -> 1 | _, K_type -> -1 + | K_order, _ -> 1 | _, K_order -> -1 +end + +module KOpt = struct + type t = kinded_id + let compare kopt1 kopt2 = + let lex_ord c1 c2 = if c1 = 0 then c2 else c1 in + lex_ord (Kid.compare (kopt_kid kopt1) (kopt_kid kopt2)) + (Kind.compare (kopt_kind kopt1) (kopt_kind kopt2)) +end + module Id = struct type t = id let compare id1 id2 = @@ -200,6 +221,8 @@ module Bindings = Map.Make(Id) module IdSet = Set.Make(Id) module KBindings = Map.Make(Kid) module KidSet = Set.Make(Kid) +module KOptSet = Set.Make(KOpt) +module KOptMap = Map.Make(KOpt) module NexpSet = Set.Make(Nexp) module NexpMap = Map.Make(Nexp) @@ -389,6 +412,13 @@ let arg_order ?loc:(l=Parse_ast.Unknown) ord = A_aux (A_order ord, l) let arg_typ ?loc:(l=Parse_ast.Unknown) typ = A_aux (A_typ typ, l) let arg_bool ?loc:(l=Parse_ast.Unknown) nc = A_aux (A_bool nc, l) +let arg_kopt (KOpt_aux (KOpt_kind (K_aux (k, _), v), l)) = + match k with + | K_int -> arg_nexp (nvar v) + | K_order -> arg_order (Ord_aux (Ord_var v, l)) + | K_bool -> arg_bool (nc_var v) + | K_type -> arg_typ (mk_typ (Typ_var v)) + let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc])) let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown) @@ -644,6 +674,9 @@ let string_of_kind_aux = function let string_of_kind (K_aux (k, _)) = string_of_kind_aux k +let string_of_kinded_id (KOpt_aux (KOpt_kind (k, kid), _)) = + "(" ^ string_of_kid kid ^ " : " ^ string_of_kind k ^ ")" + let string_of_base_effect = function | BE_aux (beff, _) -> string_of_base_effect_aux beff @@ -687,7 +720,7 @@ and string_of_typ_aux = function ^ string_of_typ typ_ret ^ " effect " ^ string_of_effect eff | Typ_bidir (typ1, typ2) -> string_of_typ typ1 ^ " <-> " ^ string_of_typ typ2 | Typ_exist (kids, nc, typ) -> - "{" ^ string_of_list " " string_of_kid kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}" + "{" ^ string_of_list " " string_of_kinded_id kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}" and string_of_typ_arg = function | A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg and string_of_typ_arg_aux = function @@ -995,7 +1028,7 @@ module Typ = struct | n -> n) | Typ_tup ts1, Typ_tup ts2 -> Util.compare_list compare ts1 ts2 | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) -> - (match Util.compare_list Kid.compare ks1 ks2 with + (match Util.compare_list KOpt.compare ks1 ks2 with | 0 -> (match NC.compare nc1 nc2 with | 0 -> compare t1 t2 | n -> n) @@ -1016,8 +1049,10 @@ module Typ = struct | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 | A_typ t1, A_typ t2 -> compare t1 t2 | A_order o1, A_order o2 -> order_compare o1 o2 + | A_bool nc1, A_bool nc2 -> NC.compare nc1 nc2 | A_nexp _, _ -> -1 | _, A_nexp _ -> 1 | A_typ _, _ -> -1 | _, A_typ _ -> 1 + | A_order _, _ -> -1 | _, A_order _ -> 1 end module TypMap = Map.Make(Typ) @@ -1179,7 +1214,7 @@ and tyvars_of_typ (Typ_aux (t,_)) = KidSet.empty tas | Typ_exist (kids, nc, t) -> let s = KidSet.union (tyvars_of_typ t) (tyvars_of_constraint nc) in - List.fold_left (fun s k -> KidSet.remove k s) s kids + List.fold_left (fun s k -> KidSet.remove k s) s (List.map kopt_kid kids) and tyvars_of_typ_arg (A_aux (ta,_)) = match ta with | A_nexp nexp -> tyvars_of_nexp nexp @@ -1387,6 +1422,11 @@ let locate_id f (Id_aux (name, l)) = Id_aux (name, f l) let locate_kid f (Kid_aux (name, l)) = Kid_aux (name, f l) +let locate_kind f (K_aux (kind, l)) = K_aux (kind, f l) + +let locate_kinded_id f (KOpt_aux (KOpt_kind (k, kid), l)) = + KOpt_aux (KOpt_kind (locate_kind f k, locate_kid f kid), f l) + let locate_lit f (L_aux (lit, l)) = L_aux (lit, f l) let locate_base_effect f (BE_aux (base_effect, l)) = BE_aux (base_effect, f l) @@ -1427,10 +1467,12 @@ let rec locate_nc f (NC_aux (nc_aux, l)) = | NC_and (nc1, nc2) -> NC_and (locate_nc f nc1, locate_nc f nc2) | NC_true -> NC_true | NC_false -> NC_false + | NC_var v -> NC_var (locate_kid f v) + | NC_app (id, args) -> NC_app (locate_id f id, List.map (locate_typ_arg f) args) in NC_aux (nc_aux, f l) -let rec locate_typ f (Typ_aux (typ_aux, l)) = +and locate_typ f (Typ_aux (typ_aux, l)) = let typ_aux = match typ_aux with | Typ_internal_unknown -> Typ_internal_unknown | Typ_id id -> Typ_id (locate_id f id) @@ -1439,7 +1481,7 @@ let rec locate_typ f (Typ_aux (typ_aux, l)) = Typ_fn (List.map (locate_typ f) arg_typs, locate_typ f ret_typ, locate_effect f effect) | Typ_bidir (typ1, typ2) -> Typ_bidir (locate_typ f typ1, locate_typ f typ2) | Typ_tup typs -> Typ_tup (List.map (locate_typ f) typs) - | Typ_exist (kids, constr, typ) -> Typ_exist (List.map (locate_kid f) kids, locate_nc f constr, locate_typ f typ) + | Typ_exist (kopts, constr, typ) -> Typ_exist (List.map (locate_kinded_id f) kopts, locate_nc f constr, locate_typ f typ) | Typ_app (id, typ_args) -> Typ_app (locate_id f id, List.map (locate_typ_arg f) typ_args) in Typ_aux (typ_aux, f l) @@ -1449,6 +1491,7 @@ and locate_typ_arg f (A_aux (typ_arg_aux, l)) = | A_nexp nexp -> A_nexp (locate_nexp f nexp) | A_typ typ -> A_typ (locate_typ f typ) | A_order ord -> A_order (locate_order f ord) + | A_bool nc -> A_bool (locate_nc f nc) in A_aux (typ_arg_aux, f l) @@ -1638,8 +1681,10 @@ and typ_subst_aux sv subst = function | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst sv subst typ1, typ_subst sv subst typ2) | Typ_tup typs -> Typ_tup (List.map (typ_subst sv subst) typs) | Typ_app (f, args) -> Typ_app (f, List.map (typ_arg_subst sv subst) args) - | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, constraint_subst sv subst nc, typ_subst sv subst typ) + | Typ_exist (kopts, nc, typ) when KidSet.mem sv (KidSet.of_list (List.map kopt_kid kopts)) -> + Typ_exist (kopts, nc, typ) + | Typ_exist (kopts, nc, typ) -> + Typ_exist (kopts, constraint_subst sv subst nc, typ_subst sv subst typ) and typ_arg_subst sv subst (A_aux (arg, l)) = A_aux (typ_arg_subst_aux sv subst arg, l) and typ_arg_subst_aux sv subst = function diff --git a/src/ast_util.mli b/src/ast_util.mli index 8155acde..ca3a9598 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -180,6 +180,7 @@ val arg_nexp : ?loc:l -> nexp -> typ_arg val arg_order : ?loc:l -> order -> typ_arg val arg_typ : ?loc:l -> typ -> typ_arg val arg_bool : ?loc:l -> n_constraint -> typ_arg +val arg_kopt : kinded_id -> typ_arg (* Functions for working with type quantifiers *) val quant_add : quant_item -> typquant -> typquant @@ -260,6 +261,16 @@ module Kid : sig val compare : kid -> kid -> int end +module Kind : sig + type t = kind + val compare : kind -> kind -> int +end + +module KOpt : sig + type t = kinded_id + val compare : kinded_id -> kinded_id -> int +end + module Nexp : sig type t = nexp val compare : nexp -> nexp -> int @@ -288,6 +299,14 @@ module NexpMap : sig include Map.S with type key = nexp end +module KOptSet : sig + include Set.S with type elt = kinded_id +end + +module KOptMap : sig + include Map.S with type key = kinded_id +end + module BESet : sig include Set.S with type elt = base_effect end diff --git a/src/initial_check.ml b/src/initial_check.ml index 62f0dcf4..d394fde9 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -160,7 +160,7 @@ let rec to_ast_typ ctx (P.ATyp_aux (aux, l)) = | P.ATyp_exist (kids, nc, atyp) -> let kids = List.map to_ast_var kids in let ctx = { ctx with kinds = List.fold_left (fun kinds kid -> KBindings.add kid K_int kinds) ctx.kinds kids } in - Typ_exist (kids, to_ast_constraint ctx nc, to_ast_typ ctx atyp) + Typ_exist (List.map (mk_kopt K_int) kids, to_ast_constraint ctx nc, to_ast_typ ctx atyp) | P.ATyp_base (id, kind, nc) -> raise (Reporting.err_unreachable l __POS__ "TODO") | _ -> raise (Reporting.err_typ l "Invalid type") @@ -984,7 +984,7 @@ let generate_enum_functions vs_ids (Defs defs) = (* Create a function that converts from an enum to a number. *) let from_enum = let kid = mk_kid "e" in - let to_typ = mk_typ (Typ_exist ([kid], range_constraint kid, atom_typ (nvar kid))) in + let to_typ = mk_typ (Typ_exist ([mk_kopt K_int kid], range_constraint kid, atom_typ (nvar kid))) in let name = prepend_id "num_of_" id in let pexp n id = mk_pexp (Pat_exp (mk_pat (P_id id), mk_lit_exp (L_num (Big_int.of_int n)))) in let funcl = diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 9206332d..0e362d3b 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -139,9 +139,9 @@ let subst_src_typ substs t = | Typ_bidir (t1, t2) -> re (Typ_bidir (s_styp substs t1, s_styp substs t2)) | Typ_tup ts -> re (Typ_tup (List.map (s_styp substs) ts)) | Typ_app (id,tas) -> re (Typ_app (id,List.map (s_starg substs) tas)) - | Typ_exist (kids,nc,t) -> - let substs = List.fold_left (fun sub v -> KBindings.remove v sub) substs kids in - re (Typ_exist (kids,nc,s_styp substs t)) + | Typ_exist (kopts,nc,t) -> + let substs = List.fold_left (fun sub kopt -> KBindings.remove (kopt_kid kopt) sub) substs kopts in + re (Typ_exist (kopts,nc,s_styp substs t)) | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" and s_starg substs (A_aux (ta,l) as targ) = match ta with @@ -330,13 +330,15 @@ let rec inst_src_type insts (Typ_aux (ty,l) as typ) = (fun arg (insts,args) -> let insts,arg = inst_src_typ_arg insts arg in insts,arg::args) args (insts,[]) in insts, Typ_aux (Typ_app (id,ts),l) - | Typ_exist (kids, nc, t) -> begin + | Typ_exist (kopts, nc, t) -> begin + (* TODO handle non-integer existentials *) + let kids = List.map kopt_kid kopts in let kid_insts, insts' = peel (kids,insts) in let kids', t' = apply_kid_insts kid_insts t in (* TODO: subst in nc *) match kids' with | [] -> insts', t' - | _ -> insts', Typ_aux (Typ_exist (kids', nc, t'), l) + | _ -> insts', Typ_aux (Typ_exist (List.map (mk_kopt K_int) kids', nc, t'), l) end | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" and inst_src_typ_arg insts (A_aux (ta,l) as tyarg) = @@ -408,7 +410,9 @@ let split_src_type id ty (TypQ_aux (q,ql)) = | Typ_app (_, tas) -> (KidSet.empty,[[],typ]) (* We only support sizes for bitvectors mentioned explicitly, not any buried inside another type *) - | Typ_exist (kids, nc, t) -> + | Typ_exist (kopts, nc, t) -> + (* TODO handle non integer existentials *) + let kids = List.map kopt_kid kopts in let (vars,tys) = size_nvars_ty t in let find_insts k (insts,nc) = let inst,nc' = @@ -426,7 +430,7 @@ let split_src_type id ty (TypQ_aux (q,ql)) = (* Typ_exist is not allowed an empty list of kids *) match kids with | [] -> ty - | _ -> Typ_aux (Typ_exist (kids, nc', ty),l) + | _ -> Typ_aux (Typ_exist (List.map (mk_kopt K_int) kids, nc', ty),l) in inst@inst0, ty in let tys = List.concat (List.map (fun instty -> List.map (ty_and_inst instty) insts) tys) in @@ -524,7 +528,9 @@ let refine_constructor refinements l env id args = let arg_ty = typ_of_args args in match Type_check.destruct_exist (Type_check.Env.expand_synonyms env constr_ty) with | None -> None - | Some (kids,nc,constr_ty) -> + | Some (kopts,nc,constr_ty) -> + (* TODO: Handle non-integer existentials *) + let kids = List.map kopt_kid kopts in let bindings = Type_check.unify l env (tyvars_of_typ constr_ty) constr_ty arg_ty in let find_kid kid = try Some (KBindings.find kid bindings) with Not_found -> None in let bindings = List.map find_kid kids in @@ -730,7 +736,8 @@ let fabricate_nexp l tannot = | Some (env,typ,_) -> match Type_check.destruct_exist (Type_check.Env.expand_synonyms env typ) with | None -> nint 32 - | Some (kids,nc,typ') -> fabricate_nexp_exist env l typ kids nc typ' + (* TODO: check this *) + | Some (kopts,nc,typ') -> fabricate_nexp_exist env l typ (List.map kopt_kid kopts) nc typ' let atom_typ_kid kid = function | Typ_aux (Typ_app (Id_aux (Id "atom",_), @@ -746,23 +753,23 @@ let reduce_cast typ exp l annot = let env = env_of_annot (l,annot) in let typ' = Env.base_typ_of env typ in match exp, destruct_exist (Env.expand_synonyms env typ') with - | E_aux (E_lit (L_aux (L_num n,_)),_), Some ([kid],nc,typ'') when atom_typ_kid kid typ'' -> - let nc_env = Env.add_typ_var l kid K_int env in - let nc_env = Env.add_constraint (nc_eq (nvar kid) (nconstant n)) nc_env in + | E_aux (E_lit (L_aux (L_num n,_)),_), Some ([kopt],nc,typ'') when atom_typ_kid (kopt_kid kopt) typ'' -> + let nc_env = Env.add_typ_var l kopt env in + let nc_env = Env.add_constraint (nc_eq (nvar (kopt_kid kopt)) (nconstant n)) nc_env in if prove nc_env nc then exp else raise (Reporting.err_unreachable l __POS__ ("Constant propagation error: literal " ^ Big_int.to_string n ^ " does not satisfy constraint " ^ string_of_n_constraint nc)) - | E_aux (E_lit (L_aux (L_undef,_)),_), Some ([kid],nc,typ'') when atom_typ_kid kid typ'' -> - let nexp = fabricate_nexp_exist env Unknown typ [kid] nc typ'' in - let newtyp = subst_src_typ (KBindings.singleton kid nexp) typ'' in + | E_aux (E_lit (L_aux (L_undef,_)),_), Some ([kopt],nc,typ'') when atom_typ_kid (kopt_kid kopt) typ'' -> + let nexp = fabricate_nexp_exist env Unknown typ [kopt_kid kopt] nc typ'' in + let newtyp = subst_src_typ (KBindings.singleton (kopt_kid kopt) nexp) typ'' in E_aux (E_cast (newtyp, exp), (Generated l,replace_typ newtyp annot)) | E_aux (E_cast (_, (E_aux (E_lit (L_aux (L_undef,_)),_) as exp)),_), - Some ([kid],nc,typ'') when atom_typ_kid kid typ'' -> - let nexp = fabricate_nexp_exist env Unknown typ [kid] nc typ'' in - let newtyp = subst_src_typ (KBindings.singleton kid nexp) typ'' in + Some ([kopt],nc,typ'') when atom_typ_kid (kopt_kid kopt) typ'' -> + let nexp = fabricate_nexp_exist env Unknown typ [kopt_kid kopt] nc typ'' in + let newtyp = subst_src_typ (KBindings.singleton (kopt_kid kopt) nexp) typ'' in E_aux (E_cast (newtyp, exp), (Generated l,replace_typ newtyp annot)) | _ -> E_aux (E_cast (typ,exp),(l,annot)) @@ -2185,8 +2192,8 @@ let rec sizes_of_typ (Typ_aux (t,l)) = "Function type on expression") | Typ_bidir _ -> raise (Reporting.err_general l "Mapping type on expression") | Typ_tup typs -> kidset_bigunion (List.map sizes_of_typ typs) - | Typ_exist (kids,_,typ) -> - List.fold_left (fun s k -> KidSet.remove k s) (sizes_of_typ typ) kids + | Typ_exist (kopts,_,typ) -> + List.fold_left (fun s k -> KidSet.remove (kopt_kid k) s) (sizes_of_typ typ) kopts | Typ_app (Id_aux (Id "vector",_), [A_aux (A_nexp size,_); _;A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) -> @@ -3184,11 +3191,11 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) = let env, tenv, typ = match destruct_exist (Env.expand_synonyms tenv typ) with | None -> env, tenv, typ - | Some (kids, nc, typ) -> + | Some (kopts, nc, typ) -> { env with kid_deps = - List.fold_left (fun kds kid -> KBindings.add kid deps kds) env.kid_deps kids }, + List.fold_left (fun kds kopt -> KBindings.add (kopt_kid kopt) deps kds) env.kid_deps kopts }, Env.add_constraint nc - (List.fold_left (fun tenv kid -> Env.add_typ_var l kid K_int tenv) tenv kids), + (List.fold_left (fun tenv kopt -> Env.add_typ_var l kopt tenv) tenv kopts), typ in if is_bitvector_typ typ then diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index 6dfc6191..35aa9e20 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -277,7 +277,7 @@ let rec coq_nvars_of_typ (Typ_aux (t,l)) = List.fold_left (fun s ta -> KidSet.union s (coq_nvars_of_typ_arg ta)) KidSet.empty tas (* TODO: remove appropriate bound variables *) - | Typ_exist (kids,_,t) -> trec t + | Typ_exist (_,_,t) -> trec t | Typ_bidir _ -> unreachable l __POS__ "Coq doesn't support bidir types" | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" and coq_nvars_of_typ_arg (A_aux (ta,_)) = @@ -359,11 +359,11 @@ let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) = let kid = mk_kid "rangevar" in let var = nvar kid in let nc = nc_and (nc_lteq low var) (nc_lteq var high) in - Some (Typ_aux (Typ_exist ([kid], nc, atom_typ var),Parse_ast.Generated l)) + Some (Typ_aux (Typ_exist ([mk_kopt K_int kid], nc, atom_typ var),Parse_ast.Generated l)) | Typ_id (Id_aux (Id "nat",_)) -> let kid = mk_kid "n" in let var = nvar kid in - Some (Typ_aux (Typ_exist ([kid], nc_gteq var (nconstant Nat_big_num.zero), atom_typ var), + Some (Typ_aux (Typ_exist ([mk_kopt K_int kid], nc_gteq var (nconstant Nat_big_num.zero), atom_typ var), Parse_ast.Generated l)) | _ -> None @@ -449,24 +449,25 @@ let doc_typ, doc_atomic_typ = * if we add a new Typ constructor *) let tpp = typ ty in if atyp_needed then parens tpp else tpp - | Typ_exist (kids,nc,ty') -> begin - let kids,nc,ty' = match maybe_expand_range_type ty' with - | Some (Typ_aux (Typ_exist (kids',nc',ty'),_)) -> - kids'@kids,nc_and nc nc',ty' - | _ -> kids,nc,ty' + (* TODO: handle non-integer kopts *) + | Typ_exist (kopts,nc,ty') -> begin + let kopts,nc,ty' = match maybe_expand_range_type ty' with + | Some (Typ_aux (Typ_exist (kopts',nc',ty'),_)) -> + kopts'@kopts,nc_and nc nc',ty' + | _ -> kopts,nc,ty' in match ty' with | Typ_aux (Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp nexp,_)]),_) -> - begin match nexp, kids with - | (Nexp_aux (Nexp_var kid,_)), [kid'] when Kid.compare kid kid' == 0 -> + begin match nexp, kopts with + | (Nexp_aux (Nexp_var kid,_)), [kopt] when Kid.compare kid (kopt_kid kopt) == 0 -> braces (separate space [doc_var ctx kid; colon; string "Z"; ampersand; doc_arithfact ctx nc]) | _ -> let var = mk_kid "_atom" in (* TODO collision avoid *) let nc = nice_and (nc_eq (nvar var) nexp) nc in braces (separate space [doc_var ctx var; colon; string "Z"; - ampersand; doc_arithfact ctx ~exists:kids nc]) + ampersand; doc_arithfact ctx ~exists:(List.map kopt_kid kopts) nc]) end | Typ_aux (Typ_app (Id_aux (Id "vector",_), [A_aux (A_nexp m, _); @@ -474,7 +475,7 @@ let doc_typ, doc_atomic_typ = A_aux (A_typ elem_typ, _)]),_) -> (* TODO: proper handling of m, complex elem type, dedup with above *) let var = mk_kid "_vec" in (* TODO collision avoid *) - let kid_set = KidSet.of_list kids in + let kid_set = KidSet.of_list (List.map kopt_kid kopts) in let m_pp = doc_nexp ctx ~skip_vars:kid_set m in let tpp, len_pp = match elem_typ with | Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) -> @@ -489,7 +490,7 @@ let doc_typ, doc_atomic_typ = braces (separate space [doc_var ctx var; colon; tpp; ampersand; - doc_arithfact ctx ~exists:kids ?extra:length_constraint_pp nc]) + doc_arithfact ctx ~exists:(List.map kopt_kid kopts) ?extra:length_constraint_pp nc]) | _ -> raise (Reporting.err_todo l ("Non-atom existential type not yet supported in Coq: " ^ @@ -858,7 +859,7 @@ let replace_atom_return_type ret_typ = match ret_typ with | Typ_aux (Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp nexp,_)]),l) -> let kid = mk_kid "_retval" in (* TODO: collision avoidance *) - true, Typ_aux (Typ_exist ([kid], nc_eq (nvar kid) nexp, atom_typ (nvar kid)),Parse_ast.Generated l) + true, Typ_aux (Typ_exist ([mk_kopt K_int kid], nc_eq (nvar kid) nexp, atom_typ (nvar kid)),Parse_ast.Generated l) | _ -> false, ret_typ let is_range_from_atom env (Typ_aux (argty,_)) (Typ_aux (fnty,_)) = @@ -2031,15 +2032,15 @@ let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) = when not (is_enum env id) -> begin let full_typ = (expand_range_type exp_typ) in match destruct_exist (Env.expand_synonyms env full_typ) with - | Some ([kid], NC_aux (NC_true,_), + | Some ([kopt], NC_aux (NC_true,_), Typ_aux (Typ_app (Id_aux (Id "atom",_), - [A_aux (A_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_)) - when Kid.compare kid kid' == 0 -> + [A_aux (A_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_)) + when Kid.compare (kopt_kid kopt) kid == 0 -> parens (separate space [doc_id id; colon; string "Z"]) - | Some ([kid], nc, + | Some ([kopt], nc, Typ_aux (Typ_app (Id_aux (Id "atom",_), - [A_aux (A_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_)) - when Kid.compare kid kid' == 0 -> + [A_aux (A_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_)) + when Kid.compare (kopt_kid kopt) kid == 0 -> (used_a_pattern := true; squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt typ])) | _ -> diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 6e2a2b55..ac0195aa 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -313,7 +313,8 @@ let doc_typ_lem, doc_atomic_typ_lem = * if we add a new Typ constructor *) let tpp = typ ty in if atyp_needed then parens tpp else tpp - | Typ_exist (kids,_,ty) -> begin + | Typ_exist (kopts,_,ty) when List.for_all is_nat_kopt kopts -> begin + let kids = List.map kopt_kid kopts in let tpp = typ ty in let visible_vars = lem_tyvars_of_typ ty in match List.filter (fun kid -> KidSet.mem kid visible_vars) kids with @@ -323,6 +324,7 @@ let doc_typ_lem, doc_atomic_typ_lem = String.concat ", " (List.map string_of_kid bad) ^ " escape into Lem")) end + | Typ_exist _ -> unreachable l __POS__ "Non-integer existentials currently unsupported in Lem" (* TODO *) | Typ_bidir _ -> unreachable l __POS__ "Lem doesn't support bidir types" | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" and doc_typ_arg_lem (A_aux(t,_)) = match t with @@ -528,7 +530,8 @@ let rec typ_needs_printed (Typ_aux (t,_) as typ) = match t with | Typ_app (Id_aux (Id "itself",_),_) -> true | Typ_app (_, targs) -> is_bitvector_typ typ || List.exists typ_needs_printed_arg targs | Typ_fn (ts,t,_) -> List.exists typ_needs_printed ts || typ_needs_printed t - | Typ_exist (kids,_,t) -> + | Typ_exist (kopts,_,t) -> + let kids = List.map kopt_kid kopts in (* TODO: Check this *) let visible_kids = KidSet.inter (KidSet.of_list kids) (lem_tyvars_of_typ t) in typ_needs_printed t && KidSet.is_empty visible_kids | _ -> false diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index d779b3a7..e0223105 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -65,6 +65,12 @@ let doc_id (Id_aux (id_aux, _)) = let doc_kid kid = string (Ast_util.string_of_kid kid) +let doc_kopt = function + | kopt when is_nat_kopt kopt -> doc_kid (kopt_kid kopt) + | kopt when is_typ_kopt kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Type"]) + | kopt when is_order_kopt kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Order"]) + | kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Bool"]) + let doc_int n = string (Big_int.to_string n) let docstring (l, _) = match l with @@ -166,13 +172,13 @@ and doc_typ ?(simple=false) (Typ_aux (typ_aux, l)) = | Typ_tup typs -> parens (separate_map (string ", ") doc_typ typs) | Typ_var kid -> doc_kid kid (* Resugar set types like {|1, 2, 3|} *) - | Typ_exist ([kid1], - NC_aux (NC_set (kid2, ints), _), - Typ_aux (Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var kid3, _)), _)]), _)) - when Kid.compare kid1 kid2 == 0 && Kid.compare kid2 kid3 == 0 && Id.compare (mk_id "atom") id == 0 -> + | Typ_exist ([kopt], + NC_aux (NC_set (kid1, ints), _), + Typ_aux (Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var kid2, _)), _)]), _)) + when Kid.compare (kopt_kid kopt) kid1 == 0 && Kid.compare kid1 kid2 == 0 && Id.compare (mk_id "atom") id == 0 -> enclose (string "{|") (string "|}") (separate_map (string ", ") doc_int ints) - | Typ_exist (kids, nc, typ) -> - braces (separate_map space doc_kid kids ^^ comma ^^ space ^^ doc_nc nc ^^ dot ^^ space ^^ doc_typ typ) + | Typ_exist (kopts, nc, typ) -> + braces (separate_map space doc_kopt kopts ^^ comma ^^ space ^^ doc_nc nc ^^ dot ^^ space ^^ doc_typ typ) | Typ_fn (typs, typ, Effect_aux (Effect_set [], _)) -> separate space [doc_arg_typs typs; string "->"; doc_typ typ] | Typ_fn (typs, typ, Effect_aux (Effect_set effs, _)) -> @@ -194,11 +200,6 @@ and doc_arg_typs = function | [typ] -> doc_typ typ | typs -> parens (separate_map (comma ^^ space) doc_typ typs) -let doc_kopt = function - | kopt when is_nat_kopt kopt -> doc_kid (kopt_kid kopt) - | kopt when is_typ_kopt kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Type"]) - | kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Order"]) - let doc_quants quants = let doc_qi_kopt (QI_aux (qi_aux, _)) = match qi_aux with diff --git a/src/rewrites.ml b/src/rewrites.ml index d8f1af75..79b5f619 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -3745,7 +3745,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = (* Bind the loop variable in the body, annotated with constraints *) let lvar_kid = mk_kid ("loop_" ^ string_of_id id) in let lvar_nc = nc_and constr (nc_and (nc_lteq lower (nvar lvar_kid)) (nc_lteq (nvar lvar_kid) upper)) in - let lvar_typ = mk_typ (Typ_exist (lvar_kid :: kids, lvar_nc, atom_typ (nvar lvar_kid))) in + let lvar_typ = mk_typ (Typ_exist (List.map (mk_kopt K_int) (lvar_kid :: kids), lvar_nc, atom_typ (nvar lvar_kid))) in let lvar_pat = unaux_pat (add_p_typ lvar_typ (annot_pat (P_var ( annot_pat (P_id id) el env (atom_typ (nvar lvar_kid)), TP_aux (TP_var lvar_kid, gen_loc el))) el env lvar_typ)) in diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml index 65614b8d..0f8db0ff 100644 --- a/src/spec_analysis.ml +++ b/src/spec_analysis.ml @@ -94,7 +94,7 @@ let rec free_type_names_t consider_var (Typ_aux (t, l)) = match t with (free_type_names_t consider_var t2) | Typ_tup ts -> free_type_names_ts consider_var ts | Typ_app (name,targs) -> Nameset.add (string_of_id name) (free_type_names_t_args consider_var targs) - | Typ_exist (kids,_,t') -> List.fold_left (fun s kid -> Nameset.remove (string_of_kid kid) s) (free_type_names_t consider_var t') kids + | Typ_exist (kopts,_,t') -> List.fold_left (fun s kopt -> Nameset.remove (string_of_kid (kopt_kid kopt)) s) (free_type_names_t consider_var t') kopts | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" and free_type_names_ts consider_var ts = nameset_bigunion (List.map (free_type_names_t consider_var) ts) and free_type_names_maybe_t consider_var = function @@ -126,7 +126,10 @@ let rec fv_of_typ consider_var bound used (Typ_aux (t,l)) : Nameset.t = | Typ_tup ts -> List.fold_right (fun t n -> fv_of_typ consider_var bound n t) ts used | Typ_app(id,targs) -> List.fold_right (fun ta n -> fv_of_targ consider_var bound n ta) targs (conditional_add_typ bound used id) - | Typ_exist (kids,_,t') -> fv_of_typ consider_var (List.fold_left (fun b (Kid_aux (Var v,_)) -> Nameset.add v b) bound kids) used t' + | Typ_exist (kopts,_,t') -> + fv_of_typ consider_var + (List.fold_left (fun b (KOpt_aux (KOpt_kind (_, (Kid_aux (Var v,_))), _)) -> Nameset.add v b) bound kopts) + used t' | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" and fv_of_targ consider_var bound used (Ast.A_aux(targ,_)) : Nameset.t = match targ with diff --git a/src/specialize.ml b/src/specialize.ml index 583de600..1ba57bd0 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -100,13 +100,13 @@ let rec polymorphic_functions is_kopt (Defs defs) = let string_of_instantiation instantiation = let open Type_check in - let kid_names = ref KBindings.empty in + let kid_names = ref KOptMap.empty in let kid_counter = ref 0 in let kid_name kid = - try KBindings.find kid !kid_names with + try KOptMap.find kid !kid_names with | Not_found -> begin let n = string_of_int !kid_counter in - kid_names := KBindings.add kid n !kid_names; + kid_names := KOptMap.add kid n !kid_names; incr kid_counter; n end @@ -117,7 +117,7 @@ let string_of_instantiation instantiation = | Nexp_aux (nexp, _) -> string_of_nexp_aux nexp and string_of_nexp_aux = function | Nexp_id id -> string_of_id id - | Nexp_var kid -> kid_name kid + | Nexp_var kid -> kid_name (mk_kopt K_int kid) | Nexp_constant c -> Big_int.to_string c | Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")" | Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")" @@ -131,7 +131,7 @@ let string_of_instantiation instantiation = | Typ_aux (typ, l) -> string_of_typ_aux typ and string_of_typ_aux = function | Typ_id id -> string_of_id id - | Typ_var kid -> kid_name kid + | Typ_var kid -> kid_name (mk_kopt K_type kid) | Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")" | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")" | Typ_fn (arg_typs, ret_typ, eff) -> @@ -158,7 +158,7 @@ let string_of_instantiation instantiation = | NC_aux (NC_and (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_set (kid, ns), _) -> - kid_name kid ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" + kid_name (mk_kopt K_int kid) ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" | NC_aux (NC_true, _) -> "true" | NC_aux (NC_false, _) -> "false" in @@ -249,7 +249,7 @@ let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = | Typ_var kid -> KidSet.singleton kid | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs:exs) typs) | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs:exs) args) - | Typ_exist (kids, nc, typ) -> typ_frees ~exs:(KidSet.of_list kids) typ + | Typ_exist (kopts, nc, typ) -> typ_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ | Typ_fn (arg_typs, ret_typ, _) -> List.fold_left KidSet.union (typ_frees ~exs:exs ret_typ) (List.map (typ_frees ~exs:exs) arg_typs) | Typ_bidir (t1, t2) -> KidSet.union (typ_frees ~exs:exs t1) (typ_frees ~exs:exs t2) @@ -266,7 +266,7 @@ let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = | Typ_var kid -> KidSet.empty | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_int_frees ~exs:exs) typs) | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_int_frees ~exs:exs) args) - | Typ_exist (kids, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list kids) typ + | Typ_exist (kopts, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ | Typ_fn (arg_typs, ret_typ, _) -> List.fold_left KidSet.union (typ_int_frees ~exs:exs ret_typ) (List.map (typ_int_frees ~exs:exs) arg_typs) | Typ_bidir (t1, t2) -> KidSet.union (typ_int_frees ~exs:exs t1) (typ_int_frees ~exs:exs t2) diff --git a/src/type_check.ml b/src/type_check.ml index f1d9b961..5774a46f 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -194,7 +194,8 @@ and strip_typ_aux : typ_aux -> typ_aux = function | Typ_fn (arg_typs, ret_typ, effect) -> Typ_fn (List.map strip_typ arg_typs, strip_typ ret_typ, strip_effect effect) | Typ_bidir (typ1, typ2) -> Typ_bidir (strip_typ typ1, strip_typ typ2) | Typ_tup typs -> Typ_tup (List.map strip_typ typs) - | Typ_exist (kids, constr, typ) -> Typ_exist ((List.map strip_kid kids), strip_n_constraint constr, strip_typ typ) + | Typ_exist (kopts, constr, typ) -> + Typ_exist ((List.map strip_kinded_id kopts), strip_n_constraint constr, strip_typ typ) | Typ_app (id, args) -> Typ_app (strip_id id, List.map strip_typ_arg args) and strip_typ : typ -> typ = function | Typ_aux (typ_aux, _) -> Typ_aux (strip_typ_aux typ_aux, Parse_ast.Unknown) @@ -216,17 +217,21 @@ and strip_kind = function let ex_counter = ref 0 -let fresh_existential ?name:(n="") () = +let fresh_existential ?name:(n="") k = let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#" ^ n), Parse_ast.Unknown) in - incr ex_counter; fresh + incr ex_counter; mk_kopt k fresh let destruct_exist' typ = match typ with - | Typ_aux (Typ_exist (kids, nc, typ), _) -> - let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in - let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in - let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in - Some (List.map snd fresh_kids, nc, typ) + | Typ_aux (Typ_exist (kopts, nc, typ), _) -> + let fresh_kopts = + List.map (fun kopt -> (kopt_kid kopt, + fresh_existential ~name:(string_of_id (id_of_kid (kopt_kid kopt))) (unaux_kind (kopt_kind kopt)))) + kopts + in + let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_kopt fresh) nc) nc fresh_kopts in + let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_kopt fresh) typ) typ fresh_kopts in + Some (List.map snd fresh_kopts, nc, typ) | _ -> None (** Destructure and canonicalise a numeric type into a list of type @@ -240,23 +245,23 @@ let destruct_exist' typ = let destruct_numeric typ = match destruct_exist' typ, typ with | Some (kids, nc, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _)), _ when string_of_id id = "atom" -> - Some (kids, nc, nexp) + Some (List.map kopt_kid kids, nc, nexp) | None, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _) when string_of_id id = "atom" -> Some ([], nc_true, nexp) | None, Typ_aux (Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]), _) when string_of_id id = "range" -> - let kid = fresh_existential () in + let kid = kopt_kid (fresh_existential K_int) in Some ([kid], nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi), nvar kid) | None, Typ_aux (Typ_id id, _) when string_of_id id = "nat" -> - let kid = fresh_existential () in + let kid = kopt_kid (fresh_existential K_int) in Some ([kid], nc_lteq (nint 0) (nvar kid), nvar kid) | None, Typ_aux (Typ_id id, _) when string_of_id id = "int" -> - let kid = fresh_existential () in + let kid = kopt_kid (fresh_existential K_int) in Some ([kid], nc_true, nvar kid) | _, _ -> None let destruct_exist typ = match destruct_numeric typ with - | Some (kids, nc, nexp) -> Some (kids, nc, atom_typ nexp) + | Some (kids, nc, nexp) -> Some (List.map (mk_kopt K_int) kids, nc, atom_typ nexp) | None -> destruct_exist' typ @@ -303,7 +308,7 @@ module Env : sig val get_typ_var_loc : kid -> t -> Ast.l val get_typ_vars : t -> kind_aux KBindings.t val get_typ_var_locs : t -> Ast.l KBindings.t - val add_typ_var : l -> kid -> kind_aux -> t -> t + val add_typ_var : l -> kinded_id -> t -> t val get_ret_typ : t -> typ option val add_ret_typ : typ -> t -> t val add_typ_synonym : id -> (t -> typ_arg list -> typ_arg) -> t -> t @@ -545,7 +550,7 @@ end = struct end with | Not_found -> Typ_aux (Typ_id id, l)) - | Typ_exist (kids, nc, typ) -> + | Typ_exist (kopts, nc, typ) -> (* When expanding an existential synonym we need to take care to add the type variables and constraints to the environment, so we can check constraints attached to type @@ -554,24 +559,27 @@ end = struct scope while doing this. *) let rebindings = ref [] in - let rename_kid kid = if KBindings.mem kid env.typ_vars then prepend_kid "syn#" kid else kid in - let add_typ_var env kid = + let rename_kopt (KOpt_aux (KOpt_kind (k, kid), l) as kopt) = + if KBindings.mem kid env.typ_vars then + KOpt_aux (KOpt_kind (k, prepend_kid "syn#" kid), l) + else kopt + in + let add_typ_var env (KOpt_aux (KOpt_kind (k, kid), l) as kopt) = try let (l, _) = KBindings.find kid env.typ_vars in rebindings := kid :: !rebindings; - { env with typ_vars = KBindings.add (prepend_kid "syn#" kid) (l, K_int) env.typ_vars } + { env with typ_vars = KBindings.add (prepend_kid "syn#" kid) (l, unaux_kind k) env.typ_vars } with | Not_found -> - { env with typ_vars = KBindings.add kid (l, K_int) env.typ_vars } + { env with typ_vars = KBindings.add kid (l, unaux_kind k) env.typ_vars } in - let env = List.fold_left add_typ_var env kids in - let kids = List.map rename_kid kids in + let env = List.fold_left add_typ_var env kopts in + let kopts = List.map rename_kopt kopts in let nc = List.fold_left (fun nc kid -> constraint_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) nc) nc !rebindings in let typ = List.fold_left (fun typ kid -> typ_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) typ) typ !rebindings in - typ_debug (lazy ("Synonym existential: {" ^ string_of_list " " string_of_kid kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}")); let env = { env with constraints = nc :: env.constraints } in - Typ_aux (Typ_exist (kids, nc, expand_synonyms env typ), l) + Typ_aux (Typ_exist (kopts, nc, expand_synonyms env typ), l) | Typ_var v -> Typ_aux (Typ_var v, l) and expand_synonyms_arg env (A_aux (typ_arg, l)) = match typ_arg with @@ -623,9 +631,9 @@ end = struct check_args_typquant id env args (infer_kind env id) | Typ_app (id, _) -> typ_error l ("Undefined type " ^ string_of_id id) | Typ_exist ([], _, _) -> typ_error l ("Existential must have some type variables") - | Typ_exist (kids, nc, typ) when KidSet.is_empty exs -> - wf_constraint ~exs:(KidSet.of_list kids) env nc; - wf_typ ~exs:(KidSet.of_list kids) { env with constraints = nc :: env.constraints } typ + | Typ_exist (kopts, nc, typ) when KidSet.is_empty exs -> + wf_constraint ~exs:(KidSet.of_list (List.map kopt_kid kopts)) env nc; + wf_typ ~exs:(KidSet.of_list (List.map kopt_kid kopts)) { env with constraints = nc :: env.constraints } typ | Typ_exist (_, _, _) -> typ_error l ("Nested existentials are not allowed") | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" and wf_typ_arg ?exs:(exs=KidSet.empty) env (A_aux (typ_arg_aux, _)) = @@ -748,7 +756,7 @@ end = struct let existential_arg typq = function | None -> typq | Some (exs, nc, _) -> - List.fold_left (fun typq kid -> quant_add (mk_qi_id K_int kid) typq) (quant_add (mk_qi_nc nc) typq) exs + List.fold_left (fun typq kopt -> quant_add (mk_qi_kopt kopt) typq) (quant_add (mk_qi_nc nc) typq) exs in let typq = List.fold_left existential_arg typq base_args in let arg_typs = List.map2 (fun typ -> function Some (_, _, typ) -> typ | None -> typ) arg_typs base_args in @@ -999,9 +1007,9 @@ end = struct with | Not_found -> Unbound - let add_typ_var l kid k env = + let add_typ_var l (KOpt_aux (KOpt_kind (K_aux (k, _), kid), _) as kopt) env = if KBindings.mem kid env.typ_vars - then typ_error (kid_loc kid) ("type variable " ^ string_of_kid kid ^ " is already bound") + then typ_error (kid_loc kid) ("type variable " ^ string_of_kinded_id kopt ^ " is already bound") else begin typ_print (lazy (adding ^ "type variable " ^ string_of_kid kid ^ " : " ^ string_of_kind_aux k)); @@ -1107,7 +1115,7 @@ let add_typquant l (quant : typquant) (env : Env.t) : Env.t = | QI_aux (qi, _) -> add_quant_item_aux env qi and add_quant_item_aux env = function | QI_const constr -> Env.add_constraint constr env - | QI_id (KOpt_aux (KOpt_kind (K_aux (k, _), kid), _)) -> Env.add_typ_var l kid k env + | QI_id kopt -> Env.add_typ_var l kopt env in match quant with | TypQ_aux (TypQ_no_forall, _) -> env @@ -1123,24 +1131,24 @@ let default_order_error_string = let dvector_typ env n typ = vector_typ n (Env.get_default_order env) typ -let add_existential l kids nc env = - let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env kids in +let add_existential l kopts nc env = + let env = List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env kopts in Env.add_constraint nc env -let add_typ_vars l kids env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env kids +let add_typ_vars l kopts env = List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env kopts let is_exist = function | Typ_aux (Typ_exist (_, _, _), _) -> true | _ -> false let exist_typ constr typ = - let fresh_kid = fresh_existential () in - mk_typ (Typ_exist ([fresh_kid], constr fresh_kid, typ fresh_kid)) + let fresh = fresh_existential K_int in + mk_typ (Typ_exist ([fresh], constr (kopt_kid fresh), typ (kopt_kid fresh))) let bind_numeric l typ env = match destruct_numeric (Env.expand_synonyms env typ) with | Some (kids, nc, nexp) -> - nexp, add_existential l kids nc env + nexp, add_existential l (List.map (mk_kopt K_int) kids) nc env | None -> typ_error l ("Expected " ^ string_of_typ typ ^ " to be numeric") (** Pull an (potentially)-existentially qualified type into the global @@ -1151,14 +1159,14 @@ let bind_existential l typ env = | None -> typ, env let destruct_range env typ = - let kids, constr, (Typ_aux (typ_aux, _)) = + let kopts, constr, (Typ_aux (typ_aux, _)) = Util.option_default ([], nc_true, typ) (destruct_exist (Env.expand_synonyms env typ)) in match typ_aux with | Typ_app (f, [A_aux (A_nexp n, _)]) - when string_of_id f = "atom" -> Some (kids, constr, n, n) + when string_of_id f = "atom" -> Some (List.map kopt_kid kopts, constr, n, n) | Typ_app (f, [A_aux (A_nexp n1, _); A_aux (A_nexp n2, _)]) - when string_of_id f = "range" -> Some (kids, constr, n1, n2) + when string_of_id f = "range" -> Some (List.map kopt_kid kopts, constr, n1, n2) | _ -> None let destruct_vector env typ = @@ -1312,7 +1320,7 @@ let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = | Typ_var kid -> KidSet.singleton kid | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs:exs) typs) | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs:exs) args) - | Typ_exist (kids, nc, typ) -> typ_frees ~exs:(KidSet.of_list kids) typ + | Typ_exist (kopts, nc, typ) -> typ_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ | Typ_fn (arg_typs, ret_typ, _) -> List.fold_left KidSet.union (typ_frees ~exs:exs ret_typ) (List.map (typ_frees ~exs:exs) arg_typs) | Typ_bidir (typ1, typ2) -> KidSet.union (typ_frees ~exs:exs typ1) (typ_frees ~exs:exs typ2) and typ_arg_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) = @@ -1379,8 +1387,8 @@ let typ_identical env typ1 typ2 = try Id.compare f1 f2 = 0 && List.for_all2 typ_arg_identical args1 args2 with | Invalid_argument _ -> false end - | Typ_exist (kids1, nc1, typ1), Typ_exist (kids2, nc2, typ2) when List.length kids1 = List.length kids2 -> - List.for_all2 (fun k1 k2 -> Kid.compare k1 k2 = 0) kids1 kids2 && nc_identical nc1 nc2 && typ_identical' typ1 typ2 + | Typ_exist (kopts1, nc1, typ1), Typ_exist (kopts2, nc2, typ2) when List.length kopts1 = List.length kopts2 -> + List.for_all2 (fun k1 k2 -> KOpt.compare k1 k2 = 0) kopts1 kopts2 && nc_identical nc1 nc2 && typ_identical' typ1 typ2 | _, _ -> false and typ_arg_identical (A_aux (arg1, _)) (A_aux (arg2, _)) = match arg1, arg2 with @@ -1573,26 +1581,28 @@ let destruct_atom_kid env typ = only care about Int-kinded kids because those are the only type that can appear in an existential. *) -let rec kid_order_nexp kids (Nexp_aux (aux, l) as nexp) = +let rec kid_order_nexp kind_map (Nexp_aux (aux, l) as nexp) = match aux with - | Nexp_var kid when KidSet.mem kid kids -> ([kid], KidSet.remove kid kids) - | Nexp_var _ | Nexp_id _ | Nexp_constant _ -> ([], kids) - | Nexp_exp nexp | Nexp_neg nexp -> kid_order_nexp kids nexp + | Nexp_var kid when KBindings.mem kid kind_map -> + ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) + | Nexp_var _ | Nexp_id _ | Nexp_constant _ -> ([], kind_map) + | Nexp_exp nexp | Nexp_neg nexp -> kid_order_nexp kind_map nexp | Nexp_times (nexp1, nexp2) | Nexp_sum (nexp1, nexp2) | Nexp_minus (nexp1, nexp2) -> - let (ord, kids) = kid_order_nexp kids nexp1 in + let (ord, kids) = kid_order_nexp kind_map nexp1 in let (ord', kids) = kid_order_nexp kids nexp2 in (ord @ ord', kids) | Nexp_app (id, nexps) -> - List.fold_left (fun (ord, kids) nexp -> let (ord', kids) = kid_order_nexp kids nexp in (ord @ ord', kids)) ([], kids) nexps + List.fold_left (fun (ord, kids) nexp -> let (ord', kids) = kid_order_nexp kids nexp in (ord @ ord', kids)) ([], kind_map) nexps -let rec kid_order kids (Typ_aux (aux, l) as typ) = +let rec kid_order kind_map (Typ_aux (aux, l) as typ) = match aux with - | Typ_var kid when KidSet.mem kid kids -> ([kid], KidSet.remove kid kids) - | Typ_id _ | Typ_var _ -> ([], kids) + | Typ_var kid when KBindings.mem kid kind_map -> + ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) + | Typ_id _ | Typ_var _ -> ([], kind_map) | Typ_tup typs -> - List.fold_left (fun (ord, kids) typ -> let (ord', kids) = kid_order kids typ in (ord @ ord', kids)) ([], kids) typs + List.fold_left (fun (ord, kids) typ -> let (ord', kids) = kid_order kids typ in (ord @ ord', kids)) ([], kind_map) typs | Typ_app (_, args) -> - List.fold_left (fun (ord, kids) arg -> let (ord', kids) = kid_order_arg kids arg in (ord @ ord', kids)) ([], kids) args + List.fold_left (fun (ord, kids) arg -> let (ord', kids) = kid_order_arg kids arg in (ord @ ord', kids)) ([], kind_map) args | Typ_fn _ | Typ_bidir _ | Typ_exist _ -> typ_error l ("Existential or function type cannot appear within existential type: " ^ string_of_typ typ) | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" and kid_order_arg kids (A_aux (aux, l) as arg) = @@ -1613,13 +1623,14 @@ let rec alpha_equivalent env typ1 typ2 = | Typ_fn (arg_typs, ret_typ, eff) -> Typ_fn (List.map relabel arg_typs, relabel ret_typ, eff) | Typ_bidir (typ1, typ2) -> Typ_bidir (relabel typ1, relabel typ2) | Typ_tup typs -> Typ_tup (List.map relabel typs) - | Typ_exist (kids, nc, typ) -> - let (kids, _) = kid_order (KidSet.of_list kids) typ in - let kids = List.map (fun kid -> (kid, new_kid ())) kids in - let nc = List.fold_left (fun nc (kid, nk) -> constraint_subst kid (arg_nexp (nvar nk)) nc) nc kids in - let typ = List.fold_left (fun nc (kid, nk) -> typ_subst kid (arg_nexp (nvar nk)) nc) typ kids in - let kids = List.map snd kids in - Typ_exist (kids, nc, typ) + | Typ_exist (kopts, nc, typ) -> + let kind_map = List.fold_left (fun m kopt -> KBindings.add (kopt_kid kopt) (kopt_kind kopt) m) KBindings.empty kopts in + let (kopts, _) = kid_order kind_map typ in + let kopts = List.map (fun kopt -> (kopt_kid kopt, mk_kopt (unaux_kind (kopt_kind kopt)) (new_kid ()))) kopts in + let nc = List.fold_left (fun nc (kid, nk) -> constraint_subst kid (arg_kopt nk) nc) nc kopts in + let typ = List.fold_left (fun nc (kid, nk) -> typ_subst kid (arg_kopt nk) nc) typ kopts in + let kopts = List.map snd kopts in + Typ_exist (kopts, nc, typ) | Typ_app (id, args) -> Typ_app (id, List.map relabel_arg args) in @@ -1699,11 +1710,11 @@ let rec subtyp l env typ1 typ2 = | _, _ when alpha_equivalent env typ1 typ2 -> () (* Special cases for two numeric (atom) types *) | Some (kids1, nc1, nexp1), Some ([], _, nexp2) -> - let env = add_existential l kids1 nc1 env in + let env = add_existential l (List.map (mk_kopt K_int) kids1) nc1 env in if prove env (nc_eq nexp1 nexp2) then () else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env)) | Some (kids1, nc1, nexp1), Some (kids2, nc2, nexp2) -> - let env = add_existential l kids1 nc1 env in - let env = add_typ_vars l (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2))) env in + let env = add_existential l (List.map (mk_kopt K_int) kids1) nc1 env in + let env = add_typ_vars l (List.map (mk_kopt K_int) (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2)))) env in let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in if not (kids2 = []) then typ_error l ("Universally quantified constraint generated: " ^ Util.string_of_list ", " string_of_kid kids2) else (); let env = Env.add_constraint (nc_eq nexp1 nexp2) env in @@ -1711,13 +1722,13 @@ let rec subtyp l env typ1 typ2 = else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env)) | _, _ -> match destruct_exist' typ1, destruct_exist (canonicalize env typ2) with - | Some (kids, nc, typ1), _ -> - let env = add_existential l kids nc env in subtyp l env typ1 typ2 - | None, Some (kids, nc, typ2) -> + | Some (kopts, nc, typ1), _ -> + let env = add_existential l kopts nc env in subtyp l env typ1 typ2 + | None, Some (kopts, nc, typ2) -> typ_debug (lazy "Subtype check with unification"); let typ1 = canonicalize env typ1 in - let env = add_typ_vars l kids env in - let kids' = KidSet.elements (KidSet.diff (KidSet.of_list kids) (typ_frees typ2)) in + let env = add_typ_vars l kopts env in + let kids' = KidSet.elements (KidSet.diff (KidSet.of_list (List.map kopt_kid kopts)) (typ_frees typ2)) in if not (kids' = []) then typ_error l "Universally quantified constraint generated" else (); let unifiers = try unify l env (KidSet.diff (tyvars_of_typ typ2) (tyvars_of_typ typ1)) typ2 typ1 with @@ -2760,7 +2771,7 @@ and bind_typ_pat env (TP_aux (typ_pat_aux, l) as typ_pat) (Typ_aux (typ_aux, _) begin match typ_nexps typ with | [nexp] -> - Env.add_constraint (nc_eq (nvar kid) nexp) (Env.add_typ_var l kid K_int env) + Env.add_constraint (nc_eq (nvar kid) nexp) (Env.add_typ_var l (mk_kopt K_int kid) env) | [] -> typ_error l ("No numeric expressions in " ^ string_of_typ typ ^ " to bind " ^ string_of_kid kid ^ " to") | nexps -> @@ -2773,7 +2784,7 @@ and bind_typ_pat_arg env (TP_aux (typ_pat_aux, l) as typ_pat) (A_aux (typ_arg_au match typ_pat_aux, typ_arg_aux with | TP_wild, _ -> env | TP_var kid, A_nexp nexp -> - Env.add_constraint (nc_eq (nvar kid) nexp) (Env.add_typ_var l kid K_int env) + Env.add_constraint (nc_eq (nvar kid) nexp) (Env.add_typ_var l (mk_kopt K_int kid) env) | _, A_typ typ -> bind_typ_pat env typ_pat typ | _, A_order _ -> typ_error l "Cannot bind type pattern against order" | _, _ -> typ_error l ("Couldn't bind type argument " ^ string_of_typ_arg typ_arg ^ " with " ^ string_of_typ_pat typ_pat) @@ -3117,7 +3128,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = match destruct_numeric (typ_of inferred_f), destruct_numeric (typ_of inferred_t) with | Some (kids1, nc1, nexp1), Some (kids2, nc2, nexp2) -> let loop_kid = mk_kid ("loop_" ^ string_of_id v) in - let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env (loop_kid :: kids1 @ kids2) in + let env = List.fold_left (fun env kid -> Env.add_typ_var l (mk_kopt K_int kid) env) env (loop_kid :: kids1 @ kids2) in let env = Env.add_constraint (nc_and nc1 nc2) env in let env = Env.add_constraint (nc_and (nc_lteq nexp1 (nvar loop_kid)) (nc_lteq (nvar loop_kid) nexp2)) env in let loop_vtyp = atom_typ (nvar loop_kid) in @@ -3311,17 +3322,17 @@ and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = typ_raise l (Err_unresolved_quants (f, !quants, Env.get_locals env, Env.get_constraints env)) else (); - let ty_vars = List.map fst (KBindings.bindings (Env.get_typ_vars env)) in - let existentials = List.filter (fun kid -> not (KBindings.mem kid universals)) ty_vars in + let ty_vars = KBindings.bindings (Env.get_typ_vars env) |> List.map (fun (v, k) -> mk_kopt k v) in + let existentials = List.filter (fun kopt -> not (KBindings.mem (kopt_kid kopt) universals)) ty_vars in let num_new_ncs = List.length (Env.get_constraints env) - List.length universal_constraints in let ex_constraints = take num_new_ncs (Env.get_constraints env) in - typ_debug (lazy ("Existentials: " ^ string_of_list ", " string_of_kid existentials)); + typ_debug (lazy ("Existentials: " ^ string_of_list ", " string_of_kinded_id existentials)); typ_debug (lazy ("Existential constraints: " ^ string_of_list ", " string_of_n_constraint ex_constraints)); let universals = KBindings.bindings universals |> List.map fst |> KidSet.of_list in let typ_ret = - if KidSet.is_empty (KidSet.of_list existentials) || KidSet.is_empty (KidSet.diff (typ_frees !typ_ret) universals) + if KidSet.is_empty (KidSet.of_list (List.map kopt_kid existentials)) || KidSet.is_empty (KidSet.diff (typ_frees !typ_ret) universals) then !typ_ret else mk_typ (Typ_exist (existentials, List.fold_left nc_and nc_true ex_constraints, !typ_ret)) in diff --git a/src/type_check.mli b/src/type_check.mli index 47b9d172..0a0e18f7 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -134,7 +134,7 @@ module Env : sig val get_typ_var_locs : t -> Ast.l KBindings.t - val add_typ_var : Ast.l -> kid -> kind_aux -> t -> t + val add_typ_var : Ast.l -> kinded_id -> t -> t val is_record : id -> t -> bool @@ -208,9 +208,12 @@ end an environment *) val add_typquant : Ast.l -> typquant -> Env.t -> Env.t -val destruct_exist : Env.t -> typ -> (kid list * n_constraint * typ) option +(** Safely destructure an existential type. Returns None if the type + is not existential. This function will pick a fresh name for the + existential to ensure that no name-clashes occur. *) +val destruct_exist : typ -> (kinded_id list * n_constraint * typ) option -val add_existential : Ast.l -> kid list -> n_constraint -> Env.t -> Env.t +val add_existential : Ast.l -> kinded_id list -> n_constraint -> Env.t -> Env.t (** When the typechecker creates new type variables it gives them fresh names of the form 'fvXXX#name, where XXX is a number (not @@ -349,11 +352,6 @@ val expected_typ_of : Ast.l * tannot -> typ option val destruct_atom_nexp : Env.t -> typ -> nexp option -(** Safely destructure an existential type. Returns None if the type - is not existential. This function will pick a fresh name for the - existential to ensure that no name-clashes occur. *) -val destruct_exist : typ -> (kid list * n_constraint * typ) option - val destruct_range : Env.t -> typ -> (kid list * n_constraint * nexp * nexp) option val destruct_numeric : typ -> (kid list * n_constraint * nexp) option |
