diff options
| author | Pierre Roux | 2020-09-03 13:12:00 +0200 |
|---|---|---|
| committer | Pierre Roux | 2020-11-04 20:53:47 +0100 |
| commit | 11a8997dd8fa83537607272692a3baf10dab342a (patch) | |
| tree | 2b88f003ab19f264d94f29806c28b48258800d28 /plugins/syntax | |
| parent | dfcb15141a19db4f1cc61c14d1cdad0275009356 (diff) | |
[numeral notation] Adding the via ... using ... option
This enables numeral notations for non inductive types by
pre/postprocessing them to a given proxy inductive type.
For instance, this should enable the use of numeral notations for R.
Diffstat (limited to 'plugins/syntax')
| -rw-r--r-- | plugins/syntax/g_numeral.mlg | 58 | ||||
| -rw-r--r-- | plugins/syntax/numeral.ml | 233 | ||||
| -rw-r--r-- | plugins/syntax/numeral.mli | 11 |
3 files changed, 285 insertions, 17 deletions
diff --git a/plugins/syntax/g_numeral.mlg b/plugins/syntax/g_numeral.mlg index 48e262c3ef..e60ae45b01 100644 --- a/plugins/syntax/g_numeral.mlg +++ b/plugins/syntax/g_numeral.mlg @@ -19,33 +19,71 @@ open Names open Stdarg open Pcoq.Prim -let pr_number_modifier = function +let pr_number_after = function | Nop -> mt () - | Warning n -> str "(warning after " ++ NumTok.UnsignedNat.print n ++ str ")" - | Abstract n -> str "(abstract after " ++ NumTok.UnsignedNat.print n ++ str ")" + | Warning n -> str "warning after " ++ NumTok.UnsignedNat.print n + | Abstract n -> str "abstract after " ++ NumTok.UnsignedNat.print n + +let pr_deprecated_number_modifier m = str "(" ++ pr_number_after m ++ str ")" let warn_deprecated_numeral_notation = CWarnings.create ~name:"numeral-notation" ~category:"deprecated" (fun () -> strbrk "Numeral Notation is deprecated, please use Number Notation instead.") +let pr_number_mapping (n, n') = + Libnames.pr_qualid n ++ spc () ++ str "=>" ++ spc () ++ Libnames.pr_qualid n' + +let pr_number_via (n, l) = + str "via " ++ Libnames.pr_qualid n ++ str " mapping [" + ++ prlist_with_sep pr_comma pr_number_mapping l ++ str "]" + +let pr_number_modifier = function + | After a -> pr_number_after a + | Via nl -> pr_number_via nl + +let pr_number_options l = + str "(" ++ prlist_with_sep pr_comma pr_number_modifier l ++ str ")" + } -VERNAC ARGUMENT EXTEND number_modifier - PRINTED BY { pr_number_modifier } +VERNAC ARGUMENT EXTEND deprecated_number_modifier + PRINTED BY { pr_deprecated_number_modifier } | [ ] -> { Nop } | [ "(" "warning" "after" bignat(waft) ")" ] -> { Warning (NumTok.UnsignedNat.of_string waft) } | [ "(" "abstract" "after" bignat(n) ")" ] -> { Abstract (NumTok.UnsignedNat.of_string n) } END +VERNAC ARGUMENT EXTEND number_mapping + PRINTED BY { pr_number_mapping } +| [ reference(n) "=>" reference(n') ] -> { n, n' } +END + +VERNAC ARGUMENT EXTEND number_via + PRINTED BY { pr_number_via } +| [ "via" reference(n) "mapping" "[" ne_number_mapping_list_sep(l, ",") "]" ] -> { n, l } +END + +VERNAC ARGUMENT EXTEND number_modifier + PRINTED BY { pr_number_modifier } +| [ "warning" "after" bignat(waft) ] -> { After (Warning (NumTok.UnsignedNat.of_string waft)) } +| [ "abstract" "after" bignat(n) ] -> { After (Abstract (NumTok.UnsignedNat.of_string n)) } +| [ number_via(v) ] -> { Via v } +END + +VERNAC ARGUMENT EXTEND number_options + PRINTED BY { pr_number_options } +| [ "(" ne_number_modifier_list_sep(l, ",") ")" ] -> { l } +END + VERNAC COMMAND EXTEND NumberNotation CLASSIFIED AS SIDEFF - | #[ locality = Attributes.locality; ] [ "Number" "Notation" reference(ty) reference(f) reference(g) ":" - ident(sc) number_modifier(o) ] -> + | #[ locality = Attributes.locality; ] [ "Number" "Notation" reference(ty) reference(f) reference(g) number_options_opt(nl) ":" + ident(sc) ] -> - { vernac_number_notation (Locality.make_module_locality locality) ty f g (Id.to_string sc) o } + { vernac_number_notation (Locality.make_module_locality locality) ty f g (Option.default [] nl) (Id.to_string sc) } | #[ locality = Attributes.locality; ] [ "Numeral" "Notation" reference(ty) reference(f) reference(g) ":" - ident(sc) number_modifier(o) ] -> + ident(sc) deprecated_number_modifier(o) ] -> { warn_deprecated_numeral_notation (); - vernac_number_notation (Locality.make_module_locality locality) ty f g (Id.to_string sc) o } + vernac_number_notation (Locality.make_module_locality locality) ty f g [After o] (Id.to_string sc) } END diff --git a/plugins/syntax/numeral.ml b/plugins/syntax/numeral.ml index ad90a9a982..316ca456a4 100644 --- a/plugins/syntax/numeral.ml +++ b/plugins/syntax/numeral.ml @@ -16,8 +16,16 @@ open Constrexpr open Constrexpr_ops open Notation +module CSet = CSet.Make (Constr) +module CMap = CMap.Make (Constr) + (** * Number notation *) +type number_string_via = qualid * (qualid * qualid) list +type number_option = + | After of numnot_option + | Via of number_string_via + let warn_abstract_large_num_no_op = CWarnings.create ~name:"abstract-large-number-no-op" ~category:"numbers" (fun f -> @@ -128,12 +136,228 @@ let warn_deprecated_decimal = Decimal.int or Decimal.decimal. Use Number.uint, \ Number.int or Number.number respectively.") -let vernac_number_notation local ty f g scope opts = +let remapping_error ?loc ty ty' ty'' = + CErrors.user_err ?loc + (Printer.pr_global ty + ++ str " was already mapped to" ++ spc () ++ Printer.pr_global ty' + ++ str " and cannot be remapped to" ++ spc () ++ Printer.pr_global ty'' + ++ str ".") + +let error_missing c = + CErrors.user_err + (str "Missing mapping for constructor " ++ Printer.pr_global c ++ str ".") + +let pr_constr env sigma c = + let c = Constrextern.extern_constr env sigma (EConstr.of_constr c) in + Ppconstr.pr_constr_expr env sigma c + +let warn_via_remapping = + CWarnings.create ~name:"via-type-remapping" ~category:"numbers" + (fun (env, sigma, ty, ty', ty'') -> + let constr = pr_constr env sigma in + constr ty ++ str " was already mapped to" ++ spc () ++ constr ty' + ++ str ", mapping it also to" ++ spc () ++ constr ty'' + ++ str " might yield ill typed terms when using the notation.") + +let warn_via_type_mismatch = + CWarnings.create ~name:"via-type-mismatch" ~category:"numbers" + (fun (env, sigma, g, g', exp, actual) -> + let constr = pr_constr env sigma in + str "Type of" ++ spc() ++ Printer.pr_global g + ++ str " seems incompatible with the type of" ++ spc () + ++ Printer.pr_global g' ++ str "." ++ spc () + ++ str "Expected type is: " ++ constr exp ++ spc () + ++ str "instead of " ++ constr actual ++ str "." ++ spc () + ++ str "This might yield ill typed terms when using the notation.") + +let multiple_via_error () = + CErrors.user_err (Pp.str "Multiple 'via' options.") + +let multiple_after_error () = + CErrors.user_err (Pp.str "Multiple 'warning after' or 'abstract after' options.") + +let via_abstract_error () = + CErrors.user_err (Pp.str "'via' and 'abstract' cannot be used together.") + +let locate_global_sort_inductive_or_constant sigma qid = + let locate_sort qid = + match Nametab.locate_extended qid with + | Globnames.TrueGlobal _ -> raise Not_found + | Globnames.SynDef kn -> + match Syntax_def.search_syntactic_definition kn with + | [], Notation_term.NSort r -> + let sigma,c = Evd.fresh_sort_in_family sigma (Glob_ops.glob_sort_family r) in + sigma,Constr.mkSort c + | _ -> raise Not_found in + try locate_sort qid + with Not_found -> + match Smartlocate.global_with_alias qid with + | GlobRef.IndRef i -> sigma, Constr.mkInd i + | _ -> sigma, Constr.mkConst (Smartlocate.global_constant_with_alias qid) + +let locate_global_constructor_inductive_or_constant qid = + let g = Smartlocate.global_with_alias qid in + match g with + | GlobRef.ConstructRef c -> g, Constr.mkConstruct c + | GlobRef.IndRef i -> g, Constr.mkInd i + | _ -> g, Constr.mkConst (Smartlocate.global_constant_with_alias qid) + +(* [get_type env sigma c] retrieves the type of [c] and returns a pair + [l, t] such that [c : l_0 -> ... -> l_n -> t]. *) +let get_type env sigma c = + (* inspired from [compute_implicit_names] in "interp/impargs.ml" *) + let rec aux env acc t = + let t = Reductionops.whd_all env sigma t in + match EConstr.kind sigma t with + | Constr.Prod (na, a, b) -> + let a = Reductionops.whd_all env sigma a in + let rel = Context.Rel.Declaration.LocalAssum (na, a) in + aux (EConstr.push_rel rel env) ((na, a) :: acc) b + | _ -> List.rev acc, t in + let t = Retyping.get_type_of env sigma (EConstr.of_constr c) in + let l, t = aux env [] t in + List.map (fun (na, a) -> na, EConstr.Unsafe.to_constr a) l, + EConstr.Unsafe.to_constr t + +(* [elaborate_to_post env sigma ty_name ty_ind l] builds the [to_post] + translation (c.f., interp/notation.mli) for the number notation to + parse/print type [ty_name] through the inductive [ty_ind] according + to the pairs [constant, constructor] in the list [l]. *) +let elaborate_to_post env sigma ty_name ty_ind l = + let sigma, ty_name = + locate_global_sort_inductive_or_constant sigma ty_name in + let ty_ind = Constr.mkInd ty_ind in + (* Retrieve constants and constructors mappings and their type. + For each constant [cnst] and inductive constructor [indc] in [l], retrieve: + * its location: [lcnst] and [lindc] + * its GlobRef: [cnst] and [indc] + * its type: [tcnst] and [tindc] (decomposed in product by [get_type] above) *) + let l = + let read (cnst, indc) = + let lcnst, lindc = cnst.CAst.loc, indc.CAst.loc in + let cnst, ccnst = locate_global_constructor_inductive_or_constant cnst in + let indc, cindc = + let indc = Smartlocate.global_constructor_with_alias indc in + GlobRef.ConstructRef indc, Constr.mkConstruct indc in + let get_type_wo_params c = + (* ignore parameters of inductive types *) + let rm_params c = match Constr.kind c with + | Constr.App (c, _) when Constr.isInd c -> c + | _ -> c in + let lc, tc = get_type env sigma c in + List.map (fun (n, c) -> n, rm_params c) lc, rm_params tc in + let tcnst, tindc = get_type_wo_params ccnst, get_type_wo_params cindc in + lcnst, cnst, tcnst, lindc, indc, tindc in + List.map read l in + let eq_indc indc (_, _, _, _, indc', _) = GlobRef.equal indc indc' in + (* Collect all inductive types involved. + That is [ty_ind] and all final codomains of [tindc] above. *) + let inds = + List.fold_left (fun s (_, _, _, _, _, tindc) -> CSet.add (snd tindc) s) + (CSet.singleton ty_ind) l in + (* And for each inductive, retrieve its constructors. *) + let constructors = + CSet.fold (fun ind m -> + let inductive, _ = Constr.destInd ind in + CMap.add ind (get_constructors inductive) m) + inds CMap.empty in + (* Error if one [constructor] in some inductive in [inds] + doesn't appear exactly once in [l] *) + let _ = (* check_for duplicate constructor and error *) + List.fold_left (fun already_seen (_, cnst, _, loc, indc, _) -> + try + let cnst' = List.assoc_f GlobRef.equal indc already_seen in + remapping_error ?loc indc cnst' cnst + with Not_found -> (indc, cnst) :: already_seen) + [] l in + let () = (* check for missing constructor and error *) + CMap.iter (fun _ -> + List.iter (fun cstr -> + if not (List.exists (eq_indc cstr) l) then error_missing cstr)) + constructors in + (* Perform some checks on types and warn if they look strange. + These checks are neither sound nor complete, so we only warn. *) + let () = + (* associate inductives to types, and check that this mapping is one to one + and maps [ty_ind] to [ty_name] *) + let ind2ty, ty2ind = + let add loc ckey cval m = + match CMap.find_opt ckey m with + | None -> CMap.add ckey cval m + | Some old_cval -> + if not (Constr.equal old_cval cval) then + warn_via_remapping ?loc (env, sigma, ckey, old_cval, cval); + m in + List.fold_left + (fun (ind2ty, ty2ind) (lcnst, _, (_, tcnst), lindc, _, (_, tindc)) -> + add lcnst tindc tcnst ind2ty, add lindc tcnst tindc ty2ind) + CMap.(singleton ty_ind ty_name, singleton ty_name ty_ind) l in + (* check that type of constants and constructors mapped in [l] + match modulo [ind2ty] *) + let replace m (l, t) = + let apply_m c = try CMap.find c m with Not_found -> c in + List.fold_right (fun (na, a) b -> Constr.mkProd (na, (apply_m a), b)) + l (apply_m t) in + List.iter (fun (_, cnst, tcnst, loc, indc, tindc) -> + let tcnst' = replace CMap.empty tcnst in + if not (Constr.equal tcnst' (replace ind2ty tindc)) then + let actual = replace CMap.empty tindc in + let expected = replace ty2ind tcnst in + warn_via_type_mismatch ?loc (env, sigma, indc, cnst, expected, actual)) + l in + (* Associate an index to each inductive, starting from 0 for [ty_ind]. *) + let ind2num, num2ind, nb_ind = + CMap.fold (fun ind _ (ind2num, num2ind, i) -> + CMap.add ind i ind2num, Int.Map.add i ind num2ind, i + 1) + (CMap.remove ty_ind constructors) + (CMap.singleton ty_ind 0, Int.Map.singleton 0 ty_ind, 1) in + (* Finally elaborate [to_post] *) + let to_post = + let rec map_prod = function + | [] -> [] + | (_, a) :: b -> + let t = match CMap.find_opt a ind2num with + | Some i -> ToPostAs i + | None -> ToPostCopy in + t :: map_prod b in + Array.init nb_ind (fun i -> + List.map (fun indc -> + let _, cnst, _, _, _, tindc = List.find (eq_indc indc) l in + indc, cnst, map_prod (fst tindc)) + (CMap.find (Int.Map.find i num2ind) constructors)) in + (* and use constants mapped to constructors of [ty_ind] as triggers. *) + let pt_refs = List.map (fun (_, cnst, _) -> cnst) (to_post.(0)) in + to_post, pt_refs + +let elaborate_to_post env sigma ty_ind via = + match via with + | None -> [||], get_constructors ty_ind + | Some (ty_name, l) -> elaborate_to_post env sigma ty_name ty_ind l + +let vernac_number_notation local ty f g opts scope = + let rec parse_opts = function + | [] -> None, Nop + | h :: opts -> + let via, opts = parse_opts opts in + let via = match h, via with + | Via _, Some _ -> multiple_via_error () + | Via v, None -> Some v + | _ -> via in + let opts = match h, opts with + | After _, (Warning _ | Abstract _) -> multiple_after_error () + | After a, Nop -> a + | _ -> opts in + via, opts in + let via, opts = parse_opts opts in + (match via, opts with Some _, Abstract _ -> via_abstract_error () | _ -> ()); let env = Global.env () in let sigma = Evd.from_env env in let num_ty = locate_number () in let z_pos_ty = locate_z () in let int63_ty = locate_int63 () in + let ty_name = ty in + let ty, via = + match via with None -> ty, via | Some (ty', a) -> ty', Some (ty, a) in let tyc = Smartlocate.global_inductive_with_alias ty in let to_ty = Smartlocate.global_with_alias f in let of_ty = Smartlocate.global_with_alias g in @@ -143,7 +367,6 @@ let vernac_number_notation local ty f g scope opts = mkProdC ([CAst.make Anonymous],Default Glob_term.Explicit, x, y) in let opt r = app (mkRefC (q_option ())) r in - let constructors = get_constructors tyc in (* Check the type of f *) let to_kind = match num_ty with @@ -199,8 +422,8 @@ let vernac_number_notation local ty f g scope opts = | _, ((DecimalInt _ | DecimalUInt _ | Decimal _), _) -> warn_deprecated_decimal () | _ -> ()); - let o = { to_kind; to_ty; to_post = [||]; of_kind; of_ty; - ty_name = ty; + let to_post, pt_refs = elaborate_to_post env sigma tyc via in + let o = { to_kind; to_ty; to_post; of_kind; of_ty; ty_name; warning = opts } in (match opts, to_kind with @@ -211,7 +434,7 @@ let vernac_number_notation local ty f g scope opts = pt_scope = scope; pt_interp_info = NumberNotation o; pt_required = Nametab.path_of_global (GlobRef.IndRef tyc),[]; - pt_refs = constructors; + pt_refs; pt_in_match = true } in enable_prim_token_interpretation i diff --git a/plugins/syntax/numeral.mli b/plugins/syntax/numeral.mli index d5fe42b0b4..1f6896d549 100644 --- a/plugins/syntax/numeral.mli +++ b/plugins/syntax/numeral.mli @@ -14,6 +14,13 @@ open Notation (** * Number notation *) +type number_string_via = qualid * (qualid * qualid) list +type number_option = + | After of numnot_option + | Via of number_string_via + val vernac_number_notation : locality_flag -> - qualid -> qualid -> qualid -> - Notation_term.scope_name -> numnot_option -> unit + qualid -> + qualid -> qualid -> + number_option list -> + Notation_term.scope_name -> unit |
