diff options
| author | Alasdair Armstrong | 2019-05-17 18:38:35 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2019-05-17 18:38:35 +0100 |
| commit | a1ef7946b96d95b3192f8db496f09d4bb23b775a (patch) | |
| tree | fffb42d83bebfae64ae1be1149e8c5e660753ed1 /src | |
| parent | f0b547154b3d2ce9e4bac74b0c56f20d6db76cd2 (diff) | |
Experiment with making vector and bitvector distinct types
Only change that should be needed for 99.9% of uses is to change
vector('n, 'ord, bit) to bitvector('n, 'ord), and adding
$ifndef FEATURE_BITVECTOR_TYPE
type bitvector('n, dec) = vector('n, dec, bit)
$endif
for to support any Sail before this
Currently I have all C, Typechecking, and SMT tests passing, as well
as the RISC-V spec building OCaml and C completely unmodified.
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 17 | ||||
| -rw-r--r-- | src/ast_util.mli | 1 | ||||
| -rw-r--r-- | src/bitfield.ml | 6 | ||||
| -rw-r--r-- | src/initial_check.ml | 3 | ||||
| -rw-r--r-- | src/jib/c_backend.ml | 9 | ||||
| -rw-r--r-- | src/jib/jib_compile.ml | 10 | ||||
| -rw-r--r-- | src/jib/jib_smt.ml | 6 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 2 | ||||
| -rw-r--r-- | src/process_file.ml | 1 | ||||
| -rw-r--r-- | src/rewrites.ml | 10 | ||||
| -rw-r--r-- | src/sail_lib.ml | 5 | ||||
| -rw-r--r-- | src/type_check.ml | 155 | ||||
| -rw-r--r-- | src/type_check.mli | 1 |
13 files changed, 161 insertions, 65 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 70845468..014d50d0 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -397,6 +397,11 @@ let vector_typ n ord typ = mk_typ_arg (A_order ord); mk_typ_arg (A_typ typ)])) +let bitvector_typ n ord = + mk_typ (Typ_app (mk_id "bitvector", + [mk_typ_arg (A_nexp (nexp_simp n)); + mk_typ_arg (A_order ord)])) + let exc_typ = mk_id_typ (mk_id "exception") let nconstant c = Nexp_aux (Nexp_constant c, Parse_ast.Unknown) @@ -1264,6 +1269,8 @@ let typ_app_args_of = function let rec vector_typ_args_of typ = match typ_app_args_of typ with | ("vector", [A_nexp len; A_order ord; A_typ etyp], l) -> (nexp_simp len, ord, etyp) + | ("bitvector", [A_nexp len; A_order ord], l) -> + (nexp_simp len, ord, bit_typ) | ("register", [A_typ rtyp], _) -> vector_typ_args_of rtyp | (_, _, l) -> raise (Reporting.err_typ l @@ -1286,11 +1293,11 @@ let is_bit_typ = function | Typ_aux (Typ_id (Id_aux (Id "bit", _)), _) -> true | _ -> false -let is_bitvector_typ typ = - if is_vector_typ typ then - let (_,_,etyp) = vector_typ_args_of typ in - is_bit_typ etyp - else false +let rec is_bitvector_typ = function + | Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [_;_]), _) -> true + | Typ_aux (Typ_app (Id_aux (Id "register",_), [A_aux (A_typ rtyp,_)]), _) -> + is_bitvector_typ rtyp + | _ -> false let has_effect (Effect_aux (eff,_)) searched_for = match eff with | Effect_set effs -> diff --git a/src/ast_util.mli b/src/ast_util.mli index c8f3cc5c..ee8fdf13 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -157,6 +157,7 @@ val unit_typ : typ val string_typ : typ val real_typ : typ val vector_typ : nexp -> order -> typ -> typ +val bitvector_typ : nexp -> order -> typ val list_typ : typ -> typ val exc_typ : typ val tuple_typ : typ list -> typ diff --git a/src/bitfield.ml b/src/bitfield.ml index feda4602..5b0a73b0 100644 --- a/src/bitfield.ml +++ b/src/bitfield.ml @@ -55,7 +55,7 @@ open Ast open Ast_util let bitvec size order = - Printf.sprintf "vector(%i, %s, bit)" size (string_of_order order) + Printf.sprintf "bitvector(%i, %s)" size (string_of_order order) let rec combine = function | [] -> Defs [] @@ -65,12 +65,12 @@ let rec combine = function let newtype name size order = let chunks_64 = - Util.list_init (size / 64) (fun i -> Printf.sprintf "%s_chunk_%i : vector(64, %s, bit)" name i (string_of_order order)) + Util.list_init (size / 64) (fun i -> Printf.sprintf "%s_chunk_%i : bitvector(64, %s)" name i (string_of_order order)) in let chunks = if size mod 64 = 0 then chunks_64 else let chunk_rem = - Printf.sprintf "%s_chunk_%i : vector(%i, %s, bit)" name (List.length chunks_64) (size mod 64) (string_of_order order) + Printf.sprintf "%s_chunk_%i : bitvector(%i, %s)" name (List.length chunks_64) (size mod 64) (string_of_order order) in chunk_rem :: List.rev chunks_64 in diff --git a/src/initial_check.ml b/src/initial_check.ml index 522faab7..1f15d054 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -859,6 +859,7 @@ let initial_ctx = { ("list", [K_type]); ("register", [K_type]); ("range", [K_int; K_int]); + ("bitvector", [K_int; K_order]); ("vector", [K_int; K_order; K_type]); ("atom", [K_int]); ("implicit", [K_int]); @@ -925,7 +926,7 @@ let undefined_builtin_val_specs = extern_of_string (mk_id "undefined_range") "forall 'n 'm. (atom('n), atom('m)) -> range('n,'m) effect {undef}"; extern_of_string (mk_id "undefined_vector") "forall 'n ('a:Type) ('ord : Order). (atom('n), 'a) -> vector('n, 'ord,'a) effect {undef}"; (* Only used with lem_mwords *) - extern_of_string (mk_id "undefined_bitvector") "forall 'n. atom('n) -> vector('n, dec, bit) effect {undef}"; + extern_of_string (mk_id "undefined_bitvector") "forall 'n. atom('n) -> bitvector('n, dec) effect {undef}"; extern_of_string (mk_id "undefined_unit") "unit -> unit effect {undef}"] let generate_undefineds vs_ids (Defs defs) = diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index b98c53c4..2c9c11ee 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -148,9 +148,8 @@ let rec ctyp_of_typ ctx typ = - If the length is less than 64, then use a small bits type, sbits. - If the length may be larger than 64, use a large bits type lbits. *) | Typ_app (id, [A_aux (A_nexp n, _); - A_aux (A_order ord, _); - A_aux (A_typ (Typ_aux (Typ_id vtyp_id, _)), _)]) - when string_of_id id = "vector" && string_of_id vtyp_id = "bit" -> + A_aux (A_order ord, _)]) + when string_of_id id = "bitvector" -> let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in begin match nexp_simp n with | Nexp_aux (Nexp_constant n, _) when Big_int.less_equal n (Big_int.of_int 64) -> CT_fbits (Big_int.to_int n, direction) @@ -1429,8 +1428,8 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = end | "internal_vector_update", _ -> Printf.sprintf "internal_vector_update_%s" (sgen_ctyp_name ctyp) | "internal_vector_init", _ -> Printf.sprintf "internal_vector_init_%s" (sgen_ctyp_name ctyp) - | "undefined_vector", CT_fbits _ -> "UNDEFINED(fbits)" - | "undefined_vector", CT_lbits _ -> "UNDEFINED(lbits)" + | "undefined_bitvector", CT_fbits _ -> "UNDEFINED(fbits)" + | "undefined_bitvector", CT_lbits _ -> "UNDEFINED(lbits)" | "undefined_bit", _ -> "UNDEFINED(fbits)" | "undefined_vector", _ -> Printf.sprintf "UNDEFINED(vector_%s)" (sgen_ctyp_name ctyp) | fname, _ -> fname diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index f8cb3bcd..90c0022d 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -310,10 +310,10 @@ let rec compile_aval l ctx = function begin let bitstring = List.map value_of_aval_bit avals in let len = List.length avals in - match destruct_vector ctx.tc_env typ with - | Some (_, Ord_aux (Ord_inc, _), _) -> + match destruct_bitvector ctx.tc_env typ with + | Some (_, Ord_aux (Ord_inc, _)) -> [], V_lit (VL_bits (bitstring, false), CT_fbits (len, false)), [] - | Some (_, Ord_aux (Ord_dec, _), _) -> + | Some (_, Ord_aux (Ord_dec, _)) -> [], V_lit (VL_bits (bitstring, true), CT_fbits (len, true)), [] | Some _ -> raise (Reporting.err_general l "Encountered order polymorphic bitvector literal") @@ -337,8 +337,8 @@ let rec compile_aval l ctx = function [iclear (CT_lbits true) gs] (* If we have a bitvector value, that isn't a literal then we need to set bits individually. *) - | AV_vector (avals, Typ_aux (Typ_app (id, [_; A_aux (A_order ord, _); A_aux (A_typ (Typ_aux (Typ_id bit_id, _)), _)]), _)) - when string_of_id bit_id = "bit" && string_of_id id = "vector" && List.length avals <= 64 -> + | AV_vector (avals, Typ_aux (Typ_app (id, [_; A_aux (A_order ord, _)]), _)) + when string_of_id id = "bitvector" && List.length avals <= 64 -> let len = List.length avals in let direction = match ord with | Ord_aux (Ord_inc, _) -> false diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index 897c685a..0d70695b 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -1147,10 +1147,8 @@ let rec ctyp_of_typ ctx typ = CT_list (ctyp_of_typ ctx typ) (* Note that we have to use lbits for zero-length bitvectors because they are not allowed by SMTLIB *) - | Typ_app (id, [A_aux (A_nexp n, _); - A_aux (A_order ord, _); - A_aux (A_typ (Typ_aux (Typ_id vtyp_id, _)), _)]) - when string_of_id id = "vector" && string_of_id vtyp_id = "bit" -> + | Typ_app (id, [A_aux (A_nexp n, _); A_aux (A_order ord, _)]) + when string_of_id id = "bitvector" -> let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in begin match nexp_simp n with | Nexp_aux (Nexp_constant n, _) when Big_int.equal n Big_int.zero -> CT_lbits direction diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 89830d38..d28a2b6e 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -69,7 +69,7 @@ type context = { top_env : Env.t } let empty_ctxt = { early_ret = false; bound_nexps = NexpSet.empty; top_env = Env.empty } - + let print_to_from_interp_value = ref false let langlebar = string "<|" let ranglebar = string "|>" diff --git a/src/process_file.ml b/src/process_file.ml index ae79d5c3..1672663b 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -94,6 +94,7 @@ let default_symbols = List.fold_left (fun set str -> StringSet.add str set) StringSet.empty [ "FEATURE_IMPLICITS"; "FEATURE_CONSTANT_TYPES"; + "FEATURE_BITVECTOR_TYPE"; ] let symbols = ref default_symbols diff --git a/src/rewrites.ml b/src/rewrites.ml index c10d931d..41319b45 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -570,7 +570,7 @@ let remove_vector_concat_pat pat = let eff = effect_of_annot (snd annot) in let (l,_) = annot in let wild _ = P_aux (P_wild,(gen_loc l, mk_tannot env bit_typ eff)) in - if is_vector_typ typ then + if is_vector_typ typ || is_bitvector_typ typ then match p, vector_typ_args_of typ with | P_vector ps,_ -> acc @ ps | _, (Nexp_aux (Nexp_constant length,_),_,_) -> @@ -1990,11 +1990,13 @@ let rewrite_undefined_if_gen always_bitvector env defs = then rewrite_undefined (always_bitvector || !Pretty_print_lem.opt_mwords) env defs else defs -let rec simple_typ (Typ_aux (typ_aux, l) as typ) = Typ_aux (simple_typ_aux typ_aux, l) -and simple_typ_aux = function +let rec simple_typ (Typ_aux (typ_aux, l) as typ) = Typ_aux (simple_typ_aux l typ_aux, l) +and simple_typ_aux l = function | Typ_id id -> Typ_id id | Typ_app (id, [_; _; A_aux (A_typ typ, l)]) when Id.compare id (mk_id "vector") = 0 -> Typ_app (mk_id "list", [A_aux (A_typ (simple_typ typ), l)]) + | Typ_app (id, [_; _]) when Id.compare id (mk_id "bitvector") = 0 -> + Typ_app (mk_id "list", [A_aux (A_typ bit_typ, gen_loc l)]) | Typ_app (id, [_]) when Id.compare id (mk_id "atom") = 0 -> Typ_id (mk_id "int") | Typ_app (id, [_; _]) when Id.compare id (mk_id "range") = 0 -> @@ -2004,7 +2006,7 @@ and simple_typ_aux = function | Typ_app (id, args) -> Typ_app (id, List.concat (List.map simple_typ_arg args)) | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map simple_typ arg_typs, simple_typ ret_typ, effs) | Typ_tup typs -> Typ_tup (List.map simple_typ typs) - | Typ_exist (_, _, Typ_aux (typ, l)) -> simple_typ_aux typ + | Typ_exist (_, _, Typ_aux (typ, l)) -> simple_typ_aux l typ | typ_aux -> typ_aux and simple_typ_arg (A_aux (typ_arg_aux, l)) = match typ_arg_aux with diff --git a/src/sail_lib.ml b/src/sail_lib.ml index 13ed491b..40f5cecf 100644 --- a/src/sail_lib.ml +++ b/src/sail_lib.ml @@ -106,6 +106,11 @@ let rec undefined_vector (len, item) = then [] else item :: undefined_vector (Big_int.sub len (Big_int.of_int 1), item) +let rec undefined_bitvector len = + if Big_int.equal len Big_int.zero + then [] + else B0 :: undefined_vector (Big_int.sub len (Big_int.of_int 1), B0) + let undefined_string () = "" let undefined_unit () = () diff --git a/src/type_check.ml b/src/type_check.ml index 2be68ade..fabcd7b4 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -543,6 +543,7 @@ end = struct ("atom", [K_int]); ("implicit", [K_int]); ("vector", [K_int; K_order; K_type]); + ("bitvector", [K_int; K_order]); ("register", [K_type]); ("bit", []); ("unit", []); @@ -1330,6 +1331,7 @@ let default_order_error_string = "No default Order (if you have set a default Order, move it earlier in the specification)" let dvector_typ env n typ = vector_typ n (Env.get_default_order env) typ +let bits_typ env n = bitvector_typ n (Env.get_default_order env) let add_existential l kopts nc env = let env = List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env kopts in @@ -1388,6 +1390,15 @@ let destruct_vector env typ = in destruct_vector' (Env.expand_synonyms env typ) +let destruct_bitvector env typ = + let destruct_bitvector' = function + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); + A_aux (A_order o, _)] + ), _) when string_of_id id = "bitvector" -> Some (nexp_simp n1, o) + | typ -> None + in + destruct_bitvector' (Env.expand_synonyms env typ) + let rec is_typ_monomorphic (Typ_aux (typ, l)) = match typ with | Typ_id _ -> true @@ -2174,7 +2185,7 @@ let rec rewrite_sizeof' env (Nexp_aux (aux, l) as nexp) = | Typ_app (id, [A_aux (A_nexp n, _)]) when string_of_id id = "atom" -> prove __POS__ env (nc_eq (nvar v) n) - | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var v', _)), _); _; _]) when string_of_id id = "vector" -> + | Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var v', _)), _); _]) when string_of_id id = "bitvector" -> Kid.compare v v' = 0 | _ -> @@ -2355,14 +2366,14 @@ let infer_lit env (L_aux (lit_aux, l) as lit) = begin match Env.get_default_order env with | Ord_aux (Ord_inc, _) | Ord_aux (Ord_dec, _) -> - dvector_typ env (nint (String.length str)) (mk_typ (Typ_id (mk_id "bit"))) + bits_typ env (nint (String.length str)) | Ord_aux (Ord_var _, _) -> typ_error env l default_order_error_string end | L_hex str -> begin match Env.get_default_order env with | Ord_aux (Ord_inc, _) | Ord_aux (Ord_dec, _) -> - dvector_typ env (nint (String.length str * 4)) (mk_typ (Typ_id (mk_id "bit"))) + bits_typ env (nint (String.length str * 4)) | Ord_aux (Ord_var _, _) -> typ_error env l default_order_error_string end | L_undef -> typ_error env l "Cannot infer the type of undefined" @@ -2409,16 +2420,41 @@ let instantiate_simple_equations = inst_from_eq quants in inst_from_eq -let destruct_vec_typ l env typ = - let destruct_vec_typ' l = function +type destructed_vector = + | Destruct_vector of nexp * order * typ + | Destruct_bitvector of nexp * order + +let destruct_any_vector_typ l env typ = + let destruct_any_vector_typ' l = function + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); + A_aux (A_order o, _)] + ), _) when string_of_id id = "bitvector" -> Destruct_bitvector (n1, o) | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); A_aux (A_order o, _); A_aux (A_typ vtyp, _)] - ), _) when string_of_id id = "vector" -> (n1, o, vtyp) + ), _) when string_of_id id = "vector" -> Destruct_vector (n1, o, vtyp) + | typ -> typ_error env l ("Expected vector or bitvector type, got " ^ string_of_typ typ) + in + destruct_any_vector_typ' l (Env.expand_synonyms env typ) + +let destruct_vector_typ l env typ = + let destruct_vector_typ' l = function + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); + A_aux (A_order o, _); + A_aux (A_typ vtyp, _)] + ), _) when string_of_id id = "vector" -> n1, o, vtyp | typ -> typ_error env l ("Expected vector type, got " ^ string_of_typ typ) in - destruct_vec_typ' l (Env.expand_synonyms env typ) + destruct_vector_typ' l (Env.expand_synonyms env typ) +let destruct_bitvector_typ l env typ = + let destruct_bitvector_typ' l = function + | Typ_aux (Typ_app (id, [A_aux (A_nexp n1, _); + A_aux (A_order o, _)] + ), _) when string_of_id id = "bitvector" -> n1, o + | typ -> typ_error env l ("Expected bitvector type, got " ^ string_of_typ typ) + in + destruct_bitvector_typ' l (Env.expand_synonyms env typ) let env_of_annot (l, tannot) = match tannot with | Some t -> t.env @@ -2894,7 +2930,10 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ let checked_body = crule check_exp env body typ in annot_exp (E_internal_plet (tpat, bind_exp, checked_body)) typ | E_vector vec, _ -> - let (len, ord, vtyp) = destruct_vec_typ l env typ in + let len, ord, vtyp = match destruct_any_vector_typ l env typ with + | Destruct_vector (len, ord, vtyp) -> len, ord, vtyp + | Destruct_bitvector (len, ord) -> len, ord, bit_typ + in let checked_items = List.map (fun i -> crule check_exp env i vtyp) vec in if prove __POS__ env (nc_eq (nint (List.length vec)) (nexp_simp len)) then annot_exp (E_vector checked_items) typ else typ_error env l "List length didn't match" (* FIXME: improve error message *) @@ -3327,8 +3366,9 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = let pats, env, guards = List.fold_left fold_pats ([], env, []) (pat :: pats) in let len = nexp_simp (nint (List.length pats)) in let etyp = typ_of_pat (List.hd pats) in + (* BVS TODO: Non-bitvector P_vector *) List.iter (fun pat -> typ_equality l env etyp (typ_of_pat pat)) pats; - annot_pat (P_vector pats) (dvector_typ env len etyp), env, guards + annot_pat (P_vector pats) (bits_typ env len), env, guards | P_vector_concat (pat :: pats) -> let fold_pats (pats, env, guards) pat = let inferred_pat, env, guards' = infer_pat env pat in @@ -3336,14 +3376,23 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = in let inferred_pats, env, guards = List.fold_left fold_pats ([], env, []) (pat :: pats) in - let (len, _, vtyp) = destruct_vec_typ l env (typ_of_pat (List.hd inferred_pats)) in - let fold_len len pat = - let (len', _, vtyp') = destruct_vec_typ l env (typ_of_pat pat) in - typ_equality l env vtyp vtyp'; - nsum len len' - in - let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_pats)) in - annot_pat (P_vector_concat inferred_pats) (dvector_typ env len vtyp), env, guards + begin match destruct_any_vector_typ l env (typ_of_pat (List.hd inferred_pats)) with + | Destruct_vector (len, _, vtyp) -> + let fold_len len pat = + let (len', _, vtyp') = destruct_vector_typ l env (typ_of_pat pat) in + typ_equality l env vtyp vtyp'; + nsum len len' + in + let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_pats)) in + annot_pat (P_vector_concat inferred_pats) (dvector_typ env len vtyp), env, guards + | Destruct_bitvector (len, _) -> + let fold_len len pat = + let (len', _) = destruct_bitvector_typ l env (typ_of_pat pat) in + nsum len len' + in + let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_pats)) in + annot_pat (P_vector_concat inferred_pats) (bits_typ env len), env, guards + end | P_string_append pats -> let fold_pats (pats, env, guards) pat = let inferred_pat, env, guards' = infer_pat env pat in @@ -3545,8 +3594,7 @@ and infer_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) = let inferred_v_lexp = infer_lexp env v_lexp in let (Typ_aux (v_typ_aux, _) as v_typ) = Env.expand_synonyms env (lexp_typ_of inferred_v_lexp) in match v_typ_aux with - | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) - when Id.compare id (mk_id "vector") = 0 -> + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) when Id.compare id (mk_id "bitvector") = 0 -> let inferred_exp1 = infer_exp env exp1 in let inferred_exp2 = infer_exp env exp2 in let nexp1, env = bind_numeric l (typ_of inferred_exp1) env in @@ -3554,10 +3602,10 @@ and infer_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) = begin match ord with | Ord_aux (Ord_inc, _) when !opt_no_lexp_bounds_check || prove __POS__ env (nc_lteq nexp1 nexp2) -> let len = nexp_simp (nsum (nminus nexp2 nexp1) (nint 1)) in - annot_lexp (LEXP_vector_range (inferred_v_lexp, inferred_exp1, inferred_exp2)) (vector_typ len ord elem_typ) + annot_lexp (LEXP_vector_range (inferred_v_lexp, inferred_exp1, inferred_exp2)) (bitvector_typ len ord) | Ord_aux (Ord_dec, _) when !opt_no_lexp_bounds_check || prove __POS__ env (nc_gteq nexp1 nexp2) -> let len = nexp_simp (nsum (nminus nexp1 nexp2) (nint 1)) in - annot_lexp (LEXP_vector_range (inferred_v_lexp, inferred_exp1, inferred_exp2)) (vector_typ len ord elem_typ) + annot_lexp (LEXP_vector_range (inferred_v_lexp, inferred_exp1, inferred_exp2)) (bitvector_typ len ord) | _ -> typ_error env l ("Could not infer length of vector slice assignment " ^ string_of_lexp lexp) end | _ -> typ_error env l "Cannot assign slice of non vector type" @@ -3575,12 +3623,20 @@ and infer_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) = annot_lexp (LEXP_vector (inferred_v_lexp, inferred_exp)) elem_typ else typ_error env l ("Vector assignment not provably in bounds " ^ string_of_lexp lexp) + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) + when Id.compare id (mk_id "bitvector") = 0 -> + let inferred_exp = infer_exp env exp in + let nexp, env = bind_numeric l (typ_of inferred_exp) env in + if !opt_no_lexp_bounds_check || prove __POS__ env (nc_and (nc_lteq (nint 0) nexp) (nc_lteq nexp (nexp_simp (nminus len (nint 1))))) then + annot_lexp (LEXP_vector (inferred_v_lexp, inferred_exp)) bit_typ + else + typ_error env l ("Vector assignment not provably in bounds " ^ string_of_lexp lexp) | _ -> typ_error env l "Cannot assign vector element of non vector type" end | LEXP_vector_concat [] -> typ_error env l "Cannot have empty vector concatenation l-expression" | LEXP_vector_concat (v_lexp :: v_lexps) -> begin - let sum_lengths first_ord first_elem_typ acc (Typ_aux (v_typ_aux, _) as v_typ) = + let sum_vector_lengths first_ord first_elem_typ acc (Typ_aux (v_typ_aux, _) as v_typ) = match v_typ_aux with | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) when Id.compare id (mk_id "vector") = 0 && ord_identical ord first_ord -> @@ -3588,6 +3644,13 @@ and infer_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) = nsum acc len | _ -> typ_error env l "Vector concatentation l-expression must only contain vector types of the same order" in + let sum_bitvector_lengths first_ord acc (Typ_aux (v_typ_aux, _) as v_typ) = + match v_typ_aux with + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) + when Id.compare id (mk_id "bitvector") = 0 && ord_identical ord first_ord -> + nsum acc len + | _ -> typ_error env l "Bitvector concatentation l-expression must only contain bitvector types of the same order" + in let inferred_v_lexp = infer_lexp env v_lexp in let inferred_v_lexps = List.map (infer_lexp env) v_lexps in let (Typ_aux (v_typ_aux, _) as v_typ) = Env.expand_synonyms env (lexp_typ_of inferred_v_lexp) in @@ -3595,9 +3658,13 @@ and infer_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) = match v_typ_aux with | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _); A_aux (A_typ elem_typ, _)]) when Id.compare id (mk_id "vector") = 0 -> - let len = List.fold_left (sum_lengths ord elem_typ) len v_typs in + let len = List.fold_left (sum_vector_lengths ord elem_typ) len v_typs in annot_lexp (LEXP_vector_concat (inferred_v_lexp :: inferred_v_lexps)) (vector_typ (nexp_simp len) ord elem_typ) - | _ -> typ_error env l ("Vector concatentation l-expression must only contain vector types, found " ^ string_of_typ v_typ) + | Typ_app (id, [A_aux (A_nexp len, _); A_aux (A_order ord, _)]) + when Id.compare id (mk_id "bitvector") = 0 -> + let len = List.fold_left (sum_bitvector_lengths ord) len v_typs in + annot_lexp (LEXP_vector_concat (inferred_v_lexp :: inferred_v_lexps)) (bitvector_typ (nexp_simp len) ord) + | _ -> typ_error env l ("Vector concatentation l-expression must only contain bitvector or vector types, found " ^ string_of_typ v_typ) end | LEXP_field (LEXP_aux (LEXP_id v, _), fid) -> (* FIXME: will only work for ASL *) @@ -3795,8 +3862,14 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | E_vector ((item :: items) as vec) -> let inferred_item = irule infer_exp env item in let checked_items = List.map (fun i -> crule check_exp env i (typ_of inferred_item)) items in - let vec_typ = dvector_typ env (nint (List.length vec)) (typ_of inferred_item) in - annot_exp (E_vector (inferred_item :: checked_items)) vec_typ + begin match typ_of inferred_item with + | Typ_aux (Typ_id id, _) when string_of_id id = "bit" -> + let bitvec_typ = bits_typ env (nint (List.length vec)) in + annot_exp (E_vector (inferred_item :: checked_items)) bitvec_typ + | _ -> + let vec_typ = dvector_typ env (nint (List.length vec)) (typ_of inferred_item) in + annot_exp (E_vector (inferred_item :: checked_items)) vec_typ + end | E_assert (test, msg) -> let msg = assert_msg msg in let checked_test = crule check_exp env test bool_typ in @@ -4255,14 +4328,23 @@ and infer_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) if allow_unknown && List.exists (fun mpat -> is_unknown_type (typ_of_mpat mpat)) inferred_mpats then annot_mpat (MP_vector_concat inferred_mpats) unknown_typ, env, guards (* hack *) else - let (len, _, vtyp) = destruct_vec_typ l env (typ_of_mpat (List.hd inferred_mpats)) in - let fold_len len mpat = - let (len', _, vtyp') = destruct_vec_typ l env (typ_of_mpat mpat) in - typ_equality l env vtyp vtyp'; - nsum len len' - in - let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_mpats)) in - annot_mpat (MP_vector_concat inferred_mpats) (dvector_typ env len vtyp), env, guards + begin match destruct_any_vector_typ l env (typ_of_mpat (List.hd inferred_mpats)) with + | Destruct_vector (len, _, vtyp) -> + let fold_len len mpat = + let (len', _, vtyp') = destruct_vector_typ l env (typ_of_mpat mpat) in + typ_equality l env vtyp vtyp'; + nsum len len' + in + let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_mpats)) in + annot_mpat (MP_vector_concat inferred_mpats) (dvector_typ env len vtyp), env, guards + | Destruct_bitvector (len, _) -> + let fold_len len mpat = + let (len', _) = destruct_bitvector_typ l env (typ_of_mpat mpat) in + nsum len len' + in + let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_mpats)) in + annot_mpat (MP_vector_concat inferred_mpats) (bits_typ env len), env, guards + end | MP_string_append mpats -> let fold_pats (pats, env, guards) pat = let inferred_pat, env, guards' = infer_mpat allow_unknown other_env env pat in @@ -4988,9 +5070,8 @@ let rec check_typedef : 'a. Env.t -> 'a type_def -> (tannot def) list * Env.t = match typ with (* The type of a bitfield must be a constant-width bitvector *) | Typ_aux (Typ_app (v, [A_aux (A_nexp (Nexp_aux (Nexp_constant size, _)), _); - A_aux (A_order order, _); - A_aux (A_typ (Typ_aux (Typ_id b, _)), _)]), _) - when string_of_id v = "vector" && string_of_id b = "bit" -> + A_aux (A_order order, _)]), _) + when string_of_id v = "bitvector" -> let size = Big_int.to_int size in let eval_index_nexp env nexp = int_of_nexp_opt (nexp_simp (Env.expand_nexp_synonyms env nexp)) in diff --git a/src/type_check.mli b/src/type_check.mli index dcedcc90..711f2411 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -398,6 +398,7 @@ val destruct_range : Env.t -> typ -> (kid list * n_constraint * nexp * nexp) opt val destruct_numeric : ?name:string option -> typ -> (kid list * n_constraint * nexp) option val destruct_vector : Env.t -> typ -> (nexp * order * typ) option +val destruct_bitvector : Env.t -> typ -> (nexp * order) option (** Construct an existential type with a guaranteed fresh identifier. *) |
