diff options
| author | Alasdair Armstrong | 2017-06-24 19:06:22 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2017-06-24 19:06:22 +0100 |
| commit | 7691516974eaaa7c41179818e1d76d073c72cc18 (patch) | |
| tree | d64c7305733b3729d4d6fe3eac37f09297c22efc | |
| parent | 98a20e197ef086bd294e157f4eaf75f9f025ff69 (diff) | |
Added implicit casting
Added support for implicit casting to the bi-directional type
checker. The casts can be any user-specified function, and in princple
don't have to be hardcoded. This allows us to typecheck definitions such as
function bit[64] rGPR idx = {
if idx == 0 then 0 else GPR[idx]
}
in the MIPS spec, which involves lots of casting from integers to
bitvectors, as well as casting from a named register to it's value
(implicit dereferencing).
| -rw-r--r-- | src/type_check_new.ml | 243 | ||||
| -rw-r--r-- | test/typecheck/pass/vector_access.sail | 3 | ||||
| -rw-r--r-- | test/typecheck/pass/vector_append.sail | 3 |
3 files changed, 173 insertions, 76 deletions
diff --git a/src/type_check_new.ml b/src/type_check_new.ml index 5d74e13e..98670108 100644 --- a/src/type_check_new.ml +++ b/src/type_check_new.ml @@ -45,7 +45,7 @@ open Ast open Util open Big_int -let debug = ref 2 +let debug = ref 1 let depth = ref 0 let rec indent n = match n with @@ -57,9 +57,53 @@ let typ_debug m = if !debug > 1 then prerr_endline (indent !depth ^ m) else () let typ_print m = if !debug > 0 then prerr_endline (indent !depth ^ m) else () exception Type_error of l * string;; - + let typ_error l m = raise (Type_error (l, m)) +let rec map_exp_annot f (E_aux (exp, annot)) = E_aux (map_exp_annot_aux f exp, f annot) +and map_exp_annot_aux f = function + | E_block xs -> E_block (List.map (map_exp_annot f) xs) + | E_nondet xs -> E_nondet (List.map (map_exp_annot f) xs) + | E_id id -> E_id id + | E_lit lit -> E_lit lit + | E_cast (typ, exp) -> E_cast (typ, map_exp_annot f exp) + | E_app (id, xs) -> E_app (id, List.map (map_exp_annot f) xs) + | E_app_infix (x, op, y) -> E_app_infix (map_exp_annot f x, op, map_exp_annot f y) + | E_tuple xs -> E_tuple (List.map (map_exp_annot f) xs) + | E_if (cond, t, e) -> E_if (map_exp_annot f cond, map_exp_annot f t, map_exp_annot f e) + | E_for (v, e1, e2, e3, o, e4) -> E_for (v, map_exp_annot f e1, map_exp_annot f e2, map_exp_annot f e3, o, map_exp_annot f e4) + | E_list xs -> E_list (List.map (map_exp_annot f) xs) + | E_case (exp, cases) -> E_case (map_exp_annot f exp, List.map (map_pexp_annot f) cases) + | E_let (letbind, exp) -> E_let (map_letbind_annot f letbind, map_exp_annot f exp) + | E_assign (lexp, exp) -> E_assign (map_lexp_annot f lexp, map_exp_annot f exp) + | E_sizeof nexp -> E_sizeof nexp + | E_exit exp -> E_exit (map_exp_annot f exp) + | E_return exp -> E_return (map_exp_annot f exp) + | _ -> typ_error Parse_ast.Unknown "Unimplemented: Cannot map annot in exp" +and map_pexp_annot f (Pat_aux (Pat_exp (pat, exp), annot)) = Pat_aux (Pat_exp (map_pat_annot f pat, map_exp_annot f exp), f annot) +and map_pat_annot f (P_aux (pat, annot)) = P_aux (map_pat_annot_aux f pat, f annot) +and map_pat_annot_aux f = function + | P_lit lit -> P_lit lit + | P_wild -> P_wild + | P_as (pat, id) -> P_as (map_pat_annot f pat, id) + | P_typ (typ, pat) -> P_typ (typ, map_pat_annot f pat) + | P_id id -> P_id id + | P_app (id, pats) -> P_app (id, List.map (map_pat_annot f) pats) + | P_tup pats -> P_tup (List.map (map_pat_annot f) pats) + | P_list pats -> P_list (List.map (map_pat_annot f) pats) + | _ -> typ_error Parse_ast.Unknown "Unimplemented: Cannot map annot in pat" +and map_letbind_annot f (LB_aux (lb, annot)) = LB_aux (map_letbind_annot_aux f lb, f annot) +and map_letbind_annot_aux f = function + | LB_val_explicit (typschm, pat, exp) -> LB_val_explicit (typschm, map_pat_annot f pat, map_exp_annot f exp) + | LB_val_implicit (pat, exp) -> LB_val_implicit (map_pat_annot f pat, map_exp_annot f exp) +and map_lexp_annot f (LEXP_aux (lexp, annot)) = LEXP_aux (map_lexp_annot_aux f lexp, f annot) +and map_lexp_annot_aux f = function + | LEXP_id id -> LEXP_id id + | LEXP_memory (id, exps) -> LEXP_memory (id, List.map (map_exp_annot f) exps) + | LEXP_cast (typ, id) -> LEXP_cast (typ, id) + | LEXP_tup lexps -> LEXP_tup (List.map (map_lexp_annot f) lexps) + | _ -> typ_error Parse_ast.Unknown "Unimplemented: Cannot map annot in lexp" + let string_of_id = function | Id_aux (Id v, _) -> v | Id_aux (DeIid v, _) -> v @@ -235,11 +279,13 @@ module Kid = struct end let unaux_nexp (Nexp_aux (nexp, _)) = nexp - let unaux_order (Ord_aux (ord, _)) = ord - let unaux_typ (Typ_aux (typ, _)) = typ +let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown) +let mk_typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown) +let mk_id str = Id_aux (Id str, Parse_ast.Unknown) + let rec nexp_subst sv subst (Nexp_aux (nexp, l)) = Nexp_aux (nexp_subst_aux sv subst nexp, l) and nexp_subst_aux sv subst = function | Nexp_id v -> Nexp_id v @@ -260,7 +306,7 @@ and nc_subst_nexp_aux l sv subst = function if compare kid sv = 0 then typ_error l ("Cannot substitute " ^ string_of_kid sv ^ " into set constraint " ^ string_of_n_constraint (NC_aux (set_nc, l))) else set_nc - + let rec typ_subst_nexp sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_nexp_aux sv subst typ, l) and typ_subst_nexp_aux sv subst = function | Typ_wild -> Typ_wild @@ -283,9 +329,9 @@ and typ_subst_typ_aux sv subst = function | Typ_var kid -> if Kid.compare kid sv = 0 then subst else Typ_var kid | Typ_fn (typ1, typ2, effs) -> Typ_fn (typ_subst_typ sv subst typ1, typ_subst_typ sv subst typ2, effs) | Typ_tup typs -> Typ_tup (List.map (typ_subst_typ sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_nexp sv subst) args) -and typ_subst_arg_nexp sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_nexp_aux sv subst arg, l) -and typ_subst_arg_nexp_aux sv subst = function + | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_typ sv subst) args) +and typ_subst_arg_typ sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_typ_aux sv subst arg, l) +and typ_subst_arg_typ_aux sv subst = function | Typ_arg_nexp nexp -> Typ_arg_nexp nexp | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_typ sv subst typ) | Typ_arg_order ord -> Typ_arg_order ord @@ -305,9 +351,9 @@ and typ_subst_order_aux sv subst = function | Typ_var kid -> Typ_var kid | Typ_fn (typ1, typ2, effs) -> Typ_fn (typ_subst_order sv subst typ1, typ_subst_order sv subst typ2, effs) | Typ_tup typs -> Typ_tup (List.map (typ_subst_order sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_nexp sv subst) args) -and typ_subst_arg_nexp sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_nexp_aux sv subst arg, l) -and typ_subst_arg_nexp_aux sv subst = function + | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_order sv subst) args) +and typ_subst_arg_order sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_order_aux sv subst arg, l) +and typ_subst_arg_order_aux sv subst = function | Typ_arg_nexp nexp -> Typ_arg_nexp nexp | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_order sv subst typ) | Typ_arg_order ord -> Typ_arg_order (order_subst sv subst ord) @@ -338,7 +384,7 @@ let quant_item_subst_kid_aux sv subst = function let rec pow2 = function | 0 -> 1 | n -> 2 * pow2 (n - 1) - + let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l) and nexp_simp_aux = function | Nexp_sum (n1, n2) -> @@ -373,7 +419,7 @@ and nexp_simp_aux = function | _ -> Nexp_exp n end | nexp -> nexp - + let quant_item_subst_kid sv subst (QI_aux (quant, l)) = QI_aux (quant_item_subst_kid_aux sv subst quant, l) let typquant_subst_kid_aux sv subst = function @@ -424,6 +470,9 @@ module Env : sig val get_default_order : t -> order val set_default_order_inc : t -> t val set_default_order_dec : t -> t + val get_casts : t -> id list + val allow_casts : t -> bool + val no_casts : t -> t val lookup_id : id -> t -> lvar val fresh_kid : t -> kid val expand_synonyms : t -> typ -> typ @@ -437,6 +486,8 @@ end = struct typ_vars : base_kind_aux KBindings.t; typ_synonyms : (typ_arg list -> typ) Bindings.t; overloads : (id list) Bindings.t; + casts : id list; + allow_casts : bool; constraints : n_constraint list; default_order : order option; ret_typ : typ option @@ -450,6 +501,8 @@ end = struct typ_vars = KBindings.empty; typ_synonyms = Bindings.empty; overloads = Bindings.empty; + casts = [mk_id "cast_vec_to_range"; mk_id "cast_01_to_vec"; mk_id "reg_deref"]; + allow_casts = true; constraints = []; default_order = None; ret_typ = None; @@ -518,7 +571,9 @@ end = struct let add_overloads id ids env = typ_print ("Adding overloads for " ^ string_of_id id ^ " [" ^ string_of_list ", " string_of_id ids ^ "]"); { env with overloads = Bindings.add id ids env.overloads } - + + let get_casts env = env.casts + let check_index_range cmp f t (BF_aux (ir, l)) = match ir with | BF_single n -> @@ -625,6 +680,10 @@ end = struct let add_ret_typ typ env = { env with ret_typ = Some typ } + let allow_casts env = env.allow_casts + + let no_casts env = { env with allow_casts = false } + let add_typ_synonym id synonym env = if Bindings.mem id env.typ_synonyms then typ_error (id_loc id) ("Type synonym " ^ string_of_id id ^ " already exists") @@ -692,10 +751,6 @@ let add_typquant (quant : typquant) (env : Env.t) : Env.t = | TypQ_aux (TypQ_no_forall, _) -> env | TypQ_aux (TypQ_tq quants, _) -> List.fold_left add_quant_item env quants -let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown) -let mk_typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown) -let mk_id str = Id_aux (Id str, Parse_ast.Unknown) - let nconstant c = Nexp_aux (Nexp_constant c, Parse_ast.Unknown) let nminus n1 n2 = Nexp_aux (Nexp_minus (n1, n2), Parse_ast.Unknown) let nsum n1 n2 = Nexp_aux (Nexp_sum (n1, n2), Parse_ast.Unknown) @@ -983,6 +1038,26 @@ let unify_order l (Ord_aux (ord_aux1, _) as ord1) (Ord_aux (ord_aux2, _) as ord2 | Ord_dec, Ord_dec -> KBindings.empty | _, _ -> unify_error l (string_of_order ord1 ^ " cannot be unified with " ^ string_of_order ord2) +let subst_unifiers unifiers typ = + let subst_unifier typ (kid, uvar) = + match uvar with + | U_nexp nexp -> typ_subst_nexp kid (unaux_nexp nexp) typ + | U_order ord -> typ_subst_order kid (unaux_order ord) typ + | U_typ subst -> typ_subst_typ kid (unaux_typ subst) typ + | _ -> typ_error Parse_ast.Unknown "Cannot subst unifier" + in + List.fold_left subst_unifier typ (KBindings.bindings unifiers) + +let subst_args_unifiers unifiers typ_args = + let subst_unifier typ_args (kid, uvar) = + match uvar with + | U_nexp nexp -> List.map (typ_subst_arg_nexp kid (unaux_nexp nexp)) typ_args + | U_order ord -> List.map (typ_subst_arg_order kid (unaux_order ord)) typ_args + | U_typ subst -> List.map (typ_subst_arg_typ kid (unaux_typ subst)) typ_args + | _ -> typ_error Parse_ast.Unknown "Cannot subst unifier" + in + List.fold_left subst_unifier typ_args (KBindings.bindings unifiers) + let unify l env typ1 typ2 = let merge_unifiers l kid uvar1 uvar2 = match uvar1, uvar2 with @@ -1009,16 +1084,29 @@ let unify l env typ1 typ2 = | Invalid_argument _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2 ^ " tuple type is of different length") end - | Typ_app (f1, args1), Typ_app (f2, args2) -> + | Typ_app (f1, args1), Typ_app (f2, args2) when Id.compare f1 f2 = 0 -> + unify_typ_arg_list 0 KBindings.empty [] [] args1 args2 + | _, _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2) + + and unify_typ_arg_list unified acc uargs1 uargs2 args1 args2 = + match args1, args2 with + | [], [] when unified = 0 && List.length uargs1 > 0 -> unify_error l ("Could not unify arg lists") (*FIXME improve error *) + | [], [] when unified > 0 && List.length uargs1 > 0 -> unify_typ_arg_list 0 acc [] [] uargs1 uargs2 + | [], [] when List.length uargs1 = 0 -> acc + | (a1 :: a1s), (a2 :: a2s) -> begin - if Id.compare f1 f2 = 0 - then - try List.fold_left (KBindings.merge (merge_unifiers l)) KBindings.empty (List.map2 (unify_typ_args l) args1 args2) with - | Invalid_argument _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2 - ^ " functions applied to different number of arguments") - else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2) + try + let unifiers = unify_typ_args l a1 a2 in + let a1s = subst_args_unifiers unifiers a1s in + let a2s = subst_args_unifiers unifiers a2s in + let uargs1 = subst_args_unifiers unifiers uargs1 in + let uargs2 = subst_args_unifiers unifiers uargs2 in + unify_typ_arg_list (unified + 1) (KBindings.merge (merge_unifiers l) unifiers acc) uargs1 uargs2 a1s a2s + with + | Unification_error _ -> unify_typ_arg_list unified acc (a1 :: uargs1) (a2 :: uargs2) a1s a2s end - | _, _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2) + | _, _ -> unify_error l "Cannot unify type lists of different length" + and unify_typ_args l (Typ_arg_aux (typ_arg_aux1, _) as typ_arg1) (Typ_arg_aux (typ_arg_aux2, _) as typ_arg2) = match typ_arg_aux1, typ_arg_aux2 with | Typ_arg_nexp n1, Typ_arg_nexp n2 -> @@ -1042,7 +1130,7 @@ unifying [|'n - 'l + 1:'n|] against [|0:31|] for example we can only unify the first argument if we do the second first *) - + let infer_lit env (L_aux (lit_aux, l) as lit) = match lit_aux with | L_unit -> mk_typ (Typ_id (mk_id "unit")) @@ -1213,16 +1301,6 @@ let rec instantiate_quants quants kid uvar = match quants with | _ -> (QI_aux (QI_const nc, l)) :: instantiate_quants quants kid uvar end -let subst_unifiers unifiers typ = - let subst_unifier typ (kid, uvar) = - match uvar with - | U_nexp nexp -> typ_subst_nexp kid (unaux_nexp nexp) typ - | U_order ord -> typ_subst_order kid (unaux_order ord) typ - | U_typ subst -> typ_subst_typ kid (unaux_typ subst) typ - | _ -> typ_error Parse_ast.Unknown "Cannot subst unifier" - in - List.fold_left subst_unifier typ (KBindings.bindings unifiers) - let destructure_vec_typ l typ = match typ with | Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp n1, _); @@ -1246,7 +1324,7 @@ let crule r env exp typ = decr depth; checked_exp with | Type_error (l, m) -> decr depth; typ_error l m - + let irule r env exp = incr depth; try @@ -1257,7 +1335,7 @@ let irule r env exp = with | Type_error (l, m) -> decr depth; typ_error l m -let rec check_exp env (E_aux (exp_aux, (l, annot)) as exp : 'a exp) (Typ_aux (typ_aux, _) as typ) : tannot exp = +let rec check_exp env (E_aux (exp_aux, (l, annot)) as exp) (Typ_aux (typ_aux, _) as typ) = let annot_exp exp typ = E_aux (exp, (l, Some (env, typ))) in match (exp_aux, typ_aux) with | E_block exps, _ -> @@ -1290,17 +1368,17 @@ let rec check_exp env (E_aux (exp_aux, (l, annot)) as exp : 'a exp) (Typ_aux (ty let tpat, env = bind_pat env pat (typ_of inferred_bind) in annot_exp (E_let (LB_aux (LB_val_implicit (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ end - | E_app_infix (x, op, y), _ when List.length (Env.get_overloads op env) > 0 -> - let rec try_overload ops = - match ops with - | [] -> typ_error l ("No valid overloading for " ^ string_of_exp exp) - | (op :: ops) -> begin - typ_print ("Overload: " ^ string_of_id op); - try crule check_exp env (E_aux (E_app (op, [x; y]), (l, annot))) typ with - | Type_error _ -> try_overload ops + | E_app_infix (x, op, y), _ when List.length (Env.get_overloads op env) > 0 -> check_exp env (E_aux (E_app (op, [x; y]), (l, annot))) typ + | E_app (f, xs), _ when List.length (Env.get_overloads f env) > 0 -> + let rec try_overload m1 = function + | [] -> typ_error l (m1 ^ "\nNo valid overloading for " ^ string_of_exp exp) + | (f :: fs) -> begin + typ_print ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"); + try crule check_exp env (E_aux (E_app (f, xs), (l, annot))) typ with + | Type_error (_, m2) -> try_overload (m1 ^ "\nand " ^ m2) fs end in - try_overload (Env.get_overloads op env) + try_overload "Overloading error" (Env.get_overloads f env) | E_app (f, xs), _ -> let inferred_exp = infer_funapp l env f xs (Some typ) in (subtyp l env (typ_of inferred_exp) typ; inferred_exp) @@ -1329,15 +1407,30 @@ let rec check_exp env (E_aux (exp_aux, (l, annot)) as exp : 'a exp) (Typ_aux (ty let rtyp = Env.get_register reg env in subtyp l env rtyp typ; annot_exp (E_id reg) typ (* CHECK: is this subtyp the correct way around? *) | _, _ -> - let inferred_exp = irule infer_exp env exp - in (subtyp l env (typ_of inferred_exp) typ; inferred_exp) + let inferred_exp = irule infer_exp env exp in + let strip_annots exp = map_exp_annot (fun (l, _) -> (l, annot)) exp in + let rec try_casts m = function + | [] -> typ_error l ("No valid casts:\n" ^ m) + | (cast :: casts) -> begin + typ_print ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp inferred_exp ^ " to " ^ string_of_typ typ); + try crule check_exp (Env.no_casts env) (strip_annots (annot_exp (E_app (cast, [inferred_exp])) typ)) typ with + | Type_error (l, m) -> try_casts m casts + end + in + begin + try + subtyp l env (typ_of inferred_exp) typ; inferred_exp + with + | Type_error (_, m) when Env.allow_casts env -> try_casts "" (Env.get_casts env) + | Type_error (l, m) -> typ_error l ("Subtype error " ^ m) + end and bind_assignment env lexp (E_aux (_, (l, _)) as exp) = let inferred_exp = irule infer_exp env exp in let tlexp, env' = bind_lexp env lexp (typ_of inferred_exp) in E_aux (E_assign (tlexp, inferred_exp), (l, Some (env, mk_typ (Typ_id (mk_id "unit"))))), env' -and infer_exp env (E_aux (exp_aux, (l, annot)) as exp : 'a exp) : tannot exp = +and infer_exp env (E_aux (exp_aux, (l, annot)) as exp : 'a exp) = let annot_exp exp typ = E_aux (exp, (l, Some (env, typ))) in match exp_aux with | E_id v -> @@ -1363,21 +1456,21 @@ and infer_exp env (E_aux (exp_aux, (l, annot)) as exp : 'a exp) : tannot exp = | E_cast (typ, exp) -> let checked_exp = crule check_exp env exp typ in annot_exp (E_cast (typ, checked_exp)) typ - | E_app_infix (x, op, y) when List.length (Env.get_overloads op env) > 0 -> - let rec try_overload ops = - match ops with - | [] -> typ_error l ("No valid overloading for " ^ string_of_exp exp) - | (op :: ops) -> begin - typ_print ("Overload: " ^ string_of_id op); - try irule infer_exp env (E_aux (E_app (op, [x; y]), (l, annot))) with - | Type_error _ -> try_overload ops + | E_app_infix (x, op, y) when List.length (Env.get_overloads op env) > 0 -> infer_exp env (E_aux (E_app (op, [x; y]), (l, annot))) + | E_app (f, xs) when List.length (Env.get_overloads f env) > 0 -> + let rec try_overload m1 = function + | [] -> typ_error l (m1 ^ "\nNo valid overloading for " ^ string_of_exp exp) + | (f :: fs) -> begin + typ_print ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"); + try irule infer_exp env (E_aux (E_app (f, xs), (l, annot))) with + | Type_error (_, m2) -> try_overload (m1 ^ "\nand " ^ m2) fs end in - try_overload (Env.get_overloads op env) + try_overload "Overloading error" (Env.get_overloads f env) | E_app (f, xs) -> infer_funapp l env f xs None - | E_vector_access (v, n) -> infer_funapp l env (mk_id "vector_access") [v; n] None - | E_vector_append (v1, v2) -> infer_funapp l env (mk_id "vector_append") [v1; v2] None - | E_vector_subrange (v, n, m) -> infer_funapp l env (mk_id "vector_subrange") [v; n; m] None + | E_vector_access (v, n) -> infer_exp env (E_aux (E_app (mk_id "vector_access", [v; n]), (l, annot))) + | E_vector_append (v1, v2) -> infer_exp env (E_aux (E_app (mk_id "vector_append", [v1; v2]), (l, annot))) + | E_vector_subrange (v, n, m) -> infer_exp env (E_aux (E_app (mk_id "vector_subrange", [v; n; m]), (l, annot))) | E_vector [] -> typ_error l "Cannot infer type of empty vector" | E_vector ((item :: items) as vec) -> let inferred_item = irule infer_exp env item in @@ -1401,6 +1494,10 @@ and infer_exp env (E_aux (exp_aux, (l, annot)) as exp : 'a exp) : tannot exp = and infer_funapp l env f xs ret_ctx_typ = let annot_exp exp typ = E_aux (exp, (l, Some (env, typ))) in + let rec number n = function + | [] -> [] + | (x :: xs) -> (n, x) :: number (n + 1) xs + in let solve_quant = function | QI_aux (QI_id _, _) -> false | QI_aux (QI_const nc, _) -> prove env nc @@ -1409,13 +1506,13 @@ and infer_funapp l env f xs ret_ctx_typ = match typs, args with | (utyps, []), (uargs, []) -> begin - let iuargs = List.map2 (fun utyp uarg -> crule check_exp env uarg utyp) utyps uargs in + let iuargs = List.map2 (fun utyp (n, uarg) -> (n, crule check_exp env uarg utyp)) utyps uargs in if List.for_all solve_quant quants then (iuargs, ret_typ) else typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants ^ " not resolved during function application of " ^ string_of_id f) end - | (utyps, (typ :: typs)), (uargs, (arg :: args)) -> + | (utyps, (typ :: typs)), (uargs, ((n, arg) :: args)) -> begin typ_debug ("INSTANTIATE: " ^ string_of_exp arg ^ " with " ^ string_of_typ typ ^ " NF " ^ string_of_tnf (normalize_typ env typ)); let iarg = irule infer_exp env arg in @@ -1428,12 +1525,11 @@ and infer_funapp l env f xs ret_ctx_typ = let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in let ret_typ' = subst_unifiers unifiers ret_typ in let (iargs, ret_typ'') = instantiate quants' (utyps', typs') ret_typ' (uargs, args) in - (iarg :: iargs, ret_typ'') + ((n, iarg) :: iargs, ret_typ'') with | Unification_error (l, str) -> typ_debug ("Unification error: " ^ str); - let (iargs, ret_typ') = instantiate quants (typ :: utyps, typs) ret_typ (arg :: uargs, args) in - (iarg :: iargs, ret_typ') + instantiate quants (typ :: utyps, typs) ret_typ ((n, arg) :: uargs, args) end | (_, []), _ -> typ_error l ("Function " ^ string_of_id f ^ " applied to too many arguments") | _, (_, []) -> typ_error l ("Function " ^ string_of_id f ^ " not applied to enough arguments") @@ -1460,12 +1556,14 @@ and infer_funapp l env f xs ret_ctx_typ = match f_typ with | Typ_aux (Typ_fn (Typ_aux (Typ_tup typ_args, _), typ_ret, effs), _) -> let (quants, typ_args, typ_ret) = instantiate_ret (quant_items typq) typ_args typ_ret in - let (xs, typ_ret) = instantiate quants ([], typ_args) typ_ret ([], xs) in - annot_exp (E_app (f, xs)) typ_ret + let (xs_instantiated, typ_ret) = instantiate quants ([], typ_args) typ_ret ([], number 0 xs) in + let xs_reordered = List.map snd (List.sort (fun (n, _) (m, _) -> compare n m) xs_instantiated) in + annot_exp (E_app (f, xs_reordered)) typ_ret | Typ_aux (Typ_fn (typ_arg, typ_ret, effs), _) -> let (quants, typ_args, typ_ret) = instantiate_ret (quant_items typq) [typ_arg] typ_ret in - let (xs, typ_ret) = instantiate quants ([], typ_args) typ_ret ([], xs) in - annot_exp (E_app (f, xs)) typ_ret + let (xs_instantiated, typ_ret) = instantiate quants ([], typ_args) typ_ret ([], number 0 xs) in + let xs_reordered = List.map snd (List.sort (fun (n, _) (m, _) -> compare n m) xs_instantiated) in + annot_exp (E_app (f, xs_reordered)) typ_ret | _ -> typ_error l (string_of_id f ^ " is not a function") let check_letdef env (LB_aux (letbind, (l, _))) = @@ -1586,3 +1684,4 @@ let initial_env = |> Env.add_overloads (mk_id "^^") [mk_id "duplicate"; mk_id "duplicate_bits"] |> Env.add_overloads (mk_id "!=") [mk_id "neq_vec"] |> Env.add_overloads (mk_id "==") [mk_id "vec_eq_01_left"; mk_id "vec_eq_01_right"] + |> Env.add_overloads (mk_id "vector_access") [mk_id "vector_access_inc"; mk_id "vector_access_dec"] diff --git a/test/typecheck/pass/vector_access.sail b/test/typecheck/pass/vector_access.sail index 9346f5fe..c7c1b502 100644 --- a/test/typecheck/pass/vector_access.sail +++ b/test/typecheck/pass/vector_access.sail @@ -1,5 +1,6 @@ -val forall Nat 'n, Nat 'l, Order 'o, Type 'a, 'l >= 0. (vector<'n,'l,'o,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access +val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|]) -> 'a effect pure vector_access_dec +val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,inc,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access_inc default Order inc diff --git a/test/typecheck/pass/vector_append.sail b/test/typecheck/pass/vector_append.sail index d2f1ca47..af83c44d 100644 --- a/test/typecheck/pass/vector_append.sail +++ b/test/typecheck/pass/vector_append.sail @@ -1,5 +1,4 @@ -val forall Nat 'n, Nat 'l, Order 'o, Type 'a, 'l >= 0. (vector<'n,'l,'o,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access val forall Nat 'n1, Nat 'l1, Nat 'n2, Nat 'l2, Order 'o, Type 'a, 'l1 >= 0, 'l2 >= 0. (vector<'n1,'l1,'o,'a>, vector<'n2,'l2,'o,'a>) -> vector<'n1,'l1 + 'l2,'o,'a> effect pure vector_append @@ -10,8 +9,6 @@ val (bit[4], bit[4]) -> bit[8] effect pure test function bit[8] test (v1, v2) = { - z := vector_access(v1, 3); - z := v1[0]; zv := vector_append(v1, v2); zv := v1 : v2; zv |
