diff options
| author | Pierre Roux | 2020-09-03 13:25:00 +0200 |
|---|---|---|
| committer | Pierre Roux | 2020-11-05 00:20:19 +0100 |
| commit | e728a1ef0f8b5fdc4b1815a7d0349c67db15f9b4 (patch) | |
| tree | 2a809813e374246465eb693bf444bffab25fd13c | |
| parent | 036117fa4992debb42e8346a48f6259f504793d3 (diff) | |
[numeral notation] Add support for parameterized inductives
| -rw-r--r-- | doc/sphinx/user-extensions/syntax-extensions.rst | 28 | ||||
| -rw-r--r-- | interp/notation.ml | 27 | ||||
| -rw-r--r-- | interp/notation.mli | 4 | ||||
| -rw-r--r-- | plugins/syntax/numeral.ml | 66 | ||||
| -rw-r--r-- | test-suite/output/NumberNotations.out | 52 | ||||
| -rw-r--r-- | test-suite/output/NumberNotations.v | 127 |
6 files changed, 284 insertions, 20 deletions
diff --git a/doc/sphinx/user-extensions/syntax-extensions.rst b/doc/sphinx/user-extensions/syntax-extensions.rst index 4c6d300b13..60fbd68687 100644 --- a/doc/sphinx/user-extensions/syntax-extensions.rst +++ b/doc/sphinx/user-extensions/syntax-extensions.rst @@ -1608,6 +1608,12 @@ Number notations function application, constructors, inductive type families, sorts, and primitive integers) will be considered for printing. + .. note:: + Number notations for parameterized inductive types can be + added by declaring an :ref:`abbreviation <Abbreviations>` + for the inductive which instantiates all parameters. See + example below. + :n:`via @qualid__ind mapping [ {+, @qualid__constant => @qualid__constructor } ]` When using this option, :n:`@qualid__type` no longer needs to be an inductive type and is instead mapped to the @@ -1847,6 +1853,24 @@ Number notations Check 2 : Fin.t 2. + .. example:: Number Notation with a parameterized inductive type + + .. coqtop:: in reset + + Definition of_uint u : list unit := + let fix f n := match n with O => nil | S n => cons tt (f n) end in + f (Nat.of_num_uint u). + Definition to_uint (l : list unit) := Nat.to_num_uint (length l). + + The parameter :g:`unit` for the parameterized inductive type + :g:`list` is given through an :ref:`abbreviation + <Abbreviations>`. + + .. coqtop:: in + + Notation list_unit := (list unit) (only parsing). + Number Notation list_unit of_uint to_uint : nat_scope. + .. _string-notations: String notations @@ -1917,8 +1941,8 @@ The following errors apply to both string and number notations: .. exn:: @type is not an inductive type. - String and number notations can only be declared for inductive types with no - arguments. Declare numeral notations for non-inductive types using :n:`@number_via`. + String and number notations can only be declared for inductive types. + Declare number notations for non-inductive types using :n:`@number_via`. .. exn:: Cannot interpret in @scope_name because @qualid could not be found in the current environment. diff --git a/interp/notation.ml b/interp/notation.ml index 0f149c5f50..1839e287d7 100644 --- a/interp/notation.ml +++ b/interp/notation.ml @@ -554,8 +554,10 @@ type 'target conversion_kind = 'target * option_kind argument is recursively translated according to [l_k]. [ToPostHole] introduces an additional implicit argument hole (in the reverse translation, the corresponding argument is removed). + [ToPostCheck r] behaves as [ToPostCopy] except in the reverse + translation which fails if the copied term is not [r]. When [n] is null, no translation is performed. *) -type to_post_arg = ToPostCopy | ToPostAs of int | ToPostHole +type to_post_arg = ToPostCopy | ToPostAs of int | ToPostHole | ToPostCheck of GlobRef.t type ('target, 'warning) prim_token_notation_obj = { to_kind : 'target conversion_kind; to_ty : GlobRef.t; @@ -620,11 +622,11 @@ let constr_of_globref allow_constant env sigma = function sigma,mkConstU c | _ -> raise NotAValidPrimToken -let rec constr_of_glob to_post post env sigma g = match DAst.get g with +let rec constr_of_glob allow_constant to_post post env sigma g = match DAst.get g with | Glob_term.GRef (r, _) -> let o = List.find_opt (fun (_,r',_) -> GlobRef.equal r r') post in begin match o with - | None -> constr_of_globref false env sigma r + | None -> constr_of_globref allow_constant env sigma r | Some (r, _, a) -> (* [g] is not a GApp so check that [post] does not expect any actual argument @@ -638,19 +640,26 @@ let rec constr_of_glob to_post post env sigma g = match DAst.get g with | _ -> None in begin match o with | None -> - let sigma,c = constr_of_glob to_post post env sigma gc in - let sigma,cl = List.fold_left_map (constr_of_glob to_post post env) sigma gcl in + let sigma,c = constr_of_glob allow_constant to_post post env sigma gc in + let sigma,cl = List.fold_left_map (constr_of_glob allow_constant to_post post env) sigma gcl in sigma,mkApp (c, Array.of_list cl) | Some (r, _, a) -> let sigma,c = constr_of_globref true env sigma r in let rec aux sigma a gcl = match a, gcl with | [], [] -> sigma,[] | ToPostCopy :: a, gc :: gcl -> - let sigma,c = constr_of_glob [||] [] env sigma gc in + let sigma,c = constr_of_glob allow_constant [||] [] env sigma gc in + let sigma,cl = aux sigma a gcl in + sigma, c :: cl + | ToPostCheck r :: a, gc :: gcl -> + let () = match DAst.get gc with + | Glob_term.GRef (r', _) when GlobRef.equal r r' -> () + | _ -> raise NotAValidPrimToken in + let sigma,c = constr_of_glob true [||] [] env sigma gc in let sigma,cl = aux sigma a gcl in sigma, c :: cl | ToPostAs i :: a, gc :: gcl -> - let sigma,c = constr_of_glob to_post to_post.(i) env sigma gc in + let sigma,c = constr_of_glob allow_constant to_post to_post.(i) env sigma gc in let sigma,cl = aux sigma a gcl in sigma, c :: cl | ToPostHole :: post, _ :: gcl -> aux sigma post gcl @@ -668,7 +677,7 @@ let rec constr_of_glob to_post post env sigma g = match DAst.get g with let constr_of_glob to_post env sigma (Glob_term.AnyGlobConstr g) = let post = match to_post with [||] -> [] | _ -> to_post.(0) in - constr_of_glob to_post post env sigma g + constr_of_glob false to_post post env sigma g let rec glob_of_constr token_kind ?loc env sigma c = match Constr.kind c with | App (c, ca) -> @@ -705,7 +714,7 @@ let rec postprocess token_kind ?loc ty to_post post g = let e = Evar_kinds.ImplicitArg (r, (n, None), true) in let h = DAst.make ?loc (Glob_term.GHole (e, Namegen.IntroAnonymous, None)) in h :: f (n+1) a gl - | ToPostCopy :: a, g :: gl -> g :: f (n+1) a gl + | (ToPostCopy | ToPostCheck _) :: a, g :: gl -> g :: f (n+1) a gl | ToPostAs c :: a, g :: gl -> postprocess token_kind ?loc ty to_post to_post.(c) g :: f (n+1) a gl | [], _::_ | _::_, [] -> diff --git a/interp/notation.mli b/interp/notation.mli index 012aaac8f0..acca7b262b 100644 --- a/interp/notation.mli +++ b/interp/notation.mli @@ -166,8 +166,10 @@ type 'target conversion_kind = 'target * option_kind argument is recursively translated according to [l_k]. [ToPostHole] introduces an additional implicit argument hole (in the reverse translation, the corresponding argument is removed). + [ToPostCheck r] behaves as [ToPostCopy] except in the reverse + translation which fails if the copied term is not [r]. When [n] is null, no translation is performed. *) -type to_post_arg = ToPostCopy | ToPostAs of int | ToPostHole +type to_post_arg = ToPostCopy | ToPostAs of int | ToPostHole | ToPostCheck of GlobRef.t type ('target, 'warning) prim_token_notation_obj = { to_kind : 'target conversion_kind; to_ty : GlobRef.t; diff --git a/plugins/syntax/numeral.ml b/plugins/syntax/numeral.ml index 1efe6b77d1..89d757a72a 100644 --- a/plugins/syntax/numeral.ml +++ b/plugins/syntax/numeral.ml @@ -136,6 +136,11 @@ let warn_deprecated_decimal = Decimal.int or Decimal.decimal. Use Number.uint, \ Number.int or Number.number respectively.") +let error_params ind = + CErrors.user_err + (str "Wrong number of parameters for inductive" ++ spc () + ++ Printer.pr_global (GlobRef.IndRef ind) ++ str ".") + let remapping_error ?loc ty ty' ty'' = CErrors.user_err ?loc (Printer.pr_global ty @@ -219,11 +224,43 @@ let get_type env sigma c = List.map (fun (na, a) -> na, EConstr.Unsafe.to_constr a) l, EConstr.Unsafe.to_constr t -(* [elaborate_to_post env sigma ty_name ty_ind l] builds the [to_post] +(* [elaborate_to_post_params env sigma ty_ind params] builds the + [to_post] translation (c.f., interp/notation.mli) for the numeral + notation to parse/print type [ty_ind]. This translation is the + identity ([ToPostCopy]) except that it checks ([ToPostCheck]) that + the parameters of the inductive type [ty_ind] match the ones given + in [params]. *) +let elaborate_to_post_params env sigma ty_ind params = + let to_post_for_constructor indc = + let sigma, c = match indc with + | GlobRef.ConstructRef c -> + let sigma,c = Evd.fresh_constructor_instance env sigma c in + sigma, Constr.mkConstructU c + | _ -> assert false in (* c.f. get_constructors *) + let args, t = get_type env sigma c in + let params_indc = match Constr.kind t with + | Constr.App (_, a) -> Array.to_list a | _ -> [] in + let sz = List.length args in + let a = Array.make sz ToPostCopy in + if List.length params <> List.length params_indc then error_params ty_ind; + List.iter2 (fun param param_indc -> + match param, Constr.kind param_indc with + | Some p, Constr.Rel i when i <= sz -> a.(sz - i) <- ToPostCheck p + | _ -> ()) + params params_indc; + indc, indc, Array.to_list a in + let pt_refs = get_constructors ty_ind in + let to_post_0 = List.map to_post_for_constructor pt_refs in + let to_post = + let only_copy (_, _, args) = List.for_all ((=) ToPostCopy) args in + if (List.for_all only_copy to_post_0) then [||] else [|to_post_0|] in + to_post, pt_refs + +(* [elaborate_to_post_via env sigma ty_name ty_ind l] builds the [to_post] translation (c.f., interp/notation.mli) for the number notation to parse/print type [ty_name] through the inductive [ty_ind] according to the pairs [constant, constructor] in the list [l]. *) -let elaborate_to_post env sigma ty_name ty_ind l = +let elaborate_to_post_via env sigma ty_name ty_ind l = let sigma, ty_name = locate_global_sort_inductive_or_constant sigma ty_name in let ty_ind = Constr.mkInd ty_ind in @@ -344,10 +381,21 @@ let elaborate_to_post env sigma ty_name ty_ind l = let pt_refs = List.map (fun (_, cnst, _) -> cnst) (to_post.(0)) in to_post, pt_refs -let elaborate_to_post env sigma ty_ind via = - match via with - | None -> [||], get_constructors ty_ind - | Some (ty_name, l) -> elaborate_to_post env sigma ty_name ty_ind l +let locate_global_inductive allow_params qid = + let locate_param_inductive qid = + match Nametab.locate_extended qid with + | Globnames.TrueGlobal _ -> raise Not_found + | Globnames.SynDef kn -> + match Syntax_def.search_syntactic_definition kn with + | [], Notation_term.(NApp (NRef (GlobRef.IndRef i), l)) when allow_params -> + i, + List.map (function + | Notation_term.NRef r -> Some r + | Notation_term.NHole _ -> None + | _ -> raise Not_found) l + | _ -> raise Not_found in + try locate_param_inductive qid + with Not_found -> Smartlocate.global_inductive_with_alias qid, [] let vernac_number_notation local ty f g opts scope = let rec parse_opts = function @@ -373,7 +421,7 @@ let vernac_number_notation local ty f g opts scope = let ty_name = ty in let ty, via = match via with None -> ty, via | Some (ty', a) -> ty', Some (ty, a) in - let tyc = Smartlocate.global_inductive_with_alias ty in + let tyc, params = locate_global_inductive (via = None) ty in let to_ty = Smartlocate.global_with_alias f in let of_ty = Smartlocate.global_with_alias g in let cty = mkRefC ty in @@ -437,7 +485,9 @@ let vernac_number_notation local ty f g opts scope = | _, ((DecimalInt _ | DecimalUInt _ | Decimal _), _) -> warn_deprecated_decimal () | _ -> ()); - let to_post, pt_refs = elaborate_to_post env sigma tyc via in + let to_post, pt_refs = match via with + | None -> elaborate_to_post_params env sigma tyc params + | Some (ty, l) -> elaborate_to_post_via env sigma ty tyc l in let o = { to_kind; to_ty; to_post; of_kind; of_ty; ty_name; warning = opts } in diff --git a/test-suite/output/NumberNotations.out b/test-suite/output/NumberNotations.out index 357119f74e..57206772c8 100644 --- a/test-suite/output/NumberNotations.out +++ b/test-suite/output/NumberNotations.out @@ -342,6 +342,58 @@ The term has type "Fin.t (S (S (S (S ?n))))" while it is expected to have type "Fin.t (S (S (S O)))". 0 + : list unit +1 + : list unit +2 + : list unit +2 + : list unit +0 :: 0 :: nil + : list nat +0 + : Ip nat bool +1 + : Ip nat bool +2 + : Ip nat bool +3 + : Ip nat bool +1 + : Ip nat bool +1 + : Ip nat bool +1 + : Ip nat bool +1 + : Ip nat bool +Ip0 nat nat 1 + : Ip nat nat +Ip0 bool bool 1 + : Ip bool bool +Ip1 nat nat 1 + : Ip nat nat +Ip3 1 nat nat + : Ip nat nat +Ip0 nat bool O + : Ip nat bool +Ip1 bool nat (S O) + : Ip nat bool +Ip2 nat (S (S O)) bool + : Ip nat bool +Ip3 (S (S (S O))) nat bool + : Ip nat bool +0 + : 0 = 0 +eq_refl + : 1 = 1 +0 + : 1 = 1 +2 + : extra_list_unit +cons O unit tt (cons O unit tt (nil O unit)) + : extra_list unit +0 : Set 1 : Set diff --git a/test-suite/output/NumberNotations.v b/test-suite/output/NumberNotations.v index bfcad2621a..556cf929b4 100644 --- a/test-suite/output/NumberNotations.v +++ b/test-suite/output/NumberNotations.v @@ -686,6 +686,133 @@ Unset Printing All. End Test24. +(* Test number notations for parameterized inductives *) +Module Test25. + +Definition of_uint (u : Number.uint) : list unit := + let fix f n := + match n with + | O => nil + | S n => cons tt (f n) + end in + f (Nat.of_num_uint u). + +Definition to_uint (l : list unit) : Number.uint := + let fix f n := + match n with + | nil => O + | cons tt l => S (f l) + end in + Nat.to_num_uint (f l). + +Notation listunit := (list unit) (only parsing). +Number Notation listunit of_uint to_uint : nat_scope. + +Check 0. +Check 1. +Check 2. + +Check cons tt (cons tt nil). +Check cons O (cons O nil). (* printer not called on list nat *) + +(* inductive with multiple parameters that are not the first + parameters and not in the same order for each constructor *) +Inductive Ip : Type -> Type -> Type := +| Ip0 : forall T T', nat -> Ip T T' +| Ip1 : forall T' T, nat -> Ip T T' +| Ip2 : forall T, nat -> forall T', Ip T T' +| Ip3 : nat -> forall T T', Ip T T'. + +Definition Ip_of_uint (u : Number.uint) : option (Ip nat bool) := + let f n := + match n with + | O => Some (Ip0 nat bool O) + | S O => Some (Ip1 bool nat (S O)) + | S (S O) => Some (Ip2 nat (S (S O)) bool) + | S (S (S O)) => Some (Ip3 (S (S (S O))) nat bool) + | _ => None + end in + f (Nat.of_num_uint u). + +Definition Ip_to_uint (l : Ip nat bool) : Number.uint := + let f n := + match n with + | Ip0 _ _ n => n + | Ip1 _ _ n => n + | Ip2 _ n _ => n + | Ip3 n _ _ => n + end in + Nat.to_num_uint (f l). + +Notation Ip_nat_bool := (Ip nat bool) (only parsing). +Number Notation Ip_nat_bool Ip_of_uint Ip_to_uint : nat_scope. + +Check 0. +Check 1. +Check 2. +Check 3. +Check Ip0 nat bool (S O). +Check Ip1 bool nat (S O). +Check Ip2 nat (S O) bool. +Check Ip3 (S O) nat bool. +Check Ip0 nat nat (S O). (* not printed *) +Check Ip0 bool bool (S O). (* not printed *) +Check Ip1 nat nat (S O). (* not printed *) +Check Ip3 (S O) nat nat. (* not printed *) +Set Printing All. +Check 0. +Check 1. +Check 2. +Check 3. +Unset Printing All. + +Notation eqO := (eq _ O) (only parsing). +Definition eqO_of_uint (x : Number.uint) : eqO := eq_refl O. +Definition eqO_to_uint (x : O = O) : Number.uint := + match x with + | eq_refl _ => Nat.to_num_uint O + end. +Number Notation eqO eqO_of_uint eqO_to_uint : nat_scope. + +Check 42. +Check eq_refl (S O). (* doesn't match eq _ O, printer not called *) + +Notation eq_ := (eq _ _) (only parsing). +Number Notation eq_ eqO_of_uint eqO_to_uint : nat_scope. + +Check eq_refl (S O). (* matches eq _ _, printer called *) + +Inductive extra_list : Type -> Type := +| nil (n : nat) (v : Type) : extra_list v +| cons (n : nat) (t : Type) (x : t) : extra_list t -> extra_list t. + +Definition extra_list_unit_of_uint (x : Number.uint) : extra_list unit := + let fix f n := + match n with + | O => nil O unit + | S n => cons O unit tt (f n) + end in + f (Nat.of_num_uint x). + +Definition extra_list_unit_to_uint (x : extra_list unit) : Number.uint := + let fix f T (x : extra_list T) := + match x with + | nil _ _ => O + | cons _ T _ x => S (f T x) + end in + Nat.to_num_uint (f unit x). + +Notation extra_list_unit := (extra_list unit). +Number Notation extra_list_unit + extra_list_unit_of_uint extra_list_unit_to_uint : nat_scope. + +Check 2. +Set Printing All. +Check 2. +Unset Printing All. + +End Test25. + (* Test the via ... mapping ... option with let-binders, beta-redexes, delta-redexes, etc *) Module Test26. |
