diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/rewrites.ml | 6 | ||||
| -rw-r--r-- | src/type_check.ml | 166 | ||||
| -rw-r--r-- | src/type_check.mli | 6 |
3 files changed, 94 insertions, 84 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index 32ffe54a..6158422e 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1084,7 +1084,7 @@ let remove_bitvector_pat (P_aux (_, (l, _)) as pat) = ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false)) } in - let pat, env = bind_pat env + let pat, env = bind_pat_no_guard env (strip_pat ((fold_pat name_bitvector_roots pat) false)) (pat_typ_of pat) in @@ -1588,11 +1588,11 @@ let rewrite_register_ref_writes (Defs defs) = | BF_aux (BF_range (i, j), _) -> (i, j) | _ -> raise (Reporting_basic.err_unreachable l "unsupported field type") in let mk_num_exp i = mk_lit_exp (L_num i) in - let reg_pat, reg_env = bind_pat env (mk_pat (P_typ (rtyp, mk_pat (P_id (mk_id "reg"))))) rtyp in + let reg_pat, reg_env = bind_pat_no_guard env (mk_pat (P_typ (rtyp, mk_pat (P_id (mk_id "reg"))))) rtyp in let inferred_get = infer_exp reg_env (mk_exp (E_vector_subrange (mk_exp (E_id (mk_id "reg")), mk_num_exp i, mk_num_exp j))) in let ftyp = typ_of inferred_get in - let v_pat, v_env = bind_pat reg_env (mk_pat (P_typ (ftyp, mk_pat (P_id (mk_id "v"))))) ftyp in + let v_pat, v_env = bind_pat_no_guard reg_env (mk_pat (P_typ (ftyp, mk_pat (P_id (mk_id "v"))))) ftyp in let inferred_set = infer_exp v_env (mk_exp (E_vector_update_subrange (mk_exp (E_id (mk_id "reg")), mk_num_exp i, mk_num_exp j, mk_exp (E_id (mk_id "v"))))) in let set_args = P_aux (P_tup [reg_pat; v_pat], (l, Some (env, tuple_typ [rtyp; ftyp], no_effect))) in diff --git a/src/type_check.ml b/src/type_check.ml index a4b4bb39..9d235cb4 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -2054,16 +2054,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ annot_exp (E_case (inferred_exp, List.map (fun case -> check_case env inferred_typ case typ) cases)) typ | E_try (exp, cases), _ -> let checked_exp = crule check_exp env exp typ in - let check_case pat typ = match pat with - | Pat_aux (Pat_exp (pat, case), (l, _)) -> - let tpat, env = bind_pat env pat exc_typ in - Pat_aux (Pat_exp (tpat, crule check_exp env case typ), (l, None)) - | Pat_aux (Pat_when (pat, guard, case), (l, _)) -> - let tpat, env = bind_pat env pat exc_typ in - let checked_guard = check_exp env guard bool_typ in - Pat_aux (Pat_when (tpat, checked_guard, crule check_exp env case typ), (l, None)) - in - annot_exp (E_try (checked_exp, List.map (fun case -> check_case case typ) cases)) typ + annot_exp (E_try (checked_exp, List.map (fun case -> check_case env exc_typ case typ) cases)) typ | E_cons (x, xs), _ -> begin match is_list (Env.expand_synonyms env typ) with @@ -2118,11 +2109,11 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | LB_val (P_aux (P_typ (ptyp, _), _) as pat, bind) -> Env.wf_typ env ptyp; let checked_bind = crule check_exp env bind ptyp in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in annot_exp (E_let (LB_aux (LB_val (tpat, checked_bind), (let_loc, None)), crule check_exp env exp typ)) typ | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env bind in - let tpat, env = bind_pat env pat (typ_of inferred_bind) in + let tpat, env = bind_pat_no_guard env pat (typ_of inferred_bind) in annot_exp (E_let (LB_aux (LB_val (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ end | E_app_infix (x, op, y), _ -> @@ -2181,7 +2172,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | _ -> let inferred_bind = irule infer_exp env bind in inferred_bind, typ_of inferred_bind in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) let env = match bind_exp with | E_aux (E_assert (E_aux (E_constraint nc, _), _), _) -> @@ -2219,7 +2210,17 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ and check_case env pat_typ pexp typ = let pat,guard,case,((l,_) as annot) = destruct_pexp pexp in match bind_pat env pat pat_typ with - | tpat, env -> + | tpat, env, guards -> + let guard = match guard, guards with + | None, h::t -> Some (h,t) + | Some x, l -> Some (x,l) + | None, [] -> None + in + let guard = match guard with + | Some (h,t) -> + Some (List.fold_left (fun acc guard -> mk_exp (E_app_infix (acc, mk_id "&", guard))) h t) + | None -> None + in let checked_guard, env' = match guard with | None -> None, env | Some guard -> @@ -2229,16 +2230,6 @@ and check_case env pat_typ pexp typ = in let checked_case = crule check_exp env' case typ in construct_pexp (tpat, checked_guard, checked_case, (l, None)) - | exception (Type_error _ as typ_exn) -> - match pat with - | P_aux (P_lit lit, _) -> - let guard' = mk_exp (E_app_infix (mk_exp (E_id (mk_id "p#")), mk_id "==", mk_exp (E_lit lit))) in - let guard = match guard with - | None -> guard' - | Some guard -> mk_exp (E_app_infix (guard, mk_id "&", guard')) - in - check_case env pat_typ (Pat_aux (Pat_when (mk_pat (P_id (mk_id "p#")), guard, case), annot)) typ - | _ -> raise typ_exn (* type_coercion env exp typ takes a fully annoted (i.e. already type checked) expression exp, and attempts to cast (coerce) it to the @@ -2304,26 +2295,31 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ = try_casts casts end +and bind_pat_no_guard env (P_aux (_,(l,_)) as pat) typ = + match bind_pat env pat typ with + | _, _, _::_ -> typ_error l "Literal patterns not supported here" + | tpat, env, [] -> tpat, env + and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) = typ_print ("Binding " ^ string_of_pat pat ^ " to " ^ string_of_typ typ); let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in let switch_typ (P_aux (pat_aux, (l, Some (env, _, eff)))) typ = P_aux (pat_aux, (l, Some (env, typ, eff))) in - let bind_tuple_pat (tpats, env) pat typ = - let tpat, env = bind_pat env pat typ in tpat :: tpats, env + let bind_tuple_pat (tpats, env, guards) pat typ = + let tpat, env, guards' = bind_pat env pat typ in tpat :: tpats, env, guards' @ guards in match pat_aux with | P_id v -> begin match Env.lookup_id v env with - | Local (Immutable, _) | Unbound -> annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env + | Local (Immutable, _) | Unbound -> annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env, [] | Local (Mutable, _) | Register _ -> typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat) - | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env + | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env, [] | Union (typq, ctor_typ) -> begin try let _ = unify l env ctor_typ typ in - annot_pat (P_id v) typ, env + annot_pat (P_id v) typ, env, [] with | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m) end @@ -2336,34 +2332,34 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) let env = Env.add_typ_var kid BK_nat env in let ex_typ = typ_subst_nexp kid' (Nexp_var kid) ex_typ in let env = Env.add_constraint (nc_subst_nexp kid' (Nexp_var kid) nc) env in - let typed_pat, env = bind_pat env pat ex_typ in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat ex_typ in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | Some _, _ -> typ_error l ("Cannot bind type variable pattern against multiple argument existential") | None, Typ_aux (Typ_id id, _) when Id.compare id (mk_id "int") == 0 -> let env = Env.add_typ_var kid BK_nat env in - let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | None, Typ_aux (Typ_id id, _) when Id.compare id (mk_id "nat") == 0 -> let env = Env.add_typ_var kid BK_nat env in let env = Env.add_constraint (nc_gt (nvar kid) (nint 0)) env in - let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | None, Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp lo, _); Typ_arg_aux (Typ_arg_nexp hi, _)]), _) when Id.compare id (mk_id "range") == 0 -> let env = Env.add_typ_var kid BK_nat env in let env = Env.add_constraint (nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi)) env in - let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | None, _ -> typ_error l ("Cannot bind type variable against non existential or numeric type") end - | P_wild -> annot_pat P_wild typ, env + | P_wild -> annot_pat P_wild typ, env, [] | P_cons (hd_pat, tl_pat) -> begin match Env.expand_synonyms env typ with | Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> - let hd_pat, env = bind_pat env hd_pat ltyp in - let tl_pat, env = bind_pat env tl_pat typ in - annot_pat (P_cons (hd_pat, tl_pat)) typ, env + let hd_pat, env, hd_guards = bind_pat env hd_pat ltyp in + let tl_pat, env, tl_guards = bind_pat env tl_pat typ in + annot_pat (P_cons (hd_pat, tl_pat)) typ, env, hd_guards @ tl_guards | _ -> typ_error l "Cannot match cons pattern against non-list type" end | P_list pats -> @@ -2371,32 +2367,32 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) match Env.expand_synonyms env typ with | Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> let rec process_pats env = function - | [] -> [], env + | [] -> [], env, [] | (pat :: pats) -> - let pat', env = bind_pat env pat ltyp in - let pats', env = process_pats env pats in - pat' :: pats', env + let pat', env, guards = bind_pat env pat ltyp in + let pats', env, guards' = process_pats env pats in + pat' :: pats', env, guards @ guards' in - let pats, env = process_pats env pats in - annot_pat (P_list pats) typ, env + let pats, env, guards = process_pats env pats in + annot_pat (P_list pats) typ, env, guards | _ -> typ_error l ("Cannot match list pattern " ^ string_of_pat pat ^ " against non-list type " ^ string_of_typ typ) end | P_tup [] -> begin match Env.expand_synonyms env typ with | Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" -> - annot_pat (P_tup []) typ, env + annot_pat (P_tup []) typ, env, [] | _ -> typ_error l "Cannot match unit pattern against non-unit type" end | P_tup pats -> begin match Env.expand_synonyms env typ with | Typ_aux (Typ_tup typs, _) -> - let tpats, env = - try List.fold_left2 bind_tuple_pat ([], env) pats typs with + let tpats, env, guards = + try List.fold_left2 bind_tuple_pat ([], env, []) pats typs with | Invalid_argument _ -> typ_error l "Tuple pattern and tuple type have different length" in - annot_pat (P_tup (List.rev tpats)) typ, env + annot_pat (P_tup (List.rev tpats)) typ, env, guards | _ -> typ_error l "Cannot bind tuple pattern against non tuple type" end | P_app (f, pats) when Env.is_union_constructor f env -> @@ -2420,11 +2416,11 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) else (); let ret_typ' = subst_unifiers unifiers ret_typ in - let tpats, env = - try List.fold_left2 bind_tuple_pat ([], env) pats (untuple arg_typ') with + let tpats, env, guards = + try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with | Invalid_argument _ -> typ_error l "Union constructor pattern arguments have incorrect length" in - annot_pat (P_app (f, List.rev tpats)) typ, env + annot_pat (P_app (f, List.rev tpats)) typ, env, guards with | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m) end @@ -2433,12 +2429,19 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) | P_app (f, _) when not (Env.is_union_constructor f env) -> typ_error l (string_of_id f ^ " is not a union constructor in pattern " ^ string_of_pat pat) | P_as (pat, id) -> - let (typed_pat, env) = bind_pat env pat typ in - annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env + let (typed_pat, env, guards) = bind_pat env pat typ in + annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env, guards | _ -> - let (inferred_pat, env) = infer_pat env pat in - subtyp l env (pat_typ_of inferred_pat) typ; - switch_typ inferred_pat typ, env + let (inferred_pat, env, guards) = infer_pat env pat in + match subtyp l env (pat_typ_of inferred_pat) typ with + | () -> switch_typ inferred_pat typ, env, guards + | exception (Type_error _ as typ_exn) -> + match pat_aux with + | P_lit lit -> + let guard = mk_exp (E_app_infix (mk_exp (E_id (mk_id "p#")), mk_id "==", mk_exp (E_lit lit))) in + let (typed_pat, env, guards) = bind_pat env (mk_pat (P_id (mk_id "p#"))) typ in + typed_pat, env, guard::guards + | _ -> raise typ_exn and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in @@ -2450,31 +2453,32 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = typ_error l ("Cannot infer identifier in pattern " ^ string_of_pat pat ^ " - try adding a type annotation") | Local (Mutable, _) | Register _ -> typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat) - | Enum enum -> annot_pat (P_id v) enum, env + | Enum enum -> annot_pat (P_id v) enum, env, [] end | P_typ (typ_annot, pat) -> Env.wf_typ env typ_annot; - let (typed_pat, env) = bind_pat env pat typ_annot in - annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env + let (typed_pat, env, guards) = bind_pat env pat typ_annot in + annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env, guards | P_lit lit -> - annot_pat (P_lit lit) (infer_lit env lit), env + annot_pat (P_lit lit) (infer_lit env lit), env, [] | P_vector (pat :: pats) -> - let fold_pats (pats, env) pat = - let typed_pat, env = bind_pat env pat bit_typ in - pats @ [typed_pat], env + let fold_pats (pats, env, guards) pat = + let typed_pat, env, guards' = bind_pat env pat bit_typ in + pats @ [typed_pat], env, guards' @ guards in - let ((typed_pat :: typed_pats) as pats), env = - List.fold_left fold_pats ([], env) (pat :: pats) in + let ((typed_pat :: typed_pats) as pats), env, guards = + List.fold_left fold_pats ([], env, []) (pat :: pats) in let len = nexp_simp (nint (List.length pats)) in let etyp = pat_typ_of typed_pat in List.map (fun pat -> typ_equality l env etyp (pat_typ_of pat)) pats; - annot_pat (P_vector pats) (lvector_typ env len etyp), env + annot_pat (P_vector pats) (lvector_typ env len etyp), env, guards | P_vector_concat (pat :: pats) -> - let fold_pats (pats, env) pat = - let inferred_pat, env = infer_pat env pat in - pats @ [inferred_pat], env + let fold_pats (pats, env, guards) pat = + let inferred_pat, env, guards' = infer_pat env pat in + pats @ [inferred_pat], env, guards' @ guards in - let (inferred_pat :: inferred_pats), env = List.fold_left fold_pats ([], env) (pat :: pats) in + let (inferred_pat :: inferred_pats), env, guards = + List.fold_left fold_pats ([], env, []) (pat :: pats) in let (_, len, _, vtyp) = destruct_vec_typ l env (pat_typ_of inferred_pat) in let fold_len len pat = let (_, len', _, vtyp') = destruct_vec_typ l env (pat_typ_of pat) in @@ -2482,10 +2486,12 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = nsum len len' in let len = nexp_simp (List.fold_left fold_len len inferred_pats) in - annot_pat (P_vector_concat (inferred_pat :: inferred_pats)) (lvector_typ env len vtyp), env + annot_pat (P_vector_concat (inferred_pat :: inferred_pats)) (lvector_typ env len vtyp), env, guards | P_as (pat, id) -> - let (typed_pat, env) = infer_pat env pat in - annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env + let (typed_pat, env, guards) = infer_pat env pat in + annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), + Env.add_local id (Immutable, pat_typ_of typed_pat) env, + guards | _ -> typ_error l ("Couldn't infer type of pattern " ^ string_of_pat pat) and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as exp) = @@ -2905,7 +2911,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | _ -> let inferred_bind = irule infer_exp env bind in inferred_bind, typ_of inferred_bind in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) let env = match bind_exp with | E_aux (E_assert (E_aux (E_constraint nc, _), _), _) -> @@ -2923,7 +2929,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env bind in inferred_bind, pat, typ_of inferred_bind in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in let inferred_exp = irule infer_exp env exp in annot_exp (E_let (LB_aux (LB_val (tpat, bind_exp), (let_loc, None)), inferred_exp)) (typ_of inferred_exp) | _ -> typ_error l ("Cannot infer type of: " ^ string_of_exp exp) @@ -3368,11 +3374,11 @@ let check_letdef env (LB_aux (letbind, (l, _))) = match letbind with | LB_val (P_aux (P_typ (typ_annot, pat), _), bind) -> let checked_bind = crule check_exp env (strip_exp bind) typ_annot in - let tpat, env = bind_pat env (strip_pat pat) typ_annot in + let tpat, env = bind_pat_no_guard env (strip_pat pat) typ_annot in [DEF_val (LB_aux (LB_val (P_aux (P_typ (typ_annot, tpat), (l, Some (env, typ_annot, no_effect))), checked_bind), (l, None)))], env | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env (strip_exp bind) in - let tpat, env = bind_pat env (strip_pat pat) (typ_of inferred_bind) in + let tpat, env = bind_pat_no_guard env (strip_pat pat) (typ_of inferred_bind) in [DEF_val (LB_aux (LB_val (tpat, inferred_bind), (l, None)))], env end diff --git a/src/type_check.mli b/src/type_check.mli index 5066553e..d531a2a8 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -206,7 +206,11 @@ val prove : Env.t -> n_constraint -> bool val subtype_check : Env.t -> typ -> typ -> bool -val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t +val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t * unit Ast.exp list +(* Variant that doesn't introduce new guards for literal patterns, but raises + a type error instead. This should always be safe to use on patterns that + have previously been type checked. *) +val bind_pat_no_guard : Env.t -> unit pat -> typ -> tannot pat * Env.t (* Partial functions: The expressions and patterns passed to these functions must be guaranteed to have tannots of the form Some (env, |
