From 2ded4c25e532c5dfca0483c211653768ebed01a7 Mon Sep 17 00:00:00 2001 From: Gaƫtan Gilbert Date: Thu, 13 Jun 2019 15:39:43 +0200 Subject: UIP in SProp --- kernel/constr.ml | 168 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 124 insertions(+), 44 deletions(-) (limited to 'kernel/constr.ml') diff --git a/kernel/constr.ml b/kernel/constr.ml index 703e3616a0..d0598bdad1 100644 --- a/kernel/constr.ml +++ b/kernel/constr.ml @@ -83,6 +83,10 @@ type pconstant = Constant.t puniverses type pinductive = inductive puniverses type pconstructor = constructor puniverses +type ('constr, 'univs) case_invert = + | NoInvert + | CaseInvert of { univs : 'univs; args : 'constr array } + (* [Var] is used for named variables and [Rel] for variables as de Bruijn indices. *) type ('constr, 'types, 'sort, 'univs) kind_of_term = @@ -99,7 +103,7 @@ type ('constr, 'types, 'sort, 'univs) kind_of_term = | Const of (Constant.t * 'univs) | Ind of (inductive * 'univs) | Construct of (constructor * 'univs) - | Case of case_info * 'constr * 'constr * 'constr array + | Case of case_info * 'constr * ('constr, 'univs) case_invert * 'constr * 'constr array | Fix of ('constr, 'types) pfixpoint | CoFix of ('constr, 'types) pcofixpoint | Proj of Projection.t * 'constr @@ -189,7 +193,7 @@ let mkConstructU c = Construct c let mkConstructUi ((ind,u),i) = Construct ((ind,i),u) (* Constructs the term

Case c of c1 | c2 .. | cn end *) -let mkCase (ci, p, c, ac) = Case (ci, p, c, ac) +let mkCase (ci, p, iv, c, ac) = Case (ci, p, iv, c, ac) (* If recindxs = [|i1,...in|] funnames = [|f1,...fn|] @@ -417,7 +421,7 @@ let destConstruct c = match kind c with (* Destructs a term

Case c of lc1 | lc2 .. | lcn end *) let destCase c = match kind c with - | Case (ci,p,c,v) -> (ci,p,c,v) + | Case (ci,p,iv,c,v) -> (ci,p,iv,c,v) | _ -> raise DestKO let destProj c = match kind c with @@ -461,6 +465,11 @@ let decompose_appvect c = starting from [acc] and proceeding from left to right according to the usual representation of the constructions; it is not recursive *) +let fold_invert f acc = function + | NoInvert -> acc + | CaseInvert {univs=_;args} -> + Array.fold_left f acc args + let fold f acc c = match kind c with | (Rel _ | Meta _ | Var _ | Sort _ | Const _ | Ind _ | Construct _ | Int _ | Float _) -> acc @@ -471,7 +480,7 @@ let fold f acc c = match kind c with | App (c,l) -> Array.fold_left f (f acc c) l | Proj (_p,c) -> f acc c | Evar (_,l) -> List.fold_left f acc l - | Case (_,p,c,bl) -> Array.fold_left f (f (f acc p) c) bl + | Case (_,p,iv,c,bl) -> Array.fold_left f (f (fold_invert f (f acc p) iv) c) bl | Fix (_,(_lna,tl,bl)) -> Array.fold_left2 (fun acc t b -> f (f acc t) b) acc tl bl | CoFix (_,(_lna,tl,bl)) -> @@ -481,6 +490,11 @@ let fold f acc c = match kind c with not recursive and the order with which subterms are processed is not specified *) +let iter_invert f = function + | NoInvert -> () + | CaseInvert {univs=_; args;} -> + Array.iter f args + let iter f c = match kind c with | (Rel _ | Meta _ | Var _ | Sort _ | Const _ | Ind _ | Construct _ | Int _ | Float _) -> () @@ -491,7 +505,7 @@ let iter f c = match kind c with | App (c,l) -> f c; Array.iter f l | Proj (_p,c) -> f c | Evar (_,l) -> List.iter f l - | Case (_,p,c,bl) -> f p; f c; Array.iter f bl + | Case (_,p,iv,c,bl) -> f p; iter_invert f iv; f c; Array.iter f bl | Fix (_,(_,tl,bl)) -> Array.iter f tl; Array.iter f bl | CoFix (_,(_,tl,bl)) -> Array.iter f tl; Array.iter f bl @@ -510,7 +524,7 @@ let iter_with_binders g f n c = match kind c with | LetIn (_,b,t,c) -> f n b; f n t; f (g n) c | App (c,l) -> f n c; Array.Fun1.iter f n l | Evar (_,l) -> List.iter (fun c -> f n c) l - | Case (_,p,c,bl) -> f n p; f n c; Array.Fun1.iter f n bl + | Case (_,p,iv,c,bl) -> f n p; iter_invert (f n) iv; f n c; Array.Fun1.iter f n bl | Proj (_p,c) -> f n c | Fix (_,(_,tl,bl)) -> Array.Fun1.iter f n tl; @@ -537,7 +551,7 @@ let fold_constr_with_binders g f n acc c = | App (c,l) -> Array.fold_left (f n) (f n acc c) l | Proj (_p,c) -> f n acc c | Evar (_,l) -> List.fold_left (f n) acc l - | Case (_,p,c,bl) -> Array.fold_left (f n) (f n (f n acc p) c) bl + | Case (_,p,iv,c,bl) -> Array.fold_left (f n) (f n (fold_invert (f n) (f n acc p) iv) c) bl | Fix (_,(_,tl,bl)) -> let n' = iterate g (Array.length tl) n in let fd = Array.map2 (fun t b -> (t,b)) tl bl in @@ -623,6 +637,13 @@ let map_branches_with_full_binders g f l ci bl = let map_return_predicate_with_full_binders g f l ci p = map_under_context_with_full_binders g f l (List.length ci.ci_pp_info.ind_tags) p +let map_invert f = function + | NoInvert -> NoInvert + | CaseInvert {univs;args;} as orig -> + let args' = Array.Smart.map f args in + if args == args' then orig + else CaseInvert {univs;args=args';} + let map_gen userview f c = match kind c with | (Rel _ | Meta _ | Var _ | Sort _ | Const _ | Ind _ | Construct _ | Int _ | Float _) -> c @@ -660,18 +681,20 @@ let map_gen userview f c = match kind c with let l' = List.Smart.map f l in if l'==l then c else mkEvar (e, l') - | Case (ci,p,b,bl) when userview -> + | Case (ci,p,iv,b,bl) when userview -> let b' = f b in + let iv' = map_invert f iv in let p' = map_return_predicate f ci p in let bl' = map_branches f ci bl in - if b'==b && p'==p && bl'==bl then c - else mkCase (ci, p', b', bl') - | Case (ci,p,b,bl) -> + if b'==b && iv'==iv && p'==p && bl'==bl then c + else mkCase (ci, p', iv', b', bl') + | Case (ci,p,iv,b,bl) -> let b' = f b in + let iv' = map_invert f iv in let p' = f p in let bl' = Array.Smart.map f bl in - if b'==b && p'==p && bl'==bl then c - else mkCase (ci, p', b', bl') + if b'==b && iv'==iv && p'==p && bl'==bl then c + else mkCase (ci, p', iv', b', bl') | Fix (ln,(lna,tl,bl)) -> let tl' = Array.Smart.map f tl in let bl' = Array.Smart.map f bl in @@ -688,6 +711,13 @@ let map = map_gen false (* Like {!map} but with an accumulator. *) +let fold_map_invert f acc = function + | NoInvert -> acc, NoInvert + | CaseInvert {univs;args;} as orig -> + let acc, args' = Array.fold_left_map f acc args in + if args==args' then acc, orig + else acc, CaseInvert {univs;args=args';} + let fold_map f accu c = match kind c with | (Rel _ | Meta _ | Var _ | Sort _ | Const _ | Ind _ | Construct _ | Int _ | Float _) -> accu, c @@ -726,12 +756,13 @@ let fold_map f accu c = match kind c with let accu, l' = List.fold_left_map f accu l in if l'==l then accu, c else accu, mkEvar (e, l') - | Case (ci,p,b,bl) -> + | Case (ci,p,iv,b,bl) -> let accu, b' = f accu b in + let accu, iv' = fold_map_invert f accu iv in let accu, p' = f accu p in let accu, bl' = Array.Smart.fold_left_map f accu bl in - if b'==b && p'==p && bl'==bl then accu, c - else accu, mkCase (ci, p', b', bl') + if b'==b && iv'==iv && p'==p && bl'==bl then accu, c + else accu, mkCase (ci, p', iv', b', bl') | Fix (ln,(lna,tl,bl)) -> let accu, tl' = Array.Smart.fold_left_map f accu tl in let accu, bl' = Array.Smart.fold_left_map f accu bl in @@ -786,12 +817,13 @@ let map_with_binders g f l c0 = match kind c0 with let al' = List.Smart.map (fun c -> f l c) al in if al' == al then c0 else mkEvar (e, al') - | Case (ci, p, c, bl) -> + | Case (ci, p, iv, c, bl) -> let p' = f l p in + let iv' = map_invert (f l) iv in let c' = f l c in let bl' = Array.Fun1.Smart.map f l bl in - if p' == p && c' == c && bl' == bl then c0 - else mkCase (ci, p', c', bl') + if p' == p && iv' == iv && c' == c && bl' == bl then c0 + else mkCase (ci, p', iv', c', bl') | Fix (ln, (lna, tl, bl)) -> let tl' = Array.Fun1.Smart.map f l tl in let l' = iterate g (Array.length tl) l in @@ -836,7 +868,7 @@ let fold_with_full_binders g f n acc c = | App (c,l) -> Array.fold_left (f n) (f n acc c) l | Proj (_,c) -> f n acc c | Evar (_,l) -> List.fold_left (f n) acc l - | Case (_,p,c,bl) -> Array.fold_left (f n) (f n (f n acc p) c) bl + | Case (_,p,iv,c,bl) -> Array.fold_left (f n) (f n (fold_invert (f n) (f n acc p) iv) c) bl | Fix (_,(lna,tl,bl)) -> let n' = CArray.fold_left2_i (fun i c n t -> g (LocalAssum (n,lift i t)) c) n lna tl in let fd = Array.map2 (fun t b -> (t,b)) tl bl in @@ -847,7 +879,7 @@ let fold_with_full_binders g f n acc c = Array.fold_left (fun acc (t,b) -> f n' (f n acc t) b) acc fd -type 'univs instance_compare_fn = GlobRef.t -> int -> +type 'univs instance_compare_fn = (GlobRef.t * int) option -> 'univs -> 'univs -> bool type 'constr constr_compare_fn = int -> 'constr -> 'constr -> bool @@ -863,6 +895,14 @@ type 'constr constr_compare_fn = int -> 'constr -> 'constr -> bool optimisation that physically equal arrays are equals (hence the calls to {!Array.equal_norefl}). *) +let eq_invert eq leq_universes iv1 iv2 = + match iv1, iv2 with + | NoInvert, NoInvert -> true + | NoInvert, CaseInvert _ | CaseInvert _, NoInvert -> false + | CaseInvert {univs;args}, CaseInvert iv2 -> + leq_universes univs iv2.univs + && Array.equal eq args iv2.args + let compare_head_gen_leq_with kind1 kind2 leq_universes leq_sorts eq leq nargs t1 t2 = match kind_nocast_gen kind1 t1, kind_nocast_gen kind2 t2 with | Cast _, _ | _, Cast _ -> assert false (* kind_nocast *) @@ -884,12 +924,12 @@ let compare_head_gen_leq_with kind1 kind2 leq_universes leq_sorts eq leq nargs t | Evar (e1,l1), Evar (e2,l2) -> Evar.equal e1 e2 && List.equal (eq 0) l1 l2 | Const (c1,u1), Const (c2,u2) -> (* The args length currently isn't used but may as well pass it. *) - Constant.equal c1 c2 && leq_universes (GlobRef.ConstRef c1) nargs u1 u2 - | Ind (c1,u1), Ind (c2,u2) -> eq_ind c1 c2 && leq_universes (GlobRef.IndRef c1) nargs u1 u2 + Constant.equal c1 c2 && leq_universes (Some (GlobRef.ConstRef c1, nargs)) u1 u2 + | Ind (c1,u1), Ind (c2,u2) -> eq_ind c1 c2 && leq_universes (Some (GlobRef.IndRef c1, nargs)) u1 u2 | Construct (c1,u1), Construct (c2,u2) -> - eq_constructor c1 c2 && leq_universes (GlobRef.ConstructRef c1) nargs u1 u2 - | Case (_,p1,c1,bl1), Case (_,p2,c2,bl2) -> - eq 0 p1 p2 && eq 0 c1 c2 && Array.equal (eq 0) bl1 bl2 + eq_constructor c1 c2 && leq_universes (Some (GlobRef.ConstructRef c1, nargs)) u1 u2 + | Case (_,p1,iv1,c1,bl1), Case (_,p2,iv2,c2,bl2) -> + eq 0 p1 p2 && eq_invert (eq 0) (leq_universes None) iv1 iv2 && eq 0 c1 c2 && Array.equal (eq 0) bl1 bl2 | Fix ((ln1, i1),(_,tl1,bl1)), Fix ((ln2, i2),(_,tl2,bl2)) -> Int.equal i1 i2 && Array.equal Int.equal ln1 ln2 && Array.equal_norefl (eq 0) tl1 tl2 && Array.equal_norefl (eq 0) bl1 bl2 @@ -923,7 +963,7 @@ let compare_head_gen_with kind1 kind2 eq_universes eq_sorts eq t1 t2 = let compare_head_gen eq_universes eq_sorts eq t1 t2 = compare_head_gen_leq eq_universes eq_sorts eq eq t1 t2 -let compare_head = compare_head_gen (fun _ _ -> Univ.Instance.equal) Sorts.equal +let compare_head = compare_head_gen (fun _ -> Univ.Instance.equal) Sorts.equal (*******************************) (* alpha conversion functions *) @@ -932,14 +972,14 @@ let compare_head = compare_head_gen (fun _ _ -> Univ.Instance.equal) Sorts.equal (* alpha conversion : ignore print names and casts *) let rec eq_constr nargs m n = - (m == n) || compare_head_gen (fun _ _ -> Instance.equal) Sorts.equal eq_constr nargs m n + (m == n) || compare_head_gen (fun _ -> Instance.equal) Sorts.equal eq_constr nargs m n let equal n m = eq_constr 0 m n (* to avoid tracing a recursive fun *) let eq_constr_univs univs m n = if m == n then true else - let eq_universes _ _ = UGraph.check_eq_instances univs in + let eq_universes _ = UGraph.check_eq_instances univs in let eq_sorts s1 s2 = s1 == s2 || UGraph.check_eq univs (Sorts.univ_of_sort s1) (Sorts.univ_of_sort s2) in let rec eq_constr' nargs m n = m == n || compare_head_gen eq_universes eq_sorts eq_constr' nargs m n @@ -948,7 +988,7 @@ let eq_constr_univs univs m n = let leq_constr_univs univs m n = if m == n then true else - let eq_universes _ _ = UGraph.check_eq_instances univs in + let eq_universes _ = UGraph.check_eq_instances univs in let eq_sorts s1 s2 = s1 == s2 || UGraph.check_eq univs (Sorts.univ_of_sort s1) (Sorts.univ_of_sort s2) in let leq_sorts s1 s2 = s1 == s2 || @@ -965,7 +1005,7 @@ let eq_constr_univs_infer univs m n = if m == n then true, Constraint.empty else let cstrs = ref Constraint.empty in - let eq_universes _ _ = UGraph.check_eq_instances univs in + let eq_universes _ = UGraph.check_eq_instances univs in let eq_sorts s1 s2 = if Sorts.equal s1 s2 then true else @@ -985,7 +1025,7 @@ let leq_constr_univs_infer univs m n = if m == n then true, Constraint.empty else let cstrs = ref Constraint.empty in - let eq_universes _ _ l l' = UGraph.check_eq_instances univs l l' in + let eq_universes _ l l' = UGraph.check_eq_instances univs l l' in let eq_sorts s1 s2 = if Sorts.equal s1 s2 then true else @@ -1015,7 +1055,16 @@ let leq_constr_univs_infer univs m n = res, !cstrs let rec eq_constr_nounivs m n = - (m == n) || compare_head_gen (fun _ _ _ _ -> true) (fun _ _ -> true) (fun _ -> eq_constr_nounivs) 0 m n + (m == n) || compare_head_gen (fun _ _ _ -> true) (fun _ _ -> true) (fun _ -> eq_constr_nounivs) 0 m n + +let compare_invert f iv1 iv2 = + match iv1, iv2 with + | NoInvert, NoInvert -> 0 + | NoInvert, CaseInvert _ -> -1 + | CaseInvert _, NoInvert -> 1 + | CaseInvert iv1, CaseInvert iv2 -> + (* univs ignored deliberately *) + Array.compare f iv1.args iv2.args let constr_ord_int f t1 t2 = let (=?) f g i1 i2 j1 j2= @@ -1060,8 +1109,12 @@ let constr_ord_int f t1 t2 = | Ind _, _ -> -1 | _, Ind _ -> 1 | Construct (ct1,_u1), Construct (ct2,_u2) -> constructor_ord ct1 ct2 | Construct _, _ -> -1 | _, Construct _ -> 1 - | Case (_,p1,c1,bl1), Case (_,p2,c2,bl2) -> - ((f =? f) ==? (Array.compare f)) p1 p2 c1 c2 bl1 bl2 + | Case (_,p1,iv1,c1,bl1), Case (_,p2,iv2,c2,bl2) -> + let c = f p1 p2 in + if Int.equal c 0 then let c = compare_invert f iv1 iv2 in + if Int.equal c 0 then let c = f c1 c2 in + if Int.equal c 0 then Array.compare f bl1 bl2 + else c else c else c | Case _, _ -> -1 | _, Case _ -> 1 | Fix (ln1,(_,tl1,bl1)), Fix (ln2,(_,tl2,bl2)) -> ((fix_cmp =? (Array.compare f)) ==? (Array.compare f)) @@ -1129,6 +1182,14 @@ let array_eqeq t1 t2 = (Int.equal i (Array.length t1)) || (t1.(i) == t2.(i) && aux (i + 1)) in aux 0) +let invert_eqeq iv1 iv2 = + match iv1, iv2 with + | NoInvert, NoInvert -> true + | NoInvert, CaseInvert _ | CaseInvert _, NoInvert -> false + | CaseInvert iv1, CaseInvert iv2 -> + iv1.univs == iv2.univs + && iv1.args == iv2.args + let hasheq t1 t2 = match t1, t2 with | Rel n1, Rel n2 -> n1 == n2 @@ -1146,8 +1207,8 @@ let hasheq t1 t2 = | Const (c1,u1), Const (c2,u2) -> c1 == c2 && u1 == u2 | Ind (ind1,u1), Ind (ind2,u2) -> ind1 == ind2 && u1 == u2 | Construct (cstr1,u1), Construct (cstr2,u2) -> cstr1 == cstr2 && u1 == u2 - | Case (ci1,p1,c1,bl1), Case (ci2,p2,c2,bl2) -> - ci1 == ci2 && p1 == p2 && c1 == c2 && array_eqeq bl1 bl2 + | Case (ci1,p1,iv1,c1,bl1), Case (ci2,p2,iv2,c2,bl2) -> + ci1 == ci2 && p1 == p2 && invert_eqeq iv1 iv2 && c1 == c2 && array_eqeq bl1 bl2 | Fix ((ln1, i1),(lna1,tl1,bl1)), Fix ((ln2, i2),(lna2,tl2,bl2)) -> Int.equal i1 i2 && Array.equal Int.equal ln1 ln2 @@ -1236,12 +1297,13 @@ let hashcons (sh_sort,sh_ci,sh_construct,sh_ind,sh_con,sh_na,sh_id) = let u', hu = sh_instance u in (Construct (sh_construct c, u'), combinesmall 11 (combine (constructor_syntactic_hash c) hu)) - | Case (ci,p,c,bl) -> + | Case (ci,p,iv,c,bl) -> let p, hp = sh_rec p + and iv, hiv = sh_invert iv and c, hc = sh_rec c in let bl,hbl = hash_term_array bl in - let hbl = combine (combine hc hp) hbl in - (Case (sh_ci ci, p, c, bl), combinesmall 12 hbl) + let hbl = combine4 hc hp hiv hbl in + (Case (sh_ci ci, p, iv, c, bl), combinesmall 12 hbl) | Fix (ln,(lna,tl,bl)) -> let bl,hbl = hash_term_array bl in let tl,htl = hash_term_array tl in @@ -1271,6 +1333,13 @@ let hashcons (sh_sort,sh_ci,sh_construct,sh_ind,sh_con,sh_na,sh_id) = (t, combinesmall 18 (combine h l)) | Float f -> (t, combinesmall 19 (Float64.hash f)) + and sh_invert = function + | NoInvert -> NoInvert, 0 + | CaseInvert {univs;args;} -> + let univs, hu = sh_instance univs in + let args, ha = hash_term_array args in + CaseInvert {univs;args;}, combinesmall 1 (combine hu ha) + and sh_rec t = let (y, h) = hash_term t in (* [h] must be positive. *) @@ -1332,8 +1401,8 @@ let rec hash t = combinesmall 10 (combine (ind_hash ind) (Instance.hash u)) | Construct (c,u) -> combinesmall 11 (combine (constructor_hash c) (Instance.hash u)) - | Case (_ , p, c, bl) -> - combinesmall 12 (combine3 (hash c) (hash p) (hash_term_array bl)) + | Case (_ , p, iv, c, bl) -> + combinesmall 12 (combine4 (hash c) (hash p) (hash_invert iv) (hash_term_array bl)) | Fix (_ln ,(_, tl, bl)) -> combinesmall 13 (combine (hash_term_array bl) (hash_term_array tl)) | CoFix(_ln, (_, tl, bl)) -> @@ -1345,6 +1414,11 @@ let rec hash t = | Int i -> combinesmall 18 (Uint63.hash i) | Float f -> combinesmall 19 (Float64.hash f) +and hash_invert = function + | NoInvert -> 0 + | CaseInvert {univs;args;} -> + combinesmall 1 (combine (Instance.hash univs) (hash_term_array args)) + and hash_term_array t = Array.fold_left (fun acc t -> combine acc (hash t)) 0 t @@ -1476,9 +1550,9 @@ let rec debug_print c = | Construct (((sp,i),j),u) -> str"Constr(" ++ pr_puniverses (MutInd.print sp ++ str"," ++ int i ++ str"," ++ int j) u ++ str")" | Proj (p,c) -> str"Proj(" ++ Constant.debug_print (Projection.constant p) ++ str"," ++ bool (Projection.unfolded p) ++ debug_print c ++ str")" - | Case (_ci,p,c,bl) -> v 0 + | Case (_ci,p,iv,c,bl) -> v 0 (hv 0 (str"<"++debug_print p++str">"++ cut() ++ str"Case " ++ - debug_print c ++ str"of") ++ cut() ++ + debug_print c ++ debug_invert iv ++ str"of") ++ cut() ++ prlist_with_sep (fun _ -> brk(1,2)) debug_print (Array.to_list bl) ++ cut() ++ str"end") | Fix f -> debug_print_fix debug_print f @@ -1492,3 +1566,9 @@ let rec debug_print c = str"}") | Int i -> str"Int("++str (Uint63.to_string i) ++ str")" | Float i -> str"Float("++str (Float64.to_string i) ++ str")" + +and debug_invert = let open Pp in function + | NoInvert -> mt() + | CaseInvert {univs;args;} -> + spc() ++ str"Invert {univs=" ++ Instance.pr Level.pr univs ++ + str "; args=" ++ prlist_with_sep spc debug_print (Array.to_list args) ++ str "} " -- cgit v1.2.3