aboutsummaryrefslogtreecommitdiff
path: root/interp
diff options
context:
space:
mode:
Diffstat (limited to 'interp')
-rw-r--r--interp/constrexpr.ml2
-rw-r--r--interp/constrexpr_ops.ml100
-rw-r--r--interp/constrextern.ml9
-rw-r--r--interp/constrintern.ml99
-rw-r--r--interp/notation_ops.ml31
-rw-r--r--interp/notation_ops.mli2
6 files changed, 151 insertions, 92 deletions
diff --git a/interp/constrexpr.ml b/interp/constrexpr.ml
index a5ff5df7cf..b3f06faa1c 100644
--- a/interp/constrexpr.ml
+++ b/interp/constrexpr.ml
@@ -147,7 +147,7 @@ and recursion_order_expr = recursion_order_expr_r CAst.t
and local_binder_expr =
| CLocalAssum of lname list * binder_kind * constr_expr
| CLocalDef of lname * constr_expr * constr_expr option
- | CLocalPattern of (cases_pattern_expr * constr_expr option) CAst.t
+ | CLocalPattern of cases_pattern_expr
and constr_notation_substitution =
constr_expr list * (* for constr subterms *)
diff --git a/interp/constrexpr_ops.ml b/interp/constrexpr_ops.ml
index fe107c3580..a60dc11b57 100644
--- a/interp/constrexpr_ops.ml
+++ b/interp/constrexpr_ops.ml
@@ -271,39 +271,37 @@ let is_constructor id =
(Nametab.locate_extended (qualid_of_ident id)))
with Not_found -> false
-let rec cases_pattern_fold_names f a pt = match CAst.(pt.v) with
+let rec cases_pattern_fold_names f h nacc pt = match CAst.(pt.v) with
| CPatRecord l ->
- List.fold_left (fun acc (r, cp) -> cases_pattern_fold_names f acc cp) a l
- | CPatAlias (pat,{CAst.v=na}) -> Name.fold_right f na (cases_pattern_fold_names f a pat)
+ List.fold_left (fun nacc (r, cp) -> cases_pattern_fold_names f h nacc cp) nacc l
+ | CPatAlias (pat,{CAst.v=na}) -> Name.fold_right (fun na (n,acc) -> (f na n,acc)) na (cases_pattern_fold_names f h nacc pat)
| CPatOr (patl) ->
- List.fold_left (cases_pattern_fold_names f) a patl
+ List.fold_left (cases_pattern_fold_names f h) nacc patl
| CPatCstr (_,patl1,patl2) ->
- List.fold_left (cases_pattern_fold_names f)
- (Option.fold_left (List.fold_left (cases_pattern_fold_names f)) a patl1) patl2
+ List.fold_left (cases_pattern_fold_names f h)
+ (Option.fold_left (List.fold_left (cases_pattern_fold_names f h)) nacc patl1) patl2
| CPatNotation (_,_,(patl,patll),patl') ->
- List.fold_left (cases_pattern_fold_names f)
- (List.fold_left (cases_pattern_fold_names f) a (patl@List.flatten patll)) patl'
- | CPatDelimiters (_,pat) -> cases_pattern_fold_names f a pat
+ List.fold_left (cases_pattern_fold_names f h)
+ (List.fold_left (cases_pattern_fold_names f h) nacc (patl@List.flatten patll)) patl'
+ | CPatDelimiters (_,pat) -> cases_pattern_fold_names f h nacc pat
| CPatAtom (Some qid)
when qualid_is_ident qid && not (is_constructor @@ qualid_basename qid) ->
- f (qualid_basename qid) a
- | CPatPrim _ | CPatAtom _ -> a
- | CPatCast ({CAst.loc},_) ->
- CErrors.user_err ?loc ~hdr:"cases_pattern_fold_names"
- (Pp.strbrk "Casts are not supported here.")
-
-let ids_of_pattern =
- cases_pattern_fold_names Id.Set.add Id.Set.empty
-
-let ids_of_pattern_list =
- List.fold_left
- (List.fold_left (cases_pattern_fold_names Id.Set.add))
- Id.Set.empty
+ let (n, acc) = nacc in
+ (f (qualid_basename qid) n, acc)
+ | CPatPrim _ | CPatAtom _ -> nacc
+ | CPatCast (p,t) ->
+ let (n, acc) = nacc in
+ cases_pattern_fold_names f h (n, h acc t) p
+
+let ids_of_pattern_list p =
+ fst (List.fold_left
+ (List.fold_left (cases_pattern_fold_names Id.Set.add (fun () _ -> ())))
+ (Id.Set.empty,()) p)
let ids_of_cases_tomatch tms =
List.fold_right
(fun (_, ona, indnal) l ->
- Option.fold_right (fun t ids -> cases_pattern_fold_names Id.Set.add ids t)
+ Option.fold_right (fun t ids -> fst (cases_pattern_fold_names Id.Set.add (fun () _ -> ()) (ids,()) t))
indnal
(Option.fold_right (CAst.with_val (Name.fold_right Id.Set.add)) ona l))
tms Id.Set.empty
@@ -315,9 +313,9 @@ let rec fold_local_binders g f n acc b = let open CAst in function
f n (fold_local_binders g f n' acc b l) t
| CLocalDef ( { v = na },c,t)::l ->
Option.fold_left (f n) (f n (fold_local_binders g f (Name.fold_right g na n) acc b l) c) t
- | CLocalPattern { v = pat,t }::l ->
- let acc = fold_local_binders g f (cases_pattern_fold_names g n pat) acc b l in
- Option.fold_left (f n) acc t
+ | CLocalPattern pat :: l ->
+ let n, acc = cases_pattern_fold_names g (f n) (n,acc) pat in
+ fold_local_binders g f n acc b l
| [] ->
f n acc b
@@ -381,10 +379,42 @@ let names_of_constr_expr c =
let occur_var_constr_expr id c = Id.Set.mem id (free_vars_of_constr_expr c)
+let rec fold_map_cases_pattern f h acc (CAst.{v=pt;loc} as p) = match pt with
+ | CPatRecord l ->
+ let acc, l = List.fold_left_map (fun acc (r, cp) -> let acc, cp = fold_map_cases_pattern f h acc cp in acc, (r, cp)) acc l in
+ acc, CAst.make ?loc (CPatRecord l)
+ | CPatAlias (pat,({CAst.v=na} as lna)) ->
+ let acc, p = fold_map_cases_pattern f h acc pat in
+ let acc = Name.fold_right f na acc in
+ acc, CAst.make ?loc (CPatAlias (pat,lna))
+ | CPatOr patl ->
+ let acc, patl = List.fold_left_map (fold_map_cases_pattern f h) acc patl in
+ acc, CAst.make ?loc (CPatOr patl)
+ | CPatCstr (c,patl1,patl2) ->
+ let acc, patl1 = Option.fold_left_map (List.fold_left_map (fold_map_cases_pattern f h)) acc patl1 in
+ let acc, patl2 = List.fold_left_map (fold_map_cases_pattern f h) acc patl2 in
+ acc, CAst.make ?loc (CPatCstr (c,patl1,patl2))
+ | CPatNotation (sc,ntn,(patl,patll),patl') ->
+ let acc, patl = List.fold_left_map (fold_map_cases_pattern f h) acc patl in
+ let acc, patll = List.fold_left_map (List.fold_left_map (fold_map_cases_pattern f h)) acc patll in
+ let acc, patl' = List.fold_left_map (fold_map_cases_pattern f h) acc patl' in
+ acc, CAst.make ?loc (CPatNotation (sc,ntn,(patl,patll),patl'))
+ | CPatDelimiters (d,pat) ->
+ let acc, p = fold_map_cases_pattern f h acc pat in
+ acc, CAst.make ?loc (CPatDelimiters (d,pat))
+ | CPatAtom (Some qid)
+ when qualid_is_ident qid && not (is_constructor @@ qualid_basename qid) ->
+ f (qualid_basename qid) acc, p
+ | CPatPrim _ | CPatAtom _ -> (acc,p)
+ | CPatCast (pat,t) ->
+ let acc, pat = fold_map_cases_pattern f h acc pat in
+ let t = h acc t in
+ acc, CAst.make ?loc (CPatCast (pat,t))
+
(* Used in correctness and interface *)
let map_binder g e nal = List.fold_right (CAst.with_val (Name.fold_right g)) nal e
-let map_local_binders f g e bl =
+let fold_map_local_binders f g e bl =
(* TODO: avoid variable capture in [t] by some [na] in [List.tl nal] *)
let open CAst in
let h (e,bl) = function
@@ -392,9 +422,9 @@ let map_local_binders f g e bl =
(map_binder g e nal, CLocalAssum(nal,k,f e ty)::bl)
| CLocalDef( { loc ; v = na } as cna ,c,ty) ->
(Name.fold_right g na e, CLocalDef(cna,f e c,Option.map (f e) ty)::bl)
- | CLocalPattern { loc; v = pat,t } ->
- let ids = ids_of_pattern pat in
- (Id.Set.fold g ids e, CLocalPattern (make ?loc (pat,Option.map (f e) t))::bl) in
+ | CLocalPattern pat ->
+ let e, pat = fold_map_cases_pattern g f e pat in
+ (e, CLocalPattern pat::bl) in
let (e,rbl) = List.fold_left h (e,[]) bl in
(e, List.rev rbl)
@@ -403,16 +433,16 @@ let map_constr_expr_with_binders g f e = CAst.map (function
| CApp ((p,a),l) ->
CApp ((p,f e a),List.map (fun (a,i) -> (f e a,i)) l)
| CProdN (bl,b) ->
- let (e,bl) = map_local_binders f g e bl in CProdN (bl,f e b)
+ let (e,bl) = fold_map_local_binders f g e bl in CProdN (bl,f e b)
| CLambdaN (bl,b) ->
- let (e,bl) = map_local_binders f g e bl in CLambdaN (bl,f e b)
+ let (e,bl) = fold_map_local_binders f g e bl in CLambdaN (bl,f e b)
| CLetIn (na,a,t,b) ->
CLetIn (na,f e a,Option.map (f e) t,f (Name.fold_right g (na.CAst.v) e) b)
| CCast (a,c) -> CCast (f e a, Glob_ops.map_cast_type (f e) c)
| CNotation (inscope,n,(l,ll,bl,bll)) ->
(* This is an approximation because we don't know what binds what *)
CNotation (inscope,n,(List.map (f e) l,List.map (List.map (f e)) ll, bl,
- List.map (fun bl -> snd (map_local_binders f g e bl)) bll))
+ List.map (fun bl -> snd (fold_map_local_binders f g e bl)) bll))
| CGeneralization (b,a,c) -> CGeneralization (b,a,f e c)
| CDelimiters (s,a) -> CDelimiters (s,f e a)
| CHole _ | CEvar _ | CPatVar _ | CSort _
@@ -434,7 +464,7 @@ let map_constr_expr_with_binders g f e = CAst.map (function
CIf (f e c,(ona,Option.map (f e') po),f e b1,f e b2)
| CFix (id,dl) ->
CFix (id,List.map (fun (id,n,bl,t,d) ->
- let (e',bl') = map_local_binders f g e bl in
+ let (e',bl') = fold_map_local_binders f g e bl in
let t' = f e' t in
(* Note: fix names should be inserted before the arguments... *)
let e'' = List.fold_left (fun e ({ CAst.v = id },_,_,_,_) -> g id e) e' dl in
@@ -442,7 +472,7 @@ let map_constr_expr_with_binders g f e = CAst.map (function
(id,n,bl',t',d')) dl)
| CCoFix (id,dl) ->
CCoFix (id,List.map (fun (id,bl,t,d) ->
- let (e',bl') = map_local_binders f g e bl in
+ let (e',bl') = fold_map_local_binders f g e bl in
let t' = f e' t in
let e'' = List.fold_left (fun e ({ CAst.v = id },_,_,_) -> g id e) e' dl in
let d' = f e'' d in
diff --git a/interp/constrextern.ml b/interp/constrextern.ml
index aa3a458989..cf88036f73 100644
--- a/interp/constrextern.ml
+++ b/interp/constrextern.ml
@@ -1126,7 +1126,7 @@ and factorize_prod ?impargs scopes vars na bk t c =
let disjpat = if occur_glob_constr id b then List.map (set_pat_alias id) disjpat else disjpat in
let b = extern_typ scopes vars b in
let p = mkCPatOr (List.map (extern_cases_pattern_in_scope scopes vars) disjpat) in
- let binder = CLocalPattern (make ?loc:c.loc (p,None)) in
+ let binder = CLocalPattern p in
(match b.v with
| CProdN (bl,b) -> CProdN (binder::bl,b)
| _ -> CProdN ([binder],b))
@@ -1167,7 +1167,7 @@ and factorize_lambda inctx scopes vars na bk t c =
let disjpat = if occur_glob_constr id b then List.map (set_pat_alias id) disjpat else disjpat in
let b = sub_extern inctx scopes vars b in
let p = mkCPatOr (List.map (extern_cases_pattern_in_scope scopes vars) disjpat) in
- let binder = CLocalPattern (make ?loc:c.loc (p,None)) in
+ let binder = CLocalPattern p in
(match b.v with
| CLambdaN (bl,b) -> CLambdaN (binder::bl,b)
| _ -> CLambdaN ([binder],b))
@@ -1219,7 +1219,10 @@ and extern_local_binder scopes vars = function
if !Flags.raw_print then Some (extern_typ scopes vars ty) else None in
let p = mkCPatOr (List.map (extern_cases_pattern vars) p) in
let (assums,ids,l) = extern_local_binder scopes vars l in
- (assums,ids, CLocalPattern(CAst.make @@ (p,ty)) :: l)
+ let p = match ty with
+ | None -> p
+ | Some ty -> CAst.make @@ (CPatCast (p,ty)) in
+ (assums,ids, CLocalPattern p :: l)
and extern_eqn inctx scopes vars {CAst.loc;v=(ids,pll,c)} =
let pll = List.map (List.map (extern_cases_pattern_in_scope scopes vars)) pll in
diff --git a/interp/constrintern.ml b/interp/constrintern.ml
index cb2c5b5f4c..1a922eb9a4 100644
--- a/interp/constrintern.ml
+++ b/interp/constrintern.ml
@@ -586,7 +586,10 @@ let intern_letin_binder intern ntnvars env (({loc;v=na} as locna),def,ty) =
(push_name_env ntnvars impls env locna,
(na,Explicit,term,ty))
-let intern_cases_pattern_as_binder ?loc test_kind ntnvars env p =
+let intern_cases_pattern_as_binder intern test_kind ntnvars env bk (CAst.{v=p;loc} as pv) =
+ let p,t = match p with
+ | CPatCast (p, t) -> (p, Some t)
+ | _ -> (pv, None) in
let il,disjpat =
let (il, subst_disjpat) = !intern_cases_pattern_fwd test_kind ntnvars (env_for_pattern (reset_tmp_scope env)) p in
let substl,disjpat = List.split subst_disjpat in
@@ -594,12 +597,17 @@ let intern_cases_pattern_as_binder ?loc test_kind ntnvars env p =
user_err ?loc (str "Unsupported nested \"as\" clause.");
il,disjpat
in
- let env = List.fold_right (fun {loc;v=id} env -> push_name_env ntnvars [] env (make ?loc @@ Name id)) il env in
let na = alias_of_pat (List.hd disjpat) in
+ let env = List.fold_right (fun {loc;v=id} env -> push_name_env ntnvars [] env (make ?loc @@ Name id)) il env in
let ienv = Name.fold_right Id.Set.remove na env.ids in
let id = Namegen.next_name_away_with_default "pat" na ienv in
let na = make ?loc @@ Name id in
- env,((disjpat,il),id),na
+ let t = match t with
+ | Some t -> t
+ | None -> CAst.make ?loc @@ CHole(Some (Evar_kinds.BinderType na.v),IntroAnonymous,None) in
+ let _, bl' = intern_assumption intern ntnvars env [na] (Default bk) t in
+ let {v=(_,bk,t)} = List.hd bl' in
+ env,((disjpat,il),id),na,bk,t
let intern_local_binder_aux intern ntnvars (env,bl) = function
| CLocalAssum(nal,bk,ty) ->
@@ -609,17 +617,9 @@ let intern_local_binder_aux intern ntnvars (env,bl) = function
| CLocalDef( {loc; v=na} as locna,def,ty) ->
let env,(na,bk,def,ty) = intern_letin_binder intern ntnvars env (locna,def,ty) in
env, (DAst.make ?loc @@ GLocalDef (na,bk,def,ty)) :: bl
- | CLocalPattern {loc;v=(p,ty)} ->
- let tyc =
- match ty with
- | Some ty -> ty
- | None -> CAst.make ?loc @@ CHole(None,IntroAnonymous,None)
- in
- let env, ((disjpat,il),id),na = intern_cases_pattern_as_binder ?loc test_kind_tolerant ntnvars env p in
- let bk = Default Explicit in
- let _, bl' = intern_assumption intern ntnvars env [na] bk tyc in
- let {v=(_,bk,t)} = List.hd bl' in
- (env, (DAst.make ?loc @@ GLocalPattern((disjpat,List.map (fun x -> x.v) il),id,bk,t)) :: bl)
+ | CLocalPattern p ->
+ let env, ((disjpat,il),id),na,bk,t = intern_cases_pattern_as_binder intern test_kind_tolerant ntnvars env Explicit p in
+ (env, (DAst.make ?loc:p.CAst.loc @@ GLocalPattern((disjpat,List.map (fun x -> x.v) il),id,bk,t)) :: bl)
let intern_generalization intern env ntnvars loc bk ak c =
let c = intern {env with unb = true} c in
@@ -705,7 +705,7 @@ let is_patvar c =
let is_patvar_store store pat =
match DAst.get pat with
- | PatVar na -> ignore(store na); true
+ | PatVar na -> ignore(store (CAst.make ?loc:pat.loc na)); true
| _ -> false
let out_patvar = CAst.map_with_loc (fun ?loc -> function
@@ -714,19 +714,38 @@ let out_patvar = CAst.map_with_loc (fun ?loc -> function
| CPatAtom None -> Anonymous
| _ -> assert false)
-let traverse_binder intern_pat ntnvars (terms,_,binders,_ as subst) avoid (renaming,env) = function
- | Anonymous -> (renaming,env), None, Anonymous, Explicit
+let canonize_type = function
+ | None -> None
+ | Some t as t' ->
+ match DAst.get t with
+ | GHole (Evar_kinds.BinderType _,IntroAnonymous,None) -> None
+ | _ -> t'
+
+let set_type ty1 ty2 =
+ match canonize_type ty1, canonize_type ty2 with
+ (* Not a meta-binding binder, we use the type given in the notation *)
+ | _, None -> ty1
+ (* A meta-binding binder meta-bound to a possibly-typed pattern *)
+ (* the binder is supposed to come w/o an explicit type in the notation *)
+ | None, Some _ -> ty2
+ | Some ty1, Some t2 ->
+ (* An explicitly typed meta-binding binder, not supposed to be a pattern; checked in interp_notation *)
+ user_err ?loc:t2.CAst.loc Pp.(str "Unexpected type constraint in notation already providing a type constraint.")
+
+let traverse_binder intern_pat ntnvars (terms,_,binders,_ as subst) avoid (renaming,env) na ty =
+ match na with
+ | Anonymous -> (renaming,env), None, Anonymous, Explicit, set_type ty None
| Name id ->
let store,get = set_temporary_memory () in
let test_kind = test_kind_tolerant in
try
(* We instantiate binder name with patterns which may be parsed as terms *)
let pat = coerce_to_cases_pattern_expr (fst (Id.Map.find id terms)) in
- let env,((disjpat,ids),id),na = intern_pat test_kind ntnvars env pat in
+ let env,((disjpat,ids),id),na,bk,t = intern_pat test_kind ntnvars env Explicit pat in
let pat, na = match disjpat with
| [pat] when is_patvar_store store pat -> let na = get () in None, na
- | _ -> Some ((List.map (fun x -> x.v) ids,disjpat),id), na.v in
- (renaming,env), pat, na, Explicit
+ | _ -> Some ((List.map (fun x -> x.v) ids,disjpat),id), na in
+ (renaming,env), pat, na.v, bk, set_type ty (Some t)
with Not_found ->
try
(* Trying to associate a pattern *)
@@ -736,15 +755,16 @@ let traverse_binder intern_pat ntnvars (terms,_,binders,_ as subst) avoid (renam
(* Do not try to interpret a variable as a constructor *)
let na = out_patvar pat in
let env = push_name_env ntnvars [] env na in
- (renaming,env), None, na.v, bk
+ let ty' = DAst.make @@ GHole (Evar_kinds.BinderType na.CAst.v,IntroAnonymous,None) in
+ (renaming,env), None, na.v, bk, set_type ty (Some ty')
else
(* Interpret as a pattern *)
- let env,((disjpat,ids),id),na = intern_pat test_kind ntnvars env pat in
+ let env,((disjpat,ids),id),na,bk,t = intern_pat test_kind ntnvars env bk pat in
let pat, na =
match disjpat with
| [pat] when is_patvar_store store pat -> let na = get () in None, na
- | _ -> Some ((List.map (fun x -> x.v) ids,disjpat),id), na.v in
- (renaming,env), pat, na, bk
+ | _ -> Some ((List.map (fun x -> x.v) ids,disjpat),id), na in
+ (renaming,env), pat, na.v, bk, set_type ty (Some t)
with Not_found ->
(* Binders not bound in the notation do not capture variables *)
(* outside the notation (i.e. in the substitution) *)
@@ -752,7 +772,7 @@ let traverse_binder intern_pat ntnvars (terms,_,binders,_ as subst) avoid (renam
let renaming' =
if Id.equal id id' then renaming else Id.Map.add id id' renaming
in
- (renaming',env), None, Name id', Explicit
+ (renaming',env), None, Name id', Explicit, set_type ty None
type binder_action =
| AddLetIn of lname * constr_expr * constr_expr option
@@ -877,12 +897,13 @@ let instantiate_notation_constr loc intern intern_pat ntnvars subst infos c =
Id.Map.add id (gc, None) map
with Nametab.GlobalizationError _ -> map
in
- let mk_env' ((c,_bk), (onlyident,scopes)) =
- let nenv = set_env_scopes env scopes in
+ let mk_env' ((c,_bk), (onlyident,(tmp_scope,subscopes))) =
+ let nenv = {env with tmp_scope; scopes = subscopes @ env.scopes} in
let test_kind =
if onlyident then test_kind_ident_in_notation
else test_kind_pattern_in_notation in
- let _,((disjpat,_),_),_ = intern_pat test_kind ntnvars nenv c in
+ let _,((disjpat,_),_),_,_,_ty = intern_pat test_kind ntnvars nenv Explicit c in
+ (* TODO: use cast? *)
match disjpat with
| [pat] -> (glob_constr_of_cases_pattern (Global.env()) pat, None)
| _ -> error_cannot_coerce_disjunctive_pattern_term ?loc:c.loc ()
@@ -913,17 +934,6 @@ let instantiate_notation_constr loc intern intern_pat ntnvars subst infos c =
| NLambda (Name id,NHole _,c') when option_mem_assoc id binderopt ->
let binder = snd (Option.get binderopt) in
expand_binders ?loc mkGLambda [binder] (aux subst' (renaming,env) c')
- (* Two special cases to keep binder name synchronous with BinderType *)
- | NProd (na,NHole(Evar_kinds.BinderType na',naming,arg),c')
- when Name.equal na na' ->
- let subinfos,disjpat,na,bk = traverse_binder intern_pat ntnvars subst avoid subinfos na in
- let ty = DAst.make ?loc @@ GHole (Evar_kinds.BinderType na,naming,arg) in
- DAst.make ?loc @@ GProd (na,bk,ty,Option.fold_right apply_cases_pattern disjpat (aux subst' subinfos c'))
- | NLambda (na,NHole(Evar_kinds.BinderType na',naming,arg),c')
- when Name.equal na na' ->
- let subinfos,disjpat,na,bk = traverse_binder intern_pat ntnvars subst avoid subinfos na in
- let ty = DAst.make ?loc @@ GHole (Evar_kinds.BinderType na,naming,arg) in
- DAst.make ?loc @@ GLambda (na,bk,ty,Option.fold_right apply_cases_pattern disjpat (aux subst' subinfos c'))
| t ->
glob_constr_of_notation_constr_with_binders ?loc
(traverse_binder intern_pat ntnvars subst avoid) (aux subst') ~h:binder_status_fun subinfos t
@@ -935,12 +945,13 @@ let instantiate_notation_constr loc intern intern_pat ntnvars subst infos c =
intern (set_env_scopes env scopes) a
with Not_found ->
try
- let (pat,_bk),(onlyident,scopes) = Id.Map.find id binders in
- let nenv = set_env_scopes env scopes in
+ let (pat,bk),(onlyident,scopes) = Id.Map.find id binders in
+ let env = set_env_scopes env scopes in
let test_kind =
if onlyident then test_kind_ident_in_notation
else test_kind_pattern_in_notation in
- let env,((disjpat,ids),id),na = intern_pat test_kind ntnvars nenv pat in
+ let env,((disjpat,ids),id),na,bk,_ty = intern_pat test_kind ntnvars env bk pat in
+ (* TODO: use cast? *)
match disjpat with
| [pat] -> glob_constr_of_cases_pattern (Global.env()) pat
| _ -> user_err Pp.(str "Cannot turn a disjunctive pattern into a term.")
@@ -1030,7 +1041,7 @@ let intern_notation intern env ntnvars loc ntn fullargs =
(* Dispatch parsing substitution to an interpretation substitution *)
let subst = split_by_type ids fullargs in
(* Instantiate the notation *)
- instantiate_notation_constr loc intern intern_cases_pattern_as_binder ntnvars subst (Id.Map.empty, env) c
+ instantiate_notation_constr loc intern (intern_cases_pattern_as_binder intern) ntnvars subst (Id.Map.empty, env) c
(**********************************************************************)
(* Discriminating between bound variables and global references *)
@@ -1158,7 +1169,7 @@ let intern_qualid ?(no_secvar=false) qid intern env ntnvars us args =
check_no_explicitation args1;
let subst = split_by_type ids (List.map fst args1,[],[],[]) in
let infos = (Id.Map.empty, env) in
- let c = instantiate_notation_constr loc intern intern_cases_pattern_as_binder ntnvars subst infos c in
+ let c = instantiate_notation_constr loc intern (intern_cases_pattern_as_binder intern) ntnvars subst infos c in
let loc = c.loc in
let err () =
user_err ?loc (str "Notation " ++ pr_qualid qid
diff --git a/interp/notation_ops.ml b/interp/notation_ops.ml
index 61f93aa969..338a77de3d 100644
--- a/interp/notation_ops.ml
+++ b/interp/notation_ops.ml
@@ -276,12 +276,21 @@ let test_implicit_argument_mark bk =
if not (Glob_ops.binding_kind_eq bk Explicit) then
user_err (Pp.str "Unexpected implicit argument mark.")
+let test_pattern_cast = function
+ | None -> ()
+ | Some t -> user_err ?loc:t.CAst.loc (Pp.str "Unsupported pattern cast.")
+
let protect g e na =
- let e',disjpat,na,bk = g e na in
+ let e',disjpat,na,bk,t = g e na None in
if disjpat <> None then user_err (Pp.str "Unsupported substitution of an arbitrary pattern.");
test_implicit_argument_mark bk;
+ test_pattern_cast t;
e',na
+let set_anonymous_type na = function
+ | None -> DAst.make @@ GHole (Evar_kinds.BinderType na, IntroAnonymous, None)
+ | Some t -> t
+
let apply_cases_pattern_term ?loc (ids,disjpat) tm c =
let eqns = List.map (fun pat -> (CAst.make ?loc (ids,[pat],c))) disjpat in
DAst.make ?loc @@ GCases (Constr.LetPatternStyle, None, [tm,(Anonymous,None)], eqns)
@@ -307,16 +316,21 @@ let glob_constr_of_notation_constr_with_binders ?loc g f ?(h=default_binder_stat
DAst.get (subst_glob_vars outerl it)
| NLambda (na,ty,c) ->
let e = h.switch_lambda e in
- let e',disjpat,na,bk = g e na in GLambda (na,bk,f (h.restart_prod e) ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c))
+ let ty = Some (f (h.restart_prod e) ty) in
+ let e',disjpat,na',bk,ty = g e na ty in
+ GLambda (na',bk,set_anonymous_type na ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c))
| NProd (na,ty,c) ->
let e = h.switch_prod e in
- let e',disjpat,na,bk = g e na in GProd (na,bk,f (h.restart_prod e) ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c))
+ let ty = f (h.restart_prod e) ty in
+ let e',disjpat,na',bk,ty = g e na (Some ty) in
+ GProd (na',bk,set_anonymous_type na ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c))
| NLetIn (na,b,t,c) ->
- let e',disjpat,na,bk = g e na in
+ let t = Option.map (f (h.restart_prod e)) t in
+ let e',disjpat,na,bk,t = g e na t in
test_implicit_argument_mark bk;
(match disjpat with
- | None -> GLetIn (na,f (h.restart_lambda e) b,Option.map (f (h.restart_prod e)) t,f e' c)
- | Some (disjpat,_id) -> DAst.get (apply_cases_pattern_term ?loc disjpat (f e b) (f e' c)))
+ | None -> GLetIn (na,f (h.restart_lambda e) b,t,f e' c)
+ | Some (disjpat,_id) -> test_pattern_cast t; DAst.get (apply_cases_pattern_term ?loc disjpat (f e b) (f e' c)))
| NCases (sty,rtntypopt,tml,eqnl) ->
let e = h.no e in
let e',tml' = List.fold_right (fun (tm,(na,t)) (e',tml') ->
@@ -330,8 +344,9 @@ let glob_constr_of_notation_constr_with_binders ?loc g f ?(h=default_binder_stat
let e',na' = protect g e' na in
(e',(f e tm,(na',t'))::tml')) tml (e,[]) in
let fold (idl,e) na =
- let (e,disjpat,na,bk) = g e na in
+ let (e,disjpat,na,bk,t) = g e na None in
test_implicit_argument_mark bk;
+ test_pattern_cast t;
((Name.cons na idl,e),disjpat,na) in
let eqnl' = List.map (fun (patl,rhs) ->
let ((idl,e),patl) =
@@ -365,7 +380,7 @@ let glob_constr_of_notation_constr_with_binders ?loc g f ?(h=default_binder_stat
let glob_constr_of_notation_constr ?loc x =
let rec aux () x =
- glob_constr_of_notation_constr_with_binders ?loc (fun () id -> ((),None,id,Explicit)) aux () x
+ glob_constr_of_notation_constr_with_binders ?loc (fun () id t -> ((),None,id,Explicit,t)) aux () x
in aux () x
(******************************************************************************)
diff --git a/interp/notation_ops.mli b/interp/notation_ops.mli
index 3e8fdd8254..e7a0429b35 100644
--- a/interp/notation_ops.mli
+++ b/interp/notation_ops.mli
@@ -53,7 +53,7 @@ val apply_cases_pattern : ?loc:Loc.t ->
(Id.t list * cases_pattern_disjunction) * Id.t -> glob_constr -> glob_constr
val glob_constr_of_notation_constr_with_binders : ?loc:Loc.t ->
- ('a -> Name.t -> 'a * ((Id.t list * cases_pattern_disjunction) * Id.t) option * Name.t * Glob_term.binding_kind) ->
+ ('a -> Name.t -> glob_constr option -> 'a * ((Id.t list * cases_pattern_disjunction) * Id.t) option * Name.t * Glob_term.binding_kind * glob_constr option) ->
('a -> notation_constr -> glob_constr) -> ?h:'a binder_status_fun ->
'a -> notation_constr -> glob_constr