diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast.ml | 1 | ||||
| -rw-r--r-- | src/initial_check.ml | 2 | ||||
| -rw-r--r-- | src/lexer.mll | 1 | ||||
| -rw-r--r-- | src/parse_ast.ml | 1 | ||||
| -rw-r--r-- | src/parser.mly | 4 | ||||
| -rw-r--r-- | src/type_check.ml | 1 | ||||
| -rw-r--r-- | src/type_check_new.ml | 98 |
7 files changed, 67 insertions, 41 deletions
@@ -570,6 +570,7 @@ and 'a def = (* Top-level definition *) | DEF_fundef of 'a fundef (* function definition *) | DEF_val of 'a letbind (* value definition *) | DEF_spec of 'a val_spec (* top-level type constraint *) + | DEF_overload of id * id list (* operator overload specification *) | DEF_default of 'a default_spec (* default kind and type assumptions *) | DEF_scattered of 'a scattered_def (* scattered function and type definition *) | DEF_reg_dec of 'a dec_spec (* register declaration *) diff --git a/src/initial_check.ml b/src/initial_check.ml index 8ee3da1b..b6d4f863 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -847,6 +847,8 @@ let to_ast_dec (names,k_env,def_ord) (Parse_ast.DEC_aux(regdec,l)) = let to_ast_def (names, k_env, def_ord) partial_defs def : def_progress envs_out * (id * partial_def) list = let envs = (names,k_env,def_ord) in match def with + | Parse_ast.DEF_overload(id,ids) -> + ((Finished(DEF_overload(to_ast_id id, List.map to_ast_id ids))),envs),partial_defs | Parse_ast.DEF_kind(k_def) -> let kd,envs = to_ast_kdef envs k_def in ((Finished(DEF_kind(kd))),envs),partial_defs diff --git a/src/lexer.mll b/src/lexer.mll index 2cedcf42..99965e20 100644 --- a/src/lexer.mll +++ b/src/lexer.mll @@ -81,6 +81,7 @@ let kw_table = ("forall", (fun _ -> Forall)); ("foreach", (fun _ -> Foreach)); ("function", (fun x -> Function_)); + ("overload", (fun _ -> Overload)); ("if", (fun x -> If_)); ("in", (fun x -> In)); ("inc", (fun _ -> Inc)); diff --git a/src/parse_ast.ml b/src/parse_ast.ml index c6af651b..cfb09bf5 100644 --- a/src/parse_ast.ml +++ b/src/parse_ast.ml @@ -488,6 +488,7 @@ def = (* Top-level definition *) | DEF_type of type_def (* type definition *) | DEF_fundef of fundef (* function definition *) | DEF_val of letbind (* value definition *) + | DEF_overload of id * id list (* operator overload specifications *) | DEF_spec of val_spec (* top-level type constraint *) | DEF_default of default_typing_spec (* default kind and type assumptions *) | DEF_scattered of scattered_def (* scattered definition *) diff --git a/src/parser.mly b/src/parser.mly index a544c906..6cf954ef 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -129,7 +129,7 @@ let make_vector_sugar order_set is_inc typ typ1 = /*Terminals with no content*/ %token And Alias As Assert Bitzero Bitone Bits By Case Clause Const Dec Def Default Deinfix Effect EFFECT End -%token Enumerate Else Exit Extern False Forall Foreach Function_ If_ In IN Inc Let_ Member Nat NatNum Order Cast +%token Enumerate Else Exit Extern False Forall Foreach Overload Function_ If_ In IN Inc Let_ Member Nat NatNum Order Cast %token Pure Rec Register Return Scattered Sizeof Struct Switch Then True TwoStarStar Type TYPE Typedef %token Undefined Union With Val %token Barr Depend Rreg Wreg Rmem Rmemt Wmem Wmv Wmvt Eamem Exmem Undef Unspec Nondet Escape @@ -1274,6 +1274,8 @@ def: { dloc (DEF_spec($1)) } | default_typ { dloc (DEF_default($1)) } + | Overload id Lsquare enum_body Rsquare + { dloc (DEF_overload($2,$4)) } | Register typ id { dloc (DEF_reg_dec(DEC_aux(DEC_reg($2,$3),loc ()))) } | Register Alias id Eq exp diff --git a/src/type_check.ml b/src/type_check.ml index c68e60ae..bc6d67a8 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -2528,6 +2528,7 @@ let check_def envs def = let rec check envs (Defs defs) = match defs with | [] -> (Defs []),envs + | (DEF_overload (_, _)::defs) -> check envs (Defs defs) | def::defs -> let (def, envs) = check_def envs def in let (Defs defs, envs) = check envs (Defs defs) in (Defs (def::defs)), envs diff --git a/src/type_check_new.ml b/src/type_check_new.ml index ea460ca8..cc232eb8 100644 --- a/src/type_check_new.ml +++ b/src/type_check_new.ml @@ -115,8 +115,12 @@ and map_lexp_annot_aux f = function let string_of_id = function | Id_aux (Id v, _) -> v - | Id_aux (DeIid v, _) -> v + | Id_aux (DeIid v, _) -> "(deinfix " ^ v ^ ")" +let deinfix = function + | Id_aux (Id v, l) -> Id_aux (DeIid v, l) + | Id_aux (DeIid v, l) -> Id_aux (DeIid v, l) + let string_of_kid = function | Kid_aux (Var v, _) -> v @@ -303,18 +307,19 @@ let unaux_typ (Typ_aux (typ, _)) = typ let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown) let mk_typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown) let mk_id str = Id_aux (Id str, Parse_ast.Unknown) +let mk_infix_id str = Id_aux (DeIid str, Parse_ast.Unknown) + +let mk_id_typ id = Typ_aux (Typ_id id, Parse_ast.Unknown) -let mk_id_typ id = Typ_aux (Typ_id id, Parse_ast.Unknown) - let unit_typ = mk_id_typ (mk_id "unit") let bit_typ = mk_id_typ (mk_id "bit") let atom_typ nexp = mk_typ (Typ_app (mk_id "atom", [mk_typ_arg (Typ_arg_nexp nexp)])) let range_typ nexp1 nexp2 = mk_typ (Typ_app (mk_id "range", [mk_typ_arg (Typ_arg_nexp nexp1); mk_typ_arg (Typ_arg_nexp nexp2)])) let bool_typ = mk_id_typ (mk_id "bool") let string_typ = mk_id_typ (mk_id "string") - + let no_effects = Effect_aux (Effect_set [], Parse_ast.Unknown) - + let rec nexp_subst sv subst (Nexp_aux (nexp, l)) = Nexp_aux (nexp_subst_aux sv subst nexp, l) and nexp_subst_aux sv subst = function | Nexp_id v -> Nexp_id v @@ -795,9 +800,9 @@ end = struct let no_casts env = { env with allow_casts = false } let add_cast cast env = - typ_print ("Adding cast " ^ string_of_id cast); + typ_print ("Adding cast " ^ string_of_id cast); { env with casts = cast :: env.casts } - + let add_typ_synonym id synonym env = if Bindings.mem id env.typ_synonyms then typ_error (id_loc id) ("Type synonym " ^ string_of_id id ^ " already exists") @@ -1240,14 +1245,6 @@ let unify l env typ1 typ2 = let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in unify_typ l typ1 typ2 -(* FIXME: we need to unify lists of typ args better, consider: - -unifying [|'n - 'l + 1:'n|] against [|0:31|] for example - -we can only unify the first argument if we do the second first - -*) - let vector_typ n m ord typ = mk_typ (Typ_app (mk_id "vector", [mk_typ_arg (Typ_arg_nexp n); @@ -1334,16 +1331,13 @@ let rec instantiate_quants quants kid uvar = match quants with | _ -> (QI_aux (QI_const nc, l)) :: instantiate_quants quants kid uvar end -let destructure_vec_typ l typ = - match typ with +let destructure_vec_typ l = function | Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp n1, _); Typ_arg_aux (Typ_arg_nexp n2, _); Typ_arg_aux (Typ_arg_order o, _); Typ_arg_aux (Typ_arg_typ vtyp, _)] - ), _) -> - if string_of_id id = "vector" then (n1, n2, o, vtyp) - else typ_error l ("Expected vector type, got " ^ string_of_typ typ) - | _ -> typ_error l ("Expected vector type, got " ^ string_of_typ typ) + ), _) when string_of_id id = "vector" -> (n1, n2, o, vtyp) + | typ -> typ_error l ("Expected vector type, got " ^ string_of_typ typ) let typ_of (E_aux (_, (_, tannot))) = match tannot with | Some (_, typ) -> typ @@ -1370,7 +1364,7 @@ let irule r env exp = let strip_exp : 'a exp -> unit exp = function exp -> map_exp_annot (fun (l, _) -> (l, ())) exp let strip_pat : 'a pat -> unit pat = function pat -> map_pat_annot (fun (l, _) -> (l, ())) pat - + let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ_aux, _) as typ) : tannot exp = let annot_exp exp typ = E_aux (exp, (l, Some (env, typ))) in match (exp_aux, typ_aux) with @@ -1404,7 +1398,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ let tpat, env = bind_pat env pat (typ_of inferred_bind) in annot_exp (E_let (LB_aux (LB_val_implicit (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ end - | E_app_infix (x, op, y), _ when List.length (Env.get_overloads op env) > 0 -> check_exp env (E_aux (E_app (op, [x; y]), (l, ()))) typ + | E_app_infix (x, op, y), _ when List.length (Env.get_overloads (deinfix op) env) > 0 -> check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ()))) typ | E_app (f, xs), _ when List.length (Env.get_overloads f env) > 0 -> let rec try_overload m1 = function | [] -> typ_error l (m1 ^ "\nNo valid overloading for " ^ string_of_exp exp) @@ -1446,6 +1440,12 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ let inferred_exp = irule infer_exp env exp in type_coercion env inferred_exp typ +(* type_coercion env exp typ takes a fully annoted (i.e. already type + checked) expression exp, and attempts to cast (coerce) it to the + type typ by inserting a coercion function that transforms the + annotated expression into the correct type. Returns an annoted + expression consisting of a type coercion function applied to exp, + or throws a type error if the coercion cannot be performed. *) and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ = let strip exp_aux = strip_exp (E_aux (exp_aux, (Parse_ast.Unknown, None))) in let rec try_casts m = function @@ -1453,7 +1453,7 @@ and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ = | (cast :: casts) -> begin typ_print ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " to " ^ string_of_typ typ); try crule check_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) typ with - | Type_error (l, m) -> try_casts m casts + | Type_error (_, m) -> try_casts m casts end in begin @@ -1464,7 +1464,34 @@ and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ = | Type_error (_, m) when Env.allow_casts env -> try_casts "" (Env.get_casts env) | Type_error (l, m) -> typ_error l ("Subtype error " ^ m) end - + +(* type_coercion_unify env exp typ attempts to coerce exp to a type + exp_typ in the same way as type_coercion, except it is only + required that exp_typ unifies with typ. Returns the annotated + coercion as with type_coercion and also a set of unifiers, or + throws a unification error *) +and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ = + let strip exp_aux = strip_exp (E_aux (exp_aux, (Parse_ast.Unknown, None))) in + let rec try_casts m = function + | [] -> unify_error l ("No valid casts resulted in unification:\n" ^ m) + | (cast :: casts) -> begin + typ_print ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " for unification"); + try + let annotated_exp = irule infer_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) in + annotated_exp, unify l env typ (typ_of annotated_exp) + with + | Type_error (_, m) -> try_casts m casts + | Unification_error (_, m) -> try_casts m casts + end + in + begin + try + typ_debug "PERFORMING COERCING UNIFICATION"; + annotated_exp, unify l env typ (typ_of annotated_exp) + with + | Unification_error (_, m) when Env.allow_casts env -> try_casts "" (Env.get_casts env) + end + and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) = let annot_pat pat typ = P_aux (pat, (l, Some (env, typ))) in match pat_aux with @@ -1537,7 +1564,7 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as let inferred_exp = irule infer_exp env exp in let tlexp, env' = bind_lexp env lexp (typ_of inferred_exp) in annot_assign tlexp inferred_exp, env' - + and bind_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) typ = let annot_lexp lexp typ = LEXP_aux (lexp, (l, Some (env, typ))) in match lexp_aux with @@ -1587,7 +1614,7 @@ and bind_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) typ = annot_lexp (LEXP_tup tlexps) typ, env | _ -> typ_error l "Cannot bind tuple l-expression against non tuple type" end - (* Not sure about this case... can the left lexp be anything other than an identifier? *) + (* Not sure about this case... can the left lexp be anything other than an identifier? *) | LEXP_vector (LEXP_aux (LEXP_id v, _), exp) -> begin let is_immutable, vtyp = match Env.lookup_id v env with @@ -1656,7 +1683,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | E_cast (typ, exp) -> let checked_exp = crule check_exp env exp typ in annot_exp (E_cast (typ, checked_exp)) typ - | E_app_infix (x, op, y) when List.length (Env.get_overloads op env) > 0 -> infer_exp env (E_aux (E_app (op, [x; y]), (l, ()))) + | E_app_infix (x, op, y) when List.length (Env.get_overloads (deinfix op) env) > 0 -> infer_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ()))) | E_app (f, xs) when List.length (Env.get_overloads f env) > 0 -> let rec try_overload m1 = function | [] -> typ_error l (m1 ^ "\nNo valid overloading for " ^ string_of_exp exp) @@ -1725,7 +1752,7 @@ and infer_funapp l env f xs ret_ctx_typ = let iarg = irule infer_exp env arg in typ_debug ("INFER: " ^ string_of_exp arg ^ " type " ^ string_of_typ (typ_of iarg) ^ " NF " ^ string_of_tnf (normalize_typ env (typ_of iarg))); try - let unifiers = unify l env typ (typ_of iarg) in + let iarg, unifiers = type_coercion_unify env iarg typ in typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)); let utyps' = List.map (subst_unifiers unifiers) utyps in let typs' = List.map (subst_unifiers unifiers) typs in @@ -1767,7 +1794,7 @@ and infer_funapp l env f xs ret_ctx_typ = let xs_reordered = List.map snd (List.sort (fun (n, _) (m, _) -> compare n m) xs_instantiated) in annot_exp (E_app (f, xs_reordered)) typ_ret | _ -> typ_error l (string_of_id f ^ " is not a function") - + let check_letdef env (LB_aux (letbind, (l, _))) = begin match letbind with @@ -1888,6 +1915,7 @@ let rec check_def env def = | DEF_val letdef -> check_letdef env letdef | DEF_spec vs -> check_val_spec env vs | DEF_default default -> check_default env default + | DEF_overload (id, ids) -> DEF_overload (id, ids), Env.add_overloads id ids env | DEF_reg_dec (DEC_aux (DEC_reg (typ, id), (l, _))) -> DEF_reg_dec (DEC_aux (DEC_reg (typ, id), (l, None))), Env.add_register id typ env | DEF_reg_dec (DEC_aux (DEC_alias (id, aspec), (l, annot))) -> cd_err () @@ -1913,13 +1941,3 @@ let check env defs = let initial_env = Env.empty |> Env.add_typ_synonym (mk_id "atom") (fun args -> mk_typ (Typ_app (mk_id "range", args @ args))) - |> Env.add_overloads (mk_id "^^") [mk_id "duplicate"; mk_id "duplicate_bits"] - |> Env.add_overloads (mk_id "!=") [mk_id "neq_vec"; mk_id "neq_anything"] - |> Env.add_overloads (mk_id "==") [mk_id "vec_eq_01_left"; mk_id "vec_eq_01_right"; mk_id "eq_anything"] - |> Env.add_overloads (mk_id "mask") [mk_id "mask_inc"; mk_id "mask_dec"] - |> Env.add_overloads (mk_id "~") [mk_id "not"] - |> Env.add_overloads (mk_id "|") [mk_id "bool_or"] - |> Env.add_overloads (mk_id "&") [mk_id "bool_and"] - |> Env.add_overloads (mk_id "+") [mk_id "bv_add"] - |> Env.add_overloads (mk_id "-") [mk_id "sub_exact"; mk_id "sub_range"; mk_id "sub_bv"] - |> Env.add_overloads (mk_id "vector_access") [mk_id "vector_access_inc"; mk_id "vector_access_dec"] |
