diff options
| author | Alasdair Armstrong | 2018-03-20 17:57:20 +0000 |
|---|---|---|
| committer | Alasdair Armstrong | 2018-03-22 18:58:59 +0000 |
| commit | e33c8546e005fba30ff882b188c86ca03d0917c8 (patch) | |
| tree | cf72fc3066962718d26a76baedd2d11a7be16946 /src | |
| parent | 0860deb52e55b11e39e3470290e07f861f877483 (diff) | |
Fix C compilation for CHERI and MIPS
First, the specialisation of option types has been fixed by allowing
the specialisation of constructor return types - this essentially
means that a constructor, such as Some : 'a -> option('a) can get
specialised to int -> option(int), rather than int -> option('a). This
means that these constructors are treated like GADTs internally. Since
this only happens just before the C translation, I haven't put much
effort into making this very robust so far.
Second, there was a bug in C compilation for the typing of return
expressions in non-unit contexts, which has been fixed.
Finally support for vector literals that are non-bitvectors has been
added.
Diffstat (limited to 'src')
| -rw-r--r-- | src/c_backend.ml | 110 | ||||
| -rw-r--r-- | src/parser.mly | 2 | ||||
| -rw-r--r-- | src/specialize.ml | 51 | ||||
| -rw-r--r-- | src/type_check.ml | 7 |
4 files changed, 147 insertions, 23 deletions
diff --git a/src/c_backend.ml b/src/c_backend.ml index def81c75..398a0281 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -1072,7 +1072,9 @@ let iinit ?loc:(l=Parse_ast.Unknown) ctyp id cval = let iif ?loc:(l=Parse_ast.Unknown) cval then_instrs else_instrs ctyp = I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (instr_number (), l)) let ifuncall ?loc:(l=Parse_ast.Unknown) clexp id cvals ctyp = - I_aux (I_funcall (clexp, id, cvals, ctyp), (instr_number (), l)) + I_aux (I_funcall (clexp, false, id, cvals, ctyp), (instr_number (), l)) +let iextern ?loc:(l=Parse_ast.Unknown) clexp id cvals ctyp = + I_aux (I_funcall (clexp, true, id, cvals, ctyp), (instr_number (), l)) let icopy ?loc:(l=Parse_ast.Unknown) clexp cval = I_aux (I_copy (clexp, cval), (instr_number (), l)) let iconvert ?loc:(l=Parse_ast.Unknown) clexp ctyp1 id ctyp2 = @@ -1127,7 +1129,7 @@ let rec instr_ctyps (I_aux (instr, aux)) = | I_init (ctyp, _, cval) | I_reinit (ctyp, _, cval) -> [ctyp; cval_ctyp cval] | I_if (cval, instrs1, instrs2, ctyp) -> ctyp :: cval_ctyp cval :: List.concat (List.map instr_ctyps instrs1 @ List.map instr_ctyps instrs2) - | I_funcall (_, _, cvals, ctyp) -> + | I_funcall (_, _, _, cvals, ctyp) -> ctyp :: List.map cval_ctyp cvals | I_convert (_, ctyp1, _, ctyp2) -> [ctyp1; ctyp2] | I_copy (_, cval) -> [cval_ctyp cval] @@ -1207,7 +1209,7 @@ let rec pp_instr ?short:(short=false) (I_aux (instr, aux)) = pp_keyword "create" ^^ pp_id id ^^ string " : " ^^ pp_ctyp ctyp ^^ string " = " ^^ pp_cval cval | I_reinit (ctyp, id, cval) -> pp_keyword "recreate" ^^ pp_id id ^^ string " : " ^^ pp_ctyp ctyp ^^ string " = " ^^ pp_cval cval - | I_funcall (x, f, args, ctyp2) -> + | I_funcall (x, _, f, args, ctyp2) -> separate space [ pp_clexp x; string "="; string (string_of_id f |> Util.green |> Util.clear) ^^ parens (separate_map (string ", ") pp_cval args); string ":"; pp_ctyp ctyp2 ] @@ -1445,11 +1447,35 @@ let rec compile_aval ctx = function @ List.concat (List.mapi aval_mask (List.rev avals)), (F_id gs, ctyp), [] - (* - c_error ("Have AV_vector (small/bits): " ^ Pretty_print_sail.to_string (separate_map (comma ^^ space) pp_aval avals)) *) + + (* Compiling a vector literal that isn't a bitvector *) + | AV_vector (avals, Typ_aux (Typ_app (id, [_; Typ_arg_aux (Typ_arg_order ord, _); Typ_arg_aux (Typ_arg_typ typ, _)]), _)) + when string_of_id id = "vector" -> + let len = List.length avals in + let direction = match ord with + | Ord_aux (Ord_inc, _) -> false + | Ord_aux (Ord_dec, _) -> true + | Ord_aux (Ord_var _, _) -> c_error "Polymorphic vector direction found" + in + let vector_ctyp = CT_vector (direction, ctyp_of_typ ctx typ) in + let gs = gensym () in + let aval_set i aval = + let setup, cval, cleanup = compile_aval ctx aval in + setup + @ [iextern (CL_id gs) + (mk_id "internal_vector_update") + [(F_id gs, vector_ctyp); (F_lit (V_int (Big_int.of_int i)), CT_int64); cval] vector_ctyp] + @ cleanup + in + [idecl vector_ctyp gs; + ialloc vector_ctyp gs; + iextern (CL_id gs) (mk_id "internal_vector_init") [(F_lit (V_int (Big_int.of_int len)), CT_int64)] vector_ctyp] + @ List.concat (List.mapi aval_set avals), + (F_id gs, vector_ctyp), + [iclear vector_ctyp gs] | AV_vector _ as aval -> - c_error ("Have AV_vector: " ^ Pretty_print_sail.to_string (pp_aval aval)) + c_error ("Have AV_vector: " ^ Pretty_print_sail.to_string (pp_aval aval) ^ " which is not a vector type") | AV_list (avals, Typ_aux (typ, _)) -> let ctyp = match typ with @@ -1498,6 +1524,8 @@ let compile_funcall ctx id args typ = (Printf.sprintf "Failure when setting up function %s arguments: %s and %s." (string_of_id id) (string_of_ctyp have_ctyp) (string_of_ctyp ctyp)) in + assert (List.length arg_ctyps = List.length args); + let sargs = List.map2 setup_arg arg_ctyps args in let call = @@ -1520,7 +1548,16 @@ let rec compile_match ctx apat cval case_label = | AP_id pid, (frag, ctyp) when Env.is_union_constructor pid ctx.tc_env -> [ijump (F_op (F_field (frag, "kind"), "!=", F_lit (V_ctor_kind (string_of_id pid))), CT_bool) case_label], [] - | AP_global (pid, _), _ -> [icopy (CL_id pid) cval], [] + | AP_global (pid, typ), (frag, ctyp) -> + let global_ctyp = ctyp_of_typ ctx typ in + if ctyp_equal global_ctyp ctyp then + [icopy (CL_id pid) cval], [] + else + begin match frag with + | F_id id -> + [iconvert (CL_id pid) global_ctyp id ctyp], [] + | _ -> c_error "Cannot compile global letbinding" + end | AP_id pid, (frag, ctyp) when is_ct_enum ctyp -> begin match Env.lookup_id pid ctx.tc_env with | Unbound -> [idecl ctyp pid; icopy (CL_id pid) (frag, ctyp)], [] @@ -1591,6 +1628,7 @@ let rec compile_aexp ctx = function letb_setup @ setup, ctyp, call, cleanup @ letb_cleanup else begin + prerr_endline ("Mismatch: " ^ string_of_ctyp body_ctyp ^ " and " ^ string_of_ctyp ctyp); let gs = gensym () in letb_setup @ setup @ [idecl ctyp gs; ialloc ctyp gs; call (CL_id gs)], body_ctyp, @@ -1846,7 +1884,7 @@ let rec compile_aexp ctx = function (* Cleanup info will be re-added by fix_early_return *) let return_setup, cval, _ = compile_aval ctx aval in return_setup @ [ireturn cval], - cval_ctyp cval, + ctyp_of_typ ctx typ, (fun clexp -> icomment "unreachable after return"), [] @@ -1903,7 +1941,10 @@ let compile_type_def ctx (TD_aux (type_def, _)) = { ctx with records = Bindings.add id ctors ctx.records } | TD_variant (id, _, _, tus, _) -> - let compile_tu (Tu_aux (Tu_ty_id (typ, id), _)) = ctyp_of_typ ctx typ, id in + let compile_tu = function + | Tu_aux (Tu_ty_id (Typ_aux (Typ_fn (typ, _, _), _), id), _) -> ctyp_of_typ ctx typ, id + | Tu_aux (Tu_ty_id (typ, id), _) -> ctyp_of_typ ctx typ, id + in let ctus = List.fold_left (fun ctus (ctyp, id) -> Bindings.add id ctyp ctus) Bindings.empty (List.map compile_tu tus) in CTD_variant (id, Bindings.bindings ctus), { ctx with variants = Bindings.add id ctus ctx.variants } @@ -2046,7 +2087,7 @@ let fix_exception_block ctx instrs = @ generate_cleanup (historic @ before) @ [igoto end_block_label] @ rewrite_exception (historic @ before) after - | before, (I_aux (I_funcall (x, f, args, ctyp), _) as funcall) :: after -> + | before, (I_aux (I_funcall (x, _, f, args, ctyp), _) as funcall) :: after -> let effects = match Env.get_val_spec f ctx.tc_env with | _, Typ_aux (Typ_fn (_, _, effects), _) -> effects | exception (Type_error _) -> no_effect (* nullary union constructor, so no val spec *) @@ -2080,8 +2121,9 @@ let fix_exception ctx instrs = let instrs = List.map (map_try_block (fix_exception_block ctx)) instrs in fix_exception_block ctx instrs -let arg_pats ctx (Typ_aux (arg_typ_aux, _) as arg_typ) (P_aux (p_aux, (l, _)) as pat) = +let rec arg_pats ctx (Typ_aux (arg_typ_aux, _) as arg_typ) (P_aux (p_aux, (l, _)) as pat) = match p_aux, arg_typ_aux with + | P_typ (_, pat), _ -> arg_pats ctx arg_typ pat | P_tup pats, Typ_tup arg_typs when List.length pats = List.length arg_typs -> List.map2 (fun pat arg_typ -> pat, ctyp_of_typ ctx arg_typ) pats arg_typs | P_wild, Typ_tup arg_typs -> List.map (fun arg_typ -> pat, ctyp_of_typ ctx arg_typ) arg_typs @@ -2302,7 +2344,7 @@ let instr_deps = function | I_init (ctyp, id, cval) | I_reinit (ctyp, id, cval) -> cval_deps cval, NS.singleton (G_id id) | I_if (cval, _, _, _) -> cval_deps cval, NS.empty | I_jump (cval, label) -> cval_deps cval, NS.singleton (G_label label) - | I_funcall (clexp, _, cvals, _) -> List.fold_left NS.union NS.empty (List.map cval_deps cvals), clexp_deps clexp + | I_funcall (clexp, _, _, cvals, _) -> List.fold_left NS.union NS.empty (List.map cval_deps cvals), clexp_deps clexp | I_convert (clexp, _, id, _) -> NS.singleton (G_id id), clexp_deps clexp | I_copy (clexp, cval) -> cval_deps cval, clexp_deps clexp | I_clear (_, id) -> NS.singleton (G_id id), NS.singleton (G_id id) @@ -2439,8 +2481,8 @@ let rec instrs_rename from_id to_id = | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs -> I_aux (I_if (crename cval, irename then_instrs, irename else_instrs, ctyp), aux) :: irename instrs | I_aux (I_jump (cval, label), aux) :: instrs -> I_aux (I_jump (crename cval, label), aux) :: irename instrs - | I_aux (I_funcall (clexp, id, cvals, ctyp), aux) :: instrs -> - I_aux (I_funcall (lrename clexp, rename id, List.map crename cvals, ctyp), aux) :: irename instrs + | I_aux (I_funcall (clexp, extern, id, cvals, ctyp), aux) :: instrs -> + I_aux (I_funcall (lrename clexp, extern, rename id, List.map crename cvals, ctyp), aux) :: irename instrs | I_aux (I_copy (clexp, cval), aux) :: instrs -> I_aux (I_copy (lrename clexp, crename cval), aux) :: irename instrs | I_aux (I_convert (clexp, ctyp1, id, ctyp2), aux) :: instrs -> I_aux (I_convert (lrename clexp, ctyp1, rename id, ctyp2), aux) :: irename instrs @@ -2579,7 +2621,7 @@ let upper_codegen_id id = string (upper_sgen_id id) let rec sgen_ctyp = function | CT_unit -> "unit" - | CT_bit -> "int" + | CT_bit -> "uint64_t" | CT_bool -> "bool" | CT_uint64 _ -> "uint64_t" | CT_int64 -> "int64_t" @@ -2597,7 +2639,7 @@ let rec sgen_ctyp = function let rec sgen_ctyp_name = function | CT_unit -> "unit" - | CT_bit -> "int" + | CT_bit -> "uint64_t" | CT_bool -> "bool" | CT_uint64 _ -> "uint64_t" | CT_int64 -> "int64_t" @@ -2669,9 +2711,16 @@ let rec codegen_instr fid ctx (I_aux (instr, _)) = string " { /* try */" ^^ jump 2 2 (separate_map hardline (codegen_instr fid ctx) instrs) ^^ hardline ^^ string " }" - | I_funcall (x, f, args, ctyp) -> + | I_funcall (x, extern, f, args, ctyp) -> let c_args = Util.string_of_list ", " sgen_cval args in - let fname = if Env.is_extern f ctx.tc_env "c" then Env.get_extern f ctx.tc_env "c" else sgen_id f in + let fname = + if Env.is_extern f ctx.tc_env "c" then + Env.get_extern f ctx.tc_env "c" + else if extern then + string_of_id f + else + sgen_id f + in let fname = match fname, ctyp with | "internal_pick", _ -> Printf.sprintf "pick_%s" (sgen_ctyp_name ctyp) @@ -2696,6 +2745,8 @@ let rec codegen_instr fid ctx (I_aux (instr, _)) = | "vector_update", CT_uint64 _ -> "update_uint64_t" | "vector_update", CT_bv _ -> "update_bv" | "vector_update", _ -> Printf.sprintf "vector_update_%s" (sgen_ctyp_name ctyp) + | "internal_vector_update", _ -> Printf.sprintf "internal_vector_update_%s" (sgen_ctyp_name ctyp) + | "internal_vector_init", _ -> Printf.sprintf "internal_vector_init_%s" (sgen_ctyp_name ctyp) | "undefined_vector", CT_uint64 _ -> "undefined_uint64_t" | "undefined_vector", CT_bv _ -> "undefined_bv_t" | "undefined_vector", _ -> Printf.sprintf "undefined_vector_%s" (sgen_ctyp_name ctyp) @@ -3059,6 +3110,7 @@ let codegen_list ctx ctyp = ^^ codegen_pick id ctyp ^^ twice hardline end +(* Generate functions for working with non-bit vectors of some specific type. *) let codegen_vector ctx (direction, ctyp) = let id = mk_id (string_of_ctyp (CT_vector (direction, ctyp))) in if IdSet.mem id !generated then @@ -3111,6 +3163,14 @@ let codegen_vector ctx (direction, ctyp) = ^^ string " }\n" ^^ string "}" in + let internal_vector_update = + string (Printf.sprintf "void internal_vector_update_%s(%s *rop, %s op, const int64_t n, const %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + ^^ string (if is_stack_ctyp ctyp then + " rop->data[n] = elem;\n" + else + Printf.sprintf " set_%s((rop->data) + n, elem);\n" (sgen_ctyp_name ctyp)) + ^^ string "}" + in let vector_access = if is_stack_ctyp ctyp then string (Printf.sprintf "%s vector_access_%s(%s op, mpz_t n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) @@ -3123,6 +3183,12 @@ let codegen_vector ctx (direction, ctyp) = ^^ string (Printf.sprintf " set_%s(rop, op.data[m]);\n" (sgen_ctyp_name ctyp)) ^^ string "}" in + let internal_vector_init = + string (Printf.sprintf "void internal_vector_init_%s(%s *rop, const int64_t len) {\n" (sgen_id id) (sgen_id id)) + ^^ string " rop->len = len;\n" + ^^ string (Printf.sprintf " rop->data = malloc(len * sizeof(%s));\n" (sgen_ctyp ctyp)) + ^^ string "}" + in let vector_undefined = string (Printf.sprintf "void undefined_vector_%s(%s *rop, mpz_t len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) ^^ string (Printf.sprintf " rop->len = mpz_get_ui(len);\n") @@ -3144,6 +3210,8 @@ let codegen_vector ctx (direction, ctyp) = ^^ vector_access ^^ twice hardline ^^ vector_set ^^ twice hardline ^^ vector_update ^^ twice hardline + ^^ internal_vector_update ^^ twice hardline + ^^ internal_vector_init ^^ twice hardline end let is_decl = function @@ -3178,6 +3246,12 @@ let codegen_def' ctx = function | _ -> assert false in let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in + if (List.length arg_ctyps <> List.length args) then + c_error ~loc:(id_loc id) ("function arguments " + ^ Util.string_of_list ", " string_of_id args + ^ " matched against type " + ^ Util.string_of_list ", " string_of_ctyp arg_ctyps) + else (); let args = Util.string_of_list ", " (fun x -> x) (List.map2 (fun ctyp arg -> sgen_ctyp ctyp ^ " " ^ sgen_id arg) arg_ctyps args) in let function_header = match ret_arg with diff --git a/src/parser.mly b/src/parser.mly index 40bece90..97b6f28c 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -1164,6 +1164,8 @@ struct_fields: type_union: | id Colon typ { Tu_aux (Tu_ty_id ($3, $1), loc $startpos $endpos) } + | id Colon typ MinusGt typ + { (fun s e -> Tu_aux (Tu_ty_id (mk_typ (ATyp_fn ($3, $5, mk_typ (ATyp_set []) s e)) s e, $1), loc s e)) $startpos $endpos } type_unions: | type_union diff --git a/src/specialize.ml b/src/specialize.ml index 1b978fbc..151a185a 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -57,11 +57,34 @@ let is_typ_ord_uvar = function | Type_check.U_order _ -> true | _ -> false +let rec nexp_simp_typ (Typ_aux (typ_aux, l)) = + let typ_aux = match typ_aux with + | Typ_id v -> Typ_id v + | Typ_var kid -> Typ_var kid + | Typ_tup typs -> Typ_tup (List.map nexp_simp_typ typs) + | Typ_app (f, args) -> Typ_app (f, List.map nexp_simp_typ_arg args) + | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, nexp_simp_typ typ) + | Typ_fn (typ1, typ2, effect) -> Typ_fn (nexp_simp_typ typ1, nexp_simp_typ typ2, effect) + in + Typ_aux (typ_aux, l) +and nexp_simp_typ_arg (Typ_arg_aux (typ_arg_aux, l)) = + let typ_arg_aux = match typ_arg_aux with + | Typ_arg_nexp n -> Typ_arg_nexp (nexp_simp n) + | Typ_arg_typ typ -> Typ_arg_typ (nexp_simp_typ typ) + | Typ_arg_order ord -> Typ_arg_order ord + in + Typ_arg_aux (typ_arg_aux, l) + +let nexp_simp_uvar = function + | Type_check.U_nexp nexp -> (prerr_endline ("Simp nexp " ^ string_of_nexp nexp); Type_check.U_nexp (nexp_simp nexp)) + | Type_check.U_typ typ -> Type_check.U_typ (nexp_simp_typ typ) + | uvar -> uvar + (* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of. This part of the typechecker API is a bit ugly. *) let fix_instantiation instantiation = let instantiation = KBindings.bindings (KBindings.filter (fun _ uvar -> is_typ_ord_uvar uvar) instantiation) in - let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, uvar) instantiation in + let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, nexp_simp_uvar uvar) instantiation in List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation let rec polymorphic_functions is_kopt (Defs defs) = @@ -440,12 +463,28 @@ let rewrite_polymorphic_constructors id ast = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp); rewrite_pat = (fun _ -> fold_pat rewrite_pat)} ast +let kinded_id_arg kind_id = + let typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown) in + match kind_id with + | KOpt_aux (KOpt_none kid, _) -> typ_arg (Typ_arg_nexp (nvar kid)) + | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_nat, _)], _), kid), _) -> typ_arg (Typ_arg_nexp (nvar kid)) + | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_order, _)], _), kid), _) -> + typ_arg (Typ_arg_order (Ord_aux (Ord_var kid, Parse_ast.Unknown))) + | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_type, _)], _), kid), _) -> + typ_arg (Typ_arg_typ (mk_typ (Typ_var kid))) + | KOpt_aux (KOpt_kind (K_aux (K_kind kinds, _), kid), l) -> assert false + +let fold_union_quant quants (QI_aux (qi, l)) = + match qi with + | QI_id kind_id -> quants @ [kinded_id_arg kind_id] + | _ -> quants + let specialize_variants ((Defs defs) as ast) env = let ctors = ref [] in let specialize_variant (TD_aux (tdef_aux, annot)) ast env = match tdef_aux with - | TD_variant (id, name_scheme, typq, tus, flag) as variant -> + | TD_variant (v_id, name_scheme, typq, tus, flag) as variant -> let kopts = List.filter (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) (quant_kopts typq) in if kopts = [] then (* If non-polymorphic, then do nothing. *) @@ -455,14 +494,18 @@ let specialize_variants ((Defs defs) as ast) env = ctors := id :: !ctors; let is = instantiations_of id ast in let is = List.sort_uniq (fun i1 i2 -> String.compare (string_of_instantiation i1) (string_of_instantiation i2)) is in - List.map (fun i -> Tu_aux (Tu_ty_id (Type_check.subst_unifiers i typ, id_of_instantiation id i), annot)) is + List.map (fun i -> + let i = fix_instantiation i in + let ret_typ = app_typ v_id (List.fold_left fold_union_quant [] (quant_items typq)) in + let ret_typ = Type_check.subst_unifiers i ret_typ in + Tu_aux (Tu_ty_id (Type_check.subst_unifiers i (mk_typ (Typ_fn (typ, ret_typ, no_effect))), id_of_instantiation id i), annot)) is in (* let kopts, constraints = quant_split typq in let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in let typq = mk_typquant (List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in *) - TD_aux (TD_variant (id, name_scheme, typq, List.concat (List.map specialize_tu tus), flag), annot) + TD_aux (TD_variant (v_id, name_scheme, typq, List.concat (List.map specialize_tu tus), flag), annot) | _ -> assert false in diff --git a/src/type_check.ml b/src/type_check.ml index 74e7cc80..689a7338 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -1275,7 +1275,7 @@ and typ_arg_nexps (Typ_arg_aux (typ_arg_aux, l)) = | Typ_arg_nexp n -> [n] | Typ_arg_typ typ -> typ_nexps typ | Typ_arg_order ord -> [] - + let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = match typ_aux with | Typ_id v -> KidSet.empty @@ -3587,6 +3587,11 @@ let fold_union_quant quants (QI_aux (qi, l)) = let check_type_union env variant typq (Tu_aux (tu, l)) = let ret_typ = app_typ variant (List.fold_left fold_union_quant [] (quant_items typq)) in match tu with + | Tu_ty_id (Typ_aux (Typ_fn (arg_typ, ret_typ, _), _) as typ, v) -> + let typq = mk_typquant (List.map (mk_qi_id BK_type) (KidSet.elements (typ_frees typ))) in + env + |> Env.add_union_id v (typq, typ) + |> Env.add_val_spec v (typq, typ) | Tu_ty_id (typ, v) -> let typ' = mk_typ (Typ_fn (typ, ret_typ, no_effect)) in env |
