aboutsummaryrefslogtreecommitdiff
path: root/interp/constrextern.ml
diff options
context:
space:
mode:
Diffstat (limited to 'interp/constrextern.ml')
-rw-r--r--interp/constrextern.ml115
1 files changed, 79 insertions, 36 deletions
diff --git a/interp/constrextern.ml b/interp/constrextern.ml
index 4f7d537d3f..9e18966b63 100644
--- a/interp/constrextern.ml
+++ b/interp/constrextern.ml
@@ -257,7 +257,7 @@ let insert_pat_delimiters ?loc p = function
let insert_pat_alias ?loc p = function
| Anonymous -> p
- | Name id -> CAst.make ?loc @@ CPatAlias (p,id)
+ | Name _ as na -> CAst.make ?loc @@ CPatAlias (p,(loc,na))
(**********************************************************************)
(* conversion of references *)
@@ -323,34 +323,35 @@ let is_zero s =
Int.equal (String.length s) i || (s.[i] == '0' && aux (i+1))
in aux 0
-let make_notation_gen loc ntn mknot mkprim destprim l =
+let make_notation_gen loc ntn mknot mkprim destprim l bl =
match ntn,List.map destprim l with
(* Special case to avoid writing "- 3" for e.g. (Z.opp 3) *)
| "- _", [Some (Numeral (p,true))] when not (is_zero p) ->
- mknot (loc,ntn,([mknot (loc,"( _ )",l)]))
+ assert (bl=[]);
+ mknot (loc,ntn,([mknot (loc,"( _ )",l,[])]),[])
| _ ->
match decompose_notation_key ntn, l with
| [Terminal "-"; Terminal x], [] when is_number x ->
mkprim (loc, Numeral (x,false))
| [Terminal x], [] when is_number x ->
mkprim (loc, Numeral (x,true))
- | _ -> mknot (loc,ntn,l)
+ | _ -> mknot (loc,ntn,l,bl)
-let make_notation loc ntn (terms,termlists,binders as subst) =
- if not (List.is_empty termlists) || not (List.is_empty binders) then
+let make_notation loc ntn (terms,termlists,binders,binderlists as subst) =
+ if not (List.is_empty termlists) || not (List.is_empty binderlists) then
CAst.make ?loc @@ CNotation (ntn,subst)
else
make_notation_gen loc ntn
- (fun (loc,ntn,l) -> CAst.make ?loc @@ CNotation (ntn,(l,[],[])))
+ (fun (loc,ntn,l,bl) -> CAst.make ?loc @@ CNotation (ntn,(l,[],bl,[])))
(fun (loc,p) -> CAst.make ?loc @@ CPrim p)
- destPrim terms
+ destPrim terms binders
let make_pat_notation ?loc ntn (terms,termlists as subst) args =
if not (List.is_empty termlists) then (CAst.make ?loc @@ CPatNotation (ntn,subst,args)) else
make_notation_gen loc ntn
- (fun (loc,ntn,l) -> CAst.make ?loc @@ CPatNotation (ntn,(l,[]),args))
+ (fun (loc,ntn,l,_) -> CAst.make ?loc @@ CPatNotation (ntn,(l,[]),args))
(fun (loc,p) -> CAst.make ?loc @@ CPatPrim p)
- destPatPrim terms
+ destPatPrim terms []
let mkPat ?loc qid l = CAst.make ?loc @@
(* Normally irrelevant test with v8 syntax, but let's do it anyway *)
@@ -533,6 +534,10 @@ let occur_name na aty =
| Name id -> occur_var_constr_expr id aty
| Anonymous -> false
+let is_gvar id c = match DAst.get c with
+| GVar id' -> Id.equal id id'
+| _ -> false
+
let is_projection nargs = function
| Some r when not !Flags.in_debugger && not !Flags.raw_print && !print_projections ->
(try
@@ -817,13 +822,11 @@ let rec extern inctx scopes vars r =
| GProd (na,bk,t,c) ->
let t = extern_typ scopes vars t in
- let (idl,c) = factorize_prod scopes (add_vname vars na) na bk t c in
- CProdN ([(Loc.tag na)::idl,Default bk,t],c)
+ factorize_prod scopes (add_vname vars na) na bk t c
| GLambda (na,bk,t,c) ->
let t = extern_typ scopes vars t in
- let (idl,c) = factorize_lambda inctx scopes (add_vname vars na) na bk t c in
- CLambdaN ([(Loc.tag na)::idl,Default bk,t],c)
+ factorize_lambda inctx scopes (add_vname vars na) na bk t c
| GCases (sty,rtntypopt,tml,eqns) ->
let vars' =
@@ -919,24 +922,60 @@ and extern_typ (_,scopes) =
and sub_extern inctx (_,scopes) = extern inctx (None,scopes)
and factorize_prod scopes vars na bk aty c =
- let c = extern_typ scopes vars c in
- match na, c with
- | Name id, { CAst.loc ; v = CProdN ([nal,Default bk',ty],c) }
- when binding_kind_eq bk bk' && constr_expr_eq aty ty
- && not (occur_var_constr_expr id ty) (* avoid na in ty escapes scope *) ->
- nal,c
- | _ ->
- [],c
+ let store, get = set_temporary_memory () in
+ match na, DAst.get c with
+ | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],(_::_ as eqns))
+ when is_gvar id e && List.length (store (factorize_eqns eqns)) = 1 ->
+ (match get () with
+ | [(_,(ids,disj_of_patl,b))] ->
+ let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in
+ let disjpat = if occur_glob_constr id b then List.map (set_pat_alias id) disjpat else disjpat in
+ let b = extern_typ scopes vars b in
+ let p = mkCPatOr (List.map (extern_cases_pattern_in_scope scopes vars) disjpat) in
+ let binder = CLocalPattern (c.loc,(p,None)) in
+ (match b.v with
+ | CProdN (bl,b) -> CProdN (binder::bl,b)
+ | _ -> CProdN ([binder],b))
+ | _ -> assert false)
+ | _, _ ->
+ let c = extern_typ scopes vars c in
+ match na, c.v with
+ | Name id, CProdN (CLocalAssum(nal,Default bk',ty)::bl,b)
+ when binding_kind_eq bk bk' && constr_expr_eq aty ty
+ && not (occur_var_constr_expr id ty) (* avoid na in ty escapes scope *) ->
+ CProdN (CLocalAssum(Loc.tag na::nal,Default bk,aty)::bl,b)
+ | _, CProdN (bl,b) ->
+ CProdN (CLocalAssum([Loc.tag na],Default bk,aty)::bl,b)
+ | _, _ ->
+ CProdN ([CLocalAssum([Loc.tag na],Default bk,aty)],c)
and factorize_lambda inctx scopes vars na bk aty c =
- let c = sub_extern inctx scopes vars c in
- match c with
- | { CAst.loc; v = CLambdaN ([nal,Default bk',ty],c) }
- when binding_kind_eq bk bk' && constr_expr_eq aty ty
- && not (occur_name na ty) (* avoid na in ty escapes scope *) ->
- nal,c
- | _ ->
- [],c
+ let store, get = set_temporary_memory () in
+ match na, DAst.get c with
+ | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],(_::_ as eqns))
+ when is_gvar id e && List.length (store (factorize_eqns eqns)) = 1 ->
+ (match get () with
+ | [(_,(ids,disj_of_patl,b))] ->
+ let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in
+ let disjpat = if occur_glob_constr id b then List.map (set_pat_alias id) disjpat else disjpat in
+ let b = sub_extern inctx scopes vars b in
+ let p = mkCPatOr (List.map (extern_cases_pattern_in_scope scopes vars) disjpat) in
+ let binder = CLocalPattern (c.loc,(p,None)) in
+ (match b.v with
+ | CLambdaN (bl,b) -> CLambdaN (binder::bl,b)
+ | _ -> CLambdaN ([binder],b))
+ | _ -> assert false)
+ | _, _ ->
+ let c = sub_extern inctx scopes vars c in
+ match c.v with
+ | CLambdaN (CLocalAssum(nal,Default bk',ty)::bl,b)
+ when binding_kind_eq bk bk' && constr_expr_eq aty ty
+ && not (occur_name na ty) (* avoid na in ty escapes scope *) ->
+ CLambdaN (CLocalAssum(Loc.tag na::nal,Default bk,aty)::bl,b)
+ | CLambdaN (bl,b) ->
+ CLambdaN (CLocalAssum([Loc.tag na],Default bk,aty)::bl,b)
+ | _ ->
+ CLambdaN ([CLocalAssum([Loc.tag na],Default bk,aty)],c)
and extern_local_binder scopes vars = function
[] -> ([],[],[])
@@ -965,7 +1004,7 @@ and extern_local_binder scopes vars = function
| GLocalPattern ((p,_),_,bk,ty) ->
let ty =
if !Flags.raw_print then Some (extern_typ scopes vars ty) else None in
- let p = extern_cases_pattern vars p in
+ let p = mkCPatOr (List.map (extern_cases_pattern vars) p) in
let (assums,ids,l) = extern_local_binder scopes vars l in
(assums,ids, CLocalPattern(Loc.tag @@ (p,ty)) :: l)
@@ -1014,7 +1053,7 @@ and extern_notation (tmp_scope,scopes as allscopes) vars t = function
| _, None -> t, [], [], []
| _ -> raise No_match in
(* Try matching ... *)
- let terms,termlists,binders =
+ let terms,termlists,binders,binderlists =
match_notation_constr !print_universes t pat in
(* Try availability of interpretation ... *)
let e =
@@ -1035,11 +1074,15 @@ and extern_notation (tmp_scope,scopes as allscopes) vars t = function
List.map (fun (c,(scopt,scl)) ->
List.map (extern true (scopt,scl@scopes') vars) c)
termlists in
- let bll =
- List.map (fun (bl,(scopt,scl)) ->
- pi3 (extern_local_binder (scopt,scl@scopes') vars bl))
+ let bl =
+ List.map (fun (bl,(scopt,scl)) ->
+ mkCPatOr (List.map (extern_cases_pattern_in_scope (scopt,scl@scopes') vars) bl))
binders in
- insert_delimiters (make_notation loc ntn (l,ll,bll)) key)
+ let bll =
+ List.map (fun (bl,(scopt,scl)) ->
+ pi3 (extern_local_binder (scopt,scl@scopes') vars bl))
+ binderlists in
+ insert_delimiters (make_notation loc ntn (l,ll,bl,bll)) key)
| SynDefRule kn ->
let l =
List.map (fun (c,(scopt,scl)) ->