summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlasdair Armstrong2017-06-23 20:54:33 +0100
committerAlasdair Armstrong2017-06-23 20:54:33 +0100
commit98a20e197ef086bd294e157f4eaf75f9f025ff69 (patch)
tree323d22795394370cb578c02446039b46c71e3a29 /src
parent454084884c7a0bbe6c00ea46349962e8d5228118 (diff)
Added support for overloaded operators
Diffstat (limited to 'src')
-rw-r--r--src/type_check_new.ml315
1 files changed, 210 insertions, 105 deletions
diff --git a/src/type_check_new.ml b/src/type_check_new.ml
index 25759d87..5d74e13e 100644
--- a/src/type_check_new.ml
+++ b/src/type_check_new.ml
@@ -45,7 +45,7 @@ open Ast
open Util
open Big_int
-let debug = ref 1
+let debug = ref 2
let depth = ref 0
let rec indent n = match n with
@@ -53,10 +53,12 @@ let rec indent n = match n with
| n -> "| " ^ indent (n - 1)
let typ_debug m = if !debug > 1 then prerr_endline (indent !depth ^ m) else ()
-
+
let typ_print m = if !debug > 0 then prerr_endline (indent !depth ^ m) else ()
-
-let typ_error l m = raise (Reporting_basic.err_typ l m)
+
+exception Type_error of l * string;;
+
+let typ_error l m = raise (Type_error (l, m))
let string_of_id = function
| Id_aux (Id v, _) -> v
@@ -64,29 +66,29 @@ let string_of_id = function
let string_of_kid = function
| Kid_aux (Var v, _) -> v
-
+
let id_loc = function
| Id_aux (_, l) -> l
let kid_loc = function
| Kid_aux (_, l) -> l
-
+
let string_of_base_effect_aux = function
- | BE_rreg -> "rreg"
- | BE_wreg -> "wreg"
- | BE_rmem -> "rmem"
- | BE_rmemt -> "rmemt"
- | BE_wmem -> "wmem"
- | BE_eamem -> "eamem"
- | BE_exmem -> "exmem"
- | BE_wmv -> "wmv"
- | BE_wmvt -> "wmvt"
- | BE_barr -> "barr"
- | BE_depend -> "depend"
- | BE_undef -> "undef"
- | BE_unspec -> "unspec"
- | BE_nondet -> "nondet"
- | BE_escape -> "escape"
+ | BE_rreg -> "rreg"
+ | BE_wreg -> "wreg"
+ | BE_rmem -> "rmem"
+ | BE_rmemt -> "rmemt"
+ | BE_wmem -> "wmem"
+ | BE_eamem -> "eamem"
+ | BE_exmem -> "exmem"
+ | BE_wmv -> "wmv"
+ | BE_wmvt -> "wmvt"
+ | BE_barr -> "barr"
+ | BE_depend -> "depend"
+ | BE_undef -> "undef"
+ | BE_unspec -> "unspec"
+ | BE_nondet -> "nondet"
+ | BE_escape -> "escape"
| BE_lset -> "lset"
| BE_lret -> "lret"
@@ -97,9 +99,9 @@ let string_of_base_kind_aux = function
| BK_effect -> "Effect"
let string_of_base_kind (BK_aux (bk, _)) = string_of_base_kind_aux bk
-
+
let string_of_kind (K_aux (K_kind bks, _)) = string_of_list " -> " string_of_base_kind bks
-
+
let string_of_base_effect = function
| BE_aux (beff, _) -> string_of_base_effect_aux beff
@@ -108,13 +110,13 @@ let string_of_effect = function
typ_debug "kid effect occured"; string_of_kid kid
| Effect_aux (Effect_set [], _) -> "pure"
| Effect_aux (Effect_set beffs, _) ->
- "{" ^ string_of_list ", " string_of_base_effect beffs ^ "}"
+ "{" ^ string_of_list ", " string_of_base_effect beffs ^ "}"
let string_of_order = function
| Ord_aux (Ord_var kid, _) -> string_of_kid kid
| Ord_aux (Ord_inc, _) -> "inc"
| Ord_aux (Ord_dec, _) -> "dec"
-
+
let rec string_of_nexp = function
| Nexp_aux (nexp, _) -> string_of_nexp_aux nexp
and string_of_nexp_aux = function
@@ -126,7 +128,7 @@ and string_of_nexp_aux = function
| Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")"
| Nexp_exp n -> "2 ^ " ^ string_of_nexp n
| Nexp_neg n -> "- " ^ string_of_nexp n
-
+
let rec string_of_typ = function
| Typ_aux (typ, l) -> string_of_typ_aux typ
and string_of_typ_aux = function
@@ -140,7 +142,7 @@ and string_of_typ_aux = function
and string_of_typ_arg = function
| Typ_arg_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg
and string_of_typ_arg_aux = function
- | Typ_arg_nexp n -> string_of_nexp n
+ | Typ_arg_nexp n -> string_of_nexp n
| Typ_arg_typ typ -> string_of_typ typ
| Typ_arg_order o -> string_of_order o
| Typ_arg_effect eff -> string_of_effect eff
@@ -151,31 +153,31 @@ let string_of_n_constraint = function
| NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2
| NC_aux (NC_nat_set_bounded (kid, ns), _) ->
string_of_kid kid ^ " IN {" ^ string_of_list ", " string_of_int ns ^ "}"
-
+
let string_of_quant_item_aux = function
| QI_id (KOpt_aux (KOpt_none kid, _)) -> string_of_kid kid
| QI_id (KOpt_aux (KOpt_kind (k, kid), _)) -> string_of_kind k ^ " " ^ string_of_kid kid
| QI_const constr -> string_of_n_constraint constr
-
+
let string_of_quant_item = function
| QI_aux (qi, _) -> string_of_quant_item_aux qi
-
+
let string_of_typquant_aux = function
| TypQ_tq quants -> "forall " ^ string_of_list ", " string_of_quant_item quants
| TypQ_no_forall -> ""
-
+
let string_of_typquant = function
| TypQ_aux (quant, _) -> string_of_typquant_aux quant
let string_of_typschm (TypSchm_aux (TypSchm_ts (quant, typ), _)) =
string_of_typquant quant ^ ". " ^ string_of_typ typ
-
+
let string_of_bind (typquant, typ) = string_of_typquant typquant ^ ". " ^ string_of_typ typ
let string_of_lit (L_aux (lit, _)) =
match lit with
| L_unit -> "()"
- | L_zero -> "bitzero" (* FIXME: Check *)
+ | L_zero -> "bitzero"
| L_one -> "bitone"
| L_true -> "true"
| L_false -> "false"
@@ -184,15 +186,16 @@ let string_of_lit (L_aux (lit, _)) =
| L_bin n -> "0b" ^ n
| L_undef -> "undefined"
| L_string str -> "\"" ^ str ^ "\""
-
+
let rec string_of_exp (E_aux (exp, _)) =
match exp with
| E_block exps -> "{ " ^ string_of_list "; " string_of_exp exps ^ " }"
| E_id v -> string_of_id v
- | E_sizeof nexp -> "sizeof " ^ string_of_nexp nexp
+ | E_sizeof nexp -> "sizeof " ^ string_of_nexp nexp
| E_lit lit -> string_of_lit lit
| E_return exp -> "return " ^ string_of_exp exp
| E_app (f, args) -> string_of_id f ^ "(" ^ string_of_list ", " string_of_exp args ^ ")"
+ | E_app_infix (x, op, y) -> "(" ^ string_of_exp x ^ " " ^ string_of_id op ^ " " ^ string_of_exp y ^ ")"
| E_tuple exps -> "(" ^ string_of_list ", " string_of_exp exps ^ ")"
| E_case (exp, cases) ->
"switch " ^ string_of_exp exp ^ " { case " ^ string_of_list " case " string_of_pexp cases ^ "}"
@@ -200,6 +203,10 @@ let rec string_of_exp (E_aux (exp, _)) =
| E_assign (lexp, bind) -> string_of_lexp lexp ^ " := " ^ string_of_exp bind
| E_cast (typ, exp) -> "(" ^ string_of_typ typ ^ ") " ^ string_of_exp exp
| E_vector vec -> "[" ^ string_of_list ", " string_of_exp vec ^ "]"
+ | E_vector_access (v, n) -> string_of_exp v ^ "[" ^ string_of_exp n ^ "]"
+ | E_vector_subrange (v, n1, n2) -> string_of_exp v ^ "[" ^ string_of_exp n1 ^ " .. " ^ string_of_exp n2 ^ "]"
+ | E_if (cond, then_branch, else_branch) ->
+ "if " ^ string_of_exp cond ^ " then " ^ string_of_exp then_branch ^ " else " ^ string_of_exp else_branch
| _ -> "INTERNAL"
and string_of_pexp (Pat_aux (Pat_exp (pat, exp), _)) = string_of_pat pat ^ " -> " ^ string_of_exp exp
and string_of_pat (P_aux (pat, l)) =
@@ -226,13 +233,13 @@ module Kid = struct
type t = kid
let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2)
end
-
+
let unaux_nexp (Nexp_aux (nexp, _)) = nexp
let unaux_order (Ord_aux (ord, _)) = ord
let unaux_typ (Typ_aux (typ, _)) = typ
-
+
let rec nexp_subst sv subst (Nexp_aux (nexp, l)) = Nexp_aux (nexp_subst_aux sv subst nexp, l)
and nexp_subst_aux sv subst = function
| Nexp_id v -> Nexp_id v
@@ -253,7 +260,7 @@ and nc_subst_nexp_aux l sv subst = function
if compare kid sv = 0
then typ_error l ("Cannot substitute " ^ string_of_kid sv ^ " into set constraint " ^ string_of_n_constraint (NC_aux (set_nc, l)))
else set_nc
-
+
let rec typ_subst_nexp sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_nexp_aux sv subst typ, l)
and typ_subst_nexp_aux sv subst = function
| Typ_wild -> Typ_wild
@@ -283,14 +290,14 @@ and typ_subst_arg_nexp_aux sv subst = function
| Typ_arg_typ typ -> Typ_arg_typ (typ_subst_typ sv subst typ)
| Typ_arg_order ord -> Typ_arg_order ord
| Typ_arg_effect eff -> Typ_arg_effect eff
-
+
let order_subst_aux sv subst = function
| Ord_var kid -> if Kid.compare kid sv = 0 then subst else Ord_var kid
| Ord_inc -> Ord_inc
| Ord_dec -> Ord_dec
-
-let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l)
-
+
+let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l)
+
let rec typ_subst_order sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_order_aux sv subst typ, l)
and typ_subst_order_aux sv subst = function
| Typ_wild -> Typ_wild
@@ -327,15 +334,54 @@ let quant_item_subst_kid_aux sv subst = function
| QI_id (KOpt_aux (KOpt_kind (k, kid), l)) as qid ->
if Kid.compare kid sv = 0 then QI_id (KOpt_aux (KOpt_kind (k, subst), l)) else qid
| QI_const nc -> QI_const (nc_subst_nexp sv (Nexp_var subst) nc)
-
-let quant_item_subst_kid sv subst (QI_aux (quant, l)) = QI_aux (quant_item_subst_kid_aux sv subst quant, l)
-
+
+let rec pow2 = function
+ | 0 -> 1
+ | n -> 2 * pow2 (n - 1)
+
+let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l)
+and nexp_simp_aux = function
+ | Nexp_sum (n1, n2) ->
+ begin
+ let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in
+ let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in
+ match n1_simp, n2_simp with
+ | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (c1 + c2)
+ | _, _ -> Nexp_sum (n1, n2)
+ end
+ | Nexp_times (n1, n2) ->
+ begin
+ let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in
+ let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in
+ match n1_simp, n2_simp with
+ | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (c1 * c2)
+ | _, _ -> Nexp_times (n1, n2)
+ end
+ | Nexp_minus (n1, n2) ->
+ begin
+ let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in
+ let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in
+ match n1_simp, n2_simp with
+ | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (c1 - c2)
+ | _, _ -> Nexp_minus (n1, n2)
+ end
+ | Nexp_exp n ->
+ begin
+ let (Nexp_aux (n_simp, _) as n) = nexp_simp n in
+ match n_simp with
+ | Nexp_constant c -> Nexp_constant (pow2 c)
+ | _ -> Nexp_exp n
+ end
+ | nexp -> nexp
+
+let quant_item_subst_kid sv subst (QI_aux (quant, l)) = QI_aux (quant_item_subst_kid_aux sv subst quant, l)
+
let typquant_subst_kid_aux sv subst = function
| TypQ_tq quants -> TypQ_tq (List.map (quant_item_subst_kid sv subst) quants)
| TypQ_no_forall -> TypQ_no_forall
-
+
let typquant_subst_kid sv subst (TypQ_aux (typq, l)) = TypQ_aux (typquant_subst_kid_aux sv subst typq, l)
-
+
module Id = struct
type t = id
let compare id1 id2 =
@@ -354,7 +400,7 @@ module KidSet = Set.Make(Kid)
type mut = Immutable | Mutable
type lvar = Register of typ | Local of mut * typ | Unbound
-
+
module Env : sig
type t
val get_val_spec : id -> t -> typquant * typ
@@ -373,13 +419,15 @@ module Env : sig
val add_ret_typ : typ -> t -> t
val add_typ_synonym : id -> (typ_arg list -> typ) -> t -> t
val get_typ_synonym : id -> t -> typ_arg list -> typ
+ val add_overloads : id -> id list -> t -> t
+ val get_overloads : id -> t -> id list
val get_default_order : t -> order
val set_default_order_inc : t -> t
val set_default_order_dec : t -> t
val lookup_id : id -> t -> lvar
val fresh_kid : t -> kid
val expand_synonyms : t -> typ -> typ
- val empty : t
+ val empty : t
end = struct
type t =
{ top_val_specs : (typquant * typ) Bindings.t;
@@ -388,11 +436,12 @@ end = struct
regtyps : (int * int * (index_range * id) list) Bindings.t;
typ_vars : base_kind_aux KBindings.t;
typ_synonyms : (typ_arg list -> typ) Bindings.t;
+ overloads : (id list) Bindings.t;
constraints : n_constraint list;
default_order : order option;
ret_typ : typ option
}
-
+
let empty =
{ top_val_specs = Bindings.empty;
locals = Bindings.empty;
@@ -400,6 +449,7 @@ end = struct
regtyps = Bindings.empty;
typ_vars = KBindings.empty;
typ_synonyms = Bindings.empty;
+ overloads = Bindings.empty;
constraints = [];
default_order = None;
ret_typ = None;
@@ -414,7 +464,7 @@ end = struct
let freshen_kid env kid (typq, typ) =
let fresh = fresh_kid env in
(typquant_subst_kid kid fresh typq, typ_subst_kid kid fresh typ)
-
+
let get_val_spec id env =
try
let bind = Bindings.find id env.top_val_specs in
@@ -432,7 +482,7 @@ end = struct
begin
typ_debug ("Adding val spec binding " ^ string_of_id id ^ " :: " ^ string_of_bind bind);
{ env with top_val_specs = Bindings.add id bind env.top_val_specs }
- end
+ end
let get_local id env =
try Bindings.find id env.locals with
@@ -446,11 +496,11 @@ end = struct
| Immutable -> false
with
| Not_found -> typ_error (id_loc id) ("No local binding found for " ^ string_of_id id)
-
+
let string_of_mtyp (mut, typ) = match mut with
| Immutable -> string_of_typ typ
| Mutable -> "ref<" ^ string_of_typ typ ^ ">"
-
+
let add_local id mtyp env =
begin
typ_print ("Adding local binding " ^ string_of_id id ^ " :: " ^ string_of_mtyp mtyp);
@@ -461,6 +511,14 @@ end = struct
try Bindings.find id env.registers with
| Not_found -> typ_error (id_loc id) ("No register binding found for " ^ string_of_id id)
+ let get_overloads id env =
+ try Bindings.find id env.overloads with
+ | Not_found -> []
+
+ let add_overloads id ids env =
+ typ_print ("Adding overloads for " ^ string_of_id id ^ " [" ^ string_of_list ", " string_of_id ids ^ "]");
+ { env with overloads = Bindings.add id ids env.overloads }
+
let check_index_range cmp f t (BF_aux (ir, l)) =
match ir with
| BF_single n ->
@@ -483,7 +541,7 @@ end = struct
let base' = check_index_range cmp base top range in
check_index_ranges (IdSet.add id ids) cmp base' top ranges
end
-
+
let add_register id typ env =
if Bindings.mem id env.registers
then typ_error (id_loc id) ("Register " ^ string_of_id id ^ " is already bound")
@@ -504,7 +562,7 @@ end = struct
else check_index_ranges IdSet.empty (fun x y -> x < y) (base - 1) (top + 1) ranges;
{ env with regtyps = Bindings.add id (base, top, ranges) env.regtyps }
end
-
+
let lookup_id id env =
try
let (mut, typ) = Bindings.find id env.locals in
@@ -515,7 +573,7 @@ end = struct
try Register (Bindings.find id env.registers) with
| Not_found -> Unbound
end
-
+
let get_typ_var kid env =
try KBindings.find kid env.typ_vars with
| Not_found -> typ_error (kid_loc kid) ("No kind identifier " ^ string_of_kid kid)
@@ -553,9 +611,9 @@ end = struct
| NC_bounded_ge (n1, n2) -> wf_nexp env n1; wf_nexp env n2
| NC_bounded_le (n1, n2) -> wf_nexp env n1; wf_nexp env n2
| NC_nat_set_bounded (kid, ints) -> () (* MAYBE: We could demand that ints are all unique here *)
-
+
let get_constraints env = env.constraints
-
+
let add_constraint (NC_aux (_, l) as constr) env =
wf_constraint env constr;
begin
@@ -616,11 +674,11 @@ end = struct
let set_default_order_inc = set_default_order Ord_inc
let set_default_order_dec = set_default_order Ord_dec
-
+
end
type tannot = (Env.t * typ) option
-
+
let add_typquant (quant : typquant) (env : Env.t) : Env.t =
let rec add_quant_item env = function
| QI_aux (qi, _) -> add_quant_item_aux env qi
@@ -637,16 +695,16 @@ let add_typquant (quant : typquant) (env : Env.t) : Env.t =
let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown)
let mk_typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown)
let mk_id str = Id_aux (Id str, Parse_ast.Unknown)
-
+
let nconstant c = Nexp_aux (Nexp_constant c, Parse_ast.Unknown)
let nminus n1 n2 = Nexp_aux (Nexp_minus (n1, n2), Parse_ast.Unknown)
-let nsum n1 n2 = Nexp_aux (Nexp_sum (n1, n2), Parse_ast.Unknown)
+let nsum n1 n2 = Nexp_aux (Nexp_sum (n1, n2), Parse_ast.Unknown)
let nvar kid = Nexp_aux (Nexp_var kid, Parse_ast.Unknown)
type index_sort =
| IS_int
| IS_prop of kid * (nexp * nexp) list
-
+
type tnf =
| Tnf_wild
| Tnf_id of id
@@ -670,11 +728,11 @@ let rec string_of_tnf = function
| Tnf_index_sort (IS_prop (kid, props)) ->
"{" ^ string_of_kid kid ^ " | " ^ string_of_list " & " (fun (n1, n2) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2) props ^ "}"
and string_of_tnf_arg = function
- | Tnf_arg_nexp n -> string_of_nexp n
+ | Tnf_arg_nexp n -> string_of_nexp n
| Tnf_arg_typ tnf -> string_of_tnf tnf
| Tnf_arg_order o -> string_of_order o
| Tnf_arg_effect eff -> string_of_effect eff
-
+
let rec normalize_typ env (Typ_aux (typ, l)) =
match typ with
| Typ_wild -> Tnf_wild
@@ -733,7 +791,7 @@ let order_eq (Ord_aux (ord_aux1, _)) (Ord_aux (ord_aux2, _)) =
| Ord_dec, Ord_dec -> true
| Ord_var kid1, Ord_var kid2 -> Kid.compare kid1 kid2 = 0
| _, _ -> false
-
+
let rec props_subst sv subst props =
match props with
| [] -> []
@@ -749,7 +807,7 @@ let rec nexp_constraint var_of (Nexp_aux (nexp, l)) =
| Nexp_minus (nexp1, nexp2) -> Constraint.sub (nexp_constraint var_of nexp1) (nexp_constraint var_of nexp2)
| Nexp_exp nexp -> Constraint.pow2 (nexp_constraint var_of nexp)
| Nexp_neg nexp -> Constraint.sub (Constraint.constant (big_int_of_int 0)) (nexp_constraint var_of nexp)
-
+
let nc_constraint var_of (NC_aux (nc, _)) =
match nc with
| NC_fixed (nexp1, nexp2) -> Constraint.eq (nexp_constraint var_of nexp1) (nexp_constraint var_of nexp2)
@@ -786,7 +844,7 @@ let prove env nc =
| Constraint.Unsat _ -> typ_debug "unsat"; true
| Constraint.Unknown [] -> typ_debug "sat"; false
| Constraint.Unknown _ -> typ_debug "unknown"; false
-
+
let rec subtyp_tnf env tnf1 tnf2 =
typ_print ("Subset " ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_tnf tnf1 ^ " " ^ string_of_tnf tnf2);
let module Bindings = Map.Make(Kid) in
@@ -844,7 +902,7 @@ and tnf_args_eq env arg1 arg2 =
| Tnf_arg_order ord1, Tnf_arg_order ord2 -> order_eq ord1 ord2
| Tnf_arg_typ tnf1, Tnf_arg_typ tnf2 -> subtyp_tnf env tnf1 tnf2 && subtyp_tnf env tnf2 tnf1
| _, _ -> assert false
-
+
let subtyp l env typ1 typ2 =
if subtyp_tnf env (normalize_typ env typ1) (normalize_typ env typ2)
then ()
@@ -878,7 +936,7 @@ let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) =
exception Unification_error of l * string;;
let unify_error l str = raise (Unification_error (l, str))
-
+
let rec unify_nexps l (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) =
typ_debug ("UNIFYING NEXPS " ^ string_of_nexp nexp1 ^ " AND " ^ string_of_nexp nexp2);
match nexp_aux1 with
@@ -902,7 +960,7 @@ let rec unify_nexps l (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _
if KidSet.is_empty (nexp_frees n1b)
then unify_nexps l n1a (nsum nexp2 n1b)
else unify_error l ("Cannot unify minus Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
-
+
| _ -> unify_error l ("Cannot unify Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
type uvar =
@@ -924,7 +982,7 @@ let unify_order l (Ord_aux (ord_aux1, _) as ord1) (Ord_aux (ord_aux2, _) as ord2
| Ord_inc, Ord_inc -> KBindings.empty
| Ord_dec, Ord_dec -> KBindings.empty
| _, _ -> unify_error l (string_of_order ord1 ^ " cannot be unified with " ^ string_of_order ord2)
-
+
let unify l env typ1 typ2 =
let merge_unifiers l kid uvar1 uvar2 =
match uvar1, uvar2 with
@@ -960,12 +1018,12 @@ let unify l env typ1 typ2 =
^ " functions applied to different number of arguments")
else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2)
end
- | _, _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2)
+ | _, _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2)
and unify_typ_args l (Typ_arg_aux (typ_arg_aux1, _) as typ_arg1) (Typ_arg_aux (typ_arg_aux2, _) as typ_arg2) =
match typ_arg_aux1, typ_arg_aux2 with
| Typ_arg_nexp n1, Typ_arg_nexp n2 ->
begin
- match unify_nexps l n1 n2 with
+ match unify_nexps l (nexp_simp n1) (nexp_simp n2) with
| Some (kid, unifier) -> KBindings.singleton kid (U_nexp unifier)
| None -> KBindings.empty
end
@@ -977,6 +1035,14 @@ let unify l env typ1 typ2 =
let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in
unify_typ l typ1 typ2
+(* FIXME: we need to unify lists of typ args better, consider:
+
+unifying [|'n - 'l + 1:'n|] against [|0:31|] for example
+
+we can only unify the first argument if we do the second first
+
+*)
+
let infer_lit env (L_aux (lit_aux, l) as lit) =
match lit_aux with
| L_unit -> mk_typ (Typ_id (mk_id "unit"))
@@ -1019,7 +1085,7 @@ let infer_lit env (L_aux (lit_aux, l) as lit) =
mk_typ_arg (Typ_arg_typ (mk_typ (Typ_id (mk_id "bit"))))]))
end
| L_undef -> typ_error l "Cannot infer the type of undefined"
-
+
let rec bind_pat env (P_aux (pat_aux, (l, _)) as pat) (Typ_aux (typ_aux, _) as typ) =
let annot_pat pat typ = P_aux (pat, (l, Some (env, typ))) in
match pat_aux with
@@ -1102,11 +1168,11 @@ and bind_lexp env (LEXP_aux (lexp_aux, (l, _))) typ =
| _ -> typ_error l "Cannot bind tuple l-expression against non tuple type"
end
| _ -> typ_error l ("Unhandled l-expression")
-
+
let quant_items : typquant -> quant_item list = function
| TypQ_aux (TypQ_tq qis, _) -> qis
| TypQ_aux (TypQ_no_forall, _) -> []
-
+
let is_nat_kid kid = function
| KOpt_aux (KOpt_none kid', _) -> Kid.compare kid kid' = 0
| KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_nat, _)], _), kid'), _) -> Kid.compare kid kid' = 0
@@ -1119,7 +1185,7 @@ let is_order_kid kid = function
let is_typ_kid kid = function
| KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_type, _)], _), kid'), _) -> Kid.compare kid kid' = 0
| _ -> false
-
+
let rec instantiate_quants quants kid uvar = match quants with
| [] -> []
| ((QI_aux (QI_id kinded_id, _) as quant) :: quants) ->
@@ -1136,7 +1202,7 @@ let rec instantiate_quants quants kid uvar = match quants with
| U_typ typ ->
if is_typ_kid kid kinded_id
then instantiate_quants quants kid uvar
- else quant :: instantiate_quants quants kid uvar
+ else quant :: instantiate_quants quants kid uvar
| _ -> typ_error Parse_ast.Unknown "Cannot instantiate quantifier"
end
| ((QI_aux (QI_const nc, l)) :: quants) ->
@@ -1167,25 +1233,31 @@ let destructure_vec_typ l typ =
if string_of_id id = "vector" then (n1, n2, o, vtyp)
else typ_error l ("Expected vector type, got " ^ string_of_typ typ)
| _ -> typ_error l ("Expected vector type, got " ^ string_of_typ typ)
-
+
let typ_of (E_aux (_, (_, tannot))) = match tannot with
| Some (_, typ) -> typ
| None -> assert false
let crule r env exp typ =
incr depth;
- typ_print ("Check " ^ string_of_exp exp ^ " <= " ^ string_of_typ typ);
- let checked_exp = r env exp typ in
- decr depth; checked_exp
-
+ typ_print ("Check " ^ string_of_exp exp ^ " <= " ^ string_of_typ typ);
+ try
+ let checked_exp = r env exp typ in
+ decr depth; checked_exp
+ with
+ | Type_error (l, m) -> decr depth; typ_error l m
+
let irule r env exp =
incr depth;
- let inferred_exp = r env exp in
- typ_print ("Infer " ^ string_of_exp exp ^ " => " ^ string_of_typ (typ_of inferred_exp));
- decr depth;
- inferred_exp
-
-let rec check_exp env (E_aux (exp_aux, (l, _)) as exp : 'a exp) (Typ_aux (typ_aux, _) as typ) : tannot exp =
+ try
+ let inferred_exp = r env exp in
+ typ_print ("Infer " ^ string_of_exp exp ^ " => " ^ string_of_typ (typ_of inferred_exp));
+ decr depth;
+ inferred_exp
+ with
+ | Type_error (l, m) -> decr depth; typ_error l m
+
+let rec check_exp env (E_aux (exp_aux, (l, annot)) as exp : 'a exp) (Typ_aux (typ_aux, _) as typ) : tannot exp =
let annot_exp exp typ = E_aux (exp, (l, Some (env, typ))) in
match (exp_aux, typ_aux) with
| E_block exps, _ ->
@@ -1218,6 +1290,17 @@ let rec check_exp env (E_aux (exp_aux, (l, _)) as exp : 'a exp) (Typ_aux (typ_au
let tpat, env = bind_pat env pat (typ_of inferred_bind) in
annot_exp (E_let (LB_aux (LB_val_implicit (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ
end
+ | E_app_infix (x, op, y), _ when List.length (Env.get_overloads op env) > 0 ->
+ let rec try_overload ops =
+ match ops with
+ | [] -> typ_error l ("No valid overloading for " ^ string_of_exp exp)
+ | (op :: ops) -> begin
+ typ_print ("Overload: " ^ string_of_id op);
+ try crule check_exp env (E_aux (E_app (op, [x; y]), (l, annot))) typ with
+ | Type_error _ -> try_overload ops
+ end
+ in
+ try_overload (Env.get_overloads op env)
| E_app (f, xs), _ ->
let inferred_exp = infer_funapp l env f xs (Some typ)
in (subtyp l env (typ_of inferred_exp) typ; inferred_exp)
@@ -1241,7 +1324,11 @@ let rec check_exp env (E_aux (exp_aux, (l, _)) as exp : 'a exp) (Typ_aux (typ_au
end
| E_lit (L_aux (L_undef, _) as lit), _ ->
annot_exp (E_lit lit) typ
- | _, _ ->
+ (* This rule allows registers of type t to be passed by name with type register<t>*)
+ | E_id reg, Typ_app (id, [Typ_arg_aux (Typ_arg_typ typ, _)]) when string_of_id id = "register" ->
+ let rtyp = Env.get_register reg env in
+ subtyp l env rtyp typ; annot_exp (E_id reg) typ (* CHECK: is this subtyp the correct way around? *)
+ | _, _ ->
let inferred_exp = irule infer_exp env exp
in (subtyp l env (typ_of inferred_exp) typ; inferred_exp)
@@ -1249,8 +1336,8 @@ and bind_assignment env lexp (E_aux (_, (l, _)) as exp) =
let inferred_exp = irule infer_exp env exp in
let tlexp, env' = bind_lexp env lexp (typ_of inferred_exp) in
E_aux (E_assign (tlexp, inferred_exp), (l, Some (env, mk_typ (Typ_id (mk_id "unit"))))), env'
-
-and infer_exp env (E_aux (exp_aux, (l, _)) as exp : 'a exp) : tannot exp =
+
+and infer_exp env (E_aux (exp_aux, (l, annot)) as exp : 'a exp) : tannot exp =
let annot_exp exp typ = E_aux (exp, (l, Some (env, typ))) in
match exp_aux with
| E_id v ->
@@ -1276,6 +1363,17 @@ and infer_exp env (E_aux (exp_aux, (l, _)) as exp : 'a exp) : tannot exp =
| E_cast (typ, exp) ->
let checked_exp = crule check_exp env exp typ in
annot_exp (E_cast (typ, checked_exp)) typ
+ | E_app_infix (x, op, y) when List.length (Env.get_overloads op env) > 0 ->
+ let rec try_overload ops =
+ match ops with
+ | [] -> typ_error l ("No valid overloading for " ^ string_of_exp exp)
+ | (op :: ops) -> begin
+ typ_print ("Overload: " ^ string_of_id op);
+ try irule infer_exp env (E_aux (E_app (op, [x; y]), (l, annot))) with
+ | Type_error _ -> try_overload ops
+ end
+ in
+ try_overload (Env.get_overloads op env)
| E_app (f, xs) -> infer_funapp l env f xs None
| E_vector_access (v, n) -> infer_funapp l env (mk_id "vector_access") [v; n] None
| E_vector_append (v1, v2) -> infer_funapp l env (mk_id "vector_append") [v1; v2] None
@@ -1300,7 +1398,7 @@ and infer_exp env (E_aux (exp_aux, (l, _)) as exp : 'a exp) : tannot exp =
in
annot_exp (E_vector (inferred_item :: checked_items)) vec_typ
| _ -> typ_error l "Unimplemented"
-
+
and infer_funapp l env f xs ret_ctx_typ =
let annot_exp exp typ = E_aux (exp, (l, Some (env, typ))) in
let solve_quant = function
@@ -1383,7 +1481,7 @@ let check_letdef env (LB_aux (letbind, (l, _))) =
let tpat, env = bind_pat env pat (typ_of inferred_bind) in
DEF_val (LB_aux (LB_val_implicit (tpat, inferred_bind), (l, None))), env
end
-
+
let check_funcl env (FCL_aux (FCL_Funcl (id, pat, exp), (l, _))) typ =
match typ with
| Typ_aux (Typ_fn (typ_arg, typ_ret, eff), _) ->
@@ -1394,19 +1492,19 @@ let check_funcl env (FCL_aux (FCL_Funcl (id, pat, exp), (l, _))) typ =
FCL_aux (FCL_Funcl (id, typed_pat, exp), (l, Some (env, typ)))
end
| _ -> typ_error l ("Function clause must have function type: " ^ string_of_typ typ ^ " is not a function type")
-
+
let check_fundef env (FD_aux (FD_function (recopt, tannotopt, effectopt, funcls), (l, _)) as fd_aux : 'a fundef) : tannot def =
let (Typ_annot_opt_aux (Typ_annot_opt_some (annot_quant, annot_typ1), _)) = tannotopt in
let id =
match (List.fold_right
(fun (FCL_aux (FCL_Funcl (id, _, _), _)) id' ->
match id' with
- | Some id' -> if string_of_id id' = string_of_id id then Some id'
- else typ_error l ("Function declaration expects all definitions to have the same name, "
+ | Some id' -> if string_of_id id' = string_of_id id then Some id'
+ else typ_error l ("Function declaration expects all definitions to have the same name, "
^ string_of_id id ^ " differs from other definitions of " ^ string_of_id id')
| None -> Some id) funcls None)
with
- | Some id -> id
+ | Some id -> id
| None -> typ_error l "funcl list is empty"
in
typ_print ("\nChecking function " ^ string_of_id id);
@@ -1422,7 +1520,7 @@ let check_fundef env (FD_aux (FD_function (recopt, tannotopt, effectopt, funcls)
the difference is irrelevant for the typechecker. *)
let check_val_spec env (VS_aux (vs, (l, _))) =
let (id, quants, typ) = match vs with
- | VS_val_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id) -> (id, quants, typ)
+ | VS_val_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id) -> (id, quants, typ)
| VS_extern_no_rename (TypSchm_aux (TypSchm_ts (quants, typ), _), id) -> (id, quants, typ)
| VS_extern_spec (TypSchm_aux (TypSchm_ts (quants, typ), _), id, _) -> (id, quants, typ) in
DEF_spec (VS_aux (vs, (l, None))), Env.add_val_spec id (quants, typ) env
@@ -1440,7 +1538,7 @@ let check_register env id base top ranges =
match base, top with
| Nexp_aux (Nexp_constant basec, _), Nexp_aux (Nexp_constant topc, _) -> Env.add_regtyp id basec topc ranges env
| _, _ -> typ_error (id_loc id) "Num expressions in register type declaration do not evaluate to constants"
-
+
let check_typedef env (TD_aux (tdef, (l, _))) =
let td_err () = raise (Reporting_basic.err_unreachable Parse_ast.Unknown "Unimplemented Typedef") in
match tdef with
@@ -1450,7 +1548,7 @@ let check_typedef env (TD_aux (tdef, (l, _))) =
| TD_variant(id, nmscm, typq, arms, _) -> td_err ()
| TD_enum(id, nmscm, ids, _) -> td_err ()
| TD_register(id, base, top, ranges) -> DEF_type (TD_aux (tdef, (l, None))), check_register env id base top ranges
-
+
let rec check_def env def =
let cd_err () = raise (Reporting_basic.err_unreachable Parse_ast.Unknown "Unimplemented Case") in
match def with
@@ -1470,14 +1568,21 @@ let rec check_def env def =
let def, env = check_def env def
in DEF_comm (DC_comm_struct def), env
-let rec check env (Defs defs) =
+let rec check' env (Defs defs) =
match defs with
| [] -> (Defs []), env
| def :: defs ->
let (def, env) = check_def env def in
- let (Defs defs, env) = check env (Defs defs) in
+ let (Defs defs, env) = check' env (Defs defs) in
(Defs (def::defs)), env
+let check env defs =
+ try check' env defs with
+ | Type_error (l, m) -> raise (Reporting_basic.err_typ l m)
+
let initial_env =
Env.empty
|> Env.add_typ_synonym (mk_id "atom") (fun args -> mk_typ (Typ_app (mk_id "range", args @ args)))
+ |> Env.add_overloads (mk_id "^^") [mk_id "duplicate"; mk_id "duplicate_bits"]
+ |> Env.add_overloads (mk_id "!=") [mk_id "neq_vec"]
+ |> Env.add_overloads (mk_id "==") [mk_id "vec_eq_01_left"; mk_id "vec_eq_01_right"]