summaryrefslogtreecommitdiff
path: root/src/initial_check.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/initial_check.ml')
-rw-r--r--src/initial_check.ml66
1 files changed, 56 insertions, 10 deletions
diff --git a/src/initial_check.ml b/src/initial_check.ml
index 1b21e2be..2ded2895 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -75,6 +75,8 @@ let string_of_parse_id_aux = function
let string_of_parse_id (P.Id_aux (id, l)) = string_of_parse_id_aux id
+let parse_id_loc (P.Id_aux (_, l)) = l
+
let string_contains str char =
try (ignore (String.index str char); true) with
| Not_found -> false
@@ -584,21 +586,65 @@ let rec realise_union_anon_rec_types orig_union arms =
| _ ->
raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "Non union type-definition passed to realise_union_anon_rec_typs")
-let rec to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit type_def list ctx_out =
+let generate_enum_functions l ctx enum_id fns exps =
+ let get_exp i = function
+ | Some (P.E_aux (P.E_tuple exps, _)) -> List.nth exps i
+ | Some exp -> exp
+ | None -> Reporting.unreachable l __POS__ "get_exp called without expression"
+ in
+ let num_exps = function
+ | Some (P.E_aux (P.E_tuple exps, _)) -> List.length exps
+ | Some _ -> 1
+ | None -> 0
+ in
+ let num_fns = List.length fns in
+ List.iter (fun (id, exp) ->
+ let n = num_exps exp in
+ if n <> num_fns then (
+ let l = (match exp with Some (P.E_aux (_, l)) -> l | None -> parse_id_loc id) in
+ raise (Reporting.err_general l
+ (sprintf "Each enumeration clause for %s must define exactly %d expressions for the functions %s\n\
+ %s expressions have been given here"
+ (string_of_id enum_id)
+ num_fns
+ (string_of_list ", " string_of_parse_id (List.map fst fns))
+ (if n = 0 then "No" else if n > num_fns then "Too many" else "Too few")))
+ )
+ ) exps;
+ List.mapi (fun i (id, typ) ->
+ let typ = to_ast_typ ctx typ in
+ let name = mk_id (string_of_id enum_id ^ "_" ^ string_of_parse_id id) in
+ [mk_fundef [
+ mk_funcl name (mk_pat (P_id (mk_id "arg#")))
+ (mk_exp (E_case (mk_exp (E_id (mk_id "arg#")),
+ List.map (fun (id, exps) ->
+ let id = to_ast_id id in
+ let exp = to_ast_exp ctx (get_exp i exps) in
+ mk_pexp (Pat_exp (mk_pat (P_id id), exp))
+ ) exps)))
+ ];
+ mk_val_spec (VS_val_spec (mk_typschm (mk_typquant []) (function_typ [mk_id_typ enum_id] typ no_effect),
+ name,
+ [],
+ false))]
+ ) fns
+ |> List.concat
+
+let rec to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit def list ctx_out =
match aux with
| P.TD_abbrev (id, typq, kind, typ_arg) ->
let id = to_ast_id id in
let typq, typq_ctx = to_ast_typquant ctx typq in
let kind = to_ast_kind kind in
let typ_arg = to_ast_typ_arg typq_ctx typ_arg (unaux_kind kind) in
- [TD_aux (TD_abbrev (id, typq, typ_arg), (l, ()))],
+ [DEF_type (TD_aux (TD_abbrev (id, typq, typ_arg), (l, ())))],
add_constructor id typq ctx
| P.TD_record (id, typq, fields, _) ->
let id = to_ast_id id in
let typq, typq_ctx = to_ast_typquant ctx typq in
let fields = List.map (fun (atyp, id) -> to_ast_typ typq_ctx atyp, to_ast_id id) fields in
- [TD_aux (TD_record (id, typq, fields, false), (l, ()))],
+ [DEF_type (TD_aux (TD_record (id, typq, fields, false), (l, ())))],
add_constructor id typq ctx
| P.TD_variant (id, typq, arms, _) as union ->
@@ -621,20 +667,21 @@ let rec to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit type_def list
let id = to_ast_id id in
let typq, typq_ctx = to_ast_typquant ctx typq in
let arms = List.map (to_ast_type_union typq_ctx) arms in
- [TD_aux (TD_variant (id, typq, arms, false), (l, ()))] @ generated_records,
+ [DEF_type (TD_aux (TD_variant (id, typq, arms, false), (l, ())))] @ generated_records,
add_constructor id typq ctx
- | P.TD_enum (id, enums, _) ->
+ | P.TD_enum (id, fns, enums, _) ->
let id = to_ast_id id in
- let enums = List.map to_ast_id enums in
- [TD_aux (TD_enum (id, enums, false), (l, ()))],
+ let fns = generate_enum_functions l ctx id fns enums in
+ let enums = List.map (fun e -> to_ast_id (fst e)) enums in
+ fns @ [DEF_type (TD_aux (TD_enum (id, enums, false), (l, ())))],
{ ctx with type_constructors = Bindings.add id [] ctx.type_constructors }
| P.TD_bitfield (id, typ, ranges) ->
let id = to_ast_id id in
let typ = to_ast_typ ctx typ in
let ranges = List.map (fun (id, range) -> (to_ast_id id, to_ast_range range)) ranges in
- [TD_aux (TD_bitfield (id, typ, ranges), (l, ()))],
+ [DEF_type (TD_aux (TD_bitfield (id, typ, ranges), (l, ())))],
{ ctx with type_constructors = Bindings.add id [] ctx.type_constructors }
let to_ast_rec ctx (P.Rec_aux(r,l): P.rec_opt) : unit rec_opt =
@@ -788,8 +835,7 @@ let to_ast_def ctx def : unit def list ctx_out =
| P.DEF_fixity (prec, n, op) ->
[DEF_fixity (to_ast_prec prec, n, to_ast_id op)], ctx
| P.DEF_type(t_def) ->
- let tds, ctx = to_ast_typedef ctx t_def in
- List.map (fun td -> DEF_type td) tds, ctx
+ to_ast_typedef ctx t_def
| P.DEF_fundef(f_def) ->
let fd = to_ast_fundef ctx f_def in
[DEF_fundef fd], ctx