diff options
Diffstat (limited to 'interp/notation.ml')
| -rw-r--r-- | interp/notation.ml | 180 |
1 files changed, 143 insertions, 37 deletions
diff --git a/interp/notation.ml b/interp/notation.ml index 8d05fab63c..286ece6cb6 100644 --- a/interp/notation.ml +++ b/interp/notation.ml @@ -345,11 +345,23 @@ let also_cases_notation_rule_eq (also_cases1,rule1) (also_cases2,rule2) = (* No need in principle to compare also_cases as it is inferred *) also_cases1 = also_cases2 && notation_rule_eq rule1 rule2 +let adjust_application c1 c2 = + match c1, c2 with + | NApp (t1, a1), (NList (_,_,NApp (_, a2),_,_) | NApp (_, a2)) when List.length a1 >= List.length a2 -> + NApp (t1, List.firstn (List.length a2) a1) + | NApp (t1, a1), _ -> + t1 + | _ -> c1 + +let strictly_finer_interpretation_than (_,(_,(vars1,c1),_)) (_,(_,(vars2,c2),_)) = + let c1 = adjust_application c1 c2 in + Notation_ops.strictly_finer_notation_constr (List.map fst vars1, List.map fst vars2) c1 c2 + let keymap_add key interp map = let old = try KeyMap.find key map with Not_found -> [] in - (* In case of re-import, no need to keep the previous copy *) - let old = try List.remove_first (also_cases_notation_rule_eq interp) old with Not_found -> old in - KeyMap.add key (interp :: old) map + (* strictly finer interpretation are kept in front *) + let strictly_finer, rest = List.partition (fun c -> strictly_finer_interpretation_than c interp) old in + KeyMap.add key (strictly_finer @ interp :: rest) map let keymap_remove key interp map = let old = try KeyMap.find key map with Not_found -> [] in @@ -391,6 +403,10 @@ let notation_constr_key = function (* Rem: NApp(NRef ref,[]) stands for @ref *) | NBinderList (_,_,NApp (NRef ref,args),_,_) -> RefKey (canonical_gr ref), AppBoundedNotation (List.length args) | NRef ref -> RefKey(canonical_gr ref), NotAppNotation + | NApp (NList (_,_,NApp (NRef ref,args),_,_), args') -> + RefKey (canonical_gr ref), AppBoundedNotation (List.length args + List.length args') + | NApp (NList (_,_,NApp (_,args),_,_), args') -> + Oth, AppBoundedNotation (List.length args + List.length args') | NApp (_,args) -> Oth, AppBoundedNotation (List.length args) | NList (_,_,NApp (NVar x,_),_,_) when x = Notation_ops.ldots_var -> Oth, AppUnboundedNotation | _ -> Oth, NotAppNotation @@ -1415,12 +1431,12 @@ let check_parsing_override (scopt,ntn) data = function | OnlyParsingData (_,old_data) -> let overridden = not (interpretation_eq data.not_interp old_data.not_interp) in warn_override_if_needed (scopt,ntn) overridden data old_data; - None, not overridden + None | ParsingAndPrintingData (_,on_printing,old_data) -> let overridden = not (interpretation_eq data.not_interp old_data.not_interp) in warn_override_if_needed (scopt,ntn) overridden data old_data; - (if on_printing then Some old_data.not_interp else None), not overridden - | NoParsingData -> None, false + if on_printing then Some old_data.not_interp else None + | NoParsingData -> None let check_printing_override (scopt,ntn) data parsingdata printingdata = let parsing_update = match parsingdata with @@ -1449,15 +1465,15 @@ let update_notation_data (scopt,ntn) use data table = try NotationMap.find ntn table with Not_found -> (NoParsingData, []) in match use with | OnlyParsing -> - let printing_update, exists = check_parsing_override (scopt,ntn) data parsingdata in - NotationMap.add ntn (OnlyParsingData (true,data), printingdata) table, printing_update, exists + let printing_update = check_parsing_override (scopt,ntn) data parsingdata in + NotationMap.add ntn (OnlyParsingData (true,data), printingdata) table, printing_update | ParsingAndPrinting -> - let printing_update, exists = check_parsing_override (scopt,ntn) data parsingdata in - NotationMap.add ntn (ParsingAndPrintingData (true,true,data), printingdata) table, printing_update, exists + let printing_update = check_parsing_override (scopt,ntn) data parsingdata in + NotationMap.add ntn (ParsingAndPrintingData (true,true,data), printingdata) table, printing_update | OnlyPrinting -> let parsingdata, exists = check_printing_override (scopt,ntn) data parsingdata printingdata in let printingdata = if exists then printingdata else (true,data) :: printingdata in - NotationMap.add ntn (parsingdata, printingdata) table, None, exists + NotationMap.add ntn (parsingdata, printingdata) table, None let rec find_interpretation ntn find = function | [] -> raise Not_found @@ -1730,23 +1746,22 @@ let declare_notation (scopt,ntn) pat df ~use ~also_in_cases_pattern coe deprecat not_location = df; not_deprecation = deprecation; } in - let notation_update,printing_update, exists = update_notation_data (scopt,ntn) use notdata sc.notations in - if not exists then - let sc = { sc with notations = notation_update } in - scope_map := String.Map.add scope sc !scope_map; + let notation_update,printing_update = update_notation_data (scopt,ntn) use notdata sc.notations in + let sc = { sc with notations = notation_update } in + scope_map := String.Map.add scope sc !scope_map; (* Update the uninterpretation cache *) begin match printing_update with | Some pat -> remove_uninterpretation (NotationRule (scopt,ntn)) also_in_cases_pattern pat | None -> () end; - if not exists && use <> OnlyParsing then declare_uninterpretation ~also_in_cases_pattern (NotationRule (scopt,ntn)) pat; + if use <> OnlyParsing then declare_uninterpretation ~also_in_cases_pattern (NotationRule (scopt,ntn)) pat; (* Register visibility of lonely notations *) - if not exists then begin match scopt with + begin match scopt with | LastLonelyNotation -> scope_stack := LonelyNotationItem ntn :: !scope_stack | NotationInScope _ -> () end; (* Declare a possible coercion *) - if not exists then begin match coe with + begin match coe with | Some (IsEntryCoercion entry) -> let (_,level,_) = level_of_notation ntn in let level = match fst ntn with @@ -2035,12 +2050,12 @@ type symbol = | Break of int let rec symbol_eq s1 s2 = match s1, s2 with -| Terminal s1, Terminal s2 -> String.equal s1 s2 -| NonTerminal id1, NonTerminal id2 -> Id.equal id1 id2 -| SProdList (id1, l1), SProdList (id2, l2) -> - Id.equal id1 id2 && List.equal symbol_eq l1 l2 -| Break i1, Break i2 -> Int.equal i1 i2 -| _ -> false + | Terminal s1, Terminal s2 -> String.equal s1 s2 + | NonTerminal id1, NonTerminal id2 -> Id.equal id1 id2 + | SProdList (id1, l1), SProdList (id2, l2) -> + Id.equal id1 id2 && List.equal symbol_eq l1 l2 + | Break i1, Break i2 -> Int.equal i1 i2 + | _ -> false let rec string_of_symbol = function | NonTerminal _ -> ["_"] @@ -2202,23 +2217,114 @@ let rec raw_analyze_notation_tokens = function | WhiteSpace n :: sl -> Break n :: raw_analyze_notation_tokens sl -let decompose_raw_notation ntn = raw_analyze_notation_tokens (split_notation_string ntn) - -let possible_notations ntn = +let rec raw_analyze_anonymous_notation_tokens = function + | [] -> [] + | String ".." :: sl -> NonTerminal Notation_ops.ldots_var :: raw_analyze_anonymous_notation_tokens sl + | String "_" :: sl -> NonTerminal (Id.of_string "dummy") :: raw_analyze_anonymous_notation_tokens sl + | String s :: sl -> + Terminal (String.drop_simple_quotes s) :: raw_analyze_anonymous_notation_tokens sl + | WhiteSpace n :: sl -> raw_analyze_anonymous_notation_tokens sl + +(* Interpret notations with a recursive component *) + +let out_nt = function NonTerminal x -> x | _ -> assert false + +let msg_expected_form_of_recursive_notation = + "In the notation, the special symbol \"..\" must occur in\na configuration of the form \"x symbs .. symbs y\"." + +let rec find_pattern nt xl = function + | Break n as x :: l, Break n' :: l' when Int.equal n n' -> + find_pattern nt (x::xl) (l,l') + | Terminal s as x :: l, Terminal s' :: l' when String.equal s s' -> + find_pattern nt (x::xl) (l,l') + | [], NonTerminal x' :: l' -> + (out_nt nt,x',List.rev xl),l' + | _, Break s :: _ | Break s :: _, _ -> + user_err Pp.(str ("A break occurs on one side of \"..\" but not on the other side.")) + | _, Terminal s :: _ | Terminal s :: _, _ -> + user_err ~hdr:"Metasyntax.find_pattern" + (str "The token \"" ++ str s ++ str "\" occurs on one side of \"..\" but not on the other side.") + | _, [] -> + user_err Pp.(str msg_expected_form_of_recursive_notation) + | ((SProdList _ | NonTerminal _) :: _), _ | _, (SProdList _ :: _) -> + anomaly (Pp.str "Only Terminal or Break expected on left, non-SProdList on right.") + +let rec interp_list_parser hd = function + | [] -> [], List.rev hd + | NonTerminal id :: tl when Id.equal id Notation_ops.ldots_var -> + if List.is_empty hd then user_err Pp.(str msg_expected_form_of_recursive_notation); + let hd = List.rev hd in + let ((x,y,sl),tl') = find_pattern (List.hd hd) [] (List.tl hd,tl) in + let xyl,tl'' = interp_list_parser [] tl' in + (* We remember each pair of variable denoting a recursive part to *) + (* remove the second copy of it afterwards *) + (x,y)::xyl, SProdList (x,sl) :: tl'' + | (Terminal _ | Break _) as s :: tl -> + if List.is_empty hd then + let yl,tl' = interp_list_parser [] tl in + yl, s :: tl' + else + interp_list_parser (s::hd) tl + | NonTerminal _ as x :: tl -> + let xyl,tl' = interp_list_parser [x] tl in + xyl, List.rev_append hd tl' + | SProdList _ :: _ -> anomaly (Pp.str "Unexpected SProdList in interp_list_parser.") + +let get_notation_vars l = + List.map_filter (function NonTerminal id | SProdList (id,_) -> Some id | _ -> None) l + +let decompose_raw_notation ntn = + let l = split_notation_string ntn in + let l = raw_analyze_notation_tokens l in + let recvars,l = interp_list_parser [] l in + let vars = get_notation_vars l in + recvars, vars, l + +let interpret_notation_string ntn = (* We collect the possible interpretations of a notation string depending on whether it is in "x 'U' y" or "_ U _" format *) let toks = split_notation_string ntn in - if List.exists (function String "_" -> true | _ -> false) toks then - (* Only "_ U _" format *) - [ntn] - else - let _,ntn' = make_notation_key None (raw_analyze_notation_tokens toks) in - if String.equal ntn ntn' then (* Only symbols *) [ntn] else [ntn;ntn'] + let toks = + if + List.exists (function String "_" -> true | _ -> false) toks || + List.for_all (function String id -> Id.is_valid id | _ -> false) toks + then + (* Only "_ U _" format *) + raw_analyze_anonymous_notation_tokens toks + else + (* Includes the case of only a subset of tokens or an "x 'U' y"-style format *) + raw_analyze_notation_tokens toks + in + let _,toks = interp_list_parser [] toks in + let _,ntn' = make_notation_key None toks in + ntn' + +(* Tell if a non-recursive notation is an instance of a recursive one *) +let is_approximation ntn ntn' = + let rec aux toks1 toks2 = match (toks1, toks2) with + | Terminal s1 :: toks1, Terminal s2 :: toks2 -> String.equal s1 s2 && aux toks1 toks2 + | NonTerminal _ :: toks1, NonTerminal _ :: toks2 -> aux toks1 toks2 + | SProdList (_,l1) :: toks1, SProdList (_, l2) :: toks2 -> aux l1 l2 && aux toks1 toks2 + | NonTerminal _ :: toks1, SProdList (_,l2) :: toks2 -> aux' toks1 l2 l2 toks2 || aux toks1 toks2 + | [], [] -> true + | (Break _ :: _, _) | (_, Break _ :: _) -> assert false + | (Terminal _ | NonTerminal _ | SProdList _) :: _, _ -> false + | [], _ -> false + and aux' toks1 l2 l2full toks2 = match (toks1, l2) with + | Terminal s1 :: toks1, Terminal s2 :: l2 when String.equal s1 s2 -> aux' toks1 l2 l2full toks2 + | NonTerminal _ :: toks1, [] -> aux' toks1 l2full l2full toks2 || aux toks1 toks2 + | _ -> false + in + let _,toks = interp_list_parser [] (raw_analyze_anonymous_notation_tokens (split_notation_string ntn)) in + let _,toks' = interp_list_parser [] (raw_analyze_anonymous_notation_tokens (split_notation_string ntn')) in + aux toks toks' let browse_notation strict ntn map = - let ntns = possible_notations ntn in - let find (from,ntn' as fullntn') ntn = - if String.contains ntn ' ' then String.equal ntn ntn' + let ntn = interpret_notation_string ntn in + let find (from,ntn' as fullntn') = + if String.contains ntn ' ' then + if String.string_contains ~where:ntn' ~what:".." then is_approximation ntn ntn' + else String.equal ntn ntn' else let _,toks = decompose_notation_key fullntn' in let get_terminals = function Terminal ntn -> Some ntn | _ -> None in @@ -2230,7 +2336,7 @@ let browse_notation strict ntn map = String.Map.fold (fun scope_name sc -> NotationMap.fold (fun ntn data l -> - if List.exists (find ntn) ntns + if find ntn then List.map (fun d -> (ntn,scope_name,d)) (extract_notation_data data) @ l else l) sc.notations) map [] in |
