summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlasdair Armstrong2017-07-25 16:22:21 +0100
committerAlasdair Armstrong2017-07-25 16:22:21 +0100
commite68bc2728db442d531f431aa050768e18486849d (patch)
treed9e867310f4e8342ed2c459317f5ef07af0ee722
parent3ff8009f1dc81593a972eb2050f7e1159aba718a (diff)
Add instantiation_of helper function to type_check.mli that returns
the instantiated type variables in a function application
-rw-r--r--src/type_check.ml52
-rw-r--r--src/type_check.mli11
2 files changed, 47 insertions, 16 deletions
diff --git a/src/type_check.ml b/src/type_check.ml
index ee88f746..cf4cf179 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -1285,20 +1285,20 @@ let subst_args_unifiers unifiers typ_args =
in
List.fold_left subst_unifier typ_args (KBindings.bindings unifiers)
+let merge_unifiers l kid uvar1 uvar2 =
+ match uvar1, uvar2 with
+ | Some (U_nexp n1), Some (U_nexp n2) ->
+ if nexp_identical n1 n2 then Some (U_nexp n1)
+ else unify_error l ("Multiple non-identical unifiers for " ^ string_of_kid kid
+ ^ ": " ^ string_of_nexp n1 ^ " and " ^ string_of_nexp n2)
+ | Some _, Some _ -> unify_error l "Multiple non-identical non-nexp unifiers"
+ | None, Some u2 -> Some u2
+ | Some u1, None -> Some u1
+ | None, None -> None
+
let unify l env typ1 typ2 =
typ_print ("Unify " ^ string_of_typ typ1 ^ " with " ^ string_of_typ typ2);
let goals = KidSet.inter (KidSet.diff (typ_frees typ1) (typ_frees typ2)) (typ_frees typ1) in
- let merge_unifiers l kid uvar1 uvar2 =
- match uvar1, uvar2 with
- | Some (U_nexp n1), Some (U_nexp n2) ->
- if nexp_identical n1 n2 then Some (U_nexp n1)
- else unify_error l ("Multiple non-identical unifiers for " ^ string_of_kid kid
- ^ ": " ^ string_of_nexp n1 ^ " and " ^ string_of_nexp n2)
- | Some _, Some _ -> unify_error l "Multiple non-identical non-nexp unifiers"
- | None, Some u2 -> Some u2
- | Some u1, None -> Some u1
- | None, None -> None
- in
let rec unify_typ l (Typ_aux (typ1_aux, _) as typ1) (Typ_aux (typ2_aux, _) as typ2) =
typ_debug ("UNIFYING TYPES " ^ string_of_typ typ1 ^ " AND " ^ string_of_typ typ2);
match typ1_aux, typ2_aux with
@@ -1362,6 +1362,11 @@ let unify l env typ1 typ2 =
let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in
unify_typ l typ1 typ2
+let merge_uvars l unifiers1 unifiers2 =
+ try KBindings.merge (merge_unifiers l) unifiers1 unifiers2
+ with
+ | Unification_error (_, m) -> typ_error l ("Could not merge unification variables: " ^ m)
+
(**************************************************************************)
(* 5. Type checking expressions *)
(**************************************************************************)
@@ -1460,8 +1465,14 @@ let typ_of_annot (l, tannot) = match tannot with
| Some (_, typ, _) -> typ
| None -> raise (Reporting_basic.err_unreachable l "no type annotation")
+let env_of_annot (l, tannot) = match tannot with
+ | Some (env, _, _) -> env
+ | None -> raise (Reporting_basic.err_unreachable l "no type annotation")
+
let typ_of (E_aux (_, (l, tannot))) = typ_of_annot (l, tannot)
+let env_of (E_aux (_, (l, tannot))) = env_of_annot (l, tannot)
+
let pat_typ_of (P_aux (_, (l, tannot))) = typ_of_annot (l, tannot)
(* Flow typing *)
@@ -1709,7 +1720,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
begin
match Env.lookup_id id env with
| Union (typq, ctor_typ) ->
- let inferred_exp = infer_funapp' l env id (typq, mk_typ (Typ_fn (unit_typ, ctor_typ, no_effect))) [mk_lit L_unit] (Some typ) in
+ let inferred_exp = fst (infer_funapp' l env id (typq, mk_typ (Typ_fn (unit_typ, ctor_typ, no_effect))) [mk_lit L_unit] (Some typ)) in
annot_exp (E_id id) (typ_of inferred_exp)
| _ -> assert false (* Unreachble due to guard *)
end
@@ -2139,7 +2150,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
(* Accessing a field of a record *)
| Typ_aux (Typ_id rectyp, _) as typ when Env.is_record rectyp env ->
begin
- let inferred_acc = infer_funapp' l (Env.no_casts env) field (Env.get_accessor field env) [strip_exp inferred_exp] None in
+ let inferred_acc, _ = infer_funapp' l (Env.no_casts env) field (Env.get_accessor field env) [strip_exp inferred_exp] None in
match inferred_acc with
| E_aux (E_app (field, [inferred_exp]) ,_) -> annot_exp (E_field (inferred_exp, field)) (typ_of inferred_acc)
| _ -> assert false (* Unreachable *)
@@ -2217,10 +2228,17 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
annot_exp (E_assert (checked_test, checked_msg)) unit_typ
| _ -> typ_error l ("Cannot infer type of: " ^ string_of_exp exp)
-and infer_funapp l env f xs ret_ctx_typ = infer_funapp' l env f (Env.get_val_spec f env) xs ret_ctx_typ
+and infer_funapp l env f xs ret_ctx_typ = fst (infer_funapp' l env f (Env.get_val_spec f env) xs ret_ctx_typ)
+
+and instantiation_of (E_aux (exp_aux, (l, _)) as exp) =
+ let env = env_of exp in
+ match exp_aux with
+ | E_app (f, xs) -> snd (infer_funapp' l (Env.no_casts env) f (Env.get_val_spec f env) (List.map strip_exp xs) (Some (typ_of exp)))
+ | _ -> invalid_arg ("instantiation_of expected application, got " ^ string_of_exp exp)
and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
let annot_exp exp typ eff = E_aux (exp, (l, Some (env, typ, eff))) in
+ let all_unifiers = ref KBindings.empty in
let rec number n = function
| [] -> []
| (x :: xs) -> (n, x) :: number (n + 1) xs
@@ -2255,6 +2273,7 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
try
let iarg, unifiers = type_coercion_unify env iarg typ in
typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers));
+ all_unifiers := merge_uvars l !all_unifiers unifiers;
let utyps' = List.map (subst_unifiers unifiers) utyps in
let typs' = List.map (subst_unifiers unifiers) typs in
let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
@@ -2278,6 +2297,7 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
typ_debug ("INSTANTIATE RETURN:" ^ string_of_typ ret_typ);
let unifiers = try unify l env ret_typ rct with Unification_error _ -> typ_debug "UERROR"; KBindings.empty in
typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers));
+ all_unifiers := merge_uvars l !all_unifiers unifiers;
let typs' = List.map (subst_unifiers unifiers) typs in
let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
let ret_typ' = subst_unifiers unifiers ret_typ in
@@ -2299,8 +2319,8 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
| _ -> typ_error l (string_of_typ f_typ ^ " is not a function type")
in
match ret_ctx_typ with
- | None -> exp
- | Some rct -> type_coercion env exp rct
+ | None -> exp, !all_unifiers
+ | Some rct -> type_coercion env exp rct, !all_unifiers
(**************************************************************************)
(* 6. Effect system *)
diff --git a/src/type_check.mli b/src/type_check.mli
index 0c943f6f..5a237b2f 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -189,11 +189,22 @@ val check_exp : Env.t -> unit exp -> typ -> tannot exp
val typ_of : tannot exp -> typ
val typ_of_annot : Ast.l * tannot -> typ
+val env_of : tannot exp -> Env.t
+
val pat_typ_of : tannot pat -> typ
val effect_of : tannot exp -> effect
val effect_of_annot : tannot -> effect
+type uvar =
+ | U_nexp of nexp
+ | U_order of order
+ | U_effect of effect
+ | U_typ of typ
+
+(* Throws Invalid_argument if the argument is not a E_app expression *)
+val instantiation_of : tannot exp -> uvar KBindings.t
+
val propagate_exp_effect : tannot exp -> tannot exp
(* Fully type-check an AST