summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlasdair Armstrong2018-12-10 20:39:16 +0000
committerAlasdair Armstrong2018-12-10 20:45:05 +0000
commit5bc5f5dee8921f8d24260dae54177e00c291fcb1 (patch)
tree89bbd7a947e8063bdbaac4abf364f6cccd2c3fdf
parentd8f0854ca9d80d3af8d6a4aaec778643eda9421c (diff)
Various changes:
* Improve type inference for numeric if statements (if_infer test) * Correctly handle constraints for existentially quantified constructors (constraint_ctor test) * Canonicalise all numeric types in function arguments, which triggers some weird edge cases between parametric polymorphism and subtyping of numeric arguments * Because of this eq_int, eq_range, and eq_atom etc become identical * Avoid duplicating destruct_exist in Env * Handle some odd subtyping cases better
-rwxr-xr-xaarch64/prelude.sail4
-rw-r--r--lib/flow.sail37
-rw-r--r--src/ast_util.ml9
-rw-r--r--src/c_backend.ml2
-rw-r--r--src/initial_check.ml5
-rw-r--r--src/initial_check.mli1
-rw-r--r--src/isail.ml3
-rw-r--r--src/monomorphise.ml8
-rw-r--r--src/parser.mly6
-rw-r--r--src/pretty_print_coq.ml16
-rw-r--r--src/rewrites.ml2
-rw-r--r--src/type_check.ml386
-rw-r--r--src/type_check.mli6
-rw-r--r--test/typecheck/pass/constrained_struct/v1.expect2
-rw-r--r--test/typecheck/pass/constraint_ctor.sail20
-rw-r--r--test/typecheck/pass/constraint_ctor/v1.expect5
-rw-r--r--test/typecheck/pass/constraint_ctor/v1.sail20
-rw-r--r--test/typecheck/pass/constraint_ctor/v2.expect5
-rw-r--r--test/typecheck/pass/constraint_ctor/v2.sail20
-rw-r--r--test/typecheck/pass/constraint_ctor/v3.expect5
-rw-r--r--test/typecheck/pass/constraint_ctor/v3.sail20
-rw-r--r--test/typecheck/pass/constraint_ctor/v4.expect5
-rw-r--r--test/typecheck/pass/constraint_ctor/v4.sail20
-rw-r--r--test/typecheck/pass/exist2.sail2
-rw-r--r--test/typecheck/pass/global_type_var/v1.expect10
-rw-r--r--test/typecheck/pass/global_type_var/v2.expect10
-rw-r--r--test/typecheck/pass/if_infer.sail12
-rw-r--r--test/typecheck/pass/if_infer/v1.expect17
-rw-r--r--test/typecheck/pass/if_infer/v1.sail12
-rw-r--r--test/typecheck/pass/if_infer/v2.expect17
-rw-r--r--test/typecheck/pass/if_infer/v2.sail12
-rw-r--r--test/typecheck/pass/if_infer/v3.expect7
-rw-r--r--test/typecheck/pass/if_infer/v3.sail12
33 files changed, 476 insertions, 242 deletions
diff --git a/aarch64/prelude.sail b/aarch64/prelude.sail
index 8cd18fac..505ca7b6 100755
--- a/aarch64/prelude.sail
+++ b/aarch64/prelude.sail
@@ -143,12 +143,12 @@ val UInt = {
interpreter: "uint",
c: "sail_unsigned",
coq: "uint"
-} : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1)
+} : forall 'n. bits('n) -> {'m, 0 <= 'm <= 2 ^ 'n - 1. int('m)}
val SInt = {
c: "sail_signed",
_: "sint"
-} : forall 'n. bits('n) -> range(- (2 ^ ('n - 1)), 2 ^ ('n - 1) - 1)
+} : forall 'n. bits('n) -> {'m, (- (2 ^ ('n - 1))) <= 'm <= 2 ^ ('n - 1) - 1. int('m)}
val hex_slice = "hex_slice" : forall 'n 'm. (string, atom('n), atom('m)) -> bits('n - 'm) effect {escape}
diff --git a/lib/flow.sail b/lib/flow.sail
index cdc6b2fd..b9653828 100644
--- a/lib/flow.sail
+++ b/lib/flow.sail
@@ -20,34 +20,9 @@ val not_bool = {coq: "negb", _: "not"} : bool -> bool
or_bool that are not shown here. */
val and_bool = {coq: "andb", _: "and_bool"} : (bool, bool) -> bool
val or_bool = {coq: "orb", _: "or_bool"} : (bool, bool) -> bool
-
-val eq_atom = {ocaml: "eq_int", lem: "eq", c: "eq_int", coq: "Z.eqb"} : forall 'n 'm. (atom('n), atom('m)) -> bool
-
-val neq_atom = {lem: "neq", coq: "neq_atom"} : forall 'n 'm. (atom('n), atom('m)) -> bool
-
-function neq_atom (x, y) = not_bool(eq_atom(x, y))
-
-val lteq_atom = {coq: "Z.leb", _: "lteq"} : forall 'n 'm. (atom('n), atom('m)) -> bool
-val gteq_atom = {coq: "Z.geb", _: "gteq"} : forall 'n 'm. (atom('n), atom('m)) -> bool
-val lt_atom = {coq: "Z.ltb", _: "lt"} : forall 'n 'm. (atom('n), atom('m)) -> bool
-val gt_atom = {coq: "Z.gtb", _: "gt"} : forall 'n 'm. (atom('n), atom('m)) -> bool
-
-val lt_range_atom = {coq: "ltb_range_l", _: "lt"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool
-val lteq_range_atom = {coq: "leb_range_l", _: "lteq"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool
-val gt_range_atom = {coq: "gtb_range_l", _: "gt"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool
-val gteq_range_atom = {coq: "geb_range_l", _: "gteq"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool
-val lt_atom_range = {coq: "ltb_range_r", _: "lt"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool
-val lteq_atom_range = {coq: "leb_range_r", _: "lteq"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool
-val gt_atom_range = {coq: "gtb_range_r", _: "gt"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool
-val gteq_atom_range = {coq: "geb_range_r", _: "gteq"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool
-
-val eq_range = {ocaml: "eq_int", lem: "eq", c: "eq_int", coq: "eq_range"} : forall 'n 'm 'o 'p. (range('n, 'm), range('o, 'p)) -> bool
val eq_int = {ocaml: "eq_int", lem: "eq", c: "eq_int", coq: "Z.eqb"} : (int, int) -> bool
val eq_bool = {ocaml: "eq_bool", lem: "eq", c: "eq_bool", coq: "Bool.eqb"} : (bool, bool) -> bool
-val neq_range = {lem: "neq"} : forall 'n 'm 'o 'p. (range('n, 'm), range('o, 'p)) -> bool
-function neq_range (x, y) = not_bool(eq_range(x, y))
-
val neq_int = {lem: "neq"} : (int, int) -> bool
function neq_int (x, y) = not_bool(eq_int(x, y))
@@ -59,15 +34,15 @@ val gteq_int = {coq: "Z.geb", _:"gteq"} : (int, int) -> bool
val lt_int = {coq: "Z.ltb", _:"lt"} : (int, int) -> bool
val gt_int = {coq: "Z.gtb", _:"gt"} : (int, int) -> bool
-overload operator == = {eq_atom, eq_range, eq_int, eq_bit, eq_bool, eq_unit}
-overload operator != = {neq_atom, neq_range, neq_int, neq_bool}
+overload operator == = {eq_int, eq_bit, eq_bool, eq_unit}
+overload operator != = {neq_int, neq_bool}
overload operator | = {or_bool}
overload operator & = {and_bool}
-overload operator <= = {lteq_atom, lteq_range_atom, lteq_atom_range, lteq_int}
-overload operator < = {lt_atom, lt_range_atom, lt_atom_range, lt_int}
-overload operator >= = {gteq_atom, gteq_range_atom, gteq_atom_range, gteq_int}
-overload operator > = {gt_atom, gt_range_atom, gt_atom_range, gt_int}
+overload operator <= = {lteq_int}
+overload operator < = {lt_int}
+overload operator >= = {gteq_int}
+overload operator > = {gt_int}
$ifdef TEST
diff --git a/src/ast_util.ml b/src/ast_util.ml
index f6b8317d..46afe599 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -368,12 +368,17 @@ let nc_lteq n1 n2 = NC_aux (NC_bounded_le (n1, n2), Parse_ast.Unknown)
let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown)
let nc_lt n1 n2 = nc_lteq (nsum n1 (nint 1)) n2
let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nint 1))
-let nc_and nc1 nc2 = mk_nc (NC_and (nc1, nc2))
let nc_or nc1 nc2 = mk_nc (NC_or (nc1, nc2))
let nc_var kid = mk_nc (NC_var kid)
let nc_true = mk_nc NC_true
let nc_false = mk_nc NC_false
+let nc_and nc1 nc2 =
+ match nc1, nc2 with
+ | _, NC_aux (NC_true, _) -> nc1
+ | NC_aux (NC_true, _), _ -> nc2
+ | _, _ -> mk_nc (NC_and (nc1, nc2))
+
let arg_nexp ?loc:(l=Parse_ast.Unknown) n = A_aux (A_nexp n, l)
let arg_order ?loc:(l=Parse_ast.Unknown) ord = A_aux (A_order ord, l)
let arg_typ ?loc:(l=Parse_ast.Unknown) typ = A_aux (A_typ typ, l)
@@ -685,7 +690,7 @@ and string_of_typ_arg_aux = function
| A_order o -> string_of_order o
| A_bool nc -> string_of_n_constraint nc
and string_of_n_constraint = function
- | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2
+ | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " == " ^ string_of_nexp n2
| NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2
| NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2
| NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2
diff --git a/src/c_backend.ml b/src/c_backend.ml
index 535a0b67..95ab51df 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -194,7 +194,7 @@ let rec ctyp_of_typ ctx typ =
ensure that we don't cause any type variable clashes in
local_env, and that we can optimize the existential based upon
it's constraints. *)
- begin match destruct_exist ctx.local_env typ with
+ begin match destruct_exist (Env.expand_synonyms ctx.local_env typ) with
| Some (kids, nc, typ) ->
let env = add_existential l kids nc ctx.local_env in
ctyp_of_typ { ctx with local_env = env } typ
diff --git a/src/initial_check.ml b/src/initial_check.ml
index 0f1af63d..44f36892 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -777,6 +777,11 @@ let typschm_of_string str =
let typschm, _ = to_ast_typschm initial_ctx typschm in
typschm
+let typ_of_string str =
+ let typ = Parser.typ_eof Lexer.token (Lexing.from_string str) in
+ let typ = to_ast_typ initial_ctx typ in
+ typ
+
let extern_of_string id str = mk_val_spec (VS_val_spec (typschm_of_string str, id, (fun _ -> Some (string_of_id id)), false))
let val_spec_of_string id str = mk_val_spec (VS_val_spec (typschm_of_string str, id, (fun _ -> None), false))
diff --git a/src/initial_check.mli b/src/initial_check.mli
index 32def316..25187e4c 100644
--- a/src/initial_check.mli
+++ b/src/initial_check.mli
@@ -91,3 +91,4 @@ val extern_of_string : id -> string -> unit def
val val_spec_of_string : id -> string -> unit def
val exp_of_string : string -> unit exp
+val typ_of_string : string -> typ
diff --git a/src/isail.ml b/src/isail.ml
index 195e5940..18c59e0b 100644
--- a/src/isail.ml
+++ b/src/isail.ml
@@ -270,6 +270,9 @@ let handle_input' input =
let exp = Type_check.infer_exp !interactive_env exp in
pretty_sail stdout (doc_typ (Type_check.typ_of exp));
print_newline ()
+ | ":canon" ->
+ let typ = Initial_check.typ_of_string arg in
+ print_endline (string_of_typ (Type_check.canonicalize !interactive_env typ))
| ":v" | ":verbose" ->
Type_check.opt_tc_debug := (!Type_check.opt_tc_debug + 1) mod 3;
print_endline ("Verbosity: " ^ string_of_int !Type_check.opt_tc_debug)
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 74ef8376..113db3a2 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -522,7 +522,7 @@ let refine_constructor refinements l env id args =
(* A constructor should always have a single argument. *)
| Typ_aux (Typ_fn ([constr_ty],_,_),_) -> begin
let arg_ty = typ_of_args args in
- match Type_check.destruct_exist env constr_ty with
+ match Type_check.destruct_exist (Type_check.Env.expand_synonyms env constr_ty) with
| None -> None
| Some (kids,nc,constr_ty) ->
let bindings = Type_check.unify l env (tyvars_of_typ constr_ty) constr_ty arg_ty in
@@ -728,7 +728,7 @@ let fabricate_nexp l tannot =
match destruct_tannot tannot with
| None -> nint 32
| Some (env,typ,_) ->
- match Type_check.destruct_exist env typ with
+ match Type_check.destruct_exist (Type_check.Env.expand_synonyms env typ) with
| None -> nint 32
| Some (kids,nc,typ') -> fabricate_nexp_exist env l typ kids nc typ'
@@ -745,7 +745,7 @@ let atom_typ_kid kid = function
let reduce_cast typ exp l annot =
let env = env_of_annot (l,annot) in
let typ' = Env.base_typ_of env typ in
- match exp, destruct_exist env typ' with
+ match exp, destruct_exist (Env.expand_synonyms env typ') with
| E_aux (E_lit (L_aux (L_num n,_)),_), Some ([kid],nc,typ'') when atom_typ_kid kid typ'' ->
let nc_env = Env.add_typ_var l kid K_int env in
let nc_env = Env.add_constraint (nc_eq (nvar kid) (nconstant n)) nc_env in
@@ -3182,7 +3182,7 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) =
| Some (tenv,typ,_) ->
let typ = Env.base_typ_of tenv typ in
let env, tenv, typ =
- match destruct_exist tenv typ with
+ match destruct_exist (Env.expand_synonyms tenv typ) with
| None -> env, tenv, typ
| Some (kids, nc, typ) ->
{ env with kid_deps =
diff --git a/src/parser.mly b/src/parser.mly
index fa36591c..83e6936d 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -212,9 +212,11 @@ let rec desugar_rchain chain s e =
%start file
%start typschm_eof
+%start typ_eof
%start exp_eof
%start def_eof
%type <Parse_ast.typschm> typschm_eof
+%type <Parse_ast.atyp> typ_eof
%type <Parse_ast.exp> exp_eof
%type <Parse_ast.def> def_eof
%type <Parse_ast.defs> file
@@ -349,6 +351,10 @@ tyarg:
| Lparen typ_list Rparen
{ [], $2 }
+typ_eof:
+ | typ Eof
+ { $1 }
+
typ:
| typ0
{ $1 }
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml
index 025156cc..f00a93b7 100644
--- a/src/pretty_print_coq.ml
+++ b/src/pretty_print_coq.ml
@@ -667,7 +667,7 @@ let is_ctor env id = match Env.lookup_id id env with
let is_auto_decomposed_exist env typ =
let typ = expand_range_type typ in
- match destruct_exist env typ with
+ match destruct_exist (Env.expand_synonyms env typ) with
| Some (_, _, typ') -> Some typ'
| _ -> None
@@ -905,7 +905,7 @@ let doc_exp, doc_let =
debug ctxt (lazy (" at type " ^ string_of_typ typ))
in
let typ = expand_range_type typ in
- match destruct_exist env typ with
+ match destruct_exist typ with
| None -> epp
| Some _ ->
let epp = string "build_ex" ^/^ epp in
@@ -921,12 +921,12 @@ let doc_exp, doc_let =
| _ ->
let typ' = expand_range_type (Env.expand_synonyms (env_of exp) typ) in
let build_ex, out_typ =
- match destruct_exist env typ' with
+ match destruct_exist typ' with
| Some (_,_,t) -> true, t
| None -> false, typ'
in
let in_typ = expand_range_type (Env.expand_synonyms (env_of exp) (typ_of exp)) in
- let in_typ = match destruct_exist env in_typ with Some (_,_,t) -> t | None -> in_typ in
+ let in_typ = match destruct_exist in_typ with Some (_,_,t) -> t | None -> in_typ in
let autocast =
(* Avoid using helper functions which simplify the nexps *)
is_bitvector_typ in_typ && is_bitvector_typ out_typ &&
@@ -1528,7 +1528,7 @@ let doc_exp, doc_let =
| P_aux (P_var (P_aux (P_typ (typ, P_aux (P_id id,_)),_),_),_)
when not (is_enum (env_of e1) id) ->
let full_typ = (expand_range_type typ) in
- let binder = match destruct_exist (env_of e1) full_typ with
+ let binder = match destruct_exist (Env.expand_synonyms (env_of e1) full_typ) with
| Some _ ->
squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt typ])
| _ ->
@@ -1975,7 +1975,7 @@ let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) =
| _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type")
in
let build_ex, ret_typ = replace_atom_return_type ret_typ in
- let build_ex = match destruct_exist env (expand_range_type ret_typ) with
+ let build_ex = match destruct_exist (Env.expand_synonyms env (expand_range_type ret_typ)) with
| Some _ -> true
| _ -> build_ex
in
@@ -2035,7 +2035,7 @@ let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) =
| P_typ (_,P_aux (P_id id,_))
when not (is_enum env id) -> begin
let full_typ = (expand_range_type exp_typ) in
- match destruct_exist env full_typ with
+ match destruct_exist (Env.expand_synonyms env full_typ) with
| Some ([kid], NC_aux (NC_true,_),
Typ_aux (Typ_app (Id_aux (Id "atom",_),
[A_aux (A_nexp (Nexp_aux (Nexp_var kid',_)),_)]),_))
@@ -2255,7 +2255,7 @@ let doc_val pat exp =
| None -> typpp, exp
| Some typ ->
let typ = expand_range_type (Env.expand_synonyms env typ) in
- match destruct_exist env typ with
+ match destruct_exist typ with
| None -> typpp, exp
| Some _ ->
empty, match exp with
diff --git a/src/rewrites.ml b/src/rewrites.ml
index d5601d08..d8f1af75 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -3729,7 +3729,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
in
let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in
let ord_exp, kids, constr, lower, upper, lower_exp, upper_exp =
- match destruct_numeric env (typ_of exp1), destruct_numeric env (typ_of exp2) with
+ match destruct_numeric (Env.expand_synonyms env (typ_of exp1)), destruct_numeric (Env.expand_synonyms env (typ_of exp2)) with
| None, _ | _, None ->
raise (Reporting.err_unreachable el __POS__ "Could not determine loop bounds")
| Some (kids1, constr1, n1), Some (kids2, constr2, n2) ->
diff --git a/src/type_check.ml b/src/type_check.ml
index 42616361..459fe8d7 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -215,6 +215,52 @@ and strip_kinded_id_aux = function
and strip_kind = function
| K_aux (k_aux, _) -> K_aux (k_aux, Parse_ast.Unknown)
+let ex_counter = ref 0
+
+let fresh_existential ?name:(n="") () =
+ let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#" ^ n), Parse_ast.Unknown) in
+ incr ex_counter; fresh
+
+let destruct_exist' typ =
+ match typ with
+ | Typ_aux (Typ_exist (kids, nc, typ), _) ->
+ let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in
+ let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in
+ let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in
+ Some (List.map snd fresh_kids, nc, typ)
+ | _ -> None
+
+(** Destructure and canonicalise a numeric type into a list of type
+ variables, a constraint on those type variables, and an
+ N-expression that represents that numeric type in the
+ environment. For example:
+ - {'n, 'n <= 10. atom('n)} => ['n], 'n <= 10, 'n
+ - int => ['n], true, 'n (where x is fresh)
+ - atom('n) => [], true, 'n
+**)
+let destruct_numeric typ =
+ match destruct_exist' typ, typ with
+ | Some (kids, nc, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _)), _ when string_of_id id = "atom" ->
+ Some (kids, nc, nexp)
+ | None, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _) when string_of_id id = "atom" ->
+ Some ([], nc_true, nexp)
+ | None, Typ_aux (Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]), _) when string_of_id id = "range" ->
+ let kid = fresh_existential () in
+ Some ([kid], nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi), nvar kid)
+ | None, Typ_aux (Typ_id id, _) when string_of_id id = "nat" ->
+ let kid = fresh_existential () in
+ Some ([kid], nc_lteq (nint 0) (nvar kid), nvar kid)
+ | None, Typ_aux (Typ_id id, _) when string_of_id id = "int" ->
+ let kid = fresh_existential () in
+ Some ([kid], nc_true, nvar kid)
+ | _, _ -> None
+
+let destruct_exist typ =
+ match destruct_numeric typ with
+ | Some (kids, nc, nexp) -> Some (kids, nc, atom_typ nexp)
+ | None -> destruct_exist' typ
+
+
let adding = Util.("Adding " |> darkgray |> clear)
(**************************************************************************)
@@ -244,6 +290,7 @@ module Env : sig
val get_variant : id -> t -> typquant * type_union list
val add_mapping : id -> typquant * typ * typ -> t -> t
val add_union_id : id -> typquant * typ -> t -> t
+ val get_union_id : id -> t -> typquant * typ
val add_flow : id -> (typ -> typ) -> t -> t
val get_flow : id -> t -> typ -> typ
val remove_flow : id -> t -> t
@@ -286,11 +333,7 @@ module Env : sig
val fresh_kid : ?kid:kid -> t -> kid
val expand_synonyms : t -> typ -> typ
val expand_constraint_synonyms : t -> n_constraint -> n_constraint
- val canonicalize : t -> typ -> typ
val base_typ_of : t -> typ -> typ
- val add_smt_op : id -> string -> t -> t
- val get_smt_op : id -> t -> string
- val have_smt_op : id -> t -> bool
val allow_unknowns : t -> bool
val set_allow_unknowns : bool -> t -> t
@@ -332,7 +375,6 @@ end = struct
records : (typquant * (typ * id) list) Bindings.t;
accessors : (typquant * typ) Bindings.t;
externs : (string -> string option) Bindings.t;
- smt_ops : string Bindings.t;
casts : id list;
allow_casts : bool;
allow_bindings : bool;
@@ -361,7 +403,6 @@ end = struct
records = Bindings.empty;
accessors = Bindings.empty;
externs = Bindings.empty;
- smt_ops = Bindings.empty;
casts = [];
allow_bindings = true;
allow_casts = true;
@@ -434,21 +475,6 @@ end = struct
let existing = try Bindings.find id env.overloads with Not_found -> [] in
{ env with overloads = Bindings.add id (existing @ ids) env.overloads }
- let add_smt_op id str env =
- typ_print (lazy (adding ^ "smt binding " ^ string_of_id id ^ " to " ^ str));
- { env with smt_ops = Bindings.add id str env.smt_ops }
-
- let get_smt_op (Id_aux (_, l) as id) env =
- let rec first_smt_op = function
- | id :: ids -> (try Bindings.find id env.smt_ops with Not_found -> first_smt_op ids)
- | [] -> typ_error l ("No SMT op for " ^ string_of_id id)
- in
- try Bindings.find id env.smt_ops with
- | Not_found -> first_smt_op (get_overloads id env)
-
- let have_smt_op id env =
- try ignore(get_smt_op id env); true with Type_error _ -> false
-
let rec infer_kind env id =
if Bindings.mem id builtin_typs then
Bindings.find id builtin_typs
@@ -566,53 +592,6 @@ end = struct
| A_order _ | A_typ _ | A_bool _ -> arg
| A_nexp n -> A_aux (A_nexp (f n), l)
- let canonical env typ =
- let typ = expand_synonyms env typ in
- let counter = ref 0 in
- let complex_nexps = ref KBindings.empty in
- let simplify_nexp (Nexp_aux (nexp_aux, l) as nexp) =
- match nexp_aux with
- | Nexp_constant _ -> nexp (* Check this ? *)
- | _ ->
- let kid = Kid_aux (Var ("'c#" ^ string_of_int !counter), l) in
- complex_nexps := KBindings.add kid nexp !complex_nexps;
- incr counter;
- Nexp_aux (Nexp_var kid, l)
- in
- let typ = map_nexps (fun nexp -> simplify_nexp (nexp_simp nexp)) typ in
- let existentials = KBindings.bindings !complex_nexps |> List.map fst in
- let constrs = List.fold_left (fun ncs (kid, nexp) -> nc_eq (nvar kid) nexp :: ncs) [] (KBindings.bindings !complex_nexps) in
- existentials, constrs, typ
-
- let is_canonical env typ =
- let typ = expand_synonyms env typ in
- let counter = ref 0 in
- let simplify_nexp (Nexp_aux (nexp_aux, l) as nexp) =
- match nexp_aux with
- | Nexp_constant _ -> nexp
- | _ -> (incr counter; nexp)
- in
- let typ = map_nexps simplify_nexp typ in
- not (!counter > 0)
-
- let rec canonicalize env typ =
- match typ with
- | Typ_aux (Typ_fn (arg_typs, ret_typ, effects), l) when List.for_all (is_canonical env) arg_typs ->
- Typ_aux (Typ_fn (arg_typs, canonicalize env ret_typ, effects), l)
- | Typ_aux (Typ_fn _, l) -> typ_error l ("Function type " ^ string_of_typ typ ^ " is not canonical")
- | _ ->
- let existentials, constrs, (Typ_aux (typ_aux, l) as typ) = canonical env typ in
- if existentials = [] then
- typ
- else
- let typ_aux = match typ_aux with
- | Typ_tup _ | Typ_app _ -> Typ_exist (existentials, List.fold_left nc_and (List.hd constrs) (List.tl constrs), typ)
- | Typ_exist (kids, nc, typ) -> Typ_exist (kids @ existentials, List.fold_left nc_and nc constrs, typ)
- | Typ_fn _ | Typ_bidir _ | Typ_id _ | Typ_var _ -> assert false (* These must be simple *)
- | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown"
- in
- Typ_aux (typ_aux, l)
-
(* Check if a type, order, n-expression or constraint is
well-formed. Throws a type error if the type is badly formed. *)
let rec wf_typ ?exs:(exs=KidSet.empty) env typ =
@@ -667,7 +646,6 @@ end = struct
end
| Nexp_constant _ -> ()
| Nexp_app (id, nexps) ->
- let _ = get_smt_op id env in
List.iter (fun n -> wf_nexp ~exs:exs env n) nexps
| Nexp_times (nexp1, nexp2) -> wf_nexp ~exs:exs env nexp1; wf_nexp ~exs:exs env nexp2
| Nexp_sum (nexp1, nexp2) -> wf_nexp ~exs:exs env nexp1; wf_nexp ~exs:exs env nexp2
@@ -746,20 +724,6 @@ end = struct
let ex_counter = ref 0
- (* TODO: Currently this is duplicated with destruct_exist outside of Env and deals with val spec arguments only. *)
- let fresh_existential ?name:(n="") () =
- let fresh = Kid_aux (Var ("'all" ^ string_of_int !ex_counter ^ "#" ^ n), Parse_ast.Unknown) in
- incr ex_counter; fresh
-
- let destruct_exist env typ =
- match expand_synonyms env typ with
- | Typ_aux (Typ_exist (kids, nc, typ), _) ->
- let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in
- let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in
- let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in
- Some (List.map snd fresh_kids, nc, typ)
- | _ -> None
-
let rec update_val_spec id (typq, typ) env =
begin match expand_synonyms env typ with
| Typ_aux (Typ_fn (arg_typs, ret_typ, effect), l) ->
@@ -769,7 +733,7 @@ end = struct
forall 'n, 'n >= 2. (int('n), foo) -> bar
this enforces the invariant that all things on the left of functions are 'base types' (i.e. without existentials)
*)
- let base_args = List.map (destruct_exist env) arg_typs in
+ let base_args = List.map (fun typ -> destruct_exist (expand_synonyms env typ)) arg_typs in
let existential_arg typq = function
| None -> typq
| Some (exs, nc, _) ->
@@ -959,10 +923,15 @@ end = struct
| None -> typ_error (id_loc id) ("union " ^ string_of_id id ^ " not found")
let add_union_id id bind env =
- begin
- typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind));
- { env with union_ids = Bindings.add id bind env.union_ids }
- end
+ typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind));
+ { env with union_ids = Bindings.add id bind env.union_ids }
+
+ let get_union_id id env =
+ try
+ let bind = Bindings.find id env.union_ids in
+ List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars)
+ with
+ | Not_found -> typ_error (id_loc id) ("No union constructor found for " ^ string_of_id id)
let get_flow id env =
try Bindings.find id env.flow with
@@ -1156,21 +1125,6 @@ let default_order_error_string =
let dvector_typ env n typ = vector_typ n (Env.get_default_order env) typ
-let ex_counter = ref 0
-
-let fresh_existential ?name:(n="") () =
- let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#" ^ n), Parse_ast.Unknown) in
- incr ex_counter; fresh
-
-let destruct_exist env typ =
- match Env.expand_synonyms env typ with
- | Typ_aux (Typ_exist (kids, nc, typ), _) ->
- let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in
- let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in
- let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in
- Some (List.map snd fresh_kids, nc, typ)
- | _ -> None
-
let add_existential l kids nc env =
let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env kids in
Env.add_constraint nc env
@@ -1185,34 +1139,8 @@ let exist_typ constr typ =
let fresh_kid = fresh_existential () in
mk_typ (Typ_exist ([fresh_kid], constr fresh_kid, typ fresh_kid))
-(** Destructure and canonicalise a numeric type into a list of type
- variables, a constraint on those type variables, and an
- N-expression that represents that numeric type in the
- environment. For example:
- - {'n, 'n <= 10. atom('n)} => ['n], 'n <= 10, 'n
- - int => ['n], true, 'n (where x is fresh)
- - atom('n) => [], true, 'n
-**)
-let destruct_numeric env typ =
- let typ = Env.expand_synonyms env typ in
- match destruct_exist env typ, typ with
- | Some (kids, nc, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _)), _ when string_of_id id = "atom" ->
- Some (kids, nc, nexp)
- | None, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _) when string_of_id id = "atom" ->
- Some ([], nc_true, nexp)
- | None, Typ_aux (Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]), _) when string_of_id id = "range" ->
- let kid = fresh_existential () in
- Some ([kid], nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi), nvar kid)
- | None, Typ_aux (Typ_id id, _) when string_of_id id = "nat" ->
- let kid = fresh_existential () in
- Some ([kid], nc_lteq (nint 0) (nvar kid), nvar kid)
- | None, Typ_aux (Typ_id id, _) when string_of_id id = "int" ->
- let kid = fresh_existential () in
- Some ([kid], nc_true, nvar kid)
- | _, _ -> None
-
let bind_numeric l typ env =
- match destruct_numeric env typ with
+ match destruct_numeric (Env.expand_synonyms env typ) with
| Some (kids, nc, nexp) ->
nexp, add_existential l kids nc env
| None -> typ_error l ("Expected " ^ string_of_typ typ ^ " to be numeric")
@@ -1220,15 +1148,13 @@ let bind_numeric l typ env =
(** Pull an (potentially)-existentially qualified type into the global
typing environment **)
let bind_existential l typ env =
- match destruct_numeric env typ with
- | Some (kids, nc, nexp) -> atom_typ nexp, add_existential l kids nc env
- | None -> match destruct_exist env typ with
- | Some (kids, nc, typ) -> typ, add_existential l kids nc env
- | None -> typ, env
+ match destruct_exist (Env.expand_synonyms env typ) with
+ | Some (kids, nc, typ) -> typ, add_existential l kids nc env
+ | None -> typ, env
let destruct_range env typ =
let kids, constr, (Typ_aux (typ_aux, _)) =
- Util.option_default ([], nc_true, typ) (destruct_exist env typ)
+ Util.option_default ([], nc_true, typ) (destruct_exist (Env.expand_synonyms env typ))
in
match typ_aux with
| Typ_app (f, [A_aux (A_nexp n, _)])
@@ -1492,7 +1418,7 @@ let rec unify_typ l env goals (Typ_aux (aux1, _) as typ1) (Typ_aux (aux2, _) as
| Typ_internal_unknown, _ | _, Typ_internal_unknown
when Env.allow_unknowns env ->
KBindings.empty
-
+
| Typ_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_typ typ2)
| Typ_app (range, [A_aux (A_nexp n1, _); A_aux (A_nexp n2, _)]),
@@ -1528,7 +1454,8 @@ and unify_order l goals (Ord_aux (aux1, _) as ord1) (Ord_aux (aux2, _) as ord2)
| _, _ -> unify_error l ("Cound not unify " ^ string_of_order ord1 ^ " and " ^ string_of_order ord2)
and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) =
- typ_debug (lazy ("UNIFYING NEXPS " ^ string_of_nexp nexp1 ^ " AND " ^ string_of_nexp nexp2 ^ " FOR GOALS " ^ string_of_list ", " string_of_kid (KidSet.elements goals)));
+ typ_debug (lazy (Util.("Unify nexp " |> magenta |> clear) ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2
+ ^ " goals " ^ string_of_list ", " string_of_kid (KidSet.elements goals)));
if KidSet.is_empty (KidSet.inter (nexp_frees nexp1) goals)
then
begin
@@ -1559,19 +1486,17 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au
then unify_nexp l env goals n1a (nsum nexp2 n1b)
else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
| Nexp_times (n1a, n1b) ->
- (* If we have SMT operations div and mod, then we can use the
+ (* f we have SMT operations div and mod, then we can use the
property that
mod(m, C) = 0 && C != 0 --> (C * n = m <--> n = m / C)
- to help us unify multiplications. *)
- if Env.have_smt_op (mk_id "div") env && Env.have_smt_op (mk_id "mod") env then
- let valid n c = prove env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove env (nc_neq c (nint 0)) in
- if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then
- unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b])
- else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then
- unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a])
- else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
+ to help us unify multiplications and divisions. *)
+ let valid n c = prove env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove env (nc_neq c (nint 0)) in
+ if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then
+ unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b])
+ else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then
+ unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a])
else if KidSet.is_empty (nexp_frees n1a) then
begin
match nexp_aux2 with
@@ -1611,7 +1536,7 @@ let subst_unifiers unifiers typ =
let subst_unifiers_typ_arg unifiers typ_arg =
List.fold_left (fun typ_arg (v, arg) -> typ_arg_subst v arg typ_arg) typ_arg (KBindings.bindings unifiers)
-
+
let instantiate_quant (v, arg) (QI_aux (aux, l) as qi) =
match aux with
| QI_id kopt when Kid.compare (kopt_kid kopt) v = 0 ->
@@ -1716,7 +1641,7 @@ let rec alpha_equivalent env typ1 typ2 =
else (typ_debug (lazy "Not alpha-equivalent"); false)
let unwrap_exist env typ =
- match destruct_exist env typ with
+ match destruct_exist (Env.expand_synonyms env typ) with
| Some (kids, nc, typ) -> (kids, nc, typ)
| None -> ([], nc_true, typ)
@@ -1725,13 +1650,51 @@ let unifier_constraint env (v, arg) =
| A_aux (A_nexp nexp, _) -> Env.add_constraint (nc_eq (nvar v) nexp) env
| _ -> env
-let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as typ2) =
+let canonicalize env typ =
+ let typ = Env.expand_synonyms env typ in
+ let rec canon (Typ_aux (aux, l)) =
+ match aux with
+ | Typ_var v -> Typ_aux (Typ_var v, l)
+ | Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l)
+ | Typ_id id when string_of_id id = "int" ->
+ exist_typ (fun _ -> nc_true) (fun v -> atom_typ (nvar v))
+ | Typ_id id -> Typ_aux (Typ_id id, l)
+ | Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]) when string_of_id id = "range" ->
+ exist_typ (fun v -> nc_and (nc_lteq lo (nvar v)) (nc_lteq (nvar v) hi)) (fun v -> atom_typ (nvar v))
+ | Typ_app (id, args) ->
+ Typ_aux (Typ_app (id, List.map canon_arg args), l)
+ | Typ_tup typs ->
+ let typs = List.map canon typs in
+ let fold_exist (kids, nc, typs) typ =
+ match destruct_exist typ with
+ | Some (kids', nc', typ') -> (kids @ kids', nc_and nc nc', typs @ [typ'])
+ | None -> (kids, nc, typs @ [typ])
+ in
+ let kids, nc, typs = List.fold_left fold_exist ([], nc_true, []) typs in
+ if kids = [] then
+ Typ_aux (Typ_tup typs, l)
+ else
+ Typ_aux (Typ_exist (kids, nc, Typ_aux (Typ_tup typs, l)), l)
+ | Typ_exist (kids, nc, typ) ->
+ begin match destruct_exist (canon typ) with
+ | Some (kids', nc', typ') ->
+ Typ_aux (Typ_exist (kids @ kids', nc_and nc nc', typ'), l)
+ | None -> Typ_aux (Typ_exist (kids, nc, typ), l)
+ end
+ | Typ_fn _ | Typ_bidir _ -> raise (Reporting.err_unreachable l __POS__ "Function type passed to Type_check.canonicalize")
+ and canon_arg (A_aux (aux, l)) =
+ A_aux ((match aux with
+ | A_typ typ -> A_typ (canon typ)
+ | arg -> arg),
+ l)
+ in
+ canon typ
+
+let rec subtyp l env typ1 typ2 =
+ let (Typ_aux (typ_aux1, _) as typ1) = Env.expand_synonyms env typ1 in
+ let (Typ_aux (typ_aux2, _) as typ2) = Env.expand_synonyms env typ2 in
typ_print (lazy (("Subtype " |> Util.green |> Util.clear) ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2));
- match typ_aux1, typ_aux2 with
- | Typ_tup typs1, Typ_tup typs2 when List.length typs1 = List.length typs2 ->
- List.iter2 (subtyp l env) typs1 typs2
- | _, _ ->
- match destruct_numeric env typ1, destruct_numeric env typ2 with
+ match destruct_numeric typ1, destruct_numeric typ2 with
(* Ensure alpha equivalent types are always subtypes of one another
- this ensures that we can always re-check inferred types. *)
| _, _ when alpha_equivalent env typ1 typ2 -> ()
@@ -1743,27 +1706,50 @@ let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as t
let env = add_existential l kids1 nc1 env in
let env = add_typ_vars l (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2))) env in
let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in
- if not (kids2 = []) then typ_error l "Universally quantified constraint generated" else ();
+ if not (kids2 = []) then typ_error l ("Universally quantified constraint generated: " ^ Util.string_of_list ", " string_of_kid kids2) else ();
let env = Env.add_constraint (nc_eq nexp1 nexp2) env in
if prove env nc2 then ()
else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env))
| _, _ ->
- match destruct_exist env typ1, unwrap_exist env (Env.canonicalize env typ2) with
+ match destruct_exist' typ1, destruct_exist (canonicalize env typ2) with
| Some (kids, nc, typ1), _ ->
let env = add_existential l kids nc env in subtyp l env typ1 typ2
- | None, (kids, nc, typ2) ->
+ | None, Some (kids, nc, typ2) ->
typ_debug (lazy "Subtype check with unification");
+ let typ1 = canonicalize env typ1 in
let env = add_typ_vars l kids env in
let kids' = KidSet.elements (KidSet.diff (KidSet.of_list kids) (typ_frees typ2)) in
if not (kids' = []) then typ_error l "Universally quantified constraint generated" else ();
let unifiers =
- try unify l env (tyvars_of_typ typ2) typ2 typ1 with
+ try unify l env (KidSet.diff (tyvars_of_typ typ2) (tyvars_of_typ typ1)) typ2 typ1 with
| Unification_error (_, m) -> typ_error l m
in
let nc = List.fold_left (fun nc (kid, uvar) -> constraint_subst kid uvar nc) nc (KBindings.bindings unifiers) in
let env = List.fold_left unifier_constraint env (KBindings.bindings unifiers) in
if prove env nc then ()
else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env))
+ | None, None ->
+ match typ_aux1, typ_aux2 with
+ | Typ_tup typs1, Typ_tup typs2 when List.length typs1 = List.length typs2 ->
+ List.iter2 (subtyp l env) typs1 typs2
+
+ | Typ_app (id1, args1), Typ_app (id2, args2) when Id.compare id1 id2 = 0 && List.length args1 = List.length args2 ->
+ List.iter2 (subtyp_arg l env) args1 args2
+
+ | Typ_id id1, Typ_id id2 when Id.compare id1 id2 = 0 -> ()
+ | Typ_id id1, Typ_app (id2, []) when Id.compare id1 id2 = 0 -> ()
+ | Typ_app (id1, []), Typ_id id2 when Id.compare id1 id2 = 0 -> ()
+
+ | _, _ -> typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env))
+
+and subtyp_arg l env (A_aux (aux1, _) as arg1) (A_aux (aux2, _) as arg2) =
+ typ_print (lazy (("Subtype arg " |> Util.green |> Util.clear) ^ string_of_typ_arg arg1 ^ " and " ^ string_of_typ_arg arg2));
+ match aux1, aux2 with
+ | A_nexp n1, A_nexp n2 when prove env (nc_eq n1 n2) -> ()
+ | A_typ typ1, A_typ typ2 -> subtyp l env typ1 typ2
+ | A_order ord1, A_order ord2 when ord_identical ord1 ord2 -> ()
+ | A_bool nc1, A_bool nc2 -> assert false
+ | _, _ -> typ_error l "Mismatched argument types in subtype check"
let typ_equality l env typ1 typ2 =
subtyp l env typ1 typ2; subtyp l env typ2 typ1
@@ -1928,6 +1914,38 @@ let expected_typ_of (l, tannot) = match tannot with
(* Flow typing *)
+type simple_numeric =
+ | Equal of nexp
+ | Constraint of (kid -> n_constraint)
+ | Anything
+
+let to_simple_numeric l kids nc (Nexp_aux (aux, _) as n) =
+ match aux, kids with
+ | Nexp_var v, [v'] when Kid.compare v v' = 0 ->
+ Constraint (fun subst -> constraint_subst v (arg_nexp (nvar subst)) nc)
+ | _, [] ->
+ Equal n
+ | _ ->
+ typ_error l "Numeric type is non-simple"
+
+let union_simple_numeric ex1 ex2 =
+ match ex1, ex2 with
+ | Equal nexp1, Equal nexp2 ->
+ Constraint (fun kid -> nc_or (nc_eq (nvar kid) nexp1) (nc_eq (nvar kid) nexp2))
+
+ | Equal nexp, Constraint c ->
+ Constraint (fun kid -> nc_or (nc_eq (nvar kid) nexp) (c kid))
+
+ | Constraint c, Equal nexp ->
+ Constraint (fun kid -> nc_or (c kid) (nc_eq (nvar kid) nexp))
+
+ | _, _ -> Anything
+
+let typ_of_simple_numeric = function
+ | Anything -> int_typ
+ | Equal nexp -> atom_typ nexp
+ | Constraint c -> exist_typ c (fun kid -> atom_typ (nvar kid))
+
let rec big_int_of_nexp (Nexp_aux (nexp, _)) = match nexp with
| Nexp_constant c -> Some c
| Nexp_times (n1, n2) ->
@@ -1977,17 +1995,17 @@ let rec assert_constraint env b (E_aux (exp_aux, _) as exp) =
combine_constraint (not b) nc_or (assert_constraint env b x) (assert_constraint env b y)
| E_app (op, [x; y]) when string_of_id op = "and_bool" ->
combine_constraint b nc_and (assert_constraint env b x) (assert_constraint env b y)
- | E_app (op, [x; y]) when string_of_id op = "gteq_atom" ->
+ | E_app (op, [x; y]) when string_of_id op = "gteq_int" ->
option_binop nc_gteq (assert_nexp env x) (assert_nexp env y)
- | E_app (op, [x; y]) when string_of_id op = "lteq_atom" ->
+ | E_app (op, [x; y]) when string_of_id op = "lteq_int" ->
option_binop nc_lteq (assert_nexp env x) (assert_nexp env y)
- | E_app (op, [x; y]) when string_of_id op = "gt_atom" ->
+ | E_app (op, [x; y]) when string_of_id op = "gt_int" ->
option_binop nc_gt (assert_nexp env x) (assert_nexp env y)
- | E_app (op, [x; y]) when string_of_id op = "lt_atom" ->
+ | E_app (op, [x; y]) when string_of_id op = "lt_int" ->
option_binop nc_lt (assert_nexp env x) (assert_nexp env y)
- | E_app (op, [x; y]) when string_of_id op = "eq_atom" ->
+ | E_app (op, [x; y]) when string_of_id op = "eq_int" ->
option_binop nc_eq (assert_nexp env x) (assert_nexp env y)
- | E_app (op, [x; y]) when string_of_id op = "neq_atom" ->
+ | E_app (op, [x; y]) when string_of_id op = "neq_int" ->
option_binop nc_neq (assert_nexp env x) (assert_nexp env y)
| _ ->
None
@@ -2398,13 +2416,13 @@ and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ =
in
begin
try
- typ_debug (lazy ("PERFORMING TYPE COERCION: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ));
+ typ_debug (lazy ("Performing type coercion: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ));
subtyp l env (typ_of annotated_exp) typ; switch_exp_typ annotated_exp
with
| Type_error (_, trigger) when Env.allow_casts env ->
let casts = filter_casts env (typ_of annotated_exp) typ (Env.get_casts env) in
try_casts trigger [] casts
- | Type_error (l, err) -> typ_error l "Subtype error"
+ | Type_error (l, err) -> typ_raise l err
end
(* type_coercion_unify env exp typ attempts to coerce exp to a type
@@ -2434,7 +2452,7 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ =
in
begin
try
- typ_debug (lazy "PERFORMING COERCING UNIFICATION");
+ typ_debug (lazy ("Coercing unification: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ));
let atyp, env = bind_existential l (typ_of annotated_exp) env in
annotated_exp, unify l env goals typ atyp, env
with
@@ -2548,7 +2566,7 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
end
| P_app (f, pats) when Env.is_union_constructor f env ->
begin
- let (typq, ctor_typ) = Env.get_val_spec f env in
+ let (typq, ctor_typ) = Env.get_union_id f env in
let quants = quant_items typq in
let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with
| Typ_tup typs -> typs
@@ -2563,8 +2581,8 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
let unifiers = unify l env goals ret_typ typ in
let arg_typ' = subst_unifiers unifiers arg_typ in
let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
- if (match quants' with [] -> false | _ -> true)
- then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat)
+ if not (List.for_all (solve_quant env) quants') then
+ typ_raise l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env))
else ();
let ret_typ' = subst_unifiers unifiers ret_typ in
let tpats, env, guards =
@@ -2580,7 +2598,7 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
| P_app (f, pats) when Env.is_mapping f env ->
begin
- let (typq, mapping_typ) = Env.get_val_spec f env in
+ let (typq, mapping_typ) = Env.get_union_id f env in
let quants = quant_items typq in
let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with
| Typ_tup typs -> typs
@@ -3094,7 +3112,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
let inferred_f = irule infer_exp env f in
let inferred_t = irule infer_exp env t in
let checked_step = crule check_exp env step int_typ in
- match destruct_numeric env (typ_of inferred_f), destruct_numeric env (typ_of inferred_t) with
+ match destruct_numeric (typ_of inferred_f), destruct_numeric (typ_of inferred_t) with
| Some (kids1, nc1, nexp1), Some (kids2, nc2, nexp2) ->
let loop_kid = mk_kid ("loop_" ^ string_of_id v) in
let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env (loop_kid :: kids1 @ kids2) in
@@ -3110,8 +3128,22 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
| E_if (cond, then_branch, else_branch) ->
let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in
let then_branch' = irule infer_exp (add_opt_constraint (assert_constraint env true cond') env) then_branch in
- let else_branch' = crule check_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch (typ_of then_branch') in
- annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch')
+ (* We don't have generic type union in Sail, but we can union simple numeric types. *)
+ begin match destruct_numeric (Env.expand_synonyms env (typ_of then_branch')) with
+ | Some (kids, nc, then_nexp) ->
+ let then_sn = to_simple_numeric l kids nc then_nexp in
+ let else_branch' = irule infer_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch in
+ begin match destruct_numeric (Env.expand_synonyms env (typ_of else_branch')) with
+ | Some (kids, nc, else_nexp) ->
+ let else_sn = to_simple_numeric l kids nc else_nexp in
+ let typ = typ_of_simple_numeric (union_simple_numeric then_sn else_sn) in
+ annot_exp (E_if (cond', then_branch', else_branch')) typ
+ | None -> typ_error l ("Could not infer type of " ^ string_of_exp else_branch)
+ end
+ | None ->
+ let else_branch' = crule check_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch (typ_of then_branch') in
+ annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch')
+ end
| E_vector_access (v, n) -> infer_exp env (E_aux (E_app (mk_id "vector_access", [v; n]), (l, ())))
| E_vector_update (v, n, exp) -> infer_exp env (E_aux (E_app (mk_id "vector_update", [v; n; exp]), (l, ())))
| E_vector_update_subrange (v, n, m, exp) -> infer_exp env (E_aux (E_app (mk_id "vector_update_subrange", [v; n; m; exp]), (l, ())))
@@ -4163,10 +4195,6 @@ let check_val_spec env (VS_aux (vs, (l, _))) =
let vs, id, typq, typ, env = match vs with
| VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), ts_l) as typschm, id, ext_opt, is_cast) ->
typ_print (lazy (Util.("Check val spec " |> cyan |> clear) ^ string_of_id id ^ " : " ^ string_of_typschm typschm));
- let env = match (ext_opt "smt", ext_opt "#") with
- | Some op, None -> Env.add_smt_op id op env
- | _, _ -> env
- in
let env = Env.add_extern id ext_opt env in
let env = if is_cast then Env.add_cast id env else env in
let typq, typ =
diff --git a/src/type_check.mli b/src/type_check.mli
index 52ade6fa..47b9d172 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -300,6 +300,8 @@ val prove : Env.t -> n_constraint -> bool
val solve : Env.t -> nexp -> Big_int.num option
+val canonicalize : Env.t -> typ -> typ
+
val subtype_check : Env.t -> typ -> typ -> bool
val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t * unit Ast.exp list
@@ -350,11 +352,11 @@ val destruct_atom_nexp : Env.t -> typ -> nexp option
(** Safely destructure an existential type. Returns None if the type
is not existential. This function will pick a fresh name for the
existential to ensure that no name-clashes occur. *)
-val destruct_exist : Env.t -> typ -> (kid list * n_constraint * typ) option
+val destruct_exist : typ -> (kid list * n_constraint * typ) option
val destruct_range : Env.t -> typ -> (kid list * n_constraint * nexp * nexp) option
-val destruct_numeric : Env.t -> typ -> (kid list * n_constraint * nexp) option
+val destruct_numeric : typ -> (kid list * n_constraint * nexp) option
val destruct_vector : Env.t -> typ -> (nexp * order * typ) option
diff --git a/test/typecheck/pass/constrained_struct/v1.expect b/test/typecheck/pass/constrained_struct/v1.expect
index 5173ef0b..ab25cbc4 100644
--- a/test/typecheck/pass/constrained_struct/v1.expect
+++ b/test/typecheck/pass/constrained_struct/v1.expect
@@ -2,4 +2,4 @@ Type error at file "constrained_struct/v1.sail", line 10, character 19 to line 1
type MyStruct64 = MyStruct(65)
-Could not prove (65 = 32 | 65 = 64) for type constructor MyStruct
+Could not prove (65 == 32 | 65 == 64) for type constructor MyStruct
diff --git a/test/typecheck/pass/constraint_ctor.sail b/test/typecheck/pass/constraint_ctor.sail
new file mode 100644
index 00000000..2b4a5746
--- /dev/null
+++ b/test/typecheck/pass/constraint_ctor.sail
@@ -0,0 +1,20 @@
+default Order dec
+
+$include <flow.sail>
+
+union Foo = {
+ Foo : {'n, 'n >= 3. int('n)}
+}
+
+function foo(Foo(x as int('x)): Foo) -> unit = {
+ _prove(constraint('x >= 3));
+}
+
+union Bar('m), 'm <= 100 = {
+ Bar : {'n, 'n >= 'm. int('n)}
+}
+
+function bar(Bar(x as int('x)) : Bar(23)) -> unit = {
+ _prove(constraint('x >= 23));
+ ()
+}
diff --git a/test/typecheck/pass/constraint_ctor/v1.expect b/test/typecheck/pass/constraint_ctor/v1.expect
new file mode 100644
index 00000000..c3886af8
--- /dev/null
+++ b/test/typecheck/pass/constraint_ctor/v1.expect
@@ -0,0 +1,5 @@
+Type error at file "constraint_ctor/v1.sail", line 10, character 3 to line 10, character 29
+
+ _prove(constraint('x >= 4));
+
+Cannot prove 'x >= 4
diff --git a/test/typecheck/pass/constraint_ctor/v1.sail b/test/typecheck/pass/constraint_ctor/v1.sail
new file mode 100644
index 00000000..20df5480
--- /dev/null
+++ b/test/typecheck/pass/constraint_ctor/v1.sail
@@ -0,0 +1,20 @@
+default Order dec
+
+$include <flow.sail>
+
+union Foo = {
+ Foo : {'n, 'n >= 3. int('n)}
+}
+
+function foo(Foo(x as int('x)): Foo) -> unit = {
+ _prove(constraint('x >= 4));
+}
+
+union Bar('m), 'm <= 100 = {
+ Bar : {'n, 'n >= 'm. int('n)}
+}
+
+function bar(Bar(x as int('x)) : Bar(23)) -> unit = {
+ _prove(constraint('x >= 23));
+ ()
+}
diff --git a/test/typecheck/pass/constraint_ctor/v2.expect b/test/typecheck/pass/constraint_ctor/v2.expect
new file mode 100644
index 00000000..a315b3b7
--- /dev/null
+++ b/test/typecheck/pass/constraint_ctor/v2.expect
@@ -0,0 +1,5 @@
+Type error at file "constraint_ctor/v2.sail", line 18, character 3 to line 18, character 30
+
+ _prove(constraint('x >= 24));
+
+Cannot prove 'x >= 24
diff --git a/test/typecheck/pass/constraint_ctor/v2.sail b/test/typecheck/pass/constraint_ctor/v2.sail
new file mode 100644
index 00000000..76d9793d
--- /dev/null
+++ b/test/typecheck/pass/constraint_ctor/v2.sail
@@ -0,0 +1,20 @@
+default Order dec
+
+$include <flow.sail>
+
+union Foo = {
+ Foo : {'n, 'n >= 3. int('n)}
+}
+
+function foo(Foo(x as int('x)): Foo) -> unit = {
+ _prove(constraint('x >= 3));
+}
+
+union Bar('m), 'm <= 100 = {
+ Bar : {'n, 'n >= 'm. int('n)}
+}
+
+function bar(Bar(x as int('x)) : Bar(23)) -> unit = {
+ _prove(constraint('x >= 24));
+ ()
+}
diff --git a/test/typecheck/pass/constraint_ctor/v3.expect b/test/typecheck/pass/constraint_ctor/v3.expect
new file mode 100644
index 00000000..e0edd01a
--- /dev/null
+++ b/test/typecheck/pass/constraint_ctor/v3.expect
@@ -0,0 +1,5 @@
+Type error at file "constraint_ctor/v3.sail", line 18, character 3 to line 18, character 30
+
+ _prove(constraint('x >= 23));
+
+Cannot prove 'x >= 23
diff --git a/test/typecheck/pass/constraint_ctor/v3.sail b/test/typecheck/pass/constraint_ctor/v3.sail
new file mode 100644
index 00000000..a8f5bd13
--- /dev/null
+++ b/test/typecheck/pass/constraint_ctor/v3.sail
@@ -0,0 +1,20 @@
+default Order dec
+
+$include <flow.sail>
+
+union Foo = {
+ Foo : {'n, 'n >= 3. int('n)}
+}
+
+function foo(Foo(x as int('x)): Foo) -> unit = {
+ _prove(constraint('x >= 3));
+}
+
+union Bar('m), 'm <= 100 = {
+ Bar : {'n, 'n >= 'm. int('n)}
+}
+
+function bar(Bar(x as int('x)) : Bar(22)) -> unit = {
+ _prove(constraint('x >= 23));
+ ()
+}
diff --git a/test/typecheck/pass/constraint_ctor/v4.expect b/test/typecheck/pass/constraint_ctor/v4.expect
new file mode 100644
index 00000000..06eb9d22
--- /dev/null
+++ b/test/typecheck/pass/constraint_ctor/v4.expect
@@ -0,0 +1,5 @@
+Type error at file "constraint_ctor/v4.sail", line 17, character 34 to line 17, character 36
+
+function bar(Bar(x as int('x)) : Bar(23)) -> unit = {
+
+Could not prove 23 <= 22 for type constructor Bar
diff --git a/test/typecheck/pass/constraint_ctor/v4.sail b/test/typecheck/pass/constraint_ctor/v4.sail
new file mode 100644
index 00000000..d8dab178
--- /dev/null
+++ b/test/typecheck/pass/constraint_ctor/v4.sail
@@ -0,0 +1,20 @@
+default Order dec
+
+$include <flow.sail>
+
+union Foo = {
+ Foo : {'n, 'n >= 3. int('n)}
+}
+
+function foo(Foo(x as int('x)): Foo) -> unit = {
+ _prove(constraint('x >= 3));
+}
+
+union Bar('m), 'm <= 22 = {
+ Bar : {'n, 'n >= 'm. int('n)}
+}
+
+function bar(Bar(x as int('x)) : Bar(23)) -> unit = {
+ _prove(constraint('x >= 23));
+ ()
+}
diff --git a/test/typecheck/pass/exist2.sail b/test/typecheck/pass/exist2.sail
index 102a1084..e518609d 100644
--- a/test/typecheck/pass/exist2.sail
+++ b/test/typecheck/pass/exist2.sail
@@ -39,6 +39,6 @@ overload existential = {existential_int, existential_range}
let v11 : {'n, 0 == 0. atom('n)} = existential(v10)
-let v12 : {'e, 0 <= 'e & 'e <= 3. atom('e)} = existential(2 : range(0, 3))
+let v12 : {'e, 0 <= 'e & 'e <= 3. atom('e)} = 2
let v13 : MyInt = existential(v10)
diff --git a/test/typecheck/pass/global_type_var/v1.expect b/test/typecheck/pass/global_type_var/v1.expect
index 7e3b517c..e81c467e 100644
--- a/test/typecheck/pass/global_type_var/v1.expect
+++ b/test/typecheck/pass/global_type_var/v1.expect
@@ -6,15 +6,15 @@ Tried performing type coercion from int(32) to int('size) on 32
Coercion failed because:
int(32) is not a subtype of int('size)
in context
- * 'size = 'ex8#
- * ('ex8# = 32 | 'ex8# = 64)
- * ('ex7# = 32 | 'ex7# = 64)
+ * 'size == 'ex14#
+ * ('ex14# == 32 | 'ex14# == 64)
+ * ('ex13# == 32 | 'ex13# == 64)
where
- * 'ex7# bound at file "global_type_var/v1.sail", line 5, character 5 to line 5, character 32
+ * 'ex13# bound at file "global_type_var/v1.sail", line 5, character 5 to line 5, character 32
let (size as 'size) : {|32, 64|} = 32
- * 'ex8# bound at file "global_type_var/v1.sail", line 5, character 6 to line 5, character 18
+ * 'ex14# bound at file "global_type_var/v1.sail", line 5, character 6 to line 5, character 18
let (size as 'size) : {|32, 64|} = 32
diff --git a/test/typecheck/pass/global_type_var/v2.expect b/test/typecheck/pass/global_type_var/v2.expect
index dc1281d2..21c4b348 100644
--- a/test/typecheck/pass/global_type_var/v2.expect
+++ b/test/typecheck/pass/global_type_var/v2.expect
@@ -6,15 +6,15 @@ Tried performing type coercion from int(64) to int('size) on 64
Coercion failed because:
int(64) is not a subtype of int('size)
in context
- * 'size = 'ex8#
- * ('ex8# = 32 | 'ex8# = 64)
- * ('ex7# = 32 | 'ex7# = 64)
+ * 'size == 'ex14#
+ * ('ex14# == 32 | 'ex14# == 64)
+ * ('ex13# == 32 | 'ex13# == 64)
where
- * 'ex7# bound at file "global_type_var/v2.sail", line 5, character 5 to line 5, character 32
+ * 'ex13# bound at file "global_type_var/v2.sail", line 5, character 5 to line 5, character 32
let (size as 'size) : {|32, 64|} = 32
- * 'ex8# bound at file "global_type_var/v2.sail", line 5, character 6 to line 5, character 18
+ * 'ex14# bound at file "global_type_var/v2.sail", line 5, character 6 to line 5, character 18
let (size as 'size) : {|32, 64|} = 32
diff --git a/test/typecheck/pass/if_infer.sail b/test/typecheck/pass/if_infer.sail
new file mode 100644
index 00000000..f3fec1c4
--- /dev/null
+++ b/test/typecheck/pass/if_infer.sail
@@ -0,0 +1,12 @@
+default Order dec
+
+$include <prelude.sail>
+
+register R : bool
+
+val f : unit -> {'n, 1 <= 'n <= 3. int('n)}
+
+function main((): unit) -> unit = {
+ let _ = 0b1001[if R then 0 else f()];
+ ()
+}
diff --git a/test/typecheck/pass/if_infer/v1.expect b/test/typecheck/pass/if_infer/v1.expect
new file mode 100644
index 00000000..06df7dc5
--- /dev/null
+++ b/test/typecheck/pass/if_infer/v1.expect
@@ -0,0 +1,17 @@
+Type error at file "if_infer/v1.sail", line 10, character 11 to line 10, character 37
+
+ let _ = 0b100[if R then 0 else f()];
+
+No overloadings for vector_access, tried:
+ bitvector_access:
+ Could not resolve quantifiers for bitvector_access (0 <= 'ex41#ex40# & ('ex41#ex40# + 1) <= 3)
+
+ Try adding named type variables for
+
+
+ plain_vector_access:
+ Could not resolve quantifiers for plain_vector_access (0 <= 'ex44#ex43# & ('ex44#ex43# + 1) <= 3)
+
+ Try adding named type variables for
+
+
diff --git a/test/typecheck/pass/if_infer/v1.sail b/test/typecheck/pass/if_infer/v1.sail
new file mode 100644
index 00000000..0938aaed
--- /dev/null
+++ b/test/typecheck/pass/if_infer/v1.sail
@@ -0,0 +1,12 @@
+default Order dec
+
+$include <prelude.sail>
+
+register R : bool
+
+val f : unit -> {'n, 1 <= 'n <= 3. int('n)}
+
+function main((): unit) -> unit = {
+ let _ = 0b100[if R then 0 else f()];
+ ()
+}
diff --git a/test/typecheck/pass/if_infer/v2.expect b/test/typecheck/pass/if_infer/v2.expect
new file mode 100644
index 00000000..050e90e4
--- /dev/null
+++ b/test/typecheck/pass/if_infer/v2.expect
@@ -0,0 +1,17 @@
+Type error at file "if_infer/v2.sail", line 10, character 11 to line 10, character 38
+
+ let _ = 0b1001[if R then 0 else f()];
+
+No overloadings for vector_access, tried:
+ bitvector_access:
+ Could not resolve quantifiers for bitvector_access (0 <= 'ex41#ex40# & ('ex41#ex40# + 1) <= 4)
+
+ Try adding named type variables for
+
+
+ plain_vector_access:
+ Could not resolve quantifiers for plain_vector_access (0 <= 'ex44#ex43# & ('ex44#ex43# + 1) <= 4)
+
+ Try adding named type variables for
+
+
diff --git a/test/typecheck/pass/if_infer/v2.sail b/test/typecheck/pass/if_infer/v2.sail
new file mode 100644
index 00000000..a49e1ed7
--- /dev/null
+++ b/test/typecheck/pass/if_infer/v2.sail
@@ -0,0 +1,12 @@
+default Order dec
+
+$include <prelude.sail>
+
+register R : bool
+
+val f : unit -> {'n, 1 <= 'n <= 4. int('n)}
+
+function main((): unit) -> unit = {
+ let _ = 0b1001[if R then 0 else f()];
+ ()
+}
diff --git a/test/typecheck/pass/if_infer/v3.expect b/test/typecheck/pass/if_infer/v3.expect
new file mode 100644
index 00000000..8b149bc8
--- /dev/null
+++ b/test/typecheck/pass/if_infer/v3.expect
@@ -0,0 +1,7 @@
+Type error at file "if_infer/v3.sail", line 10, character 11 to line 10, character 38
+
+ let _ = 0b1001[if R then 0 else f()];
+
+No overloadings for vector_access, tried:
+ bitvector_access: Numeric type is non-simple
+ plain_vector_access: Numeric type is non-simple
diff --git a/test/typecheck/pass/if_infer/v3.sail b/test/typecheck/pass/if_infer/v3.sail
new file mode 100644
index 00000000..0c3dec21
--- /dev/null
+++ b/test/typecheck/pass/if_infer/v3.sail
@@ -0,0 +1,12 @@
+default Order dec
+
+$include <prelude.sail>
+
+register R : bool
+
+val f : unit -> {'n 'm, 'm == 3 & 1 <= 'n <= 'm. int('n)}
+
+function main((): unit) -> unit = {
+ let _ = 0b1001[if R then 0 else f()];
+ ()
+}