diff options
Diffstat (limited to 'src/ast_util.ml')
| -rw-r--r-- | src/ast_util.ml | 63 |
1 files changed, 54 insertions, 9 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index bd7a51bb..02e297cb 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -128,7 +128,7 @@ let mk_val_spec vs_aux = let kopt_kid (KOpt_aux (KOpt_kind (_, kid), _)) = kid let kopt_kind (KOpt_aux (KOpt_kind (k, _), _)) = k - + let is_nat_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _) -> true | _ -> false @@ -144,7 +144,7 @@ let is_typ_kopt = function let is_bool_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_bool, _), _), _) -> true | _ -> false - + let string_of_kid = function | Kid_aux (Var v, _) -> v @@ -153,6 +153,27 @@ module Kid = struct let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2) end +module Kind = struct + type t = kind + let compare (K_aux (aux1, _)) (K_aux (aux2, _)) = + match aux1, aux2 with + | K_int, K_int -> 0 + | K_type, K_type -> 0 + | K_order, K_order -> 0 + | K_bool, K_bool -> 0 + | K_int, _ -> 1 | _, K_int -> -1 + | K_type, _ -> 1 | _, K_type -> -1 + | K_order, _ -> 1 | _, K_order -> -1 +end + +module KOpt = struct + type t = kinded_id + let compare kopt1 kopt2 = + let lex_ord c1 c2 = if c1 = 0 then c2 else c1 in + lex_ord (Kid.compare (kopt_kid kopt1) (kopt_kid kopt2)) + (Kind.compare (kopt_kind kopt1) (kopt_kind kopt2)) +end + module Id = struct type t = id let compare id1 id2 = @@ -200,6 +221,8 @@ module Bindings = Map.Make(Id) module IdSet = Set.Make(Id) module KBindings = Map.Make(Kid) module KidSet = Set.Make(Kid) +module KOptSet = Set.Make(KOpt) +module KOptMap = Map.Make(KOpt) module NexpSet = Set.Make(Nexp) module NexpMap = Map.Make(Nexp) @@ -389,6 +412,13 @@ let arg_order ?loc:(l=Parse_ast.Unknown) ord = A_aux (A_order ord, l) let arg_typ ?loc:(l=Parse_ast.Unknown) typ = A_aux (A_typ typ, l) let arg_bool ?loc:(l=Parse_ast.Unknown) nc = A_aux (A_bool nc, l) +let arg_kopt (KOpt_aux (KOpt_kind (K_aux (k, _), v), l)) = + match k with + | K_int -> arg_nexp (nvar v) + | K_order -> arg_order (Ord_aux (Ord_var v, l)) + | K_bool -> arg_bool (nc_var v) + | K_type -> arg_typ (mk_typ (Typ_var v)) + let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc])) let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown) @@ -644,6 +674,9 @@ let string_of_kind_aux = function let string_of_kind (K_aux (k, _)) = string_of_kind_aux k +let string_of_kinded_id (KOpt_aux (KOpt_kind (k, kid), _)) = + "(" ^ string_of_kid kid ^ " : " ^ string_of_kind k ^ ")" + let string_of_base_effect = function | BE_aux (beff, _) -> string_of_base_effect_aux beff @@ -687,7 +720,7 @@ and string_of_typ_aux = function ^ string_of_typ typ_ret ^ " effect " ^ string_of_effect eff | Typ_bidir (typ1, typ2) -> string_of_typ typ1 ^ " <-> " ^ string_of_typ typ2 | Typ_exist (kids, nc, typ) -> - "{" ^ string_of_list " " string_of_kid kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}" + "{" ^ string_of_list " " string_of_kinded_id kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}" and string_of_typ_arg = function | A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg and string_of_typ_arg_aux = function @@ -995,7 +1028,7 @@ module Typ = struct | n -> n) | Typ_tup ts1, Typ_tup ts2 -> Util.compare_list compare ts1 ts2 | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) -> - (match Util.compare_list Kid.compare ks1 ks2 with + (match Util.compare_list KOpt.compare ks1 ks2 with | 0 -> (match NC.compare nc1 nc2 with | 0 -> compare t1 t2 | n -> n) @@ -1016,8 +1049,10 @@ module Typ = struct | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 | A_typ t1, A_typ t2 -> compare t1 t2 | A_order o1, A_order o2 -> order_compare o1 o2 + | A_bool nc1, A_bool nc2 -> NC.compare nc1 nc2 | A_nexp _, _ -> -1 | _, A_nexp _ -> 1 | A_typ _, _ -> -1 | _, A_typ _ -> 1 + | A_order _, _ -> -1 | _, A_order _ -> 1 end module TypMap = Map.Make(Typ) @@ -1179,7 +1214,7 @@ and tyvars_of_typ (Typ_aux (t,_)) = KidSet.empty tas | Typ_exist (kids, nc, t) -> let s = KidSet.union (tyvars_of_typ t) (tyvars_of_constraint nc) in - List.fold_left (fun s k -> KidSet.remove k s) s kids + List.fold_left (fun s k -> KidSet.remove k s) s (List.map kopt_kid kids) and tyvars_of_typ_arg (A_aux (ta,_)) = match ta with | A_nexp nexp -> tyvars_of_nexp nexp @@ -1387,6 +1422,11 @@ let locate_id f (Id_aux (name, l)) = Id_aux (name, f l) let locate_kid f (Kid_aux (name, l)) = Kid_aux (name, f l) +let locate_kind f (K_aux (kind, l)) = K_aux (kind, f l) + +let locate_kinded_id f (KOpt_aux (KOpt_kind (k, kid), l)) = + KOpt_aux (KOpt_kind (locate_kind f k, locate_kid f kid), f l) + let locate_lit f (L_aux (lit, l)) = L_aux (lit, f l) let locate_base_effect f (BE_aux (base_effect, l)) = BE_aux (base_effect, f l) @@ -1427,10 +1467,12 @@ let rec locate_nc f (NC_aux (nc_aux, l)) = | NC_and (nc1, nc2) -> NC_and (locate_nc f nc1, locate_nc f nc2) | NC_true -> NC_true | NC_false -> NC_false + | NC_var v -> NC_var (locate_kid f v) + | NC_app (id, args) -> NC_app (locate_id f id, List.map (locate_typ_arg f) args) in NC_aux (nc_aux, f l) -let rec locate_typ f (Typ_aux (typ_aux, l)) = +and locate_typ f (Typ_aux (typ_aux, l)) = let typ_aux = match typ_aux with | Typ_internal_unknown -> Typ_internal_unknown | Typ_id id -> Typ_id (locate_id f id) @@ -1439,7 +1481,7 @@ let rec locate_typ f (Typ_aux (typ_aux, l)) = Typ_fn (List.map (locate_typ f) arg_typs, locate_typ f ret_typ, locate_effect f effect) | Typ_bidir (typ1, typ2) -> Typ_bidir (locate_typ f typ1, locate_typ f typ2) | Typ_tup typs -> Typ_tup (List.map (locate_typ f) typs) - | Typ_exist (kids, constr, typ) -> Typ_exist (List.map (locate_kid f) kids, locate_nc f constr, locate_typ f typ) + | Typ_exist (kopts, constr, typ) -> Typ_exist (List.map (locate_kinded_id f) kopts, locate_nc f constr, locate_typ f typ) | Typ_app (id, typ_args) -> Typ_app (locate_id f id, List.map (locate_typ_arg f) typ_args) in Typ_aux (typ_aux, f l) @@ -1449,6 +1491,7 @@ and locate_typ_arg f (A_aux (typ_arg_aux, l)) = | A_nexp nexp -> A_nexp (locate_nexp f nexp) | A_typ typ -> A_typ (locate_typ f typ) | A_order ord -> A_order (locate_order f ord) + | A_bool nc -> A_bool (locate_nc f nc) in A_aux (typ_arg_aux, f l) @@ -1638,8 +1681,10 @@ and typ_subst_aux sv subst = function | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst sv subst typ1, typ_subst sv subst typ2) | Typ_tup typs -> Typ_tup (List.map (typ_subst sv subst) typs) | Typ_app (f, args) -> Typ_app (f, List.map (typ_arg_subst sv subst) args) - | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, constraint_subst sv subst nc, typ_subst sv subst typ) + | Typ_exist (kopts, nc, typ) when KidSet.mem sv (KidSet.of_list (List.map kopt_kid kopts)) -> + Typ_exist (kopts, nc, typ) + | Typ_exist (kopts, nc, typ) -> + Typ_exist (kopts, constraint_subst sv subst nc, typ_subst sv subst typ) and typ_arg_subst sv subst (A_aux (arg, l)) = A_aux (typ_arg_subst_aux sv subst arg, l) and typ_arg_subst_aux sv subst = function |
