diff options
Diffstat (limited to 'src/specialize.ml')
| -rw-r--r-- | src/specialize.ml | 47 |
1 files changed, 34 insertions, 13 deletions
diff --git a/src/specialize.ml b/src/specialize.ml index 0a4e0fbb..1b978fbc 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -150,8 +150,16 @@ let id_of_instantiation id instantiation = let str = string_of_instantiation instantiation in prepend_id (str ^ "#") id +let rec variant_generic_typ id (Defs defs) = + match defs with + | DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ when Id.compare id id' = 0 -> + mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq))) + | _ :: defs -> variant_generic_typ id (Defs defs) + | [] -> failwith ("No variant with id " ^ string_of_id id) + (* Returns a list of all the instantiations of a function id in an - ast. *) + ast. Also works with union constructors, and searches for them in + patterns. *) let rec instantiations_of id ast = let instantiations = ref [] in @@ -163,8 +171,30 @@ let rec instantiations_of id ast = | exp -> exp in - let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in - let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in + (* We need to to check patterns in case id is a union constructor + that is never called like a function. *) + let inspect_pat = function + | P_aux (P_app (id', _), annot) as pat when Id.compare id id' = 0 -> + begin match Type_check.typ_of_annot annot with + | Typ_aux (Typ_app (variant_id, _), _) as typ -> + let open Type_check in + let instantiation, _, _ = unify (fst annot) (env_of_annot annot) + (variant_generic_typ variant_id ast) + typ + in + instantiations := fix_instantiation instantiation :: !instantiations; + pat + | Typ_aux (Typ_id variant_id, _) -> pat + | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type") + end + | pat -> pat + in + + let rewrite_pat = { id_pat_alg with p_aux = (fun (pat, annot) -> inspect_pat (P_aux (pat, annot))) } in + let rewrite_exp = { id_exp_alg with pat_alg = rewrite_pat; + e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in + let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp); + rewrite_pat = (fun _ -> fold_pat rewrite_pat)} ast in !instantiations @@ -329,7 +359,6 @@ let remove_unused_valspecs env ast = match defs with | def :: defs when is_fundef id def -> remove_unused (Defs defs) id | def :: defs when is_valspec id def -> - prerr_endline ("Removing: " ^ string_of_id id); remove_unused (Defs defs) id | DEF_overload (overload_id, overloads) :: defs -> begin @@ -379,13 +408,6 @@ let specialize_ids ids ast = (***** Specialising polymorphic variant types, e.g. option *****) -let rec variant_generic_typ id (Defs defs) = - match defs with - | DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ -> - mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq))) - | _ :: defs -> variant_generic_typ id (Defs defs) - | [] -> failwith ("No variant with id " ^ string_of_id id) - let rewrite_polymorphic_constructors id ast = let rewrite_e_aux = function | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 -> @@ -461,6 +483,5 @@ let rec specialize ast env = if IdSet.is_empty ids then specialize_variants ast env else - (prerr_endline (Util.string_of_list ", " string_of_id (IdSet.elements ids)); let ast, env = specialize_ids ids ast in - specialize ast env) + specialize ast env |
