diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/rewrites.ml | 128 | ||||
| -rw-r--r-- | src/type_check.ml | 28 |
2 files changed, 82 insertions, 74 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index d8cb5a5d..bc9792ef 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2043,60 +2043,74 @@ let rewrite_split_fun_constr_pats fun_name env (Defs defs) = let clauses, aux_funs = List.fold_left (fun (clauses, aux_funs) (FCL_aux (FCL_Funcl (id, pexp), fannot) as clause) -> - let pat, guard, exp, annot = destruct_pexp pexp in - match pat with - | P_aux (P_app (constr_id, args), pannot) -> - let argstup_typ = tuple_typ (List.map typ_of_pat args) in - let pannot' = swaptyp argstup_typ pannot in - let pat' = - match args with - | [arg] -> arg - | _ -> P_aux (P_tup args, pannot') - in - let pexp' = construct_pexp (pat', guard, exp, annot) in - let aux_fun_id = prepend_id (fun_name ^ "_") constr_id in - let aux_funcl = FCL_aux (FCL_Funcl (aux_fun_id, pexp'), pannot') in - begin - try - let aux_clauses = Bindings.find aux_fun_id aux_funs in - clauses, - Bindings.add aux_fun_id (aux_clauses @ [aux_funcl]) aux_funs - with Not_found -> - let argpats, argexps = List.split (List.mapi - (fun idx (P_aux (_,a) as pat) -> - let id = match pat_var pat with - | Some id -> id - | None -> mk_id ("arg" ^ string_of_int idx) - in - P_aux (P_id id, a), E_aux (E_id id, a)) - args) - in - let pexp = construct_pexp - (P_aux (P_app (constr_id, argpats), pannot), - None, - E_aux (E_app (aux_fun_id, argexps), annot), - annot) - in - clauses @ [FCL_aux (FCL_Funcl (id, pexp), fannot)], - Bindings.add aux_fun_id [aux_funcl] aux_funs - end - | _ -> clauses @ [clause], aux_funs) + let pat, guard, exp, annot = destruct_pexp pexp in + match pat with + | P_aux (P_app (constr_id, args), pannot) -> + let ctor_typq, ctor_typ = Env.get_union_id constr_id env in + let args = match args with [P_aux (P_tup args, _)] -> args | _ -> args in + let argstup_typ = tuple_typ (List.map typ_of_pat args) in + let pannot' = swaptyp argstup_typ pannot in + let pat' = + match args with + | [arg] -> arg + | _ -> P_aux (P_tup args, pannot') + in + let pexp' = construct_pexp (pat', guard, exp, annot) in + let aux_fun_id = prepend_id (fun_name ^ "_") constr_id in + let aux_funcl = FCL_aux (FCL_Funcl (aux_fun_id, pexp'), pannot') in + begin + try + let aux_clauses = Bindings.find aux_fun_id aux_funs in + clauses, + Bindings.add aux_fun_id (aux_clauses @ [(aux_funcl, ctor_typq, ctor_typ)]) aux_funs + with Not_found -> + let argpats, argexps = List.split (List.mapi + (fun idx (P_aux (_,a) as pat) -> + let id = match pat_var pat with + | Some id -> id + | None -> mk_id ("arg" ^ string_of_int idx) + in + P_aux (P_id id, a), E_aux (E_id id, a)) + args) + in + let pexp = construct_pexp + (P_aux (P_app (constr_id, argpats), pannot), + None, + E_aux (E_app (aux_fun_id, argexps), annot), + annot) + in + clauses @ [FCL_aux (FCL_Funcl (id, pexp), fannot)], + Bindings.add aux_fun_id [(aux_funcl, ctor_typq, ctor_typ)] aux_funs + end + | _ -> clauses @ [clause], aux_funs) ([], Bindings.empty) clauses in - let add_aux_def id funcls defs = - let env, args_typ, ret_typ = match funcls with - | FCL_aux (FCL_Funcl (_, pexp), _) :: _ -> + let add_aux_def id aux_funs defs = + let funcls = List.map (fun (fcl, _, _) -> fcl) aux_funs in + let env, quants, args_typ, ret_typ = match aux_funs with + | (FCL_aux (FCL_Funcl (_, pexp), _), ctor_typq, ctor_typ) :: _ -> let pat, _, exp, _ = destruct_pexp pexp in - env_of exp, typ_of_pat pat, typ_of exp + let ctor_quants args_typ = + List.filter (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ args_typ)) + (quant_items ctor_typq) + in + begin match ctor_typ with + | Typ_aux (Typ_fn ([Typ_aux (Typ_exist (kopts, nc, args_typ), _)], _, _), _) -> + env_of exp, ctor_quants args_typ @ List.map mk_qi_kopt kopts @ [mk_qi_nc nc], args_typ, typ_of exp + | Typ_aux (Typ_fn ([args_typ], _, _), _) -> env_of exp, ctor_quants args_typ, args_typ, typ_of exp + | _ -> + raise (Reporting.err_unreachable l __POS__ + ("Union constructor has non-function type: " ^ string_of_typ ctor_typ)) + end | _ -> raise (Reporting.err_unreachable l __POS__ - "rewrite_split_fun_constr_pats: empty auxiliary function") + "rewrite_split_fun_constr_pats: empty auxiliary function") in let eff = List.fold_left - (fun eff (FCL_aux (FCL_Funcl (_, pexp), _)) -> - let _, _, exp, _ = destruct_pexp pexp in - union_effects eff (effect_of exp)) - no_effect funcls + (fun eff (FCL_aux (FCL_Funcl (_, pexp), _)) -> + let _, _, exp, _ = destruct_pexp pexp in + union_effects eff (effect_of exp)) + no_effect funcls in let fun_typ = (* Because we got the argument type from a pattern we need to @@ -2107,27 +2121,9 @@ let rewrite_split_fun_constr_pats fun_name env (Defs defs) = | _ -> function_typ [args_typ] ret_typ eff in - let quant_new_kopts qis = - let quant_kopts = List.fold_left KOptSet.union KOptSet.empty (List.map kopts_of_quant_item qis) in - let typ_kopts = kopts_of_typ fun_typ in - let new_kopts = KOptSet.diff typ_kopts quant_kopts in - List.map mk_qi_kopt (KOptSet.elements new_kopts) - in - let typquant = match typquant with - | TypQ_aux (TypQ_tq qis, l) -> - let qis = - List.filter - (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ fun_typ)) - qis - @ quant_new_kopts qis - in - TypQ_aux (TypQ_tq qis, l) - | _ -> - TypQ_aux (TypQ_tq (List.map mk_qi_kopt (KOptSet.elements (kopts_of_typ fun_typ))), l) - in let val_spec = VS_aux (VS_val_spec - (mk_typschm typquant fun_typ, id, (fun _ -> None), false), + (mk_typschm (mk_typquant quants) fun_typ, id, (fun _ -> None), false), (Parse_ast.Unknown, empty_tannot)) in let fundef = FD_aux (FD_function (r_o, t_o, e_o, funcls), fdannot) in diff --git a/src/type_check.ml b/src/type_check.ml index c1689a82..603052b5 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -432,6 +432,7 @@ module Env : sig val get_typ_var_loc : kid -> t -> Ast.l val get_typ_vars : t -> kind_aux KBindings.t val get_typ_var_locs : t -> Ast.l KBindings.t + val add_typ_var_shadow : l -> kinded_id -> t -> t * kid option val add_typ_var : l -> kinded_id -> t -> t val get_ret_typ : t -> typ option val add_ret_typ : typ -> t -> t @@ -656,10 +657,9 @@ end = struct ^ " with " ^ Util.string_of_list ", " string_of_n_constraint env.constraints) let get_typ_synonym id env = - begin match Bindings.find_opt id env.typ_synonyms with + match Bindings.find_opt id env.typ_synonyms with | Some (typq, arg) -> mk_synonym typq arg | None -> raise Not_found - end let rec expand_constraint_synonyms env (NC_aux (aux, l) as nc) = typ_debug ~level:2 (lazy ("Expanding " ^ string_of_n_constraint nc)); @@ -1208,7 +1208,7 @@ end = struct with | Not_found -> Unbound - let add_typ_var l (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) env = + let add_typ_var_shadow l (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) env = if KBindings.mem v env.typ_vars then begin let n = match KBindings.find_opt v env.shadow_vars with Some n -> n | None -> 0 in let s_l, s_k = KBindings.find v env.typ_vars in @@ -1218,13 +1218,15 @@ end = struct constraints = List.map (constraint_subst v (arg_kopt (mk_kopt s_k s_v))) env.constraints; typ_vars = KBindings.add v (l, k) (KBindings.add s_v (s_l, s_k) env.typ_vars); shadow_vars = KBindings.add v (n + 1) env.shadow_vars - } + }, Some s_v end else begin typ_print (lazy (adding ^ "type variable " ^ string_of_kid v ^ " : " ^ string_of_kind_aux k)); - { env with typ_vars = KBindings.add v (l, k) env.typ_vars } + { env with typ_vars = KBindings.add v (l, k) env.typ_vars }, None end + let add_typ_var l kopt env = fst (add_typ_var_shadow l kopt env) + let get_constraints env = env.constraints let add_constraint constr env = @@ -3133,6 +3135,8 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) end | P_app (f, pats) when Env.is_union_constructor f env -> begin + (* Treat Ctor((p, x)) the same as Ctor(p, x) *) + let pats = match pats with [P_aux (P_tup pats, _)] -> pats | _ -> pats in let (typq, ctor_typ) = Env.get_union_id f env in let quants = quant_items typq in let untuple (Typ_aux (typ_aux, _) as typ) = match typ_aux with @@ -3152,6 +3156,7 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) typ_raise env l (Err_unresolved_quants (f, quants', Env.get_locals env, Env.get_constraints env)) else (); let ret_typ' = subst_unifiers unifiers ret_typ in + let arg_typ', env = bind_existential l None arg_typ' env in let tpats, env, guards = try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with | Invalid_argument _ -> typ_error env l "Union constructor pattern arguments have incorrect length" @@ -3325,15 +3330,20 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = | _ -> typ_error env l ("Couldn't infer type of pattern " ^ string_of_pat pat) and bind_typ_pat env (TP_aux (typ_pat_aux, l) as typ_pat) (Typ_aux (typ_aux, _) as typ) = + typ_print (lazy (Util.("Binding type pattern " |> yellow |> clear) ^ string_of_typ_pat typ_pat ^ " to " ^ string_of_typ typ)); match typ_pat_aux, typ_aux with | TP_wild, _ -> env | TP_var kid, _ -> begin match typ_nexps typ, typ_constraints typ with | [nexp], [] -> - Env.add_constraint (nc_eq (nvar kid) nexp) (Env.add_typ_var l (mk_kopt K_int kid) env) + let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_int kid) env in + let nexp = match shadow with Some s_v -> nexp_subst kid (arg_nexp (nvar s_v)) nexp | None -> nexp in + Env.add_constraint (nc_eq (nvar kid) nexp) env | [], [nc] -> - Env.add_constraint (nc_and (nc_or (nc_not nc) (nc_var kid)) (nc_or nc (nc_not (nc_var kid)))) (Env.add_typ_var l (mk_kopt K_bool kid) env) + let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_bool kid) env in + let nexp = match shadow with Some s_v -> constraint_subst kid (arg_bool (nc_var s_v)) nc | None -> nc in + Env.add_constraint (nc_and (nc_or (nc_not nc) (nc_var kid)) (nc_or nc (nc_not (nc_var kid)))) env | [], [] -> typ_error env l ("No numeric expressions in " ^ string_of_typ typ ^ " to bind " ^ string_of_kid kid ^ " to") | _, _ -> @@ -3346,7 +3356,9 @@ and bind_typ_pat_arg env (TP_aux (typ_pat_aux, l) as typ_pat) (A_aux (typ_arg_au match typ_pat_aux, typ_arg_aux with | TP_wild, _ -> env | TP_var kid, A_nexp nexp -> - Env.add_constraint (nc_eq (nvar kid) nexp) (Env.add_typ_var l (mk_kopt K_int kid) env) + let env, shadow = Env.add_typ_var_shadow l (mk_kopt K_int kid) env in + let nexp = match shadow with Some s_v -> nexp_subst kid (arg_nexp (nvar s_v)) nexp | None -> nexp in + Env.add_constraint (nc_eq (nvar kid) nexp) env | _, A_typ typ -> bind_typ_pat env typ_pat typ | _, A_order _ -> typ_error env l "Cannot bind type pattern against order" | _, _ -> typ_error env l ("Couldn't bind type argument " ^ string_of_typ_arg typ_arg ^ " with " ^ string_of_typ_pat typ_pat) |
