aboutsummaryrefslogtreecommitdiff
path: root/plugins
diff options
context:
space:
mode:
authorPierre Roux2020-09-03 13:25:00 +0200
committerPierre Roux2020-11-05 00:20:19 +0100
commite728a1ef0f8b5fdc4b1815a7d0349c67db15f9b4 (patch)
tree2a809813e374246465eb693bf444bffab25fd13c /plugins
parent036117fa4992debb42e8346a48f6259f504793d3 (diff)
[numeral notation] Add support for parameterized inductives
Diffstat (limited to 'plugins')
-rw-r--r--plugins/syntax/numeral.ml66
1 files changed, 58 insertions, 8 deletions
diff --git a/plugins/syntax/numeral.ml b/plugins/syntax/numeral.ml
index 1efe6b77d1..89d757a72a 100644
--- a/plugins/syntax/numeral.ml
+++ b/plugins/syntax/numeral.ml
@@ -136,6 +136,11 @@ let warn_deprecated_decimal =
Decimal.int or Decimal.decimal. Use Number.uint, \
Number.int or Number.number respectively.")
+let error_params ind =
+ CErrors.user_err
+ (str "Wrong number of parameters for inductive" ++ spc ()
+ ++ Printer.pr_global (GlobRef.IndRef ind) ++ str ".")
+
let remapping_error ?loc ty ty' ty'' =
CErrors.user_err ?loc
(Printer.pr_global ty
@@ -219,11 +224,43 @@ let get_type env sigma c =
List.map (fun (na, a) -> na, EConstr.Unsafe.to_constr a) l,
EConstr.Unsafe.to_constr t
-(* [elaborate_to_post env sigma ty_name ty_ind l] builds the [to_post]
+(* [elaborate_to_post_params env sigma ty_ind params] builds the
+ [to_post] translation (c.f., interp/notation.mli) for the numeral
+ notation to parse/print type [ty_ind]. This translation is the
+ identity ([ToPostCopy]) except that it checks ([ToPostCheck]) that
+ the parameters of the inductive type [ty_ind] match the ones given
+ in [params]. *)
+let elaborate_to_post_params env sigma ty_ind params =
+ let to_post_for_constructor indc =
+ let sigma, c = match indc with
+ | GlobRef.ConstructRef c ->
+ let sigma,c = Evd.fresh_constructor_instance env sigma c in
+ sigma, Constr.mkConstructU c
+ | _ -> assert false in (* c.f. get_constructors *)
+ let args, t = get_type env sigma c in
+ let params_indc = match Constr.kind t with
+ | Constr.App (_, a) -> Array.to_list a | _ -> [] in
+ let sz = List.length args in
+ let a = Array.make sz ToPostCopy in
+ if List.length params <> List.length params_indc then error_params ty_ind;
+ List.iter2 (fun param param_indc ->
+ match param, Constr.kind param_indc with
+ | Some p, Constr.Rel i when i <= sz -> a.(sz - i) <- ToPostCheck p
+ | _ -> ())
+ params params_indc;
+ indc, indc, Array.to_list a in
+ let pt_refs = get_constructors ty_ind in
+ let to_post_0 = List.map to_post_for_constructor pt_refs in
+ let to_post =
+ let only_copy (_, _, args) = List.for_all ((=) ToPostCopy) args in
+ if (List.for_all only_copy to_post_0) then [||] else [|to_post_0|] in
+ to_post, pt_refs
+
+(* [elaborate_to_post_via env sigma ty_name ty_ind l] builds the [to_post]
translation (c.f., interp/notation.mli) for the number notation to
parse/print type [ty_name] through the inductive [ty_ind] according
to the pairs [constant, constructor] in the list [l]. *)
-let elaborate_to_post env sigma ty_name ty_ind l =
+let elaborate_to_post_via env sigma ty_name ty_ind l =
let sigma, ty_name =
locate_global_sort_inductive_or_constant sigma ty_name in
let ty_ind = Constr.mkInd ty_ind in
@@ -344,10 +381,21 @@ let elaborate_to_post env sigma ty_name ty_ind l =
let pt_refs = List.map (fun (_, cnst, _) -> cnst) (to_post.(0)) in
to_post, pt_refs
-let elaborate_to_post env sigma ty_ind via =
- match via with
- | None -> [||], get_constructors ty_ind
- | Some (ty_name, l) -> elaborate_to_post env sigma ty_name ty_ind l
+let locate_global_inductive allow_params qid =
+ let locate_param_inductive qid =
+ match Nametab.locate_extended qid with
+ | Globnames.TrueGlobal _ -> raise Not_found
+ | Globnames.SynDef kn ->
+ match Syntax_def.search_syntactic_definition kn with
+ | [], Notation_term.(NApp (NRef (GlobRef.IndRef i), l)) when allow_params ->
+ i,
+ List.map (function
+ | Notation_term.NRef r -> Some r
+ | Notation_term.NHole _ -> None
+ | _ -> raise Not_found) l
+ | _ -> raise Not_found in
+ try locate_param_inductive qid
+ with Not_found -> Smartlocate.global_inductive_with_alias qid, []
let vernac_number_notation local ty f g opts scope =
let rec parse_opts = function
@@ -373,7 +421,7 @@ let vernac_number_notation local ty f g opts scope =
let ty_name = ty in
let ty, via =
match via with None -> ty, via | Some (ty', a) -> ty', Some (ty, a) in
- let tyc = Smartlocate.global_inductive_with_alias ty in
+ let tyc, params = locate_global_inductive (via = None) ty in
let to_ty = Smartlocate.global_with_alias f in
let of_ty = Smartlocate.global_with_alias g in
let cty = mkRefC ty in
@@ -437,7 +485,9 @@ let vernac_number_notation local ty f g opts scope =
| _, ((DecimalInt _ | DecimalUInt _ | Decimal _), _) ->
warn_deprecated_decimal ()
| _ -> ());
- let to_post, pt_refs = elaborate_to_post env sigma tyc via in
+ let to_post, pt_refs = match via with
+ | None -> elaborate_to_post_params env sigma tyc params
+ | Some (ty, l) -> elaborate_to_post_via env sigma ty tyc l in
let o = { to_kind; to_ty; to_post; of_kind; of_ty; ty_name;
warning = opts }
in