diff options
| author | Pierre Roux | 2020-09-03 13:23:00 +0200 |
|---|---|---|
| committer | Pierre Roux | 2020-11-05 00:20:19 +0100 |
| commit | 0520decfdc94d52a2f8658b9cf6a730e6d333f8f (patch) | |
| tree | 56130ed8dafe578760221bbc6e7d7d835ac4791c /plugins/syntax | |
| parent | 9082af80f5bb70ff2b75117f9e5cc3165b1c8b42 (diff) | |
[numeral notation] Handle implicit arguments
Diffstat (limited to 'plugins/syntax')
| -rw-r--r-- | plugins/syntax/g_numeral.mlg | 12 | ||||
| -rw-r--r-- | plugins/syntax/numeral.ml | 51 | ||||
| -rw-r--r-- | plugins/syntax/numeral.mli | 2 |
3 files changed, 43 insertions, 22 deletions
diff --git a/plugins/syntax/g_numeral.mlg b/plugins/syntax/g_numeral.mlg index e60ae45b01..a3cc786a4a 100644 --- a/plugins/syntax/g_numeral.mlg +++ b/plugins/syntax/g_numeral.mlg @@ -31,8 +31,13 @@ let warn_deprecated_numeral_notation = (fun () -> strbrk "Numeral Notation is deprecated, please use Number Notation instead.") -let pr_number_mapping (n, n') = - Libnames.pr_qualid n ++ spc () ++ str "=>" ++ spc () ++ Libnames.pr_qualid n' +let pr_number_mapping (b, n, n') = + if b then + str "[" ++ Libnames.pr_qualid n ++ str "]" ++ spc () ++ str "=>" ++ spc () + ++ Libnames.pr_qualid n' + else + Libnames.pr_qualid n ++ spc () ++ str "=>" ++ spc () + ++ Libnames.pr_qualid n' let pr_number_via (n, l) = str "via " ++ Libnames.pr_qualid n ++ str " mapping [" @@ -56,7 +61,8 @@ END VERNAC ARGUMENT EXTEND number_mapping PRINTED BY { pr_number_mapping } -| [ reference(n) "=>" reference(n') ] -> { n, n' } +| [ reference(n) "=>" reference(n') ] -> { false, n, n' } +| [ "[" reference(n) "]" "=>" reference(n') ] -> { true, n, n' } END VERNAC ARGUMENT EXTEND number_via diff --git a/plugins/syntax/numeral.ml b/plugins/syntax/numeral.ml index 316ca456a4..1efe6b77d1 100644 --- a/plugins/syntax/numeral.ml +++ b/plugins/syntax/numeral.ml @@ -21,7 +21,7 @@ module CMap = CMap.Make (Constr) (** * Number notation *) -type number_string_via = qualid * (qualid * qualid) list +type number_string_via = qualid * (bool * qualid * qualid) list type number_option = | After of numnot_option | Via of number_string_via @@ -231,9 +231,10 @@ let elaborate_to_post env sigma ty_name ty_ind l = For each constant [cnst] and inductive constructor [indc] in [l], retrieve: * its location: [lcnst] and [lindc] * its GlobRef: [cnst] and [indc] - * its type: [tcnst] and [tindc] (decomposed in product by [get_type] above) *) + * its type: [tcnst] and [tindc] (decomposed in product by [get_type] above) + * [impls] are the implicit arguments of [cnst] *) let l = - let read (cnst, indc) = + let read (consider_implicits, cnst, indc) = let lcnst, lindc = cnst.CAst.loc, indc.CAst.loc in let cnst, ccnst = locate_global_constructor_inductive_or_constant cnst in let indc, cindc = @@ -247,13 +248,16 @@ let elaborate_to_post env sigma ty_name ty_ind l = let lc, tc = get_type env sigma c in List.map (fun (n, c) -> n, rm_params c) lc, rm_params tc in let tcnst, tindc = get_type_wo_params ccnst, get_type_wo_params cindc in - lcnst, cnst, tcnst, lindc, indc, tindc in + let impls = + if not consider_implicits then [] else + Impargs.(select_stronger_impargs (implicits_of_global cnst)) in + lcnst, cnst, tcnst, lindc, indc, tindc, impls in List.map read l in - let eq_indc indc (_, _, _, _, indc', _) = GlobRef.equal indc indc' in + let eq_indc indc (_, _, _, _, indc', _, _) = GlobRef.equal indc indc' in (* Collect all inductive types involved. That is [ty_ind] and all final codomains of [tindc] above. *) let inds = - List.fold_left (fun s (_, _, _, _, _, tindc) -> CSet.add (snd tindc) s) + List.fold_left (fun s (_, _, _, _, _, tindc, _) -> CSet.add (snd tindc) s) (CSet.singleton ty_ind) l in (* And for each inductive, retrieve its constructors. *) let constructors = @@ -264,7 +268,7 @@ let elaborate_to_post env sigma ty_name ty_ind l = (* Error if one [constructor] in some inductive in [inds] doesn't appear exactly once in [l] *) let _ = (* check_for duplicate constructor and error *) - List.fold_left (fun already_seen (_, cnst, _, loc, indc, _) -> + List.fold_left (fun already_seen (_, cnst, _, loc, indc, _, _) -> try let cnst' = List.assoc_f GlobRef.equal indc already_seen in remapping_error ?loc indc cnst' cnst @@ -289,16 +293,23 @@ let elaborate_to_post env sigma ty_name ty_ind l = warn_via_remapping ?loc (env, sigma, ckey, old_cval, cval); m in List.fold_left - (fun (ind2ty, ty2ind) (lcnst, _, (_, tcnst), lindc, _, (_, tindc)) -> + (fun (ind2ty, ty2ind) (lcnst, _, (_, tcnst), lindc, _, (_, tindc), _) -> add lcnst tindc tcnst ind2ty, add lindc tcnst tindc ty2ind) CMap.(singleton ty_ind ty_name, singleton ty_name ty_ind) l in (* check that type of constants and constructors mapped in [l] match modulo [ind2ty] *) + let rm_impls impls (l, t) = + let rec aux impls l = match impls, l with + | Some _ :: impls, _ :: b -> aux impls b + | None :: impls, (n, a) :: b -> (n, a) :: aux impls b + | _ -> l in + aux impls l, t in let replace m (l, t) = let apply_m c = try CMap.find c m with Not_found -> c in List.fold_right (fun (na, a) b -> Constr.mkProd (na, (apply_m a), b)) l (apply_m t) in - List.iter (fun (_, cnst, tcnst, loc, indc, tindc) -> + List.iter (fun (_, cnst, tcnst, loc, indc, tindc, impls) -> + let tcnst = rm_impls impls tcnst in let tcnst' = replace CMap.empty tcnst in if not (Constr.equal tcnst' (replace ind2ty tindc)) then let actual = replace CMap.empty tindc in @@ -313,17 +324,21 @@ let elaborate_to_post env sigma ty_name ty_ind l = (CMap.singleton ty_ind 0, Int.Map.singleton 0 ty_ind, 1) in (* Finally elaborate [to_post] *) let to_post = - let rec map_prod = function - | [] -> [] - | (_, a) :: b -> - let t = match CMap.find_opt a ind2num with - | Some i -> ToPostAs i - | None -> ToPostCopy in - t :: map_prod b in + let rec map_prod impls tindc = match impls with + | Some _ :: impls -> ToPostHole :: map_prod impls tindc + | _ -> + match tindc with + | [] -> [] + | (_, a) :: b -> + let t = match CMap.find_opt a ind2num with + | Some i -> ToPostAs i + | None -> ToPostCopy in + let impls = match impls with [] -> [] | _ :: t -> t in + t :: map_prod impls b in Array.init nb_ind (fun i -> List.map (fun indc -> - let _, cnst, _, _, _, tindc = List.find (eq_indc indc) l in - indc, cnst, map_prod (fst tindc)) + let _, cnst, _, _, _, tindc, impls = List.find (eq_indc indc) l in + indc, cnst, map_prod impls (fst tindc)) (CMap.find (Int.Map.find i num2ind) constructors)) in (* and use constants mapped to constructors of [ty_ind] as triggers. *) let pt_refs = List.map (fun (_, cnst, _) -> cnst) (to_post.(0)) in diff --git a/plugins/syntax/numeral.mli b/plugins/syntax/numeral.mli index 1f6896d549..5a13d1068b 100644 --- a/plugins/syntax/numeral.mli +++ b/plugins/syntax/numeral.mli @@ -14,7 +14,7 @@ open Notation (** * Number notation *) -type number_string_via = qualid * (qualid * qualid) list +type number_string_via = qualid * (bool * qualid * qualid) list type number_option = | After of numnot_option | Via of number_string_via |
