summaryrefslogtreecommitdiff
path: root/src/ast_util.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/ast_util.ml')
-rw-r--r--src/ast_util.ml63
1 files changed, 54 insertions, 9 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index bd7a51bb..02e297cb 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -128,7 +128,7 @@ let mk_val_spec vs_aux =
let kopt_kid (KOpt_aux (KOpt_kind (_, kid), _)) = kid
let kopt_kind (KOpt_aux (KOpt_kind (k, _), _)) = k
-
+
let is_nat_kopt = function
| KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _) -> true
| _ -> false
@@ -144,7 +144,7 @@ let is_typ_kopt = function
let is_bool_kopt = function
| KOpt_aux (KOpt_kind (K_aux (K_bool, _), _), _) -> true
| _ -> false
-
+
let string_of_kid = function
| Kid_aux (Var v, _) -> v
@@ -153,6 +153,27 @@ module Kid = struct
let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2)
end
+module Kind = struct
+ type t = kind
+ let compare (K_aux (aux1, _)) (K_aux (aux2, _)) =
+ match aux1, aux2 with
+ | K_int, K_int -> 0
+ | K_type, K_type -> 0
+ | K_order, K_order -> 0
+ | K_bool, K_bool -> 0
+ | K_int, _ -> 1 | _, K_int -> -1
+ | K_type, _ -> 1 | _, K_type -> -1
+ | K_order, _ -> 1 | _, K_order -> -1
+end
+
+module KOpt = struct
+ type t = kinded_id
+ let compare kopt1 kopt2 =
+ let lex_ord c1 c2 = if c1 = 0 then c2 else c1 in
+ lex_ord (Kid.compare (kopt_kid kopt1) (kopt_kid kopt2))
+ (Kind.compare (kopt_kind kopt1) (kopt_kind kopt2))
+end
+
module Id = struct
type t = id
let compare id1 id2 =
@@ -200,6 +221,8 @@ module Bindings = Map.Make(Id)
module IdSet = Set.Make(Id)
module KBindings = Map.Make(Kid)
module KidSet = Set.Make(Kid)
+module KOptSet = Set.Make(KOpt)
+module KOptMap = Map.Make(KOpt)
module NexpSet = Set.Make(Nexp)
module NexpMap = Map.Make(Nexp)
@@ -389,6 +412,13 @@ let arg_order ?loc:(l=Parse_ast.Unknown) ord = A_aux (A_order ord, l)
let arg_typ ?loc:(l=Parse_ast.Unknown) typ = A_aux (A_typ typ, l)
let arg_bool ?loc:(l=Parse_ast.Unknown) nc = A_aux (A_bool nc, l)
+let arg_kopt (KOpt_aux (KOpt_kind (K_aux (k, _), v), l)) =
+ match k with
+ | K_int -> arg_nexp (nvar v)
+ | K_order -> arg_order (Ord_aux (Ord_var v, l))
+ | K_bool -> arg_bool (nc_var v)
+ | K_type -> arg_typ (mk_typ (Typ_var v))
+
let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc]))
let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown)
@@ -644,6 +674,9 @@ let string_of_kind_aux = function
let string_of_kind (K_aux (k, _)) = string_of_kind_aux k
+let string_of_kinded_id (KOpt_aux (KOpt_kind (k, kid), _)) =
+ "(" ^ string_of_kid kid ^ " : " ^ string_of_kind k ^ ")"
+
let string_of_base_effect = function
| BE_aux (beff, _) -> string_of_base_effect_aux beff
@@ -687,7 +720,7 @@ and string_of_typ_aux = function
^ string_of_typ typ_ret ^ " effect " ^ string_of_effect eff
| Typ_bidir (typ1, typ2) -> string_of_typ typ1 ^ " <-> " ^ string_of_typ typ2
| Typ_exist (kids, nc, typ) ->
- "{" ^ string_of_list " " string_of_kid kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}"
+ "{" ^ string_of_list " " string_of_kinded_id kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}"
and string_of_typ_arg = function
| A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg
and string_of_typ_arg_aux = function
@@ -995,7 +1028,7 @@ module Typ = struct
| n -> n)
| Typ_tup ts1, Typ_tup ts2 -> Util.compare_list compare ts1 ts2
| Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) ->
- (match Util.compare_list Kid.compare ks1 ks2 with
+ (match Util.compare_list KOpt.compare ks1 ks2 with
| 0 -> (match NC.compare nc1 nc2 with
| 0 -> compare t1 t2
| n -> n)
@@ -1016,8 +1049,10 @@ module Typ = struct
| A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2
| A_typ t1, A_typ t2 -> compare t1 t2
| A_order o1, A_order o2 -> order_compare o1 o2
+ | A_bool nc1, A_bool nc2 -> NC.compare nc1 nc2
| A_nexp _, _ -> -1 | _, A_nexp _ -> 1
| A_typ _, _ -> -1 | _, A_typ _ -> 1
+ | A_order _, _ -> -1 | _, A_order _ -> 1
end
module TypMap = Map.Make(Typ)
@@ -1179,7 +1214,7 @@ and tyvars_of_typ (Typ_aux (t,_)) =
KidSet.empty tas
| Typ_exist (kids, nc, t) ->
let s = KidSet.union (tyvars_of_typ t) (tyvars_of_constraint nc) in
- List.fold_left (fun s k -> KidSet.remove k s) s kids
+ List.fold_left (fun s k -> KidSet.remove k s) s (List.map kopt_kid kids)
and tyvars_of_typ_arg (A_aux (ta,_)) =
match ta with
| A_nexp nexp -> tyvars_of_nexp nexp
@@ -1387,6 +1422,11 @@ let locate_id f (Id_aux (name, l)) = Id_aux (name, f l)
let locate_kid f (Kid_aux (name, l)) = Kid_aux (name, f l)
+let locate_kind f (K_aux (kind, l)) = K_aux (kind, f l)
+
+let locate_kinded_id f (KOpt_aux (KOpt_kind (k, kid), l)) =
+ KOpt_aux (KOpt_kind (locate_kind f k, locate_kid f kid), f l)
+
let locate_lit f (L_aux (lit, l)) = L_aux (lit, f l)
let locate_base_effect f (BE_aux (base_effect, l)) = BE_aux (base_effect, f l)
@@ -1427,10 +1467,12 @@ let rec locate_nc f (NC_aux (nc_aux, l)) =
| NC_and (nc1, nc2) -> NC_and (locate_nc f nc1, locate_nc f nc2)
| NC_true -> NC_true
| NC_false -> NC_false
+ | NC_var v -> NC_var (locate_kid f v)
+ | NC_app (id, args) -> NC_app (locate_id f id, List.map (locate_typ_arg f) args)
in
NC_aux (nc_aux, f l)
-let rec locate_typ f (Typ_aux (typ_aux, l)) =
+and locate_typ f (Typ_aux (typ_aux, l)) =
let typ_aux = match typ_aux with
| Typ_internal_unknown -> Typ_internal_unknown
| Typ_id id -> Typ_id (locate_id f id)
@@ -1439,7 +1481,7 @@ let rec locate_typ f (Typ_aux (typ_aux, l)) =
Typ_fn (List.map (locate_typ f) arg_typs, locate_typ f ret_typ, locate_effect f effect)
| Typ_bidir (typ1, typ2) -> Typ_bidir (locate_typ f typ1, locate_typ f typ2)
| Typ_tup typs -> Typ_tup (List.map (locate_typ f) typs)
- | Typ_exist (kids, constr, typ) -> Typ_exist (List.map (locate_kid f) kids, locate_nc f constr, locate_typ f typ)
+ | Typ_exist (kopts, constr, typ) -> Typ_exist (List.map (locate_kinded_id f) kopts, locate_nc f constr, locate_typ f typ)
| Typ_app (id, typ_args) -> Typ_app (locate_id f id, List.map (locate_typ_arg f) typ_args)
in
Typ_aux (typ_aux, f l)
@@ -1449,6 +1491,7 @@ and locate_typ_arg f (A_aux (typ_arg_aux, l)) =
| A_nexp nexp -> A_nexp (locate_nexp f nexp)
| A_typ typ -> A_typ (locate_typ f typ)
| A_order ord -> A_order (locate_order f ord)
+ | A_bool nc -> A_bool (locate_nc f nc)
in
A_aux (typ_arg_aux, f l)
@@ -1638,8 +1681,10 @@ and typ_subst_aux sv subst = function
| Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst sv subst typ1, typ_subst sv subst typ2)
| Typ_tup typs -> Typ_tup (List.map (typ_subst sv subst) typs)
| Typ_app (f, args) -> Typ_app (f, List.map (typ_arg_subst sv subst) args)
- | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ)
- | Typ_exist (kids, nc, typ) -> Typ_exist (kids, constraint_subst sv subst nc, typ_subst sv subst typ)
+ | Typ_exist (kopts, nc, typ) when KidSet.mem sv (KidSet.of_list (List.map kopt_kid kopts)) ->
+ Typ_exist (kopts, nc, typ)
+ | Typ_exist (kopts, nc, typ) ->
+ Typ_exist (kopts, constraint_subst sv subst nc, typ_subst sv subst typ)
and typ_arg_subst sv subst (A_aux (arg, l)) = A_aux (typ_arg_subst_aux sv subst arg, l)
and typ_arg_subst_aux sv subst = function