aboutsummaryrefslogtreecommitdiff
path: root/vernac/comFixpoint.ml
diff options
context:
space:
mode:
Diffstat (limited to 'vernac/comFixpoint.ml')
-rw-r--r--vernac/comFixpoint.ml25
1 files changed, 18 insertions, 7 deletions
diff --git a/vernac/comFixpoint.ml b/vernac/comFixpoint.ml
index 2aadbd224f..1912646ffd 100644
--- a/vernac/comFixpoint.ml
+++ b/vernac/comFixpoint.ml
@@ -329,16 +329,27 @@ let declare_cofixpoint ~ontop local poly ((fixnames,fixrs,fixdefs,fixtypes),pl,c
List.iter (Metasyntax.add_notation_interpretation (Global.env())) ntns;
pstate
-let extract_decreasing_argument limit = function
- | (na,CStructRec) -> na
- | (na,_) when not limit -> na
+let extract_decreasing_argument ~structonly = function { CAst.v = v } -> match v with
+ | CStructRec na -> na
+ | (CWfRec (na,_) | CMeasureRec (Some na,_,_)) when not structonly -> na
+ | CMeasureRec (None,_,_) when not structonly ->
+ user_err Pp.(str "Decreasing argument must be specificed in measure clause.")
| _ -> user_err Pp.(str
- "Only structural decreasing is supported for a non-Program Fixpoint")
+ "Well-founded induction requires Program Fixpoint or Function.")
-let extract_fixpoint_components limit l =
+let extract_fixpoint_components ~structonly l =
let fixl, ntnl = List.split l in
let fixl = List.map (fun (({CAst.v=id},pl),ann,bl,typ,def) ->
- let ann = extract_decreasing_argument limit ann in
+ (* This is a special case: if there's only one binder, we pick it as the
+ recursive argument if none is provided. *)
+ let ann = Option.map (fun ann -> match bl, ann with
+ | [CLocalAssum([{ CAst.v = Name x }],_,_)], { CAst.v = CMeasureRec(None, mes, rel); CAst.loc } ->
+ CAst.make ?loc @@ CMeasureRec(Some (CAst.make x), mes, rel)
+ | [CLocalDef({ CAst.v = Name x },_,_)], { CAst.v = CMeasureRec(None, mes, rel); CAst.loc } ->
+ CAst.make ?loc @@ CMeasureRec(Some (CAst.make x), mes, rel)
+ | _, x -> x) ann
+ in
+ let ann = Option.map (extract_decreasing_argument ~structonly) ann in
{fix_name = id; fix_annot = ann; fix_univs = pl;
fix_binders = bl; fix_body = def; fix_type = typ}) fixl in
fixl, List.flatten ntnl
@@ -356,7 +367,7 @@ let check_safe () =
flags.check_universes && flags.check_guarded
let do_fixpoint ~ontop local poly l =
- let fixl, ntns = extract_fixpoint_components true l in
+ let fixl, ntns = extract_fixpoint_components ~structonly:true l in
let (_, _, _, info as fix) = interp_fixpoint ~cofix:false fixl ntns in
let possible_indexes =
List.map compute_possible_guardness_evidences info in