summaryrefslogtreecommitdiff
path: root/src/pretty_print_coq.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/pretty_print_coq.ml')
-rw-r--r--src/pretty_print_coq.ml361
1 files changed, 278 insertions, 83 deletions
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml
index ffe376e0..74e97a29 100644
--- a/src/pretty_print_coq.ml
+++ b/src/pretty_print_coq.ml
@@ -96,12 +96,14 @@ type context = {
kid_renames : kid KBindings.t; (* Plain tyvar -> tyvar renames *)
kid_id_renames : id KBindings.t; (* tyvar -> argument renames *)
bound_nexps : NexpSet.t;
+ build_ex_return : bool;
}
let empty_ctxt = {
early_ret = false;
kid_renames = KBindings.empty;
kid_id_renames = KBindings.empty;
- bound_nexps = NexpSet.empty
+ bound_nexps = NexpSet.empty;
+ build_ex_return = false;
}
let langlebar = string "<|"
@@ -135,7 +137,10 @@ let rec fix_id remove_tick name = match name with
| "GT"
| "EQ"
| "Z"
+ | "O"
+ | "S"
| "mod"
+ | "M"
-> name ^ "'"
| _ ->
if String.contains name '#' then
@@ -146,15 +151,17 @@ let rec fix_id remove_tick name = match name with
fix_id remove_tick (String.concat "__" (Util.split_on_char '^' name))
else if name.[0] = '\'' then
let var = String.sub name 1 (String.length name - 1) in
- if remove_tick then var else (var ^ "'")
+ if remove_tick then fix_id remove_tick var else (var ^ "'")
else if is_number_char(name.[0]) then
("v" ^ name ^ "'")
else name
-let doc_id (Id_aux(i,_)) =
+let string_id (Id_aux(i,_)) =
match i with
- | Id i -> string (fix_id false i)
- | DeIid x -> string (Util.zencode_string ("op " ^ x))
+ | Id i -> fix_id false i
+ | DeIid x -> Util.zencode_string ("op " ^ x)
+
+let doc_id id = string (string_id id)
let doc_id_type (Id_aux(i,_)) =
match i with
@@ -318,7 +325,7 @@ let drop_duplicate_atoms kids ty =
in aux_typ ty
(* TODO: parens *)
-let rec doc_nc ctx (NC_aux (nc,_)) =
+let rec doc_nc_prop ctx (NC_aux (nc,_)) =
match nc with
| NC_equal (ne1, ne2) -> doc_op equals (doc_nexp ctx ne1) (doc_nexp ctx ne2)
| NC_bounded_ge (ne1, ne2) -> doc_op (string ">=") (doc_nexp ctx ne1) (doc_nexp ctx ne2)
@@ -328,11 +335,27 @@ let rec doc_nc ctx (NC_aux (nc,_)) =
separate space [string "In"; doc_var_lem ctx kid;
brackets (separate (string "; ")
(List.map (fun i -> string (Nat_big_num.to_string i)) is))]
- | NC_or (nc1, nc2) -> doc_op (string "\\/") (doc_nc ctx nc1) (doc_nc ctx nc2)
- | NC_and (nc1, nc2) -> doc_op (string "/\\") (doc_nc ctx nc1) (doc_nc ctx nc2)
+ | NC_or (nc1, nc2) -> doc_op (string "\\/") (doc_nc_prop ctx nc1) (doc_nc_prop ctx nc2)
+ | NC_and (nc1, nc2) -> doc_op (string "/\\") (doc_nc_prop ctx nc1) (doc_nc_prop ctx nc2)
| NC_true -> string "True"
| NC_false -> string "False"
+(* TODO: parens *)
+let rec doc_nc_exp ctx (NC_aux (nc,_)) =
+ match nc with
+ | NC_equal (ne1, ne2) -> doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)
+ | NC_bounded_ge (ne1, ne2) -> doc_op (string ">=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)
+ | NC_bounded_le (ne1, ne2) -> doc_op (string "<=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)
+ | NC_not_equal (ne1, ne2) -> string "negb" ^^ space ^^ parens (doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2))
+ | NC_set (kid, is) -> (* TODO: is this a good translation? *)
+ separate space [string "member_Z_list"; doc_var_lem ctx kid;
+ brackets (separate (string "; ")
+ (List.map (fun i -> string (Nat_big_num.to_string i)) is))]
+ | NC_or (nc1, nc2) -> doc_op (string "||") (doc_nc_exp ctx nc1) (doc_nc_exp ctx nc2)
+ | NC_and (nc1, nc2) -> doc_op (string "&&") (doc_nc_exp ctx nc1) (doc_nc_exp ctx nc2)
+ | NC_true -> string "true"
+ | NC_false -> string "false"
+
let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) =
match typ with
| Typ_app(Id_aux (Id "range", _), [Typ_arg_aux(Typ_arg_nexp low,_);
@@ -347,7 +370,7 @@ let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) =
let expand_range_type typ = Util.option_default typ (maybe_expand_range_type typ)
let doc_arithfact ctxt nc =
- string "ArithFact" ^^ space ^^ parens (doc_nc ctxt nc)
+ string "ArithFact" ^^ space ^^ parens (doc_nc_prop ctxt nc)
(* When making changes here, check whether they affect lem_tyvars_of_typ *)
let doc_typ, doc_atomic_typ =
@@ -381,7 +404,7 @@ let doc_typ, doc_atomic_typ =
let tpp = match elem_typ with
| Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) ->
string "mword " ^^ doc_nexp ctx (nexp_simp m)
- | _ -> string "list" ^^ space ^^ typ elem_typ in
+ | _ -> string "vec" ^^ space ^^ typ elem_typ ^^ space ^^ doc_nexp ctx (nexp_simp m) in
if atyp_needed then parens tpp else tpp
| Typ_app(Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ etyp, _)]) ->
let tpp = string "register_ref regstate register_value " ^^ typ etyp in
@@ -424,7 +447,7 @@ let doc_typ, doc_atomic_typ =
List.fold_left add_tyvar tpp kids
| None ->
match nc with
- | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids)
+(* | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids)*)
| _ -> List.fold_left add_tyvar (doc_arithfact ctx nc) kids
end
and doc_typ_arg (Typ_arg_aux(t,_)) = match t with
@@ -537,9 +560,9 @@ let doc_typquant_items ctx delimit (TypQ_aux (tq,_)) =
let doc_typquant_items_separate ctx delimit (TypQ_aux (tq,_)) =
match tq with
| TypQ_tq qis ->
- separate_opt space (doc_quant_item_id ctx delimit) qis,
- separate_opt space (doc_quant_item_constr ctx delimit) qis
- | TypQ_no_forall -> empty, empty
+ Util.map_filter (doc_quant_item_id ctx delimit) qis,
+ Util.map_filter (doc_quant_item_constr ctx delimit) qis
+ | TypQ_no_forall -> [], []
let doc_typquant ctx (TypQ_aux(tq,_)) typ = match tq with
| TypQ_tq ((_ :: _) as qs) ->
@@ -687,14 +710,34 @@ let typ_id_of (Typ_aux (typ, l)) = match typ with
| Typ_app (id, _) -> id
| _ -> raise (Reporting_basic.err_unreachable l "failed to get type id")
+(* TODO: maybe Nexp_exp, division? *)
+(* Evaluation of constant nexp subexpressions, because Coq will be able to do those itself *)
+let rec nexp_const_eval (Nexp_aux (n,l) as nexp) =
+ let binop f re l n1 n2 =
+ match nexp_const_eval n1, nexp_const_eval n2 with
+ | Nexp_aux (Nexp_constant c1,_), Nexp_aux (Nexp_constant c2,_) ->
+ Nexp_aux (Nexp_constant (f c1 c2),l)
+ | n1', n2' -> Nexp_aux (re n1' n2',l)
+ in
+ let unop f re l n1 =
+ match nexp_const_eval n1 with
+ | Nexp_aux (Nexp_constant c1,_) -> Nexp_aux (Nexp_constant (f c1),l)
+ | n1' -> Nexp_aux (re n1',l)
+ in
+ match n with
+ | Nexp_times (n1,n2) -> binop Big_int.mul (fun n1 n2 -> Nexp_times (n1,n2)) l n1 n2
+ | Nexp_sum (n1,n2) -> binop Big_int.add (fun n1 n2 -> Nexp_sum (n1,n2)) l n1 n2
+ | Nexp_minus (n1,n2) -> binop Big_int.sub (fun n1 n2 -> Nexp_minus (n1,n2)) l n1 n2
+ | Nexp_neg n1 -> unop Big_int.negate (fun n -> Nexp_neg n) l n1
+ | _ -> nexp
+
(* Decide whether two nexps used in a vector size are similar; if not
a cast will be inserted *)
-let similar_nexps n1 n2 =
+let similar_nexps env n1 n2 =
let rec same_nexp_shape (Nexp_aux (n1,_)) (Nexp_aux (n2,_)) =
match n1, n2 with
- | Nexp_id _, Nexp_id _
- | Nexp_var _, Nexp_var _
- -> true
+ | Nexp_id _, Nexp_id _ -> true
+ | Nexp_var k1, Nexp_var k2 -> prove env (nc_eq (nvar k1) (nvar k2))
| Nexp_constant c1, Nexp_constant c2 -> Nat_big_num.equal c1 c2
| Nexp_app (f1,args1), Nexp_app (f2,args2) ->
Id.compare f1 f2 == 0 && List.for_all2 same_nexp_shape args1 args2
@@ -706,7 +749,48 @@ let similar_nexps n1 n2 =
| Nexp_neg n1, Nexp_neg n2
-> same_nexp_shape n1 n2
| _ -> false
- in if same_nexp_shape n1 n2 then true else false
+ in if same_nexp_shape (nexp_const_eval n1) (nexp_const_eval n2) then true else false
+
+let constraint_fns = ["Z.leb"; "Z.geb"; "Z.ltb"; "Z.gtb"; "Z.eqb"; "neq_atom"]
+
+let condition_produces_constraint exp =
+ (* Cheat a little - this isn't quite the right environment for subexpressions
+ but will have all of the relevant functions in it. *)
+ let env = env_of exp in
+ Rewriter.fold_exp
+ { (Rewriter.pure_exp_alg false (||)) with
+ Rewriter.e_app = fun (f,bs) ->
+ List.exists (fun x -> x) bs ||
+ (let name = if Env.is_extern f env "coq"
+ then Env.get_extern f env "coq"
+ else string_id f in
+ List.exists (fun id -> String.compare name id == 0) constraint_fns)
+ } exp
+
+(* For most functions whose return types are non-trivial atoms we return a
+ dependent pair with a proof that the result is the expected integer. This
+ is redundant for basic arithmetic functions and functions which we unfold
+ in the constraint solver. *)
+let no_Z_proof_fns = ["Z.add"; "Z.sub"; "Z.opp"; "Z.mul"; "length_mword"; "length"]
+
+let is_no_Z_proof_fn env id =
+ if Env.is_extern id env "coq"
+ then
+ let s = Env.get_extern id env "coq" in
+ List.exists (fun x -> String.compare x s == 0) no_Z_proof_fns
+ else false
+
+let replace_atom_return_type ret_typ =
+ (* TODO: more complex uses of atom *)
+ match ret_typ with
+ | Typ_aux (Typ_app (Id_aux (Id "atom",_), [Typ_arg_aux (Typ_arg_nexp nexp,_)]),l) ->
+ let kid = mk_kid "_retval" in (* TODO: collision avoidance *)
+ true, Typ_aux (Typ_exist ([kid], nc_eq (nvar kid) nexp, atom_typ (nvar kid)),Generated l)
+ | Typ_aux (Typ_id (Id_aux (Id "nat",_)),l) ->
+ let kid = mk_kid "_retval" in
+ true, Typ_aux (Typ_exist ([kid], nc_gteq (nvar kid) (nconstant Nat_big_num.zero), atom_typ (nvar kid)),Generated l)
+ | _ -> false, ret_typ
+
let prefix_recordtype = true
let report = Reporting_basic.err_unreachable
@@ -815,7 +899,7 @@ let doc_exp_lem, doc_let_lem =
| Id_aux (Id "foreach", _) ->
begin
match args with
- | [exp1; exp2; exp3; ord_exp; vartuple; body] ->
+ | [from_exp; to_exp; step_exp; ord_exp; vartuple; body] ->
let loopvar, body = match body with
| E_aux (E_let (LB_aux (LB_val (_, _), _),
E_aux (E_let (LB_aux (LB_val (_, _), _),
@@ -826,13 +910,13 @@ let doc_exp_lem, doc_let_lem =
| (P_aux (P_id id, _))), _), _),
body), _), _), _)), _)), _) -> id, body
| _ -> raise (Reporting_basic.err_unreachable l ("Unable to find loop variable in " ^ string_of_exp body)) in
- let step = match ord_exp with
- | E_aux (E_lit (L_aux (L_false, _)), _) ->
- parens (separate space [string "integerNegate"; expY exp3])
- | _ -> expY exp3
+ let dir = match ord_exp with
+ | E_aux (E_lit (L_aux (L_false, _)), _) -> "_down"
+ | E_aux (E_lit (L_aux (L_true, _)), _) -> "_up"
+ | _ -> raise (Reporting_basic.err_unreachable l ("Unexpected loop direction " ^ string_of_exp ord_exp))
in
- let combinator = if effectful (effect_of body) then "foreachM" else "foreach" in
- let indices_pp = parens (separate space [string "index_list"; expY exp1; expY exp2; step]) in
+ let combinator = if effectful (effect_of body) then "foreach_ZM" else "foreach_Z" in
+ let combinator = combinator ^ dir in
let used_vars_body = find_e_ids body in
let body_lambda =
(* Work around indentation issues in Lem when translating
@@ -840,18 +924,20 @@ let doc_exp_lem, doc_let_lem =
match fst (uncast_exp vartuple) with
| E_aux (E_tuple _, _)
when not (IdSet.mem (mk_id "varstup") used_vars_body)->
- separate space [string "fun"; doc_id loopvar; string "varstup"; bigarrow]
+ separate space [string "fun"; doc_id loopvar; string "_"; string "varstup"; bigarrow]
^^ break 1 ^^
- separate space [string "let"; expY vartuple; string ":= varstup in"]
+ separate space [string "let"; squote ^^ expY vartuple; string ":= varstup in"]
| E_aux (E_lit (L_aux (L_unit, _)), _)
when not (IdSet.mem (mk_id "unit_var") used_vars_body) ->
- separate space [string "fun"; doc_id loopvar; string "unit_var"; bigarrow]
+ separate space [string "fun"; doc_id loopvar; string "_"; string "unit_var"; bigarrow]
| _ ->
- separate space [string "fun"; doc_id loopvar; expY vartuple; bigarrow]
+ separate space [string "fun"; doc_id loopvar; string "_"; expY vartuple; bigarrow]
in
parens (
(prefix 2 1)
- ((separate space) [string combinator; indices_pp; expY vartuple])
+ ((separate space) [string combinator;
+ expY from_exp; expY to_exp; expY step_exp;
+ expY vartuple])
(parens
(prefix 2 1 (group body_lambda) (expN body))
)
@@ -879,7 +965,7 @@ let doc_exp_lem, doc_let_lem =
| E_aux (E_tuple _, _)
when not (IdSet.mem (mk_id "varstup") used_vars_body)->
separate space [string "fun varstup"; bigarrow] ^^ break 1 ^^
- separate space [string "let"; expY varstuple; string ":= varstup in"]
+ separate space [string "let"; squote ^^ expY varstuple; string ":= varstup in"]
| E_aux (E_lit (L_aux (L_unit, _)), _)
when not (IdSet.mem (mk_id "unit_var") used_vars_body) ->
separate space [string "fun unit_var"; bigarrow]
@@ -946,17 +1032,35 @@ let doc_exp_lem, doc_let_lem =
(* TODO: more sophisticated check *)
match destruct_exist env arg_ty, destruct_exist env typ_from_fn with
| Some _, None -> parens (string "projT1 " ^^ arg_pp)
+ (* Usually existentials have already been built elsewhere, but this
+ is useful for (e.g.) ranges *)
+ | None, Some _ -> parens (string "build_ex " ^^ arg_pp)
| _, _ -> arg_pp
in
let epp = hang 2 (flow (break 1) (call :: List.map2 doc_arg args arg_typs)) in
- (* Unpack existential result *)
- let inst = instantiation_of full_exp in
+ (* Decide whether to unpack an existential result, pack one, or cast.
+ To do this we compare the expected type stored in the checked expression
+ with the inferred type. *)
+ let inst =
+ match instantiation_of_without_type full_exp with
+ | x -> x
+ (* Not all function applications can be inferred, so try falling back to the
+ type inferred when we know the target type.
+ TODO: there are probably some edge cases where this won't pick up a need
+ to cast. *)
+ | exception _ -> instantiation_of full_exp
+ in
let inst = KBindings.fold (fun k u m -> KBindings.add (orig_kid k) u m) inst KBindings.empty in
- let ret_typ_inst = subst_unifiers inst ret_typ in
+ let ret_typ_inst =
+ subst_unifiers inst ret_typ
+ in
let unpack,build_ex,autocast =
let ann_typ = Env.expand_synonyms env (typ_of_annot (l,annot)) in
let ann_typ = expand_range_type ann_typ in
let ret_typ_inst = expand_range_type (Env.expand_synonyms env ret_typ_inst) in
+ let ret_typ_inst =
+ if is_no_Z_proof_fn env f then ret_typ_inst
+ else snd (replace_atom_return_type ret_typ_inst) in
let unpack, build_ex, in_typ, out_typ =
match ret_typ_inst, ann_typ with
| Typ_aux (Typ_exist (_,_,t1),_), Typ_aux (Typ_exist (_,_,t2),_) ->
@@ -968,10 +1072,12 @@ let doc_exp_lem, doc_let_lem =
| t1, t2 -> false,false,t1,t2
in
let autocast =
- match destruct_vector env in_typ, destruct_vector env out_typ with
- | Some (n1,_,t1), Some (n2,_,t2)
- when is_bit_typ t1 && is_bit_typ t2 ->
- not (similar_nexps n1 n2)
+ (* Avoid using helper functions which simplify the nexps *)
+ is_bitvector_typ in_typ && is_bitvector_typ out_typ &&
+ match in_typ, out_typ with
+ | Typ_aux (Typ_app (_,[Typ_arg_aux (Typ_arg_nexp n1,_);_;_]),_),
+ Typ_aux (Typ_app (_,[Typ_arg_aux (Typ_arg_nexp n2,_);_;_]),_) ->
+ not (similar_nexps env n1 n2)
| _ -> false
in unpack,build_ex,autocast
in
@@ -1048,13 +1154,33 @@ let doc_exp_lem, doc_let_lem =
(doc_fexp ctxt recordtyp) fexps)) in
if aexp_needed then parens epp else epp
| E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) ->
- let recordtyp = match annot with
+ let recordtyp, env = match annot with
| Some (env, Typ_aux (Typ_id tid,_), _)
| Some (env, Typ_aux (Typ_app (tid, _), _), _)
when Env.is_record tid env ->
- tid
+ tid, env
| _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in
- enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps))
+ if List.length fexps > 1 then
+ let _,fields = Env.get_record recordtyp env in
+ let var, let_pp =
+ match e with
+ | E_aux (E_id id,_) -> id, empty
+ | _ -> let v = mk_id "_record" in (* TODO: collision avoid *)
+ v, separate space [string "let "; doc_id v; coloneq; top_exp ctxt true e; string "in"] ^^ break 1
+ in
+ let doc_field (_,id) =
+ match List.find (fun (FE_aux (FE_Fexp (id',_),_)) -> Id.compare id id' == 0) fexps with
+ | fexp -> doc_fexp ctxt recordtyp fexp
+ | exception Not_found ->
+ let fname =
+ if prefix_recordtype && string_of_id recordtyp <> "regstate"
+ then (string (string_of_id recordtyp ^ "_")) ^^ doc_id id
+ else doc_id id in
+ doc_op coloneq fname (doc_id var ^^ dot ^^ parens fname)
+ in let_pp ^^ enclose_record (align (separate_map (semi_sp ^^ break 1)
+ doc_field fields))
+ else
+ enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps))
| E_vector exps ->
let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in
let start, (len, order, etyp) =
@@ -1079,7 +1205,9 @@ let doc_exp_lem, doc_let_lem =
if is_bit_typ etyp then
let bepp = string "vec_of_bits" ^^ space ^^ align epp in
(align (group (prefix 0 1 bepp (doc_tannot_lem ctxt (env_of full_exp) false t))), true)
- else (epp,aexp_needed) in
+ else
+ let vepp = string "vec_of_list_len" ^^ space ^^ align epp in
+ (vepp,aexp_needed) in
if aexp_needed then parens (align epp) else epp
| E_vector_update(v,e1,e2) ->
raise (Reporting_basic.err_unreachable l
@@ -1100,7 +1228,8 @@ let doc_exp_lem, doc_let_lem =
if effectful (effect_of e) then
let try_catch = if ctxt.early_ret then "try_catchR" else "try_catch" in
let epp =
- group ((separate space [string try_catch; expY e; string "(function "]) ^/^
+ (* TODO capture avoidance for __catch_val *)
+ group ((separate space [string try_catch; expY e; string "(fun __catch_val => match __catch_val with "]) ^/^
(separate_map (break 1) (doc_case ctxt exc_typ) pexps) ^/^
(string "end)")) in
if aexp_needed then parens (align epp) else align epp
@@ -1119,24 +1248,37 @@ let doc_exp_lem, doc_let_lem =
| E_var(lexp, eq_exp, in_exp) ->
raise (report l "E_vars should have been removed before pretty-printing")
| E_internal_plet (pat,e1,e2) ->
- let epp =
- let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in
- let middle =
- match fst (untyp_pat pat) with
- | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) ->
- string ">>"
- | P_aux (P_id id,_) ->
- separate space [string ">>= fun"; doc_id id; bigarrow]
- | P_aux (P_typ (typ, P_aux (P_id id,_)),_) ->
- separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt typ; bigarrow]
- | _ ->
- separate space [string ">>= fun"; squote ^^ doc_pat ctxt true (pat, typ_of e1); bigarrow]
- in
- infix 0 1 middle (expV b e1) (expN e2)
- in
- if aexp_needed then parens (align epp) else epp
+ begin
+ match pat, e1 with
+ | (P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _)),
+ (E_aux (E_assert (assert_e1,assert_e2),_)) ->
+ let epp = liftR (separate space [string "assert_exp'"; expY assert_e1; expY assert_e2]) in
+ let epp = infix 0 1 (string ">>= fun _ =>") epp (expN e2) in
+ if aexp_needed then parens (align epp) else align epp
+ | _ ->
+ let epp =
+ let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in
+ let middle =
+ match fst (untyp_pat pat) with
+ | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) ->
+ string ">>"
+ | P_aux (P_id id,_) ->
+ separate space [string ">>= fun"; doc_id id; bigarrow]
+ | P_aux (P_typ (typ, P_aux (P_id id,_)),_) ->
+ separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt typ; bigarrow]
+ | _ ->
+ separate space [string ">>= fun"; squote ^^ doc_pat ctxt true (pat, typ_of e1); bigarrow]
+ in
+ infix 0 1 middle (expV b e1) (expN e2)
+ in
+ if aexp_needed then parens (align epp) else epp
+ end
| E_internal_return (e1) ->
- wrap_parens (align (separate space [string "returnm"; expY e1]))
+ let e1pp = expY e1 in
+ let valpp = if ctxt.build_ex_return
+ then parens (string "build_ex" ^^ space ^^ e1pp)
+ else e1pp in
+ wrap_parens (align (separate space [string "returnm"; valpp]))
| E_sizeof nexp ->
(match nexp_simp nexp with
| Nexp_aux (Nexp_constant i, _) -> doc_lit (L_aux (L_num i, l))
@@ -1153,7 +1295,7 @@ let doc_exp_lem, doc_let_lem =
parens (doc_typ ctxt (typ_of full_exp));
parens (doc_typ ctxt (typ_of r))] in
align (parens (string "early_return" ^//^ expV true r ^//^ ta))
- | E_constraint _ -> string "true"
+ | E_constraint nc -> wrap_parens (doc_nc_exp ctxt nc)
| E_comment _ | E_comment_struc _ -> empty
| E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _
| E_internal_exp_user _ | E_internal_value _ ->
@@ -1168,7 +1310,9 @@ let doc_exp_lem, doc_let_lem =
| _ -> prefix 2 1 (string "else") (top_exp ctxt false e)
in
(prefix 2 1
- (soft_surround 2 1 if_pp (string "sumbool_of_bool" ^^ space ^^ parens (top_exp ctxt true c)) (string "then"))
+ (soft_surround 2 1 if_pp
+ ((if condition_produces_constraint c then string "sumbool_of_bool" ^^ space else empty)
+ ^^ parens (top_exp ctxt true c)) (string "then"))
(top_exp ctxt false t)) ^^
break 1 ^^
else_pp
@@ -1404,17 +1548,28 @@ let demote_as_pattern i (P_aux (_,p_annot) as pat,typ) =
E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)),p_annot),e),e_ann)
else (pat,typ), fun e -> e
-(* Ideally we'd remove the duplication between type variables and atom
- arguments, but for now we just add an equality constraint. *)
+(* Add equality constraints between arguments and nexps, except in the case
+ that they've been merged. *)
-let atom_constraint ctxt (pat, typ) =
+let rec atom_constraint ctxt (pat, typ) =
let typ = Env.base_typ_of (pat_env_of pat) typ in
match pat, typ with
| P_aux (P_id id, _),
Typ_aux (Typ_app (Id_aux (Id "atom",_),
- [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_) ->
+ [Typ_arg_aux (Typ_arg_nexp nexp,_)]),_) ->
+ (match nexp with
+ (* When the kid is mapped to the id, we don't need a constraint *)
+ | Nexp_aux (Nexp_var kid,_)
+ when (try Id.compare (KBindings.find kid ctxt.kid_id_renames) id == 0 with _ -> false) ->
+ None
+ | _ ->
+ Some (bquote ^^ braces (string "ArithFact" ^^ space ^^
+ parens (doc_op equals (doc_id id) (doc_nexp ctxt nexp)))))
+ | P_aux (P_id id, _),
+ Typ_aux (Typ_id (Id_aux (Id "nat",_)),_) ->
Some (bquote ^^ braces (string "ArithFact" ^^ space ^^
- parens (doc_op equals (doc_id id) (doc_var_lem ctxt kid))))
+ parens (doc_op (string ">=") (doc_id id) (string "0"))))
+ | P_aux (P_typ (_,p),_), _ -> atom_constraint ctxt (p, typ)
| _ -> None
let all_ids pexp =
@@ -1485,14 +1640,13 @@ let merge_kids_atoms pats =
let gone,map,_ = List.fold_left try_eliminate (KidSet.empty, KBindings.empty, KidSet.empty) pats in
gone,map
-let doc_binder ctxt (P_aux (p,ann) as pat, typ) =
- let env = env_of_annot ann in
- let exp_typ = Env.expand_synonyms env typ in
- match p with
- | P_id id
- | P_typ (_,P_aux (P_id id,_)) when Util.is_none (is_auto_decomposed_exist env exp_typ) ->
- parens (separate space [doc_id id; colon; doc_typ ctxt typ])
- | _ -> squote ^^ parens (separate space [doc_pat ctxt true (pat, exp_typ); colon; doc_typ ctxt typ])
+let merge_var_patterns map pats =
+ let map,pats = List.fold_left (fun (map,pats) (pat, typ) ->
+ match pat with
+ | P_aux (P_var (P_aux (P_id id,_), TP_aux (TP_var kid,_)),ann) ->
+ KBindings.add kid id map, (P_aux (P_id id,ann), typ) :: pats
+ | _ -> map, (pat,typ)::pats) (map,[]) pats
+ in map, List.rev pats
let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) =
let (tq,typ) = Env.get_val_spec_orig id (env_of_annot annot) in
@@ -1500,38 +1654,70 @@ let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) =
| Typ_aux (Typ_fn (arg_typ, ret_typ, eff),_) -> arg_typ, ret_typ, eff
| _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type")
in
+ let build_ex, ret_typ = replace_atom_return_type ret_typ in
let ids_to_avoid = all_ids pexp in
let kids_used = tyvars_of_typquant tq in
let pat,guard,exp,(l,_) = destruct_pexp pexp in
let pats, bind = untuple_args_pat arg_typ pat in
let pats, binds = List.split (Util.list_mapi demote_as_pattern pats) in
let eliminated_kids, kid_to_arg_rename = merge_kids_atoms pats in
+ let kid_to_arg_rename, pats = merge_var_patterns kid_to_arg_rename pats in
let kids_used = KidSet.diff kids_used eliminated_kids in
let ctxt =
{ early_ret = contains_early_return exp;
kid_renames = mk_kid_renames ids_to_avoid kids_used;
kid_id_renames = kid_to_arg_rename;
- bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) } in
+ bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ);
+ build_ex_return = effectful eff && build_ex;
+ } in
(* Put the constraints after pattern matching so that any type variable that's
been replaced by one of the term-level arguments is bound. *)
let quantspp, constrspp = doc_typquant_items_separate ctxt braces tq in
let exp = List.fold_left (fun body f -> f body) (bind exp) binds in
- let patspp = separate_map space (doc_binder ctxt) pats in
- let atom_constr_pp = separate_opt space (atom_constraint ctxt) pats in
+ let used_a_pattern = ref false in
+ let doc_binder (P_aux (p,ann) as pat, typ) =
+ let env = env_of_annot ann in
+ let exp_typ = Env.expand_synonyms env typ in
+ match p with
+ | P_id id
+ | P_typ (_,P_aux (P_id id,_)) when Util.is_none (is_auto_decomposed_exist env exp_typ) ->
+ parens (separate space [doc_id id; colon; doc_typ ctxt typ])
+ | _ ->
+ (used_a_pattern := true;
+ squote ^^ parens (separate space [doc_pat ctxt true (pat, exp_typ); colon; doc_typ ctxt typ]))
+ in
+ let patspp = separate_map space doc_binder pats in
+ let atom_constrs = Util.map_filter (atom_constraint ctxt) pats in
+ let atom_constr_pp = separate space atom_constrs in
let retpp =
if effectful eff
then string "M" ^^ space ^^ parens (doc_typ ctxt ret_typ)
else doc_typ ctxt ret_typ
in
+ let idpp = doc_id id in
+ (* Work around Coq bug 7975 about pattern binders followed by implicit arguments *)
+ let implicitargs =
+ if !used_a_pattern && List.length constrspp + List.length atom_constrs > 0 then
+ break 1 ^^ separate space
+ ([string "Arguments"; idpp;] @
+ List.map (fun _ -> string "{_}") quantspp @
+ List.map (fun _ -> string "_") pats @
+ List.map (fun _ -> string "{_}") constrspp @
+ List.map (fun _ -> string "{_}") atom_constrs)
+ ^^ dot
+ else empty
+ in
let _ = match guard with
| None -> ()
| _ ->
raise (Reporting_basic.err_unreachable l
"guarded pattern expression should have been rewritten before pretty-printing") in
+ let bodypp = doc_fun_body ctxt exp in
+ let bodypp = if effectful eff || not build_ex then bodypp else string "build_ex" ^^ parens bodypp in
group (prefix 3 1
- (separate space [doc_id id; quantspp; patspp; constrspp; atom_constr_pp] ^/^
- colon ^^ space ^^ retpp ^^ coloneq)
- (doc_fun_body ctxt exp ^^ dot))
+ (separate space ([idpp] @ quantspp @ [patspp] @ constrspp @ [atom_constr_pp]) ^/^
+ separate space [colon; retpp; coloneq])
+ (bodypp ^^ dot)) ^^ implicitargs
let get_id = function
| [] -> failwith "FD_function with empty list"
@@ -1658,9 +1844,16 @@ let doc_axiom_typschm typ_env (TypSchm_aux (TypSchm_ts (tqs,typ),l) as ts) =
| _ -> parens (underscore ^^ string " : " ^^ doc_typ empty_ctxt typ)
in
let arg_typs_pp = separate space (List.map doc_typ' typs) in
+ let _, ret_ty = replace_atom_return_type ret_ty in
let ret_typ_pp = doc_typ empty_ctxt ret_ty in
+ let ret_typ_pp =
+ if effectful eff
+ then string "M" ^^ space ^^ parens ret_typ_pp
+ else ret_typ_pp
+ in
let tyvars_pp, constrs_pp = doc_typquant_items_separate empty_ctxt braces tqs in
- string "forall" ^/^ tyvars_pp ^/^ arg_typs_pp ^/^ constrs_pp ^^ comma ^/^ ret_typ_pp
+ string "forall" ^/^ separate space tyvars_pp ^/^
+ arg_typs_pp ^/^ separate space constrs_pp ^^ comma ^/^ ret_typ_pp
| _ -> doc_typschm empty_ctxt true ts
let doc_val_spec unimplemented (VS_aux (VS_val_spec(tys,id,_,_),ann)) =
@@ -1800,6 +1993,8 @@ try
(fun lib -> separate space [string "Require Import";string lib] ^^ dot) defs_modules;hardline;
string "Import ListNotations.";
hardline;
+ string "Open Scope string."; hardline;
+ string "Open Scope bool."; hardline;
(* Put the body into a Section so that we can define some values with
Let to put them into the local context, where tactics can see them *)
string "Section Content.";