summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlasdair Armstrong2019-02-19 17:02:19 +0000
committerAlasdair Armstrong2019-02-19 17:02:19 +0000
commitfc7d360e9442ab2e945e0d2da97faaf0eefec66f (patch)
treea823d0c949dde68bdf117c836c3c2e28f9cf9088
parent3c967f9075d890b8ba0e3fa1fb990a41a36ddd80 (diff)
Refactor specialization
specialize functions now take a 'specialization' parameter that specifies how they will specialize the AST. typ_ord_specialization gives the previous behaviour, whereas int_specialization allows specializing on Int-kinded arguments. Note that this can loop forever unless the appropriate case splits are inserted beforehand, presumably by monomorphisation. rename is_nat_kopt -> is_int_kopt for consistency
-rw-r--r--src/ast_util.ml2
-rw-r--r--src/ast_util.mli2
-rw-r--r--src/initial_check.ml6
-rw-r--r--src/isail.ml6
-rw-r--r--src/monomorphise.ml2
-rw-r--r--src/ocaml_backend.ml4
-rw-r--r--src/pretty_print_coq.ml2
-rw-r--r--src/pretty_print_lem.ml2
-rw-r--r--src/pretty_print_sail.ml6
-rw-r--r--src/sail.ml2
-rw-r--r--src/specialize.ml101
-rw-r--r--src/specialize.mli18
-rw-r--r--src/type_check.ml6
13 files changed, 101 insertions, 58 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index e4287249..8942b3b1 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -129,7 +129,7 @@ let mk_val_spec vs_aux =
let kopt_kid (KOpt_aux (KOpt_kind (_, kid), _)) = kid
let kopt_kind (KOpt_aux (KOpt_kind (k, _), _)) = k
-let is_nat_kopt = function
+let is_int_kopt = function
| KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _) -> true
| _ -> false
diff --git a/src/ast_util.mli b/src/ast_util.mli
index fe722f5e..823fcebb 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -109,7 +109,7 @@ val dec_ord : order
(* Utilites for working with kinded_ids *)
val kopt_kid : kinded_id -> kid
val kopt_kind : kinded_id -> kind
-val is_nat_kopt : kinded_id -> bool
+val is_int_kopt : kinded_id -> bool
val is_order_kopt : kinded_id -> bool
val is_typ_kopt : kinded_id -> bool
val is_bool_kopt : kinded_id -> bool
diff --git a/src/initial_check.ml b/src/initial_check.ml
index 07316c6d..003da64e 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -824,15 +824,15 @@ let val_spec_ids (Defs defs) =
IdSet.of_list (vs_ids defs)
let quant_item_param = function
- | QI_aux (QI_id kopt, _) when is_nat_kopt kopt -> [prepend_id "atom_" (id_of_kid (kopt_kid kopt))]
+ | QI_aux (QI_id kopt, _) when is_int_kopt kopt -> [prepend_id "atom_" (id_of_kid (kopt_kid kopt))]
| QI_aux (QI_id kopt, _) when is_typ_kopt kopt -> [prepend_id "typ_" (id_of_kid (kopt_kid kopt))]
| _ -> []
let quant_item_typ = function
- | QI_aux (QI_id kopt, _) when is_nat_kopt kopt -> [atom_typ (nvar (kopt_kid kopt))]
+ | QI_aux (QI_id kopt, _) when is_int_kopt kopt -> [atom_typ (nvar (kopt_kid kopt))]
| QI_aux (QI_id kopt, _) when is_typ_kopt kopt -> [mk_typ (Typ_var (kopt_kid kopt))]
| _ -> []
let quant_item_arg = function
- | QI_aux (QI_id kopt, _) when is_nat_kopt kopt -> [mk_typ_arg (A_nexp (nvar (kopt_kid kopt)))]
+ | QI_aux (QI_id kopt, _) when is_int_kopt kopt -> [mk_typ_arg (A_nexp (nvar (kopt_kid kopt)))]
| QI_aux (QI_id kopt, _) when is_typ_kopt kopt -> [mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt))))]
| _ -> []
let undefined_typschm id typq =
diff --git a/src/isail.ml b/src/isail.ml
index 4cfb2c6f..252b21b8 100644
--- a/src/isail.ml
+++ b/src/isail.ml
@@ -358,7 +358,7 @@ let handle_input' input =
List.iter print_endline commands
| ":poly" ->
let is_kopt = match arg with
- | "Int" -> is_nat_kopt
+ | "Int" -> is_int_kopt
| "Type" -> is_typ_kopt
| "Order" -> is_order_kopt
| _ -> failwith "Invalid kind"
@@ -374,7 +374,7 @@ let handle_input' input =
| Arg.Bad message | Arg.Help message -> print_endline message
end;
| ":spec" ->
- let ast, env = Specialize.specialize !Interactive.ast !Interactive.env in
+ let ast, env = Specialize.(specialize int_specialization !Interactive.ast !Interactive.env) in
Interactive.ast := ast;
Interactive.env := env;
interactive_state := initial_state !Interactive.ast Value.primops
@@ -384,7 +384,7 @@ let handle_input' input =
let open PPrint in
let open C_backend in
let ast = Process_file.rewrite_ast_c !Interactive.env !Interactive.ast in
- let ast, env = Specialize.specialize ast !Interactive.env in
+ let ast, env = Specialize.(specialize typ_ord_specialization ast !Interactive.env) in
let ctx = initial_ctx env in
interactive_bytecode := bytecode_ast ctx (List.map flatten_cdef) ast
| ":ir" ->
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index acc31456..856e36d5 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -4300,7 +4300,7 @@ let add_bitvector_casts (Defs defs) =
let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann)) =
let fcl_env = env_of_annot fcl_ann in
let (tq,typ) = Env.get_val_spec_orig id fcl_env in
- let quant_kids = List.map kopt_kid (List.filter is_nat_kopt (quant_kopts tq)) in
+ let quant_kids = List.map kopt_kid (List.filter is_int_kopt (quant_kopts tq)) in
let ret_typ =
match typ with
| Typ_aux (Typ_fn (_,ret,_),_) -> ret
diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml
index 05406413..894d028f 100644
--- a/src/ocaml_backend.ml
+++ b/src/ocaml_backend.ml
@@ -744,13 +744,13 @@ let ocaml_pp_generators ctx defs orig_types required =
let gen_tyvars = List.map (fun k -> kopt_kid k |> zencode_kid)
(List.filter is_typ_kopt tquants) in
let print_quant kindedid =
- if is_nat_kopt kindedid then string "int" else
+ if is_int_kopt kindedid then string "int" else
if is_order_kopt kindedid then string "bool" else
parens (separate space [string "generators"; string "->"; zencode_kid (kopt_kid kindedid)])
in
let name = "gen_" ^ type_name id in
let make_tyarg kindedid =
- if is_nat_kopt kindedid
+ if is_int_kopt kindedid
then mk_typ_arg (A_nexp (nvar (kopt_kid kindedid)))
else if is_order_kopt kindedid
then mk_typ_arg (A_order (mk_ord (Ord_var (kopt_kid kindedid))))
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml
index c321cb26..430eb40d 100644
--- a/src/pretty_print_coq.ml
+++ b/src/pretty_print_coq.ml
@@ -2512,7 +2512,7 @@ let doc_axiom_typschm typ_env (TypSchm_aux (TypSchm_ts (tqs,typ),l) as ts) =
let used = if is_number ret_ty then used else KidSet.union used (tyvars_of_typ ret_ty) in
let tqs = match tqs with
| TypQ_aux (TypQ_tq qs,l) -> TypQ_aux (TypQ_tq (List.filter (function
- | QI_aux (QI_id kopt,_) when is_nat_kopt kopt ->
+ | QI_aux (QI_id kopt,_) when is_int_kopt kopt ->
let kid = kopt_kid kopt in
KidSet.mem kid used && not (KidSet.mem kid args)
| _ -> true) qs),l)
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 9d472e15..1c30e06e 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -316,7 +316,7 @@ let doc_typ_lem, doc_atomic_typ_lem =
* if we add a new Typ constructor *)
let tpp = typ ty in
if atyp_needed then parens tpp else tpp
- | Typ_exist (kopts,_,ty) when List.for_all is_nat_kopt kopts -> begin
+ | Typ_exist (kopts,_,ty) when List.for_all is_int_kopt kopts -> begin
let kids = List.map kopt_kid kopts in
let tpp = typ ty in
let visible_vars = lem_tyvars_of_typ ty in
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index 9a374275..27f626ea 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -66,7 +66,7 @@ let doc_id (Id_aux (id_aux, _)) =
let doc_kid kid = string (Ast_util.string_of_kid kid)
let doc_kopt = function
- | kopt when is_nat_kopt kopt -> doc_kid (kopt_kid kopt)
+ | kopt when is_int_kopt kopt -> doc_kid (kopt_kid kopt)
| kopt when is_typ_kopt kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Type"])
| kopt when is_order_kopt kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Order"])
| kopt -> parens (separate space [doc_kid (kopt_kid kopt); colon; string "Bool"])
@@ -213,7 +213,7 @@ and doc_arg_typs = function
let doc_quants quants =
let doc_qi_kopt (QI_aux (qi_aux, _)) =
match qi_aux with
- | QI_id kopt when is_nat_kopt kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Int"])]
+ | QI_id kopt when is_int_kopt kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Int"])]
| QI_id kopt when is_typ_kopt kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Type"])]
| QI_id kopt when is_bool_kopt kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Bool"])]
| QI_id kopt -> [parens (separate space [doc_kid (kopt_kid kopt); colon; string "Order"])]
@@ -234,7 +234,7 @@ let doc_quants quants =
let doc_param_quants quants =
let doc_qi_kopt (QI_aux (qi_aux, _)) =
match qi_aux with
- | QI_id kopt when is_nat_kopt kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Int"]
+ | QI_id kopt when is_int_kopt kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Int"]
| QI_id kopt when is_typ_kopt kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Type"]
| QI_id kopt when is_bool_kopt kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Bool"]
| QI_id kopt -> [doc_kid (kopt_kid kopt) ^^ colon ^^ space ^^ string "Order"]
diff --git a/src/sail.ml b/src/sail.ml
index 64ccd341..eaf96eb4 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -406,7 +406,7 @@ let main() =
(if !(opt_print_c)
then
let ast_c = rewrite_ast_c type_envs ast in
- let ast_c, type_envs = Specialize.specialize ast_c type_envs in
+ let ast_c, type_envs = Specialize.(specialize typ_ord_specialization ast_c type_envs) in
(* let ast_c = Spec_analysis.top_sort_defs ast_c in *)
Util.opt_warnings := true;
C_backend.compile_ast (C_backend.initial_ctx type_envs) (!opt_includes_c) ast_c
diff --git a/src/specialize.ml b/src/specialize.ml
index 9f6af6d6..eaef1231 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -52,11 +52,26 @@ open Ast
open Ast_util
open Rewriter
-let is_typ_ord_uvar = function
+let is_typ_ord_arg = function
| A_aux (A_typ _, _) -> true
| A_aux (A_order _, _) -> true
| _ -> false
+type specialization = {
+ is_polymorphic : kinded_id -> bool;
+ instantiation_filter : kid -> typ_arg -> bool
+ }
+
+let typ_ord_specialization = {
+ is_polymorphic = (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt);
+ instantiation_filter = (fun _ -> is_typ_ord_arg)
+ }
+
+let int_specialization = {
+ is_polymorphic = is_int_kopt;
+ instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp _, _) -> true | _ -> false)
+ }
+
let rec nexp_simp_typ (Typ_aux (typ_aux, l)) =
let typ_aux = match typ_aux with
| Typ_id v -> Typ_id v
@@ -81,34 +96,43 @@ and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) =
(* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of.
This part of the typechecker API is a bit ugly. *)
-let fix_instantiation instantiation =
- let instantiation = KBindings.bindings (KBindings.filter (fun _ arg -> is_typ_ord_uvar arg) instantiation) in
+let fix_instantiation spec instantiation =
+ let instantiation = KBindings.bindings (KBindings.filter spec.instantiation_filter instantiation) in
let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in
List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation
+(* polymorphic_functions returns all functions that are polymorphic
+ for some set of kinded-identifiers, specified by the is_kopt
+ predicate. For example, polymorphic_functions is_int_kopt will
+ return all Int-polymorphic functions. *)
let rec polymorphic_functions is_kopt (Defs defs) =
match defs with
| DEF_spec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ) , _), id, _, externs), _)) :: defs ->
- let is_type_polymorphic = List.exists is_kopt (quant_kopts typq) in
- if is_type_polymorphic then
+ let is_polymorphic = List.exists is_kopt (quant_kopts typq) in
+ if is_polymorphic then
IdSet.add id (polymorphic_functions is_kopt (Defs defs))
else
polymorphic_functions is_kopt (Defs defs)
| _ :: defs -> polymorphic_functions is_kopt (Defs defs)
| [] -> IdSet.empty
+(* When we specialize a function, we need to generate new name. To do
+ this we take the instantiation that the new function is specialized
+ for and turn that into a string in such a way that alpha-equivalent
+ instantiations always get the same name. We then zencode that
+ string so it is a valid identifier name, and prepend it to the
+ previous function name. *)
let string_of_instantiation instantiation =
let open Type_check in
let kid_names = ref KOptMap.empty in
let kid_counter = ref 0 in
let kid_name kid =
try KOptMap.find kid !kid_names with
- | Not_found -> begin
- let n = string_of_int !kid_counter in
- kid_names := KOptMap.add kid n !kid_names;
- incr kid_counter;
- n
- end
+ | Not_found ->
+ let n = string_of_int !kid_counter in
+ kid_names := KOptMap.add kid n !kid_names;
+ incr kid_counter;
+ n
in
(* We need custom string_of functions to ensure that alpha-equivalent definitions get the same name *)
@@ -121,7 +145,7 @@ let string_of_instantiation instantiation =
| Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")"
| Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")"
| Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")"
- | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_nexp nexps ^ ")"
+ | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_nexp nexps ^ ")"
| Nexp_exp n -> "2 ^ " ^ string_of_nexp n
| Nexp_neg n -> "- " ^ string_of_nexp n
in
@@ -132,7 +156,7 @@ let string_of_instantiation instantiation =
| Typ_id id -> string_of_id id
| Typ_var kid -> kid_name (mk_kopt K_type kid)
| Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")"
- | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")"
+ | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")"
| Typ_fn (arg_typs, ret_typ, eff) ->
"(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ ^ " effect " ^ string_of_effect eff
| Typ_bidir (t1, t2) ->
@@ -161,10 +185,10 @@ let string_of_instantiation instantiation =
| NC_aux (NC_true, _) -> "true"
| NC_aux (NC_false, _) -> "false"
| NC_aux (NC_var kid, _) -> kid_name (mk_kopt K_bool kid)
- | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")"
+ | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")"
in
- let string_of_binding (kid, arg) = string_of_kid kid ^ " => " ^ string_of_typ_arg arg in
+ let string_of_binding (kid, arg) = string_of_kid kid ^ "=>" ^ string_of_typ_arg arg in
Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation))
let id_of_instantiation id instantiation =
@@ -181,12 +205,12 @@ let rec variant_generic_typ id (Defs defs) =
(* Returns a list of all the instantiations of a function id in an
ast. Also works with union constructors, and searches for them in
patterns. *)
-let rec instantiations_of id ast =
+let rec instantiations_of spec id ast =
let instantiations = ref [] in
let inspect_exp = function
| E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 ->
- let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in
instantiations := instantiation :: !instantiations;
exp
| exp -> exp
@@ -204,7 +228,7 @@ let rec instantiations_of id ast =
(variant_generic_typ variant_id ast)
typ
in
- instantiations := fix_instantiation instantiation :: !instantiations;
+ instantiations := fix_instantiation spec instantiation :: !instantiations;
pat
| Typ_aux (Typ_id variant_id, _) -> pat
| _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type")
@@ -220,12 +244,12 @@ let rec instantiations_of id ast =
!instantiations
-let rec rewrite_polymorphic_calls id ast =
+let rec rewrite_polymorphic_calls spec id ast =
let vs_ids = Initial_check.val_spec_ids ast in
let rewrite_e_aux = function
| E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 ->
- let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in
let spec_id = id_of_instantiation id instantiation in
(* Make sure we only generate specialized calls when we've
specialized the valspec. The valspec may not be generated if
@@ -280,7 +304,7 @@ and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) =
| A_order ord -> KidSet.empty
| A_bool _ -> KidSet.empty
-let specialize_id_valspec instantiations id ast =
+let specialize_id_valspec spec instantiations id ast =
match split_defs (is_valspec id) ast with
| None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!")
| Some (pre_ast, vs, post_ast) ->
@@ -304,7 +328,8 @@ let specialize_id_valspec instantiations id ast =
(* Remove type variables from the type quantifier. *)
let kopts, constraints = quant_split typq in
- let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in
+ let constraints = List.map (fun c -> List.fold_left (fun c (v, a) -> constraint_subst v a c) c (KBindings.bindings instantiation)) constraints in
+ let kopts = List.filter (fun kopt -> not (spec.is_polymorphic kopt)) kopts in
let typq = mk_typquant (List.map (mk_qi_id K_type) typ_frees
@ List.map (mk_qi_id K_int) int_frees
@ List.map mk_qi_kopt kopts
@@ -344,7 +369,7 @@ let specialize_annotations instantiation =
rewrite_exp = (fun _ -> fold_exp rw_exp);
rewrite_pat = (fun _ -> fold_pat rw_pat)
}
-
+
let specialize_id_fundef instantiations id ast =
match split_defs (is_fundef id) ast with
| None -> ast
@@ -382,7 +407,15 @@ let specialize_id_overloads instantiations id (Defs defs) =
valspecs are then re-specialized. This process is iterated until
the whole spec is specialized. *)
-let initial_calls = (IdSet.of_list [mk_id "main"; mk_id "__SetConfig"; mk_id "__ListConfig"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"; mk_id "append_64"])
+let initial_calls = IdSet.of_list
+ [ mk_id "main";
+ mk_id "__SetConfig";
+ mk_id "__ListConfig";
+ mk_id "execute";
+ mk_id "decode";
+ mk_id "initialize_registers";
+ mk_id "append_64" (* used to construct bitvector literals in C backend *)
+ ]
let remove_unused_valspecs ?(initial_calls=initial_calls) env ast =
let calls = ref initial_calls in
@@ -426,9 +459,9 @@ let slice_defs env (Defs defs) keep_ids =
let defs = List.filter keep defs in
remove_unused_valspecs env (Defs defs) ~initial_calls:keep_ids
-let specialize_id id ast =
- let instantiations = instantiations_of id ast in
- let ast = specialize_id_valspec instantiations id ast in
+let specialize_id spec id ast =
+ let instantiations = instantiations_of spec id ast in
+ let ast = specialize_id_valspec spec instantiations id ast in
let ast = specialize_id_fundef instantiations id ast in
specialize_id_overloads instantiations id ast
@@ -450,21 +483,21 @@ let reorder_typedefs (Defs defs) =
let others = filter_typedefs defs in
Defs (List.rev !tdefs @ others)
-let specialize_ids ids ast =
- let ast = List.fold_left (fun ast id -> specialize_id id ast) ast (IdSet.elements ids) in
+let specialize_ids spec ids ast =
+ let ast = List.fold_left (fun ast id -> specialize_id spec id ast) ast (IdSet.elements ids) in
let ast = reorder_typedefs ast in
let ast, _ = Type_error.check Type_check.initial_env ast in
let ast =
- List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids)
+ List.fold_left (fun ast id -> rewrite_polymorphic_calls spec id ast) ast (IdSet.elements ids)
in
let ast, env = Type_error.check Type_check.initial_env ast in
let ast = remove_unused_valspecs env ast in
ast, env
-let rec specialize ast env =
- let ids = polymorphic_functions (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) ast in
+let rec specialize spec ast env =
+ let ids = polymorphic_functions spec.is_polymorphic ast in
if IdSet.is_empty ids then
ast, env
else
- let ast, env = specialize_ids ids ast in
- specialize ast env
+ let ast, env = specialize_ids spec ids ast in
+ specialize spec ast env
diff --git a/src/specialize.mli b/src/specialize.mli
index 28029747..269f2340 100644
--- a/src/specialize.mli
+++ b/src/specialize.mli
@@ -54,10 +54,18 @@ open Ast
open Ast_util
open Type_check
+type specialization
+
+(** Only specialize Type- and Ord- kinded polymorphism. *)
+val typ_ord_specialization : specialization
+
+(** (experimental) specialise Int-kinded definitions *)
+val int_specialization : specialization
+
(** Returns an IdSet with the function ids that have X-kinded
parameters, e.g. val f : forall ('a : X). 'a -> 'a. The first
argument specifies what X should be - it should be one of:
- [is_nat_kopt], [is_order_kopt], or [is_typ_kopt] from [Ast_util],
+ [is_int_kopt], [is_order_kopt], or [is_typ_kopt] from [Ast_util],
or some combination of those. *)
val polymorphic_functions : (kinded_id -> bool) -> 'a defs -> IdSet.t
@@ -66,11 +74,13 @@ val polymorphic_functions : (kinded_id -> bool) -> 'a defs -> IdSet.t
AST with [Type_check.initial_env]. The env parameter is the
environment to return if there is no polymorphism to remove, in
which case specialize returns the AST unmodified. *)
-val specialize : tannot defs -> Env.t -> tannot defs * Env.t
+val specialize : specialization -> tannot defs -> Env.t -> tannot defs * Env.t
-val instantiations_of : id -> tannot defs -> typ_arg KBindings.t list
+(** return all instantiations of a function id, with the
+ instantiations filtered according to the specialization. *)
+val instantiations_of : specialization -> id -> tannot defs -> typ_arg KBindings.t list
val string_of_instantiation : typ_arg KBindings.t -> string
-(* Remove all function definitions except for the given set *)
+(** Remove all function definitions except for the given set *)
val slice_defs : Env.t -> tannot defs -> IdSet.t -> tannot defs
diff --git a/src/type_check.ml b/src/type_check.ml
index 0da7f753..3d0f38a6 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -574,7 +574,7 @@ end = struct
let kopts, ncs = quant_split typq in
let rec subst_args kopts args =
match kopts, args with
- | kopt :: kopts, (A_aux (A_nexp _, _) as arg) :: args when is_nat_kopt kopt ->
+ | kopt :: kopts, (A_aux (A_nexp _, _) as arg) :: args when is_int_kopt kopt ->
List.map (constraint_subst (kopt_kid kopt) arg) (subst_args kopts args)
| kopt :: kopts, A_aux (A_typ arg, _) :: args when is_typ_kopt kopt ->
subst_args kopts args
@@ -992,7 +992,7 @@ end = struct
typ_print (lazy (adding ^ "record " ^ string_of_id id));
let rec record_typ_args = function
| [] -> []
- | ((QI_aux (QI_id kopt, _)) :: qis) when is_nat_kopt kopt ->
+ | ((QI_aux (QI_id kopt, _)) :: qis) when is_int_kopt kopt ->
mk_typ_arg (A_nexp (nvar (kopt_kid kopt))) :: record_typ_args qis
| ((QI_aux (QI_id kopt, _)) :: qis) when is_typ_kopt kopt ->
mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt)))) :: record_typ_args qis
@@ -4817,7 +4817,7 @@ let mk_synonym typq typ_arg =
let kopts = List.map snd kopts in
let rec subst_args env l kopts args =
match kopts, args with
- | kopt :: kopts, A_aux (A_nexp arg, _) :: args when is_nat_kopt kopt ->
+ | kopt :: kopts, A_aux (A_nexp arg, _) :: args when is_int_kopt kopt ->
let typ_arg, ncs = subst_args env l kopts args in
typ_arg_subst (kopt_kid kopt) (arg_nexp arg) typ_arg,
List.map (constraint_subst (kopt_kid kopt) (arg_nexp arg)) ncs