aboutsummaryrefslogtreecommitdiff
path: root/kernel
diff options
context:
space:
mode:
Diffstat (limited to 'kernel')
-rw-r--r--kernel/cClosure.ml8
-rw-r--r--kernel/context.ml9
-rw-r--r--kernel/context.mli3
-rw-r--r--kernel/entries.ml9
-rw-r--r--kernel/environ.ml5
-rw-r--r--kernel/environ.mli2
-rw-r--r--kernel/indTyping.ml11
-rw-r--r--kernel/inferCumulativity.ml109
-rw-r--r--kernel/inferCumulativity.mli4
-rw-r--r--kernel/reduction.ml49
-rw-r--r--kernel/type_errors.ml5
-rw-r--r--kernel/type_errors.mli3
12 files changed, 147 insertions, 70 deletions
diff --git a/kernel/cClosure.ml b/kernel/cClosure.ml
index 174125fc57..17feeb9b5a 100644
--- a/kernel/cClosure.ml
+++ b/kernel/cClosure.ml
@@ -1098,14 +1098,8 @@ module FNativeEntries =
let defined_array = ref false
- let farray = ref dummy
-
let init_array retro =
- match retro.Retroknowledge.retro_array with
- | Some c ->
- defined_array := true;
- farray := { mark = mark Norm KnownR; term = FFlex (ConstKey (Univ.in_punivs c)) }
- | None -> defined_array := false
+ defined_array := Option.has_some retro.Retroknowledge.retro_array
let init env =
current_retro := env.retroknowledge;
diff --git a/kernel/context.ml b/kernel/context.ml
index 6a99f201f3..ab66898b59 100644
--- a/kernel/context.ml
+++ b/kernel/context.ml
@@ -365,6 +365,15 @@ struct
let ty' = f ty in
if v == v' && ty == ty' then decl else LocalDef (id, v', ty')
+ let map_constr_het f = function
+ | LocalAssum (id, ty) ->
+ let ty' = f ty in
+ LocalAssum (id, ty')
+ | LocalDef (id, v, ty) ->
+ let v' = f v in
+ let ty' = f ty in
+ LocalDef (id, v', ty')
+
(** Perform a given action on all terms in a given declaration. *)
let iter_constr f = function
| LocalAssum (_, ty) -> f ty
diff --git a/kernel/context.mli b/kernel/context.mli
index 76c4461760..29309daf34 100644
--- a/kernel/context.mli
+++ b/kernel/context.mli
@@ -231,6 +231,9 @@ sig
(** Map all terms in a given declaration. *)
val map_constr : ('c -> 'c) -> ('c, 'c) pt -> ('c, 'c) pt
+ (** Map all terms, with an heterogeneous function. *)
+ val map_constr_het : ('a -> 'b) -> ('a, 'a) pt -> ('b, 'b) pt
+
(** Perform a given action on all terms in a given declaration. *)
val iter_constr : ('c -> unit) -> ('c, 'c) pt -> unit
diff --git a/kernel/entries.ml b/kernel/entries.ml
index ae64112e33..1bfc740017 100644
--- a/kernel/entries.ml
+++ b/kernel/entries.ml
@@ -20,6 +20,8 @@ type universes_entry =
| Monomorphic_entry of Univ.ContextSet.t
| Polymorphic_entry of Name.t array * Univ.UContext.t
+type variance_entry = Univ.Variance.t option array
+
type 'a in_universes_entry = 'a * universes_entry
(** {6 Declaration of inductive types. } *)
@@ -50,9 +52,10 @@ type mutual_inductive_entry = {
mind_entry_inds : one_inductive_entry list;
mind_entry_universes : universes_entry;
mind_entry_template : bool; (* Use template polymorphism *)
- mind_entry_cumulative : bool;
- (* universe constraints and the constraints for subtyping of
- inductive types in the block. *)
+ mind_entry_variance : variance_entry option;
+ (* [None] if non-cumulative, otherwise associates each universe of
+ the entry to [None] if to be inferred or [Some v] if to be
+ checked. *)
mind_entry_private : bool option;
}
diff --git a/kernel/environ.ml b/kernel/environ.ml
index 5914e66fc3..69edb1498c 100644
--- a/kernel/environ.ml
+++ b/kernel/environ.ml
@@ -568,6 +568,11 @@ let is_primitive env c =
| Declarations.Primitive _ -> true
| _ -> false
+let is_array_type env c =
+ match env.retroknowledge.Retroknowledge.retro_array with
+ | None -> false
+ | Some c' -> Constant.CanOrd.equal c c'
+
let polymorphic_constant cst env =
Declareops.constant_is_polymorphic (lookup_constant cst env)
diff --git a/kernel/environ.mli b/kernel/environ.mli
index 6d0ca93707..6a8ddce835 100644
--- a/kernel/environ.mli
+++ b/kernel/environ.mli
@@ -249,6 +249,8 @@ val constant_opt_value_in : env -> Constant.t puniverses -> constr option
val is_primitive : env -> Constant.t -> bool
+val is_array_type : env -> Constant.t -> bool
+
(** {6 Primitive projections} *)
(** Checks that the number of parameters is correct. *)
diff --git a/kernel/indTyping.ml b/kernel/indTyping.ml
index b2520b780f..33ee8c325a 100644
--- a/kernel/indTyping.ml
+++ b/kernel/indTyping.ml
@@ -369,15 +369,20 @@ let typecheck_inductive env ~sec_univs (mie:mutual_inductive_entry) =
data, Some None
in
- let variance = if not mie.mind_entry_cumulative then None
- else match mie.mind_entry_universes with
+ let variance = match mie.mind_entry_variance with
+ | None -> None
+ | Some variances ->
+ match mie.mind_entry_universes with
| Monomorphic_entry _ ->
CErrors.user_err Pp.(str "Inductive cannot be both monomorphic and universe cumulative.")
| Polymorphic_entry (_,uctx) ->
let univs = Instance.to_array @@ UContext.instance uctx in
+ let univs = Array.map2 (fun a b -> a,b) univs variances in
let univs = match sec_univs with
| None -> univs
- | Some sec_univs -> Array.append sec_univs univs
+ | Some sec_univs ->
+ let sec_univs = Array.map (fun u -> u, None) sec_univs in
+ Array.append sec_univs univs
in
let variances = InferCumulativity.infer_inductive ~env_params univs mie.mind_entry_inds in
Some variances
diff --git a/kernel/inferCumulativity.ml b/kernel/inferCumulativity.ml
index 8191a5b0f3..d02f92ef26 100644
--- a/kernel/inferCumulativity.ml
+++ b/kernel/inferCumulativity.ml
@@ -15,30 +15,82 @@ open Univ
open Variance
open Util
-type inferred = IrrelevantI | CovariantI
-
-(** Throughout this module we modify a map [variances] from local
- universes to [inferred]. It starts as a trivial mapping to
- [Irrelevant] and every time we encounter a local universe we
- restrict it accordingly.
- [Invariant] universes are removed from the map.
-*)
exception TrivialVariance
-let maybe_trivial variances =
- if LMap.is_empty variances then raise TrivialVariance
- else variances
+(** Not the same as Type_errors.BadVariance because we don't have the env where we raise. *)
+exception BadVariance of Level.t * Variance.t * Variance.t
+(* some ocaml bug is triggered if we make this an inline record *)
-let infer_level_eq u variances =
- maybe_trivial (LMap.remove u variances)
+module Inf : sig
+ type variances
+ val infer_level_eq : Level.t -> variances -> variances
+ val infer_level_leq : Level.t -> variances -> variances
+ val start : (Level.t * Variance.t option) array -> variances
+ val finish : variances -> Variance.t array
+end = struct
+ type inferred = IrrelevantI | CovariantI
+ type mode = Check | Infer
-let infer_level_leq u variances =
- (* can only set Irrelevant -> Covariant so nontrivial *)
- LMap.update u (function
- | None -> None
- | Some CovariantI as x -> x
- | Some IrrelevantI -> Some CovariantI)
- variances
+ (**
+ Each local universe is either in the [univs] map or is Invariant.
+
+ If [univs] is empty all universes are Invariant and there is nothing more to do,
+ so we stop by raising [TrivialVariance]. The [soft] check comes before that.
+ *)
+ type variances = {
+ orig_array : (Level.t * Variance.t option) array;
+ univs : (mode * inferred) LMap.t;
+ }
+
+ let to_variance = function
+ | IrrelevantI -> Irrelevant
+ | CovariantI -> Covariant
+
+ let to_variance_opt o = Option.cata to_variance Invariant o
+
+ let infer_level_eq u variances =
+ match LMap.find_opt u variances.univs with
+ | None -> variances
+ | Some (Check, expected) ->
+ let expected = to_variance expected in
+ raise (BadVariance (u, expected, Invariant))
+ | Some (Infer, _) ->
+ let univs = LMap.remove u variances.univs in
+ if LMap.is_empty univs then raise TrivialVariance;
+ {variances with univs}
+
+ let infer_level_leq u variances =
+ (* can only set Irrelevant -> Covariant so no TrivialVariance *)
+ let univs =
+ LMap.update u (function
+ | None -> None
+ | Some (_,CovariantI) as x -> x
+ | Some (Infer,IrrelevantI) -> Some (Infer,CovariantI)
+ | Some (Check,IrrelevantI) ->
+ raise (BadVariance (u, Irrelevant, Covariant)))
+ variances.univs
+ in
+ if univs == variances.univs then variances else {variances with univs}
+
+ let start us =
+ let univs = Array.fold_left (fun univs (u,variance) ->
+ match variance with
+ | None -> LMap.add u (Infer,IrrelevantI) univs
+ | Some Invariant -> univs
+ | Some Covariant -> LMap.add u (Check,CovariantI) univs
+ | Some Irrelevant -> LMap.add u (Check,IrrelevantI) univs)
+ LMap.empty us
+ in
+ if LMap.is_empty univs then raise TrivialVariance;
+ {univs; orig_array=us}
+
+ let finish variances =
+ Array.map
+ (fun (u,_check) -> to_variance_opt (Option.map snd (LMap.find_opt u variances.univs)))
+ variances.orig_array
+
+end
+open Inf
let infer_generic_instance_eq variances u =
Array.fold_left (fun variances u -> infer_level_eq u variances)
@@ -204,11 +256,7 @@ let infer_arity_constructor is_arity env variances arcn =
open Entries
let infer_inductive_core env univs entries =
- if Array.is_empty univs then raise TrivialVariance;
- let variances =
- Array.fold_left (fun variances u -> LMap.add u IrrelevantI variances)
- LMap.empty univs
- in
+ let variances = Inf.start univs in
let variances = List.fold_left (fun variances entry ->
let variances = infer_arity_constructor true
env variances entry.mind_entry_arity
@@ -218,12 +266,11 @@ let infer_inductive_core env univs entries =
variances
entries
in
- Array.map (fun u -> match LMap.find u variances with
- | exception Not_found -> Invariant
- | IrrelevantI -> Irrelevant
- | CovariantI -> Covariant)
- univs
+ Inf.finish variances
let infer_inductive ~env_params univs entries =
try infer_inductive_core env_params univs entries
- with TrivialVariance -> Array.make (Array.length univs) Invariant
+ with
+ | TrivialVariance -> Array.make (Array.length univs) Invariant
+ | BadVariance (lev, expected, actual) ->
+ Type_errors.error_bad_variance env_params ~lev ~expected ~actual
diff --git a/kernel/inferCumulativity.mli b/kernel/inferCumulativity.mli
index db5539a0ff..99d8f0c98d 100644
--- a/kernel/inferCumulativity.mli
+++ b/kernel/inferCumulativity.mli
@@ -12,8 +12,8 @@ val infer_inductive
: env_params:Environ.env
(** Environment containing the polymorphic universes and the
parameters. *)
- -> Univ.Level.t array
- (** Universes whose cumulativity we want to infer. *)
+ -> (Univ.Level.t * Univ.Variance.t option) array
+ (** Universes whose cumulativity we want to infer or check. *)
-> Entries.one_inductive_entry list
(** The inductive block data we want to infer cumulativity for.
NB: we ignore the template bool and the names, only the terms
diff --git a/kernel/reduction.ml b/kernel/reduction.ml
index c891b885c4..cf40263f61 100644
--- a/kernel/reduction.ml
+++ b/kernel/reduction.ml
@@ -280,11 +280,12 @@ let convert_constructors ctor nargs u1 u2 (s, check) =
convert_constructors_gen (check.compare_instances ~flex:false) check.compare_cumul_instances
ctor nargs u1 u2 s, check
-let conv_table_key infos k1 k2 cuniv =
+let conv_table_key infos ~nargs k1 k2 cuniv =
if k1 == k2 then cuniv else
match k1, k2 with
| ConstKey (cst, u), ConstKey (cst', u') when Constant.CanOrd.equal cst cst' ->
if Univ.Instance.equal u u' then cuniv
+ else if Int.equal nargs 1 && is_array_type (info_env infos) cst then cuniv
else
let flex = evaluable_constant cst (info_env infos)
&& RedFlags.red_set (info_flags infos) (RedFlags.fCONST cst)
@@ -304,6 +305,11 @@ let unfold_ref_with_args infos tab fl v =
Some (a, (Zupdate a::(Zprimitive(op,c,rargs,nargs)::v)))
| Undef _ | OpaqueDef _ | Primitive _ -> None
+let same_args_size sk1 sk2 =
+ let n = CClosure.stack_args_size sk1 in
+ if Int.equal n (CClosure.stack_args_size sk2) then n
+ else raise NotConvertible
+
type conv_tab = {
cnv_inf : clos_infos;
lft_tab : clos_tab;
@@ -408,7 +414,8 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
(* 2 constants, 2 local defined vars or 2 defined rels *)
| (FFlex fl1, FFlex fl2) ->
(try
- let cuniv = conv_table_key infos.cnv_inf fl1 fl2 cuniv in
+ let nargs = same_args_size v1 v2 in
+ let cuniv = conv_table_key infos.cnv_inf ~nargs fl1 fl2 cuniv in
convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
with NotConvertible | Univ.UniverseInconsistency _ ->
let r1 = unfold_ref_with_args infos.cnv_inf infos.lft_tab fl1 v1 in
@@ -577,17 +584,14 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
else
let mind = Environ.lookup_mind (fst ind1) (info_env infos.cnv_inf) in
- let nargs = CClosure.stack_args_size v1 in
- if not (Int.equal nargs (CClosure.stack_args_size v2))
- then raise NotConvertible
- else
- match convert_inductives cv_pb (mind, snd ind1) nargs u1 u2 cuniv with
- | cuniv -> convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
- | exception MustExpand ->
- let env = info_env infos.cnv_inf in
- let hd1 = eta_expand_ind env pind1 in
- let hd2 = eta_expand_ind env pind2 in
- eqappr cv_pb l2r infos (lft1,(hd1,v1)) (lft2,(hd2,v2)) cuniv
+ let nargs = same_args_size v1 v2 in
+ match convert_inductives cv_pb (mind, snd ind1) nargs u1 u2 cuniv with
+ | cuniv -> convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
+ | exception MustExpand ->
+ let env = info_env infos.cnv_inf in
+ let hd1 = eta_expand_ind env pind1 in
+ let hd2 = eta_expand_ind env pind2 in
+ eqappr cv_pb l2r infos (lft1,(hd1,v1)) (lft2,(hd2,v2)) cuniv
else raise NotConvertible
| (FConstruct ((ind1,j1),u1 as pctor1), FConstruct ((ind2,j2),u2 as pctor2)) ->
@@ -597,17 +601,14 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
else
let mind = Environ.lookup_mind (fst ind1) (info_env infos.cnv_inf) in
- let nargs = CClosure.stack_args_size v1 in
- if not (Int.equal nargs (CClosure.stack_args_size v2))
- then raise NotConvertible
- else
- match convert_constructors (mind, snd ind1, j1) nargs u1 u2 cuniv with
- | cuniv -> convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
- | exception MustExpand ->
- let env = info_env infos.cnv_inf in
- let hd1 = eta_expand_constructor env pctor1 in
- let hd2 = eta_expand_constructor env pctor2 in
- eqappr cv_pb l2r infos (lft1,(hd1,v1)) (lft2,(hd2,v2)) cuniv
+ let nargs = same_args_size v1 v2 in
+ match convert_constructors (mind, snd ind1, j1) nargs u1 u2 cuniv with
+ | cuniv -> convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
+ | exception MustExpand ->
+ let env = info_env infos.cnv_inf in
+ let hd1 = eta_expand_constructor env pctor1 in
+ let hd2 = eta_expand_constructor env pctor2 in
+ eqappr cv_pb l2r infos (lft1,(hd1,v1)) (lft2,(hd2,v2)) cuniv
else raise NotConvertible
(* Eta expansion of records *)
diff --git a/kernel/type_errors.ml b/kernel/type_errors.ml
index ae5c4b6880..bcb7aa88ca 100644
--- a/kernel/type_errors.ml
+++ b/kernel/type_errors.ml
@@ -69,6 +69,7 @@ type ('constr, 'types) ptype_error =
| DisallowedSProp
| BadRelevance
| BadInvert
+ | BadVariance of { lev : Level.t; expected : Variance.t; actual : Variance.t }
type type_error = (constr, types) ptype_error
@@ -163,6 +164,9 @@ let error_bad_relevance env =
let error_bad_invert env =
raise (TypeError (env, BadInvert))
+let error_bad_variance env ~lev ~expected ~actual =
+ raise (TypeError (env, BadVariance {lev;expected;actual}))
+
let map_pguard_error f = function
| NotEnoughAbstractionInFixBody -> NotEnoughAbstractionInFixBody
| RecursionNotOnInductiveType c -> RecursionNotOnInductiveType (f c)
@@ -207,3 +211,4 @@ let map_ptype_error f = function
| DisallowedSProp -> DisallowedSProp
| BadRelevance -> BadRelevance
| BadInvert -> BadInvert
+| BadVariance u -> BadVariance u
diff --git a/kernel/type_errors.mli b/kernel/type_errors.mli
index b1f7eb8a34..bcdcab9db7 100644
--- a/kernel/type_errors.mli
+++ b/kernel/type_errors.mli
@@ -70,6 +70,7 @@ type ('constr, 'types) ptype_error =
| DisallowedSProp
| BadRelevance
| BadInvert
+ | BadVariance of { lev : Level.t; expected : Variance.t; actual : Variance.t }
type type_error = (constr, types) ptype_error
@@ -146,5 +147,7 @@ val error_bad_relevance : env -> 'a
val error_bad_invert : env -> 'a
+val error_bad_variance : env -> lev:Level.t -> expected:Variance.t -> actual:Variance.t -> 'a
+
val map_pguard_error : ('c -> 'd) -> 'c pguard_error -> 'd pguard_error
val map_ptype_error : ('c -> 'd) -> ('c, 'c) ptype_error -> ('d, 'd) ptype_error