summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlasdair Armstrong2017-07-10 16:32:50 +0100
committerAlasdair Armstrong2017-07-10 16:32:50 +0100
commita88c8a7734213816165cf7afe5e70571696d8f40 (patch)
tree7a1b4b3dbb225170217a92832318776107513519 /src
parent83a2816d83e206a8b41e72ea8e9a932c0d23b2bb (diff)
Adder pattern matching for union types
Diffstat (limited to 'src')
-rw-r--r--src/type_check_new.ml78
1 files changed, 72 insertions, 6 deletions
diff --git a/src/type_check_new.ml b/src/type_check_new.ml
index 1056f710..281ead22 100644
--- a/src/type_check_new.ml
+++ b/src/type_check_new.ml
@@ -46,7 +46,7 @@ open Util
open Ast_util
open Big_int
-let debug = ref 0
+let debug = ref 1
let depth = ref 0
let rec indent n = match n with
@@ -115,6 +115,8 @@ let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown)
let nc_lt n1 n2 = nc_lteq n1 (nsum n2 (nconstant 1))
let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nconstant 1))
+let mk_lit l = E_aux (E_lit (L_aux (l, Parse_ast.Unknown)), (Parse_ast.Unknown, ()))
+
let nc_negate (NC_aux (nc, _)) =
match nc with
| NC_bounded_ge (n1, n2) -> Some (nc_lt n1 n2)
@@ -404,6 +406,9 @@ end = struct
let fresh = fresh_kid env in
(typquant_subst_kid kid fresh typq, typ_subst_kid kid fresh typ)
+ let freshen_bind env bind =
+ List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars)
+
let get_val_spec id env =
try
let bind = Bindings.find id env.top_val_specs in
@@ -646,7 +651,14 @@ end = struct
let (enum, _) = List.find (fun (enum, ctors) -> IdSet.mem id ctors) (Bindings.bindings env.enums) in
Enum (mk_typ (Typ_id enum))
with
- | Not_found -> Unbound
+ | Not_found ->
+ begin
+ try
+ let (typq, typ) = freshen_bind env (Bindings.find id env.union_ids) in
+ Union (typq, typ)
+ with
+ | Not_found -> Unbound
+ end
end
end
@@ -1270,6 +1282,7 @@ let is_typ_kid kid = function
let rec instantiate_quants quants kid uvar = match quants with
| [] -> []
| ((QI_aux (QI_id kinded_id, _) as quant) :: quants) ->
+ typ_debug ("instantiating quant " ^ string_of_quant_item quant);
begin
match uvar with
| U_nexp nexp ->
@@ -1427,6 +1440,11 @@ let rec filter_casts env from_typ to_typ casts =
end
| [] -> []
+let is_union_id id env =
+ match Env.lookup_id id env with
+ | Union (_, _) -> true
+ | _ -> false
+
let crule r env exp typ =
incr depth;
typ_print ("Check " ^ string_of_exp exp ^ " <= " ^ string_of_typ typ);
@@ -1516,12 +1534,20 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
else typ_error l "List length didn't match" (* FIXME: improve error message *)
| _ -> typ_error l "Cannot check list constant against non-constant length vector type"
end
- | E_lit (L_aux (L_undef, _) as lit), _ ->
+ | E_lit (L_aux (L_undef, _) as lit), _ ->
annot_exp_effect (E_lit lit) typ (mk_effect [BE_undef])
(* This rule allows registers of type t to be passed by name with type register<t>*)
| E_id reg, Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "register" ->
let rtyp = Env.get_register reg env in
subtyp l env rtyp typ; annot_exp (E_id reg) typ (* CHECK: is this subtyp the correct way around? *)
+ | E_id id, _ when is_union_id id env ->
+ begin
+ match Env.lookup_id id env with
+ | Union (typq, ctor_typ) ->
+ let inferred_exp = infer_funapp' l env id (typq, mk_typ (Typ_fn (unit_typ, ctor_typ, no_effect))) [mk_lit L_unit] (Some typ) in
+ annot_exp (E_id id) (typ_of inferred_exp)
+ | _ -> assert false (* Unreachble due to guard *)
+ end
| _, _ ->
let inferred_exp = irule infer_exp env exp in
type_coercion env inferred_exp typ
@@ -1591,6 +1617,9 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ =
and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) =
let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in
let switch_typ (P_aux (pat_aux, (l, Some (env, _, eff)))) typ = P_aux (pat_aux, (l, Some (env, typ, eff))) in
+ let bind_tuple_pat (tpats, env) pat typ =
+ let tpat, env = bind_pat env pat typ in tpat :: tpats, env
+ in
match pat_aux with
| P_id v ->
begin
@@ -1599,15 +1628,20 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
| Local (Mutable, _) | Register _ ->
typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat)
| Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env
+ | Union (typq, ctor_typ) ->
+ begin
+ try
+ let _ = unify l env ctor_typ typ in
+ annot_pat (P_id v) typ, env
+ with
+ | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m)
+ end
end
| P_wild -> annot_pat P_wild typ, env
| P_tup pats ->
begin
match typ_aux with
| Typ_tup typs ->
- let bind_tuple_pat (tpats, env) pat typ =
- let tpat, env = bind_pat env pat typ in tpat :: tpats, env
- in
let tpats, env =
try List.fold_left2 bind_tuple_pat ([], env) pats typs with
| Invalid_argument _ -> typ_error l "Tuple pattern and tuple type have different length"
@@ -1615,6 +1649,37 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
annot_pat (P_tup (List.rev tpats)) typ, env
| _ -> typ_error l "Cannot bind tuple pattern against non tuple type"
end
+ | P_app (f, pats) ->
+ begin
+ let (typq, ctor_typ) = Env.get_val_spec f env in
+ let quants = quant_items typq in
+ let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with
+ | Typ_tup typs -> typs
+ | _ -> [typ]
+ in
+ match ctor_typ with
+ | Typ_aux (Typ_fn (arg_typ, ret_typ, _), _) ->
+ begin
+ try
+ typ_debug ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ);
+ let unifiers = unify l env ret_typ typ in
+ typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers));
+ let arg_typ' = subst_unifiers unifiers arg_typ in
+ let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ if (match quants' with [] -> false | _ -> true)
+ then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat)
+ else ();
+ let ret_typ' = subst_unifiers unifiers ret_typ in
+ let tpats, env =
+ try List.fold_left2 bind_tuple_pat ([], env) pats (untuple arg_typ') with
+ | Invalid_argument _ -> typ_error l "Union constructor pattern arguments have incorrect length"
+ in
+ annot_pat (P_app (f, List.rev tpats)) typ, env
+ with
+ | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m)
+ end
+ | _ -> typ_error l ("Mal-formed constructor " ^ string_of_id f)
+ end
| _ ->
let (inferred_pat, env) = infer_pat env pat in
subtyp l env (pat_typ_of inferred_pat) typ;
@@ -1936,6 +2001,7 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
| None -> (quants, typs, ret_typ)
| Some rct ->
begin
+ typ_debug ("RCT is " ^ string_of_typ rct);
typ_debug ("INSTANTIATE RETURN:" ^ string_of_typ ret_typ);
let unifiers = try unify l env ret_typ rct with Unification_error _ -> typ_debug "UERROR"; KBindings.empty in
typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers));