summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/pretty_print_lem.ml3
-rw-r--r--src/type_check.ml92
-rw-r--r--src/type_check.mli6
3 files changed, 66 insertions, 35 deletions
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 9d169108..ea34ef3d 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -327,6 +327,9 @@ let doc_typ_lem, doc_atomic_typ_lem =
String.concat ", " (List.map string_of_kid bad) ^
" escape into Lem"))
end
+ (* AA: I think the correct thing is likely to filter out
+ non-integer kinded_id's, then use the above code. *)
+ | Typ_exist (_,_,Typ_aux(Typ_app(id,[_]),_)) when string_of_id id = "atom_bool" -> string "bool"
| Typ_exist _ -> unreachable l __POS__ "Non-integer existentials currently unsupported in Lem" (* TODO *)
| Typ_bidir _ -> unreachable l __POS__ "Lem doesn't support bidir types"
| Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown"
diff --git a/src/type_check.ml b/src/type_check.ml
index 6ddc31a7..ba7b2acb 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -220,19 +220,33 @@ and strip_kinded_id_aux = function
and strip_kind = function
| K_aux (k_aux, _) -> K_aux (k_aux, Parse_ast.Unknown)
+let rec name_pat (P_aux (aux, _)) =
+ match aux with
+ | P_id id | P_as (_, id) -> Some ("_" ^ string_of_id id)
+ | P_typ (_, pat) | P_var (pat, _) -> name_pat pat
+ | _ -> None
+
let ex_counter = ref 0
-let fresh_existential ?name:(n="") k =
- let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#" ^ n), Parse_ast.Unknown) in
+let fresh_existential k =
+ let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#"), Parse_ast.Unknown) in
incr ex_counter; mk_kopt k fresh
-let destruct_exist_plain typ =
+let named_existential k = function
+ | Some n -> mk_kopt k (mk_kid n)
+ | None -> fresh_existential k
+
+let destruct_exist_plain ?name:(name=None) typ =
match typ with
+ | Typ_aux (Typ_exist ([kopt], nc, typ), _) ->
+ let kid, fresh = kopt_kid kopt, named_existential (unaux_kind (kopt_kind kopt)) name in
+ let nc = constraint_subst kid (arg_kopt fresh) nc in
+ let typ = typ_subst kid (arg_kopt fresh) typ in
+ Some ([fresh], nc, typ)
| Typ_aux (Typ_exist (kopts, nc, typ), _) ->
+ let add_num i = match name with Some n -> Some (n ^ string_of_int i) | None -> None in
let fresh_kopts =
- List.map (fun kopt -> (kopt_kid kopt,
- fresh_existential ~name:(string_of_id (id_of_kid (kopt_kid kopt))) (unaux_kind (kopt_kind kopt))))
- kopts
+ List.mapi (fun i kopt -> (kopt_kid kopt, named_existential (unaux_kind (kopt_kind kopt)) (add_num i))) kopts
in
let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_kopt fresh) nc) nc fresh_kopts in
let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_kopt fresh) typ) typ fresh_kopts in
@@ -247,36 +261,36 @@ let destruct_exist_plain typ =
- int => ['n], true, 'n (where x is fresh)
- atom('n) => [], true, 'n
**)
-let destruct_numeric typ =
- match destruct_exist_plain typ, typ with
+let destruct_numeric ?name:(name=None) typ =
+ match destruct_exist_plain ~name:name typ, typ with
| Some (kids, nc, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _)), _ when string_of_id id = "atom" ->
Some (List.map kopt_kid kids, nc, nexp)
| None, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _) when string_of_id id = "atom" ->
Some ([], nc_true, nexp)
| None, Typ_aux (Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]), _) when string_of_id id = "range" ->
- let kid = kopt_kid (fresh_existential K_int) in
+ let kid = kopt_kid (named_existential K_int name) in
Some ([kid], nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi), nvar kid)
| None, Typ_aux (Typ_id id, _) when string_of_id id = "nat" ->
- let kid = kopt_kid (fresh_existential K_int) in
+ let kid = kopt_kid (named_existential K_int name) in
Some ([kid], nc_lteq (nint 0) (nvar kid), nvar kid)
| None, Typ_aux (Typ_id id, _) when string_of_id id = "int" ->
- let kid = kopt_kid (fresh_existential K_int) in
+ let kid = kopt_kid (named_existential K_int name) in
Some ([kid], nc_true, nvar kid)
| _, _ -> None
-let destruct_boolean = function
+let destruct_boolean ?name:(name=None) = function
| Typ_aux (Typ_id (Id_aux (Id "bool", _)), _) ->
let kid = kopt_kid (fresh_existential K_bool) in
Some (kid, nc_var kid)
| _ -> None
-let destruct_exist typ =
- match destruct_numeric typ with
+let destruct_exist ?name:(name=None) typ =
+ match destruct_numeric ~name:name typ with
| Some (kids, nc, nexp) -> Some (List.map (mk_kopt K_int) kids, nc, atom_typ nexp)
| None ->
- match destruct_boolean typ with
+ match destruct_boolean ~name:name typ with
| Some (kid, nc) -> Some ([mk_kopt K_bool kid], nc_true, atom_bool_typ nc)
- | None -> destruct_exist_plain typ
+ | None -> destruct_exist_plain ~name:name typ
let adding = Util.("Adding " |> darkgray |> clear)
@@ -384,6 +398,7 @@ end = struct
variants : (typquant * type_union list) Bindings.t;
mappings : (typquant * typ * typ) Bindings.t;
typ_vars : (Ast.l * kind_aux) KBindings.t;
+ shadow_vars : int KBindings.t;
typ_synonyms : (t -> typ_arg list -> typ_arg) Bindings.t;
num_defs : nexp Bindings.t;
overloads : (id list) Bindings.t;
@@ -412,6 +427,7 @@ end = struct
variants = Bindings.empty;
mappings = Bindings.empty;
typ_vars = KBindings.empty;
+ shadow_vars = KBindings.empty;
typ_synonyms = Bindings.empty;
num_defs = Bindings.empty;
overloads = Bindings.empty;
@@ -1027,13 +1043,21 @@ end = struct
with
| Not_found -> Unbound
- let add_typ_var l (KOpt_aux (KOpt_kind (K_aux (k, _), kid), _) as kopt) env =
- if KBindings.mem kid env.typ_vars
- then typ_error (kid_loc kid) ("type variable " ^ string_of_kinded_id kopt ^ " is already bound")
- else
- begin
- typ_print (lazy (adding ^ "type variable " ^ string_of_kid kid ^ " : " ^ string_of_kind_aux k));
- { env with typ_vars = KBindings.add kid (l, k) env.typ_vars }
+ let add_typ_var l (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) env =
+ if KBindings.mem v env.typ_vars then begin
+ let n = match KBindings.find_opt v env.shadow_vars with Some n -> n | None -> 0 in
+ let s_l, s_k = KBindings.find v env.typ_vars in
+ let s_v = Kid_aux (Var (string_of_kid v ^ "#" ^ string_of_int n), l) in
+ typ_print (lazy (Printf.sprintf "%stype variable (shadowing %s) %s : %s" adding (string_of_kid s_v) (string_of_kid v) (string_of_kind_aux k)));
+ { env with
+ constraints = List.map (constraint_subst v (arg_kopt (mk_kopt s_k s_v))) env.constraints;
+ typ_vars = KBindings.add v (l, k) (KBindings.add s_v (s_l, s_k) env.typ_vars);
+ shadow_vars = KBindings.add v (n + 1) env.shadow_vars
+ }
+ end
+ else begin
+ typ_print (lazy (adding ^ "type variable " ^ string_of_kid v ^ " : " ^ string_of_kind_aux k));
+ { env with typ_vars = KBindings.add v (l, k) env.typ_vars }
end
let add_num_def id nexp env =
@@ -1174,8 +1198,8 @@ let bind_numeric l typ env =
(** Pull an (potentially)-existentially qualified type into the global
typing environment **)
-let bind_existential l typ env =
- match destruct_exist (Env.expand_synonyms env typ) with
+let bind_existential l name typ env =
+ match destruct_exist ~name:name (Env.expand_synonyms env typ) with
| Some (kids, nc, typ) -> typ, add_existential l kids nc env
| None -> typ, env
@@ -1439,13 +1463,17 @@ and unify_typ_arg l env goals (A_aux (aux1, _) as typ_arg1) (A_aux (aux2, _) as
| A_typ typ1, A_typ typ2 -> unify_typ l env goals typ1 typ2
| A_nexp nexp1, A_nexp nexp2 -> unify_nexp l env goals nexp1 nexp2
| A_order ord1, A_order ord2 -> unify_order l goals ord1 ord2
- | A_bool nc1, A_bool nc2 -> unify_constraint l goals nc1 nc2
+ | A_bool nc1, A_bool nc2 -> unify_constraint l env goals nc1 nc2
| _, _ -> unify_error l ("Could not unify type arguments " ^ string_of_typ_arg typ_arg1 ^ " and " ^ string_of_typ_arg typ_arg2)
-and unify_constraint l goals (NC_aux (aux1, _) as nc1) (NC_aux (aux2, _) as nc2) =
+and unify_constraint l env goals (NC_aux (aux1, _) as nc1) (NC_aux (aux2, _) as nc2) =
typ_debug (lazy (Util.("Unify constraint " |> magenta |> clear) ^ string_of_n_constraint nc1 ^ " and " ^ string_of_n_constraint nc2));
match aux1, aux2 with
| NC_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_bool nc2)
+ | NC_and (nc1a, nc2a), NC_and (nc1b, nc2b) | NC_or (nc1a, nc2a), NC_or (nc1b, nc2b) ->
+ merge_uvars l (unify_constraint l env goals nc1a nc1b) (unify_constraint l env goals nc2a nc2b)
+ | NC_app (f1, args1), NC_app (f2, args2) when Id.compare f1 f2 = 0 && List.length args1 = List.length args2 ->
+ List.fold_left (merge_uvars l) KBindings.empty (List.map2 (unify_typ_arg l env goals) args1 args2)
| _, _ -> unify_error l ("Could not unify constraints " ^ string_of_n_constraint nc1 ^ " and " ^ string_of_n_constraint nc2)
and unify_order l goals (Ord_aux (aux1, _) as ord1) (Ord_aux (aux2, _) as ord2) =
@@ -1489,7 +1517,7 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au
then unify_nexp l env goals n1a (nsum nexp2 n1b)
else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
| Nexp_times (n1a, n1b) ->
- (* f we have SMT operations div and mod, then we can use the
+ (* If we have SMT operations div and mod, then we can use the
property that
mod(m, C) = 0 && C != 0 --> (C * n = m <--> n = m / C)
@@ -2525,7 +2553,7 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ =
typ_print (lazy ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " for unification"));
try
let inferred_cast = irule infer_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) in
- let ityp, env = bind_existential l (typ_of inferred_cast) env in
+ let ityp, env = bind_existential l None (typ_of inferred_cast) env in
inferred_cast, unify l env goals typ ityp, env
with
| Type_error (_, err) -> try_casts casts
@@ -2535,7 +2563,7 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ =
begin
try
typ_debug (lazy ("Coercing unification: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ));
- let atyp, env = bind_existential l (typ_of annotated_exp) env in
+ let atyp, env = bind_existential l None (typ_of annotated_exp) env in
annotated_exp, unify l env goals typ atyp, env
with
| Unification_error (_, m) when Env.allow_casts env ->
@@ -2549,7 +2577,7 @@ and bind_pat_no_guard env (P_aux (_,(l,_)) as pat) typ =
| tpat, env, [] -> tpat, env
and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) =
- let (Typ_aux (typ_aux, _) as typ), env = bind_existential l typ env in
+ let (Typ_aux (typ_aux, _) as typ), env = bind_existential l (name_pat pat) typ env in
typ_print (lazy (Util.("Binding " |> yellow |> clear) ^ string_of_pat pat ^ " to " ^ string_of_typ typ));
let annot_pat pat typ' = P_aux (pat, (l, mk_expected_tannot env typ' no_effect (Some typ))) in
let switch_typ pat typ = match pat with
@@ -3442,7 +3470,7 @@ and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ =
exp
and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) (Typ_aux (typ_aux, _) as typ) =
- let (Typ_aux (typ_aux, _) as typ), env = bind_existential l typ env in
+ let (Typ_aux (typ_aux, _) as typ), env = bind_existential l None typ env in
typ_print (lazy (Util.("Binding " |> yellow |> clear) ^ string_of_mpat mpat ^ " to " ^ string_of_typ typ));
let annot_mpat mpat typ' = MP_aux (mpat, (l, mk_expected_tannot env typ' no_effect (Some typ))) in
let switch_typ mpat typ = match mpat with
diff --git a/src/type_check.mli b/src/type_check.mli
index 81682606..c470e9c4 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -212,8 +212,8 @@ val add_typquant : Ast.l -> typquant -> Env.t -> Env.t
is not existential. This function will pick a fresh name for the
existential to ensure that no name-clashes occur. The "plain"
version does not treat numeric types as existentials. *)
-val destruct_exist_plain : typ -> (kinded_id list * n_constraint * typ) option
-val destruct_exist : typ -> (kinded_id list * n_constraint * typ) option
+val destruct_exist_plain : ?name:string option -> typ -> (kinded_id list * n_constraint * typ) option
+val destruct_exist : ?name:string option -> typ -> (kinded_id list * n_constraint * typ) option
val add_existential : Ast.l -> kinded_id list -> n_constraint -> Env.t -> Env.t
@@ -356,7 +356,7 @@ val destruct_atom_nexp : Env.t -> typ -> nexp option
val destruct_range : Env.t -> typ -> (kid list * n_constraint * nexp * nexp) option
-val destruct_numeric : typ -> (kid list * n_constraint * nexp) option
+val destruct_numeric : ?name:string option -> typ -> (kid list * n_constraint * nexp) option
val destruct_vector : Env.t -> typ -> (nexp * order * typ) option