summaryrefslogtreecommitdiff
path: root/src/specialize.ml
diff options
context:
space:
mode:
authorAlasdair Armstrong2019-02-19 17:02:19 +0000
committerAlasdair Armstrong2019-02-19 17:02:19 +0000
commitfc7d360e9442ab2e945e0d2da97faaf0eefec66f (patch)
treea823d0c949dde68bdf117c836c3c2e28f9cf9088 /src/specialize.ml
parent3c967f9075d890b8ba0e3fa1fb990a41a36ddd80 (diff)
Refactor specialization
specialize functions now take a 'specialization' parameter that specifies how they will specialize the AST. typ_ord_specialization gives the previous behaviour, whereas int_specialization allows specializing on Int-kinded arguments. Note that this can loop forever unless the appropriate case splits are inserted beforehand, presumably by monomorphisation. rename is_nat_kopt -> is_int_kopt for consistency
Diffstat (limited to 'src/specialize.ml')
-rw-r--r--src/specialize.ml101
1 files changed, 67 insertions, 34 deletions
diff --git a/src/specialize.ml b/src/specialize.ml
index 9f6af6d6..eaef1231 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -52,11 +52,26 @@ open Ast
open Ast_util
open Rewriter
-let is_typ_ord_uvar = function
+let is_typ_ord_arg = function
| A_aux (A_typ _, _) -> true
| A_aux (A_order _, _) -> true
| _ -> false
+type specialization = {
+ is_polymorphic : kinded_id -> bool;
+ instantiation_filter : kid -> typ_arg -> bool
+ }
+
+let typ_ord_specialization = {
+ is_polymorphic = (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt);
+ instantiation_filter = (fun _ -> is_typ_ord_arg)
+ }
+
+let int_specialization = {
+ is_polymorphic = is_int_kopt;
+ instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp _, _) -> true | _ -> false)
+ }
+
let rec nexp_simp_typ (Typ_aux (typ_aux, l)) =
let typ_aux = match typ_aux with
| Typ_id v -> Typ_id v
@@ -81,34 +96,43 @@ and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) =
(* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of.
This part of the typechecker API is a bit ugly. *)
-let fix_instantiation instantiation =
- let instantiation = KBindings.bindings (KBindings.filter (fun _ arg -> is_typ_ord_uvar arg) instantiation) in
+let fix_instantiation spec instantiation =
+ let instantiation = KBindings.bindings (KBindings.filter spec.instantiation_filter instantiation) in
let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in
List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation
+(* polymorphic_functions returns all functions that are polymorphic
+ for some set of kinded-identifiers, specified by the is_kopt
+ predicate. For example, polymorphic_functions is_int_kopt will
+ return all Int-polymorphic functions. *)
let rec polymorphic_functions is_kopt (Defs defs) =
match defs with
| DEF_spec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ) , _), id, _, externs), _)) :: defs ->
- let is_type_polymorphic = List.exists is_kopt (quant_kopts typq) in
- if is_type_polymorphic then
+ let is_polymorphic = List.exists is_kopt (quant_kopts typq) in
+ if is_polymorphic then
IdSet.add id (polymorphic_functions is_kopt (Defs defs))
else
polymorphic_functions is_kopt (Defs defs)
| _ :: defs -> polymorphic_functions is_kopt (Defs defs)
| [] -> IdSet.empty
+(* When we specialize a function, we need to generate new name. To do
+ this we take the instantiation that the new function is specialized
+ for and turn that into a string in such a way that alpha-equivalent
+ instantiations always get the same name. We then zencode that
+ string so it is a valid identifier name, and prepend it to the
+ previous function name. *)
let string_of_instantiation instantiation =
let open Type_check in
let kid_names = ref KOptMap.empty in
let kid_counter = ref 0 in
let kid_name kid =
try KOptMap.find kid !kid_names with
- | Not_found -> begin
- let n = string_of_int !kid_counter in
- kid_names := KOptMap.add kid n !kid_names;
- incr kid_counter;
- n
- end
+ | Not_found ->
+ let n = string_of_int !kid_counter in
+ kid_names := KOptMap.add kid n !kid_names;
+ incr kid_counter;
+ n
in
(* We need custom string_of functions to ensure that alpha-equivalent definitions get the same name *)
@@ -121,7 +145,7 @@ let string_of_instantiation instantiation =
| Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")"
| Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")"
| Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")"
- | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_nexp nexps ^ ")"
+ | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_nexp nexps ^ ")"
| Nexp_exp n -> "2 ^ " ^ string_of_nexp n
| Nexp_neg n -> "- " ^ string_of_nexp n
in
@@ -132,7 +156,7 @@ let string_of_instantiation instantiation =
| Typ_id id -> string_of_id id
| Typ_var kid -> kid_name (mk_kopt K_type kid)
| Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")"
- | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")"
+ | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")"
| Typ_fn (arg_typs, ret_typ, eff) ->
"(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ ^ " effect " ^ string_of_effect eff
| Typ_bidir (t1, t2) ->
@@ -161,10 +185,10 @@ let string_of_instantiation instantiation =
| NC_aux (NC_true, _) -> "true"
| NC_aux (NC_false, _) -> "false"
| NC_aux (NC_var kid, _) -> kid_name (mk_kopt K_bool kid)
- | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")"
+ | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")"
in
- let string_of_binding (kid, arg) = string_of_kid kid ^ " => " ^ string_of_typ_arg arg in
+ let string_of_binding (kid, arg) = string_of_kid kid ^ "=>" ^ string_of_typ_arg arg in
Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation))
let id_of_instantiation id instantiation =
@@ -181,12 +205,12 @@ let rec variant_generic_typ id (Defs defs) =
(* Returns a list of all the instantiations of a function id in an
ast. Also works with union constructors, and searches for them in
patterns. *)
-let rec instantiations_of id ast =
+let rec instantiations_of spec id ast =
let instantiations = ref [] in
let inspect_exp = function
| E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 ->
- let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in
instantiations := instantiation :: !instantiations;
exp
| exp -> exp
@@ -204,7 +228,7 @@ let rec instantiations_of id ast =
(variant_generic_typ variant_id ast)
typ
in
- instantiations := fix_instantiation instantiation :: !instantiations;
+ instantiations := fix_instantiation spec instantiation :: !instantiations;
pat
| Typ_aux (Typ_id variant_id, _) -> pat
| _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type")
@@ -220,12 +244,12 @@ let rec instantiations_of id ast =
!instantiations
-let rec rewrite_polymorphic_calls id ast =
+let rec rewrite_polymorphic_calls spec id ast =
let vs_ids = Initial_check.val_spec_ids ast in
let rewrite_e_aux = function
| E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 ->
- let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in
let spec_id = id_of_instantiation id instantiation in
(* Make sure we only generate specialized calls when we've
specialized the valspec. The valspec may not be generated if
@@ -280,7 +304,7 @@ and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) =
| A_order ord -> KidSet.empty
| A_bool _ -> KidSet.empty
-let specialize_id_valspec instantiations id ast =
+let specialize_id_valspec spec instantiations id ast =
match split_defs (is_valspec id) ast with
| None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!")
| Some (pre_ast, vs, post_ast) ->
@@ -304,7 +328,8 @@ let specialize_id_valspec instantiations id ast =
(* Remove type variables from the type quantifier. *)
let kopts, constraints = quant_split typq in
- let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in
+ let constraints = List.map (fun c -> List.fold_left (fun c (v, a) -> constraint_subst v a c) c (KBindings.bindings instantiation)) constraints in
+ let kopts = List.filter (fun kopt -> not (spec.is_polymorphic kopt)) kopts in
let typq = mk_typquant (List.map (mk_qi_id K_type) typ_frees
@ List.map (mk_qi_id K_int) int_frees
@ List.map mk_qi_kopt kopts
@@ -344,7 +369,7 @@ let specialize_annotations instantiation =
rewrite_exp = (fun _ -> fold_exp rw_exp);
rewrite_pat = (fun _ -> fold_pat rw_pat)
}
-
+
let specialize_id_fundef instantiations id ast =
match split_defs (is_fundef id) ast with
| None -> ast
@@ -382,7 +407,15 @@ let specialize_id_overloads instantiations id (Defs defs) =
valspecs are then re-specialized. This process is iterated until
the whole spec is specialized. *)
-let initial_calls = (IdSet.of_list [mk_id "main"; mk_id "__SetConfig"; mk_id "__ListConfig"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"; mk_id "append_64"])
+let initial_calls = IdSet.of_list
+ [ mk_id "main";
+ mk_id "__SetConfig";
+ mk_id "__ListConfig";
+ mk_id "execute";
+ mk_id "decode";
+ mk_id "initialize_registers";
+ mk_id "append_64" (* used to construct bitvector literals in C backend *)
+ ]
let remove_unused_valspecs ?(initial_calls=initial_calls) env ast =
let calls = ref initial_calls in
@@ -426,9 +459,9 @@ let slice_defs env (Defs defs) keep_ids =
let defs = List.filter keep defs in
remove_unused_valspecs env (Defs defs) ~initial_calls:keep_ids
-let specialize_id id ast =
- let instantiations = instantiations_of id ast in
- let ast = specialize_id_valspec instantiations id ast in
+let specialize_id spec id ast =
+ let instantiations = instantiations_of spec id ast in
+ let ast = specialize_id_valspec spec instantiations id ast in
let ast = specialize_id_fundef instantiations id ast in
specialize_id_overloads instantiations id ast
@@ -450,21 +483,21 @@ let reorder_typedefs (Defs defs) =
let others = filter_typedefs defs in
Defs (List.rev !tdefs @ others)
-let specialize_ids ids ast =
- let ast = List.fold_left (fun ast id -> specialize_id id ast) ast (IdSet.elements ids) in
+let specialize_ids spec ids ast =
+ let ast = List.fold_left (fun ast id -> specialize_id spec id ast) ast (IdSet.elements ids) in
let ast = reorder_typedefs ast in
let ast, _ = Type_error.check Type_check.initial_env ast in
let ast =
- List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids)
+ List.fold_left (fun ast id -> rewrite_polymorphic_calls spec id ast) ast (IdSet.elements ids)
in
let ast, env = Type_error.check Type_check.initial_env ast in
let ast = remove_unused_valspecs env ast in
ast, env
-let rec specialize ast env =
- let ids = polymorphic_functions (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) ast in
+let rec specialize spec ast env =
+ let ids = polymorphic_functions spec.is_polymorphic ast in
if IdSet.is_empty ids then
ast, env
else
- let ast, env = specialize_ids ids ast in
- specialize ast env
+ let ast, env = specialize_ids spec ids ast in
+ specialize spec ast env