summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlasdair Armstrong2017-07-06 19:01:09 +0100
committerAlasdair Armstrong2017-07-06 19:01:09 +0100
commit205e09e36baaf8cf2aa794e84d8e13daf8c4c4b7 (patch)
tree55a04a38c4e932f17a12621e9d96b6f2d0a0a6e9 /src
parent4bb28c48b92a469b8a7eeae5ae6e32418c8936ae (diff)
Testing new typechecker on MIPS spec
Also: - Added support for foreach loops - Started work on type unions - Flow typing can now generate constraints, in addition to restricting range-typed variables - Various bugfixes - Better unification for nexps with multiplication
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml6
-rw-r--r--src/type_check_new.ml194
-rw-r--r--src/type_check_new.mli2
3 files changed, 170 insertions, 32 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 59e19e0b..a84df58b 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -270,6 +270,12 @@ let rec string_of_exp (E_aux (exp, _)) =
| E_if (cond, then_branch, else_branch) ->
"if " ^ string_of_exp cond ^ " then " ^ string_of_exp then_branch ^ " else " ^ string_of_exp else_branch
| E_field (exp, id) -> string_of_exp exp ^ "." ^ string_of_id id
+ | E_for (id, f, t, u, ord, body) ->
+ "foreach ("
+ ^ string_of_id id ^ " from " ^ string_of_exp f ^ " to " ^ string_of_exp t
+ ^ " by " ^ string_of_exp u ^ " order " ^ string_of_order ord
+ ^ ") { "
+ ^ string_of_exp body
| _ -> "INTERNAL"
and string_of_pexp (Pat_aux (Pat_exp (pat, exp), _)) = string_of_pat pat ^ " -> " ^ string_of_exp exp
and string_of_pat (P_aux (pat, l)) =
diff --git a/src/type_check_new.ml b/src/type_check_new.ml
index 6a1bc967..12ec1c28 100644
--- a/src/type_check_new.ml
+++ b/src/type_check_new.ml
@@ -84,6 +84,7 @@ let int_typ = mk_id_typ (mk_id "int")
let nat_typ = mk_id_typ (mk_id "nat")
let unit_typ = mk_id_typ (mk_id "unit")
let bit_typ = mk_id_typ (mk_id "bit")
+let app_typ id args = mk_typ (Typ_app (id, args))
let atom_typ nexp = mk_typ (Typ_app (mk_id "atom", [mk_typ_arg (Typ_arg_nexp nexp)]))
let range_typ nexp1 nexp2 = mk_typ (Typ_app (mk_id "range", [mk_typ_arg (Typ_arg_nexp nexp1); mk_typ_arg (Typ_arg_nexp nexp2)]))
let bool_typ = mk_id_typ (mk_id "bool")
@@ -96,11 +97,30 @@ let vector_typ n m ord typ =
mk_typ_arg (Typ_arg_order ord);
mk_typ_arg (Typ_arg_typ typ)]))
+let is_range (Typ_aux (typ_aux, _)) =
+ match typ_aux with
+ | Typ_app (f, [Typ_arg_aux (Typ_arg_nexp n, _)])
+ when string_of_id f = "atom" -> Some (n, n)
+ | Typ_app (f, [Typ_arg_aux (Typ_arg_nexp n1, _); Typ_arg_aux (Typ_arg_nexp n2, _)])
+ when string_of_id f = "range" -> Some (n1, n2)
+ | _ -> None
+
let nconstant c = Nexp_aux (Nexp_constant c, Parse_ast.Unknown)
let nminus n1 n2 = Nexp_aux (Nexp_minus (n1, n2), Parse_ast.Unknown)
let nsum n1 n2 = Nexp_aux (Nexp_sum (n1, n2), Parse_ast.Unknown)
let nvar kid = Nexp_aux (Nexp_var kid, Parse_ast.Unknown)
+let nc_lteq n1 n2 = NC_aux (NC_bounded_le (n1, n2), Parse_ast.Unknown)
+let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown)
+let nc_lt n1 n2 = nc_lteq n1 (nsum n2 (nconstant 1))
+let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nconstant 1))
+
+let nc_negate (NC_aux (nc, _)) =
+ match nc with
+ | NC_bounded_ge (n1, n2) -> Some (nc_lt n1 n2)
+ | NC_bounded_le (n1, n2) -> Some (nc_gt n1 n2)
+ | _ -> None
+
(* Utilities for constructing effect sets *)
let mk_effect effs =
@@ -253,6 +273,7 @@ and nexp_simp_aux = function
let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in
match n1_simp, n2_simp with
| Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (c1 + c2)
+ | _, Nexp_neg n2 -> Nexp_minus (n1, n2)
| _, _ -> Nexp_sum (n1, n2)
end
| Nexp_times (n1, n2) ->
@@ -288,7 +309,7 @@ let typquant_subst_kid sv subst (TypQ_aux (typq, l)) = TypQ_aux (typquant_subst_
type mut = Immutable | Mutable
-type lvar = Register of typ | Enum of typ | Local of mut * typ | Unbound
+type lvar = Register of typ | Enum of typ | Local of mut * typ | Union of typquant * typ | Unbound
module Env : sig
type t
@@ -298,6 +319,8 @@ module Env : sig
val is_record : id -> t -> bool
val get_accessor : id -> t -> typquant * typ
val add_local : id -> mut * typ -> t -> t
+ val add_variant : id -> typquant * type_union list -> t -> t
+ val add_union_id : id -> typquant * typ -> t -> t
val add_flow : id -> (typ -> typ) -> t -> t
val get_flow : id -> t -> typ -> typ
val get_register : id -> t -> typ
@@ -332,8 +355,10 @@ end = struct
type t =
{ top_val_specs : (typquant * typ) Bindings.t;
locals : (mut * typ) Bindings.t;
+ union_ids : (typquant * typ) Bindings.t;
registers : typ Bindings.t;
regtyps : (int * int * (index_range * id) list) Bindings.t;
+ variants : (typquant * type_union list) Bindings.t;
typ_vars : base_kind_aux KBindings.t;
typ_synonyms : (typ_arg list -> typ) Bindings.t;
overloads : (id list) Bindings.t;
@@ -351,8 +376,10 @@ end = struct
let empty =
{ top_val_specs = Bindings.empty;
locals = Bindings.empty;
+ union_ids = Bindings.empty;
registers = Bindings.empty;
regtyps = Bindings.empty;
+ variants = Bindings.empty;
typ_vars = KBindings.empty;
typ_synonyms = Bindings.empty;
overloads = Bindings.empty;
@@ -403,6 +430,7 @@ end = struct
(* FIXME: Add an IdSet for builtin types *)
let bound_typ_id env id =
Bindings.mem id env.typ_synonyms
+ || Bindings.mem id env.variants
|| Bindings.mem id env.records
|| Bindings.mem id env.regtyps
|| Bindings.mem id env.enums
@@ -517,6 +545,18 @@ end = struct
{ env with locals = Bindings.add id mtyp env.locals }
end
+ let add_variant id variant env =
+ begin
+ typ_print ("Adding variant " ^ string_of_id id);
+ { env with variants = Bindings.add id variant env.variants }
+ end
+
+ let add_union_id id bind env =
+ begin
+ typ_print ("Adding union identifier binding " ^ string_of_id id ^ " :: " ^ string_of_bind bind);
+ { env with union_ids = Bindings.add id bind env.union_ids }
+ end
+
let get_flow id env =
try Bindings.find id env.flow with
| Not_found -> fun typ -> typ
@@ -632,7 +672,7 @@ end = struct
let add_constraint (NC_aux (_, l) as constr) env =
wf_constraint env constr;
begin
- typ_debug ("Adding constraint " ^ string_of_n_constraint constr);
+ typ_print ("Adding constraint " ^ string_of_n_constraint constr);
{ env with constraints = constr :: env.constraints }
end
@@ -1029,7 +1069,24 @@ let rec unify_nexps l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (ne
if KidSet.is_empty (nexp_frees n1b)
then unify_nexps l env goals n1a (nsum nexp2 n1b)
else unify_error l ("Cannot unify minus Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
-
+ | Nexp_times (n1a, n1b) ->
+ if KidSet.is_empty (nexp_frees n1a)
+ then
+ begin
+ match nexp_aux2 with
+ | Nexp_times (n2a, n2b) when prove env (NC_aux (NC_fixed (n1a, n2a), Parse_ast.Unknown)) ->
+ unify_nexps l env goals n1b n2b
+ | _ -> unify_error l ("Cannot unify Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
+ end
+ else if KidSet.is_empty (nexp_frees n1b)
+ then
+ begin
+ match nexp_aux2 with
+ | Nexp_times (n2a, n2b) when prove env (NC_aux (NC_fixed (n1b, n2b), Parse_ast.Unknown)) ->
+ unify_nexps l env goals n1a n2a
+ | _ -> unify_error l ("Cannot unify Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
+ end
+ else unify_error l ("Cannot unify Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
| _ -> unify_error l ("Cannot unify Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
let string_of_uvar = function
@@ -1068,10 +1125,7 @@ let subst_args_unifiers unifiers typ_args =
let unify l env typ1 typ2 =
typ_print ("Unify " ^ string_of_typ typ1 ^ " with " ^ string_of_typ typ2);
- if not (KidSet.is_empty (KidSet.inter (typ_frees typ1) (typ_frees typ2)))
- then unify_error l "Can only unify types with disjoint type variables"
- else ();
- let goals = typ_frees typ1 in
+ let goals = KidSet.inter (KidSet.diff (typ_frees typ1) (typ_frees typ2)) (typ_frees typ1) in
let merge_unifiers l kid uvar1 uvar2 =
match uvar1, uvar2 with
| Some (U_nexp n1), Some (U_nexp n2) ->
@@ -1186,8 +1240,8 @@ let quant_items : typquant -> quant_item list = function
| TypQ_aux (TypQ_no_forall, _) -> []
let is_nat_kid kid = function
- | KOpt_aux (KOpt_none kid', _) -> Kid.compare kid kid' = 0
| KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_nat, _)], _), kid'), _) -> Kid.compare kid kid' = 0
+ | KOpt_aux (KOpt_none kid', _) -> Kid.compare kid kid' = 0
| _ -> false
let is_order_kid kid = function
@@ -1251,6 +1305,14 @@ let destructure_atom (Typ_aux (typ_aux, _)) =
when string_of_id f = "range" && c1 = c2 -> c1
| _ -> assert false
+let destructure_atom_nexp (Typ_aux (typ_aux, _)) =
+ match typ_aux with
+ | Typ_app (f, [Typ_arg_aux (Typ_arg_nexp n, _)])
+ when string_of_id f = "atom" -> n
+ | Typ_app (f, [Typ_arg_aux (Typ_arg_nexp n, _); Typ_arg_aux (Typ_arg_nexp _, _)])
+ when string_of_id f = "range" -> n
+ | _ -> assert false
+
let restrict_range_upper c1 (Typ_aux (typ_aux, l) as typ) =
match typ_aux with
| Typ_app (f, [Typ_arg_aux (Typ_arg_nexp nexp, _); Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_constant c2, _)), _)])
@@ -1275,23 +1337,27 @@ let apply_flow_constraint = function
let rec infer_flow env (E_aux (exp_aux, (l, _))) =
match exp_aux with
+ | E_app (f, [x; y]) when string_of_id f = "lteq_atom_atom" ->
+ let n1 = destructure_atom_nexp (typ_of x) in
+ let n2 = destructure_atom_nexp (typ_of y) in
+ [], [nc_lteq n1 n2]
| E_app (f, [E_aux (E_id v, _); y]) when string_of_id f = "lt_range_atom" ->
let kid = Env.fresh_kid env in
let c = destructure_atom (typ_of y) in
- [(v, Flow_lteq (c - 1))]
+ [(v, Flow_lteq (c - 1))], []
| E_app (f, [E_aux (E_id v, _); y]) when string_of_id f = "lteq_range_atom" ->
let kid = Env.fresh_kid env in
let c = destructure_atom (typ_of y) in
- [(v, Flow_lteq c)]
+ [(v, Flow_lteq c)], []
| E_app (f, [E_aux (E_id v, _); y]) when string_of_id f = "gt_range_atom" ->
let kid = Env.fresh_kid env in
let c = destructure_atom (typ_of y) in
- [(v, Flow_gteq (c + 1))]
+ [(v, Flow_gteq (c + 1))], []
| E_app (f, [E_aux (E_id v, _); y]) when string_of_id f = "gteq_range_atom" ->
let kid = Env.fresh_kid env in
let c = destructure_atom (typ_of y) in
- [(v, Flow_gteq c)]
- | _ -> []
+ [(v, Flow_gteq c)], []
+ | _ -> [], []
let rec add_flows b flows env =
match flows with
@@ -1299,6 +1365,18 @@ let rec add_flows b flows env =
| (id, flow) :: flows when b -> add_flows true flows (Env.add_flow id (fst (apply_flow_constraint flow)) env)
| (id, flow) :: flows -> add_flows false flows (Env.add_flow id (snd (apply_flow_constraint flow)) env)
+let neg_constraints = function
+ | [constr] ->
+ begin
+ match nc_negate constr with
+ | Some constr -> [constr]
+ | None -> []
+ end
+ | _ -> []
+
+let rec add_constraints constrs env =
+ List.fold_left (fun env constr -> Env.add_constraint constr env) env constrs
+
let crule r env exp typ =
incr depth;
typ_print ("Check " ^ string_of_exp exp ^ " <= " ^ string_of_typ typ);
@@ -1357,23 +1435,23 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
end
| E_app_infix (x, op, y), _ when List.length (Env.get_overloads (deinfix op) env) > 0 -> check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ()))) typ
| E_app (f, xs), _ when List.length (Env.get_overloads f env) > 0 ->
- let rec try_overload m1 = function
- | [] -> typ_error l (m1 ^ "\nNo valid overloading for " ^ string_of_exp exp)
+ let rec try_overload = function
+ | [] -> typ_error l ("No valid overloading for " ^ string_of_exp exp)
| (f :: fs) -> begin
typ_print ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")");
try crule check_exp env (E_aux (E_app (f, xs), (l, ()))) typ with
- | Type_error (_, m2) -> try_overload (m1 ^ "\nand " ^ m2) fs
+ | Type_error (_, m2) -> try_overload fs
end
in
- try_overload "Overloading error" (Env.get_overloads f env)
+ try_overload (Env.get_overloads f env)
| E_app (f, xs), _ ->
let inferred_exp = infer_funapp l env f xs (Some typ) in
type_coercion env inferred_exp typ
| E_if (cond, then_branch, else_branch), _ ->
let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in
- let flows = infer_flow env cond' in
- let then_branch' = crule check_exp (add_flows true flows env) then_branch typ in
- let else_branch' = crule check_exp (add_flows false flows env) else_branch typ in
+ let flows, constrs = infer_flow env cond' in
+ let then_branch' = crule check_exp (add_constraints constrs (add_flows true flows env)) then_branch typ in
+ let else_branch' = crule check_exp (add_constraints (neg_constraints constrs) (add_flows false flows env)) else_branch typ in
annot_exp (E_if (cond', then_branch', else_branch')) typ
| E_exit exp, _ ->
let checked_exp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in
@@ -1388,7 +1466,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
else typ_error l "List length didn't match" (* FIXME: improve error message *)
| _ -> typ_error l "Cannot check list constant against non-constant length vector type"
end
- | E_lit (L_aux (L_undef, _) as lit), _ ->
+ | E_lit (L_aux (L_undef, _) as lit), _ ->
annot_exp_effect (E_lit lit) typ (mk_effect [BE_undef])
(* This rule allows registers of type t to be passed by name with type register<t>*)
| E_id reg, Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "register" ->
@@ -1629,7 +1707,6 @@ and bind_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) typ =
| _ -> typ_error l ("Unhandled l-expression")
and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
- typ_print ("Inferring " ^ string_of_exp exp);
let annot_exp_effect exp typ eff = E_aux (exp, (l, Some (env, typ, eff))) in
let annot_exp exp typ = annot_exp_effect exp typ no_effect in
match exp_aux with
@@ -1693,21 +1770,38 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
annot_exp (E_cast (typ, checked_exp)) typ
| E_app_infix (x, op, y) when List.length (Env.get_overloads (deinfix op) env) > 0 -> infer_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ())))
| E_app (f, xs) when List.length (Env.get_overloads f env) > 0 ->
- let rec try_overload m1 = function
- | [] -> typ_error l (m1 ^ "\nNo valid overloading for " ^ string_of_exp exp)
+ let rec try_overload = function
+ | [] -> typ_error l ("No valid overloading for " ^ string_of_exp exp)
| (f :: fs) -> begin
typ_print ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")");
try irule infer_exp env (E_aux (E_app (f, xs), (l, ()))) with
- | Type_error (_, m2) -> try_overload (m1 ^ "\nand " ^ m2) fs
+ | Type_error (_, m2) -> try_overload fs
end
in
- try_overload "Overloading error" (Env.get_overloads f env)
+ try_overload (Env.get_overloads f env)
| E_app (f, xs) -> infer_funapp l env f xs None
+ | E_for (v, f, t, step, ord, body) ->
+ begin
+ let f, t = match ord with
+ | Ord_aux (Ord_inc, _) -> f, t
+ | Ord_aux (Ord_dec, _) -> t, f (* reverse direction for downto loop *)
+ in
+ let inferred_f = irule infer_exp env f in
+ let inferred_t = irule infer_exp env t in
+ let checked_step = crule check_exp env step int_typ in
+ match is_range (typ_of inferred_f), is_range (typ_of inferred_t) with
+ | None, _ -> typ_error l ("Type of " ^ string_of_exp f ^ " in foreach must be a range")
+ | _, None -> typ_error l ("Type of " ^ string_of_exp t ^ " in foreach must be a range")
+ | Some (l1, l2), Some (u1, u2) when prove env (nc_lteq l2 u1) ->
+ let checked_body = crule check_exp (Env.add_local v (Immutable, range_typ l1 u2) env) body unit_typ in
+ annot_exp (E_for (v, inferred_f, inferred_t, checked_step, ord, checked_body)) unit_typ
+ | _, _ -> typ_error l "Ranges in foreach overlap"
+ end
| E_if (cond, then_branch, else_branch) ->
let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in
- let flows = infer_flow env cond' in
- let then_branch' = irule infer_exp (add_flows true flows env) then_branch in
- let else_branch' = crule check_exp (add_flows false flows env) else_branch (typ_of then_branch') in
+ let flows, constrs = infer_flow env cond' in
+ let then_branch' = irule infer_exp (add_constraints constrs (add_flows true flows env)) then_branch in
+ let else_branch' = crule check_exp (add_constraints (neg_constraints constrs) (add_flows false flows env)) else_branch (typ_of then_branch') in
annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch')
| E_vector_access (v, n) -> infer_exp env (E_aux (E_app (mk_id "vector_access", [v; n]), (l, ())))
| E_vector_append (v1, v2) -> infer_exp env (E_aux (E_app (mk_id "vector_append", [v1; v2]), (l, ())))
@@ -1870,6 +1964,9 @@ and propagate_exp_effect_aux = function
| E_app (id, xs) ->
let propagated_xs = List.map propagate_exp_effect xs in
E_app (id, propagated_xs), collect_effects propagated_xs
+ | E_vector xs ->
+ let propagated_xs = List.map propagate_exp_effect xs in
+ E_vector propagated_xs, collect_effects propagated_xs
| E_tuple xs ->
let propagated_xs = List.map propagate_exp_effect xs in
E_tuple propagated_xs, collect_effects propagated_xs
@@ -1883,6 +1980,13 @@ and propagate_exp_effect_aux = function
let propagated_cases = List.map propagate_pexp_effect cases in
let case_eff = List.fold_left union_effects no_effect (List.map snd propagated_cases) in
E_case (propagated_exp, List.map fst propagated_cases), union_effects (effect_of propagated_exp) case_eff
+ | E_for (v, f, t, step, ord, body) ->
+ let propagated_f = propagate_exp_effect f in
+ let propagated_t = propagate_exp_effect t in
+ let propagated_step = propagate_exp_effect step in
+ let propagated_body = propagate_exp_effect body in
+ E_for (v, propagated_f, propagated_t, propagated_step, ord, propagated_body),
+ collect_effects [propagated_f; propagated_t; propagated_step; propagated_body]
| E_let (letbind, exp) ->
let propagated_lb, eff = propagate_letbind_effect letbind in
let propagated_exp = propagate_exp_effect exp in
@@ -1905,7 +2009,8 @@ and propagate_exp_effect_aux = function
| E_field (exp, id) ->
let propagated_exp = propagate_exp_effect exp in
E_field (propagated_exp, id), effect_of propagated_exp
- | exp_aux -> typ_error Parse_ast.Unknown ("Unimplemented: Cannot propagate effect in expression")
+ | exp_aux -> typ_error Parse_ast.Unknown ("Unimplemented: Cannot propagate effect in expression "
+ ^ string_of_exp (E_aux (exp_aux, (Parse_ast.Unknown, None))))
and propagate_pexp_effect (Pat_aux (Pat_exp (pat, exp), (l, annot))) =
let propagated_pat = propagate_pat_effect pat in
@@ -2101,6 +2206,27 @@ let check_register env id base top ranges =
|> Env.add_cast (mk_id ("cast_" ^ string_of_id id))
| _, _ -> typ_error (id_loc id) "Num expressions in register type declaration do not evaluate to constants"
+let kinded_id_arg kind_id =
+ let typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown) in
+ match kind_id with
+ | KOpt_aux (KOpt_none kid, _) -> typ_arg (Typ_arg_nexp (nvar kid))
+ | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_nat, _)], _), kid), _) -> typ_arg (Typ_arg_nexp (nvar kid))
+ | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_order, _)], _), kid), _) ->
+ typ_arg (Typ_arg_order (Ord_aux (Ord_var kid, Parse_ast.Unknown)))
+ | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_type, _)], _), kid), _) ->
+ typ_arg (Typ_arg_typ (mk_typ (Typ_var kid)))
+
+let fold_union_quant quants (QI_aux (qi, l)) =
+ match qi with
+ | QI_id kind_id -> quants @ [kinded_id_arg kind_id]
+ | _ -> quants
+
+let check_type_union env variant typq (Tu_aux (tu, l)) =
+ let ret_typ = app_typ variant (List.fold_left fold_union_quant [] (quant_items typq)) in
+ match tu with
+ | Tu_id v -> Env.add_union_id v (typq, ret_typ) env
+ | Tu_ty_id (typ, v) -> Env.add_val_spec v (typq, mk_typ (Typ_fn (typ, ret_typ, no_effect))) env
+
let check_typedef env (TD_aux (tdef, (l, _))) =
let td_err () = raise (Reporting_basic.err_unreachable Parse_ast.Unknown "Unimplemented Typedef") in
match tdef with
@@ -2108,7 +2234,13 @@ let check_typedef env (TD_aux (tdef, (l, _))) =
DEF_type (TD_aux (tdef, (l, None))), Env.add_typ_synonym id (fun _ -> typ) env
| TD_record(id, nmscm, typq, fields, _) ->
DEF_type (TD_aux (tdef, (l, None))), Env.add_record id typq fields env
- | TD_variant(id, nmscm, typq, arms, _) -> td_err ()
+ | TD_variant(id, nmscm, typq, arms, _) ->
+ let env =
+ env
+ |> Env.add_variant id (typq, arms)
+ |> (fun env -> List.fold_left (fun env tu -> check_type_union env id typq tu) env arms)
+ in
+ DEF_type (TD_aux (tdef, (l, None))), env
| TD_enum(id, nmscm, ids, _) ->
DEF_type (TD_aux (tdef, (l, None))), Env.add_enum id ids env
| TD_register(id, base, top, ranges) -> DEF_type (TD_aux (tdef, (l, None))), check_register env id base top ranges
diff --git a/src/type_check_new.mli b/src/type_check_new.mli
index 6c67d84b..b3749cbb 100644
--- a/src/type_check_new.mli
+++ b/src/type_check_new.mli
@@ -47,7 +47,7 @@ exception Type_error of l * string;;
type mut = Immutable | Mutable
-type lvar = Register of typ | Enum of typ | Local of mut * typ | Unbound
+type lvar = Register of typ | Enum of typ | Local of mut * typ | Union of typquant * typ | Unbound
module Env : sig
(* Env.t is the type of environments *)