diff options
| author | Alasdair Armstrong | 2018-12-10 20:39:16 +0000 |
|---|---|---|
| committer | Alasdair Armstrong | 2018-12-10 20:45:05 +0000 |
| commit | 5bc5f5dee8921f8d24260dae54177e00c291fcb1 (patch) | |
| tree | 89bbd7a947e8063bdbaac4abf364f6cccd2c3fdf | |
| parent | d8f0854ca9d80d3af8d6a4aaec778643eda9421c (diff) | |
Various changes:
* Improve type inference for numeric if statements (if_infer test)
* Correctly handle constraints for existentially quantified constructors (constraint_ctor test)
* Canonicalise all numeric types in function arguments, which
triggers some weird edge cases between parametric polymorphism and
subtyping of numeric arguments
* Because of this eq_int, eq_range, and eq_atom etc become identical
* Avoid duplicating destruct_exist in Env
* Handle some odd subtyping cases better
33 files changed, 476 insertions, 242 deletions
diff --git a/aarch64/prelude.sail b/aarch64/prelude.sail index 8cd18fac..505ca7b6 100755 --- a/aarch64/prelude.sail +++ b/aarch64/prelude.sail @@ -143,12 +143,12 @@ val UInt = { interpreter: "uint", c: "sail_unsigned", coq: "uint" -} : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1) +} : forall 'n. bits('n) -> {'m, 0 <= 'm <= 2 ^ 'n - 1. int('m)} val SInt = { c: "sail_signed", _: "sint" -} : forall 'n. bits('n) -> range(- (2 ^ ('n - 1)), 2 ^ ('n - 1) - 1) +} : forall 'n. bits('n) -> {'m, (- (2 ^ ('n - 1))) <= 'm <= 2 ^ ('n - 1) - 1. int('m)} val hex_slice = "hex_slice" : forall 'n 'm. (string, atom('n), atom('m)) -> bits('n - 'm) effect {escape} diff --git a/lib/flow.sail b/lib/flow.sail index cdc6b2fd..b9653828 100644 --- a/lib/flow.sail +++ b/lib/flow.sail @@ -20,34 +20,9 @@ val not_bool = {coq: "negb", _: "not"} : bool -> bool or_bool that are not shown here. */ val and_bool = {coq: "andb", _: "and_bool"} : (bool, bool) -> bool val or_bool = {coq: "orb", _: "or_bool"} : (bool, bool) -> bool - -val eq_atom = {ocaml: "eq_int", lem: "eq", c: "eq_int", coq: "Z.eqb"} : forall 'n 'm. (atom('n), atom('m)) -> bool - -val neq_atom = {lem: "neq", coq: "neq_atom"} : forall 'n 'm. (atom('n), atom('m)) -> bool - -function neq_atom (x, y) = not_bool(eq_atom(x, y)) - -val lteq_atom = {coq: "Z.leb", _: "lteq"} : forall 'n 'm. (atom('n), atom('m)) -> bool -val gteq_atom = {coq: "Z.geb", _: "gteq"} : forall 'n 'm. (atom('n), atom('m)) -> bool -val lt_atom = {coq: "Z.ltb", _: "lt"} : forall 'n 'm. (atom('n), atom('m)) -> bool -val gt_atom = {coq: "Z.gtb", _: "gt"} : forall 'n 'm. (atom('n), atom('m)) -> bool - -val lt_range_atom = {coq: "ltb_range_l", _: "lt"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool -val lteq_range_atom = {coq: "leb_range_l", _: "lteq"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool -val gt_range_atom = {coq: "gtb_range_l", _: "gt"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool -val gteq_range_atom = {coq: "geb_range_l", _: "gteq"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool -val lt_atom_range = {coq: "ltb_range_r", _: "lt"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool -val lteq_atom_range = {coq: "leb_range_r", _: "lteq"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool -val gt_atom_range = {coq: "gtb_range_r", _: "gt"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool -val gteq_atom_range = {coq: "geb_range_r", _: "gteq"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool - -val eq_range = {ocaml: "eq_int", lem: "eq", c: "eq_int", coq: "eq_range"} : forall 'n 'm 'o 'p. (range('n, 'm), range('o, 'p)) -> bool val eq_int = {ocaml: "eq_int", lem: "eq", c: "eq_int", coq: "Z.eqb"} : (int, int) -> bool val eq_bool = {ocaml: "eq_bool", lem: "eq", c: "eq_bool", coq: "Bool.eqb"} : (bool, bool) -> bool -val neq_range = {lem: "neq"} : forall 'n 'm 'o 'p. (range('n, 'm), range('o, 'p)) -> bool -function neq_range (x, y) = not_bool(eq_range(x, y)) - val neq_int = {lem: "neq"} : (int, int) -> bool function neq_int (x, y) = not_bool(eq_int(x, y)) @@ -59,15 +34,15 @@ val gteq_int = {coq: "Z.geb", _:"gteq"} : (int, int) -> bool val lt_int = {coq: "Z.ltb", _:"lt"} : (int, int) -> bool val gt_int = {coq: "Z.gtb", _:"gt"} : (int, int) -> bool -overload operator == = {eq_atom, eq_range, eq_int, eq_bit, eq_bool, eq_unit} -overload operator != = {neq_atom, neq_range, neq_int, neq_bool} +overload operator == = {eq_int, eq_bit, eq_bool, eq_unit} +overload operator != = {neq_int, neq_bool} overload operator | = {or_bool} overload operator & = {and_bool} -overload operator <= = {lteq_atom, lteq_range_atom, lteq_atom_range, lteq_int} -overload operator < = {lt_atom, lt_range_atom, lt_atom_range, lt_int} -overload operator >= = {gteq_atom, gteq_range_atom, gteq_atom_range, gteq_int} -overload operator > = {gt_atom, gt_range_atom, gt_atom_range, gt_int} +overload operator <= = {lteq_int} +overload operator < = {lt_int} +overload operator >= = {gteq_int} +overload operator > = {gt_int} $ifdef TEST diff --git a/src/ast_util.ml b/src/ast_util.ml index f6b8317d..46afe599 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -368,12 +368,17 @@ let nc_lteq n1 n2 = NC_aux (NC_bounded_le (n1, n2), Parse_ast.Unknown) let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown) let nc_lt n1 n2 = nc_lteq (nsum n1 (nint 1)) n2 let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nint 1)) -let nc_and nc1 nc2 = mk_nc (NC_and (nc1, nc2)) let nc_or nc1 nc2 = mk_nc (NC_or (nc1, nc2)) let nc_var kid = mk_nc (NC_var kid) let nc_true = mk_nc NC_true let nc_false = mk_nc NC_false +let nc_and nc1 nc2 = + match nc1, nc2 with + | _, NC_aux (NC_true, _) -> nc1 + | NC_aux (NC_true, _), _ -> nc2 + | _, _ -> mk_nc (NC_and (nc1, nc2)) + let arg_nexp ?loc:(l=Parse_ast.Unknown) n = A_aux (A_nexp n, l) let arg_order ?loc:(l=Parse_ast.Unknown) ord = A_aux (A_order ord, l) let arg_typ ?loc:(l=Parse_ast.Unknown) typ = A_aux (A_typ typ, l) @@ -685,7 +690,7 @@ and string_of_typ_arg_aux = function | A_order o -> string_of_order o | A_bool nc -> string_of_n_constraint nc and string_of_n_constraint = function - | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2 + | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " == " ^ string_of_nexp n2 | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 | NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2 | NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2 diff --git a/src/c_backend.ml b/src/c_backend.ml index 535a0b67..95ab51df 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -194,7 +194,7 @@ let rec ctyp_of_typ ctx typ = ensure that we don't cause any type variable clashes in local_env, and that we can optimize the existential based upon it's constraints. *) - begin match destruct_exist ctx.local_env typ with + begin match destruct_exist (Env.expand_synonyms ctx.local_env typ) with | Some (kids, nc, typ) -> let env = add_existential l kids nc ctx.local_env in ctyp_of_typ { ctx with local_env = env } typ diff --git a/src/initial_check.ml b/src/initial_check.ml index 0f1af63d..44f36892 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -777,6 +777,11 @@ let typschm_of_string str = let typschm, _ = to_ast_typschm initial_ctx typschm in typschm +let typ_of_string str = + let typ = Parser.typ_eof Lexer.token (Lexing.from_string str) in + let typ = to_ast_typ initial_ctx typ in + typ + let extern_of_string id str = mk_val_spec (VS_val_spec (typschm_of_string str, id, (fun _ -> Some (string_of_id id)), false)) let val_spec_of_string id str = mk_val_spec (VS_val_spec (typschm_of_string str, id, (fun _ -> None), false)) diff --git a/src/initial_check.mli b/src/initial_check.mli index 32def316..25187e4c 100644 --- a/src/initial_check.mli +++ b/src/initial_check.mli @@ -91,3 +91,4 @@ val extern_of_string : id -> string -> unit def val val_spec_of_string : id -> string -> unit def val exp_of_string : string -> unit exp +val typ_of_string : string -> typ diff --git a/src/isail.ml b/src/isail.ml index 195e5940..18c59e0b 100644 --- a/src/isail.ml +++ b/src/isail.ml @@ -270,6 +270,9 @@ let handle_input' input = let exp = Type_check.infer_exp !interactive_env exp in pretty_sail stdout (doc_typ (Type_check.typ_of exp)); print_newline () + | ":canon" -> + let typ = Initial_check.typ_of_string arg in + print_endline (string_of_typ (Type_check.canonicalize !interactive_env typ)) | ":v" | ":verbose" -> Type_check.opt_tc_debug := (!Type_check.opt_tc_debug + 1) mod 3; print_endline ("Verbosity: " ^ string_of_int !Type_check.opt_tc_debug) diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 74ef8376..113db3a2 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -522,7 +522,7 @@ let refine_constructor refinements l env id args = (* A constructor should always have a single argument. *) | Typ_aux (Typ_fn ([constr_ty],_,_),_) -> begin let arg_ty = typ_of_args args in - match Type_check.destruct_exist env constr_ty with + match Type_check.destruct_exist (Type_check.Env.expand_synonyms env constr_ty) with | None -> None | Some (kids,nc,constr_ty) -> let bindings = Type_check.unify l env (tyvars_of_typ constr_ty) constr_ty arg_ty in @@ -728,7 +728,7 @@ let fabricate_nexp l tannot = match destruct_tannot tannot with | None -> nint 32 | Some (env,typ,_) -> - match Type_check.destruct_exist env typ with + match Type_check.destruct_exist (Type_check.Env.expand_synonyms env typ) with | None -> nint 32 | Some (kids,nc,typ') -> fabricate_nexp_exist env l typ kids nc typ' @@ -745,7 +745,7 @@ let atom_typ_kid kid = function let reduce_cast typ exp l annot = let env = env_of_annot (l,annot) in let typ' = Env.base_typ_of env typ in - match exp, destruct_exist env typ' with + match exp, destruct_exist (Env.expand_synonyms env typ') with | E_aux (E_lit (L_aux (L_num n,_)),_), Some ([kid],nc,typ'') when atom_typ_kid kid typ'' -> let nc_env = Env.add_typ_var l kid K_int env in let nc_env = Env.add_constraint (nc_eq (nvar kid) (nconstant n)) nc_env in @@ -3182,7 +3182,7 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) = | Some (tenv,typ,_) -> let typ = Env.base_typ_of tenv typ in let env, tenv, typ = - match destruct_exist tenv typ with + match destruct_exist (Env.expand_synonyms tenv typ) with | None -> env, tenv, typ | Some (kids, nc, typ) -> { env with kid_deps = diff --git a/src/parser.mly b/src/parser.mly index fa36591c..83e6936d 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -212,9 +212,11 @@ let rec desugar_rchain chain s e = %start file %start typschm_eof +%start typ_eof %start exp_eof %start def_eof %type <Parse_ast.typschm> typschm_eof +%type <Parse_ast.atyp> typ_eof %type <Parse_ast.exp> exp_eof %type <Parse_ast.def> def_eof %type <Parse_ast.defs> file @@ -349,6 +351,10 @@ tyarg: | Lparen typ_list Rparen { [], $2 } +typ_eof: + | typ Eof + { $1 } + typ: | typ0 { $1 } diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index 025156cc..f00a93b7 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -667,7 +667,7 @@ let is_ctor env id = match Env.lookup_id id env with let is_auto_decomposed_exist env typ = let typ = expand_range_type typ in - match destruct_exist env typ with + match destruct_exist (Env.expand_synonyms env typ) with | Some (_, _, typ') -> Some typ' | _ -> None @@ -905,7 +905,7 @@ let doc_exp, doc_let = debug ctxt (lazy (" at type " ^ string_of_typ typ)) in let typ = expand_range_type typ in - match destruct_exist env typ with + match destruct_exist typ with | None -> epp | Some _ -> let epp = string "build_ex" ^/^ epp in @@ -921,12 +921,12 @@ let doc_exp, doc_let = | _ -> let typ' = expand_range_type (Env.expand_synonyms (env_of exp) typ) in let build_ex, out_typ = - match destruct_exist env typ' with + match destruct_exist typ' with | Some (_,_,t) -> true, t | None -> false, typ' in let in_typ = expand_range_type (Env.expand_synonyms (env_of exp) (typ_of exp)) in - let in_typ = match destruct_exist env in_typ with Some (_,_,t) -> t | None -> in_typ in + let in_typ = match destruct_exist in_typ with Some (_,_,t) -> t | None -> in_typ in let autocast = (* Avoid using helper functions which simplify the nexps *) is_bitvector_typ in_typ && is_bitvector_typ out_typ && @@ -1528,7 +1528,7 @@ let doc_exp, doc_let = | P_aux (P_var (P_aux (P_typ (typ, P_aux (P_id id,_)),_),_),_) when not (is_enum (env_of e1) id) -> let full_typ = (expand_range_type typ) in - let binder = match destruct_exist (env_of e1) full_typ with + let binder = match destruct_exist (Env.expand_synonyms (env_of e1) full_typ) with | Some _ -> squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt typ]) | _ -> @@ -1975,7 +1975,7 @@ let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) = | _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type") in let build_ex, ret_typ = replace_atom_return_type ret_typ in - let build_ex = match destruct_exist env (expand_range_type ret_typ) with + let build_ex = match destruct_exist (Env.expand_synonyms env (expand_range_type ret_typ)) with | Some _ -> true | _ -> build_ex in @@ -2035,7 +2035,7 @@ let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) = | P_typ (_,P_aux (P_id id,_)) when not (is_enum env id) -> begin let full_typ = (expand_range_type exp_typ) in - match destruct_exist env full_typ with + match destruct_exist (Env.expand_synonyms env full_typ) with | Some ([kid], NC_aux (NC_true,_), Typ_aux (Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_)) @@ -2255,7 +2255,7 @@ let doc_val pat exp = | None -> typpp, exp | Some typ -> let typ = expand_range_type (Env.expand_synonyms env typ) in - match destruct_exist env typ with + match destruct_exist typ with | None -> typpp, exp | Some _ -> empty, match exp with diff --git a/src/rewrites.ml b/src/rewrites.ml index d5601d08..d8f1af75 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -3729,7 +3729,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = in let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in let ord_exp, kids, constr, lower, upper, lower_exp, upper_exp = - match destruct_numeric env (typ_of exp1), destruct_numeric env (typ_of exp2) with + match destruct_numeric (Env.expand_synonyms env (typ_of exp1)), destruct_numeric (Env.expand_synonyms env (typ_of exp2)) with | None, _ | _, None -> raise (Reporting.err_unreachable el __POS__ "Could not determine loop bounds") | Some (kids1, constr1, n1), Some (kids2, constr2, n2) -> diff --git a/src/type_check.ml b/src/type_check.ml index 42616361..459fe8d7 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -215,6 +215,52 @@ and strip_kinded_id_aux = function and strip_kind = function | K_aux (k_aux, _) -> K_aux (k_aux, Parse_ast.Unknown) +let ex_counter = ref 0 + +let fresh_existential ?name:(n="") () = + let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#" ^ n), Parse_ast.Unknown) in + incr ex_counter; fresh + +let destruct_exist' typ = + match typ with + | Typ_aux (Typ_exist (kids, nc, typ), _) -> + let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in + let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in + let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in + Some (List.map snd fresh_kids, nc, typ) + | _ -> None + +(** Destructure and canonicalise a numeric type into a list of type + variables, a constraint on those type variables, and an + N-expression that represents that numeric type in the + environment. For example: + - {'n, 'n <= 10. atom('n)} => ['n], 'n <= 10, 'n + - int => ['n], true, 'n (where x is fresh) + - atom('n) => [], true, 'n +**) +let destruct_numeric typ = + match destruct_exist' typ, typ with + | Some (kids, nc, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _)), _ when string_of_id id = "atom" -> + Some (kids, nc, nexp) + | None, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _) when string_of_id id = "atom" -> + Some ([], nc_true, nexp) + | None, Typ_aux (Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]), _) when string_of_id id = "range" -> + let kid = fresh_existential () in + Some ([kid], nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi), nvar kid) + | None, Typ_aux (Typ_id id, _) when string_of_id id = "nat" -> + let kid = fresh_existential () in + Some ([kid], nc_lteq (nint 0) (nvar kid), nvar kid) + | None, Typ_aux (Typ_id id, _) when string_of_id id = "int" -> + let kid = fresh_existential () in + Some ([kid], nc_true, nvar kid) + | _, _ -> None + +let destruct_exist typ = + match destruct_numeric typ with + | Some (kids, nc, nexp) -> Some (kids, nc, atom_typ nexp) + | None -> destruct_exist' typ + + let adding = Util.("Adding " |> darkgray |> clear) (**************************************************************************) @@ -244,6 +290,7 @@ module Env : sig val get_variant : id -> t -> typquant * type_union list val add_mapping : id -> typquant * typ * typ -> t -> t val add_union_id : id -> typquant * typ -> t -> t + val get_union_id : id -> t -> typquant * typ val add_flow : id -> (typ -> typ) -> t -> t val get_flow : id -> t -> typ -> typ val remove_flow : id -> t -> t @@ -286,11 +333,7 @@ module Env : sig val fresh_kid : ?kid:kid -> t -> kid val expand_synonyms : t -> typ -> typ val expand_constraint_synonyms : t -> n_constraint -> n_constraint - val canonicalize : t -> typ -> typ val base_typ_of : t -> typ -> typ - val add_smt_op : id -> string -> t -> t - val get_smt_op : id -> t -> string - val have_smt_op : id -> t -> bool val allow_unknowns : t -> bool val set_allow_unknowns : bool -> t -> t @@ -332,7 +375,6 @@ end = struct records : (typquant * (typ * id) list) Bindings.t; accessors : (typquant * typ) Bindings.t; externs : (string -> string option) Bindings.t; - smt_ops : string Bindings.t; casts : id list; allow_casts : bool; allow_bindings : bool; @@ -361,7 +403,6 @@ end = struct records = Bindings.empty; accessors = Bindings.empty; externs = Bindings.empty; - smt_ops = Bindings.empty; casts = []; allow_bindings = true; allow_casts = true; @@ -434,21 +475,6 @@ end = struct let existing = try Bindings.find id env.overloads with Not_found -> [] in { env with overloads = Bindings.add id (existing @ ids) env.overloads } - let add_smt_op id str env = - typ_print (lazy (adding ^ "smt binding " ^ string_of_id id ^ " to " ^ str)); - { env with smt_ops = Bindings.add id str env.smt_ops } - - let get_smt_op (Id_aux (_, l) as id) env = - let rec first_smt_op = function - | id :: ids -> (try Bindings.find id env.smt_ops with Not_found -> first_smt_op ids) - | [] -> typ_error l ("No SMT op for " ^ string_of_id id) - in - try Bindings.find id env.smt_ops with - | Not_found -> first_smt_op (get_overloads id env) - - let have_smt_op id env = - try ignore(get_smt_op id env); true with Type_error _ -> false - let rec infer_kind env id = if Bindings.mem id builtin_typs then Bindings.find id builtin_typs @@ -566,53 +592,6 @@ end = struct | A_order _ | A_typ _ | A_bool _ -> arg | A_nexp n -> A_aux (A_nexp (f n), l) - let canonical env typ = - let typ = expand_synonyms env typ in - let counter = ref 0 in - let complex_nexps = ref KBindings.empty in - let simplify_nexp (Nexp_aux (nexp_aux, l) as nexp) = - match nexp_aux with - | Nexp_constant _ -> nexp (* Check this ? *) - | _ -> - let kid = Kid_aux (Var ("'c#" ^ string_of_int !counter), l) in - complex_nexps := KBindings.add kid nexp !complex_nexps; - incr counter; - Nexp_aux (Nexp_var kid, l) - in - let typ = map_nexps (fun nexp -> simplify_nexp (nexp_simp nexp)) typ in - let existentials = KBindings.bindings !complex_nexps |> List.map fst in - let constrs = List.fold_left (fun ncs (kid, nexp) -> nc_eq (nvar kid) nexp :: ncs) [] (KBindings.bindings !complex_nexps) in - existentials, constrs, typ - - let is_canonical env typ = - let typ = expand_synonyms env typ in - let counter = ref 0 in - let simplify_nexp (Nexp_aux (nexp_aux, l) as nexp) = - match nexp_aux with - | Nexp_constant _ -> nexp - | _ -> (incr counter; nexp) - in - let typ = map_nexps simplify_nexp typ in - not (!counter > 0) - - let rec canonicalize env typ = - match typ with - | Typ_aux (Typ_fn (arg_typs, ret_typ, effects), l) when List.for_all (is_canonical env) arg_typs -> - Typ_aux (Typ_fn (arg_typs, canonicalize env ret_typ, effects), l) - | Typ_aux (Typ_fn _, l) -> typ_error l ("Function type " ^ string_of_typ typ ^ " is not canonical") - | _ -> - let existentials, constrs, (Typ_aux (typ_aux, l) as typ) = canonical env typ in - if existentials = [] then - typ - else - let typ_aux = match typ_aux with - | Typ_tup _ | Typ_app _ -> Typ_exist (existentials, List.fold_left nc_and (List.hd constrs) (List.tl constrs), typ) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids @ existentials, List.fold_left nc_and nc constrs, typ) - | Typ_fn _ | Typ_bidir _ | Typ_id _ | Typ_var _ -> assert false (* These must be simple *) - | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" - in - Typ_aux (typ_aux, l) - (* Check if a type, order, n-expression or constraint is well-formed. Throws a type error if the type is badly formed. *) let rec wf_typ ?exs:(exs=KidSet.empty) env typ = @@ -667,7 +646,6 @@ end = struct end | Nexp_constant _ -> () | Nexp_app (id, nexps) -> - let _ = get_smt_op id env in List.iter (fun n -> wf_nexp ~exs:exs env n) nexps | Nexp_times (nexp1, nexp2) -> wf_nexp ~exs:exs env nexp1; wf_nexp ~exs:exs env nexp2 | Nexp_sum (nexp1, nexp2) -> wf_nexp ~exs:exs env nexp1; wf_nexp ~exs:exs env nexp2 @@ -746,20 +724,6 @@ end = struct let ex_counter = ref 0 - (* TODO: Currently this is duplicated with destruct_exist outside of Env and deals with val spec arguments only. *) - let fresh_existential ?name:(n="") () = - let fresh = Kid_aux (Var ("'all" ^ string_of_int !ex_counter ^ "#" ^ n), Parse_ast.Unknown) in - incr ex_counter; fresh - - let destruct_exist env typ = - match expand_synonyms env typ with - | Typ_aux (Typ_exist (kids, nc, typ), _) -> - let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in - let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in - let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in - Some (List.map snd fresh_kids, nc, typ) - | _ -> None - let rec update_val_spec id (typq, typ) env = begin match expand_synonyms env typ with | Typ_aux (Typ_fn (arg_typs, ret_typ, effect), l) -> @@ -769,7 +733,7 @@ end = struct forall 'n, 'n >= 2. (int('n), foo) -> bar this enforces the invariant that all things on the left of functions are 'base types' (i.e. without existentials) *) - let base_args = List.map (destruct_exist env) arg_typs in + let base_args = List.map (fun typ -> destruct_exist (expand_synonyms env typ)) arg_typs in let existential_arg typq = function | None -> typq | Some (exs, nc, _) -> @@ -959,10 +923,15 @@ end = struct | None -> typ_error (id_loc id) ("union " ^ string_of_id id ^ " not found") let add_union_id id bind env = - begin - typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind)); - { env with union_ids = Bindings.add id bind env.union_ids } - end + typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind)); + { env with union_ids = Bindings.add id bind env.union_ids } + + let get_union_id id env = + try + let bind = Bindings.find id env.union_ids in + List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) + with + | Not_found -> typ_error (id_loc id) ("No union constructor found for " ^ string_of_id id) let get_flow id env = try Bindings.find id env.flow with @@ -1156,21 +1125,6 @@ let default_order_error_string = let dvector_typ env n typ = vector_typ n (Env.get_default_order env) typ -let ex_counter = ref 0 - -let fresh_existential ?name:(n="") () = - let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#" ^ n), Parse_ast.Unknown) in - incr ex_counter; fresh - -let destruct_exist env typ = - match Env.expand_synonyms env typ with - | Typ_aux (Typ_exist (kids, nc, typ), _) -> - let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in - let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in - let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in - Some (List.map snd fresh_kids, nc, typ) - | _ -> None - let add_existential l kids nc env = let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env kids in Env.add_constraint nc env @@ -1185,34 +1139,8 @@ let exist_typ constr typ = let fresh_kid = fresh_existential () in mk_typ (Typ_exist ([fresh_kid], constr fresh_kid, typ fresh_kid)) -(** Destructure and canonicalise a numeric type into a list of type - variables, a constraint on those type variables, and an - N-expression that represents that numeric type in the - environment. For example: - - {'n, 'n <= 10. atom('n)} => ['n], 'n <= 10, 'n - - int => ['n], true, 'n (where x is fresh) - - atom('n) => [], true, 'n -**) -let destruct_numeric env typ = - let typ = Env.expand_synonyms env typ in - match destruct_exist env typ, typ with - | Some (kids, nc, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _)), _ when string_of_id id = "atom" -> - Some (kids, nc, nexp) - | None, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _) when string_of_id id = "atom" -> - Some ([], nc_true, nexp) - | None, Typ_aux (Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]), _) when string_of_id id = "range" -> - let kid = fresh_existential () in - Some ([kid], nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi), nvar kid) - | None, Typ_aux (Typ_id id, _) when string_of_id id = "nat" -> - let kid = fresh_existential () in - Some ([kid], nc_lteq (nint 0) (nvar kid), nvar kid) - | None, Typ_aux (Typ_id id, _) when string_of_id id = "int" -> - let kid = fresh_existential () in - Some ([kid], nc_true, nvar kid) - | _, _ -> None - let bind_numeric l typ env = - match destruct_numeric env typ with + match destruct_numeric (Env.expand_synonyms env typ) with | Some (kids, nc, nexp) -> nexp, add_existential l kids nc env | None -> typ_error l ("Expected " ^ string_of_typ typ ^ " to be numeric") @@ -1220,15 +1148,13 @@ let bind_numeric l typ env = (** Pull an (potentially)-existentially qualified type into the global typing environment **) let bind_existential l typ env = - match destruct_numeric env typ with - | Some (kids, nc, nexp) -> atom_typ nexp, add_existential l kids nc env - | None -> match destruct_exist env typ with - | Some (kids, nc, typ) -> typ, add_existential l kids nc env - | None -> typ, env + match destruct_exist (Env.expand_synonyms env typ) with + | Some (kids, nc, typ) -> typ, add_existential l kids nc env + | None -> typ, env let destruct_range env typ = let kids, constr, (Typ_aux (typ_aux, _)) = - Util.option_default ([], nc_true, typ) (destruct_exist env typ) + Util.option_default ([], nc_true, typ) (destruct_exist (Env.expand_synonyms env typ)) in match typ_aux with | Typ_app (f, [A_aux (A_nexp n, _)]) @@ -1492,7 +1418,7 @@ let rec unify_typ l env goals (Typ_aux (aux1, _) as typ1) (Typ_aux (aux2, _) as | Typ_internal_unknown, _ | _, Typ_internal_unknown when Env.allow_unknowns env -> KBindings.empty - + | Typ_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_typ typ2) | Typ_app (range, [A_aux (A_nexp n1, _); A_aux (A_nexp n2, _)]), @@ -1528,7 +1454,8 @@ and unify_order l goals (Ord_aux (aux1, _) as ord1) (Ord_aux (aux2, _) as ord2) | _, _ -> unify_error l ("Cound not unify " ^ string_of_order ord1 ^ " and " ^ string_of_order ord2) and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) = - typ_debug (lazy ("UNIFYING NEXPS " ^ string_of_nexp nexp1 ^ " AND " ^ string_of_nexp nexp2 ^ " FOR GOALS " ^ string_of_list ", " string_of_kid (KidSet.elements goals))); + typ_debug (lazy (Util.("Unify nexp " |> magenta |> clear) ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 + ^ " goals " ^ string_of_list ", " string_of_kid (KidSet.elements goals))); if KidSet.is_empty (KidSet.inter (nexp_frees nexp1) goals) then begin @@ -1559,19 +1486,17 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au then unify_nexp l env goals n1a (nsum nexp2 n1b) else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) | Nexp_times (n1a, n1b) -> - (* If we have SMT operations div and mod, then we can use the + (* f we have SMT operations div and mod, then we can use the property that mod(m, C) = 0 && C != 0 --> (C * n = m <--> n = m / C) - to help us unify multiplications. *) - if Env.have_smt_op (mk_id "div") env && Env.have_smt_op (mk_id "mod") env then - let valid n c = prove env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove env (nc_neq c (nint 0)) in - if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then - unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b]) - else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then - unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) - else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + to help us unify multiplications and divisions. *) + let valid n c = prove env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove env (nc_neq c (nint 0)) in + if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then + unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b]) + else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then + unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) else if KidSet.is_empty (nexp_frees n1a) then begin match nexp_aux2 with @@ -1611,7 +1536,7 @@ let subst_unifiers unifiers typ = let subst_unifiers_typ_arg unifiers typ_arg = List.fold_left (fun typ_arg (v, arg) -> typ_arg_subst v arg typ_arg) typ_arg (KBindings.bindings unifiers) - + let instantiate_quant (v, arg) (QI_aux (aux, l) as qi) = match aux with | QI_id kopt when Kid.compare (kopt_kid kopt) v = 0 -> @@ -1716,7 +1641,7 @@ let rec alpha_equivalent env typ1 typ2 = else (typ_debug (lazy "Not alpha-equivalent"); false) let unwrap_exist env typ = - match destruct_exist env typ with + match destruct_exist (Env.expand_synonyms env typ) with | Some (kids, nc, typ) -> (kids, nc, typ) | None -> ([], nc_true, typ) @@ -1725,13 +1650,51 @@ let unifier_constraint env (v, arg) = | A_aux (A_nexp nexp, _) -> Env.add_constraint (nc_eq (nvar v) nexp) env | _ -> env -let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as typ2) = +let canonicalize env typ = + let typ = Env.expand_synonyms env typ in + let rec canon (Typ_aux (aux, l)) = + match aux with + | Typ_var v -> Typ_aux (Typ_var v, l) + | Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l) + | Typ_id id when string_of_id id = "int" -> + exist_typ (fun _ -> nc_true) (fun v -> atom_typ (nvar v)) + | Typ_id id -> Typ_aux (Typ_id id, l) + | Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]) when string_of_id id = "range" -> + exist_typ (fun v -> nc_and (nc_lteq lo (nvar v)) (nc_lteq (nvar v) hi)) (fun v -> atom_typ (nvar v)) + | Typ_app (id, args) -> + Typ_aux (Typ_app (id, List.map canon_arg args), l) + | Typ_tup typs -> + let typs = List.map canon typs in + let fold_exist (kids, nc, typs) typ = + match destruct_exist typ with + | Some (kids', nc', typ') -> (kids @ kids', nc_and nc nc', typs @ [typ']) + | None -> (kids, nc, typs @ [typ]) + in + let kids, nc, typs = List.fold_left fold_exist ([], nc_true, []) typs in + if kids = [] then + Typ_aux (Typ_tup typs, l) + else + Typ_aux (Typ_exist (kids, nc, Typ_aux (Typ_tup typs, l)), l) + | Typ_exist (kids, nc, typ) -> + begin match destruct_exist (canon typ) with + | Some (kids', nc', typ') -> + Typ_aux (Typ_exist (kids @ kids', nc_and nc nc', typ'), l) + | None -> Typ_aux (Typ_exist (kids, nc, typ), l) + end + | Typ_fn _ | Typ_bidir _ -> raise (Reporting.err_unreachable l __POS__ "Function type passed to Type_check.canonicalize") + and canon_arg (A_aux (aux, l)) = + A_aux ((match aux with + | A_typ typ -> A_typ (canon typ) + | arg -> arg), + l) + in + canon typ + +let rec subtyp l env typ1 typ2 = + let (Typ_aux (typ_aux1, _) as typ1) = Env.expand_synonyms env typ1 in + let (Typ_aux (typ_aux2, _) as typ2) = Env.expand_synonyms env typ2 in typ_print (lazy (("Subtype " |> Util.green |> Util.clear) ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2)); - match typ_aux1, typ_aux2 with - | Typ_tup typs1, Typ_tup typs2 when List.length typs1 = List.length typs2 -> - List.iter2 (subtyp l env) typs1 typs2 - | _, _ -> - match destruct_numeric env typ1, destruct_numeric env typ2 with + match destruct_numeric typ1, destruct_numeric typ2 with (* Ensure alpha equivalent types are always subtypes of one another - this ensures that we can always re-check inferred types. *) | _, _ when alpha_equivalent env typ1 typ2 -> () @@ -1743,27 +1706,50 @@ let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as t let env = add_existential l kids1 nc1 env in let env = add_typ_vars l (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2))) env in let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in - if not (kids2 = []) then typ_error l "Universally quantified constraint generated" else (); + if not (kids2 = []) then typ_error l ("Universally quantified constraint generated: " ^ Util.string_of_list ", " string_of_kid kids2) else (); let env = Env.add_constraint (nc_eq nexp1 nexp2) env in if prove env nc2 then () else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env)) | _, _ -> - match destruct_exist env typ1, unwrap_exist env (Env.canonicalize env typ2) with + match destruct_exist' typ1, destruct_exist (canonicalize env typ2) with | Some (kids, nc, typ1), _ -> let env = add_existential l kids nc env in subtyp l env typ1 typ2 - | None, (kids, nc, typ2) -> + | None, Some (kids, nc, typ2) -> typ_debug (lazy "Subtype check with unification"); + let typ1 = canonicalize env typ1 in let env = add_typ_vars l kids env in let kids' = KidSet.elements (KidSet.diff (KidSet.of_list kids) (typ_frees typ2)) in if not (kids' = []) then typ_error l "Universally quantified constraint generated" else (); let unifiers = - try unify l env (tyvars_of_typ typ2) typ2 typ1 with + try unify l env (KidSet.diff (tyvars_of_typ typ2) (tyvars_of_typ typ1)) typ2 typ1 with | Unification_error (_, m) -> typ_error l m in let nc = List.fold_left (fun nc (kid, uvar) -> constraint_subst kid uvar nc) nc (KBindings.bindings unifiers) in let env = List.fold_left unifier_constraint env (KBindings.bindings unifiers) in if prove env nc then () else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env)) + | None, None -> + match typ_aux1, typ_aux2 with + | Typ_tup typs1, Typ_tup typs2 when List.length typs1 = List.length typs2 -> + List.iter2 (subtyp l env) typs1 typs2 + + | Typ_app (id1, args1), Typ_app (id2, args2) when Id.compare id1 id2 = 0 && List.length args1 = List.length args2 -> + List.iter2 (subtyp_arg l env) args1 args2 + + | Typ_id id1, Typ_id id2 when Id.compare id1 id2 = 0 -> () + | Typ_id id1, Typ_app (id2, []) when Id.compare id1 id2 = 0 -> () + | Typ_app (id1, []), Typ_id id2 when Id.compare id1 id2 = 0 -> () + + | _, _ -> typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env)) + +and subtyp_arg l env (A_aux (aux1, _) as arg1) (A_aux (aux2, _) as arg2) = + typ_print (lazy (("Subtype arg " |> Util.green |> Util.clear) ^ string_of_typ_arg arg1 ^ " and " ^ string_of_typ_arg arg2)); + match aux1, aux2 with + | A_nexp n1, A_nexp n2 when prove env (nc_eq n1 n2) -> () + | A_typ typ1, A_typ typ2 -> subtyp l env typ1 typ2 + | A_order ord1, A_order ord2 when ord_identical ord1 ord2 -> () + | A_bool nc1, A_bool nc2 -> assert false + | _, _ -> typ_error l "Mismatched argument types in subtype check" let typ_equality l env typ1 typ2 = subtyp l env typ1 typ2; subtyp l env typ2 typ1 @@ -1928,6 +1914,38 @@ let expected_typ_of (l, tannot) = match tannot with (* Flow typing *) +type simple_numeric = + | Equal of nexp + | Constraint of (kid -> n_constraint) + | Anything + +let to_simple_numeric l kids nc (Nexp_aux (aux, _) as n) = + match aux, kids with + | Nexp_var v, [v'] when Kid.compare v v' = 0 -> + Constraint (fun subst -> constraint_subst v (arg_nexp (nvar subst)) nc) + | _, [] -> + Equal n + | _ -> + typ_error l "Numeric type is non-simple" + +let union_simple_numeric ex1 ex2 = + match ex1, ex2 with + | Equal nexp1, Equal nexp2 -> + Constraint (fun kid -> nc_or (nc_eq (nvar kid) nexp1) (nc_eq (nvar kid) nexp2)) + + | Equal nexp, Constraint c -> + Constraint (fun kid -> nc_or (nc_eq (nvar kid) nexp) (c kid)) + + | Constraint c, Equal nexp -> + Constraint (fun kid -> nc_or (c kid) (nc_eq (nvar kid) nexp)) + + | _, _ -> Anything + +let typ_of_simple_numeric = function + | Anything -> int_typ + | Equal nexp -> atom_typ nexp + | Constraint c -> exist_typ c (fun kid -> atom_typ (nvar kid)) + let rec big_int_of_nexp (Nexp_aux (nexp, _)) = match nexp with | Nexp_constant c -> Some c | Nexp_times (n1, n2) -> @@ -1977,17 +1995,17 @@ let rec assert_constraint env b (E_aux (exp_aux, _) as exp) = combine_constraint (not b) nc_or (assert_constraint env b x) (assert_constraint env b y) | E_app (op, [x; y]) when string_of_id op = "and_bool" -> combine_constraint b nc_and (assert_constraint env b x) (assert_constraint env b y) - | E_app (op, [x; y]) when string_of_id op = "gteq_atom" -> + | E_app (op, [x; y]) when string_of_id op = "gteq_int" -> option_binop nc_gteq (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "lteq_atom" -> + | E_app (op, [x; y]) when string_of_id op = "lteq_int" -> option_binop nc_lteq (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "gt_atom" -> + | E_app (op, [x; y]) when string_of_id op = "gt_int" -> option_binop nc_gt (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "lt_atom" -> + | E_app (op, [x; y]) when string_of_id op = "lt_int" -> option_binop nc_lt (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "eq_atom" -> + | E_app (op, [x; y]) when string_of_id op = "eq_int" -> option_binop nc_eq (assert_nexp env x) (assert_nexp env y) - | E_app (op, [x; y]) when string_of_id op = "neq_atom" -> + | E_app (op, [x; y]) when string_of_id op = "neq_int" -> option_binop nc_neq (assert_nexp env x) (assert_nexp env y) | _ -> None @@ -2398,13 +2416,13 @@ and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ = in begin try - typ_debug (lazy ("PERFORMING TYPE COERCION: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ)); + typ_debug (lazy ("Performing type coercion: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ)); subtyp l env (typ_of annotated_exp) typ; switch_exp_typ annotated_exp with | Type_error (_, trigger) when Env.allow_casts env -> let casts = filter_casts env (typ_of annotated_exp) typ (Env.get_casts env) in try_casts trigger [] casts - | Type_error (l, err) -> typ_error l "Subtype error" + | Type_error (l, err) -> typ_raise l err end (* type_coercion_unify env exp typ attempts to coerce exp to a type @@ -2434,7 +2452,7 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ = in begin try - typ_debug (lazy "PERFORMING COERCING UNIFICATION"); + typ_debug (lazy ("Coercing unification: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ)); let atyp, env = bind_existential l (typ_of annotated_exp) env in annotated_exp, unify l env goals typ atyp, env with @@ -2548,7 +2566,7 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) end | P_app (f, pats) when Env.is_union_constructor f env -> begin - let (typq, ctor_typ) = Env.get_val_spec f env in + let (typq, ctor_typ) = Env.get_union_id f env in let quants = quant_items typq in let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with | Typ_tup typs -> typs @@ -2563,8 +2581,8 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) let unifiers = unify l env goals ret_typ typ in let arg_typ' = subst_unifiers unifiers arg_typ in let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in - if (match quants' with [] -> false | _ -> true) - then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) + if not (List.for_all (solve_quant env) quants') then + typ_raise l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env)) else (); let ret_typ' = subst_unifiers unifiers ret_typ in let tpats, env, guards = @@ -2580,7 +2598,7 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) | P_app (f, pats) when Env.is_mapping f env -> begin - let (typq, mapping_typ) = Env.get_val_spec f env in + let (typq, mapping_typ) = Env.get_union_id f env in let quants = quant_items typq in let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with | Typ_tup typs -> typs @@ -3094,7 +3112,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = let inferred_f = irule infer_exp env f in let inferred_t = irule infer_exp env t in let checked_step = crule check_exp env step int_typ in - match destruct_numeric env (typ_of inferred_f), destruct_numeric env (typ_of inferred_t) with + match destruct_numeric (typ_of inferred_f), destruct_numeric (typ_of inferred_t) with | Some (kids1, nc1, nexp1), Some (kids2, nc2, nexp2) -> let loop_kid = mk_kid ("loop_" ^ string_of_id v) in let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env (loop_kid :: kids1 @ kids2) in @@ -3110,8 +3128,22 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | E_if (cond, then_branch, else_branch) -> let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in let then_branch' = irule infer_exp (add_opt_constraint (assert_constraint env true cond') env) then_branch in - let else_branch' = crule check_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch (typ_of then_branch') in - annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch') + (* We don't have generic type union in Sail, but we can union simple numeric types. *) + begin match destruct_numeric (Env.expand_synonyms env (typ_of then_branch')) with + | Some (kids, nc, then_nexp) -> + let then_sn = to_simple_numeric l kids nc then_nexp in + let else_branch' = irule infer_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch in + begin match destruct_numeric (Env.expand_synonyms env (typ_of else_branch')) with + | Some (kids, nc, else_nexp) -> + let else_sn = to_simple_numeric l kids nc else_nexp in + let typ = typ_of_simple_numeric (union_simple_numeric then_sn else_sn) in + annot_exp (E_if (cond', then_branch', else_branch')) typ + | None -> typ_error l ("Could not infer type of " ^ string_of_exp else_branch) + end + | None -> + let else_branch' = crule check_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch (typ_of then_branch') in + annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch') + end | E_vector_access (v, n) -> infer_exp env (E_aux (E_app (mk_id "vector_access", [v; n]), (l, ()))) | E_vector_update (v, n, exp) -> infer_exp env (E_aux (E_app (mk_id "vector_update", [v; n; exp]), (l, ()))) | E_vector_update_subrange (v, n, m, exp) -> infer_exp env (E_aux (E_app (mk_id "vector_update_subrange", [v; n; m; exp]), (l, ()))) @@ -4163,10 +4195,6 @@ let check_val_spec env (VS_aux (vs, (l, _))) = let vs, id, typq, typ, env = match vs with | VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), ts_l) as typschm, id, ext_opt, is_cast) -> typ_print (lazy (Util.("Check val spec " |> cyan |> clear) ^ string_of_id id ^ " : " ^ string_of_typschm typschm)); - let env = match (ext_opt "smt", ext_opt "#") with - | Some op, None -> Env.add_smt_op id op env - | _, _ -> env - in let env = Env.add_extern id ext_opt env in let env = if is_cast then Env.add_cast id env else env in let typq, typ = diff --git a/src/type_check.mli b/src/type_check.mli index 52ade6fa..47b9d172 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -300,6 +300,8 @@ val prove : Env.t -> n_constraint -> bool val solve : Env.t -> nexp -> Big_int.num option +val canonicalize : Env.t -> typ -> typ + val subtype_check : Env.t -> typ -> typ -> bool val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t * unit Ast.exp list @@ -350,11 +352,11 @@ val destruct_atom_nexp : Env.t -> typ -> nexp option (** Safely destructure an existential type. Returns None if the type is not existential. This function will pick a fresh name for the existential to ensure that no name-clashes occur. *) -val destruct_exist : Env.t -> typ -> (kid list * n_constraint * typ) option +val destruct_exist : typ -> (kid list * n_constraint * typ) option val destruct_range : Env.t -> typ -> (kid list * n_constraint * nexp * nexp) option -val destruct_numeric : Env.t -> typ -> (kid list * n_constraint * nexp) option +val destruct_numeric : typ -> (kid list * n_constraint * nexp) option val destruct_vector : Env.t -> typ -> (nexp * order * typ) option diff --git a/test/typecheck/pass/constrained_struct/v1.expect b/test/typecheck/pass/constrained_struct/v1.expect index 5173ef0b..ab25cbc4 100644 --- a/test/typecheck/pass/constrained_struct/v1.expect +++ b/test/typecheck/pass/constrained_struct/v1.expect @@ -2,4 +2,4 @@ Type error at file "constrained_struct/v1.sail", line 10, character 19 to line 1 type MyStruct64 = [41mMyStruct[0m(65) -Could not prove (65 = 32 | 65 = 64) for type constructor MyStruct +Could not prove (65 == 32 | 65 == 64) for type constructor MyStruct diff --git a/test/typecheck/pass/constraint_ctor.sail b/test/typecheck/pass/constraint_ctor.sail new file mode 100644 index 00000000..2b4a5746 --- /dev/null +++ b/test/typecheck/pass/constraint_ctor.sail @@ -0,0 +1,20 @@ +default Order dec + +$include <flow.sail> + +union Foo = { + Foo : {'n, 'n >= 3. int('n)} +} + +function foo(Foo(x as int('x)): Foo) -> unit = { + _prove(constraint('x >= 3)); +} + +union Bar('m), 'm <= 100 = { + Bar : {'n, 'n >= 'm. int('n)} +} + +function bar(Bar(x as int('x)) : Bar(23)) -> unit = { + _prove(constraint('x >= 23)); + () +} diff --git a/test/typecheck/pass/constraint_ctor/v1.expect b/test/typecheck/pass/constraint_ctor/v1.expect new file mode 100644 index 00000000..c3886af8 --- /dev/null +++ b/test/typecheck/pass/constraint_ctor/v1.expect @@ -0,0 +1,5 @@ +Type error at file "constraint_ctor/v1.sail", line 10, character 3 to line 10, character 29 + + [41m_prove(constraint('x >= 4))[0m; + +Cannot prove 'x >= 4 diff --git a/test/typecheck/pass/constraint_ctor/v1.sail b/test/typecheck/pass/constraint_ctor/v1.sail new file mode 100644 index 00000000..20df5480 --- /dev/null +++ b/test/typecheck/pass/constraint_ctor/v1.sail @@ -0,0 +1,20 @@ +default Order dec + +$include <flow.sail> + +union Foo = { + Foo : {'n, 'n >= 3. int('n)} +} + +function foo(Foo(x as int('x)): Foo) -> unit = { + _prove(constraint('x >= 4)); +} + +union Bar('m), 'm <= 100 = { + Bar : {'n, 'n >= 'm. int('n)} +} + +function bar(Bar(x as int('x)) : Bar(23)) -> unit = { + _prove(constraint('x >= 23)); + () +} diff --git a/test/typecheck/pass/constraint_ctor/v2.expect b/test/typecheck/pass/constraint_ctor/v2.expect new file mode 100644 index 00000000..a315b3b7 --- /dev/null +++ b/test/typecheck/pass/constraint_ctor/v2.expect @@ -0,0 +1,5 @@ +Type error at file "constraint_ctor/v2.sail", line 18, character 3 to line 18, character 30 + + [41m_prove(constraint('x >= 24))[0m; + +Cannot prove 'x >= 24 diff --git a/test/typecheck/pass/constraint_ctor/v2.sail b/test/typecheck/pass/constraint_ctor/v2.sail new file mode 100644 index 00000000..76d9793d --- /dev/null +++ b/test/typecheck/pass/constraint_ctor/v2.sail @@ -0,0 +1,20 @@ +default Order dec + +$include <flow.sail> + +union Foo = { + Foo : {'n, 'n >= 3. int('n)} +} + +function foo(Foo(x as int('x)): Foo) -> unit = { + _prove(constraint('x >= 3)); +} + +union Bar('m), 'm <= 100 = { + Bar : {'n, 'n >= 'm. int('n)} +} + +function bar(Bar(x as int('x)) : Bar(23)) -> unit = { + _prove(constraint('x >= 24)); + () +} diff --git a/test/typecheck/pass/constraint_ctor/v3.expect b/test/typecheck/pass/constraint_ctor/v3.expect new file mode 100644 index 00000000..e0edd01a --- /dev/null +++ b/test/typecheck/pass/constraint_ctor/v3.expect @@ -0,0 +1,5 @@ +Type error at file "constraint_ctor/v3.sail", line 18, character 3 to line 18, character 30 + + [41m_prove(constraint('x >= 23))[0m; + +Cannot prove 'x >= 23 diff --git a/test/typecheck/pass/constraint_ctor/v3.sail b/test/typecheck/pass/constraint_ctor/v3.sail new file mode 100644 index 00000000..a8f5bd13 --- /dev/null +++ b/test/typecheck/pass/constraint_ctor/v3.sail @@ -0,0 +1,20 @@ +default Order dec + +$include <flow.sail> + +union Foo = { + Foo : {'n, 'n >= 3. int('n)} +} + +function foo(Foo(x as int('x)): Foo) -> unit = { + _prove(constraint('x >= 3)); +} + +union Bar('m), 'm <= 100 = { + Bar : {'n, 'n >= 'm. int('n)} +} + +function bar(Bar(x as int('x)) : Bar(22)) -> unit = { + _prove(constraint('x >= 23)); + () +} diff --git a/test/typecheck/pass/constraint_ctor/v4.expect b/test/typecheck/pass/constraint_ctor/v4.expect new file mode 100644 index 00000000..06eb9d22 --- /dev/null +++ b/test/typecheck/pass/constraint_ctor/v4.expect @@ -0,0 +1,5 @@ +Type error at file "constraint_ctor/v4.sail", line 17, character 34 to line 17, character 36 + +function bar(Bar(x as int('x)) : [41mBar[0m(23)) -> unit = { + +Could not prove 23 <= 22 for type constructor Bar diff --git a/test/typecheck/pass/constraint_ctor/v4.sail b/test/typecheck/pass/constraint_ctor/v4.sail new file mode 100644 index 00000000..d8dab178 --- /dev/null +++ b/test/typecheck/pass/constraint_ctor/v4.sail @@ -0,0 +1,20 @@ +default Order dec + +$include <flow.sail> + +union Foo = { + Foo : {'n, 'n >= 3. int('n)} +} + +function foo(Foo(x as int('x)): Foo) -> unit = { + _prove(constraint('x >= 3)); +} + +union Bar('m), 'm <= 22 = { + Bar : {'n, 'n >= 'm. int('n)} +} + +function bar(Bar(x as int('x)) : Bar(23)) -> unit = { + _prove(constraint('x >= 23)); + () +} diff --git a/test/typecheck/pass/exist2.sail b/test/typecheck/pass/exist2.sail index 102a1084..e518609d 100644 --- a/test/typecheck/pass/exist2.sail +++ b/test/typecheck/pass/exist2.sail @@ -39,6 +39,6 @@ overload existential = {existential_int, existential_range} let v11 : {'n, 0 == 0. atom('n)} = existential(v10) -let v12 : {'e, 0 <= 'e & 'e <= 3. atom('e)} = existential(2 : range(0, 3)) +let v12 : {'e, 0 <= 'e & 'e <= 3. atom('e)} = 2 let v13 : MyInt = existential(v10) diff --git a/test/typecheck/pass/global_type_var/v1.expect b/test/typecheck/pass/global_type_var/v1.expect index 7e3b517c..e81c467e 100644 --- a/test/typecheck/pass/global_type_var/v1.expect +++ b/test/typecheck/pass/global_type_var/v1.expect @@ -6,15 +6,15 @@ Tried performing type coercion from int(32) to int('size) on 32 Coercion failed because: int(32) is not a subtype of int('size) in context - * 'size = 'ex8# - * ('ex8# = 32 | 'ex8# = 64) - * ('ex7# = 32 | 'ex7# = 64) + * 'size == 'ex14# + * ('ex14# == 32 | 'ex14# == 64) + * ('ex13# == 32 | 'ex13# == 64) where - * 'ex7# bound at file "global_type_var/v1.sail", line 5, character 5 to line 5, character 32 + * 'ex13# bound at file "global_type_var/v1.sail", line 5, character 5 to line 5, character 32 let [41m(size as 'size) : {|32, 64|}[0m = 32 - * 'ex8# bound at file "global_type_var/v1.sail", line 5, character 6 to line 5, character 18 + * 'ex14# bound at file "global_type_var/v1.sail", line 5, character 6 to line 5, character 18 let ([41msize as 'size[0m) : {|32, 64|} = 32 diff --git a/test/typecheck/pass/global_type_var/v2.expect b/test/typecheck/pass/global_type_var/v2.expect index dc1281d2..21c4b348 100644 --- a/test/typecheck/pass/global_type_var/v2.expect +++ b/test/typecheck/pass/global_type_var/v2.expect @@ -6,15 +6,15 @@ Tried performing type coercion from int(64) to int('size) on 64 Coercion failed because: int(64) is not a subtype of int('size) in context - * 'size = 'ex8# - * ('ex8# = 32 | 'ex8# = 64) - * ('ex7# = 32 | 'ex7# = 64) + * 'size == 'ex14# + * ('ex14# == 32 | 'ex14# == 64) + * ('ex13# == 32 | 'ex13# == 64) where - * 'ex7# bound at file "global_type_var/v2.sail", line 5, character 5 to line 5, character 32 + * 'ex13# bound at file "global_type_var/v2.sail", line 5, character 5 to line 5, character 32 let [41m(size as 'size) : {|32, 64|}[0m = 32 - * 'ex8# bound at file "global_type_var/v2.sail", line 5, character 6 to line 5, character 18 + * 'ex14# bound at file "global_type_var/v2.sail", line 5, character 6 to line 5, character 18 let ([41msize as 'size[0m) : {|32, 64|} = 32 diff --git a/test/typecheck/pass/if_infer.sail b/test/typecheck/pass/if_infer.sail new file mode 100644 index 00000000..f3fec1c4 --- /dev/null +++ b/test/typecheck/pass/if_infer.sail @@ -0,0 +1,12 @@ +default Order dec + +$include <prelude.sail> + +register R : bool + +val f : unit -> {'n, 1 <= 'n <= 3. int('n)} + +function main((): unit) -> unit = { + let _ = 0b1001[if R then 0 else f()]; + () +} diff --git a/test/typecheck/pass/if_infer/v1.expect b/test/typecheck/pass/if_infer/v1.expect new file mode 100644 index 00000000..06df7dc5 --- /dev/null +++ b/test/typecheck/pass/if_infer/v1.expect @@ -0,0 +1,17 @@ +Type error at file "if_infer/v1.sail", line 10, character 11 to line 10, character 37 + + let _ = [41m0b100[if R then 0 else f()][0m; + +No overloadings for vector_access, tried: + bitvector_access: + Could not resolve quantifiers for bitvector_access (0 <= 'ex41#ex40# & ('ex41#ex40# + 1) <= 3) + + Try adding named type variables for + + + plain_vector_access: + Could not resolve quantifiers for plain_vector_access (0 <= 'ex44#ex43# & ('ex44#ex43# + 1) <= 3) + + Try adding named type variables for + + diff --git a/test/typecheck/pass/if_infer/v1.sail b/test/typecheck/pass/if_infer/v1.sail new file mode 100644 index 00000000..0938aaed --- /dev/null +++ b/test/typecheck/pass/if_infer/v1.sail @@ -0,0 +1,12 @@ +default Order dec + +$include <prelude.sail> + +register R : bool + +val f : unit -> {'n, 1 <= 'n <= 3. int('n)} + +function main((): unit) -> unit = { + let _ = 0b100[if R then 0 else f()]; + () +} diff --git a/test/typecheck/pass/if_infer/v2.expect b/test/typecheck/pass/if_infer/v2.expect new file mode 100644 index 00000000..050e90e4 --- /dev/null +++ b/test/typecheck/pass/if_infer/v2.expect @@ -0,0 +1,17 @@ +Type error at file "if_infer/v2.sail", line 10, character 11 to line 10, character 38 + + let _ = [41m0b1001[if R then 0 else f()][0m; + +No overloadings for vector_access, tried: + bitvector_access: + Could not resolve quantifiers for bitvector_access (0 <= 'ex41#ex40# & ('ex41#ex40# + 1) <= 4) + + Try adding named type variables for + + + plain_vector_access: + Could not resolve quantifiers for plain_vector_access (0 <= 'ex44#ex43# & ('ex44#ex43# + 1) <= 4) + + Try adding named type variables for + + diff --git a/test/typecheck/pass/if_infer/v2.sail b/test/typecheck/pass/if_infer/v2.sail new file mode 100644 index 00000000..a49e1ed7 --- /dev/null +++ b/test/typecheck/pass/if_infer/v2.sail @@ -0,0 +1,12 @@ +default Order dec + +$include <prelude.sail> + +register R : bool + +val f : unit -> {'n, 1 <= 'n <= 4. int('n)} + +function main((): unit) -> unit = { + let _ = 0b1001[if R then 0 else f()]; + () +} diff --git a/test/typecheck/pass/if_infer/v3.expect b/test/typecheck/pass/if_infer/v3.expect new file mode 100644 index 00000000..8b149bc8 --- /dev/null +++ b/test/typecheck/pass/if_infer/v3.expect @@ -0,0 +1,7 @@ +Type error at file "if_infer/v3.sail", line 10, character 11 to line 10, character 38 + + let _ = [41m0b1001[if R then 0 else f()][0m; + +No overloadings for vector_access, tried: + bitvector_access: Numeric type is non-simple + plain_vector_access: Numeric type is non-simple diff --git a/test/typecheck/pass/if_infer/v3.sail b/test/typecheck/pass/if_infer/v3.sail new file mode 100644 index 00000000..0c3dec21 --- /dev/null +++ b/test/typecheck/pass/if_infer/v3.sail @@ -0,0 +1,12 @@ +default Order dec + +$include <prelude.sail> + +register R : bool + +val f : unit -> {'n 'm, 'm == 3 & 1 <= 'n <= 'm. int('n)} + +function main((): unit) -> unit = { + let _ = 0b1001[if R then 0 else f()]; + () +} |
