diff options
| author | Pierre Roux | 2020-09-03 13:25:00 +0200 |
|---|---|---|
| committer | Pierre Roux | 2020-11-05 00:20:19 +0100 |
| commit | e728a1ef0f8b5fdc4b1815a7d0349c67db15f9b4 (patch) | |
| tree | 2a809813e374246465eb693bf444bffab25fd13c /plugins/syntax | |
| parent | 036117fa4992debb42e8346a48f6259f504793d3 (diff) | |
[numeral notation] Add support for parameterized inductives
Diffstat (limited to 'plugins/syntax')
| -rw-r--r-- | plugins/syntax/numeral.ml | 66 |
1 files changed, 58 insertions, 8 deletions
diff --git a/plugins/syntax/numeral.ml b/plugins/syntax/numeral.ml index 1efe6b77d1..89d757a72a 100644 --- a/plugins/syntax/numeral.ml +++ b/plugins/syntax/numeral.ml @@ -136,6 +136,11 @@ let warn_deprecated_decimal = Decimal.int or Decimal.decimal. Use Number.uint, \ Number.int or Number.number respectively.") +let error_params ind = + CErrors.user_err + (str "Wrong number of parameters for inductive" ++ spc () + ++ Printer.pr_global (GlobRef.IndRef ind) ++ str ".") + let remapping_error ?loc ty ty' ty'' = CErrors.user_err ?loc (Printer.pr_global ty @@ -219,11 +224,43 @@ let get_type env sigma c = List.map (fun (na, a) -> na, EConstr.Unsafe.to_constr a) l, EConstr.Unsafe.to_constr t -(* [elaborate_to_post env sigma ty_name ty_ind l] builds the [to_post] +(* [elaborate_to_post_params env sigma ty_ind params] builds the + [to_post] translation (c.f., interp/notation.mli) for the numeral + notation to parse/print type [ty_ind]. This translation is the + identity ([ToPostCopy]) except that it checks ([ToPostCheck]) that + the parameters of the inductive type [ty_ind] match the ones given + in [params]. *) +let elaborate_to_post_params env sigma ty_ind params = + let to_post_for_constructor indc = + let sigma, c = match indc with + | GlobRef.ConstructRef c -> + let sigma,c = Evd.fresh_constructor_instance env sigma c in + sigma, Constr.mkConstructU c + | _ -> assert false in (* c.f. get_constructors *) + let args, t = get_type env sigma c in + let params_indc = match Constr.kind t with + | Constr.App (_, a) -> Array.to_list a | _ -> [] in + let sz = List.length args in + let a = Array.make sz ToPostCopy in + if List.length params <> List.length params_indc then error_params ty_ind; + List.iter2 (fun param param_indc -> + match param, Constr.kind param_indc with + | Some p, Constr.Rel i when i <= sz -> a.(sz - i) <- ToPostCheck p + | _ -> ()) + params params_indc; + indc, indc, Array.to_list a in + let pt_refs = get_constructors ty_ind in + let to_post_0 = List.map to_post_for_constructor pt_refs in + let to_post = + let only_copy (_, _, args) = List.for_all ((=) ToPostCopy) args in + if (List.for_all only_copy to_post_0) then [||] else [|to_post_0|] in + to_post, pt_refs + +(* [elaborate_to_post_via env sigma ty_name ty_ind l] builds the [to_post] translation (c.f., interp/notation.mli) for the number notation to parse/print type [ty_name] through the inductive [ty_ind] according to the pairs [constant, constructor] in the list [l]. *) -let elaborate_to_post env sigma ty_name ty_ind l = +let elaborate_to_post_via env sigma ty_name ty_ind l = let sigma, ty_name = locate_global_sort_inductive_or_constant sigma ty_name in let ty_ind = Constr.mkInd ty_ind in @@ -344,10 +381,21 @@ let elaborate_to_post env sigma ty_name ty_ind l = let pt_refs = List.map (fun (_, cnst, _) -> cnst) (to_post.(0)) in to_post, pt_refs -let elaborate_to_post env sigma ty_ind via = - match via with - | None -> [||], get_constructors ty_ind - | Some (ty_name, l) -> elaborate_to_post env sigma ty_name ty_ind l +let locate_global_inductive allow_params qid = + let locate_param_inductive qid = + match Nametab.locate_extended qid with + | Globnames.TrueGlobal _ -> raise Not_found + | Globnames.SynDef kn -> + match Syntax_def.search_syntactic_definition kn with + | [], Notation_term.(NApp (NRef (GlobRef.IndRef i), l)) when allow_params -> + i, + List.map (function + | Notation_term.NRef r -> Some r + | Notation_term.NHole _ -> None + | _ -> raise Not_found) l + | _ -> raise Not_found in + try locate_param_inductive qid + with Not_found -> Smartlocate.global_inductive_with_alias qid, [] let vernac_number_notation local ty f g opts scope = let rec parse_opts = function @@ -373,7 +421,7 @@ let vernac_number_notation local ty f g opts scope = let ty_name = ty in let ty, via = match via with None -> ty, via | Some (ty', a) -> ty', Some (ty, a) in - let tyc = Smartlocate.global_inductive_with_alias ty in + let tyc, params = locate_global_inductive (via = None) ty in let to_ty = Smartlocate.global_with_alias f in let of_ty = Smartlocate.global_with_alias g in let cty = mkRefC ty in @@ -437,7 +485,9 @@ let vernac_number_notation local ty f g opts scope = | _, ((DecimalInt _ | DecimalUInt _ | Decimal _), _) -> warn_deprecated_decimal () | _ -> ()); - let to_post, pt_refs = elaborate_to_post env sigma tyc via in + let to_post, pt_refs = match via with + | None -> elaborate_to_post_params env sigma tyc params + | Some (ty, l) -> elaborate_to_post_via env sigma ty tyc l in let o = { to_kind; to_ty; to_post; of_kind; of_ty; ty_name; warning = opts } in |
