diff options
| author | Alasdair Armstrong | 2017-07-10 16:32:50 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2017-07-10 16:32:50 +0100 |
| commit | a88c8a7734213816165cf7afe5e70571696d8f40 (patch) | |
| tree | 7a1b4b3dbb225170217a92832318776107513519 /src | |
| parent | 83a2816d83e206a8b41e72ea8e9a932c0d23b2bb (diff) | |
Adder pattern matching for union types
Diffstat (limited to 'src')
| -rw-r--r-- | src/type_check_new.ml | 78 |
1 files changed, 72 insertions, 6 deletions
diff --git a/src/type_check_new.ml b/src/type_check_new.ml index 1056f710..281ead22 100644 --- a/src/type_check_new.ml +++ b/src/type_check_new.ml @@ -46,7 +46,7 @@ open Util open Ast_util open Big_int -let debug = ref 0 +let debug = ref 1 let depth = ref 0 let rec indent n = match n with @@ -115,6 +115,8 @@ let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown) let nc_lt n1 n2 = nc_lteq n1 (nsum n2 (nconstant 1)) let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nconstant 1)) +let mk_lit l = E_aux (E_lit (L_aux (l, Parse_ast.Unknown)), (Parse_ast.Unknown, ())) + let nc_negate (NC_aux (nc, _)) = match nc with | NC_bounded_ge (n1, n2) -> Some (nc_lt n1 n2) @@ -404,6 +406,9 @@ end = struct let fresh = fresh_kid env in (typquant_subst_kid kid fresh typq, typ_subst_kid kid fresh typ) + let freshen_bind env bind = + List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) + let get_val_spec id env = try let bind = Bindings.find id env.top_val_specs in @@ -646,7 +651,14 @@ end = struct let (enum, _) = List.find (fun (enum, ctors) -> IdSet.mem id ctors) (Bindings.bindings env.enums) in Enum (mk_typ (Typ_id enum)) with - | Not_found -> Unbound + | Not_found -> + begin + try + let (typq, typ) = freshen_bind env (Bindings.find id env.union_ids) in + Union (typq, typ) + with + | Not_found -> Unbound + end end end @@ -1270,6 +1282,7 @@ let is_typ_kid kid = function let rec instantiate_quants quants kid uvar = match quants with | [] -> [] | ((QI_aux (QI_id kinded_id, _) as quant) :: quants) -> + typ_debug ("instantiating quant " ^ string_of_quant_item quant); begin match uvar with | U_nexp nexp -> @@ -1427,6 +1440,11 @@ let rec filter_casts env from_typ to_typ casts = end | [] -> [] +let is_union_id id env = + match Env.lookup_id id env with + | Union (_, _) -> true + | _ -> false + let crule r env exp typ = incr depth; typ_print ("Check " ^ string_of_exp exp ^ " <= " ^ string_of_typ typ); @@ -1516,12 +1534,20 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ else typ_error l "List length didn't match" (* FIXME: improve error message *) | _ -> typ_error l "Cannot check list constant against non-constant length vector type" end - | E_lit (L_aux (L_undef, _) as lit), _ -> + | E_lit (L_aux (L_undef, _) as lit), _ -> annot_exp_effect (E_lit lit) typ (mk_effect [BE_undef]) (* This rule allows registers of type t to be passed by name with type register<t>*) | E_id reg, Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "register" -> let rtyp = Env.get_register reg env in subtyp l env rtyp typ; annot_exp (E_id reg) typ (* CHECK: is this subtyp the correct way around? *) + | E_id id, _ when is_union_id id env -> + begin + match Env.lookup_id id env with + | Union (typq, ctor_typ) -> + let inferred_exp = infer_funapp' l env id (typq, mk_typ (Typ_fn (unit_typ, ctor_typ, no_effect))) [mk_lit L_unit] (Some typ) in + annot_exp (E_id id) (typ_of inferred_exp) + | _ -> assert false (* Unreachble due to guard *) + end | _, _ -> let inferred_exp = irule infer_exp env exp in type_coercion env inferred_exp typ @@ -1591,6 +1617,9 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ = and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) = let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in let switch_typ (P_aux (pat_aux, (l, Some (env, _, eff)))) typ = P_aux (pat_aux, (l, Some (env, typ, eff))) in + let bind_tuple_pat (tpats, env) pat typ = + let tpat, env = bind_pat env pat typ in tpat :: tpats, env + in match pat_aux with | P_id v -> begin @@ -1599,15 +1628,20 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) | Local (Mutable, _) | Register _ -> typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat) | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env + | Union (typq, ctor_typ) -> + begin + try + let _ = unify l env ctor_typ typ in + annot_pat (P_id v) typ, env + with + | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m) + end end | P_wild -> annot_pat P_wild typ, env | P_tup pats -> begin match typ_aux with | Typ_tup typs -> - let bind_tuple_pat (tpats, env) pat typ = - let tpat, env = bind_pat env pat typ in tpat :: tpats, env - in let tpats, env = try List.fold_left2 bind_tuple_pat ([], env) pats typs with | Invalid_argument _ -> typ_error l "Tuple pattern and tuple type have different length" @@ -1615,6 +1649,37 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) annot_pat (P_tup (List.rev tpats)) typ, env | _ -> typ_error l "Cannot bind tuple pattern against non tuple type" end + | P_app (f, pats) -> + begin + let (typq, ctor_typ) = Env.get_val_spec f env in + let quants = quant_items typq in + let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with + | Typ_tup typs -> typs + | _ -> [typ] + in + match ctor_typ with + | Typ_aux (Typ_fn (arg_typ, ret_typ, _), _) -> + begin + try + typ_debug ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ); + let unifiers = unify l env ret_typ typ in + typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)); + let arg_typ' = subst_unifiers unifiers arg_typ in + let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in + if (match quants' with [] -> false | _ -> true) + then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) + else (); + let ret_typ' = subst_unifiers unifiers ret_typ in + let tpats, env = + try List.fold_left2 bind_tuple_pat ([], env) pats (untuple arg_typ') with + | Invalid_argument _ -> typ_error l "Union constructor pattern arguments have incorrect length" + in + annot_pat (P_app (f, List.rev tpats)) typ, env + with + | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m) + end + | _ -> typ_error l ("Mal-formed constructor " ^ string_of_id f) + end | _ -> let (inferred_pat, env) = infer_pat env pat in subtyp l env (pat_typ_of inferred_pat) typ; @@ -1936,6 +2001,7 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ = | None -> (quants, typs, ret_typ) | Some rct -> begin + typ_debug ("RCT is " ^ string_of_typ rct); typ_debug ("INSTANTIATE RETURN:" ^ string_of_typ ret_typ); let unifiers = try unify l env ret_typ rct with Unification_error _ -> typ_debug "UERROR"; KBindings.empty in typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)); |
