summaryrefslogtreecommitdiff
path: root/src/specialize.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/specialize.ml')
-rw-r--r--src/specialize.ml79
1 files changed, 31 insertions, 48 deletions
diff --git a/src/specialize.ml b/src/specialize.ml
index 6e625176..0f5b939c 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -54,8 +54,8 @@ open Rewriter
open Extra_pervasives
let is_typ_ord_uvar = function
- | Type_check.U_typ _ -> true
- | Type_check.U_order _ -> true
+ | A_aux (A_nexp _, _) -> true
+ | A_aux (A_order _, _) -> true
| _ -> false
let rec nexp_simp_typ (Typ_aux (typ_aux, l)) =
@@ -71,24 +71,20 @@ let rec nexp_simp_typ (Typ_aux (typ_aux, l)) =
| Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown"
in
Typ_aux (typ_aux, l)
-and nexp_simp_typ_arg (Typ_arg_aux (typ_arg_aux, l)) =
+and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) =
let typ_arg_aux = match typ_arg_aux with
- | Typ_arg_nexp n -> Typ_arg_nexp (nexp_simp n)
- | Typ_arg_typ typ -> Typ_arg_typ (nexp_simp_typ typ)
- | Typ_arg_order ord -> Typ_arg_order ord
+ | A_nexp n -> A_nexp (nexp_simp n)
+ | A_typ typ -> A_typ (nexp_simp_typ typ)
+ | A_order ord -> A_order ord
+ | A_bool nc -> A_bool (constraint_simp nc)
in
- Typ_arg_aux (typ_arg_aux, l)
-
-let nexp_simp_uvar = function
- | Type_check.U_nexp nexp -> (prerr_endline ("Simp nexp " ^ string_of_nexp nexp); Type_check.U_nexp (nexp_simp nexp))
- | Type_check.U_typ typ -> Type_check.U_typ (nexp_simp_typ typ)
- | uvar -> uvar
+ A_aux (typ_arg_aux, l)
(* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of.
This part of the typechecker API is a bit ugly. *)
let fix_instantiation instantiation =
- let instantiation = KBindings.bindings (KBindings.filter (fun _ uvar -> is_typ_ord_uvar uvar) instantiation) in
- let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, nexp_simp_uvar uvar) instantiation in
+ let instantiation = KBindings.bindings (KBindings.filter (fun _ arg -> is_typ_ord_uvar arg) instantiation) in
+ let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in
List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation
let rec polymorphic_functions is_kopt (Defs defs) =
@@ -146,11 +142,12 @@ let string_of_instantiation instantiation =
"exist " ^ Util.string_of_list " " kid_name kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ
| Typ_internal_unknown -> "UNKNOWN"
and string_of_typ_arg = function
- | Typ_arg_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg
+ | A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg
and string_of_typ_arg_aux = function
- | Typ_arg_nexp n -> string_of_nexp n
- | Typ_arg_typ typ -> string_of_typ typ
- | Typ_arg_order o -> string_of_order o
+ | A_nexp n -> string_of_nexp n
+ | A_typ typ -> string_of_typ typ
+ | A_order o -> string_of_order o
+ | A_bool nc -> string_of_n_constraint nc
and string_of_n_constraint = function
| NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2
| NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2
@@ -166,13 +163,7 @@ let string_of_instantiation instantiation =
| NC_aux (NC_false, _) -> "false"
in
- let string_of_uvar = function
- | U_nexp n -> string_of_nexp n
- | U_order o -> string_of_order o
- | U_typ typ -> string_of_typ typ
- in
-
- let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ string_of_uvar uvar in
+ let string_of_binding (kid, arg) = string_of_kid kid ^ " => " ^ string_of_typ_arg arg in
Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation))
let id_of_instantiation id instantiation =
@@ -182,7 +173,7 @@ let id_of_instantiation id instantiation =
let rec variant_generic_typ id (Defs defs) =
match defs with
| DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ when Id.compare id id' = 0 ->
- mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
+ mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
| _ :: defs -> variant_generic_typ id (Defs defs)
| [] -> failwith ("No variant with id " ^ string_of_id id)
@@ -207,9 +198,10 @@ let rec instantiations_of id ast =
begin match Type_check.typ_of_annot annot with
| Typ_aux (Typ_app (variant_id, _), _) as typ ->
let open Type_check in
- let instantiation, _, _ = unify (fst annot) (env_of_annot annot)
- (variant_generic_typ variant_id ast)
- typ
+ let instantiation = unify (fst annot) (env_of_annot annot)
+ (tyvars_of_typ (variant_generic_typ variant_id ast))
+ (variant_generic_typ variant_id ast)
+ typ
in
instantiations := fix_instantiation instantiation :: !instantiations;
pat
@@ -262,11 +254,11 @@ let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
List.fold_left KidSet.union (typ_frees ~exs:exs ret_typ) (List.map (typ_frees ~exs:exs) arg_typs)
| Typ_bidir (t1, t2) -> KidSet.union (typ_frees ~exs:exs t1) (typ_frees ~exs:exs t2)
| Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown"
-and typ_arg_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) =
+and typ_arg_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) =
match typ_arg_aux with
- | Typ_arg_nexp n -> KidSet.empty
- | Typ_arg_typ typ -> typ_frees ~exs:exs typ
- | Typ_arg_order ord -> KidSet.empty
+ | A_nexp n -> KidSet.empty
+ | A_typ typ -> typ_frees ~exs:exs typ
+ | A_order ord -> KidSet.empty
let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
match typ_aux with
@@ -279,20 +271,11 @@ let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
List.fold_left KidSet.union (typ_int_frees ~exs:exs ret_typ) (List.map (typ_int_frees ~exs:exs) arg_typs)
| Typ_bidir (t1, t2) -> KidSet.union (typ_int_frees ~exs:exs t1) (typ_int_frees ~exs:exs t2)
| Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown"
-and typ_arg_int_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) =
+and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) =
match typ_arg_aux with
- | Typ_arg_nexp n -> KidSet.diff (tyvars_of_nexp n) exs
- | Typ_arg_typ typ -> KidSet.empty
- | Typ_arg_order ord -> KidSet.empty
-
-let uvar_int_frees = function
- | Type_check.U_nexp n -> tyvars_of_nexp n
- | Type_check.U_typ typ -> typ_int_frees typ
- | _ -> KidSet.empty
-
-let uvar_typ_frees = function
- | Type_check.U_typ typ -> typ_frees typ
- | _ -> KidSet.empty
+ | A_nexp n -> KidSet.diff (tyvars_of_nexp n) exs
+ | A_typ typ -> typ_int_frees ~exs:exs typ
+ | A_order ord -> KidSet.empty
let specialize_id_valspec instantiations id ast =
match split_defs (is_valspec id) ast with
@@ -313,8 +296,8 @@ let specialize_id_valspec instantiations id ast =
(* Collect any new type variables introduced by the instantiation *)
let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in
- let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_typ_frees |> collect_kids in
- let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_int_frees |> collect_kids in
+ let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_frees |> collect_kids in
+ let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_int_frees |> collect_kids in
(* Remove type variables from the type quantifier. *)
let kopts, constraints = quant_split typq in