summaryrefslogtreecommitdiff
path: root/src/specialize.ml
diff options
context:
space:
mode:
authorJon French2019-02-25 12:10:30 +0000
committerJon French2019-02-25 12:10:30 +0000
commit915d75f9c49fa2c2a9d47d189e4224cee16582c9 (patch)
tree77a93e682796977898af0b56e0a61d7689db112e /src/specialize.ml
parenta8a5308e4981b3d09fb2bf0c59d592ef6ae4417e (diff)
parent38656b50ad24df6a29f3a84e50adfcf409131fb0 (diff)
Merge branch 'sail2' into rmem_interpreter
Diffstat (limited to 'src/specialize.ml')
-rw-r--r--src/specialize.ml200
1 files changed, 152 insertions, 48 deletions
diff --git a/src/specialize.ml b/src/specialize.ml
index 00357557..591a415a 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -52,11 +52,26 @@ open Ast
open Ast_util
open Rewriter
-let is_typ_ord_uvar = function
+let is_typ_ord_arg = function
| A_aux (A_typ _, _) -> true
| A_aux (A_order _, _) -> true
| _ -> false
+type specialization = {
+ is_polymorphic : kinded_id -> bool;
+ instantiation_filter : kid -> typ_arg -> bool
+ }
+
+let typ_ord_specialization = {
+ is_polymorphic = (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt);
+ instantiation_filter = (fun _ -> is_typ_ord_arg)
+ }
+
+let int_specialization = {
+ is_polymorphic = is_int_kopt;
+ instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp _, _) -> true | _ -> false)
+ }
+
let rec nexp_simp_typ (Typ_aux (typ_aux, l)) =
let typ_aux = match typ_aux with
| Typ_id v -> Typ_id v
@@ -81,34 +96,43 @@ and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) =
(* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of.
This part of the typechecker API is a bit ugly. *)
-let fix_instantiation instantiation =
- let instantiation = KBindings.bindings (KBindings.filter (fun _ arg -> is_typ_ord_uvar arg) instantiation) in
+let fix_instantiation spec instantiation =
+ let instantiation = KBindings.bindings (KBindings.filter spec.instantiation_filter instantiation) in
let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in
List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation
+(* polymorphic_functions returns all functions that are polymorphic
+ for some set of kinded-identifiers, specified by the is_kopt
+ predicate. For example, polymorphic_functions is_int_kopt will
+ return all Int-polymorphic functions. *)
let rec polymorphic_functions is_kopt (Defs defs) =
match defs with
| DEF_spec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ) , _), id, _, externs), _)) :: defs ->
- let is_type_polymorphic = List.exists is_kopt (quant_kopts typq) in
- if is_type_polymorphic then
+ let is_polymorphic = List.exists is_kopt (quant_kopts typq) in
+ if is_polymorphic then
IdSet.add id (polymorphic_functions is_kopt (Defs defs))
else
polymorphic_functions is_kopt (Defs defs)
| _ :: defs -> polymorphic_functions is_kopt (Defs defs)
| [] -> IdSet.empty
+(* When we specialize a function, we need to generate new name. To do
+ this we take the instantiation that the new function is specialized
+ for and turn that into a string in such a way that alpha-equivalent
+ instantiations always get the same name. We then zencode that
+ string so it is a valid identifier name, and prepend it to the
+ previous function name. *)
let string_of_instantiation instantiation =
let open Type_check in
let kid_names = ref KOptMap.empty in
let kid_counter = ref 0 in
let kid_name kid =
try KOptMap.find kid !kid_names with
- | Not_found -> begin
- let n = string_of_int !kid_counter in
- kid_names := KOptMap.add kid n !kid_names;
- incr kid_counter;
- n
- end
+ | Not_found ->
+ let n = string_of_int !kid_counter in
+ kid_names := KOptMap.add kid n !kid_names;
+ incr kid_counter;
+ n
in
(* We need custom string_of functions to ensure that alpha-equivalent definitions get the same name *)
@@ -121,7 +145,7 @@ let string_of_instantiation instantiation =
| Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")"
| Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")"
| Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")"
- | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_nexp nexps ^ ")"
+ | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_nexp nexps ^ ")"
| Nexp_exp n -> "2 ^ " ^ string_of_nexp n
| Nexp_neg n -> "- " ^ string_of_nexp n
in
@@ -132,7 +156,7 @@ let string_of_instantiation instantiation =
| Typ_id id -> string_of_id id
| Typ_var kid -> kid_name (mk_kopt K_type kid)
| Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")"
- | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")"
+ | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")"
| Typ_fn (arg_typs, ret_typ, eff) ->
"(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ ^ " effect " ^ string_of_effect eff
| Typ_bidir (t1, t2) ->
@@ -160,9 +184,11 @@ let string_of_instantiation instantiation =
kid_name (mk_kopt K_int kid) ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}"
| NC_aux (NC_true, _) -> "true"
| NC_aux (NC_false, _) -> "false"
+ | NC_aux (NC_var kid, _) -> kid_name (mk_kopt K_bool kid)
+ | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")"
in
- let string_of_binding (kid, arg) = string_of_kid kid ^ " => " ^ string_of_typ_arg arg in
+ let string_of_binding (kid, arg) = string_of_kid kid ^ "=>" ^ string_of_typ_arg arg in
Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation))
let id_of_instantiation id instantiation =
@@ -179,12 +205,12 @@ let rec variant_generic_typ id (Defs defs) =
(* Returns a list of all the instantiations of a function id in an
ast. Also works with union constructors, and searches for them in
patterns. *)
-let rec instantiations_of id ast =
+let rec instantiations_of spec id ast =
let instantiations = ref [] in
let inspect_exp = function
| E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 ->
- let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in
instantiations := instantiation :: !instantiations;
exp
| exp -> exp
@@ -202,7 +228,7 @@ let rec instantiations_of id ast =
(variant_generic_typ variant_id ast)
typ
in
- instantiations := fix_instantiation instantiation :: !instantiations;
+ instantiations := fix_instantiation spec instantiation :: !instantiations;
pat
| Typ_aux (Typ_id variant_id, _) -> pat
| _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type")
@@ -218,12 +244,12 @@ let rec instantiations_of id ast =
!instantiations
-let rec rewrite_polymorphic_calls id ast =
+let rec rewrite_polymorphic_calls spec id ast =
let vs_ids = Initial_check.val_spec_ids ast in
let rewrite_e_aux = function
| E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 ->
- let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in
let spec_id = id_of_instantiation id instantiation in
(* Make sure we only generate specialized calls when we've
specialized the valspec. The valspec may not be generated if
@@ -278,13 +304,61 @@ and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) =
| A_order ord -> KidSet.empty
| A_bool _ -> KidSet.empty
-let specialize_id_valspec instantiations id ast =
+(* Implicit arguments have restrictions that won't hold
+ post-specialisation, but we can just remove them and turn them into
+ regular arguments. *)
+let rec remove_implicit (Typ_aux (aux, l) as t) =
+ match aux with
+ | Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l)
+ | Typ_tup typs -> Typ_aux (Typ_tup (List.map remove_implicit typs), l)
+ | Typ_fn (arg_typs, ret_typ, effs) -> Typ_aux (Typ_fn (List.map remove_implicit arg_typs, remove_implicit ret_typ, effs), l)
+ | Typ_bidir (typ1, typ2) -> Typ_aux (Typ_bidir (remove_implicit typ1, remove_implicit typ2), l)
+ | Typ_app (Id_aux (Id "implicit", _), args) -> Typ_aux (Typ_app (mk_id "atom", List.map remove_implicit_arg args), l)
+ | Typ_app (id, args) -> Typ_aux (Typ_app (id, List.map remove_implicit_arg args), l)
+ | Typ_id id -> Typ_aux (Typ_id id, l)
+ | Typ_exist (kopts, nc, typ) -> Typ_aux (Typ_exist (kopts, nc, remove_implicit typ), l)
+ | Typ_var v -> Typ_aux (Typ_var v, l)
+and remove_implicit_arg (A_aux (aux, l)) =
+ match aux with
+ | A_typ typ -> A_aux (A_typ (remove_implicit typ), l)
+ | arg -> A_aux (arg, l)
+
+let kopt_arg = function
+ | KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _) -> arg_nexp (nvar kid)
+ | KOpt_aux (KOpt_kind (K_aux (K_type,_), kid), _) -> arg_typ (mk_typ (Typ_var kid))
+ | _ -> failwith "oh no"
+
+(* For numeric type arguments we have to be careful not to run into a
+ situation where we have an instantiation like
+
+ 'n => 'm, 'm => 8
+
+ and end up re-writing 'n to 8. This function turns an instantition
+ like the above into two,
+
+ 'n => 'i#m, 'm => 8 and 'i#m => 'm
+
+ so we can do the substitution in two steps. *)
+let safe_instantiation instantiation =
+ let args =
+ List.map (fun (_, arg) -> kopts_of_typ_arg arg) (KBindings.bindings instantiation)
+ |> List.fold_left KOptSet.union KOptSet.empty
+ |> KOptSet.elements
+ in
+ List.fold_left (fun (i, r) v -> KBindings.map (fun arg -> subst_kid typ_arg_subst (kopt_kid v) (prepend_kid "i#" (kopt_kid v)) arg) i,
+ KBindings.add (prepend_kid "i#" (kopt_kid v)) (kopt_arg v) r)
+ (instantiation, KBindings.empty) args
+
+let instantiate_constraints instantiation ncs =
+ List.map (fun c -> List.fold_left (fun c (v, a) -> constraint_subst v a c) c (KBindings.bindings instantiation)) ncs
+
+let specialize_id_valspec spec instantiations id ast =
match split_defs (is_valspec id) ast with
- | None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!")
+ | None -> Reporting.unreachable (id_loc id) __POS__ ("Valspec " ^ string_of_id id ^ " does not exist!")
| Some (pre_ast, vs, post_ast) ->
let typschm, externs, is_cast, annot = match vs with
| DEF_spec (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)) -> typschm, externs, is_cast, annot
- | _ -> assert false (* unreachable *)
+ | _ -> Reporting.unreachable (id_loc id) __POS__ "val-spec is not actually a val-spec"
in
let TypSchm_aux (TypSchm_ts (typq, typ), _) = typschm in
@@ -292,8 +366,9 @@ let specialize_id_valspec instantiations id ast =
let spec_ids = ref IdSet.empty in
let specialize_instance instantiation =
+ let safe_instantiation, reverse = safe_instantiation instantiation in
(* Replace the polymorphic type variables in the type with their concrete instantiation. *)
- let typ = Type_check.subst_unifiers instantiation typ in
+ let typ = remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ)) in
(* Collect any new type variables introduced by the instantiation *)
let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in
@@ -302,11 +377,17 @@ let specialize_id_valspec instantiations id ast =
(* Remove type variables from the type quantifier. *)
let kopts, constraints = quant_split typq in
- let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in
- let typq = mk_typquant (List.map (mk_qi_id K_type) typ_frees
- @ List.map (mk_qi_id K_int) int_frees
- @ List.map mk_qi_kopt kopts
- @ List.map mk_qi_nc constraints) in
+ let constraints = instantiate_constraints safe_instantiation constraints in
+ let constraints = instantiate_constraints reverse constraints in
+ let kopts = List.filter (fun kopt -> not (spec.is_polymorphic kopt)) kopts in
+ let typq =
+ if List.length (typ_frees @ int_frees) = 0 && List.length kopts = 0 then
+ mk_typquant []
+ else
+ mk_typquant (List.map (mk_qi_id K_type) typ_frees
+ @ List.map (mk_qi_id K_int) int_frees
+ @ List.map mk_qi_kopt kopts
+ @ List.map mk_qi_nc constraints) in
let typschm = mk_typschm typq typ in
let spec_id = id_of_instantiation id instantiation in
@@ -324,8 +405,9 @@ let specialize_id_valspec instantiations id ast =
(* When we specialize a function definition we also need to specialize
all the types that appear as annotations within the function
- body. *)
-let specialize_annotations instantiation =
+ body. Also remove any type-annotation from the fundef itself,
+ because at this point we have that as a separate valspec.*)
+let specialize_annotations instantiation fdef =
let open Type_check in
let rw_pat = {
id_pat_alg with
@@ -337,12 +419,21 @@ let specialize_annotations instantiation =
lEXP_cast = (fun (typ, lexp) -> LEXP_cast (subst_unifiers instantiation typ, lexp));
pat_alg = rw_pat
} in
- rewrite_fun {
- rewriters_base with
- rewrite_exp = (fun _ -> fold_exp rw_exp);
- rewrite_pat = (fun _ -> fold_pat rw_pat)
- }
-
+ let fdef =
+ rewrite_fun {
+ rewriters_base with
+ rewrite_exp = (fun _ -> fold_exp rw_exp);
+ rewrite_pat = (fun _ -> fold_pat rw_pat)
+ } fdef
+ in
+ match fdef with
+ | FD_aux (FD_function (rec_opt, _, eff_opt, funcls), annot) ->
+ FD_aux (FD_function (rec_opt,
+ Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown),
+ eff_opt,
+ funcls),
+ annot)
+
let specialize_id_fundef instantiations id ast =
match split_defs (is_fundef id) ast with
| None -> ast
@@ -380,7 +471,15 @@ let specialize_id_overloads instantiations id (Defs defs) =
valspecs are then re-specialized. This process is iterated until
the whole spec is specialized. *)
-let initial_calls = (IdSet.of_list [mk_id "main"; mk_id "__SetConfig"; mk_id "__ListConfig"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"; mk_id "append_64"])
+let initial_calls = IdSet.of_list
+ [ mk_id "main";
+ mk_id "__SetConfig";
+ mk_id "__ListConfig";
+ mk_id "execute";
+ mk_id "decode";
+ mk_id "initialize_registers";
+ mk_id "append_64" (* used to construct bitvector literals in C backend *)
+ ]
let remove_unused_valspecs ?(initial_calls=initial_calls) env ast =
let calls = ref initial_calls in
@@ -424,9 +523,9 @@ let slice_defs env (Defs defs) keep_ids =
let defs = List.filter keep defs in
remove_unused_valspecs env (Defs defs) ~initial_calls:keep_ids
-let specialize_id id ast =
- let instantiations = instantiations_of id ast in
- let ast = specialize_id_valspec instantiations id ast in
+let specialize_id spec id ast =
+ let instantiations = instantiations_of spec id ast in
+ let ast = specialize_id_valspec spec instantiations id ast in
let ast = specialize_id_fundef instantiations id ast in
specialize_id_overloads instantiations id ast
@@ -448,21 +547,26 @@ let reorder_typedefs (Defs defs) =
let others = filter_typedefs defs in
Defs (List.rev !tdefs @ others)
-let specialize_ids ids ast =
- let ast = List.fold_left (fun ast id -> specialize_id id ast) ast (IdSet.elements ids) in
+let specialize_ids spec ids ast =
+ let ast = List.fold_left (fun ast id -> specialize_id spec id ast) ast (IdSet.elements ids) in
let ast = reorder_typedefs ast in
let ast, _ = Type_error.check Type_check.initial_env ast in
let ast =
- List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids)
+ List.fold_left (fun ast id -> rewrite_polymorphic_calls spec id ast) ast (IdSet.elements ids)
in
let ast, env = Type_error.check Type_check.initial_env ast in
let ast = remove_unused_valspecs env ast in
ast, env
-let rec specialize ast env =
- let ids = polymorphic_functions (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) ast in
- if IdSet.is_empty ids then
+let rec specialize' n spec ast env =
+ if n = 0 then
ast, env
else
- let ast, env = specialize_ids ids ast in
- specialize ast env
+ let ids = polymorphic_functions spec.is_polymorphic ast in
+ if IdSet.is_empty ids then
+ ast, env
+ else
+ let ast, env = specialize_ids spec ids ast in
+ specialize' (n - 1) spec ast env
+
+let specialize = specialize' (-1)