summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlasdair Armstrong2018-03-20 17:57:20 +0000
committerAlasdair Armstrong2018-03-22 18:58:59 +0000
commite33c8546e005fba30ff882b188c86ca03d0917c8 (patch)
treecf72fc3066962718d26a76baedd2d11a7be16946 /src
parent0860deb52e55b11e39e3470290e07f861f877483 (diff)
Fix C compilation for CHERI and MIPS
First, the specialisation of option types has been fixed by allowing the specialisation of constructor return types - this essentially means that a constructor, such as Some : 'a -> option('a) can get specialised to int -> option(int), rather than int -> option('a). This means that these constructors are treated like GADTs internally. Since this only happens just before the C translation, I haven't put much effort into making this very robust so far. Second, there was a bug in C compilation for the typing of return expressions in non-unit contexts, which has been fixed. Finally support for vector literals that are non-bitvectors has been added.
Diffstat (limited to 'src')
-rw-r--r--src/c_backend.ml110
-rw-r--r--src/parser.mly2
-rw-r--r--src/specialize.ml51
-rw-r--r--src/type_check.ml7
4 files changed, 147 insertions, 23 deletions
diff --git a/src/c_backend.ml b/src/c_backend.ml
index def81c75..398a0281 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -1072,7 +1072,9 @@ let iinit ?loc:(l=Parse_ast.Unknown) ctyp id cval =
let iif ?loc:(l=Parse_ast.Unknown) cval then_instrs else_instrs ctyp =
I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (instr_number (), l))
let ifuncall ?loc:(l=Parse_ast.Unknown) clexp id cvals ctyp =
- I_aux (I_funcall (clexp, id, cvals, ctyp), (instr_number (), l))
+ I_aux (I_funcall (clexp, false, id, cvals, ctyp), (instr_number (), l))
+let iextern ?loc:(l=Parse_ast.Unknown) clexp id cvals ctyp =
+ I_aux (I_funcall (clexp, true, id, cvals, ctyp), (instr_number (), l))
let icopy ?loc:(l=Parse_ast.Unknown) clexp cval =
I_aux (I_copy (clexp, cval), (instr_number (), l))
let iconvert ?loc:(l=Parse_ast.Unknown) clexp ctyp1 id ctyp2 =
@@ -1127,7 +1129,7 @@ let rec instr_ctyps (I_aux (instr, aux)) =
| I_init (ctyp, _, cval) | I_reinit (ctyp, _, cval) -> [ctyp; cval_ctyp cval]
| I_if (cval, instrs1, instrs2, ctyp) ->
ctyp :: cval_ctyp cval :: List.concat (List.map instr_ctyps instrs1 @ List.map instr_ctyps instrs2)
- | I_funcall (_, _, cvals, ctyp) ->
+ | I_funcall (_, _, _, cvals, ctyp) ->
ctyp :: List.map cval_ctyp cvals
| I_convert (_, ctyp1, _, ctyp2) -> [ctyp1; ctyp2]
| I_copy (_, cval) -> [cval_ctyp cval]
@@ -1207,7 +1209,7 @@ let rec pp_instr ?short:(short=false) (I_aux (instr, aux)) =
pp_keyword "create" ^^ pp_id id ^^ string " : " ^^ pp_ctyp ctyp ^^ string " = " ^^ pp_cval cval
| I_reinit (ctyp, id, cval) ->
pp_keyword "recreate" ^^ pp_id id ^^ string " : " ^^ pp_ctyp ctyp ^^ string " = " ^^ pp_cval cval
- | I_funcall (x, f, args, ctyp2) ->
+ | I_funcall (x, _, f, args, ctyp2) ->
separate space [ pp_clexp x; string "=";
string (string_of_id f |> Util.green |> Util.clear) ^^ parens (separate_map (string ", ") pp_cval args);
string ":"; pp_ctyp ctyp2 ]
@@ -1445,11 +1447,35 @@ let rec compile_aval ctx = function
@ List.concat (List.mapi aval_mask (List.rev avals)),
(F_id gs, ctyp),
[]
- (*
- c_error ("Have AV_vector (small/bits): " ^ Pretty_print_sail.to_string (separate_map (comma ^^ space) pp_aval avals)) *)
+
+ (* Compiling a vector literal that isn't a bitvector *)
+ | AV_vector (avals, Typ_aux (Typ_app (id, [_; Typ_arg_aux (Typ_arg_order ord, _); Typ_arg_aux (Typ_arg_typ typ, _)]), _))
+ when string_of_id id = "vector" ->
+ let len = List.length avals in
+ let direction = match ord with
+ | Ord_aux (Ord_inc, _) -> false
+ | Ord_aux (Ord_dec, _) -> true
+ | Ord_aux (Ord_var _, _) -> c_error "Polymorphic vector direction found"
+ in
+ let vector_ctyp = CT_vector (direction, ctyp_of_typ ctx typ) in
+ let gs = gensym () in
+ let aval_set i aval =
+ let setup, cval, cleanup = compile_aval ctx aval in
+ setup
+ @ [iextern (CL_id gs)
+ (mk_id "internal_vector_update")
+ [(F_id gs, vector_ctyp); (F_lit (V_int (Big_int.of_int i)), CT_int64); cval] vector_ctyp]
+ @ cleanup
+ in
+ [idecl vector_ctyp gs;
+ ialloc vector_ctyp gs;
+ iextern (CL_id gs) (mk_id "internal_vector_init") [(F_lit (V_int (Big_int.of_int len)), CT_int64)] vector_ctyp]
+ @ List.concat (List.mapi aval_set avals),
+ (F_id gs, vector_ctyp),
+ [iclear vector_ctyp gs]
| AV_vector _ as aval ->
- c_error ("Have AV_vector: " ^ Pretty_print_sail.to_string (pp_aval aval))
+ c_error ("Have AV_vector: " ^ Pretty_print_sail.to_string (pp_aval aval) ^ " which is not a vector type")
| AV_list (avals, Typ_aux (typ, _)) ->
let ctyp = match typ with
@@ -1498,6 +1524,8 @@ let compile_funcall ctx id args typ =
(Printf.sprintf "Failure when setting up function %s arguments: %s and %s." (string_of_id id) (string_of_ctyp have_ctyp) (string_of_ctyp ctyp))
in
+ assert (List.length arg_ctyps = List.length args);
+
let sargs = List.map2 setup_arg arg_ctyps args in
let call =
@@ -1520,7 +1548,16 @@ let rec compile_match ctx apat cval case_label =
| AP_id pid, (frag, ctyp) when Env.is_union_constructor pid ctx.tc_env ->
[ijump (F_op (F_field (frag, "kind"), "!=", F_lit (V_ctor_kind (string_of_id pid))), CT_bool) case_label],
[]
- | AP_global (pid, _), _ -> [icopy (CL_id pid) cval], []
+ | AP_global (pid, typ), (frag, ctyp) ->
+ let global_ctyp = ctyp_of_typ ctx typ in
+ if ctyp_equal global_ctyp ctyp then
+ [icopy (CL_id pid) cval], []
+ else
+ begin match frag with
+ | F_id id ->
+ [iconvert (CL_id pid) global_ctyp id ctyp], []
+ | _ -> c_error "Cannot compile global letbinding"
+ end
| AP_id pid, (frag, ctyp) when is_ct_enum ctyp ->
begin match Env.lookup_id pid ctx.tc_env with
| Unbound -> [idecl ctyp pid; icopy (CL_id pid) (frag, ctyp)], []
@@ -1591,6 +1628,7 @@ let rec compile_aexp ctx = function
letb_setup @ setup, ctyp, call, cleanup @ letb_cleanup
else
begin
+ prerr_endline ("Mismatch: " ^ string_of_ctyp body_ctyp ^ " and " ^ string_of_ctyp ctyp);
let gs = gensym () in
letb_setup @ setup @ [idecl ctyp gs; ialloc ctyp gs; call (CL_id gs)],
body_ctyp,
@@ -1846,7 +1884,7 @@ let rec compile_aexp ctx = function
(* Cleanup info will be re-added by fix_early_return *)
let return_setup, cval, _ = compile_aval ctx aval in
return_setup @ [ireturn cval],
- cval_ctyp cval,
+ ctyp_of_typ ctx typ,
(fun clexp -> icomment "unreachable after return"),
[]
@@ -1903,7 +1941,10 @@ let compile_type_def ctx (TD_aux (type_def, _)) =
{ ctx with records = Bindings.add id ctors ctx.records }
| TD_variant (id, _, _, tus, _) ->
- let compile_tu (Tu_aux (Tu_ty_id (typ, id), _)) = ctyp_of_typ ctx typ, id in
+ let compile_tu = function
+ | Tu_aux (Tu_ty_id (Typ_aux (Typ_fn (typ, _, _), _), id), _) -> ctyp_of_typ ctx typ, id
+ | Tu_aux (Tu_ty_id (typ, id), _) -> ctyp_of_typ ctx typ, id
+ in
let ctus = List.fold_left (fun ctus (ctyp, id) -> Bindings.add id ctyp ctus) Bindings.empty (List.map compile_tu tus) in
CTD_variant (id, Bindings.bindings ctus),
{ ctx with variants = Bindings.add id ctus ctx.variants }
@@ -2046,7 +2087,7 @@ let fix_exception_block ctx instrs =
@ generate_cleanup (historic @ before)
@ [igoto end_block_label]
@ rewrite_exception (historic @ before) after
- | before, (I_aux (I_funcall (x, f, args, ctyp), _) as funcall) :: after ->
+ | before, (I_aux (I_funcall (x, _, f, args, ctyp), _) as funcall) :: after ->
let effects = match Env.get_val_spec f ctx.tc_env with
| _, Typ_aux (Typ_fn (_, _, effects), _) -> effects
| exception (Type_error _) -> no_effect (* nullary union constructor, so no val spec *)
@@ -2080,8 +2121,9 @@ let fix_exception ctx instrs =
let instrs = List.map (map_try_block (fix_exception_block ctx)) instrs in
fix_exception_block ctx instrs
-let arg_pats ctx (Typ_aux (arg_typ_aux, _) as arg_typ) (P_aux (p_aux, (l, _)) as pat) =
+let rec arg_pats ctx (Typ_aux (arg_typ_aux, _) as arg_typ) (P_aux (p_aux, (l, _)) as pat) =
match p_aux, arg_typ_aux with
+ | P_typ (_, pat), _ -> arg_pats ctx arg_typ pat
| P_tup pats, Typ_tup arg_typs when List.length pats = List.length arg_typs ->
List.map2 (fun pat arg_typ -> pat, ctyp_of_typ ctx arg_typ) pats arg_typs
| P_wild, Typ_tup arg_typs -> List.map (fun arg_typ -> pat, ctyp_of_typ ctx arg_typ) arg_typs
@@ -2302,7 +2344,7 @@ let instr_deps = function
| I_init (ctyp, id, cval) | I_reinit (ctyp, id, cval) -> cval_deps cval, NS.singleton (G_id id)
| I_if (cval, _, _, _) -> cval_deps cval, NS.empty
| I_jump (cval, label) -> cval_deps cval, NS.singleton (G_label label)
- | I_funcall (clexp, _, cvals, _) -> List.fold_left NS.union NS.empty (List.map cval_deps cvals), clexp_deps clexp
+ | I_funcall (clexp, _, _, cvals, _) -> List.fold_left NS.union NS.empty (List.map cval_deps cvals), clexp_deps clexp
| I_convert (clexp, _, id, _) -> NS.singleton (G_id id), clexp_deps clexp
| I_copy (clexp, cval) -> cval_deps cval, clexp_deps clexp
| I_clear (_, id) -> NS.singleton (G_id id), NS.singleton (G_id id)
@@ -2439,8 +2481,8 @@ let rec instrs_rename from_id to_id =
| I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) :: instrs ->
I_aux (I_if (crename cval, irename then_instrs, irename else_instrs, ctyp), aux) :: irename instrs
| I_aux (I_jump (cval, label), aux) :: instrs -> I_aux (I_jump (crename cval, label), aux) :: irename instrs
- | I_aux (I_funcall (clexp, id, cvals, ctyp), aux) :: instrs ->
- I_aux (I_funcall (lrename clexp, rename id, List.map crename cvals, ctyp), aux) :: irename instrs
+ | I_aux (I_funcall (clexp, extern, id, cvals, ctyp), aux) :: instrs ->
+ I_aux (I_funcall (lrename clexp, extern, rename id, List.map crename cvals, ctyp), aux) :: irename instrs
| I_aux (I_copy (clexp, cval), aux) :: instrs -> I_aux (I_copy (lrename clexp, crename cval), aux) :: irename instrs
| I_aux (I_convert (clexp, ctyp1, id, ctyp2), aux) :: instrs ->
I_aux (I_convert (lrename clexp, ctyp1, rename id, ctyp2), aux) :: irename instrs
@@ -2579,7 +2621,7 @@ let upper_codegen_id id = string (upper_sgen_id id)
let rec sgen_ctyp = function
| CT_unit -> "unit"
- | CT_bit -> "int"
+ | CT_bit -> "uint64_t"
| CT_bool -> "bool"
| CT_uint64 _ -> "uint64_t"
| CT_int64 -> "int64_t"
@@ -2597,7 +2639,7 @@ let rec sgen_ctyp = function
let rec sgen_ctyp_name = function
| CT_unit -> "unit"
- | CT_bit -> "int"
+ | CT_bit -> "uint64_t"
| CT_bool -> "bool"
| CT_uint64 _ -> "uint64_t"
| CT_int64 -> "int64_t"
@@ -2669,9 +2711,16 @@ let rec codegen_instr fid ctx (I_aux (instr, _)) =
string " { /* try */"
^^ jump 2 2 (separate_map hardline (codegen_instr fid ctx) instrs) ^^ hardline
^^ string " }"
- | I_funcall (x, f, args, ctyp) ->
+ | I_funcall (x, extern, f, args, ctyp) ->
let c_args = Util.string_of_list ", " sgen_cval args in
- let fname = if Env.is_extern f ctx.tc_env "c" then Env.get_extern f ctx.tc_env "c" else sgen_id f in
+ let fname =
+ if Env.is_extern f ctx.tc_env "c" then
+ Env.get_extern f ctx.tc_env "c"
+ else if extern then
+ string_of_id f
+ else
+ sgen_id f
+ in
let fname =
match fname, ctyp with
| "internal_pick", _ -> Printf.sprintf "pick_%s" (sgen_ctyp_name ctyp)
@@ -2696,6 +2745,8 @@ let rec codegen_instr fid ctx (I_aux (instr, _)) =
| "vector_update", CT_uint64 _ -> "update_uint64_t"
| "vector_update", CT_bv _ -> "update_bv"
| "vector_update", _ -> Printf.sprintf "vector_update_%s" (sgen_ctyp_name ctyp)
+ | "internal_vector_update", _ -> Printf.sprintf "internal_vector_update_%s" (sgen_ctyp_name ctyp)
+ | "internal_vector_init", _ -> Printf.sprintf "internal_vector_init_%s" (sgen_ctyp_name ctyp)
| "undefined_vector", CT_uint64 _ -> "undefined_uint64_t"
| "undefined_vector", CT_bv _ -> "undefined_bv_t"
| "undefined_vector", _ -> Printf.sprintf "undefined_vector_%s" (sgen_ctyp_name ctyp)
@@ -3059,6 +3110,7 @@ let codegen_list ctx ctyp =
^^ codegen_pick id ctyp ^^ twice hardline
end
+(* Generate functions for working with non-bit vectors of some specific type. *)
let codegen_vector ctx (direction, ctyp) =
let id = mk_id (string_of_ctyp (CT_vector (direction, ctyp))) in
if IdSet.mem id !generated then
@@ -3111,6 +3163,14 @@ let codegen_vector ctx (direction, ctyp) =
^^ string " }\n"
^^ string "}"
in
+ let internal_vector_update =
+ string (Printf.sprintf "void internal_vector_update_%s(%s *rop, %s op, const int64_t n, const %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
+ ^^ string (if is_stack_ctyp ctyp then
+ " rop->data[n] = elem;\n"
+ else
+ Printf.sprintf " set_%s((rop->data) + n, elem);\n" (sgen_ctyp_name ctyp))
+ ^^ string "}"
+ in
let vector_access =
if is_stack_ctyp ctyp then
string (Printf.sprintf "%s vector_access_%s(%s op, mpz_t n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id))
@@ -3123,6 +3183,12 @@ let codegen_vector ctx (direction, ctyp) =
^^ string (Printf.sprintf " set_%s(rop, op.data[m]);\n" (sgen_ctyp_name ctyp))
^^ string "}"
in
+ let internal_vector_init =
+ string (Printf.sprintf "void internal_vector_init_%s(%s *rop, const int64_t len) {\n" (sgen_id id) (sgen_id id))
+ ^^ string " rop->len = len;\n"
+ ^^ string (Printf.sprintf " rop->data = malloc(len * sizeof(%s));\n" (sgen_ctyp ctyp))
+ ^^ string "}"
+ in
let vector_undefined =
string (Printf.sprintf "void undefined_vector_%s(%s *rop, mpz_t len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
^^ string (Printf.sprintf " rop->len = mpz_get_ui(len);\n")
@@ -3144,6 +3210,8 @@ let codegen_vector ctx (direction, ctyp) =
^^ vector_access ^^ twice hardline
^^ vector_set ^^ twice hardline
^^ vector_update ^^ twice hardline
+ ^^ internal_vector_update ^^ twice hardline
+ ^^ internal_vector_init ^^ twice hardline
end
let is_decl = function
@@ -3178,6 +3246,12 @@ let codegen_def' ctx = function
| _ -> assert false
in
let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx) arg_typs, ctyp_of_typ ctx ret_typ in
+ if (List.length arg_ctyps <> List.length args) then
+ c_error ~loc:(id_loc id) ("function arguments "
+ ^ Util.string_of_list ", " string_of_id args
+ ^ " matched against type "
+ ^ Util.string_of_list ", " string_of_ctyp arg_ctyps)
+ else ();
let args = Util.string_of_list ", " (fun x -> x) (List.map2 (fun ctyp arg -> sgen_ctyp ctyp ^ " " ^ sgen_id arg) arg_ctyps args) in
let function_header =
match ret_arg with
diff --git a/src/parser.mly b/src/parser.mly
index 40bece90..97b6f28c 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -1164,6 +1164,8 @@ struct_fields:
type_union:
| id Colon typ
{ Tu_aux (Tu_ty_id ($3, $1), loc $startpos $endpos) }
+ | id Colon typ MinusGt typ
+ { (fun s e -> Tu_aux (Tu_ty_id (mk_typ (ATyp_fn ($3, $5, mk_typ (ATyp_set []) s e)) s e, $1), loc s e)) $startpos $endpos }
type_unions:
| type_union
diff --git a/src/specialize.ml b/src/specialize.ml
index 1b978fbc..151a185a 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -57,11 +57,34 @@ let is_typ_ord_uvar = function
| Type_check.U_order _ -> true
| _ -> false
+let rec nexp_simp_typ (Typ_aux (typ_aux, l)) =
+ let typ_aux = match typ_aux with
+ | Typ_id v -> Typ_id v
+ | Typ_var kid -> Typ_var kid
+ | Typ_tup typs -> Typ_tup (List.map nexp_simp_typ typs)
+ | Typ_app (f, args) -> Typ_app (f, List.map nexp_simp_typ_arg args)
+ | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, nexp_simp_typ typ)
+ | Typ_fn (typ1, typ2, effect) -> Typ_fn (nexp_simp_typ typ1, nexp_simp_typ typ2, effect)
+ in
+ Typ_aux (typ_aux, l)
+and nexp_simp_typ_arg (Typ_arg_aux (typ_arg_aux, l)) =
+ let typ_arg_aux = match typ_arg_aux with
+ | Typ_arg_nexp n -> Typ_arg_nexp (nexp_simp n)
+ | Typ_arg_typ typ -> Typ_arg_typ (nexp_simp_typ typ)
+ | Typ_arg_order ord -> Typ_arg_order ord
+ in
+ Typ_arg_aux (typ_arg_aux, l)
+
+let nexp_simp_uvar = function
+ | Type_check.U_nexp nexp -> (prerr_endline ("Simp nexp " ^ string_of_nexp nexp); Type_check.U_nexp (nexp_simp nexp))
+ | Type_check.U_typ typ -> Type_check.U_typ (nexp_simp_typ typ)
+ | uvar -> uvar
+
(* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of.
This part of the typechecker API is a bit ugly. *)
let fix_instantiation instantiation =
let instantiation = KBindings.bindings (KBindings.filter (fun _ uvar -> is_typ_ord_uvar uvar) instantiation) in
- let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, uvar) instantiation in
+ let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, nexp_simp_uvar uvar) instantiation in
List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation
let rec polymorphic_functions is_kopt (Defs defs) =
@@ -440,12 +463,28 @@ let rewrite_polymorphic_constructors id ast =
rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp);
rewrite_pat = (fun _ -> fold_pat rewrite_pat)} ast
+let kinded_id_arg kind_id =
+ let typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown) in
+ match kind_id with
+ | KOpt_aux (KOpt_none kid, _) -> typ_arg (Typ_arg_nexp (nvar kid))
+ | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_nat, _)], _), kid), _) -> typ_arg (Typ_arg_nexp (nvar kid))
+ | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_order, _)], _), kid), _) ->
+ typ_arg (Typ_arg_order (Ord_aux (Ord_var kid, Parse_ast.Unknown)))
+ | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_type, _)], _), kid), _) ->
+ typ_arg (Typ_arg_typ (mk_typ (Typ_var kid)))
+ | KOpt_aux (KOpt_kind (K_aux (K_kind kinds, _), kid), l) -> assert false
+
+let fold_union_quant quants (QI_aux (qi, l)) =
+ match qi with
+ | QI_id kind_id -> quants @ [kinded_id_arg kind_id]
+ | _ -> quants
+
let specialize_variants ((Defs defs) as ast) env =
let ctors = ref [] in
let specialize_variant (TD_aux (tdef_aux, annot)) ast env =
match tdef_aux with
- | TD_variant (id, name_scheme, typq, tus, flag) as variant ->
+ | TD_variant (v_id, name_scheme, typq, tus, flag) as variant ->
let kopts = List.filter (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) (quant_kopts typq) in
if kopts = [] then
(* If non-polymorphic, then do nothing. *)
@@ -455,14 +494,18 @@ let specialize_variants ((Defs defs) as ast) env =
ctors := id :: !ctors;
let is = instantiations_of id ast in
let is = List.sort_uniq (fun i1 i2 -> String.compare (string_of_instantiation i1) (string_of_instantiation i2)) is in
- List.map (fun i -> Tu_aux (Tu_ty_id (Type_check.subst_unifiers i typ, id_of_instantiation id i), annot)) is
+ List.map (fun i ->
+ let i = fix_instantiation i in
+ let ret_typ = app_typ v_id (List.fold_left fold_union_quant [] (quant_items typq)) in
+ let ret_typ = Type_check.subst_unifiers i ret_typ in
+ Tu_aux (Tu_ty_id (Type_check.subst_unifiers i (mk_typ (Typ_fn (typ, ret_typ, no_effect))), id_of_instantiation id i), annot)) is
in
(*
let kopts, constraints = quant_split typq in
let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in
let typq = mk_typquant (List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in
*)
- TD_aux (TD_variant (id, name_scheme, typq, List.concat (List.map specialize_tu tus), flag), annot)
+ TD_aux (TD_variant (v_id, name_scheme, typq, List.concat (List.map specialize_tu tus), flag), annot)
| _ -> assert false
in
diff --git a/src/type_check.ml b/src/type_check.ml
index 74e7cc80..689a7338 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -1275,7 +1275,7 @@ and typ_arg_nexps (Typ_arg_aux (typ_arg_aux, l)) =
| Typ_arg_nexp n -> [n]
| Typ_arg_typ typ -> typ_nexps typ
| Typ_arg_order ord -> []
-
+
let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
match typ_aux with
| Typ_id v -> KidSet.empty
@@ -3587,6 +3587,11 @@ let fold_union_quant quants (QI_aux (qi, l)) =
let check_type_union env variant typq (Tu_aux (tu, l)) =
let ret_typ = app_typ variant (List.fold_left fold_union_quant [] (quant_items typq)) in
match tu with
+ | Tu_ty_id (Typ_aux (Typ_fn (arg_typ, ret_typ, _), _) as typ, v) ->
+ let typq = mk_typquant (List.map (mk_qi_id BK_type) (KidSet.elements (typ_frees typ))) in
+ env
+ |> Env.add_union_id v (typq, typ)
+ |> Env.add_val_spec v (typq, typ)
| Tu_ty_id (typ, v) ->
let typ' = mk_typ (Typ_fn (typ, ret_typ, no_effect)) in
env