diff options
| author | Alasdair Armstrong | 2019-02-19 17:02:19 +0000 |
|---|---|---|
| committer | Alasdair Armstrong | 2019-02-19 17:02:19 +0000 |
| commit | fc7d360e9442ab2e945e0d2da97faaf0eefec66f (patch) | |
| tree | a823d0c949dde68bdf117c836c3c2e28f9cf9088 /src/specialize.ml | |
| parent | 3c967f9075d890b8ba0e3fa1fb990a41a36ddd80 (diff) | |
Refactor specialization
specialize functions now take a 'specialization' parameter that
specifies how they will specialize the AST. typ_ord_specialization
gives the previous behaviour, whereas int_specialization allows
specializing on Int-kinded arguments. Note that this can loop forever
unless the appropriate case splits are inserted beforehand, presumably
by monomorphisation.
rename is_nat_kopt -> is_int_kopt for consistency
Diffstat (limited to 'src/specialize.ml')
| -rw-r--r-- | src/specialize.ml | 101 |
1 files changed, 67 insertions, 34 deletions
diff --git a/src/specialize.ml b/src/specialize.ml index 9f6af6d6..eaef1231 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -52,11 +52,26 @@ open Ast open Ast_util open Rewriter -let is_typ_ord_uvar = function +let is_typ_ord_arg = function | A_aux (A_typ _, _) -> true | A_aux (A_order _, _) -> true | _ -> false +type specialization = { + is_polymorphic : kinded_id -> bool; + instantiation_filter : kid -> typ_arg -> bool + } + +let typ_ord_specialization = { + is_polymorphic = (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt); + instantiation_filter = (fun _ -> is_typ_ord_arg) + } + +let int_specialization = { + is_polymorphic = is_int_kopt; + instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp _, _) -> true | _ -> false) + } + let rec nexp_simp_typ (Typ_aux (typ_aux, l)) = let typ_aux = match typ_aux with | Typ_id v -> Typ_id v @@ -81,34 +96,43 @@ and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) = (* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of. This part of the typechecker API is a bit ugly. *) -let fix_instantiation instantiation = - let instantiation = KBindings.bindings (KBindings.filter (fun _ arg -> is_typ_ord_uvar arg) instantiation) in +let fix_instantiation spec instantiation = + let instantiation = KBindings.bindings (KBindings.filter spec.instantiation_filter instantiation) in let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation +(* polymorphic_functions returns all functions that are polymorphic + for some set of kinded-identifiers, specified by the is_kopt + predicate. For example, polymorphic_functions is_int_kopt will + return all Int-polymorphic functions. *) let rec polymorphic_functions is_kopt (Defs defs) = match defs with | DEF_spec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ) , _), id, _, externs), _)) :: defs -> - let is_type_polymorphic = List.exists is_kopt (quant_kopts typq) in - if is_type_polymorphic then + let is_polymorphic = List.exists is_kopt (quant_kopts typq) in + if is_polymorphic then IdSet.add id (polymorphic_functions is_kopt (Defs defs)) else polymorphic_functions is_kopt (Defs defs) | _ :: defs -> polymorphic_functions is_kopt (Defs defs) | [] -> IdSet.empty +(* When we specialize a function, we need to generate new name. To do + this we take the instantiation that the new function is specialized + for and turn that into a string in such a way that alpha-equivalent + instantiations always get the same name. We then zencode that + string so it is a valid identifier name, and prepend it to the + previous function name. *) let string_of_instantiation instantiation = let open Type_check in let kid_names = ref KOptMap.empty in let kid_counter = ref 0 in let kid_name kid = try KOptMap.find kid !kid_names with - | Not_found -> begin - let n = string_of_int !kid_counter in - kid_names := KOptMap.add kid n !kid_names; - incr kid_counter; - n - end + | Not_found -> + let n = string_of_int !kid_counter in + kid_names := KOptMap.add kid n !kid_names; + incr kid_counter; + n in (* We need custom string_of functions to ensure that alpha-equivalent definitions get the same name *) @@ -121,7 +145,7 @@ let string_of_instantiation instantiation = | Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")" | Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")" | Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")" - | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_nexp nexps ^ ")" + | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_nexp nexps ^ ")" | Nexp_exp n -> "2 ^ " ^ string_of_nexp n | Nexp_neg n -> "- " ^ string_of_nexp n in @@ -132,7 +156,7 @@ let string_of_instantiation instantiation = | Typ_id id -> string_of_id id | Typ_var kid -> kid_name (mk_kopt K_type kid) | Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")" - | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")" + | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")" | Typ_fn (arg_typs, ret_typ, eff) -> "(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ ^ " effect " ^ string_of_effect eff | Typ_bidir (t1, t2) -> @@ -161,10 +185,10 @@ let string_of_instantiation instantiation = | NC_aux (NC_true, _) -> "true" | NC_aux (NC_false, _) -> "false" | NC_aux (NC_var kid, _) -> kid_name (mk_kopt K_bool kid) - | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")" + | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")" in - let string_of_binding (kid, arg) = string_of_kid kid ^ " => " ^ string_of_typ_arg arg in + let string_of_binding (kid, arg) = string_of_kid kid ^ "=>" ^ string_of_typ_arg arg in Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) let id_of_instantiation id instantiation = @@ -181,12 +205,12 @@ let rec variant_generic_typ id (Defs defs) = (* Returns a list of all the instantiations of a function id in an ast. Also works with union constructors, and searches for them in patterns. *) -let rec instantiations_of id ast = +let rec instantiations_of spec id ast = let instantiations = ref [] in let inspect_exp = function | E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 -> - let instantiation = fix_instantiation (Type_check.instantiation_of exp) in + let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in instantiations := instantiation :: !instantiations; exp | exp -> exp @@ -204,7 +228,7 @@ let rec instantiations_of id ast = (variant_generic_typ variant_id ast) typ in - instantiations := fix_instantiation instantiation :: !instantiations; + instantiations := fix_instantiation spec instantiation :: !instantiations; pat | Typ_aux (Typ_id variant_id, _) -> pat | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type") @@ -220,12 +244,12 @@ let rec instantiations_of id ast = !instantiations -let rec rewrite_polymorphic_calls id ast = +let rec rewrite_polymorphic_calls spec id ast = let vs_ids = Initial_check.val_spec_ids ast in let rewrite_e_aux = function | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 -> - let instantiation = fix_instantiation (Type_check.instantiation_of exp) in + let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in let spec_id = id_of_instantiation id instantiation in (* Make sure we only generate specialized calls when we've specialized the valspec. The valspec may not be generated if @@ -280,7 +304,7 @@ and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) = | A_order ord -> KidSet.empty | A_bool _ -> KidSet.empty -let specialize_id_valspec instantiations id ast = +let specialize_id_valspec spec instantiations id ast = match split_defs (is_valspec id) ast with | None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!") | Some (pre_ast, vs, post_ast) -> @@ -304,7 +328,8 @@ let specialize_id_valspec instantiations id ast = (* Remove type variables from the type quantifier. *) let kopts, constraints = quant_split typq in - let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in + let constraints = List.map (fun c -> List.fold_left (fun c (v, a) -> constraint_subst v a c) c (KBindings.bindings instantiation)) constraints in + let kopts = List.filter (fun kopt -> not (spec.is_polymorphic kopt)) kopts in let typq = mk_typquant (List.map (mk_qi_id K_type) typ_frees @ List.map (mk_qi_id K_int) int_frees @ List.map mk_qi_kopt kopts @@ -344,7 +369,7 @@ let specialize_annotations instantiation = rewrite_exp = (fun _ -> fold_exp rw_exp); rewrite_pat = (fun _ -> fold_pat rw_pat) } - + let specialize_id_fundef instantiations id ast = match split_defs (is_fundef id) ast with | None -> ast @@ -382,7 +407,15 @@ let specialize_id_overloads instantiations id (Defs defs) = valspecs are then re-specialized. This process is iterated until the whole spec is specialized. *) -let initial_calls = (IdSet.of_list [mk_id "main"; mk_id "__SetConfig"; mk_id "__ListConfig"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"; mk_id "append_64"]) +let initial_calls = IdSet.of_list + [ mk_id "main"; + mk_id "__SetConfig"; + mk_id "__ListConfig"; + mk_id "execute"; + mk_id "decode"; + mk_id "initialize_registers"; + mk_id "append_64" (* used to construct bitvector literals in C backend *) + ] let remove_unused_valspecs ?(initial_calls=initial_calls) env ast = let calls = ref initial_calls in @@ -426,9 +459,9 @@ let slice_defs env (Defs defs) keep_ids = let defs = List.filter keep defs in remove_unused_valspecs env (Defs defs) ~initial_calls:keep_ids -let specialize_id id ast = - let instantiations = instantiations_of id ast in - let ast = specialize_id_valspec instantiations id ast in +let specialize_id spec id ast = + let instantiations = instantiations_of spec id ast in + let ast = specialize_id_valspec spec instantiations id ast in let ast = specialize_id_fundef instantiations id ast in specialize_id_overloads instantiations id ast @@ -450,21 +483,21 @@ let reorder_typedefs (Defs defs) = let others = filter_typedefs defs in Defs (List.rev !tdefs @ others) -let specialize_ids ids ast = - let ast = List.fold_left (fun ast id -> specialize_id id ast) ast (IdSet.elements ids) in +let specialize_ids spec ids ast = + let ast = List.fold_left (fun ast id -> specialize_id spec id ast) ast (IdSet.elements ids) in let ast = reorder_typedefs ast in let ast, _ = Type_error.check Type_check.initial_env ast in let ast = - List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids) + List.fold_left (fun ast id -> rewrite_polymorphic_calls spec id ast) ast (IdSet.elements ids) in let ast, env = Type_error.check Type_check.initial_env ast in let ast = remove_unused_valspecs env ast in ast, env -let rec specialize ast env = - let ids = polymorphic_functions (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) ast in +let rec specialize spec ast env = + let ids = polymorphic_functions spec.is_polymorphic ast in if IdSet.is_empty ids then ast, env else - let ast, env = specialize_ids ids ast in - specialize ast env + let ast, env = specialize_ids spec ids ast in + specialize spec ast env |
