summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast.ml1
-rw-r--r--src/initial_check.ml2
-rw-r--r--src/lexer.mll1
-rw-r--r--src/parse_ast.ml1
-rw-r--r--src/parser.mly4
-rw-r--r--src/type_check.ml1
-rw-r--r--src/type_check_new.ml98
7 files changed, 67 insertions, 41 deletions
diff --git a/src/ast.ml b/src/ast.ml
index edbac7d8..6710c749 100644
--- a/src/ast.ml
+++ b/src/ast.ml
@@ -570,6 +570,7 @@ and 'a def = (* Top-level definition *)
| DEF_fundef of 'a fundef (* function definition *)
| DEF_val of 'a letbind (* value definition *)
| DEF_spec of 'a val_spec (* top-level type constraint *)
+ | DEF_overload of id * id list (* operator overload specification *)
| DEF_default of 'a default_spec (* default kind and type assumptions *)
| DEF_scattered of 'a scattered_def (* scattered function and type definition *)
| DEF_reg_dec of 'a dec_spec (* register declaration *)
diff --git a/src/initial_check.ml b/src/initial_check.ml
index 8ee3da1b..b6d4f863 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -847,6 +847,8 @@ let to_ast_dec (names,k_env,def_ord) (Parse_ast.DEC_aux(regdec,l)) =
let to_ast_def (names, k_env, def_ord) partial_defs def : def_progress envs_out * (id * partial_def) list =
let envs = (names,k_env,def_ord) in
match def with
+ | Parse_ast.DEF_overload(id,ids) ->
+ ((Finished(DEF_overload(to_ast_id id, List.map to_ast_id ids))),envs),partial_defs
| Parse_ast.DEF_kind(k_def) ->
let kd,envs = to_ast_kdef envs k_def in
((Finished(DEF_kind(kd))),envs),partial_defs
diff --git a/src/lexer.mll b/src/lexer.mll
index 2cedcf42..99965e20 100644
--- a/src/lexer.mll
+++ b/src/lexer.mll
@@ -81,6 +81,7 @@ let kw_table =
("forall", (fun _ -> Forall));
("foreach", (fun _ -> Foreach));
("function", (fun x -> Function_));
+ ("overload", (fun _ -> Overload));
("if", (fun x -> If_));
("in", (fun x -> In));
("inc", (fun _ -> Inc));
diff --git a/src/parse_ast.ml b/src/parse_ast.ml
index c6af651b..cfb09bf5 100644
--- a/src/parse_ast.ml
+++ b/src/parse_ast.ml
@@ -488,6 +488,7 @@ def = (* Top-level definition *)
| DEF_type of type_def (* type definition *)
| DEF_fundef of fundef (* function definition *)
| DEF_val of letbind (* value definition *)
+ | DEF_overload of id * id list (* operator overload specifications *)
| DEF_spec of val_spec (* top-level type constraint *)
| DEF_default of default_typing_spec (* default kind and type assumptions *)
| DEF_scattered of scattered_def (* scattered definition *)
diff --git a/src/parser.mly b/src/parser.mly
index a544c906..6cf954ef 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -129,7 +129,7 @@ let make_vector_sugar order_set is_inc typ typ1 =
/*Terminals with no content*/
%token And Alias As Assert Bitzero Bitone Bits By Case Clause Const Dec Def Default Deinfix Effect EFFECT End
-%token Enumerate Else Exit Extern False Forall Foreach Function_ If_ In IN Inc Let_ Member Nat NatNum Order Cast
+%token Enumerate Else Exit Extern False Forall Foreach Overload Function_ If_ In IN Inc Let_ Member Nat NatNum Order Cast
%token Pure Rec Register Return Scattered Sizeof Struct Switch Then True TwoStarStar Type TYPE Typedef
%token Undefined Union With Val
%token Barr Depend Rreg Wreg Rmem Rmemt Wmem Wmv Wmvt Eamem Exmem Undef Unspec Nondet Escape
@@ -1274,6 +1274,8 @@ def:
{ dloc (DEF_spec($1)) }
| default_typ
{ dloc (DEF_default($1)) }
+ | Overload id Lsquare enum_body Rsquare
+ { dloc (DEF_overload($2,$4)) }
| Register typ id
{ dloc (DEF_reg_dec(DEC_aux(DEC_reg($2,$3),loc ()))) }
| Register Alias id Eq exp
diff --git a/src/type_check.ml b/src/type_check.ml
index c68e60ae..bc6d67a8 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -2528,6 +2528,7 @@ let check_def envs def =
let rec check envs (Defs defs) =
match defs with
| [] -> (Defs []),envs
+ | (DEF_overload (_, _)::defs) -> check envs (Defs defs)
| def::defs -> let (def, envs) = check_def envs def in
let (Defs defs, envs) = check envs (Defs defs) in
(Defs (def::defs)), envs
diff --git a/src/type_check_new.ml b/src/type_check_new.ml
index ea460ca8..cc232eb8 100644
--- a/src/type_check_new.ml
+++ b/src/type_check_new.ml
@@ -115,8 +115,12 @@ and map_lexp_annot_aux f = function
let string_of_id = function
| Id_aux (Id v, _) -> v
- | Id_aux (DeIid v, _) -> v
+ | Id_aux (DeIid v, _) -> "(deinfix " ^ v ^ ")"
+let deinfix = function
+ | Id_aux (Id v, l) -> Id_aux (DeIid v, l)
+ | Id_aux (DeIid v, l) -> Id_aux (DeIid v, l)
+
let string_of_kid = function
| Kid_aux (Var v, _) -> v
@@ -303,18 +307,19 @@ let unaux_typ (Typ_aux (typ, _)) = typ
let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown)
let mk_typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown)
let mk_id str = Id_aux (Id str, Parse_ast.Unknown)
+let mk_infix_id str = Id_aux (DeIid str, Parse_ast.Unknown)
+
+let mk_id_typ id = Typ_aux (Typ_id id, Parse_ast.Unknown)
-let mk_id_typ id = Typ_aux (Typ_id id, Parse_ast.Unknown)
-
let unit_typ = mk_id_typ (mk_id "unit")
let bit_typ = mk_id_typ (mk_id "bit")
let atom_typ nexp = mk_typ (Typ_app (mk_id "atom", [mk_typ_arg (Typ_arg_nexp nexp)]))
let range_typ nexp1 nexp2 = mk_typ (Typ_app (mk_id "range", [mk_typ_arg (Typ_arg_nexp nexp1); mk_typ_arg (Typ_arg_nexp nexp2)]))
let bool_typ = mk_id_typ (mk_id "bool")
let string_typ = mk_id_typ (mk_id "string")
-
+
let no_effects = Effect_aux (Effect_set [], Parse_ast.Unknown)
-
+
let rec nexp_subst sv subst (Nexp_aux (nexp, l)) = Nexp_aux (nexp_subst_aux sv subst nexp, l)
and nexp_subst_aux sv subst = function
| Nexp_id v -> Nexp_id v
@@ -795,9 +800,9 @@ end = struct
let no_casts env = { env with allow_casts = false }
let add_cast cast env =
- typ_print ("Adding cast " ^ string_of_id cast);
+ typ_print ("Adding cast " ^ string_of_id cast);
{ env with casts = cast :: env.casts }
-
+
let add_typ_synonym id synonym env =
if Bindings.mem id env.typ_synonyms
then typ_error (id_loc id) ("Type synonym " ^ string_of_id id ^ " already exists")
@@ -1240,14 +1245,6 @@ let unify l env typ1 typ2 =
let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in
unify_typ l typ1 typ2
-(* FIXME: we need to unify lists of typ args better, consider:
-
-unifying [|'n - 'l + 1:'n|] against [|0:31|] for example
-
-we can only unify the first argument if we do the second first
-
-*)
-
let vector_typ n m ord typ =
mk_typ (Typ_app (mk_id "vector",
[mk_typ_arg (Typ_arg_nexp n);
@@ -1334,16 +1331,13 @@ let rec instantiate_quants quants kid uvar = match quants with
| _ -> (QI_aux (QI_const nc, l)) :: instantiate_quants quants kid uvar
end
-let destructure_vec_typ l typ =
- match typ with
+let destructure_vec_typ l = function
| Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp n1, _);
Typ_arg_aux (Typ_arg_nexp n2, _);
Typ_arg_aux (Typ_arg_order o, _);
Typ_arg_aux (Typ_arg_typ vtyp, _)]
- ), _) ->
- if string_of_id id = "vector" then (n1, n2, o, vtyp)
- else typ_error l ("Expected vector type, got " ^ string_of_typ typ)
- | _ -> typ_error l ("Expected vector type, got " ^ string_of_typ typ)
+ ), _) when string_of_id id = "vector" -> (n1, n2, o, vtyp)
+ | typ -> typ_error l ("Expected vector type, got " ^ string_of_typ typ)
let typ_of (E_aux (_, (_, tannot))) = match tannot with
| Some (_, typ) -> typ
@@ -1370,7 +1364,7 @@ let irule r env exp =
let strip_exp : 'a exp -> unit exp = function exp -> map_exp_annot (fun (l, _) -> (l, ())) exp
let strip_pat : 'a pat -> unit pat = function pat -> map_pat_annot (fun (l, _) -> (l, ())) pat
-
+
let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ_aux, _) as typ) : tannot exp =
let annot_exp exp typ = E_aux (exp, (l, Some (env, typ))) in
match (exp_aux, typ_aux) with
@@ -1404,7 +1398,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
let tpat, env = bind_pat env pat (typ_of inferred_bind) in
annot_exp (E_let (LB_aux (LB_val_implicit (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ
end
- | E_app_infix (x, op, y), _ when List.length (Env.get_overloads op env) > 0 -> check_exp env (E_aux (E_app (op, [x; y]), (l, ()))) typ
+ | E_app_infix (x, op, y), _ when List.length (Env.get_overloads (deinfix op) env) > 0 -> check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ()))) typ
| E_app (f, xs), _ when List.length (Env.get_overloads f env) > 0 ->
let rec try_overload m1 = function
| [] -> typ_error l (m1 ^ "\nNo valid overloading for " ^ string_of_exp exp)
@@ -1446,6 +1440,12 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
let inferred_exp = irule infer_exp env exp in
type_coercion env inferred_exp typ
+(* type_coercion env exp typ takes a fully annoted (i.e. already type
+ checked) expression exp, and attempts to cast (coerce) it to the
+ type typ by inserting a coercion function that transforms the
+ annotated expression into the correct type. Returns an annoted
+ expression consisting of a type coercion function applied to exp,
+ or throws a type error if the coercion cannot be performed. *)
and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ =
let strip exp_aux = strip_exp (E_aux (exp_aux, (Parse_ast.Unknown, None))) in
let rec try_casts m = function
@@ -1453,7 +1453,7 @@ and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ =
| (cast :: casts) -> begin
typ_print ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " to " ^ string_of_typ typ);
try crule check_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) typ with
- | Type_error (l, m) -> try_casts m casts
+ | Type_error (_, m) -> try_casts m casts
end
in
begin
@@ -1464,7 +1464,34 @@ and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ =
| Type_error (_, m) when Env.allow_casts env -> try_casts "" (Env.get_casts env)
| Type_error (l, m) -> typ_error l ("Subtype error " ^ m)
end
-
+
+(* type_coercion_unify env exp typ attempts to coerce exp to a type
+ exp_typ in the same way as type_coercion, except it is only
+ required that exp_typ unifies with typ. Returns the annotated
+ coercion as with type_coercion and also a set of unifiers, or
+ throws a unification error *)
+and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ =
+ let strip exp_aux = strip_exp (E_aux (exp_aux, (Parse_ast.Unknown, None))) in
+ let rec try_casts m = function
+ | [] -> unify_error l ("No valid casts resulted in unification:\n" ^ m)
+ | (cast :: casts) -> begin
+ typ_print ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " for unification");
+ try
+ let annotated_exp = irule infer_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) in
+ annotated_exp, unify l env typ (typ_of annotated_exp)
+ with
+ | Type_error (_, m) -> try_casts m casts
+ | Unification_error (_, m) -> try_casts m casts
+ end
+ in
+ begin
+ try
+ typ_debug "PERFORMING COERCING UNIFICATION";
+ annotated_exp, unify l env typ (typ_of annotated_exp)
+ with
+ | Unification_error (_, m) when Env.allow_casts env -> try_casts "" (Env.get_casts env)
+ end
+
and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) =
let annot_pat pat typ = P_aux (pat, (l, Some (env, typ))) in
match pat_aux with
@@ -1537,7 +1564,7 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as
let inferred_exp = irule infer_exp env exp in
let tlexp, env' = bind_lexp env lexp (typ_of inferred_exp) in
annot_assign tlexp inferred_exp, env'
-
+
and bind_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) typ =
let annot_lexp lexp typ = LEXP_aux (lexp, (l, Some (env, typ))) in
match lexp_aux with
@@ -1587,7 +1614,7 @@ and bind_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) typ =
annot_lexp (LEXP_tup tlexps) typ, env
| _ -> typ_error l "Cannot bind tuple l-expression against non tuple type"
end
- (* Not sure about this case... can the left lexp be anything other than an identifier? *)
+ (* Not sure about this case... can the left lexp be anything other than an identifier? *)
| LEXP_vector (LEXP_aux (LEXP_id v, _), exp) ->
begin
let is_immutable, vtyp = match Env.lookup_id v env with
@@ -1656,7 +1683,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
| E_cast (typ, exp) ->
let checked_exp = crule check_exp env exp typ in
annot_exp (E_cast (typ, checked_exp)) typ
- | E_app_infix (x, op, y) when List.length (Env.get_overloads op env) > 0 -> infer_exp env (E_aux (E_app (op, [x; y]), (l, ())))
+ | E_app_infix (x, op, y) when List.length (Env.get_overloads (deinfix op) env) > 0 -> infer_exp env (E_aux (E_app (deinfix op, [x; y]), (l, ())))
| E_app (f, xs) when List.length (Env.get_overloads f env) > 0 ->
let rec try_overload m1 = function
| [] -> typ_error l (m1 ^ "\nNo valid overloading for " ^ string_of_exp exp)
@@ -1725,7 +1752,7 @@ and infer_funapp l env f xs ret_ctx_typ =
let iarg = irule infer_exp env arg in
typ_debug ("INFER: " ^ string_of_exp arg ^ " type " ^ string_of_typ (typ_of iarg) ^ " NF " ^ string_of_tnf (normalize_typ env (typ_of iarg)));
try
- let unifiers = unify l env typ (typ_of iarg) in
+ let iarg, unifiers = type_coercion_unify env iarg typ in
typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers));
let utyps' = List.map (subst_unifiers unifiers) utyps in
let typs' = List.map (subst_unifiers unifiers) typs in
@@ -1767,7 +1794,7 @@ and infer_funapp l env f xs ret_ctx_typ =
let xs_reordered = List.map snd (List.sort (fun (n, _) (m, _) -> compare n m) xs_instantiated) in
annot_exp (E_app (f, xs_reordered)) typ_ret
| _ -> typ_error l (string_of_id f ^ " is not a function")
-
+
let check_letdef env (LB_aux (letbind, (l, _))) =
begin
match letbind with
@@ -1888,6 +1915,7 @@ let rec check_def env def =
| DEF_val letdef -> check_letdef env letdef
| DEF_spec vs -> check_val_spec env vs
| DEF_default default -> check_default env default
+ | DEF_overload (id, ids) -> DEF_overload (id, ids), Env.add_overloads id ids env
| DEF_reg_dec (DEC_aux (DEC_reg (typ, id), (l, _))) ->
DEF_reg_dec (DEC_aux (DEC_reg (typ, id), (l, None))), Env.add_register id typ env
| DEF_reg_dec (DEC_aux (DEC_alias (id, aspec), (l, annot))) -> cd_err ()
@@ -1913,13 +1941,3 @@ let check env defs =
let initial_env =
Env.empty
|> Env.add_typ_synonym (mk_id "atom") (fun args -> mk_typ (Typ_app (mk_id "range", args @ args)))
- |> Env.add_overloads (mk_id "^^") [mk_id "duplicate"; mk_id "duplicate_bits"]
- |> Env.add_overloads (mk_id "!=") [mk_id "neq_vec"; mk_id "neq_anything"]
- |> Env.add_overloads (mk_id "==") [mk_id "vec_eq_01_left"; mk_id "vec_eq_01_right"; mk_id "eq_anything"]
- |> Env.add_overloads (mk_id "mask") [mk_id "mask_inc"; mk_id "mask_dec"]
- |> Env.add_overloads (mk_id "~") [mk_id "not"]
- |> Env.add_overloads (mk_id "|") [mk_id "bool_or"]
- |> Env.add_overloads (mk_id "&") [mk_id "bool_and"]
- |> Env.add_overloads (mk_id "+") [mk_id "bv_add"]
- |> Env.add_overloads (mk_id "-") [mk_id "sub_exact"; mk_id "sub_range"; mk_id "sub_bv"]
- |> Env.add_overloads (mk_id "vector_access") [mk_id "vector_access_inc"; mk_id "vector_access_dec"]