summaryrefslogtreecommitdiff
path: root/src/specialize.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/specialize.ml')
-rw-r--r--src/specialize.ml222
1 files changed, 220 insertions, 2 deletions
diff --git a/src/specialize.ml b/src/specialize.ml
index 881881f4..2d32a90c 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -52,10 +52,40 @@ open Ast
open Ast_util
open Rewriter
+let zchar c =
+ let zc c = "z" ^ String.make 1 c in
+ if Char.code c <= 41 then zc (Char.chr (Char.code c + 16))
+ else if Char.code c <= 47 then zc (Char.chr (Char.code c + 23))
+ else if Char.code c <= 57 then String.make 1 c
+ else if Char.code c <= 64 then zc (Char.chr (Char.code c + 13))
+ else if Char.code c <= 90 then String.make 1 c
+ else if Char.code c <= 94 then zc (Char.chr (Char.code c - 13))
+ else if Char.code c <= 95 then "_"
+ else if Char.code c <= 96 then zc (Char.chr (Char.code c - 13))
+ else if Char.code c <= 121 then String.make 1 c
+ else if Char.code c <= 122 then "zz"
+ else if Char.code c <= 126 then zc (Char.chr (Char.code c - 39))
+ else raise (Invalid_argument "zchar")
+
+let zencode_string str = "z" ^ List.fold_left (fun s1 s2 -> s1 ^ s2) "" (List.map zchar (Util.string_to_list str))
+
+let zencode_upper_string str = "Z" ^ List.fold_left (fun s1 s2 -> s1 ^ s2) "" (List.map zchar (Util.string_to_list str))
+
+let is_typ_uvar = function
+ | Type_check.U_typ _ -> true
+ | _ -> false
+
+(* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of.
+ This part of the typechecker API is a bit ugly. *)
+let fix_instantiation instantiation =
+ let instantiation = KBindings.bindings (KBindings.filter (fun _ uvar -> is_typ_uvar uvar) instantiation) in
+ let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, uvar) instantiation in
+ List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation
+
(* Returns an IdSet with the function ids that have X-kinded
parameters, e.g. val f : forall ('a : X). 'a -> 'a. The first
argument specifies what X should be - it should be one of:
- is_nat_kopt, is_order_kopt, or is_type_kopt from Ast_util.
+ is_nat_kopt, is_order_kopt, or is_typ_kopt from Ast_util.
*)
let rec polymorphic_functions is_kopt (Defs defs) =
match defs with
@@ -68,6 +98,11 @@ let rec polymorphic_functions is_kopt (Defs defs) =
| _ :: defs -> polymorphic_functions is_kopt (Defs defs)
| [] -> IdSet.empty
+let id_of_instantiation id instantiation =
+ let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar in
+ let str = zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) ^ "#" in
+ prepend_id str id
+
(* Returns a list of all the instantiations of a function id in an
ast. *)
let rec instantiations_of id ast =
@@ -75,7 +110,8 @@ let rec instantiations_of id ast =
let inspect_exp = function
| E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 ->
- instantiations := Type_check.instantiation_of exp :: !instantiations;
+ let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ instantiations := instantiation :: !instantiations;
exp
| exp -> exp
in
@@ -84,3 +120,185 @@ let rec instantiations_of id ast =
let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in
!instantiations
+
+let rec rewrite_polymorphic_calls id ast =
+ print_endline ("Rewriting: " ^ string_of_id id);
+ let vs_ids = Initial_check.val_spec_ids ast in
+
+ let rewrite_e_aux = function
+ | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 ->
+ let instantiation = fix_instantiation (Type_check.instantiation_of exp) in
+ let spec_id = id_of_instantiation id instantiation in
+ (* Make sure we only generate specialized calls when we've
+ specialized the valspec. The valspec may not be generated if
+ a polymorphic function calls another polymorphic function.
+ In this case a specialization of the first may require that
+ the second needs to be specialized again, but this may not
+ have happened yet. *)
+ if IdSet.mem spec_id vs_ids then
+ E_aux (E_app (spec_id, args), annot)
+ else
+ exp
+ | exp -> exp
+ in
+
+ let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in
+ rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast
+
+let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
+ match typ_aux with
+ | Typ_id v -> KidSet.empty
+ | Typ_var kid when KidSet.mem kid exs -> KidSet.empty
+ | Typ_var kid -> KidSet.singleton kid
+ | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs:exs) typs)
+ | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs:exs) args)
+ | Typ_exist (kids, nc, typ) -> typ_frees ~exs:(KidSet.of_list kids) typ
+ | Typ_fn (typ1, typ2, _) -> KidSet.union (typ_frees ~exs:exs typ1) (typ_frees ~exs:exs typ2)
+and typ_arg_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) =
+ match typ_arg_aux with
+ | Typ_arg_nexp n -> KidSet.empty
+ | Typ_arg_typ typ -> typ_frees ~exs:exs typ
+ | Typ_arg_order ord -> KidSet.empty
+
+let specialize_id_valspec instantiations id ast =
+ match split_defs (is_valspec id) ast with
+ | None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!")
+ | Some (pre_ast, vs, post_ast) ->
+ let typschm, externs, is_cast, annot = match vs with
+ | DEF_spec (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)) -> typschm, externs, is_cast, annot
+ | _ -> assert false (* unreachable *)
+ in
+ let TypSchm_aux (TypSchm_ts (typq, typ), _) = typschm in
+
+ (* Keep track of the specialized ids to avoid generating things twice. *)
+ let spec_ids = ref IdSet.empty in
+
+ let specialize_instance instantiation =
+ (* Replace the polymorphic type variables in the type with their concrete instantiation. *)
+ let typ = Type_check.subst_unifiers instantiation typ in
+ let frees = KidSet.elements (typ_frees typ) in
+
+ (* Remove type variables from the type quantifier. *)
+ let kopts, constraints = quant_split typq in
+ let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt)) kopts in
+ let typq = mk_typquant (List.map (mk_qi_id BK_type) frees @ List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in
+ let typschm = mk_typschm typq typ in
+
+ let spec_id = id_of_instantiation id instantiation in
+ if IdSet.mem spec_id !spec_ids then [] else
+ begin
+ spec_ids := IdSet.add spec_id !spec_ids;
+ print_endline (string_of_id spec_id ^ " : " ^ string_of_typschm typschm);
+ [DEF_spec (VS_aux (VS_val_spec (typschm, spec_id, externs, is_cast), annot))]
+ end
+ in
+
+ let specializations = List.map specialize_instance instantiations |> List.concat in
+
+ append_ast pre_ast (append_ast (Defs (vs :: specializations)) post_ast)
+
+let specialize_id_fundef instantiations id ast =
+ match split_defs (is_fundef id) ast with
+ | None -> ast
+ | Some (pre_ast, DEF_fundef fundef, post_ast) ->
+ let fundefs =
+ List.map (fun i -> DEF_fundef (rename_fundef (id_of_instantiation id i) fundef)) instantiations
+ in
+ append_ast pre_ast (append_ast (Defs fundefs) post_ast)
+ | Some _ -> assert false (* unreachable *)
+
+let specialize_id_overloads instantiations id (Defs defs) =
+ let ids = IdSet.of_list (List.map (id_of_instantiation id) instantiations) in
+
+ let rec rewrite_overloads defs =
+ match defs with
+ | DEF_overload (overload_id, overloads) :: defs ->
+ let overloads = List.concat (List.map (fun id' -> if Id.compare id' id = 0 then IdSet.elements ids else [id']) overloads) in
+ DEF_overload (overload_id, overloads) :: rewrite_overloads defs
+ | def :: defs -> def :: rewrite_overloads defs
+ | [] -> []
+ in
+
+ Defs (rewrite_overloads defs)
+
+(* Once we've specialized a definition, it's original valspec should
+ be unused, unless another polymorphic function called it. We
+ therefore remove all unused valspecs. Remaining polymorphic
+ valspecs are then re-specialized. This process is iterated until
+ the whole spec is specialized. *)
+let remove_unused_valspecs ast =
+ let calls = ref (IdSet.singleton (mk_id "main")) in
+ let vs_ids = Initial_check.val_spec_ids ast in
+
+ let inspect_exp = function
+ | E_aux (E_app (call, _), _) as exp ->
+ calls := IdSet.add call !calls;
+ exp
+ | exp -> exp
+ in
+
+ let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in
+ let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in
+
+ let unused = IdSet.filter (fun vs_id -> not (IdSet.mem vs_id !calls)) vs_ids in
+
+ List.iter (fun id -> print_endline (string_of_id id)) (IdSet.elements unused);
+
+ let rec remove_unused (Defs defs) id =
+ match defs with
+ | def :: defs when is_fundef id def -> remove_unused (Defs defs) id
+ | def :: defs when is_valspec id def -> remove_unused (Defs defs) id
+ | DEF_overload (overload_id, overloads) :: defs ->
+ begin
+ match List.filter (fun id' -> Id.compare id id' <> 0) overloads with
+ | [] -> remove_unused (Defs defs) id
+ | overloads -> DEF_overload (overload_id, overloads) :: remove_unused (Defs defs) id
+ end
+ | def :: defs -> def :: remove_unused (Defs defs) id
+ | [] -> []
+ in
+
+ List.fold_left (fun ast id -> Defs (remove_unused ast id)) ast (IdSet.elements unused)
+
+let specialize_id id ast =
+ print_endline ("Specializing: " ^ string_of_id id);
+ let instantiations = instantiations_of id ast in
+
+ let ast = specialize_id_valspec instantiations id ast in
+ let ast = specialize_id_fundef instantiations id ast in
+ specialize_id_overloads instantiations id ast
+
+(* When we generate specialized versions of functions, we need to
+ ensure that the types they are specialized to appear before the
+ function definitions in the AST. Therefore we pull all the type
+ definitions (and default definitions) to the start of the AST. *)
+let reorder_typedefs (Defs defs) =
+ let tdefs = ref [] in
+
+ let rec filter_typedefs = function
+ | (DEF_default _ | DEF_type _) as tdef :: defs ->
+ tdefs := tdef :: !tdefs;
+ filter_typedefs defs
+ | def :: defs -> def :: filter_typedefs defs
+ | [] -> []
+ in
+
+ let others = filter_typedefs defs in
+ Defs (List.rev !tdefs @ others)
+
+let specialize_ids ids ast =
+ let ast = List.fold_left (fun ast id -> specialize_id id ast) ast (IdSet.elements ids) in
+ let ast = reorder_typedefs ast in
+ let ast, _ = Type_check.check Type_check.initial_env ast in
+ let ast = List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids) in
+ let ast, env = Type_check.check Type_check.initial_env ast in
+ let ast = remove_unused_valspecs ast in
+ ast, env
+
+let rec specialize ast env =
+ let ids = polymorphic_functions is_typ_kopt ast in
+ if IdSet.is_empty ids then
+ ast, env
+ else
+ let ast, env = specialize_ids ids ast in
+ specialize ast env