diff options
| author | BESSON Frederic | 2021-02-08 15:07:42 +0100 |
|---|---|---|
| committer | BESSON Frederic | 2021-02-10 09:59:18 +0100 |
| commit | 68c3ffa6db6139081dab196bf3617214862a52af (patch) | |
| tree | 3c9effe99a3ff771fa1e9520ed5a441f527e6c1d /plugins | |
| parent | 132b2e240e1869be5ca0cc7200aa4ab6a94f2275 (diff) | |
[micromega/nia] Improve sharing of proofs
Closes #13794
Diffstat (limited to 'plugins')
| -rw-r--r-- | plugins/micromega/certificate.ml | 70 | ||||
| -rw-r--r-- | plugins/micromega/polynomial.ml | 293 | ||||
| -rw-r--r-- | plugins/micromega/polynomial.mli | 4 |
3 files changed, 209 insertions, 158 deletions
diff --git a/plugins/micromega/certificate.ml b/plugins/micromega/certificate.ml index ed608bb1df..53aa619d10 100644 --- a/plugins/micromega/certificate.ml +++ b/plugins/micromega/certificate.ml @@ -223,6 +223,28 @@ let find_point l = let optimise v l = if use_simplex () then Simplex.optimise v l else Mfourier.Fourier.optimise v l +let output_cstr_sys o sys = + List.iter + (fun (c, wp) -> + Printf.fprintf o "%a by %a\n" output_cstr c ProofFormat.output_prf_rule wp) + sys + +let output_sys o sys = + List.iter (fun s -> Printf.fprintf o "%a\n" WithProof.output s) sys + +let tr_sys str f sys = + let sys' = f sys in + if debug then + Printf.fprintf stdout "[%s\n%a=>\n%a]\n" str output_sys sys output_sys sys'; + sys' + +let tr_cstr_sys str f sys = + let sys' = f sys in + if debug then + Printf.fprintf stdout "[%s\n%a=>\n%a]\n" str output_cstr_sys sys + output_cstr_sys sys'; + sys' + let dual_raw_certificate l = if debug then begin Printf.printf "dual_raw_certificate\n"; @@ -375,25 +397,7 @@ let elim_simple_linear_equality sys0 = in iterate_until_stable elim sys0 -let output_sys o sys = - List.iter (fun s -> Printf.fprintf o "%a\n" WithProof.output s) sys - -let subst sys = - let sys' = WithProof.subst sys in - if debug then - Printf.fprintf stdout "[subst:\n%a\n==>\n%a\n]" output_sys sys output_sys - sys'; - sys' - -let tr_sys str f sys = - let sys' = f sys in - if debug then ( - Printf.fprintf stdout "[%s\n" str; - List.iter (fun s -> Printf.fprintf stdout "%a\n" WithProof.output s) sys; - Printf.fprintf stdout "\n => \n"; - List.iter (fun s -> Printf.fprintf stdout "%a\n" WithProof.output s) sys'; - Printf.fprintf stdout "]\n" ); - sys' +let subst sys = tr_sys "subst" WithProof.subst sys (** [saturate_linear_equality sys] generate new constraints obtained by eliminating linear equalities by pivoting. @@ -489,12 +493,10 @@ let nlinear_preprocess (sys : WithProof.t list) = ISet.fold (fun i acc -> square_of_var i :: acc) collect_vars sys in let sys = sys @ all_pairs WithProof.product sys in - if debug then begin - Printf.fprintf stdout "Preprocessed\n"; - List.iter (fun s -> Printf.fprintf stdout "%a\n" WithProof.output s) sys - end; List.map (WithProof.annot "P") sys +let nlinear_preprocess = tr_sys "nlinear_preprocess" nlinear_preprocess + let nlinear_prover prfdepth sys = let sys = develop_constraints prfdepth q_spec sys in let sys1 = elim_simple_linear_equality sys in @@ -698,6 +700,15 @@ let pivot v (c1, p1) (c2, p2) = Some (xpivot cv1 cv2) else None +let pivot v c1 c2 = + let res = pivot v c1 c2 in + ( match res with + | None -> () + | Some (c, _) -> + if Vect.get v c.coeffs =/ Q.zero then () + else Printf.printf "pivot error %a\n" output_cstr c ); + res + (* op2 could be Eq ... this might happen *) let simpl_sys sys = @@ -762,6 +773,8 @@ let reduce_coprime psys = in Some (pivot_sys v (cstr, prf) ((c1, p1) :: sys)) +(*let pivot_sys v pc sys = tr_cstr_sys "pivot_sys" (pivot_sys v pc) sys*) + (** If there is an equation [eq] of the form 1.x + e = c, do a pivot over x with equation [eq] *) let reduce_unary psys = let is_unary_equation (cstr, prf) = @@ -820,6 +833,8 @@ let reduction_equations psys = [reduce_unary; reduce_coprime; reduce_var_change (*; reduce_pivot*)]) psys +let reduction_equations = tr_cstr_sys "reduction_equations" reduction_equations + (** [get_bound sys] returns upon success an interval (lb,e,ub) with proofs *) let get_bound sys = let is_small (v, i) = @@ -891,11 +906,6 @@ let check_sys sys = open ProofFormat -let output_cstr_sys sys = - (pp_list ";" (fun o (c, wp) -> - Printf.fprintf o "%a by %a" output_cstr c ProofFormat.output_prf_rule wp)) - sys - let xlia (can_enum : bool) reduction_equations sys = let rec enum_proof (id : int) (sys : prf_sys) = if debug then ( @@ -1170,7 +1180,9 @@ let nlia enum prfdepth sys = No: if a wrong equation is chosen, the proof may fail. It would only be safe if the variable is linear... *) - let sys1 = elim_simple_linear_equality sys in + let sys1 = + elim_simple_linear_equality (WithProof.subst_constant true sys) + in let sys2 = saturate_by_linear_equalities sys1 in let sys3 = nlinear_preprocess (sys1 @ sys2) in let sys4 = make_cstr_system (*sys2@*) sys3 in diff --git a/plugins/micromega/polynomial.ml b/plugins/micromega/polynomial.ml index 7b29aa15f9..024fc6dade 100644 --- a/plugins/micromega/polynomial.ml +++ b/plugins/micromega/polynomial.ml @@ -485,7 +485,7 @@ module ProofFormat = struct let rec output_proof o = function | Done -> Printf.fprintf o "." | Step (i, p, pf) -> - Printf.fprintf o "%i:= %a ; %a" i output_prf_rule p output_proof pf + Printf.fprintf o "%i:= %a\n ; %a" i output_prf_rule p output_proof pf | Split (i, v, p1, p2) -> Printf.fprintf o "%i:=%a ; { %a } { %a }" i Vect.pp v output_proof p1 output_proof p2 @@ -496,6 +496,48 @@ module ProofFormat = struct Printf.fprintf o "%i := %i = %i - %i ; %i := %i >= 0 ; %i := %i >= 0 ; %a" i x z t j z k t output_proof pr + module OrdPrfRule = struct + type t = prf_rule + + let id_of_constr = function + | Annot _ -> 0 + | Hyp _ -> 1 + | Def _ -> 2 + | Cst _ -> 3 + | Zero -> 4 + | Square _ -> 5 + | MulC _ -> 6 + | Gcd _ -> 7 + | MulPrf _ -> 8 + | AddPrf _ -> 9 + | CutPrf _ -> 10 + + let cmp_pair c1 c2 (x1, x2) (y1, y2) = + match c1 x1 y1 with 0 -> c2 x2 y2 | i -> i + + let rec compare p1 p2 = + match (p1, p2) with + | Annot (s1, p1), Annot (s2, p2) -> + if s1 = s2 then compare p1 p2 else String.compare s1 s2 + | Hyp i, Hyp j -> Int.compare i j + | Def i, Def j -> Int.compare i j + | Cst n, Cst m -> Q.compare n m + | Zero, Zero -> 0 + | Square v1, Square v2 -> Vect.compare v1 v2 + | MulC (v1, p1), MulC (v2, p2) -> + cmp_pair Vect.compare compare (v1, p1) (v2, p2) + | Gcd (b1, p1), Gcd (b2, p2) -> + cmp_pair Z.compare compare (b1, p1) (b2, p2) + | MulPrf (p1, q1), MulPrf (p2, q2) -> + cmp_pair compare compare (p1, q1) (p2, q2) + | AddPrf (p1, q1), AddPrf (p2, q2) -> + cmp_pair compare compare (p1, q1) (p2, q2) + | CutPrf p, CutPrf p' -> compare p p' + | _, _ -> Int.compare (id_of_constr p1) (id_of_constr p2) + end + + module PrfRuleMap = Map.Make (OrdPrfRule) + let rec pr_size = function | Annot (_, p) -> pr_size p | Zero | Square _ -> Q.zero @@ -537,33 +579,38 @@ module ProofFormat = struct (** [pr_rule_def_cut id pr] gives an explicit [id] to cut rules. This is because the Coq proof format only accept they as a proof-step *) - let rec pr_rule_def_cut id = function - | Annot (_, p) -> pr_rule_def_cut id p - | MulC (p, prf) -> - let bds, id', prf' = pr_rule_def_cut id prf in - (bds, id', MulC (p, prf')) - | MulPrf (p1, p2) -> - let bds1, id, p1 = pr_rule_def_cut id p1 in - let bds2, id, p2 = pr_rule_def_cut id p2 in - (bds2 @ bds1, id, MulPrf (p1, p2)) - | AddPrf (p1, p2) -> - let bds1, id, p1 = pr_rule_def_cut id p1 in - let bds2, id, p2 = pr_rule_def_cut id p2 in - (bds2 @ bds1, id, AddPrf (p1, p2)) - | CutPrf p -> - let bds, id, p = pr_rule_def_cut id p in - ((id, p) :: bds, id + 1, Def id) - | Gcd (c, p) -> - let bds, id, p = pr_rule_def_cut id p in - ((id, p) :: bds, id + 1, Def id) - | (Square _ | Cst _ | Def _ | Hyp _ | Zero) as x -> ([], id, x) + let pr_rule_def_cut m id p = + let rec pr_rule_def_cut m id = function + | Annot (_, p) -> pr_rule_def_cut m id p + | MulC (p, prf) -> + let bds, m, id', prf' = pr_rule_def_cut m id prf in + (bds, m, id', MulC (p, prf')) + | MulPrf (p1, p2) -> + let bds1, m, id, p1 = pr_rule_def_cut m id p1 in + let bds2, m, id, p2 = pr_rule_def_cut m id p2 in + (bds2 @ bds1, m, id, MulPrf (p1, p2)) + | AddPrf (p1, p2) -> + let bds1, m, id, p1 = pr_rule_def_cut m id p1 in + let bds2, m, id, p2 = pr_rule_def_cut m id p2 in + (bds2 @ bds1, m, id, AddPrf (p1, p2)) + | CutPrf p | Gcd (_, p) -> ( + let bds, m, id, p = pr_rule_def_cut m id p in + try + let id' = PrfRuleMap.find p m in + (bds, m, id, Def id') + with Not_found -> + let m = PrfRuleMap.add p id m in + ((id, p) :: bds, m, id + 1, Def id) ) + | (Square _ | Cst _ | Def _ | Hyp _ | Zero) as x -> ([], m, id, x) + in + pr_rule_def_cut m id p (* Do not define top-level cuts *) - let pr_rule_def_cut id = function + let pr_rule_def_cut m id = function | CutPrf p -> - let bds, ids, p' = pr_rule_def_cut id p in - (bds, ids, CutPrf p') - | p -> pr_rule_def_cut id p + let bds, m, ids, p' = pr_rule_def_cut m id p in + (bds, m, ids, CutPrf p') + | p -> pr_rule_def_cut m id p let rec implicit_cut p = match p with CutPrf p -> implicit_cut p | _ -> p @@ -577,6 +624,69 @@ module ProofFormat = struct | MulPrf (p1, p2) | AddPrf (p1, p2) -> ISet.union (pr_rule_collect_defs p1) (pr_rule_collect_defs p2) + let add_proof x y = + match (x, y) with Zero, p | p, Zero -> p | _ -> AddPrf (x, y) + + let rec mul_cst_proof c p = + match p with + | Annot (s, p) -> Annot (s, mul_cst_proof c p) + | MulC (v, p') -> MulC (Vect.mul c v, p') + | _ -> ( + match Q.sign c with + | 0 -> Zero (* This is likely to be a bug *) + | -1 -> + MulC (LinPoly.constant c, p) (* [p] should represent an equality *) + | 1 -> if Q.one =/ c then p else MulPrf (Cst c, p) + | _ -> assert false ) + + let sMulC v p = + let c, v' = Vect.decomp_cst v in + if Vect.is_null v' then mul_cst_proof c p else MulC (v, p) + + let mul_proof p1 p2 = + match (p1, p2) with + | Zero, _ | _, Zero -> Zero + | Cst c, p | p, Cst c -> mul_cst_proof c p + | _, _ -> MulPrf (p1, p2) + + let prf_rule_of_map m = + PrfRuleMap.fold (fun k v acc -> add_proof (sMulC v k) acc) m Zero + + let rec dev_prf_rule p = + match p with + | Annot (s, p) -> dev_prf_rule p + | Hyp _ | Def _ | Cst _ | Zero | Square _ -> + PrfRuleMap.singleton p (LinPoly.constant Q.one) + | MulC (v, p) -> + PrfRuleMap.map (fun v1 -> LinPoly.product v v1) (dev_prf_rule p) + | AddPrf (p1, p2) -> + PrfRuleMap.merge + (fun k o1 o2 -> + match (o1, o2) with + | None, None -> None + | None, Some v | Some v, None -> Some v + | Some v1, Some v2 -> Some (LinPoly.addition v1 v2)) + (dev_prf_rule p1) (dev_prf_rule p2) + | MulPrf (p1, p2) -> ( + let p1' = dev_prf_rule p1 in + let p2' = dev_prf_rule p2 in + let p1'' = prf_rule_of_map p1' in + let p2'' = prf_rule_of_map p2' in + match p1'' with + | Cst c -> PrfRuleMap.map (fun v1 -> Vect.mul c v1) p2' + | _ -> PrfRuleMap.singleton (MulPrf (p1'', p2'')) (LinPoly.constant Q.one) + ) + | Gcd (c, p) -> + PrfRuleMap.singleton + (Gcd (c, prf_rule_of_map (dev_prf_rule p))) + (LinPoly.constant Q.one) + | CutPrf p -> + PrfRuleMap.singleton + (CutPrf (prf_rule_of_map (dev_prf_rule p))) + (LinPoly.constant Q.one) + + let simplify_prf_rule p = prf_rule_of_map (dev_prf_rule p) + (** [simplify_proof p] removes proof steps that are never re-used. *) let rec simplify_proof p = match p with @@ -618,7 +728,9 @@ module ProofFormat = struct | Done -> (id, Done) | Step (i, Gcd (c, p), Done) -> normalise_proof id (Step (i, p, Done)) | Step (i, p, prf) -> - let bds, id, p' = pr_rule_def_cut id p in + let bds, m, id, p' = + pr_rule_def_cut PrfRuleMap.empty id (simplify_prf_rule p) + in let id, prf = normalise_proof id prf in let prf = List.fold_left @@ -642,8 +754,10 @@ module ProofFormat = struct (List.fold_left max 0 ids , Enum(i,p1,v,p2,prfs)) *) - let bds1, id, p1' = pr_rule_def_cut id (implicit_cut p1) in - let bds2, id, p2' = pr_rule_def_cut id (implicit_cut p2) in + let bds1, m, id, p1' = + pr_rule_def_cut PrfRuleMap.empty id (implicit_cut p1) + in + let bds2, m, id, p2' = pr_rule_def_cut m id (implicit_cut p2) in let ids, prfs = List.split (List.map (normalise_proof id) pl) in ( List.fold_left max 0 ids , List.fold_left @@ -659,104 +773,6 @@ module ProofFormat = struct (snd res); res - module OrdPrfRule = struct - type t = prf_rule - - let id_of_constr = function - | Annot _ -> 0 - | Hyp _ -> 1 - | Def _ -> 2 - | Cst _ -> 3 - | Zero -> 4 - | Square _ -> 5 - | MulC _ -> 6 - | Gcd _ -> 7 - | MulPrf _ -> 8 - | AddPrf _ -> 9 - | CutPrf _ -> 10 - - let cmp_pair c1 c2 (x1, x2) (y1, y2) = - match c1 x1 y1 with 0 -> c2 x2 y2 | i -> i - - let rec compare p1 p2 = - match (p1, p2) with - | Annot (s1, p1), Annot (s2, p2) -> - if s1 = s2 then compare p1 p2 else String.compare s1 s2 - | Hyp i, Hyp j -> Int.compare i j - | Def i, Def j -> Int.compare i j - | Cst n, Cst m -> Q.compare n m - | Zero, Zero -> 0 - | Square v1, Square v2 -> Vect.compare v1 v2 - | MulC (v1, p1), MulC (v2, p2) -> - cmp_pair Vect.compare compare (v1, p1) (v2, p2) - | Gcd (b1, p1), Gcd (b2, p2) -> - cmp_pair Z.compare compare (b1, p1) (b2, p2) - | MulPrf (p1, q1), MulPrf (p2, q2) -> - cmp_pair compare compare (p1, p2) (q1, q2) - | AddPrf (p1, q1), AddPrf (p2, q2) -> - cmp_pair compare compare (p1, p2) (q1, q2) - | CutPrf p, CutPrf p' -> compare p p' - | _, _ -> Int.compare (id_of_constr p1) (id_of_constr p2) - end - - let add_proof x y = - match (x, y) with Zero, p | p, Zero -> p | _ -> AddPrf (x, y) - - let rec mul_cst_proof c p = - match p with - | Annot (s, p) -> Annot (s, mul_cst_proof c p) - | MulC (v, p') -> MulC (Vect.mul c v, p') - | _ -> ( - match Q.sign c with - | 0 -> Zero (* This is likely to be a bug *) - | -1 -> - MulC (LinPoly.constant c, p) (* [p] should represent an equality *) - | 1 -> if Q.one =/ c then p else MulPrf (Cst c, p) - | _ -> assert false ) - - let sMulC v p = - let c, v' = Vect.decomp_cst v in - if Vect.is_null v' then mul_cst_proof c p else MulC (v, p) - - let mul_proof p1 p2 = - match (p1, p2) with - | Zero, _ | _, Zero -> Zero - | Cst c, p | p, Cst c -> mul_cst_proof c p - | _, _ -> MulPrf (p1, p2) - - module PrfRuleMap = Map.Make (OrdPrfRule) - - let prf_rule_of_map m = - PrfRuleMap.fold (fun k v acc -> add_proof (sMulC v k) acc) m Zero - - let rec dev_prf_rule p = - match p with - | Annot (s, p) -> dev_prf_rule p - | Hyp _ | Def _ | Cst _ | Zero | Square _ -> - PrfRuleMap.singleton p (LinPoly.constant Q.one) - | MulC (v, p) -> - PrfRuleMap.map (fun v1 -> LinPoly.product v v1) (dev_prf_rule p) - | AddPrf (p1, p2) -> - PrfRuleMap.merge - (fun k o1 o2 -> - match (o1, o2) with - | None, None -> None - | None, Some v | Some v, None -> Some v - | Some v1, Some v2 -> Some (LinPoly.addition v1 v2)) - (dev_prf_rule p1) (dev_prf_rule p2) - | MulPrf (p1, p2) -> ( - let p1' = dev_prf_rule p1 in - let p2' = dev_prf_rule p2 in - let p1'' = prf_rule_of_map p1' in - let p2'' = prf_rule_of_map p2' in - match p1'' with - | Cst c -> PrfRuleMap.map (fun v1 -> Vect.mul c v1) p2' - | _ -> PrfRuleMap.singleton (MulPrf (p1'', p2'')) (LinPoly.constant Q.one) - ) - | _ -> PrfRuleMap.singleton p (LinPoly.constant Q.one) - - let simplify_prf_rule p = prf_rule_of_map (dev_prf_rule p) - (* let mul_proof p1 p2 = let res = mul_proof p1 p2 in @@ -835,7 +851,8 @@ module ProofFormat = struct Printf.printf "cmpl_pol_z %s %a\n" (Printexc.to_string x) LinPoly.pp lp; raise x - let rec cmpl_proof env = function + let rec cmpl_proof env prf = + match prf with | Done -> Mc.DoneProof | Step (i, p, prf) -> ( match p with @@ -1097,15 +1114,33 @@ module WithProof = struct in List.sort cmp (List.rev_map (fun wp -> (size wp, wp)) sys) - let subst sys0 = + let iterate_pivot p sys0 = let elim sys = - let oeq, sys' = extract (is_substitution true) sys in + let oeq, sys' = extract p sys in match oeq with | None -> None | Some (v, pc) -> simplify (linear_pivot sys0 pc v) sys' in iterate_until_stable elim (List.map snd (sort sys0)) + let subst_constant is_int sys = + let is_integer q = Q.(q =/ floor q) in + let is_constant ((c, o), p) = + match o with + | Ge | Gt -> None + | Eq -> ( + Vect.Bound.( + match of_vect c with + | None -> None + | Some b -> + if (not is_int) || is_integer (b.cst // b.coeff) then + Monomial.get_var (LinPoly.MonT.retrieve b.var) + else None) ) + in + iterate_pivot is_constant sys + + let subst sys0 = iterate_pivot (is_substitution true) sys0 + let saturate_subst b sys0 = let select = is_substitution b in let gen (v, pc) ((c, op), prf) = diff --git a/plugins/micromega/polynomial.mli b/plugins/micromega/polynomial.mli index 84b5421207..81c131fe78 100644 --- a/plugins/micromega/polynomial.mli +++ b/plugins/micromega/polynomial.mli @@ -393,6 +393,10 @@ module WithProof : sig val subst : t list -> t list + (** [subst_constant b sys] performs the equivalent of the 'subst' tactic of Coq + only if there is an equation a.x = c for a,c a constant and a divides c if b= true*) + val subst_constant : bool -> t list -> t list + (** [subst1 sys] performs a single substitution *) val subst1 : t list -> t list |
