diff options
| author | Jon French | 2019-02-25 12:10:30 +0000 |
|---|---|---|
| committer | Jon French | 2019-02-25 12:10:30 +0000 |
| commit | 915d75f9c49fa2c2a9d47d189e4224cee16582c9 (patch) | |
| tree | 77a93e682796977898af0b56e0a61d7689db112e /src/specialize.ml | |
| parent | a8a5308e4981b3d09fb2bf0c59d592ef6ae4417e (diff) | |
| parent | 38656b50ad24df6a29f3a84e50adfcf409131fb0 (diff) | |
Merge branch 'sail2' into rmem_interpreter
Diffstat (limited to 'src/specialize.ml')
| -rw-r--r-- | src/specialize.ml | 200 |
1 files changed, 152 insertions, 48 deletions
diff --git a/src/specialize.ml b/src/specialize.ml index 00357557..591a415a 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -52,11 +52,26 @@ open Ast open Ast_util open Rewriter -let is_typ_ord_uvar = function +let is_typ_ord_arg = function | A_aux (A_typ _, _) -> true | A_aux (A_order _, _) -> true | _ -> false +type specialization = { + is_polymorphic : kinded_id -> bool; + instantiation_filter : kid -> typ_arg -> bool + } + +let typ_ord_specialization = { + is_polymorphic = (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt); + instantiation_filter = (fun _ -> is_typ_ord_arg) + } + +let int_specialization = { + is_polymorphic = is_int_kopt; + instantiation_filter = (fun _ arg -> match arg with A_aux (A_nexp _, _) -> true | _ -> false) + } + let rec nexp_simp_typ (Typ_aux (typ_aux, l)) = let typ_aux = match typ_aux with | Typ_id v -> Typ_id v @@ -81,34 +96,43 @@ and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) = (* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of. This part of the typechecker API is a bit ugly. *) -let fix_instantiation instantiation = - let instantiation = KBindings.bindings (KBindings.filter (fun _ arg -> is_typ_ord_uvar arg) instantiation) in +let fix_instantiation spec instantiation = + let instantiation = KBindings.bindings (KBindings.filter spec.instantiation_filter instantiation) in let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation +(* polymorphic_functions returns all functions that are polymorphic + for some set of kinded-identifiers, specified by the is_kopt + predicate. For example, polymorphic_functions is_int_kopt will + return all Int-polymorphic functions. *) let rec polymorphic_functions is_kopt (Defs defs) = match defs with | DEF_spec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ) , _), id, _, externs), _)) :: defs -> - let is_type_polymorphic = List.exists is_kopt (quant_kopts typq) in - if is_type_polymorphic then + let is_polymorphic = List.exists is_kopt (quant_kopts typq) in + if is_polymorphic then IdSet.add id (polymorphic_functions is_kopt (Defs defs)) else polymorphic_functions is_kopt (Defs defs) | _ :: defs -> polymorphic_functions is_kopt (Defs defs) | [] -> IdSet.empty +(* When we specialize a function, we need to generate new name. To do + this we take the instantiation that the new function is specialized + for and turn that into a string in such a way that alpha-equivalent + instantiations always get the same name. We then zencode that + string so it is a valid identifier name, and prepend it to the + previous function name. *) let string_of_instantiation instantiation = let open Type_check in let kid_names = ref KOptMap.empty in let kid_counter = ref 0 in let kid_name kid = try KOptMap.find kid !kid_names with - | Not_found -> begin - let n = string_of_int !kid_counter in - kid_names := KOptMap.add kid n !kid_names; - incr kid_counter; - n - end + | Not_found -> + let n = string_of_int !kid_counter in + kid_names := KOptMap.add kid n !kid_names; + incr kid_counter; + n in (* We need custom string_of functions to ensure that alpha-equivalent definitions get the same name *) @@ -121,7 +145,7 @@ let string_of_instantiation instantiation = | Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")" | Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")" | Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")" - | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_nexp nexps ^ ")" + | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_nexp nexps ^ ")" | Nexp_exp n -> "2 ^ " ^ string_of_nexp n | Nexp_neg n -> "- " ^ string_of_nexp n in @@ -132,7 +156,7 @@ let string_of_instantiation instantiation = | Typ_id id -> string_of_id id | Typ_var kid -> kid_name (mk_kopt K_type kid) | Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")" - | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")" + | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")" | Typ_fn (arg_typs, ret_typ, eff) -> "(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ ^ " effect " ^ string_of_effect eff | Typ_bidir (t1, t2) -> @@ -160,9 +184,11 @@ let string_of_instantiation instantiation = kid_name (mk_kopt K_int kid) ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" | NC_aux (NC_true, _) -> "true" | NC_aux (NC_false, _) -> "false" + | NC_aux (NC_var kid, _) -> kid_name (mk_kopt K_bool kid) + | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")" in - let string_of_binding (kid, arg) = string_of_kid kid ^ " => " ^ string_of_typ_arg arg in + let string_of_binding (kid, arg) = string_of_kid kid ^ "=>" ^ string_of_typ_arg arg in Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) let id_of_instantiation id instantiation = @@ -179,12 +205,12 @@ let rec variant_generic_typ id (Defs defs) = (* Returns a list of all the instantiations of a function id in an ast. Also works with union constructors, and searches for them in patterns. *) -let rec instantiations_of id ast = +let rec instantiations_of spec id ast = let instantiations = ref [] in let inspect_exp = function | E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 -> - let instantiation = fix_instantiation (Type_check.instantiation_of exp) in + let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in instantiations := instantiation :: !instantiations; exp | exp -> exp @@ -202,7 +228,7 @@ let rec instantiations_of id ast = (variant_generic_typ variant_id ast) typ in - instantiations := fix_instantiation instantiation :: !instantiations; + instantiations := fix_instantiation spec instantiation :: !instantiations; pat | Typ_aux (Typ_id variant_id, _) -> pat | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type") @@ -218,12 +244,12 @@ let rec instantiations_of id ast = !instantiations -let rec rewrite_polymorphic_calls id ast = +let rec rewrite_polymorphic_calls spec id ast = let vs_ids = Initial_check.val_spec_ids ast in let rewrite_e_aux = function | E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 -> - let instantiation = fix_instantiation (Type_check.instantiation_of exp) in + let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in let spec_id = id_of_instantiation id instantiation in (* Make sure we only generate specialized calls when we've specialized the valspec. The valspec may not be generated if @@ -278,13 +304,61 @@ and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) = | A_order ord -> KidSet.empty | A_bool _ -> KidSet.empty -let specialize_id_valspec instantiations id ast = +(* Implicit arguments have restrictions that won't hold + post-specialisation, but we can just remove them and turn them into + regular arguments. *) +let rec remove_implicit (Typ_aux (aux, l) as t) = + match aux with + | Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l) + | Typ_tup typs -> Typ_aux (Typ_tup (List.map remove_implicit typs), l) + | Typ_fn (arg_typs, ret_typ, effs) -> Typ_aux (Typ_fn (List.map remove_implicit arg_typs, remove_implicit ret_typ, effs), l) + | Typ_bidir (typ1, typ2) -> Typ_aux (Typ_bidir (remove_implicit typ1, remove_implicit typ2), l) + | Typ_app (Id_aux (Id "implicit", _), args) -> Typ_aux (Typ_app (mk_id "atom", List.map remove_implicit_arg args), l) + | Typ_app (id, args) -> Typ_aux (Typ_app (id, List.map remove_implicit_arg args), l) + | Typ_id id -> Typ_aux (Typ_id id, l) + | Typ_exist (kopts, nc, typ) -> Typ_aux (Typ_exist (kopts, nc, remove_implicit typ), l) + | Typ_var v -> Typ_aux (Typ_var v, l) +and remove_implicit_arg (A_aux (aux, l)) = + match aux with + | A_typ typ -> A_aux (A_typ (remove_implicit typ), l) + | arg -> A_aux (arg, l) + +let kopt_arg = function + | KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _) -> arg_nexp (nvar kid) + | KOpt_aux (KOpt_kind (K_aux (K_type,_), kid), _) -> arg_typ (mk_typ (Typ_var kid)) + | _ -> failwith "oh no" + +(* For numeric type arguments we have to be careful not to run into a + situation where we have an instantiation like + + 'n => 'm, 'm => 8 + + and end up re-writing 'n to 8. This function turns an instantition + like the above into two, + + 'n => 'i#m, 'm => 8 and 'i#m => 'm + + so we can do the substitution in two steps. *) +let safe_instantiation instantiation = + let args = + List.map (fun (_, arg) -> kopts_of_typ_arg arg) (KBindings.bindings instantiation) + |> List.fold_left KOptSet.union KOptSet.empty + |> KOptSet.elements + in + List.fold_left (fun (i, r) v -> KBindings.map (fun arg -> subst_kid typ_arg_subst (kopt_kid v) (prepend_kid "i#" (kopt_kid v)) arg) i, + KBindings.add (prepend_kid "i#" (kopt_kid v)) (kopt_arg v) r) + (instantiation, KBindings.empty) args + +let instantiate_constraints instantiation ncs = + List.map (fun c -> List.fold_left (fun c (v, a) -> constraint_subst v a c) c (KBindings.bindings instantiation)) ncs + +let specialize_id_valspec spec instantiations id ast = match split_defs (is_valspec id) ast with - | None -> failwith ("Valspec " ^ string_of_id id ^ " does not exist!") + | None -> Reporting.unreachable (id_loc id) __POS__ ("Valspec " ^ string_of_id id ^ " does not exist!") | Some (pre_ast, vs, post_ast) -> let typschm, externs, is_cast, annot = match vs with | DEF_spec (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)) -> typschm, externs, is_cast, annot - | _ -> assert false (* unreachable *) + | _ -> Reporting.unreachable (id_loc id) __POS__ "val-spec is not actually a val-spec" in let TypSchm_aux (TypSchm_ts (typq, typ), _) = typschm in @@ -292,8 +366,9 @@ let specialize_id_valspec instantiations id ast = let spec_ids = ref IdSet.empty in let specialize_instance instantiation = + let safe_instantiation, reverse = safe_instantiation instantiation in (* Replace the polymorphic type variables in the type with their concrete instantiation. *) - let typ = Type_check.subst_unifiers instantiation typ in + let typ = remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ)) in (* Collect any new type variables introduced by the instantiation *) let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in @@ -302,11 +377,17 @@ let specialize_id_valspec instantiations id ast = (* Remove type variables from the type quantifier. *) let kopts, constraints = quant_split typq in - let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in - let typq = mk_typquant (List.map (mk_qi_id K_type) typ_frees - @ List.map (mk_qi_id K_int) int_frees - @ List.map mk_qi_kopt kopts - @ List.map mk_qi_nc constraints) in + let constraints = instantiate_constraints safe_instantiation constraints in + let constraints = instantiate_constraints reverse constraints in + let kopts = List.filter (fun kopt -> not (spec.is_polymorphic kopt)) kopts in + let typq = + if List.length (typ_frees @ int_frees) = 0 && List.length kopts = 0 then + mk_typquant [] + else + mk_typquant (List.map (mk_qi_id K_type) typ_frees + @ List.map (mk_qi_id K_int) int_frees + @ List.map mk_qi_kopt kopts + @ List.map mk_qi_nc constraints) in let typschm = mk_typschm typq typ in let spec_id = id_of_instantiation id instantiation in @@ -324,8 +405,9 @@ let specialize_id_valspec instantiations id ast = (* When we specialize a function definition we also need to specialize all the types that appear as annotations within the function - body. *) -let specialize_annotations instantiation = + body. Also remove any type-annotation from the fundef itself, + because at this point we have that as a separate valspec.*) +let specialize_annotations instantiation fdef = let open Type_check in let rw_pat = { id_pat_alg with @@ -337,12 +419,21 @@ let specialize_annotations instantiation = lEXP_cast = (fun (typ, lexp) -> LEXP_cast (subst_unifiers instantiation typ, lexp)); pat_alg = rw_pat } in - rewrite_fun { - rewriters_base with - rewrite_exp = (fun _ -> fold_exp rw_exp); - rewrite_pat = (fun _ -> fold_pat rw_pat) - } - + let fdef = + rewrite_fun { + rewriters_base with + rewrite_exp = (fun _ -> fold_exp rw_exp); + rewrite_pat = (fun _ -> fold_pat rw_pat) + } fdef + in + match fdef with + | FD_aux (FD_function (rec_opt, _, eff_opt, funcls), annot) -> + FD_aux (FD_function (rec_opt, + Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown), + eff_opt, + funcls), + annot) + let specialize_id_fundef instantiations id ast = match split_defs (is_fundef id) ast with | None -> ast @@ -380,7 +471,15 @@ let specialize_id_overloads instantiations id (Defs defs) = valspecs are then re-specialized. This process is iterated until the whole spec is specialized. *) -let initial_calls = (IdSet.of_list [mk_id "main"; mk_id "__SetConfig"; mk_id "__ListConfig"; mk_id "execute"; mk_id "decode"; mk_id "initialize_registers"; mk_id "append_64"]) +let initial_calls = IdSet.of_list + [ mk_id "main"; + mk_id "__SetConfig"; + mk_id "__ListConfig"; + mk_id "execute"; + mk_id "decode"; + mk_id "initialize_registers"; + mk_id "append_64" (* used to construct bitvector literals in C backend *) + ] let remove_unused_valspecs ?(initial_calls=initial_calls) env ast = let calls = ref initial_calls in @@ -424,9 +523,9 @@ let slice_defs env (Defs defs) keep_ids = let defs = List.filter keep defs in remove_unused_valspecs env (Defs defs) ~initial_calls:keep_ids -let specialize_id id ast = - let instantiations = instantiations_of id ast in - let ast = specialize_id_valspec instantiations id ast in +let specialize_id spec id ast = + let instantiations = instantiations_of spec id ast in + let ast = specialize_id_valspec spec instantiations id ast in let ast = specialize_id_fundef instantiations id ast in specialize_id_overloads instantiations id ast @@ -448,21 +547,26 @@ let reorder_typedefs (Defs defs) = let others = filter_typedefs defs in Defs (List.rev !tdefs @ others) -let specialize_ids ids ast = - let ast = List.fold_left (fun ast id -> specialize_id id ast) ast (IdSet.elements ids) in +let specialize_ids spec ids ast = + let ast = List.fold_left (fun ast id -> specialize_id spec id ast) ast (IdSet.elements ids) in let ast = reorder_typedefs ast in let ast, _ = Type_error.check Type_check.initial_env ast in let ast = - List.fold_left (fun ast id -> rewrite_polymorphic_calls id ast) ast (IdSet.elements ids) + List.fold_left (fun ast id -> rewrite_polymorphic_calls spec id ast) ast (IdSet.elements ids) in let ast, env = Type_error.check Type_check.initial_env ast in let ast = remove_unused_valspecs env ast in ast, env -let rec specialize ast env = - let ids = polymorphic_functions (fun kopt -> is_typ_kopt kopt || is_order_kopt kopt) ast in - if IdSet.is_empty ids then +let rec specialize' n spec ast env = + if n = 0 then ast, env else - let ast, env = specialize_ids ids ast in - specialize ast env + let ids = polymorphic_functions spec.is_polymorphic ast in + if IdSet.is_empty ids then + ast, env + else + let ast, env = specialize_ids spec ids ast in + specialize' (n - 1) spec ast env + +let specialize = specialize' (-1) |
