aboutsummaryrefslogtreecommitdiff
path: root/interp/notation_ops.ml
diff options
context:
space:
mode:
Diffstat (limited to 'interp/notation_ops.ml')
-rw-r--r--interp/notation_ops.ml164
1 files changed, 85 insertions, 79 deletions
diff --git a/interp/notation_ops.ml b/interp/notation_ops.ml
index f30a874426..265ca58ed9 100644
--- a/interp/notation_ops.ml
+++ b/interp/notation_ops.ml
@@ -90,9 +90,11 @@ let rec eq_notation_constr (vars1,vars2 as vars) t1 t2 = match t1, t2 with
(eq_notation_constr vars) t1 t2 && cast_type_eq (eq_notation_constr vars) c1 c2
| NInt i1, NInt i2 ->
Uint63.equal i1 i2
+| NFloat f1, NFloat f2 ->
+ Float64.equal f1 f2
| (NRef _ | NVar _ | NApp _ | NHole _ | NList _ | NLambda _ | NProd _
| NBinderList _ | NLetIn _ | NCases _ | NLetTuple _ | NIf _
- | NRec _ | NSort _ | NCast _ | NInt _), _ -> false
+ | NRec _ | NSort _ | NCast _ | NInt _ | NFloat _), _ -> false
(**********************************************************************)
(* Re-interpret a notation as a glob_constr, taking care of binders *)
@@ -135,13 +137,13 @@ let rec subst_glob_vars l gc = DAst.map (function
| GVar id as r -> (try DAst.get (Id.List.assoc id l) with Not_found -> r)
| GProd (Name id,bk,t,c) ->
let id =
- try match DAst.get (Id.List.assoc id l) with GVar id' -> id' | _ -> id
- with Not_found -> id in
+ try match DAst.get (Id.List.assoc id l) with GVar id' -> id' | _ -> id
+ with Not_found -> id in
GProd (Name id,bk,subst_glob_vars l t,subst_glob_vars l c)
| GLambda (Name id,bk,t,c) ->
let id =
- try match DAst.get (Id.List.assoc id l) with GVar id' -> id' | _ -> id
- with Not_found -> id in
+ try match DAst.get (Id.List.assoc id l) with GVar id' -> id' | _ -> id
+ with Not_found -> id in
GLambda (Name id,bk,subst_glob_vars l t,subst_glob_vars l c)
| GHole (x,naming,arg) -> GHole (subst_binder_type_vars l x,naming,arg)
| _ -> DAst.get (map_glob_constr (subst_glob_vars l) gc) (* assume: id is not binding *)
@@ -188,10 +190,10 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc =
| Some (disjpat,_id) -> DAst.get (apply_cases_pattern_term ?loc disjpat (f e b) (f e' c)))
| NCases (sty,rtntypopt,tml,eqnl) ->
let e',tml' = List.fold_right (fun (tm,(na,t)) (e',tml') ->
- let e',t' = match t with
- | None -> e',None
- | Some (ind,nal) ->
- let e',nal' = List.fold_right (fun na (e',nal) ->
+ let e',t' = match t with
+ | None -> e',None
+ | Some (ind,nal) ->
+ let e',nal' = List.fold_right (fun na (e',nal) ->
let e',na' = protect g e' na in
e',na'::nal) nal (e',[]) in
e',Some (CAst.make ?loc (ind,nal')) in
@@ -214,7 +216,7 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc =
| NRec (fk,idl,dll,tl,bl) ->
let e,dll = Array.fold_left_map (List.fold_left_map (fun e (na,oc,b) ->
let e,na = protect g e na in
- (e,(na,Explicit,Option.map (f e) oc,f e b)))) e dll in
+ (e,(na,Explicit,Option.map (f e) oc,f e b)))) e dll in
let e',idl = Array.fold_left_map (to_id (protect g)) e idl in
GRec (fk,idl,dll,Array.map (f e) tl,Array.map (f e') bl)
| NCast (c,k) -> GCast (f e c,map_cast_type (f e) k)
@@ -222,6 +224,7 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc =
| NHole (x, naming, arg) -> GHole (x, naming, arg)
| NRef x -> GRef (x,None)
| NInt i -> GInt i
+ | NFloat f -> GFloat f
let glob_constr_of_notation_constr ?loc x =
let rec aux () x =
@@ -359,8 +362,8 @@ let compare_recursive_parts recvars found f f' (iterator,subc) =
if aux iterator subc then
match !diff with
| None ->
- let loc1 = loc_of_glob_constr iterator in
- let loc2 = loc_of_glob_constr (Option.get !terminator) in
+ let loc1 = loc_of_glob_constr iterator in
+ let loc2 = loc_of_glob_constr (Option.get !terminator) in
(* Here, we would need a loc made of several parts ... *)
user_err ?loc:(subtract_loc loc1 loc2)
(str "Both ends of the recursive pattern are the same.")
@@ -397,15 +400,15 @@ let notation_constr_and_vars_of_glob_constr recvars a =
| GApp (t, [_]) ->
begin match DAst.get t with
| GVar f when Id.equal f ldots_var ->
- (* Fall on the second part of the recursive pattern w/o having
- found the first part *)
+ (* Fall on the second part of the recursive pattern w/o having
+ found the first part *)
let loc = t.CAst.loc in
- user_err ?loc
- (str "Cannot find where the recursive pattern starts.")
+ user_err ?loc
+ (str "Cannot find where the recursive pattern starts.")
| _ -> aux' c
end
| _c ->
- aux' c
+ aux' c
and aux' x = DAst.with_val (function
| GVar id -> if not (Id.equal id ldots_var) then add_id found id; NVar id
| GApp (g,args) -> NApp (aux g, List.map aux args)
@@ -416,8 +419,8 @@ let notation_constr_and_vars_of_glob_constr recvars a =
let f {CAst.v=(idl,pat,rhs)} = List.iter (add_id found) idl; (pat,aux rhs) in
NCases (sty,Option.map aux rtntypopt,
List.map (fun (tm,(na,x)) ->
- add_name found na;
- Option.iter
+ add_name found na;
+ Option.iter
(fun {CAst.v=(_,nl)} -> List.iter (add_name found) nl) x;
(aux tm,(na,Option.map (fun {CAst.v=(ind,nal)} -> (ind,nal)) x))) tml,
List.map f eqnl)
@@ -431,13 +434,14 @@ let notation_constr_and_vars_of_glob_constr recvars a =
| GRec (fk,idl,dll,tl,bl) ->
Array.iter (add_id found) idl;
let dll = Array.map (List.map (fun (na,bk,oc,b) ->
- if bk != Explicit then
- user_err Pp.(str "Binders marked as implicit not allowed in notations.");
- add_name found na; (na,Option.map aux oc,aux b))) dll in
+ if bk != Explicit then
+ user_err Pp.(str "Binders marked as implicit not allowed in notations.");
+ add_name found na; (na,Option.map aux oc,aux b))) dll in
NRec (fk,idl,dll,Array.map aux tl,Array.map aux bl)
| GCast (c,k) -> NCast (aux c,map_cast_type aux k)
| GSort s -> NSort s
| GInt i -> NInt i
+ | GFloat f -> NFloat f
| GHole (w,naming,arg) ->
if arg != None then has_ltac := true;
NHole (w, naming, arg)
@@ -461,7 +465,7 @@ let check_variables_and_reversibility nenv
let check_recvar x =
if Id.List.mem x found then
user_err (Id.print x ++
- strbrk " should only be used in the recursive part of a pattern.") in
+ strbrk " should only be used in the recursive part of a pattern.") in
let check (x, y) = check_recvar x; check_recvar y in
let () = List.iter check foundrec in
let () = List.iter check foundrecbinding in
@@ -472,7 +476,7 @@ let check_variables_and_reversibility nenv
Id.List.mem_assoc_sym x foundrec ||
Id.List.mem_assoc_sym x foundrecbinding
then
- user_err Pp.(str
+ user_err Pp.(str
(Id.to_string x ^
" should not be bound in a recursive pattern of the right-hand side."))
else injective := x :: !injective
@@ -480,19 +484,19 @@ let check_variables_and_reversibility nenv
let check_pair s x y where =
if not (mem_recursive_pair (x,y) where) then
user_err (strbrk "in the right-hand side, " ++ Id.print x ++
- str " and " ++ Id.print y ++ strbrk " should appear in " ++ str s ++
- str " position as part of a recursive pattern.") in
+ str " and " ++ Id.print y ++ strbrk " should appear in " ++ str s ++
+ str " position as part of a recursive pattern.") in
let check_type x typ =
match typ with
| NtnInternTypeAny ->
- begin
- try check_pair "term" x (Id.Map.find x recvars) foundrec
- with Not_found -> check_bound x
- end
+ begin
+ try check_pair "term" x (Id.Map.find x recvars) foundrec
+ with Not_found -> check_bound x
+ end
| NtnInternTypeOnlyBinder ->
- begin
- try check_pair "binding" x (Id.Map.find x recvars) foundrecbinding
- with Not_found -> check_bound x
+ begin
+ try check_pair "binding" x (Id.Map.find x recvars) foundrecbinding
+ with Not_found -> check_bound x
end in
Id.Map.iter check_type vars;
List.rev !injective
@@ -543,49 +547,49 @@ let rec subst_notation_constr subst bound raw =
| NApp (r,rl) ->
let r' = subst_notation_constr subst bound r
and rl' = List.Smart.map (subst_notation_constr subst bound) rl in
- if r' == r && rl' == rl then raw else
- NApp(r',rl')
+ if r' == r && rl' == rl then raw else
+ NApp(r',rl')
| NList (id1,id2,r1,r2,b) ->
let r1' = subst_notation_constr subst bound r1
and r2' = subst_notation_constr subst bound r2 in
- if r1' == r1 && r2' == r2 then raw else
- NList (id1,id2,r1',r2',b)
+ if r1' == r1 && r2' == r2 then raw else
+ NList (id1,id2,r1',r2',b)
| NLambda (n,r1,r2) ->
let r1' = subst_notation_constr subst bound r1
and r2' = subst_notation_constr subst bound r2 in
- if r1' == r1 && r2' == r2 then raw else
- NLambda (n,r1',r2')
+ if r1' == r1 && r2' == r2 then raw else
+ NLambda (n,r1',r2')
| NProd (n,r1,r2) ->
let r1' = subst_notation_constr subst bound r1
and r2' = subst_notation_constr subst bound r2 in
- if r1' == r1 && r2' == r2 then raw else
- NProd (n,r1',r2')
+ if r1' == r1 && r2' == r2 then raw else
+ NProd (n,r1',r2')
| NBinderList (id1,id2,r1,r2,b) ->
let r1' = subst_notation_constr subst bound r1
and r2' = subst_notation_constr subst bound r2 in
- if r1' == r1 && r2' == r2 then raw else
+ if r1' == r1 && r2' == r2 then raw else
NBinderList (id1,id2,r1',r2',b)
| NLetIn (n,r1,t,r2) ->
let r1' = subst_notation_constr subst bound r1 in
let t' = Option.Smart.map (subst_notation_constr subst bound) t in
let r2' = subst_notation_constr subst bound r2 in
- if r1' == r1 && t == t' && r2' == r2 then raw else
- NLetIn (n,r1',t',r2')
+ if r1' == r1 && t == t' && r2' == r2 then raw else
+ NLetIn (n,r1',t',r2')
| NCases (sty,rtntypopt,rl,branches) ->
let rtntypopt' = Option.Smart.map (subst_notation_constr subst bound) rtntypopt
and rl' = List.Smart.map
(fun (a,(n,signopt) as x) ->
- let a' = subst_notation_constr subst bound a in
- let signopt' = Option.map (fun ((indkn,i),nal as z) ->
- let indkn' = subst_mind subst indkn in
- if indkn == indkn' then z else ((indkn',i),nal)) signopt in
- if a' == a && signopt' == signopt then x else (a',(n,signopt')))
+ let a' = subst_notation_constr subst bound a in
+ let signopt' = Option.map (fun ((indkn,i),nal as z) ->
+ let indkn' = subst_mind subst indkn in
+ if indkn == indkn' then z else ((indkn',i),nal)) signopt in
+ if a' == a && signopt' == signopt then x else (a',(n,signopt')))
rl
and branches' = List.Smart.map
(fun (cpl,r as branch) ->
@@ -603,30 +607,31 @@ let rec subst_notation_constr subst bound raw =
let po' = Option.Smart.map (subst_notation_constr subst bound) po
and b' = subst_notation_constr subst bound b
and c' = subst_notation_constr subst bound c in
- if po' == po && b' == b && c' == c then raw else
- NLetTuple (nal,(na,po'),b',c')
+ if po' == po && b' == b && c' == c then raw else
+ NLetTuple (nal,(na,po'),b',c')
| NIf (c,(na,po),b1,b2) ->
let po' = Option.Smart.map (subst_notation_constr subst bound) po
and b1' = subst_notation_constr subst bound b1
and b2' = subst_notation_constr subst bound b2
and c' = subst_notation_constr subst bound c in
- if po' == po && b1' == b1 && b2' == b2 && c' == c then raw else
- NIf (c',(na,po'),b1',b2')
+ if po' == po && b1' == b1 && b2' == b2 && c' == c then raw else
+ NIf (c',(na,po'),b1',b2')
| NRec (fk,idl,dll,tl,bl) ->
let dll' =
Array.Smart.map (List.Smart.map (fun (na,oc,b as x) ->
let oc' = Option.Smart.map (subst_notation_constr subst bound) oc in
- let b' = subst_notation_constr subst bound b in
- if oc' == oc && b' == b then x else (na,oc',b'))) dll in
+ let b' = subst_notation_constr subst bound b in
+ if oc' == oc && b' == b then x else (na,oc',b'))) dll in
let tl' = Array.Smart.map (subst_notation_constr subst bound) tl in
let bl' = Array.Smart.map (subst_notation_constr subst bound) bl in
if dll' == dll && tl' == tl && bl' == bl then raw else
- NRec (fk,idl,dll',tl',bl')
+ NRec (fk,idl,dll',tl',bl')
| NSort _ -> raw
| NInt _ -> raw
+ | NFloat _ -> raw
| NHole (knd, naming, solve) ->
let nknd = match knd with
@@ -655,7 +660,7 @@ let abstract_return_type_context pi mklam tml rtno =
Option.map (fun rtn ->
let nal =
List.flatten (List.map (fun (_,(na,t)) ->
- match t with Some x -> (pi x)@[na] | None -> [na]) tml) in
+ match t with Some x -> (pi x)@[na] | None -> [na]) tml) in
List.fold_right mklam nal rtn)
rtno
@@ -1126,11 +1131,11 @@ let rec match_ inner u alp metas sigma a1 a2 =
| GApp (f1,l1), NApp (f2,l2) ->
let n1 = List.length l1 and n2 = List.length l2 in
let f1,l1,f2,l2 =
- if n1 < n2 then
- let l21,l22 = List.chop (n2-n1) l2 in f1,l1, NApp (f2,l21), l22
- else if n1 > n2 then
- let l11,l12 = List.chop (n1-n2) l1 in DAst.make ?loc @@ GApp (f1,l11),l12, f2,l2
- else f1,l1, f2, l2 in
+ if n1 < n2 then
+ let l21,l22 = List.chop (n2-n1) l2 in f1,l1, NApp (f2,l21), l22
+ else if n1 > n2 then
+ let l11,l12 = List.chop (n1-n2) l1 in DAst.make ?loc @@ GApp (f1,l11),l12, f2,l2
+ else f1,l1, f2, l2 in
let may_use_eta = does_not_come_from_already_eta_expanded_var f1 in
List.fold_left2 (match_ may_use_eta u alp metas)
(match_hd u alp metas sigma f1 f2) l1 l2
@@ -1149,8 +1154,8 @@ let rec match_ inner u alp metas sigma a1 a2 =
let rtno1' = abstract_return_type_context_glob_constr tml1 rtno1 in
let rtno2' = abstract_return_type_context_notation_constr tml2 rtno2 in
let sigma =
- try Option.fold_left2 (match_in u alp metas) sigma rtno1' rtno2'
- with Option.Heterogeneous -> raise No_match
+ try Option.fold_left2 (match_in u alp metas) sigma rtno1' rtno2'
+ with Option.Heterogeneous -> raise No_match
in
let sigma = List.fold_left2
(fun s (tm1,_) (tm2,_) ->
@@ -1168,24 +1173,24 @@ let rec match_ inner u alp metas sigma a1 a2 =
let sigma = match_opt (match_binders u alp metas na1 na2) sigma to1 to2 in
let sigma = match_in u alp metas sigma b1 b2 in
let (alp,sigma) =
- List.fold_left2 (match_names metas) (alp,sigma) nal1 nal2 in
+ List.fold_left2 (match_names metas) (alp,sigma) nal1 nal2 in
match_in u alp metas sigma c1 c2
| GIf (a1,(na1,to1),b1,c1), NIf (a2,(na2,to2),b2,c2) ->
let sigma = match_opt (match_binders u alp metas na1 na2) sigma to1 to2 in
List.fold_left2 (match_in u alp metas) sigma [a1;b1;c1] [a2;b2;c2]
| GRec (fk1,idl1,dll1,tl1,bl1), NRec (fk2,idl2,dll2,tl2,bl2)
when match_fix_kind fk1 fk2 && Int.equal (Array.length idl1) (Array.length idl2) &&
- Array.for_all2 (fun l1 l2 -> Int.equal (List.length l1) (List.length l2)) dll1 dll2
- ->
+ Array.for_all2 (fun l1 l2 -> Int.equal (List.length l1) (List.length l2)) dll1 dll2
+ ->
let alp,sigma = Array.fold_left2
- (List.fold_left2 (fun (alp,sigma) (na1,_,oc1,b1) (na2,oc2,b2) ->
- let sigma =
- match_in u alp metas
+ (List.fold_left2 (fun (alp,sigma) (na1,_,oc1,b1) (na2,oc2,b2) ->
+ let sigma =
+ match_in u alp metas
(match_opt (match_in u alp metas) sigma oc1 oc2) b1 b2
- in match_names metas (alp,sigma) na1 na2)) (alp,sigma) dll1 dll2 in
+ in match_names metas (alp,sigma) na1 na2)) (alp,sigma) dll1 dll2 in
let sigma = Array.fold_left2 (match_in u alp metas) sigma tl1 tl2 in
let alp,sigma = Array.fold_right2 (fun id1 id2 alsig ->
- match_names metas alsig (Name id1) (Name id2)) idl1 idl2 (alp,sigma) in
+ match_names metas alsig (Name id1) (Name id2)) idl1 idl2 (alp,sigma) in
Array.fold_left2 (match_in u alp metas) sigma bl1 bl2
| GCast(t1, c1), NCast(t2, c2) ->
match_cast (match_in u alp metas) (match_in u alp metas sigma t1 t2) c1 c2
@@ -1196,6 +1201,7 @@ let rec match_ inner u alp metas sigma a1 a2 =
| GSort s1, NSort s2 when glob_sort_eq s1 s2 -> sigma
| GInt i1, NInt i2 when Uint63.equal i1 i2 -> sigma
+ | GFloat f1, NFloat f2 when Float64.equal f1 f2 -> sigma
| GPatVar _, NHole _ -> (*Don't hide Metas, they bind in ltac*) raise No_match
| a, NHole _ -> sigma
@@ -1219,11 +1225,11 @@ let rec match_ inner u alp metas sigma a1 a2 =
bind_bindinglist_env alp sigma id [DAst.make @@ GLocalAssum (Name id',Explicit,t1)]
else
match_names metas (alp,sigma) (Name id') na in
- match_in u alp metas sigma (mkGApp a1 (DAst.make @@ GVar id')) b2
+ match_in u alp metas sigma (mkGApp a1 [DAst.make @@ GVar id']) b2
| (GRef _ | GVar _ | GEvar _ | GPatVar _ | GApp _ | GLambda _ | GProd _
| GLetIn _ | GCases _ | GLetTuple _ | GIf _ | GRec _ | GSort _ | GHole _
- | GCast _ | GInt _ ), _ -> raise No_match
+ | GCast _ | GInt _ | GFloat _), _ -> raise No_match
and match_in u = match_ true u
@@ -1345,9 +1351,9 @@ let rec match_cases_pattern metas (terms,termlists,(),() as sigma) a1 a2 =
let le2 = List.length l2 in
if Int.equal le2 0 (* Special case of a notation for a @Cstr *) || le2 > List.length l1
then
- raise No_match
+ raise No_match
else
- let l1',more_args = Util.List.chop le2 l1 in
+ let l1',more_args = Util.List.chop le2 l1 in
(List.fold_left2 (match_cases_pattern_no_more_args metas) sigma l1' l2),(le2,more_args)
| r1, NList (x,y,iter,termin,revert) ->
(match_cases_pattern_list (match_cases_pattern_no_more_args)
@@ -1368,10 +1374,10 @@ let match_ind_pattern metas sigma ind pats a2 =
let le2 = List.length l2 in
if Int.equal le2 0 (* Special case of a notation for a @Cstr *) || le2 > List.length pats
then
- raise No_match
+ raise No_match
else
- let l1',more_args = Util.List.chop le2 pats in
- (List.fold_left2 (match_cases_pattern_no_more_args metas) sigma l1' l2),(le2,more_args)
+ let l1',more_args = Util.List.chop le2 pats in
+ (List.fold_left2 (match_cases_pattern_no_more_args metas) sigma l1' l2),(le2,more_args)
|_ -> raise No_match
let reorder_canonically_substitution terms termlists metas =