diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/constant_propagation_mutrec.ml | 42 | ||||
| -rw-r--r-- | src/monomorphise.ml | 31 | ||||
| -rw-r--r-- | src/sail.ml | 3 | ||||
| -rw-r--r-- | src/state.ml | 108 |
4 files changed, 121 insertions, 63 deletions
diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml index 6cc6d28c..03d8e154 100644 --- a/src/constant_propagation_mutrec.ml +++ b/src/constant_propagation_mutrec.ml @@ -97,7 +97,8 @@ let generate_fun_id id args = that will be propagated in *) let generate_val_spec env id args l annot = match Env.get_val_spec_orig id env with - | tq, Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) -> + | tq, (Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) as fn_typ) -> + (* Get instantiation of type variables at call site *) let orig_ksubst (kid, typ_arg) = match typ_arg with | A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg) @@ -110,21 +111,38 @@ let generate_val_spec env id args l annot = |> List.map orig_ksubst |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty in + (* Apply instantiation to original function type. Also collect the + type variables in the new type together their kinds for the new + val spec. *) + let kopts_of_typ env typ = + tyvars_of_typ typ |> KidSet.elements + |> List.map (fun kid -> mk_kopt (Env.get_typ_var kid env) kid) + |> KOptSet.of_list + in let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in - let arg_typs' = - List.map (KBindings.fold typ_subst ksubsts) arg_typs - |> List.map2 (fun arg typ -> if is_const_exp arg then [] else [typ]) args - |> List.concat - |> function [] -> [unit_typ] | typs -> typs + let (arg_typs', kopts') = + List.fold_right2 (fun arg typ (arg_typs', kopts') -> + if is_const_exp arg then + (arg_typs', kopts') + else + let typ' = KBindings.fold typ_subst ksubsts typ in + let arg_kopts = kopts_of_typ (env_of arg) typ' in + (typ' :: arg_typs', KOptSet.union arg_kopts kopts')) + args arg_typs ([], kopts_of_typ (env_of_tannot annot) ret_typ') in + let arg_typs' = if arg_typs' = [] then [unit_typ] else arg_typs' in let typ' = mk_typ (Typ_fn (arg_typs', ret_typ', eff)) in - let tyvars = tyvars_of_typ typ' in - let tq' = - quant_items tq |> - List.filter (fun qi -> KidSet.subset (tyvars_of_quant_item qi) tyvars) |> - mk_typquant + (* Construct new val spec *) + let constraints' = + quant_split tq |> snd + |> List.map (KBindings.fold constraint_subst ksubsts) + |> List.filter (fun nc -> KidSet.subset (tyvars_of_constraint nc) (tyvars_of_typ typ')) + in + let quant_items' = + List.map mk_qi_kopt (KOptSet.elements kopts') @ + List.map mk_qi_nc constraints' in - let typschm = mk_typschm tq' typ' in + let typschm = mk_typschm (mk_typquant quant_items') typ' in mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, [], false)), ksubsts | _, Typ_aux (_, l) -> diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 0785c3cd..645f6f72 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -3263,15 +3263,15 @@ let ids_in_exp exp = lEXP_cast = (fun (_,id) -> IdSet.singleton id) } exp -let make_bitvector_env_casts env quant_kids (kid,i) exp = - let mk_cast var typ exp = (make_bitvector_cast_let "bitvector_cast_in" env env quant_kids typ (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ)) var exp in +let make_bitvector_env_casts top_env env quant_kids insts exp = + let mk_cast var typ exp = (make_bitvector_cast_let "bitvector_cast_in" env top_env quant_kids typ (subst_kids_typ insts typ)) var exp in let mk_assign_in var typ = - make_bitvector_cast_assign "bitvector_cast_in" env env quant_kids typ - (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ) var + make_bitvector_cast_assign "bitvector_cast_in" env top_env quant_kids typ + (subst_kids_typ insts typ) var in let mk_assign_out var typ = - make_bitvector_cast_assign "bitvector_cast_out" env env quant_kids - (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ) typ var + make_bitvector_cast_assign "bitvector_cast_out" top_env env quant_kids + (subst_kids_typ insts typ) typ var in let locals = Env.get_locals env in let used_ids = ids_in_exp exp in @@ -3405,13 +3405,13 @@ let add_bitvector_casts (Defs defs) = (* We used to just substitute kid, but fill_in_type also catches other kids defined by it *) let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ - (make_bitvector_env_casts env quant_kids (kid,i) body) + (make_bitvector_env_casts env (env_of body) quant_kids (KBindings.singleton kid (nconstant i)) body) | P_aux (P_id var,_), Some guard -> (match extract_value_from_guard var guard with | Some i -> let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ - (make_bitvector_env_casts env quant_kids (kid,i) body) + (make_bitvector_env_casts env (env_of body) quant_kids (KBindings.singleton kid (nconstant i)) body) | None -> body) | _ -> body @@ -3425,10 +3425,9 @@ let add_bitvector_casts (Defs defs) = let env = env_of_annot ann in let result_typ = Env.base_typ_of env (typ_of_annot ann) in let insts = extract e1 in - let e2' = List.fold_left (fun body inst -> - make_bitvector_env_casts env quant_kids inst body) e2 insts in let insts = List.fold_left (fun insts (kid,i) -> KBindings.add kid (nconstant i) insts) KBindings.empty insts in + let e2' = make_bitvector_env_casts env (env_of e2) quant_kids insts e2 in let src_typ = subst_kids_typ insts result_typ in let e2' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ e2' in (* Ask the type checker if only one value remains for any of kids in @@ -3437,13 +3436,10 @@ let add_bitvector_casts (Defs defs) = let insts3 = KBindings.fold (fun kid _ i3 -> match Type_check.solve_unique env3 (nvar kid) with | None -> i3 - | Some c -> (kid, c)::i3) - insts [] + | Some c -> KBindings.add kid (nconstant c) i3) + insts KBindings.empty in - let e3' = List.fold_left (fun body inst -> - make_bitvector_env_casts env quant_kids inst body) e3 insts3 in - let insts3 = List.fold_left (fun insts (kid,i) -> - KBindings.add kid (nconstant i) insts) KBindings.empty insts3 in + let e3' = make_bitvector_env_casts env (env_of e3) quant_kids insts3 e3 in let src_typ3 = subst_kids_typ insts3 result_typ in let e3' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ3 result_typ e3' in E_aux (E_if (e1,e2',e3'), ann) @@ -3462,10 +3458,9 @@ let add_bitvector_casts (Defs defs) = let t' = aux t in let et = E_aux (E_block t',ann) in let env = env_of h in - let et = List.fold_left (fun body inst -> - make_bitvector_env_casts env quant_kids inst body) et insts in let insts = List.fold_left (fun insts (kid,i) -> KBindings.add kid (nconstant i) insts) KBindings.empty insts in + let et = make_bitvector_env_casts env (env_of et) quant_kids insts et in let src_typ = subst_kids_typ insts result_typ in let et = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ et in diff --git a/src/sail.ml b/src/sail.ml index 4324d650..b355483d 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -299,6 +299,9 @@ let options = Arg.align ([ ( "-undefined_gen", Arg.Set Initial_check.opt_undefined_gen, " generate undefined_type functions for types in the specification"); + ( "-grouped_regstate", + Arg.Set State.opt_type_grouped_regstate, + " group registers with same type together in generated register state record"); ( "-enum_casts", Arg.Set Initial_check.opt_enum_casts, " allow enumerations to be automatically casted to numeric range types"); diff --git a/src/state.ml b/src/state.ml index 9d79fef0..5af39cd2 100644 --- a/src/state.ml +++ b/src/state.ml @@ -58,6 +58,8 @@ open PPrint open Pretty_print_common open Pretty_print_sail +let opt_type_grouped_regstate = ref false + let defs_of_string = ast_of_def_string let is_defined defs name = IdSet.mem (mk_id name) (ids_of_defs (Defs defs)) @@ -78,11 +80,48 @@ let find_registers defs = | _ -> acc ) [] defs -let generate_regstate = function - | [] -> ["type regstate = unit"] +let generate_register_id_enum = function + | [] -> ["type register_id = unit"] | registers -> - let reg (typ, id) = Printf.sprintf "%s : %s" (string_of_id id) (to_string (doc_typ typ)) in - ["struct regstate = { " ^ (String.concat ", " (List.map reg registers)) ^ " }"] + let reg (typ, id) = string_of_id id in + ["type register_id = " ^ String.concat " | " (List.map reg registers)] + +let rec id_of_regtyp builtins mwords (Typ_aux (t, l) as typ) = match t with + | Typ_id id -> id + | Typ_app (id, args) -> + let name_arg (A_aux (targ, _)) = match targ with + | A_typ targ -> string_of_id (id_of_regtyp builtins mwords targ) + | A_nexp nexp when is_nexp_constant (nexp_simp nexp) -> + string_of_nexp (nexp_simp nexp) + | A_order (Ord_aux (Ord_inc, _)) -> "inc" + | A_order (Ord_aux (Ord_dec, _)) -> "dec" + | _ -> + raise (Reporting.err_typ l "Unsupported register type") + in + if IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) then id else + append_id id (String.concat "_" ("" :: List.map name_arg args)) + | _ -> raise (Reporting.err_typ l "Unsupported register type") + +let regstate_field typ = append_id (id_of_regtyp IdSet.empty false typ) "_reg" + +let generate_regstate registers = + let regstate_def = + if registers = [] then + TD_abbrev (mk_id "regstate", mk_typquant [], mk_typ_arg (A_typ unit_typ)) + else + let fields = + if !opt_type_grouped_regstate then + List.map + (fun (typ, id) -> + (function_typ [string_typ] typ no_effect, + regstate_field typ)) + registers + |> List.sort_uniq (fun (typ1, id1) (typ2, id2) -> Id.compare id1 id2) + else registers + in + TD_record (mk_id "regstate", mk_typquant [], fields, false) + in + Defs [DEF_type (TD_aux (regstate_def, (Unknown, ())))] let generate_initial_regstate defs = let registers = find_registers defs in @@ -181,27 +220,15 @@ let generate_initial_regstate defs = | _ -> inits) ([], Bindings.empty) defs in let init_reg (typ, id) = string_of_id id ^ " = " ^ lookup_init_val init_vals typ in - init_defs @ - ["let initial_regstate : regstate = struct { " ^ (String.concat ", " (List.map init_reg registers)) ^ " }"] + List.map defs_of_string + (init_defs @ + ["let initial_regstate : regstate = struct { " ^ + (String.concat ", " (List.map init_reg registers)) ^ + " }"]) with | _ -> [] (* Do not generate an initial register state if anything goes wrong *) -let rec regval_constr_id mwords (Typ_aux (t, l) as typ) = match t with - | Typ_id id -> id - | Typ_app (id, args) -> - let name_arg (A_aux (targ, _)) = match targ with - | A_typ targ -> string_of_id (regval_constr_id mwords targ) - | A_nexp nexp when is_nexp_constant (nexp_simp nexp) -> - string_of_nexp (nexp_simp nexp) - | A_order (Ord_aux (Ord_inc, _)) -> "inc" - | A_order (Ord_aux (Ord_dec, _)) -> "dec" - | _ -> - raise (Reporting.err_typ l "Unsupported register type") - in - let builtins = IdSet.of_list (List.map mk_id ["vector"; "bitvector"; "list"; "option"]) in - if IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) then id else - append_id id (String.concat "_" ("" :: List.map name_arg args)) - | _ -> raise (Reporting.err_typ l "Unsupported register type") +let regval_constr_id = id_of_regtyp (IdSet.of_list (List.map mk_id ["vector"; "bitvector"; "list"; "option"])) let register_base_types mwords typs = let rec add_base_typs typs (Typ_aux (t, _) as typ) = @@ -211,8 +238,8 @@ let register_base_types mwords typs = when IdSet.mem id builtins && not (mwords && is_bitvector_typ typ) -> let add_typ_arg base_typs (A_aux (targ, _)) = match targ with - | A_typ typ -> add_base_typs typs typ - | _ -> typs + | A_typ typ -> add_base_typs base_typs typ + | _ -> base_typs in List.fold_left add_typ_arg typs args | _ -> Bindings.add (regval_constr_id mwords typ) typ typs @@ -227,9 +254,10 @@ let generate_regval_typ typs = "Regval_list : list(register_value), " ^ "Regval_option : option(register_value)" in - ["union register_value = { " ^ - (String.concat ", " (builtins :: List.map constr (Bindings.bindings typs))) ^ - " }"] + [defs_of_string + ("union register_value = { " ^ + (String.concat ", " (builtins :: List.map constr (Bindings.bindings typs))) ^ + " }")] let add_regval_conv id typ (Defs defs) = let id = string_of_id id in @@ -307,12 +335,20 @@ let register_refs_lem mwords registers = in let register_ref (typ, id) = let idd = string (string_of_id id) in + let (read_from, write_to) = + if !opt_type_grouped_regstate then + let field_idd = string (string_of_id (regstate_field typ)) in + (field_idd ^^ space ^^ dquotes idd, + doc_op equals field_idd (string "(fun reg -> if reg = \"" ^^ idd ^^ string "\" then v else s." ^^ field_idd ^^ string " reg)")) + else + (idd, doc_op equals idd (string "v")) + in (* let field = if prefix_recordtype then string "regstate_" ^^ idd else idd in *) let of_regval, regval_of = regval_convs_lem mwords typ in concat [string "let "; idd; string "_ref = <|"; hardline; string " name = \""; idd; string "\";"; hardline; - string " read_from = (fun s -> s."; idd; string ");"; hardline; - string " write_to = (fun v s -> (<| s with "; idd; string " = v |>));"; hardline; + string " read_from = (fun s -> s."; read_from; string ");"; hardline; + string " write_to = (fun v s -> (<| s with "; write_to; string " |>));"; hardline; string " of_regval = "; string of_regval; string ";"; hardline; string " regval_of = "; string regval_of; string " |>"; hardline] in @@ -504,14 +540,20 @@ let generate_regstate_defs mwords defs = let regtyps = register_base_types mwords (List.map fst registers) in let option_typ = if is_defined defs "option" then [] else - ["union option ('a : Type) = {None : unit, Some : 'a}"] + [defs_of_string "union option ('a : Type) = {None : unit, Some : 'a}"] in let regval_typ = if is_defined defs "register_value" then [] else generate_regval_typ regtyps in - let regstate_typ = if is_defined defs "regstate" then [] else generate_regstate registers in - let initregstate = if is_defined defs "initial_regstate" then [] else generate_initial_regstate defs in + let regstate_typ = if is_defined defs "regstate" then [] else [generate_regstate registers] in + let initregstate = + (* Don't create initial regstate if it is already defined or if we generated + a regstate record with registers grouped per type; the latter would + require record fields storing functions, which is not supported in + Sail. *) + if is_defined defs "initial_regstate" || !opt_type_grouped_regstate then [] else + generate_initial_regstate defs + in let defs = option_typ @ regval_typ @ regstate_typ @ initregstate - |> List.map defs_of_string |> concat_ast |> Bindings.fold add_regval_conv regtyps in |
