diff options
| author | Jon French | 2018-12-28 15:12:00 +0000 |
|---|---|---|
| committer | Jon French | 2018-12-28 15:12:00 +0000 |
| commit | b59fba68e535f39b6285ec7f4f693107b6e34148 (patch) | |
| tree | 3135513ac4b23f96b41f3d521990f1ce91206c99 /src/specialize.ml | |
| parent | 9f6a95882e1d3d057bcb83d098ba1b63925a4d1f (diff) | |
| parent | 2c887e7d01331d3165120695594eac7a2650ec03 (diff) | |
Merge branch 'sail2' into rmem_interpreter
Diffstat (limited to 'src/specialize.ml')
| -rw-r--r-- | src/specialize.ml | 99 |
1 files changed, 41 insertions, 58 deletions
diff --git a/src/specialize.ml b/src/specialize.ml index 4d7a997f..1ba57bd0 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -54,8 +54,8 @@ open Rewriter open Extra_pervasives let is_typ_ord_uvar = function - | Type_check.U_typ _ -> true - | Type_check.U_order _ -> true + | A_aux (A_typ _, _) -> true + | A_aux (A_order _, _) -> true | _ -> false let rec nexp_simp_typ (Typ_aux (typ_aux, l)) = @@ -71,24 +71,20 @@ let rec nexp_simp_typ (Typ_aux (typ_aux, l)) = | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" in Typ_aux (typ_aux, l) -and nexp_simp_typ_arg (Typ_arg_aux (typ_arg_aux, l)) = +and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) = let typ_arg_aux = match typ_arg_aux with - | Typ_arg_nexp n -> Typ_arg_nexp (nexp_simp n) - | Typ_arg_typ typ -> Typ_arg_typ (nexp_simp_typ typ) - | Typ_arg_order ord -> Typ_arg_order ord + | A_nexp n -> A_nexp (nexp_simp n) + | A_typ typ -> A_typ (nexp_simp_typ typ) + | A_order ord -> A_order ord + | A_bool nc -> A_bool (constraint_simp nc) in - Typ_arg_aux (typ_arg_aux, l) - -let nexp_simp_uvar = function - | Type_check.U_nexp nexp -> (prerr_endline ("Simp nexp " ^ string_of_nexp nexp); Type_check.U_nexp (nexp_simp nexp)) - | Type_check.U_typ typ -> Type_check.U_typ (nexp_simp_typ typ) - | uvar -> uvar + A_aux (typ_arg_aux, l) (* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of. This part of the typechecker API is a bit ugly. *) let fix_instantiation instantiation = - let instantiation = KBindings.bindings (KBindings.filter (fun _ uvar -> is_typ_ord_uvar uvar) instantiation) in - let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, nexp_simp_uvar uvar) instantiation in + let instantiation = KBindings.bindings (KBindings.filter (fun _ arg -> is_typ_ord_uvar arg) instantiation) in + let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation let rec polymorphic_functions is_kopt (Defs defs) = @@ -104,13 +100,13 @@ let rec polymorphic_functions is_kopt (Defs defs) = let string_of_instantiation instantiation = let open Type_check in - let kid_names = ref KBindings.empty in + let kid_names = ref KOptMap.empty in let kid_counter = ref 0 in let kid_name kid = - try KBindings.find kid !kid_names with + try KOptMap.find kid !kid_names with | Not_found -> begin let n = string_of_int !kid_counter in - kid_names := KBindings.add kid n !kid_names; + kid_names := KOptMap.add kid n !kid_names; incr kid_counter; n end @@ -121,7 +117,7 @@ let string_of_instantiation instantiation = | Nexp_aux (nexp, _) -> string_of_nexp_aux nexp and string_of_nexp_aux = function | Nexp_id id -> string_of_id id - | Nexp_var kid -> kid_name kid + | Nexp_var kid -> kid_name (mk_kopt K_int kid) | Nexp_constant c -> Big_int.to_string c | Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")" | Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")" @@ -135,7 +131,7 @@ let string_of_instantiation instantiation = | Typ_aux (typ, l) -> string_of_typ_aux typ and string_of_typ_aux = function | Typ_id id -> string_of_id id - | Typ_var kid -> kid_name kid + | Typ_var kid -> kid_name (mk_kopt K_type kid) | Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")" | Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")" | Typ_fn (arg_typs, ret_typ, eff) -> @@ -146,11 +142,12 @@ let string_of_instantiation instantiation = "exist " ^ Util.string_of_list " " kid_name kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ | Typ_internal_unknown -> "UNKNOWN" and string_of_typ_arg = function - | Typ_arg_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg + | A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg and string_of_typ_arg_aux = function - | Typ_arg_nexp n -> string_of_nexp n - | Typ_arg_typ typ -> string_of_typ typ - | Typ_arg_order o -> string_of_order o + | A_nexp n -> string_of_nexp n + | A_typ typ -> string_of_typ typ + | A_order o -> string_of_order o + | A_bool nc -> string_of_n_constraint nc and string_of_n_constraint = function | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2 | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 @@ -161,18 +158,12 @@ let string_of_instantiation instantiation = | NC_aux (NC_and (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_set (kid, ns), _) -> - kid_name kid ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" + kid_name (mk_kopt K_int kid) ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" | NC_aux (NC_true, _) -> "true" | NC_aux (NC_false, _) -> "false" in - let string_of_uvar = function - | U_nexp n -> string_of_nexp n - | U_order o -> string_of_order o - | U_typ typ -> string_of_typ typ - in - - let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ string_of_uvar uvar in + let string_of_binding (kid, arg) = string_of_kid kid ^ " => " ^ string_of_typ_arg arg in Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation)) let id_of_instantiation id instantiation = @@ -182,7 +173,7 @@ let id_of_instantiation id instantiation = let rec variant_generic_typ id (Defs defs) = match defs with | DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ when Id.compare id id' = 0 -> - mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq))) + mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq))) | _ :: defs -> variant_generic_typ id (Defs defs) | [] -> failwith ("No variant with id " ^ string_of_id id) @@ -207,9 +198,10 @@ let rec instantiations_of id ast = begin match Type_check.typ_of_annot annot with | Typ_aux (Typ_app (variant_id, _), _) as typ -> let open Type_check in - let instantiation, _, _ = unify (fst annot) (env_of_annot annot) - (variant_generic_typ variant_id ast) - typ + let instantiation = unify (fst annot) (env_of_annot annot) + (tyvars_of_typ (variant_generic_typ variant_id ast)) + (variant_generic_typ variant_id ast) + typ in instantiations := fix_instantiation instantiation :: !instantiations; pat @@ -257,16 +249,16 @@ let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = | Typ_var kid -> KidSet.singleton kid | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs:exs) typs) | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs:exs) args) - | Typ_exist (kids, nc, typ) -> typ_frees ~exs:(KidSet.of_list kids) typ + | Typ_exist (kopts, nc, typ) -> typ_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ | Typ_fn (arg_typs, ret_typ, _) -> List.fold_left KidSet.union (typ_frees ~exs:exs ret_typ) (List.map (typ_frees ~exs:exs) arg_typs) | Typ_bidir (t1, t2) -> KidSet.union (typ_frees ~exs:exs t1) (typ_frees ~exs:exs t2) | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" -and typ_arg_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) = +and typ_arg_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) = match typ_arg_aux with - | Typ_arg_nexp n -> KidSet.empty - | Typ_arg_typ typ -> typ_frees ~exs:exs typ - | Typ_arg_order ord -> KidSet.empty + | A_nexp n -> KidSet.empty + | A_typ typ -> typ_frees ~exs:exs typ + | A_order ord -> KidSet.empty let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = match typ_aux with @@ -274,25 +266,16 @@ let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) = | Typ_var kid -> KidSet.empty | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_int_frees ~exs:exs) typs) | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_int_frees ~exs:exs) args) - | Typ_exist (kids, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list kids) typ + | Typ_exist (kopts, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ | Typ_fn (arg_typs, ret_typ, _) -> List.fold_left KidSet.union (typ_int_frees ~exs:exs ret_typ) (List.map (typ_int_frees ~exs:exs) arg_typs) | Typ_bidir (t1, t2) -> KidSet.union (typ_int_frees ~exs:exs t1) (typ_int_frees ~exs:exs t2) | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" -and typ_arg_int_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) = +and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) = match typ_arg_aux with - | Typ_arg_nexp n -> KidSet.diff (tyvars_of_nexp n) exs - | Typ_arg_typ typ -> KidSet.empty - | Typ_arg_order ord -> KidSet.empty - -let uvar_int_frees = function - | Type_check.U_nexp n -> tyvars_of_nexp n - | Type_check.U_typ typ -> typ_int_frees typ - | _ -> KidSet.empty - -let uvar_typ_frees = function - | Type_check.U_typ typ -> typ_frees typ - | _ -> KidSet.empty + | A_nexp n -> KidSet.diff (tyvars_of_nexp n) exs + | A_typ typ -> typ_int_frees ~exs:exs typ + | A_order ord -> KidSet.empty let specialize_id_valspec instantiations id ast = match split_defs (is_valspec id) ast with @@ -313,14 +296,14 @@ let specialize_id_valspec instantiations id ast = (* Collect any new type variables introduced by the instantiation *) let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in - let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_typ_frees |> collect_kids in - let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_int_frees |> collect_kids in + let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_frees |> collect_kids in + let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_int_frees |> collect_kids in (* Remove type variables from the type quantifier. *) let kopts, constraints = quant_split typq in let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in - let typq = mk_typquant (List.map (mk_qi_id BK_type) typ_frees - @ List.map (mk_qi_id BK_int) int_frees + let typq = mk_typquant (List.map (mk_qi_id K_type) typ_frees + @ List.map (mk_qi_id K_int) int_frees @ List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints) in let typschm = mk_typschm typq typ in |
