diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/pretty_print_lem.ml | 3 | ||||
| -rw-r--r-- | src/type_check.ml | 92 | ||||
| -rw-r--r-- | src/type_check.mli | 6 |
3 files changed, 66 insertions, 35 deletions
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 9d169108..ea34ef3d 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -327,6 +327,9 @@ let doc_typ_lem, doc_atomic_typ_lem = String.concat ", " (List.map string_of_kid bad) ^ " escape into Lem")) end + (* AA: I think the correct thing is likely to filter out + non-integer kinded_id's, then use the above code. *) + | Typ_exist (_,_,Typ_aux(Typ_app(id,[_]),_)) when string_of_id id = "atom_bool" -> string "bool" | Typ_exist _ -> unreachable l __POS__ "Non-integer existentials currently unsupported in Lem" (* TODO *) | Typ_bidir _ -> unreachable l __POS__ "Lem doesn't support bidir types" | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" diff --git a/src/type_check.ml b/src/type_check.ml index 6ddc31a7..ba7b2acb 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -220,19 +220,33 @@ and strip_kinded_id_aux = function and strip_kind = function | K_aux (k_aux, _) -> K_aux (k_aux, Parse_ast.Unknown) +let rec name_pat (P_aux (aux, _)) = + match aux with + | P_id id | P_as (_, id) -> Some ("_" ^ string_of_id id) + | P_typ (_, pat) | P_var (pat, _) -> name_pat pat + | _ -> None + let ex_counter = ref 0 -let fresh_existential ?name:(n="") k = - let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#" ^ n), Parse_ast.Unknown) in +let fresh_existential k = + let fresh = Kid_aux (Var ("'ex" ^ string_of_int !ex_counter ^ "#"), Parse_ast.Unknown) in incr ex_counter; mk_kopt k fresh -let destruct_exist_plain typ = +let named_existential k = function + | Some n -> mk_kopt k (mk_kid n) + | None -> fresh_existential k + +let destruct_exist_plain ?name:(name=None) typ = match typ with + | Typ_aux (Typ_exist ([kopt], nc, typ), _) -> + let kid, fresh = kopt_kid kopt, named_existential (unaux_kind (kopt_kind kopt)) name in + let nc = constraint_subst kid (arg_kopt fresh) nc in + let typ = typ_subst kid (arg_kopt fresh) typ in + Some ([fresh], nc, typ) | Typ_aux (Typ_exist (kopts, nc, typ), _) -> + let add_num i = match name with Some n -> Some (n ^ string_of_int i) | None -> None in let fresh_kopts = - List.map (fun kopt -> (kopt_kid kopt, - fresh_existential ~name:(string_of_id (id_of_kid (kopt_kid kopt))) (unaux_kind (kopt_kind kopt)))) - kopts + List.mapi (fun i kopt -> (kopt_kid kopt, named_existential (unaux_kind (kopt_kind kopt)) (add_num i))) kopts in let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_kopt fresh) nc) nc fresh_kopts in let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_kopt fresh) typ) typ fresh_kopts in @@ -247,36 +261,36 @@ let destruct_exist_plain typ = - int => ['n], true, 'n (where x is fresh) - atom('n) => [], true, 'n **) -let destruct_numeric typ = - match destruct_exist_plain typ, typ with +let destruct_numeric ?name:(name=None) typ = + match destruct_exist_plain ~name:name typ, typ with | Some (kids, nc, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _)), _ when string_of_id id = "atom" -> Some (List.map kopt_kid kids, nc, nexp) | None, Typ_aux (Typ_app (id, [A_aux (A_nexp nexp, _)]), _) when string_of_id id = "atom" -> Some ([], nc_true, nexp) | None, Typ_aux (Typ_app (id, [A_aux (A_nexp lo, _); A_aux (A_nexp hi, _)]), _) when string_of_id id = "range" -> - let kid = kopt_kid (fresh_existential K_int) in + let kid = kopt_kid (named_existential K_int name) in Some ([kid], nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi), nvar kid) | None, Typ_aux (Typ_id id, _) when string_of_id id = "nat" -> - let kid = kopt_kid (fresh_existential K_int) in + let kid = kopt_kid (named_existential K_int name) in Some ([kid], nc_lteq (nint 0) (nvar kid), nvar kid) | None, Typ_aux (Typ_id id, _) when string_of_id id = "int" -> - let kid = kopt_kid (fresh_existential K_int) in + let kid = kopt_kid (named_existential K_int name) in Some ([kid], nc_true, nvar kid) | _, _ -> None -let destruct_boolean = function +let destruct_boolean ?name:(name=None) = function | Typ_aux (Typ_id (Id_aux (Id "bool", _)), _) -> let kid = kopt_kid (fresh_existential K_bool) in Some (kid, nc_var kid) | _ -> None -let destruct_exist typ = - match destruct_numeric typ with +let destruct_exist ?name:(name=None) typ = + match destruct_numeric ~name:name typ with | Some (kids, nc, nexp) -> Some (List.map (mk_kopt K_int) kids, nc, atom_typ nexp) | None -> - match destruct_boolean typ with + match destruct_boolean ~name:name typ with | Some (kid, nc) -> Some ([mk_kopt K_bool kid], nc_true, atom_bool_typ nc) - | None -> destruct_exist_plain typ + | None -> destruct_exist_plain ~name:name typ let adding = Util.("Adding " |> darkgray |> clear) @@ -384,6 +398,7 @@ end = struct variants : (typquant * type_union list) Bindings.t; mappings : (typquant * typ * typ) Bindings.t; typ_vars : (Ast.l * kind_aux) KBindings.t; + shadow_vars : int KBindings.t; typ_synonyms : (t -> typ_arg list -> typ_arg) Bindings.t; num_defs : nexp Bindings.t; overloads : (id list) Bindings.t; @@ -412,6 +427,7 @@ end = struct variants = Bindings.empty; mappings = Bindings.empty; typ_vars = KBindings.empty; + shadow_vars = KBindings.empty; typ_synonyms = Bindings.empty; num_defs = Bindings.empty; overloads = Bindings.empty; @@ -1027,13 +1043,21 @@ end = struct with | Not_found -> Unbound - let add_typ_var l (KOpt_aux (KOpt_kind (K_aux (k, _), kid), _) as kopt) env = - if KBindings.mem kid env.typ_vars - then typ_error (kid_loc kid) ("type variable " ^ string_of_kinded_id kopt ^ " is already bound") - else - begin - typ_print (lazy (adding ^ "type variable " ^ string_of_kid kid ^ " : " ^ string_of_kind_aux k)); - { env with typ_vars = KBindings.add kid (l, k) env.typ_vars } + let add_typ_var l (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) env = + if KBindings.mem v env.typ_vars then begin + let n = match KBindings.find_opt v env.shadow_vars with Some n -> n | None -> 0 in + let s_l, s_k = KBindings.find v env.typ_vars in + let s_v = Kid_aux (Var (string_of_kid v ^ "#" ^ string_of_int n), l) in + typ_print (lazy (Printf.sprintf "%stype variable (shadowing %s) %s : %s" adding (string_of_kid s_v) (string_of_kid v) (string_of_kind_aux k))); + { env with + constraints = List.map (constraint_subst v (arg_kopt (mk_kopt s_k s_v))) env.constraints; + typ_vars = KBindings.add v (l, k) (KBindings.add s_v (s_l, s_k) env.typ_vars); + shadow_vars = KBindings.add v (n + 1) env.shadow_vars + } + end + else begin + typ_print (lazy (adding ^ "type variable " ^ string_of_kid v ^ " : " ^ string_of_kind_aux k)); + { env with typ_vars = KBindings.add v (l, k) env.typ_vars } end let add_num_def id nexp env = @@ -1174,8 +1198,8 @@ let bind_numeric l typ env = (** Pull an (potentially)-existentially qualified type into the global typing environment **) -let bind_existential l typ env = - match destruct_exist (Env.expand_synonyms env typ) with +let bind_existential l name typ env = + match destruct_exist ~name:name (Env.expand_synonyms env typ) with | Some (kids, nc, typ) -> typ, add_existential l kids nc env | None -> typ, env @@ -1439,13 +1463,17 @@ and unify_typ_arg l env goals (A_aux (aux1, _) as typ_arg1) (A_aux (aux2, _) as | A_typ typ1, A_typ typ2 -> unify_typ l env goals typ1 typ2 | A_nexp nexp1, A_nexp nexp2 -> unify_nexp l env goals nexp1 nexp2 | A_order ord1, A_order ord2 -> unify_order l goals ord1 ord2 - | A_bool nc1, A_bool nc2 -> unify_constraint l goals nc1 nc2 + | A_bool nc1, A_bool nc2 -> unify_constraint l env goals nc1 nc2 | _, _ -> unify_error l ("Could not unify type arguments " ^ string_of_typ_arg typ_arg1 ^ " and " ^ string_of_typ_arg typ_arg2) -and unify_constraint l goals (NC_aux (aux1, _) as nc1) (NC_aux (aux2, _) as nc2) = +and unify_constraint l env goals (NC_aux (aux1, _) as nc1) (NC_aux (aux2, _) as nc2) = typ_debug (lazy (Util.("Unify constraint " |> magenta |> clear) ^ string_of_n_constraint nc1 ^ " and " ^ string_of_n_constraint nc2)); match aux1, aux2 with | NC_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_bool nc2) + | NC_and (nc1a, nc2a), NC_and (nc1b, nc2b) | NC_or (nc1a, nc2a), NC_or (nc1b, nc2b) -> + merge_uvars l (unify_constraint l env goals nc1a nc1b) (unify_constraint l env goals nc2a nc2b) + | NC_app (f1, args1), NC_app (f2, args2) when Id.compare f1 f2 = 0 && List.length args1 = List.length args2 -> + List.fold_left (merge_uvars l) KBindings.empty (List.map2 (unify_typ_arg l env goals) args1 args2) | _, _ -> unify_error l ("Could not unify constraints " ^ string_of_n_constraint nc1 ^ " and " ^ string_of_n_constraint nc2) and unify_order l goals (Ord_aux (aux1, _) as ord1) (Ord_aux (aux2, _) as ord2) = @@ -1489,7 +1517,7 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au then unify_nexp l env goals n1a (nsum nexp2 n1b) else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) | Nexp_times (n1a, n1b) -> - (* f we have SMT operations div and mod, then we can use the + (* If we have SMT operations div and mod, then we can use the property that mod(m, C) = 0 && C != 0 --> (C * n = m <--> n = m / C) @@ -2525,7 +2553,7 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ = typ_print (lazy ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " for unification")); try let inferred_cast = irule infer_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) in - let ityp, env = bind_existential l (typ_of inferred_cast) env in + let ityp, env = bind_existential l None (typ_of inferred_cast) env in inferred_cast, unify l env goals typ ityp, env with | Type_error (_, err) -> try_casts casts @@ -2535,7 +2563,7 @@ and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ = begin try typ_debug (lazy ("Coercing unification: from " ^ string_of_typ (typ_of annotated_exp) ^ " to " ^ string_of_typ typ)); - let atyp, env = bind_existential l (typ_of annotated_exp) env in + let atyp, env = bind_existential l None (typ_of annotated_exp) env in annotated_exp, unify l env goals typ atyp, env with | Unification_error (_, m) when Env.allow_casts env -> @@ -2549,7 +2577,7 @@ and bind_pat_no_guard env (P_aux (_,(l,_)) as pat) typ = | tpat, env, [] -> tpat, env and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) = - let (Typ_aux (typ_aux, _) as typ), env = bind_existential l typ env in + let (Typ_aux (typ_aux, _) as typ), env = bind_existential l (name_pat pat) typ env in typ_print (lazy (Util.("Binding " |> yellow |> clear) ^ string_of_pat pat ^ " to " ^ string_of_typ typ)); let annot_pat pat typ' = P_aux (pat, (l, mk_expected_tannot env typ' no_effect (Some typ))) in let switch_typ pat typ = match pat with @@ -3442,7 +3470,7 @@ and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = exp and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) (Typ_aux (typ_aux, _) as typ) = - let (Typ_aux (typ_aux, _) as typ), env = bind_existential l typ env in + let (Typ_aux (typ_aux, _) as typ), env = bind_existential l None typ env in typ_print (lazy (Util.("Binding " |> yellow |> clear) ^ string_of_mpat mpat ^ " to " ^ string_of_typ typ)); let annot_mpat mpat typ' = MP_aux (mpat, (l, mk_expected_tannot env typ' no_effect (Some typ))) in let switch_typ mpat typ = match mpat with diff --git a/src/type_check.mli b/src/type_check.mli index 81682606..c470e9c4 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -212,8 +212,8 @@ val add_typquant : Ast.l -> typquant -> Env.t -> Env.t is not existential. This function will pick a fresh name for the existential to ensure that no name-clashes occur. The "plain" version does not treat numeric types as existentials. *) -val destruct_exist_plain : typ -> (kinded_id list * n_constraint * typ) option -val destruct_exist : typ -> (kinded_id list * n_constraint * typ) option +val destruct_exist_plain : ?name:string option -> typ -> (kinded_id list * n_constraint * typ) option +val destruct_exist : ?name:string option -> typ -> (kinded_id list * n_constraint * typ) option val add_existential : Ast.l -> kinded_id list -> n_constraint -> Env.t -> Env.t @@ -356,7 +356,7 @@ val destruct_atom_nexp : Env.t -> typ -> nexp option val destruct_range : Env.t -> typ -> (kid list * n_constraint * nexp * nexp) option -val destruct_numeric : typ -> (kid list * n_constraint * nexp) option +val destruct_numeric : ?name:string option -> typ -> (kid list * n_constraint * nexp) option val destruct_vector : Env.t -> typ -> (nexp * order * typ) option |
