summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/constant_propagation_mutrec.ml42
-rw-r--r--src/monomorphise.ml31
-rw-r--r--src/sail.ml3
-rw-r--r--src/state.ml108
-rw-r--r--test/mono/castreq.sail31
5 files changed, 152 insertions, 63 deletions
diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml
index 6cc6d28c..03d8e154 100644
--- a/src/constant_propagation_mutrec.ml
+++ b/src/constant_propagation_mutrec.ml
@@ -97,7 +97,8 @@ let generate_fun_id id args =
that will be propagated in *)
let generate_val_spec env id args l annot =
match Env.get_val_spec_orig id env with
- | tq, Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) ->
+ | tq, (Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) as fn_typ) ->
+ (* Get instantiation of type variables at call site *)
let orig_ksubst (kid, typ_arg) =
match typ_arg with
| A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg)
@@ -110,21 +111,38 @@ let generate_val_spec env id args l annot =
|> List.map orig_ksubst
|> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty
in
+ (* Apply instantiation to original function type. Also collect the
+ type variables in the new type together their kinds for the new
+ val spec. *)
+ let kopts_of_typ env typ =
+ tyvars_of_typ typ |> KidSet.elements
+ |> List.map (fun kid -> mk_kopt (Env.get_typ_var kid env) kid)
+ |> KOptSet.of_list
+ in
let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in
- let arg_typs' =
- List.map (KBindings.fold typ_subst ksubsts) arg_typs
- |> List.map2 (fun arg typ -> if is_const_exp arg then [] else [typ]) args
- |> List.concat
- |> function [] -> [unit_typ] | typs -> typs
+ let (arg_typs', kopts') =
+ List.fold_right2 (fun arg typ (arg_typs', kopts') ->
+ if is_const_exp arg then
+ (arg_typs', kopts')
+ else
+ let typ' = KBindings.fold typ_subst ksubsts typ in
+ let arg_kopts = kopts_of_typ (env_of arg) typ' in
+ (typ' :: arg_typs', KOptSet.union arg_kopts kopts'))
+ args arg_typs ([], kopts_of_typ (env_of_tannot annot) ret_typ')
in
+ let arg_typs' = if arg_typs' = [] then [unit_typ] else arg_typs' in
let typ' = mk_typ (Typ_fn (arg_typs', ret_typ', eff)) in
- let tyvars = tyvars_of_typ typ' in
- let tq' =
- quant_items tq |>
- List.filter (fun qi -> KidSet.subset (tyvars_of_quant_item qi) tyvars) |>
- mk_typquant
+ (* Construct new val spec *)
+ let constraints' =
+ quant_split tq |> snd
+ |> List.map (KBindings.fold constraint_subst ksubsts)
+ |> List.filter (fun nc -> KidSet.subset (tyvars_of_constraint nc) (tyvars_of_typ typ'))
+ in
+ let quant_items' =
+ List.map mk_qi_kopt (KOptSet.elements kopts') @
+ List.map mk_qi_nc constraints'
in
- let typschm = mk_typschm tq' typ' in
+ let typschm = mk_typschm (mk_typquant quant_items') typ' in
mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, [], false)),
ksubsts
| _, Typ_aux (_, l) ->
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 0785c3cd..645f6f72 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -3263,15 +3263,15 @@ let ids_in_exp exp =
lEXP_cast = (fun (_,id) -> IdSet.singleton id)
} exp
-let make_bitvector_env_casts env quant_kids (kid,i) exp =
- let mk_cast var typ exp = (make_bitvector_cast_let "bitvector_cast_in" env env quant_kids typ (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ)) var exp in
+let make_bitvector_env_casts top_env env quant_kids insts exp =
+ let mk_cast var typ exp = (make_bitvector_cast_let "bitvector_cast_in" env top_env quant_kids typ (subst_kids_typ insts typ)) var exp in
let mk_assign_in var typ =
- make_bitvector_cast_assign "bitvector_cast_in" env env quant_kids typ
- (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ) var
+ make_bitvector_cast_assign "bitvector_cast_in" env top_env quant_kids typ
+ (subst_kids_typ insts typ) var
in
let mk_assign_out var typ =
- make_bitvector_cast_assign "bitvector_cast_out" env env quant_kids
- (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ) typ var
+ make_bitvector_cast_assign "bitvector_cast_out" top_env env quant_kids
+ (subst_kids_typ insts typ) typ var
in
let locals = Env.get_locals env in
let used_ids = ids_in_exp exp in
@@ -3405,13 +3405,13 @@ let add_bitvector_casts (Defs defs) =
(* We used to just substitute kid, but fill_in_type also catches other kids defined by it *)
let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in
make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
- (make_bitvector_env_casts env quant_kids (kid,i) body)
+ (make_bitvector_env_casts env (env_of body) quant_kids (KBindings.singleton kid (nconstant i)) body)
| P_aux (P_id var,_), Some guard ->
(match extract_value_from_guard var guard with
| Some i ->
let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in
make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
- (make_bitvector_env_casts env quant_kids (kid,i) body)
+ (make_bitvector_env_casts env (env_of body) quant_kids (KBindings.singleton kid (nconstant i)) body)
| None -> body)
| _ ->
body
@@ -3425,10 +3425,9 @@ let add_bitvector_casts (Defs defs) =
let env = env_of_annot ann in
let result_typ = Env.base_typ_of env (typ_of_annot ann) in
let insts = extract e1 in
- let e2' = List.fold_left (fun body inst ->
- make_bitvector_env_casts env quant_kids inst body) e2 insts in
let insts = List.fold_left (fun insts (kid,i) ->
KBindings.add kid (nconstant i) insts) KBindings.empty insts in
+ let e2' = make_bitvector_env_casts env (env_of e2) quant_kids insts e2 in
let src_typ = subst_kids_typ insts result_typ in
let e2' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ e2' in
(* Ask the type checker if only one value remains for any of kids in
@@ -3437,13 +3436,10 @@ let add_bitvector_casts (Defs defs) =
let insts3 = KBindings.fold (fun kid _ i3 ->
match Type_check.solve_unique env3 (nvar kid) with
| None -> i3
- | Some c -> (kid, c)::i3)
- insts []
+ | Some c -> KBindings.add kid (nconstant c) i3)
+ insts KBindings.empty
in
- let e3' = List.fold_left (fun body inst ->
- make_bitvector_env_casts env quant_kids inst body) e3 insts3 in
- let insts3 = List.fold_left (fun insts (kid,i) ->
- KBindings.add kid (nconstant i) insts) KBindings.empty insts3 in
+ let e3' = make_bitvector_env_casts env (env_of e3) quant_kids insts3 e3 in
let src_typ3 = subst_kids_typ insts3 result_typ in
let e3' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ3 result_typ e3' in
E_aux (E_if (e1,e2',e3'), ann)
@@ -3462,10 +3458,9 @@ let add_bitvector_casts (Defs defs) =
let t' = aux t in
let et = E_aux (E_block t',ann) in
let env = env_of h in
- let et = List.fold_left (fun body inst ->
- make_bitvector_env_casts env quant_kids inst body) et insts in
let insts = List.fold_left (fun insts (kid,i) ->
KBindings.add kid (nconstant i) insts) KBindings.empty insts in
+ let et = make_bitvector_env_casts env (env_of et) quant_kids insts et in
let src_typ = subst_kids_typ insts result_typ in
let et = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ et in
diff --git a/src/sail.ml b/src/sail.ml
index 4324d650..b355483d 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -299,6 +299,9 @@ let options = Arg.align ([
( "-undefined_gen",
Arg.Set Initial_check.opt_undefined_gen,
" generate undefined_type functions for types in the specification");
+ ( "-grouped_regstate",
+ Arg.Set State.opt_type_grouped_regstate,
+ " group registers with same type together in generated register state record");
( "-enum_casts",
Arg.Set Initial_check.opt_enum_casts,
" allow enumerations to be automatically casted to numeric range types");
diff --git a/src/state.ml b/src/state.ml
index 9d79fef0..5af39cd2 100644
--- a/src/state.ml
+++ b/src/state.ml
@@ -58,6 +58,8 @@ open PPrint
open Pretty_print_common
open Pretty_print_sail
+let opt_type_grouped_regstate = ref false
+
let defs_of_string = ast_of_def_string
let is_defined defs name = IdSet.mem (mk_id name) (ids_of_defs (Defs defs))
@@ -78,11 +80,48 @@ let find_registers defs =
| _ -> acc
) [] defs
-let generate_regstate = function
- | [] -> ["type regstate = unit"]
+let generate_register_id_enum = function
+ | [] -> ["type register_id = unit"]
| registers ->
- let reg (typ, id) = Printf.sprintf "%s : %s" (string_of_id id) (to_string (doc_typ typ)) in
- ["struct regstate = { " ^ (String.concat ", " (List.map reg registers)) ^ " }"]
+ let reg (typ, id) = string_of_id id in
+ ["type register_id = " ^ String.concat " | " (List.map reg registers)]
+
+let rec id_of_regtyp builtins mwords (Typ_aux (t, l) as typ) = match t with
+ | Typ_id id -> id
+ | Typ_app (id, args) ->
+ let name_arg (A_aux (targ, _)) = match targ with
+ | A_typ targ -> string_of_id (id_of_regtyp builtins mwords targ)
+ | A_nexp nexp when is_nexp_constant (nexp_simp nexp) ->
+ string_of_nexp (nexp_simp nexp)
+ | A_order (Ord_aux (Ord_inc, _)) -> "inc"
+ | A_order (Ord_aux (Ord_dec, _)) -> "dec"
+ | _ ->
+ raise (Reporting.err_typ l "Unsupported register type")
+ in
+ if IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) then id else
+ append_id id (String.concat "_" ("" :: List.map name_arg args))
+ | _ -> raise (Reporting.err_typ l "Unsupported register type")
+
+let regstate_field typ = append_id (id_of_regtyp IdSet.empty false typ) "_reg"
+
+let generate_regstate registers =
+ let regstate_def =
+ if registers = [] then
+ TD_abbrev (mk_id "regstate", mk_typquant [], mk_typ_arg (A_typ unit_typ))
+ else
+ let fields =
+ if !opt_type_grouped_regstate then
+ List.map
+ (fun (typ, id) ->
+ (function_typ [string_typ] typ no_effect,
+ regstate_field typ))
+ registers
+ |> List.sort_uniq (fun (typ1, id1) (typ2, id2) -> Id.compare id1 id2)
+ else registers
+ in
+ TD_record (mk_id "regstate", mk_typquant [], fields, false)
+ in
+ Defs [DEF_type (TD_aux (regstate_def, (Unknown, ())))]
let generate_initial_regstate defs =
let registers = find_registers defs in
@@ -181,27 +220,15 @@ let generate_initial_regstate defs =
| _ -> inits) ([], Bindings.empty) defs
in
let init_reg (typ, id) = string_of_id id ^ " = " ^ lookup_init_val init_vals typ in
- init_defs @
- ["let initial_regstate : regstate = struct { " ^ (String.concat ", " (List.map init_reg registers)) ^ " }"]
+ List.map defs_of_string
+ (init_defs @
+ ["let initial_regstate : regstate = struct { " ^
+ (String.concat ", " (List.map init_reg registers)) ^
+ " }"])
with
| _ -> [] (* Do not generate an initial register state if anything goes wrong *)
-let rec regval_constr_id mwords (Typ_aux (t, l) as typ) = match t with
- | Typ_id id -> id
- | Typ_app (id, args) ->
- let name_arg (A_aux (targ, _)) = match targ with
- | A_typ targ -> string_of_id (regval_constr_id mwords targ)
- | A_nexp nexp when is_nexp_constant (nexp_simp nexp) ->
- string_of_nexp (nexp_simp nexp)
- | A_order (Ord_aux (Ord_inc, _)) -> "inc"
- | A_order (Ord_aux (Ord_dec, _)) -> "dec"
- | _ ->
- raise (Reporting.err_typ l "Unsupported register type")
- in
- let builtins = IdSet.of_list (List.map mk_id ["vector"; "bitvector"; "list"; "option"]) in
- if IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) then id else
- append_id id (String.concat "_" ("" :: List.map name_arg args))
- | _ -> raise (Reporting.err_typ l "Unsupported register type")
+let regval_constr_id = id_of_regtyp (IdSet.of_list (List.map mk_id ["vector"; "bitvector"; "list"; "option"]))
let register_base_types mwords typs =
let rec add_base_typs typs (Typ_aux (t, _) as typ) =
@@ -211,8 +238,8 @@ let register_base_types mwords typs =
when IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) ->
let add_typ_arg base_typs (A_aux (targ, _)) =
match targ with
- | A_typ typ -> add_base_typs typs typ
- | _ -> typs
+ | A_typ typ -> add_base_typs base_typs typ
+ | _ -> base_typs
in
List.fold_left add_typ_arg typs args
| _ -> Bindings.add (regval_constr_id mwords typ) typ typs
@@ -227,9 +254,10 @@ let generate_regval_typ typs =
"Regval_list : list(register_value), " ^
"Regval_option : option(register_value)"
in
- ["union register_value = { " ^
- (String.concat ", " (builtins :: List.map constr (Bindings.bindings typs))) ^
- " }"]
+ [defs_of_string
+ ("union register_value = { " ^
+ (String.concat ", " (builtins :: List.map constr (Bindings.bindings typs))) ^
+ " }")]
let add_regval_conv id typ (Defs defs) =
let id = string_of_id id in
@@ -307,12 +335,20 @@ let register_refs_lem mwords registers =
in
let register_ref (typ, id) =
let idd = string (string_of_id id) in
+ let (read_from, write_to) =
+ if !opt_type_grouped_regstate then
+ let field_idd = string (string_of_id (regstate_field typ)) in
+ (field_idd ^^ space ^^ dquotes idd,
+ doc_op equals field_idd (string "(fun reg -> if reg = \"" ^^ idd ^^ string "\" then v else s." ^^ field_idd ^^ string " reg)"))
+ else
+ (idd, doc_op equals idd (string "v"))
+ in
(* let field = if prefix_recordtype then string "regstate_" ^^ idd else idd in *)
let of_regval, regval_of = regval_convs_lem mwords typ in
concat [string "let "; idd; string "_ref = <|"; hardline;
string " name = \""; idd; string "\";"; hardline;
- string " read_from = (fun s -> s."; idd; string ");"; hardline;
- string " write_to = (fun v s -> (<| s with "; idd; string " = v |>));"; hardline;
+ string " read_from = (fun s -> s."; read_from; string ");"; hardline;
+ string " write_to = (fun v s -> (<| s with "; write_to; string " |>));"; hardline;
string " of_regval = "; string of_regval; string ";"; hardline;
string " regval_of = "; string regval_of; string " |>"; hardline]
in
@@ -504,14 +540,20 @@ let generate_regstate_defs mwords defs =
let regtyps = register_base_types mwords (List.map fst registers) in
let option_typ =
if is_defined defs "option" then [] else
- ["union option ('a : Type) = {None : unit, Some : 'a}"]
+ [defs_of_string "union option ('a : Type) = {None : unit, Some : 'a}"]
in
let regval_typ = if is_defined defs "register_value" then [] else generate_regval_typ regtyps in
- let regstate_typ = if is_defined defs "regstate" then [] else generate_regstate registers in
- let initregstate = if is_defined defs "initial_regstate" then [] else generate_initial_regstate defs in
+ let regstate_typ = if is_defined defs "regstate" then [] else [generate_regstate registers] in
+ let initregstate =
+ (* Don't create initial regstate if it is already defined or if we generated
+ a regstate record with registers grouped per type; the latter would
+ require record fields storing functions, which is not supported in
+ Sail. *)
+ if is_defined defs "initial_regstate" || !opt_type_grouped_regstate then [] else
+ generate_initial_regstate defs
+ in
let defs =
option_typ @ regval_typ @ regstate_typ @ initregstate
- |> List.map defs_of_string
|> concat_ast
|> Bindings.fold add_regval_conv regtyps
in
diff --git a/test/mono/castreq.sail b/test/mono/castreq.sail
index b1df7010..75791bfd 100644
--- a/test/mono/castreq.sail
+++ b/test/mono/castreq.sail
@@ -93,6 +93,33 @@ function assign3(x) = {
y
}
+/* Test that matching on a variable which happens to fix a bitvector variable's
+ size updates the environment properly. */
+
+val assign4 : forall 'm, 'm in {1,2}. (implicit('m),bits(8*'m)) -> bits(8*'m)
+
+function assign4(m,x) = {
+ y : bits(8*'m) = x;
+ match m {
+ 1 => y = y + 0x01,
+ 2 => y[7..0] = 0x89
+ };
+ y
+}
+
+/* The same as assign4, except with a distinct type variable. */
+
+val assign5 : forall 'm 'n, 'm in {1,2} & 'n == 8 * 'm. (implicit('m),bits('n)) -> bits('n)
+
+function assign5(m,x) = {
+ y : bits('n) = x;
+ match m {
+ 1 => y = y + 0x01,
+ 2 => y[7..0] = 0x89
+ };
+ y
+}
+
/* Adding casts for top-level pattern matches */
val foo2 : forall 'm 'n, 'm in {8,16} & 'n in {32,64}. (atom('n), bits('m)) -> bits('n) effect pure
@@ -140,6 +167,10 @@ function run () = {
assert(assign2(0x1234) == 0x00001234);
assert(assign3(0x12) == 0x13);
assert(assign3(0x1234) == 0x1289);
+ assert(assign4(0x12) == 0x13);
+ assert(assign4(0x1234) == 0x1289);
+ assert(assign5(0x12) == 0x13);
+ assert(assign5(0x1234) == 0x1289);
assert(foo2(32,0x12) == 0x00120012);
assert(foo2(64,0x12) == 0x0012001200120012);
assert(foo3(4,0x12) == 0x00120012);