aboutsummaryrefslogtreecommitdiff
path: root/interp/notation.ml
diff options
context:
space:
mode:
authorHugo Herbelin2017-11-25 17:19:49 +0100
committerHugo Herbelin2018-07-29 02:40:22 +0200
commit60daf674df3d11fa2948bbc7c9a928c09f22d099 (patch)
tree533584dd6acd3bde940529e8d3a111eca6fcbdef /interp/notation.ml
parent33d86118c7d1bfba31008b410d81c7f45dbdf092 (diff)
Adding support for custom entries in notations.
- New command "Declare Custom Entry bar". - Entries can have levels. - Printing is done using a notion of coercion between grammar entries. This typically corresponds to rules of the form 'Notation "[ x ]" := x (x custom myconstr).' but also 'Notation "{ x }" := x (in custom myconstr, x constr).'. - Rules declaring idents such as 'Notation "x" := x (in custom myconstr, x ident).' are natively recognized. - Rules declaring globals such as 'Notation "x" := x (in custom myconstr, x global).' are natively recognized. Incidentally merging ETConstr and ETConstrAsBinder. Noticed in passing that parsing binder as custom was not done as in constr. Probably some fine-tuning still to do (priority of notations, interactions between scopes and entries, ...). To be tested live further.
Diffstat (limited to 'interp/notation.ml')
-rw-r--r--interp/notation.ml231
1 files changed, 189 insertions, 42 deletions
diff --git a/interp/notation.ml b/interp/notation.ml
index 05fcd0e7f5..625d072b9f 100644
--- a/interp/notation.ml
+++ b/interp/notation.ml
@@ -39,6 +39,30 @@ open Context.Named.Declaration
expression, set this scope to be the current scope
*)
+let notation_entry_eq s1 s2 = match (s1,s2) with
+| InConstrEntry, InConstrEntry -> true
+| InCustomEntry s1, InCustomEntry s2 -> String.equal s1 s2
+| (InConstrEntry | InCustomEntry _), _ -> false
+
+let notation_entry_level_eq s1 s2 = match (s1,s2) with
+| InConstrEntrySomeLevel, InConstrEntrySomeLevel -> true
+| InCustomEntryLevel (s1,n1), InCustomEntryLevel (s2,n2) -> String.equal s1 s2 && n1 = n2
+| (InConstrEntrySomeLevel | InCustomEntryLevel _), _ -> false
+
+let notation_eq (from1,ntn1) (from2,ntn2) =
+ notation_entry_level_eq from1 from2 && String.equal ntn1 ntn2
+
+let pr_notation (from,ntn) = qstring ntn ++ match from with InConstrEntrySomeLevel -> mt () | InCustomEntryLevel (s,n) -> str " in custom " ++ str s
+
+module NotationOrd =
+ struct
+ type t = notation
+ let compare = Pervasives.compare
+ end
+
+module NotationSet = Set.Make(NotationOrd)
+module NotationMap = CMap.Make(NotationOrd)
+
(**********************************************************************)
(* Scope of symbols *)
@@ -51,7 +75,7 @@ type notation_data = {
}
type scope = {
- notations: notation_data String.Map.t;
+ notations: notation_data NotationMap.t;
delimiters: delimiters option
}
@@ -62,7 +86,7 @@ let scope_map = ref String.Map.empty
let delimiters_map = ref String.Map.empty
let empty_scope = {
- notations = String.Map.empty;
+ notations = NotationMap.empty;
delimiters = None
}
@@ -71,6 +95,9 @@ let default_scope = "" (* empty name, not available from outside *)
let init_scope_map () =
scope_map := String.Map.add default_scope empty_scope !scope_map
+(**********************************************************************)
+(* Operations on scopes *)
+
let declare_scope scope =
try let _ = String.Map.find scope !scope_map in ()
with Not_found ->
@@ -101,12 +128,12 @@ let normalize_scope sc =
(**********************************************************************)
(* The global stack of scopes *)
-type scope_elem = Scope of scope_name | SingleNotation of string
+type scope_elem = Scope of scope_name | SingleNotation of notation
type scopes = scope_elem list
let scope_eq s1 s2 = match s1, s2 with
-| Scope s1, Scope s2
-| SingleNotation s1, SingleNotation s2 -> String.equal s1 s2
+| Scope s1, Scope s2 -> String.equal s1 s2
+| SingleNotation s1, SingleNotation s2 -> notation_eq s1 s2
| Scope _, SingleNotation _
| SingleNotation _, Scope _ -> false
@@ -158,8 +185,6 @@ let push_scope sc scopes = Scope sc :: scopes
let push_scopes = List.fold_right push_scope
-type local_scopes = tmp_scope_name option * scope_name list
-
let make_current_scopes (tmp_scope,scopes) =
Option.fold_right push_scope tmp_scope (push_scopes scopes !scope_stack)
@@ -376,7 +401,7 @@ let rec find_without_delimiters find (ntn_scope,ntn) = function
end
| SingleNotation ntn' :: scopes ->
begin match ntn_scope, ntn with
- | None, Some ntn when String.equal ntn ntn' ->
+ | None, Some ntn when notation_eq ntn ntn' ->
Some (None, None)
| _ ->
find_without_delimiters find (ntn_scope,ntn) scopes
@@ -390,7 +415,7 @@ let rec find_without_delimiters find (ntn_scope,ntn) = function
let warn_notation_overridden =
CWarnings.create ~name:"notation-overridden" ~category:"parsing"
(fun (ntn,which_scope) ->
- str "Notation" ++ spc () ++ str ntn ++ spc ()
+ str "Notation" ++ spc () ++ pr_notation ntn ++ spc ()
++ strbrk "was already used" ++ which_scope ++ str ".")
let declare_notation_interpretation ntn scopt pat df ~onlyprint =
@@ -398,7 +423,7 @@ let declare_notation_interpretation ntn scopt pat df ~onlyprint =
let sc = find_scope scope in
if not onlyprint then begin
let () =
- if String.Map.mem ntn sc.notations then
+ if NotationMap.mem ntn sc.notations then
let which_scope = match scopt with
| None -> mt ()
| Some _ -> spc () ++ strbrk "in scope" ++ spc () ++ str scope in
@@ -408,7 +433,7 @@ let declare_notation_interpretation ntn scopt pat df ~onlyprint =
not_interp = pat;
not_location = df;
} in
- let sc = { sc with notations = String.Map.add ntn notdata sc.notations } in
+ let sc = { sc with notations = NotationMap.add ntn notdata sc.notations } in
scope_map := String.Map.add scope sc !scope_map
end;
begin match scopt with
@@ -425,7 +450,7 @@ let rec find_interpretation ntn find = function
| Scope scope :: scopes ->
(try let (pat,df) = find scope in pat,(df,Some scope)
with Not_found -> find_interpretation ntn find scopes)
- | SingleNotation ntn'::scopes when String.equal ntn' ntn ->
+ | SingleNotation ntn'::scopes when notation_eq ntn' ntn ->
(try let (pat,df) = find default_scope in pat,(df,None)
with Not_found ->
(* e.g. because single notation only for constr, not cases_pattern *)
@@ -434,12 +459,12 @@ let rec find_interpretation ntn find = function
find_interpretation ntn find scopes
let find_notation ntn sc =
- let n = String.Map.find ntn (find_scope sc).notations in
+ let n = NotationMap.find ntn (find_scope sc).notations in
(n.not_interp, n.not_location)
let notation_of_prim_token = function
- | Numeral (n,true) -> n
- | Numeral (n,false) -> "- "^n
+ | Numeral (n,true) -> InConstrEntrySomeLevel, n
+ | Numeral (n,false) -> InConstrEntrySomeLevel, "- "^n
| String _ -> raise Not_found
let find_prim_token check_allowed ?loc p sc =
@@ -459,13 +484,13 @@ let find_prim_token check_allowed ?loc p sc =
let interp_prim_token_gen ?loc g p local_scopes =
let scopes = make_current_scopes local_scopes in
- let p_as_ntn = try notation_of_prim_token p with Not_found -> "" in
+ let p_as_ntn = try notation_of_prim_token p with Not_found -> InConstrEntrySomeLevel,"" in
try find_interpretation p_as_ntn (find_prim_token ?loc g p) scopes
with Not_found ->
user_err ?loc ~hdr:"interp_prim_token"
((match p with
| Numeral _ ->
- str "No interpretation for numeral " ++ str (notation_of_prim_token p)
+ str "No interpretation for numeral " ++ pr_notation (notation_of_prim_token p)
| String s -> str "No interpretation for string " ++ qs s) ++ str ".")
let interp_prim_token ?loc =
@@ -490,7 +515,7 @@ let interp_notation ?loc ntn local_scopes =
try find_interpretation ntn (find_notation ntn) scopes
with Not_found ->
user_err ?loc
- (str "Unknown interpretation for notation \"" ++ str ntn ++ str "\".")
+ (str "Unknown interpretation for notation " ++ pr_notation ntn ++ str ".")
let uninterp_notations c =
List.map_append (fun key -> keymap_find key !notations_key_table)
@@ -504,9 +529,125 @@ let uninterp_ind_pattern_notations ind =
let availability_of_notation (ntn_scope,ntn) scopes =
let f scope =
- String.Map.mem ntn (String.Map.find scope !scope_map).notations in
+ NotationMap.mem ntn (String.Map.find scope !scope_map).notations in
find_without_delimiters f (ntn_scope,Some ntn) (make_current_scopes scopes)
+(* We support coercions from a custom entry at some level to an entry
+ at some level (possibly the same), and from and to the constr entry. E.g.:
+
+ Notation "[ expr ]" := expr (expr custom group at level 1).
+ Notation "( x )" := x (in custom group at level 0, x at level 1).
+ Notation "{ x }" := x (in custom group at level 0, x constr).
+
+ Supporting any level is maybe overkill in that coercions are
+ commonly from the lowest level of the source entry to the highest
+ level of the target entry. *)
+
+type entry_coercion = notation list
+
+module EntryCoercionOrd =
+ struct
+ type t = notation_entry * notation_entry
+ let compare = Pervasives.compare
+ end
+
+module EntryCoercionMap = Map.Make(EntryCoercionOrd)
+
+let entry_coercion_map = ref EntryCoercionMap.empty
+
+let level_ord lev lev' =
+ match lev, lev' with
+ | None, _ -> true
+ | _, None -> true
+ | Some n, Some n' -> n <= n'
+
+let rec search nfrom nto = function
+ | [] -> raise Not_found
+ | ((pfrom,pto),coe)::l ->
+ if level_ord pfrom nfrom && level_ord nto pto then coe else search nfrom nto l
+
+let decompose_custom_entry = function
+ | InConstrEntrySomeLevel -> InConstrEntry, None
+ | InCustomEntryLevel (s,n) -> InCustomEntry s, Some n
+
+let availability_of_entry_coercion entry entry' =
+ let entry, lev = decompose_custom_entry entry in
+ let entry', lev' = decompose_custom_entry entry' in
+ if notation_entry_eq entry entry' && level_ord lev' lev then Some []
+ else
+ try Some (search lev lev' (EntryCoercionMap.find (entry,entry') !entry_coercion_map))
+ with Not_found -> None
+
+let better_path ((lev1,lev2),path) ((lev1',lev2'),path') =
+ (* better = shorter and lower source and higher target *)
+ level_ord lev1 lev1' && level_ord lev2' lev2 && List.length path <= List.length path'
+
+let shorter_path (_,path) (_,path') =
+ List.length path <= List.length path'
+
+let rec insert_coercion_path path = function
+ | [] -> [path]
+ | path'::paths as allpaths ->
+ (* If better or equal we keep the more recent one *)
+ if better_path path path' then path::paths
+ else if better_path path' path then allpaths
+ else if shorter_path path path' then path::allpaths
+ else path'::insert_coercion_path path paths
+
+let declare_entry_coercion (entry,_ as ntn) entry' =
+ let entry, lev = decompose_custom_entry entry in
+ let entry', lev' = decompose_custom_entry entry' in
+ (* Transitive closure *)
+ let toaddleft =
+ EntryCoercionMap.fold (fun (entry'',entry''') paths l ->
+ List.fold_right (fun ((lev'',lev'''),path) l ->
+ if notation_entry_eq entry entry''' && level_ord lev lev''' &&
+ not (notation_entry_eq entry' entry'')
+ then ((entry'',entry'),((lev'',lev'),path@[ntn]))::l else l) paths l)
+ !entry_coercion_map [] in
+ let toaddright =
+ EntryCoercionMap.fold (fun (entry'',entry''') paths l ->
+ List.fold_right (fun ((lev'',lev'''),path) l ->
+ if entry' = entry'' && level_ord lev' lev'' && entry <> entry'''
+ then ((entry,entry'''),((lev,lev'''),path@[ntn]))::l else l) paths l)
+ !entry_coercion_map [] in
+ entry_coercion_map :=
+ List.fold_right (fun (pair,path) ->
+ let olds = try EntryCoercionMap.find pair !entry_coercion_map with Not_found -> [] in
+ EntryCoercionMap.add pair (insert_coercion_path path olds))
+ (((entry,entry'),((lev,lev'),[ntn]))::toaddright@toaddleft)
+ !entry_coercion_map
+
+let entry_has_global_map = ref String.Map.empty
+
+let declare_custom_entry_has_global s n =
+ try
+ let p = String.Map.find s !entry_has_global_map in
+ user_err (str "Custom entry " ++ str s ++
+ str " has already a rule for global references at level " ++ int p ++ str ".")
+ with Not_found ->
+ entry_has_global_map := String.Map.add s n !entry_has_global_map
+
+let entry_has_global = function
+ | InConstrEntrySomeLevel -> true
+ | InCustomEntryLevel (s,n) ->
+ try String.Map.find s !entry_has_global_map <= n with Not_found -> false
+
+let entry_has_ident_map = ref String.Map.empty
+
+let declare_custom_entry_has_ident s n =
+ try
+ let p = String.Map.find s !entry_has_ident_map in
+ user_err (str "Custom entry " ++ str s ++
+ str " has already a rule for global references at level " ++ int p ++ str ".")
+ with Not_found ->
+ entry_has_ident_map := String.Map.add s n !entry_has_ident_map
+
+let entry_has_ident = function
+ | InConstrEntrySomeLevel -> true
+ | InCustomEntryLevel (s,n) ->
+ try String.Map.find s !entry_has_ident_map <= n with Not_found -> false
+
let uninterp_prim_token c =
try
let (sc,numpr,_) =
@@ -565,7 +706,8 @@ let ntpe_eq t1 t2 = match t1, t2 with
| NtnTypeBinderList, NtnTypeBinderList -> true
| (NtnTypeConstr | NtnTypeBinder _ | NtnTypeConstrList | NtnTypeBinderList), _ -> false
-let var_attributes_eq (_, (sc1, tp1)) (_, (sc2, tp2)) =
+let var_attributes_eq (_, ((entry1, sc1), tp1)) (_, ((entry2, sc2), tp2)) =
+ notation_entry_level_eq entry1 entry2 &&
pair_eq (Option.equal String.equal) (List.equal String.equal) sc1 sc2 &&
ntpe_eq tp1 tp2
@@ -577,7 +719,7 @@ let exists_notation_in_scope scopt ntn onlyprint r =
let scope = match scopt with Some s -> s | None -> default_scope in
try
let sc = String.Map.find scope !scope_map in
- let n = String.Map.find ntn sc.notations in
+ let n = NotationMap.find ntn sc.notations in
interpretation_eq n.not_interp r
with Not_found -> false
@@ -793,10 +935,10 @@ let rec string_of_symbol = function
let l = List.flatten (List.map string_of_symbol l) in "_"::l@".."::l@["_"]
| Break _ -> []
-let make_notation_key symbols =
- String.concat " " (List.flatten (List.map string_of_symbol symbols))
+let make_notation_key from symbols =
+ (from,String.concat " " (List.flatten (List.map string_of_symbol symbols)))
-let decompose_notation_key s =
+let decompose_notation_key (from,s) =
let len = String.length s in
let rec decomp_ntn dirs n =
if n>=len then List.rev dirs else
@@ -811,7 +953,7 @@ let decompose_notation_key s =
| s -> Terminal (String.drop_simple_quotes s) in
decomp_ntn (tok::dirs) (pos+1)
in
- decomp_ntn [] 0
+ from, decomp_ntn [] 0
(************)
(* Printing *)
@@ -840,14 +982,14 @@ let pr_notation_info prglob ntn c =
let pr_named_scope prglob scope sc =
(if String.equal scope default_scope then
- match String.Map.cardinal sc.notations with
+ match NotationMap.cardinal sc.notations with
| 0 -> str "No lonely notation"
| n -> str "Lonely notation" ++ (if Int.equal n 1 then mt() else str"s")
else
str "Scope " ++ str scope ++ fnl () ++ pr_delimiters_info sc.delimiters)
++ fnl ()
++ pr_scope_classes scope
- ++ String.Map.fold
+ ++ NotationMap.fold
(fun ntn { not_interp = (_, r); not_location = (_, df) } strm ->
pr_notation_info prglob df r ++ fnl () ++ strm)
sc.notations (mt ())
@@ -862,11 +1004,11 @@ let pr_scopes prglob =
let rec find_default ntn = function
| [] -> None
| Scope scope :: scopes ->
- if String.Map.mem ntn (find_scope scope).notations then
+ if NotationMap.mem ntn (find_scope scope).notations then
Some scope
else find_default ntn scopes
| SingleNotation ntn' :: scopes ->
- if String.equal ntn ntn' then Some default_scope
+ if notation_eq ntn ntn' then Some default_scope
else find_default ntn scopes
let factorize_entries = function
@@ -875,7 +1017,7 @@ let factorize_entries = function
let (ntn,l_of_ntn,rest) =
List.fold_left
(fun (a',l,rest) (a,c) ->
- if String.equal a a' then (a',c::l,rest) else (a,[c],(a',l)::rest))
+ if notation_eq a a' then (a',c::l,rest) else (a,[c],(a',l)::rest))
(ntn,[c],[]) l in
(ntn,l_of_ntn)::rest
@@ -930,15 +1072,15 @@ let possible_notations ntn =
(* Only "_ U _" format *)
[ntn]
else
- let ntn' = make_notation_key (raw_analyze_notation_tokens toks) in
+ let _,ntn' = make_notation_key None (raw_analyze_notation_tokens toks) in
if String.equal ntn ntn' then (* Only symbols *) [ntn] else [ntn;ntn']
let browse_notation strict ntn map =
let ntns = possible_notations ntn in
- let find ntn' ntn =
+ let find (from,ntn' as fullntn') ntn =
if String.contains ntn ' ' then String.equal ntn ntn'
else
- let toks = decompose_notation_key ntn' in
+ let _,toks = decompose_notation_key fullntn' in
let get_terminals = function Terminal ntn -> Some ntn | _ -> None in
let trms = List.map_filter get_terminals toks in
if strict then String.List.equal [ntn] trms
@@ -947,10 +1089,10 @@ let browse_notation strict ntn map =
let l =
String.Map.fold
(fun scope_name sc ->
- String.Map.fold (fun ntn { not_interp = (_, r); not_location = df } l ->
+ NotationMap.fold (fun ntn { not_interp = (_, r); not_location = df } l ->
if List.exists (find ntn) ntns then (ntn,(scope_name,r,df))::l else l) sc.notations)
map [] in
- List.sort (fun x y -> String.compare (fst x) (fst y)) l
+ List.sort (fun x y -> String.compare (snd (fst x)) (snd (fst y))) l
let global_reference_of_notation test (ntn,(sc,c,_)) =
match c with
@@ -1011,9 +1153,9 @@ let locate_notation prglob ntn scope =
let collect_notation_in_scope scope sc known =
assert (not (String.equal scope default_scope));
- String.Map.fold
+ NotationMap.fold
(fun ntn { not_interp = (_, r); not_location = (_, df) } (l,known as acc) ->
- if String.List.mem ntn known then acc else ((df,r)::l,ntn::known))
+ if List.mem_f notation_eq ntn known then acc else ((df,r)::l,ntn::known))
sc.notations ([],known)
let collect_notations stack =
@@ -1026,10 +1168,10 @@ let collect_notations stack =
collect_notation_in_scope scope (find_scope scope) knownntn in
((scope,l)::all,knownntn)
| SingleNotation ntn ->
- if String.List.mem ntn knownntn then (all,knownntn)
+ if List.mem_f notation_eq ntn knownntn then (all,knownntn)
else
let { not_interp = (_, r); not_location = (_, df) } =
- String.Map.find ntn (find_scope default_scope).notations in
+ NotationMap.find ntn (find_scope default_scope).notations in
let all' = match all with
| (s,lonelyntn)::rest when String.equal s default_scope ->
(s,(df,r)::lonelyntn)::rest
@@ -1063,15 +1205,20 @@ let pr_visibility prglob = function
let freeze _ =
(!scope_map, !scope_stack, !arguments_scope,
- !delimiters_map, !notations_key_table, !scope_class_map)
+ !delimiters_map, !notations_key_table, !scope_class_map,
+ !entry_coercion_map, !entry_has_global_map,
+ !entry_has_ident_map)
-let unfreeze (scm,scs,asc,dlm,fkm,clsc) =
+let unfreeze (scm,scs,asc,dlm,fkm,clsc,coe,globs,ids) =
scope_map := scm;
scope_stack := scs;
delimiters_map := dlm;
arguments_scope := asc;
notations_key_table := fkm;
- scope_class_map := clsc
+ scope_class_map := clsc;
+ entry_coercion_map := coe;
+ entry_has_global_map := globs;
+ entry_has_ident_map := ids
let init () =
init_scope_map ();