summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/specialize.ml112
1 files changed, 104 insertions, 8 deletions
diff --git a/src/specialize.ml b/src/specialize.ml
index 2ebc7307..4f8a7e7e 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -75,15 +75,81 @@ let rec polymorphic_functions is_kopt (Defs defs) =
| _ :: defs -> polymorphic_functions is_kopt (Defs defs)
| [] -> IdSet.empty
-let id_of_instantiation id instantiation =
- let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar in
- let str = Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) ^ "#" in
- prepend_id str id
-
let string_of_instantiation instantiation =
- let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar in
+ let open Type_check in
+ let kid_names = ref KBindings.empty in
+ let kid_counter = ref 0 in
+ let kid_name kid =
+ try KBindings.find kid !kid_names with
+ | Not_found -> begin
+ let n = string_of_int !kid_counter in
+ kid_names := KBindings.add kid n !kid_names;
+ incr kid_counter;
+ n
+ end
+ in
+
+ (* We need custom string_of functions to ensure that alpha-equivalent definitions get the same name *)
+ let rec string_of_nexp = function
+ | Nexp_aux (nexp, _) -> string_of_nexp_aux nexp
+ and string_of_nexp_aux = function
+ | Nexp_id id -> string_of_id id
+ | Nexp_var kid -> kid_name kid
+ | Nexp_constant c -> Big_int.to_string c
+ | Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")"
+ | Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")"
+ | Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")"
+ | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_nexp nexps ^ ")"
+ | Nexp_exp n -> "2 ^ " ^ string_of_nexp n
+ | Nexp_neg n -> "- " ^ string_of_nexp n
+ in
+
+ let rec string_of_typ = function
+ | Typ_aux (typ, l) -> string_of_typ_aux typ
+ and string_of_typ_aux = function
+ | Typ_id id -> string_of_id id
+ | Typ_var kid -> kid_name kid
+ | Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")"
+ | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")"
+ | Typ_fn (typ_arg, typ_ret, eff) ->
+ string_of_typ typ_arg ^ " -> " ^ string_of_typ typ_ret ^ " effect " ^ string_of_effect eff
+ | Typ_exist (kids, nc, typ) ->
+ "exist " ^ Util.string_of_list " " kid_name kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ
+ and string_of_typ_arg = function
+ | Typ_arg_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg
+ and string_of_typ_arg_aux = function
+ | Typ_arg_nexp n -> string_of_nexp n
+ | Typ_arg_typ typ -> string_of_typ typ
+ | Typ_arg_order o -> string_of_order o
+ and string_of_n_constraint = function
+ | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2
+ | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2
+ | NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2
+ | NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2
+ | NC_aux (NC_or (nc1, nc2), _) ->
+ "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")"
+ | NC_aux (NC_and (nc1, nc2), _) ->
+ "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")"
+ | NC_aux (NC_set (kid, ns), _) ->
+ kid_name kid ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}"
+ | NC_aux (NC_true, _) -> "true"
+ | NC_aux (NC_false, _) -> "false"
+ in
+
+ let string_of_uvar = function
+ | U_nexp n -> string_of_nexp n
+ | U_order o -> string_of_order o
+ | U_effect eff -> string_of_effect eff
+ | U_typ typ -> string_of_typ typ
+ in
+
+ let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ string_of_uvar uvar in
Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation))
+let id_of_instantiation id instantiation =
+ let str = string_of_instantiation instantiation in
+ prepend_id (str ^ "#") id
+
(* Returns a list of all the instantiations of a function id in an
ast. *)
let rec instantiations_of id ast =
@@ -140,6 +206,29 @@ and typ_arg_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) =
| Typ_arg_typ typ -> typ_frees ~exs:exs typ
| Typ_arg_order ord -> KidSet.empty
+let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
+ match typ_aux with
+ | Typ_id v -> KidSet.empty
+ | Typ_var kid -> KidSet.empty
+ | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_int_frees ~exs:exs) typs)
+ | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_int_frees ~exs:exs) args)
+ | Typ_exist (kids, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list kids) typ
+ | Typ_fn (typ1, typ2, _) -> KidSet.union (typ_int_frees ~exs:exs typ1) (typ_int_frees ~exs:exs typ2)
+and typ_arg_int_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) =
+ match typ_arg_aux with
+ | Typ_arg_nexp n -> KidSet.diff (tyvars_of_nexp n) exs
+ | Typ_arg_typ typ -> KidSet.empty
+ | Typ_arg_order ord -> KidSet.empty
+
+let uvar_int_frees = function
+ | Type_check.U_nexp n -> tyvars_of_nexp n
+ | Type_check.U_typ typ -> typ_int_frees typ
+ | _ -> KidSet.empty
+
+let uvar_typ_frees = function
+ | Type_check.U_typ typ -> typ_frees typ
+ | _ -> KidSet.empty
+
let specialize_id_valspec instantiations id ast =
match split_defs (is_valspec id) ast with
| None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!")
@@ -156,12 +245,19 @@ let specialize_id_valspec instantiations id ast =
let specialize_instance instantiation =
(* Replace the polymorphic type variables in the type with their concrete instantiation. *)
let typ = Type_check.subst_unifiers instantiation typ in
- let frees = KidSet.elements (typ_frees typ) in
+
+ (* Collect any new type variables introduced by the instantiation *)
+ let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in
+ let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_typ_frees |> collect_kids in
+ let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_int_frees |> collect_kids in
(* Remove type variables from the type quantifier. *)
let kopts, constraints = quant_split typq in
let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in
- let typq = mk_typquant (List.map (mk_qi_id BK_type) frees @ List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in
+ let typq = mk_typquant (List.map (mk_qi_id BK_type) typ_frees
+ @ List.map (mk_qi_id BK_nat) int_frees
+ @ List.map mk_qi_kopt kopts
+ @ List.map mk_qi_nc constraints) in
let typschm = mk_typschm typq typ in
let spec_id = id_of_instantiation id instantiation in