aboutsummaryrefslogtreecommitdiff
path: root/plugins/syntax
diff options
context:
space:
mode:
authorPierre Roux2020-09-03 13:12:00 +0200
committerPierre Roux2020-11-04 20:53:47 +0100
commit11a8997dd8fa83537607272692a3baf10dab342a (patch)
tree2b88f003ab19f264d94f29806c28b48258800d28 /plugins/syntax
parentdfcb15141a19db4f1cc61c14d1cdad0275009356 (diff)
[numeral notation] Adding the via ... using ... option
This enables numeral notations for non inductive types by pre/postprocessing them to a given proxy inductive type. For instance, this should enable the use of numeral notations for R.
Diffstat (limited to 'plugins/syntax')
-rw-r--r--plugins/syntax/g_numeral.mlg58
-rw-r--r--plugins/syntax/numeral.ml233
-rw-r--r--plugins/syntax/numeral.mli11
3 files changed, 285 insertions, 17 deletions
diff --git a/plugins/syntax/g_numeral.mlg b/plugins/syntax/g_numeral.mlg
index 48e262c3ef..e60ae45b01 100644
--- a/plugins/syntax/g_numeral.mlg
+++ b/plugins/syntax/g_numeral.mlg
@@ -19,33 +19,71 @@ open Names
open Stdarg
open Pcoq.Prim
-let pr_number_modifier = function
+let pr_number_after = function
| Nop -> mt ()
- | Warning n -> str "(warning after " ++ NumTok.UnsignedNat.print n ++ str ")"
- | Abstract n -> str "(abstract after " ++ NumTok.UnsignedNat.print n ++ str ")"
+ | Warning n -> str "warning after " ++ NumTok.UnsignedNat.print n
+ | Abstract n -> str "abstract after " ++ NumTok.UnsignedNat.print n
+
+let pr_deprecated_number_modifier m = str "(" ++ pr_number_after m ++ str ")"
let warn_deprecated_numeral_notation =
CWarnings.create ~name:"numeral-notation" ~category:"deprecated"
(fun () ->
strbrk "Numeral Notation is deprecated, please use Number Notation instead.")
+let pr_number_mapping (n, n') =
+ Libnames.pr_qualid n ++ spc () ++ str "=>" ++ spc () ++ Libnames.pr_qualid n'
+
+let pr_number_via (n, l) =
+ str "via " ++ Libnames.pr_qualid n ++ str " mapping ["
+ ++ prlist_with_sep pr_comma pr_number_mapping l ++ str "]"
+
+let pr_number_modifier = function
+ | After a -> pr_number_after a
+ | Via nl -> pr_number_via nl
+
+let pr_number_options l =
+ str "(" ++ prlist_with_sep pr_comma pr_number_modifier l ++ str ")"
+
}
-VERNAC ARGUMENT EXTEND number_modifier
- PRINTED BY { pr_number_modifier }
+VERNAC ARGUMENT EXTEND deprecated_number_modifier
+ PRINTED BY { pr_deprecated_number_modifier }
| [ ] -> { Nop }
| [ "(" "warning" "after" bignat(waft) ")" ] -> { Warning (NumTok.UnsignedNat.of_string waft) }
| [ "(" "abstract" "after" bignat(n) ")" ] -> { Abstract (NumTok.UnsignedNat.of_string n) }
END
+VERNAC ARGUMENT EXTEND number_mapping
+ PRINTED BY { pr_number_mapping }
+| [ reference(n) "=>" reference(n') ] -> { n, n' }
+END
+
+VERNAC ARGUMENT EXTEND number_via
+ PRINTED BY { pr_number_via }
+| [ "via" reference(n) "mapping" "[" ne_number_mapping_list_sep(l, ",") "]" ] -> { n, l }
+END
+
+VERNAC ARGUMENT EXTEND number_modifier
+ PRINTED BY { pr_number_modifier }
+| [ "warning" "after" bignat(waft) ] -> { After (Warning (NumTok.UnsignedNat.of_string waft)) }
+| [ "abstract" "after" bignat(n) ] -> { After (Abstract (NumTok.UnsignedNat.of_string n)) }
+| [ number_via(v) ] -> { Via v }
+END
+
+VERNAC ARGUMENT EXTEND number_options
+ PRINTED BY { pr_number_options }
+| [ "(" ne_number_modifier_list_sep(l, ",") ")" ] -> { l }
+END
+
VERNAC COMMAND EXTEND NumberNotation CLASSIFIED AS SIDEFF
- | #[ locality = Attributes.locality; ] [ "Number" "Notation" reference(ty) reference(f) reference(g) ":"
- ident(sc) number_modifier(o) ] ->
+ | #[ locality = Attributes.locality; ] [ "Number" "Notation" reference(ty) reference(f) reference(g) number_options_opt(nl) ":"
+ ident(sc) ] ->
- { vernac_number_notation (Locality.make_module_locality locality) ty f g (Id.to_string sc) o }
+ { vernac_number_notation (Locality.make_module_locality locality) ty f g (Option.default [] nl) (Id.to_string sc) }
| #[ locality = Attributes.locality; ] [ "Numeral" "Notation" reference(ty) reference(f) reference(g) ":"
- ident(sc) number_modifier(o) ] ->
+ ident(sc) deprecated_number_modifier(o) ] ->
{ warn_deprecated_numeral_notation ();
- vernac_number_notation (Locality.make_module_locality locality) ty f g (Id.to_string sc) o }
+ vernac_number_notation (Locality.make_module_locality locality) ty f g [After o] (Id.to_string sc) }
END
diff --git a/plugins/syntax/numeral.ml b/plugins/syntax/numeral.ml
index ad90a9a982..316ca456a4 100644
--- a/plugins/syntax/numeral.ml
+++ b/plugins/syntax/numeral.ml
@@ -16,8 +16,16 @@ open Constrexpr
open Constrexpr_ops
open Notation
+module CSet = CSet.Make (Constr)
+module CMap = CMap.Make (Constr)
+
(** * Number notation *)
+type number_string_via = qualid * (qualid * qualid) list
+type number_option =
+ | After of numnot_option
+ | Via of number_string_via
+
let warn_abstract_large_num_no_op =
CWarnings.create ~name:"abstract-large-number-no-op" ~category:"numbers"
(fun f ->
@@ -128,12 +136,228 @@ let warn_deprecated_decimal =
Decimal.int or Decimal.decimal. Use Number.uint, \
Number.int or Number.number respectively.")
-let vernac_number_notation local ty f g scope opts =
+let remapping_error ?loc ty ty' ty'' =
+ CErrors.user_err ?loc
+ (Printer.pr_global ty
+ ++ str " was already mapped to" ++ spc () ++ Printer.pr_global ty'
+ ++ str " and cannot be remapped to" ++ spc () ++ Printer.pr_global ty''
+ ++ str ".")
+
+let error_missing c =
+ CErrors.user_err
+ (str "Missing mapping for constructor " ++ Printer.pr_global c ++ str ".")
+
+let pr_constr env sigma c =
+ let c = Constrextern.extern_constr env sigma (EConstr.of_constr c) in
+ Ppconstr.pr_constr_expr env sigma c
+
+let warn_via_remapping =
+ CWarnings.create ~name:"via-type-remapping" ~category:"numbers"
+ (fun (env, sigma, ty, ty', ty'') ->
+ let constr = pr_constr env sigma in
+ constr ty ++ str " was already mapped to" ++ spc () ++ constr ty'
+ ++ str ", mapping it also to" ++ spc () ++ constr ty''
+ ++ str " might yield ill typed terms when using the notation.")
+
+let warn_via_type_mismatch =
+ CWarnings.create ~name:"via-type-mismatch" ~category:"numbers"
+ (fun (env, sigma, g, g', exp, actual) ->
+ let constr = pr_constr env sigma in
+ str "Type of" ++ spc() ++ Printer.pr_global g
+ ++ str " seems incompatible with the type of" ++ spc ()
+ ++ Printer.pr_global g' ++ str "." ++ spc ()
+ ++ str "Expected type is: " ++ constr exp ++ spc ()
+ ++ str "instead of " ++ constr actual ++ str "." ++ spc ()
+ ++ str "This might yield ill typed terms when using the notation.")
+
+let multiple_via_error () =
+ CErrors.user_err (Pp.str "Multiple 'via' options.")
+
+let multiple_after_error () =
+ CErrors.user_err (Pp.str "Multiple 'warning after' or 'abstract after' options.")
+
+let via_abstract_error () =
+ CErrors.user_err (Pp.str "'via' and 'abstract' cannot be used together.")
+
+let locate_global_sort_inductive_or_constant sigma qid =
+ let locate_sort qid =
+ match Nametab.locate_extended qid with
+ | Globnames.TrueGlobal _ -> raise Not_found
+ | Globnames.SynDef kn ->
+ match Syntax_def.search_syntactic_definition kn with
+ | [], Notation_term.NSort r ->
+ let sigma,c = Evd.fresh_sort_in_family sigma (Glob_ops.glob_sort_family r) in
+ sigma,Constr.mkSort c
+ | _ -> raise Not_found in
+ try locate_sort qid
+ with Not_found ->
+ match Smartlocate.global_with_alias qid with
+ | GlobRef.IndRef i -> sigma, Constr.mkInd i
+ | _ -> sigma, Constr.mkConst (Smartlocate.global_constant_with_alias qid)
+
+let locate_global_constructor_inductive_or_constant qid =
+ let g = Smartlocate.global_with_alias qid in
+ match g with
+ | GlobRef.ConstructRef c -> g, Constr.mkConstruct c
+ | GlobRef.IndRef i -> g, Constr.mkInd i
+ | _ -> g, Constr.mkConst (Smartlocate.global_constant_with_alias qid)
+
+(* [get_type env sigma c] retrieves the type of [c] and returns a pair
+ [l, t] such that [c : l_0 -> ... -> l_n -> t]. *)
+let get_type env sigma c =
+ (* inspired from [compute_implicit_names] in "interp/impargs.ml" *)
+ let rec aux env acc t =
+ let t = Reductionops.whd_all env sigma t in
+ match EConstr.kind sigma t with
+ | Constr.Prod (na, a, b) ->
+ let a = Reductionops.whd_all env sigma a in
+ let rel = Context.Rel.Declaration.LocalAssum (na, a) in
+ aux (EConstr.push_rel rel env) ((na, a) :: acc) b
+ | _ -> List.rev acc, t in
+ let t = Retyping.get_type_of env sigma (EConstr.of_constr c) in
+ let l, t = aux env [] t in
+ List.map (fun (na, a) -> na, EConstr.Unsafe.to_constr a) l,
+ EConstr.Unsafe.to_constr t
+
+(* [elaborate_to_post env sigma ty_name ty_ind l] builds the [to_post]
+ translation (c.f., interp/notation.mli) for the number notation to
+ parse/print type [ty_name] through the inductive [ty_ind] according
+ to the pairs [constant, constructor] in the list [l]. *)
+let elaborate_to_post env sigma ty_name ty_ind l =
+ let sigma, ty_name =
+ locate_global_sort_inductive_or_constant sigma ty_name in
+ let ty_ind = Constr.mkInd ty_ind in
+ (* Retrieve constants and constructors mappings and their type.
+ For each constant [cnst] and inductive constructor [indc] in [l], retrieve:
+ * its location: [lcnst] and [lindc]
+ * its GlobRef: [cnst] and [indc]
+ * its type: [tcnst] and [tindc] (decomposed in product by [get_type] above) *)
+ let l =
+ let read (cnst, indc) =
+ let lcnst, lindc = cnst.CAst.loc, indc.CAst.loc in
+ let cnst, ccnst = locate_global_constructor_inductive_or_constant cnst in
+ let indc, cindc =
+ let indc = Smartlocate.global_constructor_with_alias indc in
+ GlobRef.ConstructRef indc, Constr.mkConstruct indc in
+ let get_type_wo_params c =
+ (* ignore parameters of inductive types *)
+ let rm_params c = match Constr.kind c with
+ | Constr.App (c, _) when Constr.isInd c -> c
+ | _ -> c in
+ let lc, tc = get_type env sigma c in
+ List.map (fun (n, c) -> n, rm_params c) lc, rm_params tc in
+ let tcnst, tindc = get_type_wo_params ccnst, get_type_wo_params cindc in
+ lcnst, cnst, tcnst, lindc, indc, tindc in
+ List.map read l in
+ let eq_indc indc (_, _, _, _, indc', _) = GlobRef.equal indc indc' in
+ (* Collect all inductive types involved.
+ That is [ty_ind] and all final codomains of [tindc] above. *)
+ let inds =
+ List.fold_left (fun s (_, _, _, _, _, tindc) -> CSet.add (snd tindc) s)
+ (CSet.singleton ty_ind) l in
+ (* And for each inductive, retrieve its constructors. *)
+ let constructors =
+ CSet.fold (fun ind m ->
+ let inductive, _ = Constr.destInd ind in
+ CMap.add ind (get_constructors inductive) m)
+ inds CMap.empty in
+ (* Error if one [constructor] in some inductive in [inds]
+ doesn't appear exactly once in [l] *)
+ let _ = (* check_for duplicate constructor and error *)
+ List.fold_left (fun already_seen (_, cnst, _, loc, indc, _) ->
+ try
+ let cnst' = List.assoc_f GlobRef.equal indc already_seen in
+ remapping_error ?loc indc cnst' cnst
+ with Not_found -> (indc, cnst) :: already_seen)
+ [] l in
+ let () = (* check for missing constructor and error *)
+ CMap.iter (fun _ ->
+ List.iter (fun cstr ->
+ if not (List.exists (eq_indc cstr) l) then error_missing cstr))
+ constructors in
+ (* Perform some checks on types and warn if they look strange.
+ These checks are neither sound nor complete, so we only warn. *)
+ let () =
+ (* associate inductives to types, and check that this mapping is one to one
+ and maps [ty_ind] to [ty_name] *)
+ let ind2ty, ty2ind =
+ let add loc ckey cval m =
+ match CMap.find_opt ckey m with
+ | None -> CMap.add ckey cval m
+ | Some old_cval ->
+ if not (Constr.equal old_cval cval) then
+ warn_via_remapping ?loc (env, sigma, ckey, old_cval, cval);
+ m in
+ List.fold_left
+ (fun (ind2ty, ty2ind) (lcnst, _, (_, tcnst), lindc, _, (_, tindc)) ->
+ add lcnst tindc tcnst ind2ty, add lindc tcnst tindc ty2ind)
+ CMap.(singleton ty_ind ty_name, singleton ty_name ty_ind) l in
+ (* check that type of constants and constructors mapped in [l]
+ match modulo [ind2ty] *)
+ let replace m (l, t) =
+ let apply_m c = try CMap.find c m with Not_found -> c in
+ List.fold_right (fun (na, a) b -> Constr.mkProd (na, (apply_m a), b))
+ l (apply_m t) in
+ List.iter (fun (_, cnst, tcnst, loc, indc, tindc) ->
+ let tcnst' = replace CMap.empty tcnst in
+ if not (Constr.equal tcnst' (replace ind2ty tindc)) then
+ let actual = replace CMap.empty tindc in
+ let expected = replace ty2ind tcnst in
+ warn_via_type_mismatch ?loc (env, sigma, indc, cnst, expected, actual))
+ l in
+ (* Associate an index to each inductive, starting from 0 for [ty_ind]. *)
+ let ind2num, num2ind, nb_ind =
+ CMap.fold (fun ind _ (ind2num, num2ind, i) ->
+ CMap.add ind i ind2num, Int.Map.add i ind num2ind, i + 1)
+ (CMap.remove ty_ind constructors)
+ (CMap.singleton ty_ind 0, Int.Map.singleton 0 ty_ind, 1) in
+ (* Finally elaborate [to_post] *)
+ let to_post =
+ let rec map_prod = function
+ | [] -> []
+ | (_, a) :: b ->
+ let t = match CMap.find_opt a ind2num with
+ | Some i -> ToPostAs i
+ | None -> ToPostCopy in
+ t :: map_prod b in
+ Array.init nb_ind (fun i ->
+ List.map (fun indc ->
+ let _, cnst, _, _, _, tindc = List.find (eq_indc indc) l in
+ indc, cnst, map_prod (fst tindc))
+ (CMap.find (Int.Map.find i num2ind) constructors)) in
+ (* and use constants mapped to constructors of [ty_ind] as triggers. *)
+ let pt_refs = List.map (fun (_, cnst, _) -> cnst) (to_post.(0)) in
+ to_post, pt_refs
+
+let elaborate_to_post env sigma ty_ind via =
+ match via with
+ | None -> [||], get_constructors ty_ind
+ | Some (ty_name, l) -> elaborate_to_post env sigma ty_name ty_ind l
+
+let vernac_number_notation local ty f g opts scope =
+ let rec parse_opts = function
+ | [] -> None, Nop
+ | h :: opts ->
+ let via, opts = parse_opts opts in
+ let via = match h, via with
+ | Via _, Some _ -> multiple_via_error ()
+ | Via v, None -> Some v
+ | _ -> via in
+ let opts = match h, opts with
+ | After _, (Warning _ | Abstract _) -> multiple_after_error ()
+ | After a, Nop -> a
+ | _ -> opts in
+ via, opts in
+ let via, opts = parse_opts opts in
+ (match via, opts with Some _, Abstract _ -> via_abstract_error () | _ -> ());
let env = Global.env () in
let sigma = Evd.from_env env in
let num_ty = locate_number () in
let z_pos_ty = locate_z () in
let int63_ty = locate_int63 () in
+ let ty_name = ty in
+ let ty, via =
+ match via with None -> ty, via | Some (ty', a) -> ty', Some (ty, a) in
let tyc = Smartlocate.global_inductive_with_alias ty in
let to_ty = Smartlocate.global_with_alias f in
let of_ty = Smartlocate.global_with_alias g in
@@ -143,7 +367,6 @@ let vernac_number_notation local ty f g scope opts =
mkProdC ([CAst.make Anonymous],Default Glob_term.Explicit, x, y)
in
let opt r = app (mkRefC (q_option ())) r in
- let constructors = get_constructors tyc in
(* Check the type of f *)
let to_kind =
match num_ty with
@@ -199,8 +422,8 @@ let vernac_number_notation local ty f g scope opts =
| _, ((DecimalInt _ | DecimalUInt _ | Decimal _), _) ->
warn_deprecated_decimal ()
| _ -> ());
- let o = { to_kind; to_ty; to_post = [||]; of_kind; of_ty;
- ty_name = ty;
+ let to_post, pt_refs = elaborate_to_post env sigma tyc via in
+ let o = { to_kind; to_ty; to_post; of_kind; of_ty; ty_name;
warning = opts }
in
(match opts, to_kind with
@@ -211,7 +434,7 @@ let vernac_number_notation local ty f g scope opts =
pt_scope = scope;
pt_interp_info = NumberNotation o;
pt_required = Nametab.path_of_global (GlobRef.IndRef tyc),[];
- pt_refs = constructors;
+ pt_refs;
pt_in_match = true }
in
enable_prim_token_interpretation i
diff --git a/plugins/syntax/numeral.mli b/plugins/syntax/numeral.mli
index d5fe42b0b4..1f6896d549 100644
--- a/plugins/syntax/numeral.mli
+++ b/plugins/syntax/numeral.mli
@@ -14,6 +14,13 @@ open Notation
(** * Number notation *)
+type number_string_via = qualid * (qualid * qualid) list
+type number_option =
+ | After of numnot_option
+ | Via of number_string_via
+
val vernac_number_notation : locality_flag ->
- qualid -> qualid -> qualid ->
- Notation_term.scope_name -> numnot_option -> unit
+ qualid ->
+ qualid -> qualid ->
+ number_option list ->
+ Notation_term.scope_name -> unit