diff options
Diffstat (limited to 'interp')
| -rw-r--r-- | interp/constrexpr.ml | 3 | ||||
| -rw-r--r-- | interp/constrintern.ml | 30 | ||||
| -rw-r--r-- | interp/constrintern.mli | 5 | ||||
| -rw-r--r-- | interp/notation.ml | 180 | ||||
| -rw-r--r-- | interp/notation.mli | 6 | ||||
| -rw-r--r-- | interp/notation_ops.ml | 271 | ||||
| -rw-r--r-- | interp/notation_ops.mli | 5 |
7 files changed, 372 insertions, 128 deletions
diff --git a/interp/constrexpr.ml b/interp/constrexpr.ml index 235310660b..977cbbccf2 100644 --- a/interp/constrexpr.ml +++ b/interp/constrexpr.ml @@ -15,8 +15,11 @@ open Libnames (** [constr_expr] is the abstract syntax tree produced by the parser *) type universe_decl_expr = (lident list, Glob_term.glob_constraint list) UState.gen_universe_decl +type cumul_univ_decl_expr = + ((lident * Univ.Variance.t option) list, Glob_term.glob_constraint list) UState.gen_universe_decl type ident_decl = lident * universe_decl_expr option +type cumul_ident_decl = lident * cumul_univ_decl_expr option type name_decl = lname * universe_decl_expr option type notation_with_optional_scope = LastLonelyNotation | NotationInScope of string diff --git a/interp/constrintern.ml b/interp/constrintern.ml index c4b5bf8984..c7ed066f7e 100644 --- a/interp/constrintern.ml +++ b/interp/constrintern.ml @@ -2413,8 +2413,9 @@ let internalize globalenv env pattern_mode (_, ntnvars as lvar) c = and intern_args env subscopes = function | [] -> [] | a::args -> - let (enva,subscopes) = apply_scope_env env subscopes in - (intern_no_implicit enva a) :: (intern_args env subscopes args) + let (enva,subscopes) = apply_scope_env env subscopes in + let a = intern_no_implicit enva a in + a :: (intern_args env subscopes args) in intern env c @@ -2652,13 +2653,34 @@ let interp_univ_decl env decl = let binders : lident list = decl.univdecl_instance in let evd = Evd.from_env ~binders env in let evd, cstrs = interp_univ_constraints env evd decl.univdecl_constraints in - let decl = { univdecl_instance = binders; + let decl = { + univdecl_instance = binders; univdecl_extensible_instance = decl.univdecl_extensible_instance; univdecl_constraints = cstrs; - univdecl_extensible_constraints = decl.univdecl_extensible_constraints } + univdecl_extensible_constraints = decl.univdecl_extensible_constraints; + } in evd, decl +let interp_cumul_univ_decl env decl = + let open UState in + let binders = List.map fst decl.univdecl_instance in + let variances = Array.map_of_list snd decl.univdecl_instance in + let evd = Evd.from_ctx (UState.from_env ~binders env) in + let evd, cstrs = interp_univ_constraints env evd decl.univdecl_constraints in + let decl = { + univdecl_instance = binders; + univdecl_extensible_instance = decl.univdecl_extensible_instance; + univdecl_constraints = cstrs; + univdecl_extensible_constraints = decl.univdecl_extensible_constraints; + } + in + evd, decl, variances + let interp_univ_decl_opt env l = match l with | None -> Evd.from_env env, UState.default_univ_decl | Some decl -> interp_univ_decl env decl + +let interp_cumul_univ_decl_opt env = function + | None -> Evd.from_env env, UState.default_univ_decl, [| |] + | Some decl -> interp_cumul_univ_decl env decl diff --git a/interp/constrintern.mli b/interp/constrintern.mli index 9037ed5414..0de6c3e89d 100644 --- a/interp/constrintern.mli +++ b/interp/constrintern.mli @@ -204,3 +204,8 @@ val interp_univ_decl : Environ.env -> universe_decl_expr -> val interp_univ_decl_opt : Environ.env -> universe_decl_expr option -> Evd.evar_map * UState.universe_decl + +val interp_cumul_univ_decl_opt : Environ.env -> cumul_univ_decl_expr option -> + Evd.evar_map * UState.universe_decl * Entries.variance_entry +(** BEWARE the variance entry needs to be adjusted by + [ComInductive.variance_of_entry] if the instance is extensible. *) diff --git a/interp/notation.ml b/interp/notation.ml index 8d05fab63c..286ece6cb6 100644 --- a/interp/notation.ml +++ b/interp/notation.ml @@ -345,11 +345,23 @@ let also_cases_notation_rule_eq (also_cases1,rule1) (also_cases2,rule2) = (* No need in principle to compare also_cases as it is inferred *) also_cases1 = also_cases2 && notation_rule_eq rule1 rule2 +let adjust_application c1 c2 = + match c1, c2 with + | NApp (t1, a1), (NList (_,_,NApp (_, a2),_,_) | NApp (_, a2)) when List.length a1 >= List.length a2 -> + NApp (t1, List.firstn (List.length a2) a1) + | NApp (t1, a1), _ -> + t1 + | _ -> c1 + +let strictly_finer_interpretation_than (_,(_,(vars1,c1),_)) (_,(_,(vars2,c2),_)) = + let c1 = adjust_application c1 c2 in + Notation_ops.strictly_finer_notation_constr (List.map fst vars1, List.map fst vars2) c1 c2 + let keymap_add key interp map = let old = try KeyMap.find key map with Not_found -> [] in - (* In case of re-import, no need to keep the previous copy *) - let old = try List.remove_first (also_cases_notation_rule_eq interp) old with Not_found -> old in - KeyMap.add key (interp :: old) map + (* strictly finer interpretation are kept in front *) + let strictly_finer, rest = List.partition (fun c -> strictly_finer_interpretation_than c interp) old in + KeyMap.add key (strictly_finer @ interp :: rest) map let keymap_remove key interp map = let old = try KeyMap.find key map with Not_found -> [] in @@ -391,6 +403,10 @@ let notation_constr_key = function (* Rem: NApp(NRef ref,[]) stands for @ref *) | NBinderList (_,_,NApp (NRef ref,args),_,_) -> RefKey (canonical_gr ref), AppBoundedNotation (List.length args) | NRef ref -> RefKey(canonical_gr ref), NotAppNotation + | NApp (NList (_,_,NApp (NRef ref,args),_,_), args') -> + RefKey (canonical_gr ref), AppBoundedNotation (List.length args + List.length args') + | NApp (NList (_,_,NApp (_,args),_,_), args') -> + Oth, AppBoundedNotation (List.length args + List.length args') | NApp (_,args) -> Oth, AppBoundedNotation (List.length args) | NList (_,_,NApp (NVar x,_),_,_) when x = Notation_ops.ldots_var -> Oth, AppUnboundedNotation | _ -> Oth, NotAppNotation @@ -1415,12 +1431,12 @@ let check_parsing_override (scopt,ntn) data = function | OnlyParsingData (_,old_data) -> let overridden = not (interpretation_eq data.not_interp old_data.not_interp) in warn_override_if_needed (scopt,ntn) overridden data old_data; - None, not overridden + None | ParsingAndPrintingData (_,on_printing,old_data) -> let overridden = not (interpretation_eq data.not_interp old_data.not_interp) in warn_override_if_needed (scopt,ntn) overridden data old_data; - (if on_printing then Some old_data.not_interp else None), not overridden - | NoParsingData -> None, false + if on_printing then Some old_data.not_interp else None + | NoParsingData -> None let check_printing_override (scopt,ntn) data parsingdata printingdata = let parsing_update = match parsingdata with @@ -1449,15 +1465,15 @@ let update_notation_data (scopt,ntn) use data table = try NotationMap.find ntn table with Not_found -> (NoParsingData, []) in match use with | OnlyParsing -> - let printing_update, exists = check_parsing_override (scopt,ntn) data parsingdata in - NotationMap.add ntn (OnlyParsingData (true,data), printingdata) table, printing_update, exists + let printing_update = check_parsing_override (scopt,ntn) data parsingdata in + NotationMap.add ntn (OnlyParsingData (true,data), printingdata) table, printing_update | ParsingAndPrinting -> - let printing_update, exists = check_parsing_override (scopt,ntn) data parsingdata in - NotationMap.add ntn (ParsingAndPrintingData (true,true,data), printingdata) table, printing_update, exists + let printing_update = check_parsing_override (scopt,ntn) data parsingdata in + NotationMap.add ntn (ParsingAndPrintingData (true,true,data), printingdata) table, printing_update | OnlyPrinting -> let parsingdata, exists = check_printing_override (scopt,ntn) data parsingdata printingdata in let printingdata = if exists then printingdata else (true,data) :: printingdata in - NotationMap.add ntn (parsingdata, printingdata) table, None, exists + NotationMap.add ntn (parsingdata, printingdata) table, None let rec find_interpretation ntn find = function | [] -> raise Not_found @@ -1730,23 +1746,22 @@ let declare_notation (scopt,ntn) pat df ~use ~also_in_cases_pattern coe deprecat not_location = df; not_deprecation = deprecation; } in - let notation_update,printing_update, exists = update_notation_data (scopt,ntn) use notdata sc.notations in - if not exists then - let sc = { sc with notations = notation_update } in - scope_map := String.Map.add scope sc !scope_map; + let notation_update,printing_update = update_notation_data (scopt,ntn) use notdata sc.notations in + let sc = { sc with notations = notation_update } in + scope_map := String.Map.add scope sc !scope_map; (* Update the uninterpretation cache *) begin match printing_update with | Some pat -> remove_uninterpretation (NotationRule (scopt,ntn)) also_in_cases_pattern pat | None -> () end; - if not exists && use <> OnlyParsing then declare_uninterpretation ~also_in_cases_pattern (NotationRule (scopt,ntn)) pat; + if use <> OnlyParsing then declare_uninterpretation ~also_in_cases_pattern (NotationRule (scopt,ntn)) pat; (* Register visibility of lonely notations *) - if not exists then begin match scopt with + begin match scopt with | LastLonelyNotation -> scope_stack := LonelyNotationItem ntn :: !scope_stack | NotationInScope _ -> () end; (* Declare a possible coercion *) - if not exists then begin match coe with + begin match coe with | Some (IsEntryCoercion entry) -> let (_,level,_) = level_of_notation ntn in let level = match fst ntn with @@ -2035,12 +2050,12 @@ type symbol = | Break of int let rec symbol_eq s1 s2 = match s1, s2 with -| Terminal s1, Terminal s2 -> String.equal s1 s2 -| NonTerminal id1, NonTerminal id2 -> Id.equal id1 id2 -| SProdList (id1, l1), SProdList (id2, l2) -> - Id.equal id1 id2 && List.equal symbol_eq l1 l2 -| Break i1, Break i2 -> Int.equal i1 i2 -| _ -> false + | Terminal s1, Terminal s2 -> String.equal s1 s2 + | NonTerminal id1, NonTerminal id2 -> Id.equal id1 id2 + | SProdList (id1, l1), SProdList (id2, l2) -> + Id.equal id1 id2 && List.equal symbol_eq l1 l2 + | Break i1, Break i2 -> Int.equal i1 i2 + | _ -> false let rec string_of_symbol = function | NonTerminal _ -> ["_"] @@ -2202,23 +2217,114 @@ let rec raw_analyze_notation_tokens = function | WhiteSpace n :: sl -> Break n :: raw_analyze_notation_tokens sl -let decompose_raw_notation ntn = raw_analyze_notation_tokens (split_notation_string ntn) - -let possible_notations ntn = +let rec raw_analyze_anonymous_notation_tokens = function + | [] -> [] + | String ".." :: sl -> NonTerminal Notation_ops.ldots_var :: raw_analyze_anonymous_notation_tokens sl + | String "_" :: sl -> NonTerminal (Id.of_string "dummy") :: raw_analyze_anonymous_notation_tokens sl + | String s :: sl -> + Terminal (String.drop_simple_quotes s) :: raw_analyze_anonymous_notation_tokens sl + | WhiteSpace n :: sl -> raw_analyze_anonymous_notation_tokens sl + +(* Interpret notations with a recursive component *) + +let out_nt = function NonTerminal x -> x | _ -> assert false + +let msg_expected_form_of_recursive_notation = + "In the notation, the special symbol \"..\" must occur in\na configuration of the form \"x symbs .. symbs y\"." + +let rec find_pattern nt xl = function + | Break n as x :: l, Break n' :: l' when Int.equal n n' -> + find_pattern nt (x::xl) (l,l') + | Terminal s as x :: l, Terminal s' :: l' when String.equal s s' -> + find_pattern nt (x::xl) (l,l') + | [], NonTerminal x' :: l' -> + (out_nt nt,x',List.rev xl),l' + | _, Break s :: _ | Break s :: _, _ -> + user_err Pp.(str ("A break occurs on one side of \"..\" but not on the other side.")) + | _, Terminal s :: _ | Terminal s :: _, _ -> + user_err ~hdr:"Metasyntax.find_pattern" + (str "The token \"" ++ str s ++ str "\" occurs on one side of \"..\" but not on the other side.") + | _, [] -> + user_err Pp.(str msg_expected_form_of_recursive_notation) + | ((SProdList _ | NonTerminal _) :: _), _ | _, (SProdList _ :: _) -> + anomaly (Pp.str "Only Terminal or Break expected on left, non-SProdList on right.") + +let rec interp_list_parser hd = function + | [] -> [], List.rev hd + | NonTerminal id :: tl when Id.equal id Notation_ops.ldots_var -> + if List.is_empty hd then user_err Pp.(str msg_expected_form_of_recursive_notation); + let hd = List.rev hd in + let ((x,y,sl),tl') = find_pattern (List.hd hd) [] (List.tl hd,tl) in + let xyl,tl'' = interp_list_parser [] tl' in + (* We remember each pair of variable denoting a recursive part to *) + (* remove the second copy of it afterwards *) + (x,y)::xyl, SProdList (x,sl) :: tl'' + | (Terminal _ | Break _) as s :: tl -> + if List.is_empty hd then + let yl,tl' = interp_list_parser [] tl in + yl, s :: tl' + else + interp_list_parser (s::hd) tl + | NonTerminal _ as x :: tl -> + let xyl,tl' = interp_list_parser [x] tl in + xyl, List.rev_append hd tl' + | SProdList _ :: _ -> anomaly (Pp.str "Unexpected SProdList in interp_list_parser.") + +let get_notation_vars l = + List.map_filter (function NonTerminal id | SProdList (id,_) -> Some id | _ -> None) l + +let decompose_raw_notation ntn = + let l = split_notation_string ntn in + let l = raw_analyze_notation_tokens l in + let recvars,l = interp_list_parser [] l in + let vars = get_notation_vars l in + recvars, vars, l + +let interpret_notation_string ntn = (* We collect the possible interpretations of a notation string depending on whether it is in "x 'U' y" or "_ U _" format *) let toks = split_notation_string ntn in - if List.exists (function String "_" -> true | _ -> false) toks then - (* Only "_ U _" format *) - [ntn] - else - let _,ntn' = make_notation_key None (raw_analyze_notation_tokens toks) in - if String.equal ntn ntn' then (* Only symbols *) [ntn] else [ntn;ntn'] + let toks = + if + List.exists (function String "_" -> true | _ -> false) toks || + List.for_all (function String id -> Id.is_valid id | _ -> false) toks + then + (* Only "_ U _" format *) + raw_analyze_anonymous_notation_tokens toks + else + (* Includes the case of only a subset of tokens or an "x 'U' y"-style format *) + raw_analyze_notation_tokens toks + in + let _,toks = interp_list_parser [] toks in + let _,ntn' = make_notation_key None toks in + ntn' + +(* Tell if a non-recursive notation is an instance of a recursive one *) +let is_approximation ntn ntn' = + let rec aux toks1 toks2 = match (toks1, toks2) with + | Terminal s1 :: toks1, Terminal s2 :: toks2 -> String.equal s1 s2 && aux toks1 toks2 + | NonTerminal _ :: toks1, NonTerminal _ :: toks2 -> aux toks1 toks2 + | SProdList (_,l1) :: toks1, SProdList (_, l2) :: toks2 -> aux l1 l2 && aux toks1 toks2 + | NonTerminal _ :: toks1, SProdList (_,l2) :: toks2 -> aux' toks1 l2 l2 toks2 || aux toks1 toks2 + | [], [] -> true + | (Break _ :: _, _) | (_, Break _ :: _) -> assert false + | (Terminal _ | NonTerminal _ | SProdList _) :: _, _ -> false + | [], _ -> false + and aux' toks1 l2 l2full toks2 = match (toks1, l2) with + | Terminal s1 :: toks1, Terminal s2 :: l2 when String.equal s1 s2 -> aux' toks1 l2 l2full toks2 + | NonTerminal _ :: toks1, [] -> aux' toks1 l2full l2full toks2 || aux toks1 toks2 + | _ -> false + in + let _,toks = interp_list_parser [] (raw_analyze_anonymous_notation_tokens (split_notation_string ntn)) in + let _,toks' = interp_list_parser [] (raw_analyze_anonymous_notation_tokens (split_notation_string ntn')) in + aux toks toks' let browse_notation strict ntn map = - let ntns = possible_notations ntn in - let find (from,ntn' as fullntn') ntn = - if String.contains ntn ' ' then String.equal ntn ntn' + let ntn = interpret_notation_string ntn in + let find (from,ntn' as fullntn') = + if String.contains ntn ' ' then + if String.string_contains ~where:ntn' ~what:".." then is_approximation ntn ntn' + else String.equal ntn ntn' else let _,toks = decompose_notation_key fullntn' in let get_terminals = function Terminal ntn -> Some ntn | _ -> None in @@ -2230,7 +2336,7 @@ let browse_notation strict ntn map = String.Map.fold (fun scope_name sc -> NotationMap.fold (fun ntn data l -> - if List.exists (find ntn) ntns + if find ntn then List.map (fun d -> (ntn,scope_name,d)) (extract_notation_data data) @ l else l) sc.notations) map [] in diff --git a/interp/notation.mli b/interp/notation.mli index b8939ff87b..97955bf92e 100644 --- a/interp/notation.mli +++ b/interp/notation.mli @@ -334,8 +334,10 @@ val symbol_eq : symbol -> symbol -> bool val make_notation_key : notation_entry -> symbol list -> notation val decompose_notation_key : notation -> notation_entry * symbol list -(** Decompose a notation of the form "a 'U' b" *) -val decompose_raw_notation : string -> symbol list +(** Decompose a notation of the form "a 'U' b" together with the lists + of pairs of recursive variables and the list of all variables + binding in the notation *) +val decompose_raw_notation : string -> (Id.t * Id.t) list * Id.t list * symbol list (** Prints scopes (expects a pure aconstr printer) *) val pr_scope_class : scope_class -> Pp.t diff --git a/interp/notation_ops.ml b/interp/notation_ops.ml index 2e3fa0aa0e..d393dcaecb 100644 --- a/interp/notation_ops.ml +++ b/interp/notation_ops.ml @@ -24,82 +24,182 @@ open Notation_term (**********************************************************************) (* Utilities *) -(* helper for NVar, NVar case in eq_notation_constr *) -let get_var_ndx id vs = try Some (List.index Id.equal id vs) with Not_found -> None - -let rec eq_notation_constr (vars1,vars2 as vars) t1 t2 = -(vars1 == vars2 && t1 == t2) || -match t1, t2 with -| NRef gr1, NRef gr2 -> GlobRef.equal gr1 gr2 -| NVar id1, NVar id2 -> ( - match (get_var_ndx id1 vars1,get_var_ndx id2 vars2) with - | Some n,Some m -> Int.equal n m - | None ,None -> Id.equal id1 id2 - | _ -> false) -| NApp (t1, a1), NApp (t2, a2) -> - (eq_notation_constr vars) t1 t2 && List.equal (eq_notation_constr vars) a1 a2 -| NHole (_, _, _), NHole (_, _, _) -> true (* FIXME? *) -| NList (i1, j1, t1, u1, b1), NList (i2, j2, t2, u2, b2) -> - Id.equal i1 i2 && Id.equal j1 j2 && (eq_notation_constr vars) t1 t2 && - (eq_notation_constr vars) u1 u2 && b1 == b2 -| NLambda (na1, t1, u1), NLambda (na2, t2, u2) -> - Name.equal na1 na2 && (eq_notation_constr vars) t1 t2 && (eq_notation_constr vars) u1 u2 -| NProd (na1, t1, u1), NProd (na2, t2, u2) -> - Name.equal na1 na2 && (eq_notation_constr vars) t1 t2 && (eq_notation_constr vars) u1 u2 -| NBinderList (i1, j1, t1, u1, b1), NBinderList (i2, j2, t2, u2, b2) -> - Id.equal i1 i2 && Id.equal j1 j2 && (eq_notation_constr vars) t1 t2 && - (eq_notation_constr vars) u1 u2 && b1 == b2 -| NLetIn (na1, b1, t1, u1), NLetIn (na2, b2, t2, u2) -> - Name.equal na1 na2 && eq_notation_constr vars b1 b2 && - Option.equal (eq_notation_constr vars) t1 t2 && (eq_notation_constr vars) u1 u2 -| NCases (_, o1, r1, p1), NCases (_, o2, r2, p2) -> (* FIXME? *) - let eqpat (p1, t1) (p2, t2) = - List.equal cases_pattern_eq p1 p2 && - (eq_notation_constr vars) t1 t2 - in - let eqf (t1, (na1, o1)) (t2, (na2, o2)) = - let eq (i1, n1) (i2, n2) = Ind.CanOrd.equal i1 i2 && List.equal Name.equal n1 n2 in - (eq_notation_constr vars) t1 t2 && Name.equal na1 na2 && Option.equal eq o1 o2 - in - Option.equal (eq_notation_constr vars) o1 o2 && - List.equal eqf r1 r2 && - List.equal eqpat p1 p2 -| NLetTuple (nas1, (na1, o1), t1, u1), NLetTuple (nas2, (na2, o2), t2, u2) -> - List.equal Name.equal nas1 nas2 && - Name.equal na1 na2 && - Option.equal (eq_notation_constr vars) o1 o2 && - (eq_notation_constr vars) t1 t2 && - (eq_notation_constr vars) u1 u2 -| NIf (t1, (na1, o1), u1, r1), NIf (t2, (na2, o2), u2, r2) -> - (eq_notation_constr vars) t1 t2 && - Name.equal na1 na2 && - Option.equal (eq_notation_constr vars) o1 o2 && - (eq_notation_constr vars) u1 u2 && - (eq_notation_constr vars) r1 r2 -| NRec (_, ids1, ts1, us1, rs1), NRec (_, ids2, ts2, us2, rs2) -> (* FIXME? *) - let eq (na1, o1, t1) (na2, o2, t2) = - Name.equal na1 na2 && - Option.equal (eq_notation_constr vars) o1 o2 && - (eq_notation_constr vars) t1 t2 - in - Array.equal Id.equal ids1 ids2 && - Array.equal (List.equal eq) ts1 ts2 && - Array.equal (eq_notation_constr vars) us1 us2 && - Array.equal (eq_notation_constr vars) rs1 rs2 -| NSort s1, NSort s2 -> - glob_sort_eq s1 s2 -| NCast (t1, c1), NCast (t2, c2) -> - (eq_notation_constr vars) t1 t2 && cast_type_eq (eq_notation_constr vars) c1 c2 -| NInt i1, NInt i2 -> - Uint63.equal i1 i2 -| NFloat f1, NFloat f2 -> - Float64.equal f1 f2 -| NArray(t1,def1,ty1), NArray(t2,def2,ty2) -> - Array.equal (eq_notation_constr vars) t1 t2 && (eq_notation_constr vars) def1 def2 - && eq_notation_constr vars ty1 ty2 -| (NRef _ | NVar _ | NApp _ | NHole _ | NList _ | NLambda _ | NProd _ - | NBinderList _ | NLetIn _ | NCases _ | NLetTuple _ | NIf _ - | NRec _ | NSort _ | NCast _ | NInt _ | NFloat _ | NArray _), _ -> false +let ldots_var = Id.of_string ".." + +let rec alpha_var id1 id2 = function + | (i1,i2)::_ when Id.equal i1 id1 -> Id.equal i2 id2 + | (i1,i2)::_ when Id.equal i2 id2 -> Id.equal i1 id1 + | _::idl -> alpha_var id1 id2 idl + | [] -> Id.equal id1 id2 + +let cast_type_iter2 f t1 t2 = match t1, t2 with + | CastConv t1, CastConv t2 -> f t1 t2 + | CastVM t1, CastVM t2 -> f t1 t2 + | CastCoerce, CastCoerce -> () + | CastNative t1, CastNative t2 -> f t1 t2 + | (CastConv _ | CastVM _ | CastCoerce | CastNative _), _ -> raise Exit + +(* used to update the notation variable with the local variables used + in NList and NBinderList, since the iterator has its own variable *) +let replace_var i j var = j :: List.remove Id.equal i var + +(* When [lt] is [true], tell if [t1] is a strict refinement of [t2] + (this is a partial order, so returning [false] does not mean that + [t2] is finer than [t1]); when [lt] is false, tell if [t1] is the + same pattern as [t2] *) + +let compare_notation_constr lt (vars1,vars2) t1 t2 = + (* this is used to reason up to order of notation variables *) + let alphameta = ref [] in + (* this becomes true when at least one subterm is detected as strictly smaller *) + let strictly_lt = ref false in + (* this is the stack of inner of iter patterns for comparison with a + new iteration or the tail of a recursive pattern *) + let tail = ref [] in + let check_alphameta id1 id2 = + try if not (Id.equal (List.assoc id1 !alphameta) id2) then raise Exit + with Not_found -> + if (List.mem_assoc id1 !alphameta) then raise Exit; + alphameta := (id1,id2) :: !alphameta in + let check_eq_id (vars1,vars2) renaming id1 id2 = + let ismeta1 = List.mem_f Id.equal id1 vars1 in + let ismeta2 = List.mem_f Id.equal id2 vars2 in + match ismeta1, ismeta2 with + | true, true -> check_alphameta id1 id2 + | false, false -> if not (alpha_var id1 id2 renaming) then raise Exit + | false, true -> + if not lt then raise Exit + else + (* a binder which is not bound in the notation can be + considered as strictly more precise since it prevents the + notation variables in its scope to be bound by this binder; + i.e. it is strictly more precise in the sense that it + covers strictly less patterns than a notation where the + same binder is bound in the notation; this is hawever + disputable *) + strictly_lt := true + | true, false -> if not lt then raise Exit in + let check_eq_name vars renaming na1 na2 = + match na1, na2 with + | Name id1, Name id2 -> check_eq_id vars renaming id1 id2; (id1,id2)::renaming + | Anonymous, Anonymous -> renaming + | Anonymous, Name _ when lt -> renaming + | _ -> raise Exit in + let rec aux (vars1,vars2 as vars) renaming t1 t2 = match t1, t2 with + | NVar id1, NVar id2 when id1 = ldots_var && id2 = ldots_var -> () + | _, NVar id2 when lt && id2 = ldots_var -> tail := t1 :: !tail + | NVar id1, _ when lt && id1 = ldots_var -> tail := t2 :: !tail + | NVar id1, NVar id2 -> check_eq_id vars renaming id1 id2 + | NHole _, NVar id2 when lt && List.mem_f Id.equal id2 vars2 -> () + | NVar id1, NHole (_, _, _) when lt && List.mem_f Id.equal id1 vars1 -> () + | _, NVar id2 when lt && List.mem_f Id.equal id2 vars2 -> strictly_lt := true + | NRef gr1, NRef gr2 when GlobRef.equal gr1 gr2 -> () + | NHole (_, _, _), NHole (_, _, _) -> () (* FIXME? *) + | _, NHole (_, _, _) when lt -> strictly_lt := true + | NList (i1, j1, iter1, tail1, b1), NList (i2, j2, iter2, tail2, b2) + | NBinderList (i1, j1, iter1, tail1, b1), NBinderList (i2, j2, iter2, tail2, b2) -> + if b1 <> b2 then raise Exit; + let vars1 = replace_var i1 j1 vars1 in + let vars2 = replace_var i2 j2 vars2 in + check_alphameta i1 i2; aux (vars1,vars2) renaming iter1 iter2; aux vars renaming tail1 tail2; + | NBinderList (i1, j1, iter1, tail1, b1), NList (i2, j2, iter2, tail2, b2) + | NList (i1, j1, iter1, tail1, b1), NBinderList (i2, j2, iter2, tail2, b2) -> + (* They may overlap on a unique iteration of them *) + let vars1 = replace_var i1 j1 vars1 in + let vars2 = replace_var i2 j2 vars2 in + aux (vars1,vars2) renaming iter1 iter2; + aux vars renaming tail1 tail2 + | t1, NList (i2, j2, iter2, tail2, b2) + | t1, NBinderList (i2, j2, iter2, tail2, b2) when lt -> + (* checking if t1 is a finite iteration of the pattern *) + let vars2 = replace_var i2 j2 vars2 in + aux (vars1,vars2) renaming t1 iter2; + let t1 = List.hd !tail in + tail := List.tl !tail; + (* either matching a new iteration, or matching the tail *) + (try aux vars renaming t1 tail2 with Exit -> aux vars renaming t1 t2) + | NList (i1, j1, iter1, tail1, b1), t2 + | NBinderList (i1, j1, iter1, tail1, b1), t2 when lt -> + (* we see the NList as a single iteration *) + let vars1 = replace_var i1 j1 vars1 in + aux (vars1,vars2) renaming iter1 t2; + let t2 = match !tail with + | t::rest -> tail := rest; t + | _ -> (* ".." is in a discarded fine-grained position *) raise Exit in + (* it had to be a single iteration of iter1 *) + aux vars renaming tail1 t2 + | NApp (t1, a1), NApp (t2, a2) -> aux vars renaming t1 t2; List.iter2 (aux vars renaming) a1 a2 + | NLambda (na1, t1, u1), NLambda (na2, t2, u2) + | NProd (na1, t1, u1), NProd (na2, t2, u2) -> + aux vars renaming t1 t2; + let renaming = check_eq_name vars renaming na1 na2 in + aux vars renaming u1 u2 + | NLetIn (na1, b1, t1, u1), NLetIn (na2, b2, t2, u2) -> + aux vars renaming b1 b2; + Option.iter2 (aux vars renaming) t1 t2;(* TODO : subtyping? *) + let renaming = check_eq_name vars renaming na1 na2 in + aux vars renaming u1 u2 + | NCases (_, o1, r1, p1), NCases (_, o2, r2, p2) -> (* FIXME? *) + let check_pat (p1, t1) (p2, t2) = + if not (List.equal cases_pattern_eq p1 p2) then raise Exit; (* TODO: subtyping and renaming *) + aux vars renaming t1 t2 + in + let eqf renaming (t1, (na1, o1)) (t2, (na2, o2)) = + aux vars renaming t1 t2; + let renaming = check_eq_name vars renaming na1 na2 in + let eq renaming (i1, n1) (i2, n2) = + if not (Ind.CanOrd.equal i1 i2) then raise Exit; + List.fold_left2 (check_eq_name vars) renaming n1 n2 in + Option.fold_left2 eq renaming o1 o2 in + let renaming = List.fold_left2 eqf renaming r1 r2 in + Option.iter2 (aux vars renaming) o1 o2; + List.iter2 check_pat p1 p2 + | NLetTuple (nas1, (na1, o1), t1, u1), NLetTuple (nas2, (na2, o2), t2, u2) -> + aux vars renaming t1 t2; + let renaming = check_eq_name vars renaming na1 na2 in + Option.iter2 (aux vars renaming) o1 o2; + let renaming' = List.fold_left2 (check_eq_name vars) renaming nas1 nas2 in + aux vars renaming' u1 u2 + | NIf (t1, (na1, o1), u1, r1), NIf (t2, (na2, o2), u2, r2) -> + aux vars renaming t1 t2; + aux vars renaming u1 u2; + aux vars renaming r1 r2; + let renaming = check_eq_name vars renaming na1 na2 in + Option.iter2 (aux vars renaming) o1 o2 + | NRec (_, ids1, ts1, us1, rs1), NRec (_, ids2, ts2, us2, rs2) -> (* FIXME? *) + let eq renaming (na1, o1, t1) (na2, o2, t2) = + Option.iter2 (aux vars renaming) o1 o2; + aux vars renaming t1 t2; + check_eq_name vars renaming na1 na2 + in + let renaming = Array.fold_left2 (fun r id1 id2 -> check_eq_id vars r id1 id2; (id1,id2)::r) renaming ids1 ids2 in + let renamings = Array.map2 (List.fold_left2 eq renaming) ts1 ts2 in + Array.iter3 (aux vars) renamings us1 us2; + Array.iter3 (aux vars) (Array.map ((@) renaming) renamings) rs1 rs2 + | NSort s1, NSort s2 when glob_sort_eq s1 s2 -> () + | NCast (t1, c1), NCast (t2, c2) -> + aux vars renaming t1 t2; + cast_type_iter2 (aux vars renaming) c1 c2 + | NInt i1, NInt i2 when Uint63.equal i1 i2 -> () + | NFloat f1, NFloat f2 when Float64.equal f1 f2 -> () + | NArray(t1,def1,ty1), NArray(t2,def2,ty2) -> + Array.iter2 (aux vars renaming) t1 t2; + aux vars renaming def1 def2; + aux vars renaming ty1 ty2 + | (NRef _ | NVar _ | NApp _ | NHole _ | NList _ | NLambda _ | NProd _ + | NBinderList _ | NLetIn _ | NCases _ | NLetTuple _ | NIf _ + | NRec _ | NSort _ | NCast _ | NInt _ | NFloat _ | NArray _), _ -> raise Exit in + try + let _ = aux (vars1,vars2) [] t1 t2 in + if not lt then + (* Check that order of notation metavariables does not matter *) + List.iter2 check_alphameta vars1 vars2; + not lt || !strictly_lt + with Exit | (* Option.iter2: *) Option.Heterogeneous | Invalid_argument _ -> false + +let eq_notation_constr vars t1 t2 = t1 == t2 || compare_notation_constr false vars t1 t2 + +let strictly_finer_notation_constr vars t1 t2 = compare_notation_constr true vars t1 t2 (**********************************************************************) (* Re-interpret a notation as a glob_constr, taking care of binders *) @@ -154,8 +254,6 @@ let rec subst_glob_vars l gc = DAst.map (function | _ -> DAst.get (map_glob_constr (subst_glob_vars l) gc) (* assume: id is not binding *) ) gc -let ldots_var = Id.of_string ".." - type 'a binder_status_fun = { no : 'a -> 'a; restart_prod : 'a -> 'a; @@ -275,6 +373,12 @@ type found_variables = { let add_id r id = r := { !r with vars = id :: (!r).vars } let add_name r = function Anonymous -> () | Name id -> add_id r id +let mkNApp1 (g,a) = + match g with + (* Ensure flattening of nested applicative nodes *) + | NApp (g,args') -> NApp (g,args'@[a]) + | _ -> NApp (g,[a]) + let is_gvar id c = match DAst.get c with | GVar id' -> Id.equal id id' | _ -> false @@ -443,7 +547,10 @@ let notation_constr_and_vars_of_glob_constr recvars a = aux' c and aux' x = DAst.with_val (function | GVar id -> if not (Id.equal id ldots_var) then add_id found id; NVar id - | GApp (g,args) -> NApp (aux g, List.map aux args) + | GApp (g,[]) -> NApp (aux g,[]) (* Encoding @foo *) + | GApp (g,args) -> + (* Treat applicative notes as binary nodes *) + let a,args = List.sep_last args in mkNApp1 (aux (DAst.make (GApp (g, args))), aux a) | GLambda (na,bk,ty,c) -> add_name found na; NLambda (na,aux ty,aux c) | GProd (na,bk,ty,c) -> add_name found na; NProd (na,aux ty,aux c) | GLetIn (na,b,t,c) -> add_name found na; NLetIn (na,aux b,Option.map aux t, aux c) @@ -740,12 +847,6 @@ let is_bindinglist_meta id metas = exception No_match -let rec alpha_var id1 id2 = function - | (i1,i2)::_ when Id.equal i1 id1 -> Id.equal i2 id2 - | (i1,i2)::_ when Id.equal i2 id2 -> Id.equal i1 id1 - | _::idl -> alpha_var id1 id2 idl - | [] -> Id.equal id1 id2 - let alpha_rename alpmetas v = if alpmetas == [] then v else try rename_glob_vars alpmetas v with UnsoundRenaming -> raise No_match diff --git a/interp/notation_ops.mli b/interp/notation_ops.mli index 3cc6f82ec8..3182ea96d7 100644 --- a/interp/notation_ops.mli +++ b/interp/notation_ops.mli @@ -16,6 +16,11 @@ open Glob_term val eq_notation_constr : Id.t list * Id.t list -> notation_constr -> notation_constr -> bool +val strictly_finer_notation_constr : Id.t list * Id.t list -> notation_constr -> notation_constr -> bool +(** Tell if [t1] is a strict refinement of [t2] + (this is a partial order and returning [false] does not mean that + [t2] is finer than [t1]) *) + (** Substitution of kernel names in interpretation data *) val subst_interpretation : |
