diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/specialize.ml | 112 |
1 files changed, 104 insertions, 8 deletions
diff --git a/src/specialize.ml b/src/specialize.ml index 2ebc7307..4f8a7e7e 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -75,15 +75,81 @@ let rec polymorphic_functions is_kopt (Defs defs) = | _ :: defs -> polymorphic_functions is_kopt (Defs defs) | [] -> IdSet.empty -let id_of_instantiation id instantiation = - let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar in - let str = Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) ^ "#" in - prepend_id str id - let string_of_instantiation instantiation = - let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ Type_check.string_of_uvar uvar in + let open Type_check in + let kid_names = ref KBindings.empty in + let kid_counter = ref 0 in + let kid_name kid = + try KBindings.find kid !kid_names with + | Not_found -> begin + let n = string_of_int !kid_counter in + kid_names := KBindings.add kid n !kid_names; + incr kid_counter; + n + end + in + + (* We need custom string_of functions to ensure that alpha-equivalent definitions get the same name *) + let rec string_of_nexp = function + | Nexp_aux (nexp, _) -> string_of_nexp_aux nexp + and string_of_nexp_aux = function + | Nexp_id id -> string_of_id id + | Nexp_var kid -> kid_name kid + | Nexp_constant c -> Big_int.to_string c + | Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")" + | Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")" + | Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")" + | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_nexp nexps ^ ")" + | Nexp_exp n -> "2 ^ " ^ string_of_nexp n + | Nexp_neg n -> "- " ^ string_of_nexp n + in + + let rec string_of_typ = function + | Typ_aux (typ, l) -> string_of_typ_aux typ + and string_of_typ_aux = function + | Typ_id id -> string_of_id id + | Typ_var kid -> kid_name kid + | Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")" + | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")" + | Typ_fn (typ_arg, typ_ret, eff) -> + string_of_typ typ_arg ^ " -> " ^ string_of_typ typ_ret ^ " effect " ^ string_of_effect eff + | Typ_exist (kids, nc, typ) -> + "exist " ^ Util.string_of_list " " kid_name kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ + and string_of_typ_arg = function + | Typ_arg_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg + and string_of_typ_arg_aux = function + | Typ_arg_nexp n -> string_of_nexp n + | Typ_arg_typ typ -> string_of_typ typ + | Typ_arg_order o -> string_of_order o + and string_of_n_constraint = function + | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2 + | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 + | NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2 + | NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2 + | NC_aux (NC_or (nc1, nc2), _) -> + "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")" + | NC_aux (NC_and (nc1, nc2), _) -> + "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" + | NC_aux (NC_set (kid, ns), _) -> + kid_name kid ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" + | NC_aux (NC_true, _) -> "true" + | NC_aux (NC_false, _) -> "false" + in + + let string_of_uvar = function + | U_nexp n -> string_of_nexp n + | U_order o -> string_of_order o + | U_effect eff -> string_of_effect eff + | U_typ typ -> string_of_typ typ + in + + let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ string_of_uvar uvar in Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) +let id_of_instantiation id instantiation = + let str = string_of_instantiation instantiation in + prepend_id (str ^ "#") id + (* Returns a list of all the instantiations of a function id in an ast. *) let rec instantiations_of id ast = @@ -140,6 +206,29 @@ and typ_arg_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) = | Typ_arg_typ typ -> typ_frees ~exs:exs typ | Typ_arg_order ord -> KidSet.empty +let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = + match typ_aux with + | Typ_id v -> KidSet.empty + | Typ_var kid -> KidSet.empty + | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_int_frees ~exs:exs) typs) + | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_int_frees ~exs:exs) args) + | Typ_exist (kids, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list kids) typ + | Typ_fn (typ1, typ2, _) -> KidSet.union (typ_int_frees ~exs:exs typ1) (typ_int_frees ~exs:exs typ2) +and typ_arg_int_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) = + match typ_arg_aux with + | Typ_arg_nexp n -> KidSet.diff (tyvars_of_nexp n) exs + | Typ_arg_typ typ -> KidSet.empty + | Typ_arg_order ord -> KidSet.empty + +let uvar_int_frees = function + | Type_check.U_nexp n -> tyvars_of_nexp n + | Type_check.U_typ typ -> typ_int_frees typ + | _ -> KidSet.empty + +let uvar_typ_frees = function + | Type_check.U_typ typ -> typ_frees typ + | _ -> KidSet.empty + let specialize_id_valspec instantiations id ast = match split_defs (is_valspec id) ast with | None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!") @@ -156,12 +245,19 @@ let specialize_id_valspec instantiations id ast = let specialize_instance instantiation = (* Replace the polymorphic type variables in the type with their concrete instantiation. *) let typ = Type_check.subst_unifiers instantiation typ in - let frees = KidSet.elements (typ_frees typ) in + + (* Collect any new type variables introduced by the instantiation *) + let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in + let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_typ_frees |> collect_kids in + let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_int_frees |> collect_kids in (* Remove type variables from the type quantifier. *) let kopts, constraints = quant_split typq in let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in - let typq = mk_typquant (List.map (mk_qi_id BK_type) frees @ List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in + let typq = mk_typquant (List.map (mk_qi_id BK_type) typ_frees + @ List.map (mk_qi_id BK_nat) int_frees + @ List.map mk_qi_kopt kopts + @ List.map mk_qi_nc constraints) in let typschm = mk_typschm typq typ in let spec_id = id_of_instantiation id instantiation in |
