diff options
Diffstat (limited to 'src/initial_check.ml')
| -rw-r--r-- | src/initial_check.ml | 66 |
1 files changed, 56 insertions, 10 deletions
diff --git a/src/initial_check.ml b/src/initial_check.ml index 1b21e2be..2ded2895 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -75,6 +75,8 @@ let string_of_parse_id_aux = function let string_of_parse_id (P.Id_aux (id, l)) = string_of_parse_id_aux id +let parse_id_loc (P.Id_aux (_, l)) = l + let string_contains str char = try (ignore (String.index str char); true) with | Not_found -> false @@ -584,21 +586,65 @@ let rec realise_union_anon_rec_types orig_union arms = | _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "Non union type-definition passed to realise_union_anon_rec_typs") -let rec to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit type_def list ctx_out = +let generate_enum_functions l ctx enum_id fns exps = + let get_exp i = function + | Some (P.E_aux (P.E_tuple exps, _)) -> List.nth exps i + | Some exp -> exp + | None -> Reporting.unreachable l __POS__ "get_exp called without expression" + in + let num_exps = function + | Some (P.E_aux (P.E_tuple exps, _)) -> List.length exps + | Some _ -> 1 + | None -> 0 + in + let num_fns = List.length fns in + List.iter (fun (id, exp) -> + let n = num_exps exp in + if n <> num_fns then ( + let l = (match exp with Some (P.E_aux (_, l)) -> l | None -> parse_id_loc id) in + raise (Reporting.err_general l + (sprintf "Each enumeration clause for %s must define exactly %d expressions for the functions %s\n\ + %s expressions have been given here" + (string_of_id enum_id) + num_fns + (string_of_list ", " string_of_parse_id (List.map fst fns)) + (if n = 0 then "No" else if n > num_fns then "Too many" else "Too few"))) + ) + ) exps; + List.mapi (fun i (id, typ) -> + let typ = to_ast_typ ctx typ in + let name = mk_id (string_of_id enum_id ^ "_" ^ string_of_parse_id id) in + [mk_fundef [ + mk_funcl name (mk_pat (P_id (mk_id "arg#"))) + (mk_exp (E_case (mk_exp (E_id (mk_id "arg#")), + List.map (fun (id, exps) -> + let id = to_ast_id id in + let exp = to_ast_exp ctx (get_exp i exps) in + mk_pexp (Pat_exp (mk_pat (P_id id), exp)) + ) exps))) + ]; + mk_val_spec (VS_val_spec (mk_typschm (mk_typquant []) (function_typ [mk_id_typ enum_id] typ no_effect), + name, + [], + false))] + ) fns + |> List.concat + +let rec to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit def list ctx_out = match aux with | P.TD_abbrev (id, typq, kind, typ_arg) -> let id = to_ast_id id in let typq, typq_ctx = to_ast_typquant ctx typq in let kind = to_ast_kind kind in let typ_arg = to_ast_typ_arg typq_ctx typ_arg (unaux_kind kind) in - [TD_aux (TD_abbrev (id, typq, typ_arg), (l, ()))], + [DEF_type (TD_aux (TD_abbrev (id, typq, typ_arg), (l, ())))], add_constructor id typq ctx | P.TD_record (id, typq, fields, _) -> let id = to_ast_id id in let typq, typq_ctx = to_ast_typquant ctx typq in let fields = List.map (fun (atyp, id) -> to_ast_typ typq_ctx atyp, to_ast_id id) fields in - [TD_aux (TD_record (id, typq, fields, false), (l, ()))], + [DEF_type (TD_aux (TD_record (id, typq, fields, false), (l, ())))], add_constructor id typq ctx | P.TD_variant (id, typq, arms, _) as union -> @@ -621,20 +667,21 @@ let rec to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit type_def list let id = to_ast_id id in let typq, typq_ctx = to_ast_typquant ctx typq in let arms = List.map (to_ast_type_union typq_ctx) arms in - [TD_aux (TD_variant (id, typq, arms, false), (l, ()))] @ generated_records, + [DEF_type (TD_aux (TD_variant (id, typq, arms, false), (l, ())))] @ generated_records, add_constructor id typq ctx - | P.TD_enum (id, enums, _) -> + | P.TD_enum (id, fns, enums, _) -> let id = to_ast_id id in - let enums = List.map to_ast_id enums in - [TD_aux (TD_enum (id, enums, false), (l, ()))], + let fns = generate_enum_functions l ctx id fns enums in + let enums = List.map (fun e -> to_ast_id (fst e)) enums in + fns @ [DEF_type (TD_aux (TD_enum (id, enums, false), (l, ())))], { ctx with type_constructors = Bindings.add id [] ctx.type_constructors } | P.TD_bitfield (id, typ, ranges) -> let id = to_ast_id id in let typ = to_ast_typ ctx typ in let ranges = List.map (fun (id, range) -> (to_ast_id id, to_ast_range range)) ranges in - [TD_aux (TD_bitfield (id, typ, ranges), (l, ()))], + [DEF_type (TD_aux (TD_bitfield (id, typ, ranges), (l, ())))], { ctx with type_constructors = Bindings.add id [] ctx.type_constructors } let to_ast_rec ctx (P.Rec_aux(r,l): P.rec_opt) : unit rec_opt = @@ -788,8 +835,7 @@ let to_ast_def ctx def : unit def list ctx_out = | P.DEF_fixity (prec, n, op) -> [DEF_fixity (to_ast_prec prec, n, to_ast_id op)], ctx | P.DEF_type(t_def) -> - let tds, ctx = to_ast_typedef ctx t_def in - List.map (fun td -> DEF_type td) tds, ctx + to_ast_typedef ctx t_def | P.DEF_fundef(f_def) -> let fd = to_ast_fundef ctx f_def in [DEF_fundef fd], ctx |
