summaryrefslogtreecommitdiff
path: root/src/specialize.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/specialize.ml')
-rw-r--r--src/specialize.ml47
1 files changed, 34 insertions, 13 deletions
diff --git a/src/specialize.ml b/src/specialize.ml
index 0a4e0fbb..1b978fbc 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -150,8 +150,16 @@ let id_of_instantiation id instantiation =
let str = string_of_instantiation instantiation in
prepend_id (str ^ "#") id
+let rec variant_generic_typ id (Defs defs) =
+ match defs with
+ | DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ when Id.compare id id' = 0 ->
+ mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
+ | _ :: defs -> variant_generic_typ id (Defs defs)
+ | [] -> failwith ("No variant with id " ^ string_of_id id)
+
(* Returns a list of all the instantiations of a function id in an
- ast. *)
+ ast. Also works with union constructors, and searches for them in
+ patterns. *)
let rec instantiations_of id ast =
let instantiations = ref [] in
@@ -163,8 +171,30 @@ let rec instantiations_of id ast =
| exp -> exp
in
- let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in
- let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in
+ (* We need to to check patterns in case id is a union constructor
+ that is never called like a function. *)
+ let inspect_pat = function
+ | P_aux (P_app (id', _), annot) as pat when Id.compare id id' = 0 ->
+ begin match Type_check.typ_of_annot annot with
+ | Typ_aux (Typ_app (variant_id, _), _) as typ ->
+ let open Type_check in
+ let instantiation, _, _ = unify (fst annot) (env_of_annot annot)
+ (variant_generic_typ variant_id ast)
+ typ
+ in
+ instantiations := fix_instantiation instantiation :: !instantiations;
+ pat
+ | Typ_aux (Typ_id variant_id, _) -> pat
+ | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type")
+ end
+ | pat -> pat
+ in
+
+ let rewrite_pat = { id_pat_alg with p_aux = (fun (pat, annot) -> inspect_pat (P_aux (pat, annot))) } in
+ let rewrite_exp = { id_exp_alg with pat_alg = rewrite_pat;
+ e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in
+ let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp);
+ rewrite_pat = (fun _ -> fold_pat rewrite_pat)} ast in
!instantiations
@@ -329,7 +359,6 @@ let remove_unused_valspecs env ast =
match defs with
| def :: defs when is_fundef id def -> remove_unused (Defs defs) id
| def :: defs when is_valspec id def ->
- prerr_endline ("Removing: " ^ string_of_id id);
remove_unused (Defs defs) id
| DEF_overload (overload_id, overloads) :: defs ->
begin
@@ -379,13 +408,6 @@ let specialize_ids ids ast =
(***** Specialising polymorphic variant types, e.g. option *****)
-let rec variant_generic_typ id (Defs defs) =
- match defs with
- | DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ ->
- mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
- | _ :: defs -> variant_generic_typ id (Defs defs)
- | [] -> failwith ("No variant with id " ^ string_of_id id)
-
let rewrite_polymorphic_constructors id ast =
let rewrite_e_aux = function
| E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 ->
@@ -461,6 +483,5 @@ let rec specialize ast env =
if IdSet.is_empty ids then
specialize_variants ast env
else
- (prerr_endline (Util.string_of_list ", " string_of_id (IdSet.elements ids));
let ast, env = specialize_ids ids ast in
- specialize ast env)
+ specialize ast env