diff options
Diffstat (limited to 'src/ast_util.ml')
| -rw-r--r-- | src/ast_util.ml | 351 |
1 files changed, 200 insertions, 151 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index e9153f7a..a771291e 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -126,13 +126,11 @@ let mk_letbind pat exp = LB_aux (LB_val (pat, exp), no_annot) let mk_val_spec vs_aux = DEF_spec (VS_aux (vs_aux, no_annot)) -let kopt_kid (KOpt_aux (kopt_aux, _)) = - match kopt_aux with - | KOpt_none kid | KOpt_kind (_, kid) -> kid - +let kopt_kid (KOpt_aux (KOpt_kind (_, kid), _)) = kid +let kopt_kind (KOpt_aux (KOpt_kind (k, _), _)) = k + let is_nat_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _) -> true - | KOpt_aux (KOpt_none _, _) -> true | _ -> false let is_order_kopt = function @@ -143,6 +141,10 @@ let is_typ_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_type, _), _), _) -> true | _ -> false +let is_bool_kopt = function + | KOpt_aux (KOpt_kind (K_aux (K_bool, _), _), _) -> true + | _ -> false + let string_of_kid = function | Kid_aux (Var v, _) -> v @@ -151,6 +153,27 @@ module Kid = struct let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2) end +module Kind = struct + type t = kind + let compare (K_aux (aux1, _)) (K_aux (aux2, _)) = + match aux1, aux2 with + | K_int, K_int -> 0 + | K_type, K_type -> 0 + | K_order, K_order -> 0 + | K_bool, K_bool -> 0 + | K_int, _ -> 1 | _, K_int -> -1 + | K_type, _ -> 1 | _, K_type -> -1 + | K_order, _ -> 1 | _, K_order -> -1 +end + +module KOpt = struct + type t = kinded_id + let compare kopt1 kopt2 = + let lex_ord c1 c2 = if c1 = 0 then c2 else c1 in + lex_ord (Kid.compare (kopt_kid kopt1) (kopt_kid kopt2)) + (Kind.compare (kopt_kind kopt1) (kopt_kind kopt2)) +end + module Id = struct type t = id let compare id1 id2 = @@ -198,6 +221,8 @@ module Bindings = Map.Make(Id) module IdSet = Set.Make(Id) module KBindings = Map.Make(Kid) module KidSet = Set.Make(Kid) +module KOptSet = Set.Make(KOpt) +module KOptMap = Map.Make(KOpt) module NexpSet = Set.Make(Nexp) module NexpMap = Map.Make(Nexp) @@ -315,12 +340,15 @@ let rec constraint_disj (NC_aux (nc_aux, l) as nc) = | _ -> [nc] let mk_typ typ = Typ_aux (typ, Parse_ast.Unknown) -let mk_typ_arg arg = Typ_arg_aux (arg, Parse_ast.Unknown) +let mk_typ_arg arg = A_aux (arg, Parse_ast.Unknown) let mk_kid str = Kid_aux (Var ("'" ^ str), Parse_ast.Unknown) let mk_infix_id str = Id_aux (DeIid str, Parse_ast.Unknown) let mk_id_typ id = Typ_aux (Typ_id id, Parse_ast.Unknown) +let mk_kopt kind_aux id = + KOpt_aux (KOpt_kind (K_aux (kind_aux, Parse_ast.Unknown), id), Parse_ast.Unknown) + let mk_ord ord_aux = Ord_aux (ord_aux, Parse_ast.Unknown) let unknown_typ = mk_typ Typ_internal_unknown @@ -330,23 +358,23 @@ let unit_typ = mk_id_typ (mk_id "unit") let bit_typ = mk_id_typ (mk_id "bit") let real_typ = mk_id_typ (mk_id "real") let app_typ id args = mk_typ (Typ_app (id, args)) -let register_typ typ = mk_typ (Typ_app (mk_id "register", [mk_typ_arg (Typ_arg_typ typ)])) +let register_typ typ = mk_typ (Typ_app (mk_id "register", [mk_typ_arg (A_typ typ)])) let atom_typ nexp = - mk_typ (Typ_app (mk_id "atom", [mk_typ_arg (Typ_arg_nexp (nexp_simp nexp))])) + mk_typ (Typ_app (mk_id "atom", [mk_typ_arg (A_nexp (nexp_simp nexp))])) let range_typ nexp1 nexp2 = - mk_typ (Typ_app (mk_id "range", [mk_typ_arg (Typ_arg_nexp (nexp_simp nexp1)); - mk_typ_arg (Typ_arg_nexp (nexp_simp nexp2))])) + mk_typ (Typ_app (mk_id "range", [mk_typ_arg (A_nexp (nexp_simp nexp1)); + mk_typ_arg (A_nexp (nexp_simp nexp2))])) let bool_typ = mk_id_typ (mk_id "bool") let string_typ = mk_id_typ (mk_id "string") -let list_typ typ = mk_typ (Typ_app (mk_id "list", [mk_typ_arg (Typ_arg_typ typ)])) +let list_typ typ = mk_typ (Typ_app (mk_id "list", [mk_typ_arg (A_typ typ)])) let tuple_typ typs = mk_typ (Typ_tup typs) let function_typ arg_typs ret_typ eff = mk_typ (Typ_fn (arg_typs, ret_typ, eff)) let vector_typ n ord typ = mk_typ (Typ_app (mk_id "vector", - [mk_typ_arg (Typ_arg_nexp (nexp_simp n)); - mk_typ_arg (Typ_arg_order ord); - mk_typ_arg (Typ_arg_typ typ)])) + [mk_typ_arg (A_nexp (nexp_simp n)); + mk_typ_arg (A_order ord); + mk_typ_arg (A_typ typ)])) let exc_typ = mk_id_typ (mk_id "exception") @@ -368,25 +396,30 @@ let nc_lteq n1 n2 = NC_aux (NC_bounded_le (n1, n2), Parse_ast.Unknown) let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown) let nc_lt n1 n2 = nc_lteq (nsum n1 (nint 1)) n2 let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nint 1)) -let nc_and nc1 nc2 = mk_nc (NC_and (nc1, nc2)) let nc_or nc1 nc2 = mk_nc (NC_or (nc1, nc2)) +let nc_var kid = mk_nc (NC_var kid) let nc_true = mk_nc NC_true let nc_false = mk_nc NC_false -let rec nc_negate (NC_aux (nc, l)) = - match nc with - | NC_bounded_ge (n1, n2) -> nc_lt n1 n2 - | NC_bounded_le (n1, n2) -> nc_gt n1 n2 - | NC_equal (n1, n2) -> nc_neq n1 n2 - | NC_not_equal (n1, n2) -> nc_eq n1 n2 - | NC_and (n1, n2) -> mk_nc (NC_or (nc_negate n1, nc_negate n2)) - | NC_or (n1, n2) -> mk_nc (NC_and (nc_negate n1, nc_negate n2)) - | NC_false -> mk_nc NC_true - | NC_true -> mk_nc NC_false - | NC_set (kid, []) -> nc_false - | NC_set (kid, [int]) -> nc_neq (nvar kid) (nconstant int) - | NC_set (kid, int :: ints) -> - mk_nc (NC_and (nc_neq (nvar kid) (nconstant int), nc_negate (mk_nc (NC_set (kid, ints))))) +let nc_and nc1 nc2 = + match nc1, nc2 with + | _, NC_aux (NC_true, _) -> nc1 + | NC_aux (NC_true, _), _ -> nc2 + | _, _ -> mk_nc (NC_and (nc1, nc2)) + +let arg_nexp ?loc:(l=Parse_ast.Unknown) n = A_aux (A_nexp n, l) +let arg_order ?loc:(l=Parse_ast.Unknown) ord = A_aux (A_order ord, l) +let arg_typ ?loc:(l=Parse_ast.Unknown) typ = A_aux (A_typ typ, l) +let arg_bool ?loc:(l=Parse_ast.Unknown) nc = A_aux (A_bool nc, l) + +let arg_kopt (KOpt_aux (KOpt_kind (K_aux (k, _), v), l)) = + match k with + | K_int -> arg_nexp (nvar v) + | K_order -> arg_order (Ord_aux (Ord_var v, l)) + | K_bool -> arg_bool (nc_var v) + | K_type -> arg_typ (mk_typ (Typ_var v)) + +let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc])) let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown) @@ -433,10 +466,19 @@ let quant_map_items f = function | TypQ_aux (TypQ_no_forall, l) -> TypQ_aux (TypQ_no_forall, l) | TypQ_aux (TypQ_tq qis, l) -> TypQ_aux (TypQ_tq (List.map f qis), l) +let is_quant_kopt = function + | QI_aux (QI_id _, _) -> true + | _ -> false + +let is_quant_constraint = function + | QI_aux (QI_const _, _) -> true + | _ -> false + let unaux_nexp (Nexp_aux (nexp, _)) = nexp let unaux_order (Ord_aux (ord, _)) = ord let unaux_typ (Typ_aux (typ, _)) = typ let unaux_kind (K_aux (k, _)) = k +let unaux_constraint (NC_aux (nc, _)) = nc let rec map_exp_annot f (E_aux (exp, annot)) = E_aux (map_exp_annot_aux f exp, f annot) and map_exp_annot_aux f = function @@ -628,9 +670,13 @@ let string_of_kind_aux = function | K_type -> "Type" | K_int -> "Int" | K_order -> "Order" + | K_bool -> "Bool" let string_of_kind (K_aux (k, _)) = string_of_kind_aux k +let string_of_kinded_id (KOpt_aux (KOpt_kind (k, kid), _)) = + "(" ^ string_of_kid kid ^ " : " ^ string_of_kind k ^ ")" + let string_of_base_effect = function | BE_aux (beff, _) -> string_of_base_effect_aux beff @@ -665,6 +711,7 @@ and string_of_typ_aux = function | Typ_var kid -> string_of_kid kid | Typ_tup typs -> "(" ^ string_of_list ", " string_of_typ typs ^ ")" | Typ_app (id, args) when Id.compare id (mk_id "atom") = 0 -> "int(" ^ string_of_list ", " string_of_typ_arg args ^ ")" + | Typ_app (id, args) when Id.compare id (mk_id "atom_bool") = 0 -> "bool(" ^ string_of_list ", " string_of_typ_arg args ^ ")" | Typ_app (id, args) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")" | Typ_fn ([typ_arg], typ_ret, eff) -> string_of_typ typ_arg ^ " -> " ^ string_of_typ typ_ret ^ " effect " ^ string_of_effect eff @@ -673,15 +720,16 @@ and string_of_typ_aux = function ^ string_of_typ typ_ret ^ " effect " ^ string_of_effect eff | Typ_bidir (typ1, typ2) -> string_of_typ typ1 ^ " <-> " ^ string_of_typ typ2 | Typ_exist (kids, nc, typ) -> - "{" ^ string_of_list " " string_of_kid kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}" + "{" ^ string_of_list " " string_of_kinded_id kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}" and string_of_typ_arg = function - | Typ_arg_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg + | A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg and string_of_typ_arg_aux = function - | Typ_arg_nexp n -> string_of_nexp n - | Typ_arg_typ typ -> string_of_typ typ - | Typ_arg_order o -> string_of_order o + | A_nexp n -> string_of_nexp n + | A_typ typ -> string_of_typ typ + | A_order o -> string_of_order o + | A_bool nc -> string_of_n_constraint nc and string_of_n_constraint = function - | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2 + | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " == " ^ string_of_nexp n2 | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 | NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2 | NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2 @@ -691,12 +739,14 @@ and string_of_n_constraint = function "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_set (kid, ns), _) -> string_of_kid kid ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}" + | NC_aux (NC_app (Id_aux (DeIid op, _), [arg1; arg2]), _) -> + "(" ^ string_of_typ_arg arg1 ^ " " ^ op ^ " " ^ string_of_typ_arg arg2 ^ ")" + | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")" + | NC_aux (NC_var v, _) -> string_of_kid v | NC_aux (NC_true, _) -> "true" | NC_aux (NC_false, _) -> "false" -let string_of_kinded_id = function - | KOpt_aux (KOpt_none kid, _) -> string_of_kid kid - | KOpt_aux (KOpt_kind (k, kid), _) -> "(" ^ string_of_kid kid ^ " : " ^ string_of_kind k ^ ")" +let string_of_kinded_id (KOpt_aux (KOpt_kind (k, kid), _)) = "(" ^ string_of_kid kid ^ " : " ^ string_of_kind k ^ ")" let string_of_quant_item_aux = function | QI_id kopt -> string_of_kinded_id kopt @@ -807,8 +857,9 @@ and string_of_pat (P_aux (pat, l)) = | P_vector_concat pats -> string_of_list " @ " string_of_pat pats | P_vector pats -> "[" ^ string_of_list ", " string_of_pat pats ^ "]" | P_as (pat, id) -> "(" ^ string_of_pat pat ^ " as " ^ string_of_id id ^ ")" + | P_string_append [] -> "\"\"" | P_string_append pats -> string_of_list " ^ " string_of_pat pats - | _ -> "PAT" + | P_record _ -> "PAT" and string_of_mpat (MP_aux (pat, l)) = match pat with @@ -979,7 +1030,7 @@ module Typ = struct | n -> n) | Typ_tup ts1, Typ_tup ts2 -> Util.compare_list compare ts1 ts2 | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) -> - (match Util.compare_list Kid.compare ks1 ks2 with + (match Util.compare_list KOpt.compare ks1 ks2 with | 0 -> (match NC.compare nc1 nc2 with | 0 -> compare t1 t2 | n -> n) @@ -995,13 +1046,15 @@ module Typ = struct | Typ_bidir _, _ -> -1 | _, Typ_bidir _ -> 1 | Typ_tup _, _ -> -1 | _, Typ_tup _ -> 1 | Typ_exist _, _ -> -1 | _, Typ_exist _ -> 1 - and arg_compare (Typ_arg_aux (ta1,_)) (Typ_arg_aux (ta2,_)) = + and arg_compare (A_aux (ta1,_)) (A_aux (ta2,_)) = match ta1, ta2 with - | Typ_arg_nexp n1, Typ_arg_nexp n2 -> Nexp.compare n1 n2 - | Typ_arg_typ t1, Typ_arg_typ t2 -> compare t1 t2 - | Typ_arg_order o1, Typ_arg_order o2 -> order_compare o1 o2 - | Typ_arg_nexp _, _ -> -1 | _, Typ_arg_nexp _ -> 1 - | Typ_arg_typ _, _ -> -1 | _, Typ_arg_typ _ -> 1 + | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 + | A_typ t1, A_typ t2 -> compare t1 t2 + | A_order o1, A_order o2 -> order_compare o1 o2 + | A_bool nc1, A_bool nc2 -> NC.compare nc1 nc2 + | A_nexp _, _ -> -1 | _, A_nexp _ -> 1 + | A_typ _, _ -> -1 | _, A_typ _ -> 1 + | A_order _, _ -> -1 | _, A_order _ -> 1 end module TypMap = Map.Make(Typ) @@ -1057,21 +1110,21 @@ let is_ref_typ (Typ_aux (typ_aux, _)) = match typ_aux with let rec is_vector_typ = function | Typ_aux (Typ_app (Id_aux (Id "vector",_), [_;_;_]), _) -> true - | Typ_aux (Typ_app (Id_aux (Id "register",_), [Typ_arg_aux (Typ_arg_typ rtyp,_)]), _) -> + | Typ_aux (Typ_app (Id_aux (Id "register",_), [A_aux (A_typ rtyp,_)]), _) -> is_vector_typ rtyp | _ -> false let typ_app_args_of = function | Typ_aux (Typ_app (Id_aux (Id c,_), targs), l) -> - (c, List.map (fun (Typ_arg_aux (a,_)) -> a) targs, l) + (c, List.map (fun (A_aux (a,_)) -> a) targs, l) | Typ_aux (_, l) as typ -> raise (Reporting.err_typ l ("typ_app_args_of called on non-app type " ^ string_of_typ typ)) let rec vector_typ_args_of typ = match typ_app_args_of typ with - | ("vector", [Typ_arg_nexp len; Typ_arg_order ord; Typ_arg_typ etyp], l) -> + | ("vector", [A_nexp len; A_order ord; A_typ etyp], l) -> (nexp_simp len, ord, etyp) - | ("register", [Typ_arg_typ rtyp], _) -> vector_typ_args_of rtyp + | ("register", [A_typ rtyp], _) -> vector_typ_args_of rtyp | (_, _, l) -> raise (Reporting.err_typ l ("vector_typ_args_of called on non-vector type " ^ string_of_typ typ)) @@ -1142,10 +1195,13 @@ let rec tyvars_of_constraint (NC_aux (nc, _)) = | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2) + | NC_app (id, args) -> + List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ_arg t)) KidSet.empty args + | NC_var kid -> KidSet.singleton kid | NC_true | NC_false -> KidSet.empty -let rec tyvars_of_typ (Typ_aux (t,_)) = +and tyvars_of_typ (Typ_aux (t,_)) = match t with | Typ_internal_unknown -> KidSet.empty | Typ_id _ -> KidSet.empty @@ -1160,15 +1216,16 @@ let rec tyvars_of_typ (Typ_aux (t,_)) = KidSet.empty tas | Typ_exist (kids, nc, t) -> let s = KidSet.union (tyvars_of_typ t) (tyvars_of_constraint nc) in - List.fold_left (fun s k -> KidSet.remove k s) s kids -and tyvars_of_typ_arg (Typ_arg_aux (ta,_)) = + List.fold_left (fun s k -> KidSet.remove k s) s (List.map kopt_kid kids) +and tyvars_of_typ_arg (A_aux (ta,_)) = match ta with - | Typ_arg_nexp nexp -> tyvars_of_nexp nexp - | Typ_arg_typ typ -> tyvars_of_typ typ - | Typ_arg_order _ -> KidSet.empty + | A_nexp nexp -> tyvars_of_nexp nexp + | A_typ typ -> tyvars_of_typ typ + | A_order _ -> KidSet.empty + | A_bool nc -> tyvars_of_constraint nc let tyvars_of_quant_item (QI_aux (qi, _)) = match qi with - | QI_id (KOpt_aux ((KOpt_none kid | KOpt_kind (_, kid)), _)) -> + | QI_id (KOpt_aux (KOpt_kind (_, kid), _)) -> KidSet.singleton kid | QI_const nc -> tyvars_of_constraint nc @@ -1182,7 +1239,7 @@ let rec undefined_of_typ mwords l annot (Typ_aux (typ_aux, _) as typ) = | Typ_app (_,[size;_;_]) when mwords && is_bitvector_typ typ -> wrap (E_app (mk_id "undefined_bitvector", undefined_of_typ_args mwords l annot size)) typ - | Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp i, _)]) when string_of_id atom = "atom" -> + | Typ_app (atom, [A_aux (A_nexp i, _)]) when string_of_id atom = "atom" -> wrap (E_sizeof i) typ | Typ_app (id, args) -> wrap (E_app (prepend_id "undefined_" id, @@ -1197,11 +1254,11 @@ let rec undefined_of_typ mwords l annot (Typ_aux (typ_aux, _) as typ) = case when re-writing those functions. *) wrap (E_id (prepend_id "typ_" (id_of_kid kid))) typ | Typ_internal_unknown | Typ_bidir _ | Typ_fn _ | Typ_exist _ -> assert false (* Typ_exist should be re-written *) -and undefined_of_typ_args mwords l annot (Typ_arg_aux (typ_arg_aux, _) as typ_arg) = +and undefined_of_typ_args mwords l annot (A_aux (typ_arg_aux, _) as typ_arg) = match typ_arg_aux with - | Typ_arg_nexp n -> [E_aux (E_sizeof n, (l, annot (atom_typ n)))] - | Typ_arg_typ typ -> [undefined_of_typ mwords l annot typ] - | Typ_arg_order _ -> [] + | A_nexp n -> [E_aux (E_sizeof n, (l, annot (atom_typ n)))] + | A_typ typ -> [undefined_of_typ mwords l annot typ] + | A_order _ -> [] let destruct_pexp (Pat_aux (pexp,ann)) = match pexp with @@ -1367,6 +1424,11 @@ let locate_id f (Id_aux (name, l)) = Id_aux (name, f l) let locate_kid f (Kid_aux (name, l)) = Kid_aux (name, f l) +let locate_kind f (K_aux (kind, l)) = K_aux (kind, f l) + +let locate_kinded_id f (KOpt_aux (KOpt_kind (k, kid), l)) = + KOpt_aux (KOpt_kind (locate_kind f k, locate_kid f kid), f l) + let locate_lit f (L_aux (lit, l)) = L_aux (lit, f l) let locate_base_effect f (BE_aux (base_effect, l)) = BE_aux (base_effect, f l) @@ -1407,10 +1469,12 @@ let rec locate_nc f (NC_aux (nc_aux, l)) = | NC_and (nc1, nc2) -> NC_and (locate_nc f nc1, locate_nc f nc2) | NC_true -> NC_true | NC_false -> NC_false + | NC_var v -> NC_var (locate_kid f v) + | NC_app (id, args) -> NC_app (locate_id f id, List.map (locate_typ_arg f) args) in NC_aux (nc_aux, f l) -let rec locate_typ f (Typ_aux (typ_aux, l)) = +and locate_typ f (Typ_aux (typ_aux, l)) = let typ_aux = match typ_aux with | Typ_internal_unknown -> Typ_internal_unknown | Typ_id id -> Typ_id (locate_id f id) @@ -1419,18 +1483,19 @@ let rec locate_typ f (Typ_aux (typ_aux, l)) = Typ_fn (List.map (locate_typ f) arg_typs, locate_typ f ret_typ, locate_effect f effect) | Typ_bidir (typ1, typ2) -> Typ_bidir (locate_typ f typ1, locate_typ f typ2) | Typ_tup typs -> Typ_tup (List.map (locate_typ f) typs) - | Typ_exist (kids, constr, typ) -> Typ_exist (List.map (locate_kid f) kids, locate_nc f constr, locate_typ f typ) + | Typ_exist (kopts, constr, typ) -> Typ_exist (List.map (locate_kinded_id f) kopts, locate_nc f constr, locate_typ f typ) | Typ_app (id, typ_args) -> Typ_app (locate_id f id, List.map (locate_typ_arg f) typ_args) in Typ_aux (typ_aux, f l) -and locate_typ_arg f (Typ_arg_aux (typ_arg_aux, l)) = +and locate_typ_arg f (A_aux (typ_arg_aux, l)) = let typ_arg_aux = match typ_arg_aux with - | Typ_arg_nexp nexp -> Typ_arg_nexp (locate_nexp f nexp) - | Typ_arg_typ typ -> Typ_arg_typ (locate_typ f typ) - | Typ_arg_order ord -> Typ_arg_order (locate_order f ord) + | A_nexp nexp -> A_nexp (locate_nexp f nexp) + | A_typ typ -> A_typ (locate_typ f typ) + | A_order ord -> A_order (locate_order f ord) + | A_bool nc -> A_bool (locate_nc f nc) in - Typ_arg_aux (typ_arg_aux, f l) + A_aux (typ_arg_aux, f l) let rec locate_typ_pat f (TP_aux (tp_aux, l)) = let tp_aux = match tp_aux with @@ -1547,10 +1612,26 @@ let unique l = (* 1. Substitutions *) (**************************************************************************) +let order_subst_aux sv subst = function + | Ord_var kid -> + begin match subst with + | A_aux (A_order ord, _) when Kid.compare kid sv = 0 -> + unaux_order ord + | _ -> Ord_var kid + end + | Ord_inc -> Ord_inc + | Ord_dec -> Ord_dec + +let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l) + let rec nexp_subst sv subst (Nexp_aux (nexp, l)) = Nexp_aux (nexp_subst_aux sv subst nexp, l) and nexp_subst_aux sv subst = function - | Nexp_id v -> Nexp_id v - | Nexp_var kid -> if Kid.compare kid sv = 0 then subst else Nexp_var kid + | Nexp_var kid -> + begin match subst with + | A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> unaux_nexp n + | _ -> Nexp_var kid + end + | Nexp_id id -> Nexp_id id | Nexp_constant c -> Nexp_constant c | Nexp_times (nexp1, nexp2) -> Nexp_times (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2) | Nexp_sum (nexp1, nexp2) -> Nexp_sum (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2) @@ -1564,100 +1645,68 @@ let rec nexp_set_to_or l subst = function | [int] -> NC_equal (subst, nconstant int) | (int :: ints) -> NC_or (mk_nc (NC_equal (subst, nconstant int)), mk_nc (nexp_set_to_or l subst ints)) -let rec nc_subst_nexp sv subst (NC_aux (nc, l)) = NC_aux (nc_subst_nexp_aux l sv subst nc, l) -and nc_subst_nexp_aux l sv subst = function +let rec constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l) +and constraint_subst_aux l sv subst = function | NC_equal (n1, n2) -> NC_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_ge (n1, n2) -> NC_bounded_ge (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_le (n1, n2) -> NC_bounded_le (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_set (kid, ints) as set_nc -> - if Kid.compare kid sv = 0 - then nexp_set_to_or l (mk_nexp subst) ints - else set_nc - | NC_or (nc1, nc2) -> NC_or (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2) - | NC_and (nc1, nc2) -> NC_and (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2) + begin match subst with + | A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> + nexp_set_to_or l n ints + | _ -> set_nc + end + | NC_or (nc1, nc2) -> NC_or (constraint_subst sv subst nc1, constraint_subst sv subst nc2) + | NC_and (nc1, nc2) -> NC_and (constraint_subst sv subst nc1, constraint_subst sv subst nc2) + | NC_app (id, args) -> NC_app (id, List.map (typ_arg_subst sv subst) args) + | NC_var kid -> + begin match subst with + | A_aux (A_bool nc, _) when Kid.compare kid sv = 0 -> + unaux_constraint nc + | _ -> NC_var kid + end | NC_false -> NC_false | NC_true -> NC_true -let rec typ_subst_nexp sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_nexp_aux sv subst typ, l) -and typ_subst_nexp_aux sv subst = function - | Typ_internal_unknown -> Typ_internal_unknown - | Typ_id v -> Typ_id v - | Typ_var kid -> Typ_var kid - | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_nexp sv subst) arg_typs, typ_subst_nexp sv subst ret_typ, effs) - | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_nexp sv subst typ1, typ_subst_nexp sv subst typ2) - | Typ_tup typs -> Typ_tup (List.map (typ_subst_nexp sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_nexp sv subst) args) - | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc_subst_nexp sv subst nc, typ_subst_nexp sv subst typ) -and typ_subst_arg_nexp sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_nexp_aux sv subst arg, l) -and typ_subst_arg_nexp_aux sv subst = function - | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv subst nexp) - | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_nexp sv subst typ) - | Typ_arg_order ord -> Typ_arg_order ord - -let rec typ_subst_typ sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_typ_aux sv subst typ, l) -and typ_subst_typ_aux sv subst = function +and typ_subst sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_aux sv subst typ, l) +and typ_subst_aux sv subst = function | Typ_internal_unknown -> Typ_internal_unknown | Typ_id v -> Typ_id v - | Typ_var kid -> if Kid.compare kid sv = 0 then subst else Typ_var kid - | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_typ sv subst) arg_typs, typ_subst_typ sv subst ret_typ, effs) - | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_typ sv subst typ1, typ_subst_typ sv subst typ2) - | Typ_tup typs -> Typ_tup (List.map (typ_subst_typ sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_typ sv subst) args) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, typ_subst_typ sv subst typ) -and typ_subst_arg_typ sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_typ_aux sv subst arg, l) -and typ_subst_arg_typ_aux sv subst = function - | Typ_arg_nexp nexp -> Typ_arg_nexp nexp - | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_typ sv subst typ) - | Typ_arg_order ord -> Typ_arg_order ord - -let order_subst_aux sv subst = function - | Ord_var kid -> if Kid.compare kid sv = 0 then subst else Ord_var kid - | Ord_inc -> Ord_inc - | Ord_dec -> Ord_dec - -let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l) - -let rec typ_subst_order sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_order_aux sv subst typ, l) -and typ_subst_order_aux sv subst = function - | Typ_internal_unknown -> Typ_internal_unknown - | Typ_id v -> Typ_id v - | Typ_var kid -> Typ_var kid - | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_order sv subst) arg_typs, typ_subst_order sv subst ret_typ, effs) - | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_order sv subst typ1, typ_subst_order sv subst typ2) - | Typ_tup typs -> Typ_tup (List.map (typ_subst_order sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_order sv subst) args) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, typ_subst_order sv subst typ) -and typ_subst_arg_order sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_order_aux sv subst arg, l) -and typ_subst_arg_order_aux sv subst = function - | Typ_arg_nexp nexp -> Typ_arg_nexp nexp - | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_order sv subst typ) - | Typ_arg_order ord -> Typ_arg_order (order_subst sv subst ord) - -let rec typ_subst_kid sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_kid_aux sv subst typ, l) -and typ_subst_kid_aux sv subst = function - | Typ_internal_unknown -> Typ_internal_unknown - | Typ_id v -> Typ_id v - | Typ_var kid -> if Kid.compare kid sv = 0 then Typ_var subst else Typ_var kid - | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_kid sv subst) arg_typs, typ_subst_kid sv subst ret_typ, effs) - | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_kid sv subst typ1, typ_subst_kid sv subst typ2) - | Typ_tup typs -> Typ_tup (List.map (typ_subst_kid sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_kid sv subst) args) - | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc_subst_nexp sv (Nexp_var subst) nc, typ_subst_kid sv subst typ) -and typ_subst_arg_kid sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_kid_aux sv subst arg, l) -and typ_subst_arg_kid_aux sv subst = function - | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv (Nexp_var subst) nexp) - | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_kid sv subst typ) - | Typ_arg_order ord -> Typ_arg_order (order_subst sv (Ord_var subst) ord) + | Typ_var kid -> + begin match subst with + | A_aux (A_typ typ, _) when Kid.compare kid sv = 0 -> + unaux_typ typ + | _ -> Typ_var kid + end + | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst sv subst) arg_typs, typ_subst sv subst ret_typ, effs) + | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst sv subst typ1, typ_subst sv subst typ2) + | Typ_tup typs -> Typ_tup (List.map (typ_subst sv subst) typs) + | Typ_app (f, args) -> Typ_app (f, List.map (typ_arg_subst sv subst) args) + | Typ_exist (kopts, nc, typ) when KidSet.mem sv (KidSet.of_list (List.map kopt_kid kopts)) -> + Typ_exist (kopts, nc, typ) + | Typ_exist (kopts, nc, typ) -> + Typ_exist (kopts, constraint_subst sv subst nc, typ_subst sv subst typ) + +and typ_arg_subst sv subst (A_aux (arg, l)) = A_aux (typ_arg_subst_aux sv subst arg, l) +and typ_arg_subst_aux sv subst = function + | A_nexp nexp -> A_nexp (nexp_subst sv subst nexp) + | A_typ typ -> A_typ (typ_subst sv subst typ) + | A_order ord -> A_order (order_subst sv subst ord) + | A_bool nc -> A_bool (constraint_subst sv subst nc) + +let subst_kid subst sv v x = + x + |> subst sv (mk_typ_arg (A_bool (nc_var v))) + |> subst sv (mk_typ_arg (A_nexp (nvar v))) + |> subst sv (mk_typ_arg (A_order (Ord_aux (Ord_var v, Parse_ast.Unknown)))) + |> subst sv (mk_typ_arg (A_typ (mk_typ (Typ_var v)))) let quant_item_subst_kid_aux sv subst = function - | QI_id (KOpt_aux (KOpt_none kid, l)) as qid -> - if Kid.compare kid sv = 0 then QI_id (KOpt_aux (KOpt_none subst, l)) else qid | QI_id (KOpt_aux (KOpt_kind (k, kid), l)) as qid -> if Kid.compare kid sv = 0 then QI_id (KOpt_aux (KOpt_kind (k, subst), l)) else qid - | QI_const nc -> QI_const (nc_subst_nexp sv (Nexp_var subst) nc) + | QI_const nc -> + QI_const (subst_kid constraint_subst sv subst nc) let quant_item_subst_kid sv subst (QI_aux (quant, l)) = QI_aux (quant_item_subst_kid_aux sv subst quant, l) |
