aboutsummaryrefslogtreecommitdiff
path: root/plugins
diff options
context:
space:
mode:
authorPierre Roux2020-09-03 13:23:00 +0200
committerPierre Roux2020-11-05 00:20:19 +0100
commit0520decfdc94d52a2f8658b9cf6a730e6d333f8f (patch)
tree56130ed8dafe578760221bbc6e7d7d835ac4791c /plugins
parent9082af80f5bb70ff2b75117f9e5cc3165b1c8b42 (diff)
[numeral notation] Handle implicit arguments
Diffstat (limited to 'plugins')
-rw-r--r--plugins/syntax/g_numeral.mlg12
-rw-r--r--plugins/syntax/numeral.ml51
-rw-r--r--plugins/syntax/numeral.mli2
3 files changed, 43 insertions, 22 deletions
diff --git a/plugins/syntax/g_numeral.mlg b/plugins/syntax/g_numeral.mlg
index e60ae45b01..a3cc786a4a 100644
--- a/plugins/syntax/g_numeral.mlg
+++ b/plugins/syntax/g_numeral.mlg
@@ -31,8 +31,13 @@ let warn_deprecated_numeral_notation =
(fun () ->
strbrk "Numeral Notation is deprecated, please use Number Notation instead.")
-let pr_number_mapping (n, n') =
- Libnames.pr_qualid n ++ spc () ++ str "=>" ++ spc () ++ Libnames.pr_qualid n'
+let pr_number_mapping (b, n, n') =
+ if b then
+ str "[" ++ Libnames.pr_qualid n ++ str "]" ++ spc () ++ str "=>" ++ spc ()
+ ++ Libnames.pr_qualid n'
+ else
+ Libnames.pr_qualid n ++ spc () ++ str "=>" ++ spc ()
+ ++ Libnames.pr_qualid n'
let pr_number_via (n, l) =
str "via " ++ Libnames.pr_qualid n ++ str " mapping ["
@@ -56,7 +61,8 @@ END
VERNAC ARGUMENT EXTEND number_mapping
PRINTED BY { pr_number_mapping }
-| [ reference(n) "=>" reference(n') ] -> { n, n' }
+| [ reference(n) "=>" reference(n') ] -> { false, n, n' }
+| [ "[" reference(n) "]" "=>" reference(n') ] -> { true, n, n' }
END
VERNAC ARGUMENT EXTEND number_via
diff --git a/plugins/syntax/numeral.ml b/plugins/syntax/numeral.ml
index 316ca456a4..1efe6b77d1 100644
--- a/plugins/syntax/numeral.ml
+++ b/plugins/syntax/numeral.ml
@@ -21,7 +21,7 @@ module CMap = CMap.Make (Constr)
(** * Number notation *)
-type number_string_via = qualid * (qualid * qualid) list
+type number_string_via = qualid * (bool * qualid * qualid) list
type number_option =
| After of numnot_option
| Via of number_string_via
@@ -231,9 +231,10 @@ let elaborate_to_post env sigma ty_name ty_ind l =
For each constant [cnst] and inductive constructor [indc] in [l], retrieve:
* its location: [lcnst] and [lindc]
* its GlobRef: [cnst] and [indc]
- * its type: [tcnst] and [tindc] (decomposed in product by [get_type] above) *)
+ * its type: [tcnst] and [tindc] (decomposed in product by [get_type] above)
+ * [impls] are the implicit arguments of [cnst] *)
let l =
- let read (cnst, indc) =
+ let read (consider_implicits, cnst, indc) =
let lcnst, lindc = cnst.CAst.loc, indc.CAst.loc in
let cnst, ccnst = locate_global_constructor_inductive_or_constant cnst in
let indc, cindc =
@@ -247,13 +248,16 @@ let elaborate_to_post env sigma ty_name ty_ind l =
let lc, tc = get_type env sigma c in
List.map (fun (n, c) -> n, rm_params c) lc, rm_params tc in
let tcnst, tindc = get_type_wo_params ccnst, get_type_wo_params cindc in
- lcnst, cnst, tcnst, lindc, indc, tindc in
+ let impls =
+ if not consider_implicits then [] else
+ Impargs.(select_stronger_impargs (implicits_of_global cnst)) in
+ lcnst, cnst, tcnst, lindc, indc, tindc, impls in
List.map read l in
- let eq_indc indc (_, _, _, _, indc', _) = GlobRef.equal indc indc' in
+ let eq_indc indc (_, _, _, _, indc', _, _) = GlobRef.equal indc indc' in
(* Collect all inductive types involved.
That is [ty_ind] and all final codomains of [tindc] above. *)
let inds =
- List.fold_left (fun s (_, _, _, _, _, tindc) -> CSet.add (snd tindc) s)
+ List.fold_left (fun s (_, _, _, _, _, tindc, _) -> CSet.add (snd tindc) s)
(CSet.singleton ty_ind) l in
(* And for each inductive, retrieve its constructors. *)
let constructors =
@@ -264,7 +268,7 @@ let elaborate_to_post env sigma ty_name ty_ind l =
(* Error if one [constructor] in some inductive in [inds]
doesn't appear exactly once in [l] *)
let _ = (* check_for duplicate constructor and error *)
- List.fold_left (fun already_seen (_, cnst, _, loc, indc, _) ->
+ List.fold_left (fun already_seen (_, cnst, _, loc, indc, _, _) ->
try
let cnst' = List.assoc_f GlobRef.equal indc already_seen in
remapping_error ?loc indc cnst' cnst
@@ -289,16 +293,23 @@ let elaborate_to_post env sigma ty_name ty_ind l =
warn_via_remapping ?loc (env, sigma, ckey, old_cval, cval);
m in
List.fold_left
- (fun (ind2ty, ty2ind) (lcnst, _, (_, tcnst), lindc, _, (_, tindc)) ->
+ (fun (ind2ty, ty2ind) (lcnst, _, (_, tcnst), lindc, _, (_, tindc), _) ->
add lcnst tindc tcnst ind2ty, add lindc tcnst tindc ty2ind)
CMap.(singleton ty_ind ty_name, singleton ty_name ty_ind) l in
(* check that type of constants and constructors mapped in [l]
match modulo [ind2ty] *)
+ let rm_impls impls (l, t) =
+ let rec aux impls l = match impls, l with
+ | Some _ :: impls, _ :: b -> aux impls b
+ | None :: impls, (n, a) :: b -> (n, a) :: aux impls b
+ | _ -> l in
+ aux impls l, t in
let replace m (l, t) =
let apply_m c = try CMap.find c m with Not_found -> c in
List.fold_right (fun (na, a) b -> Constr.mkProd (na, (apply_m a), b))
l (apply_m t) in
- List.iter (fun (_, cnst, tcnst, loc, indc, tindc) ->
+ List.iter (fun (_, cnst, tcnst, loc, indc, tindc, impls) ->
+ let tcnst = rm_impls impls tcnst in
let tcnst' = replace CMap.empty tcnst in
if not (Constr.equal tcnst' (replace ind2ty tindc)) then
let actual = replace CMap.empty tindc in
@@ -313,17 +324,21 @@ let elaborate_to_post env sigma ty_name ty_ind l =
(CMap.singleton ty_ind 0, Int.Map.singleton 0 ty_ind, 1) in
(* Finally elaborate [to_post] *)
let to_post =
- let rec map_prod = function
- | [] -> []
- | (_, a) :: b ->
- let t = match CMap.find_opt a ind2num with
- | Some i -> ToPostAs i
- | None -> ToPostCopy in
- t :: map_prod b in
+ let rec map_prod impls tindc = match impls with
+ | Some _ :: impls -> ToPostHole :: map_prod impls tindc
+ | _ ->
+ match tindc with
+ | [] -> []
+ | (_, a) :: b ->
+ let t = match CMap.find_opt a ind2num with
+ | Some i -> ToPostAs i
+ | None -> ToPostCopy in
+ let impls = match impls with [] -> [] | _ :: t -> t in
+ t :: map_prod impls b in
Array.init nb_ind (fun i ->
List.map (fun indc ->
- let _, cnst, _, _, _, tindc = List.find (eq_indc indc) l in
- indc, cnst, map_prod (fst tindc))
+ let _, cnst, _, _, _, tindc, impls = List.find (eq_indc indc) l in
+ indc, cnst, map_prod impls (fst tindc))
(CMap.find (Int.Map.find i num2ind) constructors)) in
(* and use constants mapped to constructors of [ty_ind] as triggers. *)
let pt_refs = List.map (fun (_, cnst, _) -> cnst) (to_post.(0)) in
diff --git a/plugins/syntax/numeral.mli b/plugins/syntax/numeral.mli
index 1f6896d549..5a13d1068b 100644
--- a/plugins/syntax/numeral.mli
+++ b/plugins/syntax/numeral.mli
@@ -14,7 +14,7 @@ open Notation
(** * Number notation *)
-type number_string_via = qualid * (qualid * qualid) list
+type number_string_via = qualid * (bool * qualid * qualid) list
type number_option =
| After of numnot_option
| Via of number_string_via