aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--interp/constrintern.ml44
-rw-r--r--test-suite/bugs/closed/4932.v40
2 files changed, 74 insertions, 10 deletions
diff --git a/interp/constrintern.ml b/interp/constrintern.ml
index d3d7af87a3..1fcb49af2e 100644
--- a/interp/constrintern.ml
+++ b/interp/constrintern.ml
@@ -604,6 +604,23 @@ let rec subordinate_letins intern letins = function
| [] ->
letins,[]
+let terms_of_binders bl =
+ let rec term_of_pat = function
+ | PatVar (loc,Name id) -> CRef (Ident (loc,id), None)
+ | PatVar (loc,Anonymous) -> error "Cannot turn \"_\" into a term."
+ | PatCstr (loc,c,l,_) ->
+ let r = Qualid (loc,qualid_of_path (path_of_global (ConstructRef c))) in
+ let hole = CHole (loc,None,Misctypes.IntroAnonymous,None) in
+ let params = List.make (Inductiveops.inductive_nparams (fst c)) hole in
+ CAppExpl (loc,(None,r,None),params @ List.map term_of_pat l) in
+ let rec extract_variables = function
+ | BDRawDef (loc,(Name id,_,None,_))::l -> CRef (Ident (loc,id), None) :: extract_variables l
+ | BDRawDef (loc,(Name id,_,Some _,_))::l -> extract_variables l
+ | BDRawDef (loc,(Anonymous,_,_,_))::l -> error "Cannot turn \"_\" into a term."
+ | BDPattern (loc,(u,_),lvar,env,tyc) :: l -> term_of_pat u :: extract_variables l
+ | [] -> [] in
+ extract_variables bl
+
let instantiate_notation_constr loc intern ntnvars subst infos c =
let (terms,termlists,binders) = subst in
(* when called while defining a notation, avoid capturing the private binders
@@ -615,17 +632,24 @@ let instantiate_notation_constr loc intern ntnvars subst infos c =
| NVar id when Id.equal id ldots_var -> Option.get terminopt
| NVar id -> subst_var subst' (renaming, env) id
| NList (x,_,iter,terminator,lassoc) ->
- (try
+ let l,(scopt,subscopes) =
(* All elements of the list are in scopes (scopt,subscopes) *)
- let (l,(scopt,subscopes)) = Id.Map.find x termlists in
- let termin = aux (terms,None,None) subinfos terminator in
- let fold a t =
- let nterms = Id.Map.add x (a, (scopt, subscopes)) terms in
- aux (nterms,None,Some t) subinfos iter
- in
- List.fold_right fold (if lassoc then List.rev l else l) termin
- with Not_found ->
- anomaly (Pp.str "Inconsistent substitution of recursive notation"))
+ try
+ let l,scopes = Id.Map.find x termlists in
+ (if lassoc then List.rev l else l),scopes
+ with Not_found ->
+ try
+ let (bl,(scopt,subscopes)) = Id.Map.find x binders in
+ let env,bl' = List.fold_left (intern_local_binder_aux intern ntnvars) (env,[]) bl in
+ terms_of_binders (if lassoc then bl' else List.rev bl'),(None,[])
+ with Not_found ->
+ anomaly (Pp.str "Inconsistent substitution of recursive notation") in
+ let termin = aux (terms,None,None) subinfos terminator in
+ let fold a t =
+ let nterms = Id.Map.add x (a, (scopt, subscopes)) terms in
+ aux (nterms,None,Some t) subinfos iter
+ in
+ List.fold_right fold l termin
| NHole (knd, naming, arg) ->
let knd = match knd with
| Evar_kinds.BinderType (Name id as na) ->
diff --git a/test-suite/bugs/closed/4932.v b/test-suite/bugs/closed/4932.v
new file mode 100644
index 0000000000..df200cdce0
--- /dev/null
+++ b/test-suite/bugs/closed/4932.v
@@ -0,0 +1,40 @@
+(* Testing recursive notations with binders seen as terms *)
+
+Inductive ftele : Type :=
+| fb {T:Type} : T -> ftele
+| fr {T} : (T -> ftele) -> ftele.
+
+Fixpoint args ftele : Type :=
+ match ftele with
+ | fb _ => unit
+ | fr f => sigT (fun t => args (f t))
+ end.
+
+Definition fpack := sigT args.
+Definition pack fp fa : fpack := existT _ fp fa.
+
+Notation "'tele' x .. z := b" :=
+ (
+ (fun x => ..
+ (fun z =>
+ pack
+ (fr (fun x => .. ( fr (fun z => fb b) ) .. ) )
+ (existT _ x .. (existT _ z tt) .. )
+ ) ..
+ )
+ ) (at level 85, x binder, z binder).
+
+Check fun '((y,z):nat*nat) => pack (fr (fun '((y,z):nat*nat) => fb tt))
+ (existT _ (y,z) tt).
+
+Example test := tele (t : Type) := tt.
+Check test nat.
+
+Example test2 := tele (t : Type) (x:t) := tt.
+Check test2 nat 0.
+
+Check tele (t : Type) (y:=0) (x:t) := tt.
+Check (tele (t : Type) (y:=0) (x:t) := tt) nat 0.
+
+Check tele (t : Type) '((y,z):nat*nat) (x:t) := tt.
+Check (tele (t : Type) '((y,z):nat*nat) (x:t) := tt) nat (1,2) 3.