summaryrefslogtreecommitdiff
path: root/src/specialize.ml
diff options
context:
space:
mode:
authorThomas Bauereiss2018-12-18 15:16:36 +0000
committerThomas Bauereiss2018-12-18 15:16:36 +0000
commit1766bf5e3628b5c45290a3353bec05823661b9d3 (patch)
treecae2f596d135074399cd304bb8e3dca1330a2aa8 /src/specialize.ml
parentdf0e02bc0c8259962f25d4c175fa950391695ab6 (diff)
parent07a332c856b3ee9fe26a9cd47ea6005f9d579810 (diff)
Merge branch 'sail2' into monads
Diffstat (limited to 'src/specialize.ml')
-rw-r--r--src/specialize.ml112
1 files changed, 49 insertions, 63 deletions
diff --git a/src/specialize.ml b/src/specialize.ml
index 81c8b0b0..1ba57bd0 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -54,8 +54,8 @@ open Rewriter
open Extra_pervasives
let is_typ_ord_uvar = function
- | Type_check.U_typ _ -> true
- | Type_check.U_order _ -> true
+ | A_aux (A_typ _, _) -> true
+ | A_aux (A_order _, _) -> true
| _ -> false
let rec nexp_simp_typ (Typ_aux (typ_aux, l)) =
@@ -65,29 +65,26 @@ let rec nexp_simp_typ (Typ_aux (typ_aux, l)) =
| Typ_tup typs -> Typ_tup (List.map nexp_simp_typ typs)
| Typ_app (f, args) -> Typ_app (f, List.map nexp_simp_typ_arg args)
| Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, nexp_simp_typ typ)
- | Typ_fn (typ1, typ2, effect) -> Typ_fn (nexp_simp_typ typ1, nexp_simp_typ typ2, effect)
+ | Typ_fn (arg_typs, ret_typ, effect) ->
+ Typ_fn (List.map nexp_simp_typ arg_typs, nexp_simp_typ ret_typ, effect)
| Typ_bidir (t1, t2) -> Typ_bidir (nexp_simp_typ t1, nexp_simp_typ t2)
| Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown"
in
Typ_aux (typ_aux, l)
-and nexp_simp_typ_arg (Typ_arg_aux (typ_arg_aux, l)) =
+and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) =
let typ_arg_aux = match typ_arg_aux with
- | Typ_arg_nexp n -> Typ_arg_nexp (nexp_simp n)
- | Typ_arg_typ typ -> Typ_arg_typ (nexp_simp_typ typ)
- | Typ_arg_order ord -> Typ_arg_order ord
+ | A_nexp n -> A_nexp (nexp_simp n)
+ | A_typ typ -> A_typ (nexp_simp_typ typ)
+ | A_order ord -> A_order ord
+ | A_bool nc -> A_bool (constraint_simp nc)
in
- Typ_arg_aux (typ_arg_aux, l)
-
-let nexp_simp_uvar = function
- | Type_check.U_nexp nexp -> (prerr_endline ("Simp nexp " ^ string_of_nexp nexp); Type_check.U_nexp (nexp_simp nexp))
- | Type_check.U_typ typ -> Type_check.U_typ (nexp_simp_typ typ)
- | uvar -> uvar
+ A_aux (typ_arg_aux, l)
(* We have to be careful about whether the typechecker has renamed anything returned by instantiation_of.
This part of the typechecker API is a bit ugly. *)
let fix_instantiation instantiation =
- let instantiation = KBindings.bindings (KBindings.filter (fun _ uvar -> is_typ_ord_uvar uvar) instantiation) in
- let instantiation = List.map (fun (kid, uvar) -> Type_check.orig_kid kid, nexp_simp_uvar uvar) instantiation in
+ let instantiation = KBindings.bindings (KBindings.filter (fun _ arg -> is_typ_ord_uvar arg) instantiation) in
+ let instantiation = List.map (fun (kid, arg) -> Type_check.orig_kid kid, nexp_simp_typ_arg arg) instantiation in
List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation
let rec polymorphic_functions is_kopt (Defs defs) =
@@ -103,13 +100,13 @@ let rec polymorphic_functions is_kopt (Defs defs) =
let string_of_instantiation instantiation =
let open Type_check in
- let kid_names = ref KBindings.empty in
+ let kid_names = ref KOptMap.empty in
let kid_counter = ref 0 in
let kid_name kid =
- try KBindings.find kid !kid_names with
+ try KOptMap.find kid !kid_names with
| Not_found -> begin
let n = string_of_int !kid_counter in
- kid_names := KBindings.add kid n !kid_names;
+ kid_names := KOptMap.add kid n !kid_names;
incr kid_counter;
n
end
@@ -120,7 +117,7 @@ let string_of_instantiation instantiation =
| Nexp_aux (nexp, _) -> string_of_nexp_aux nexp
and string_of_nexp_aux = function
| Nexp_id id -> string_of_id id
- | Nexp_var kid -> kid_name kid
+ | Nexp_var kid -> kid_name (mk_kopt K_int kid)
| Nexp_constant c -> Big_int.to_string c
| Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")"
| Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")"
@@ -134,22 +131,23 @@ let string_of_instantiation instantiation =
| Typ_aux (typ, l) -> string_of_typ_aux typ
and string_of_typ_aux = function
| Typ_id id -> string_of_id id
- | Typ_var kid -> kid_name kid
+ | Typ_var kid -> kid_name (mk_kopt K_type kid)
| Typ_tup typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")"
| Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list ", " string_of_typ_arg args ^ ")"
- | Typ_fn (typ_arg, typ_ret, eff) ->
- string_of_typ typ_arg ^ " -> " ^ string_of_typ typ_ret ^ " effect " ^ string_of_effect eff
+ | Typ_fn (arg_typs, ret_typ, eff) ->
+ "(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ ^ " effect " ^ string_of_effect eff
| Typ_bidir (t1, t2) ->
string_of_typ t1 ^ " <-> " ^ string_of_typ t2
| Typ_exist (kids, nc, typ) ->
"exist " ^ Util.string_of_list " " kid_name kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ
| Typ_internal_unknown -> "UNKNOWN"
and string_of_typ_arg = function
- | Typ_arg_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg
+ | A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg
and string_of_typ_arg_aux = function
- | Typ_arg_nexp n -> string_of_nexp n
- | Typ_arg_typ typ -> string_of_typ typ
- | Typ_arg_order o -> string_of_order o
+ | A_nexp n -> string_of_nexp n
+ | A_typ typ -> string_of_typ typ
+ | A_order o -> string_of_order o
+ | A_bool nc -> string_of_n_constraint nc
and string_of_n_constraint = function
| NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2
| NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2
@@ -160,18 +158,12 @@ let string_of_instantiation instantiation =
| NC_aux (NC_and (nc1, nc2), _) ->
"(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_set (kid, ns), _) ->
- kid_name kid ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}"
+ kid_name (mk_kopt K_int kid) ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}"
| NC_aux (NC_true, _) -> "true"
| NC_aux (NC_false, _) -> "false"
in
- let string_of_uvar = function
- | U_nexp n -> string_of_nexp n
- | U_order o -> string_of_order o
- | U_typ typ -> string_of_typ typ
- in
-
- let string_of_binding (kid, uvar) = string_of_kid kid ^ " => " ^ string_of_uvar uvar in
+ let string_of_binding (kid, arg) = string_of_kid kid ^ " => " ^ string_of_typ_arg arg in
Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation))
let id_of_instantiation id instantiation =
@@ -181,7 +173,7 @@ let id_of_instantiation id instantiation =
let rec variant_generic_typ id (Defs defs) =
match defs with
| DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ when Id.compare id id' = 0 ->
- mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
+ mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
| _ :: defs -> variant_generic_typ id (Defs defs)
| [] -> failwith ("No variant with id " ^ string_of_id id)
@@ -206,9 +198,10 @@ let rec instantiations_of id ast =
begin match Type_check.typ_of_annot annot with
| Typ_aux (Typ_app (variant_id, _), _) as typ ->
let open Type_check in
- let instantiation, _, _ = unify (fst annot) (env_of_annot annot)
- (variant_generic_typ variant_id ast)
- typ
+ let instantiation = unify (fst annot) (env_of_annot annot)
+ (tyvars_of_typ (variant_generic_typ variant_id ast))
+ (variant_generic_typ variant_id ast)
+ typ
in
instantiations := fix_instantiation instantiation :: !instantiations;
pat
@@ -256,15 +249,16 @@ let rec typ_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
| Typ_var kid -> KidSet.singleton kid
| Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs:exs) typs)
| Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs:exs) args)
- | Typ_exist (kids, nc, typ) -> typ_frees ~exs:(KidSet.of_list kids) typ
- | Typ_fn (typ1, typ2, _) -> KidSet.union (typ_frees ~exs:exs typ1) (typ_frees ~exs:exs typ2)
+ | Typ_exist (kopts, nc, typ) -> typ_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ
+ | Typ_fn (arg_typs, ret_typ, _) ->
+ List.fold_left KidSet.union (typ_frees ~exs:exs ret_typ) (List.map (typ_frees ~exs:exs) arg_typs)
| Typ_bidir (t1, t2) -> KidSet.union (typ_frees ~exs:exs t1) (typ_frees ~exs:exs t2)
| Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown"
-and typ_arg_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) =
+and typ_arg_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) =
match typ_arg_aux with
- | Typ_arg_nexp n -> KidSet.empty
- | Typ_arg_typ typ -> typ_frees ~exs:exs typ
- | Typ_arg_order ord -> KidSet.empty
+ | A_nexp n -> KidSet.empty
+ | A_typ typ -> typ_frees ~exs:exs typ
+ | A_order ord -> KidSet.empty
let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
match typ_aux with
@@ -272,24 +266,16 @@ let rec typ_int_frees ?exs:(exs=KidSet.empty) (Typ_aux (typ_aux, l)) =
| Typ_var kid -> KidSet.empty
| Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_int_frees ~exs:exs) typs)
| Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_int_frees ~exs:exs) args)
- | Typ_exist (kids, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list kids) typ
- | Typ_fn (typ1, typ2, _) -> KidSet.union (typ_int_frees ~exs:exs typ1) (typ_int_frees ~exs:exs typ2)
+ | Typ_exist (kopts, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ
+ | Typ_fn (arg_typs, ret_typ, _) ->
+ List.fold_left KidSet.union (typ_int_frees ~exs:exs ret_typ) (List.map (typ_int_frees ~exs:exs) arg_typs)
| Typ_bidir (t1, t2) -> KidSet.union (typ_int_frees ~exs:exs t1) (typ_int_frees ~exs:exs t2)
| Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown"
-and typ_arg_int_frees ?exs:(exs=KidSet.empty) (Typ_arg_aux (typ_arg_aux, l)) =
+and typ_arg_int_frees ?exs:(exs=KidSet.empty) (A_aux (typ_arg_aux, l)) =
match typ_arg_aux with
- | Typ_arg_nexp n -> KidSet.diff (tyvars_of_nexp n) exs
- | Typ_arg_typ typ -> KidSet.empty
- | Typ_arg_order ord -> KidSet.empty
-
-let uvar_int_frees = function
- | Type_check.U_nexp n -> tyvars_of_nexp n
- | Type_check.U_typ typ -> typ_int_frees typ
- | _ -> KidSet.empty
-
-let uvar_typ_frees = function
- | Type_check.U_typ typ -> typ_frees typ
- | _ -> KidSet.empty
+ | A_nexp n -> KidSet.diff (tyvars_of_nexp n) exs
+ | A_typ typ -> typ_int_frees ~exs:exs typ
+ | A_order ord -> KidSet.empty
let specialize_id_valspec instantiations id ast =
match split_defs (is_valspec id) ast with
@@ -310,14 +296,14 @@ let specialize_id_valspec instantiations id ast =
(* Collect any new type variables introduced by the instantiation *)
let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in
- let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_typ_frees |> collect_kids in
- let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map uvar_int_frees |> collect_kids in
+ let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_frees |> collect_kids in
+ let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_int_frees |> collect_kids in
(* Remove type variables from the type quantifier. *)
let kopts, constraints = quant_split typq in
let kopts = List.filter (fun kopt -> not (is_typ_kopt kopt || is_order_kopt kopt)) kopts in
- let typq = mk_typquant (List.map (mk_qi_id BK_type) typ_frees
- @ List.map (mk_qi_id BK_int) int_frees
+ let typq = mk_typquant (List.map (mk_qi_id K_type) typ_frees
+ @ List.map (mk_qi_id K_int) int_frees
@ List.map mk_qi_kopt kopts
@ List.map mk_qi_nc constraints) in
let typschm = mk_typschm typq typ in