diff options
Diffstat (limited to 'src/pretty_print_coq.ml')
| -rw-r--r-- | src/pretty_print_coq.ml | 361 |
1 files changed, 278 insertions, 83 deletions
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index ffe376e0..74e97a29 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -96,12 +96,14 @@ type context = { kid_renames : kid KBindings.t; (* Plain tyvar -> tyvar renames *) kid_id_renames : id KBindings.t; (* tyvar -> argument renames *) bound_nexps : NexpSet.t; + build_ex_return : bool; } let empty_ctxt = { early_ret = false; kid_renames = KBindings.empty; kid_id_renames = KBindings.empty; - bound_nexps = NexpSet.empty + bound_nexps = NexpSet.empty; + build_ex_return = false; } let langlebar = string "<|" @@ -135,7 +137,10 @@ let rec fix_id remove_tick name = match name with | "GT" | "EQ" | "Z" + | "O" + | "S" | "mod" + | "M" -> name ^ "'" | _ -> if String.contains name '#' then @@ -146,15 +151,17 @@ let rec fix_id remove_tick name = match name with fix_id remove_tick (String.concat "__" (Util.split_on_char '^' name)) else if name.[0] = '\'' then let var = String.sub name 1 (String.length name - 1) in - if remove_tick then var else (var ^ "'") + if remove_tick then fix_id remove_tick var else (var ^ "'") else if is_number_char(name.[0]) then ("v" ^ name ^ "'") else name -let doc_id (Id_aux(i,_)) = +let string_id (Id_aux(i,_)) = match i with - | Id i -> string (fix_id false i) - | DeIid x -> string (Util.zencode_string ("op " ^ x)) + | Id i -> fix_id false i + | DeIid x -> Util.zencode_string ("op " ^ x) + +let doc_id id = string (string_id id) let doc_id_type (Id_aux(i,_)) = match i with @@ -318,7 +325,7 @@ let drop_duplicate_atoms kids ty = in aux_typ ty (* TODO: parens *) -let rec doc_nc ctx (NC_aux (nc,_)) = +let rec doc_nc_prop ctx (NC_aux (nc,_)) = match nc with | NC_equal (ne1, ne2) -> doc_op equals (doc_nexp ctx ne1) (doc_nexp ctx ne2) | NC_bounded_ge (ne1, ne2) -> doc_op (string ">=") (doc_nexp ctx ne1) (doc_nexp ctx ne2) @@ -328,11 +335,27 @@ let rec doc_nc ctx (NC_aux (nc,_)) = separate space [string "In"; doc_var_lem ctx kid; brackets (separate (string "; ") (List.map (fun i -> string (Nat_big_num.to_string i)) is))] - | NC_or (nc1, nc2) -> doc_op (string "\\/") (doc_nc ctx nc1) (doc_nc ctx nc2) - | NC_and (nc1, nc2) -> doc_op (string "/\\") (doc_nc ctx nc1) (doc_nc ctx nc2) + | NC_or (nc1, nc2) -> doc_op (string "\\/") (doc_nc_prop ctx nc1) (doc_nc_prop ctx nc2) + | NC_and (nc1, nc2) -> doc_op (string "/\\") (doc_nc_prop ctx nc1) (doc_nc_prop ctx nc2) | NC_true -> string "True" | NC_false -> string "False" +(* TODO: parens *) +let rec doc_nc_exp ctx (NC_aux (nc,_)) = + match nc with + | NC_equal (ne1, ne2) -> doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_bounded_ge (ne1, ne2) -> doc_op (string ">=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_bounded_le (ne1, ne2) -> doc_op (string "<=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_not_equal (ne1, ne2) -> string "negb" ^^ space ^^ parens (doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)) + | NC_set (kid, is) -> (* TODO: is this a good translation? *) + separate space [string "member_Z_list"; doc_var_lem ctx kid; + brackets (separate (string "; ") + (List.map (fun i -> string (Nat_big_num.to_string i)) is))] + | NC_or (nc1, nc2) -> doc_op (string "||") (doc_nc_exp ctx nc1) (doc_nc_exp ctx nc2) + | NC_and (nc1, nc2) -> doc_op (string "&&") (doc_nc_exp ctx nc1) (doc_nc_exp ctx nc2) + | NC_true -> string "true" + | NC_false -> string "false" + let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) = match typ with | Typ_app(Id_aux (Id "range", _), [Typ_arg_aux(Typ_arg_nexp low,_); @@ -347,7 +370,7 @@ let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) = let expand_range_type typ = Util.option_default typ (maybe_expand_range_type typ) let doc_arithfact ctxt nc = - string "ArithFact" ^^ space ^^ parens (doc_nc ctxt nc) + string "ArithFact" ^^ space ^^ parens (doc_nc_prop ctxt nc) (* When making changes here, check whether they affect lem_tyvars_of_typ *) let doc_typ, doc_atomic_typ = @@ -381,7 +404,7 @@ let doc_typ, doc_atomic_typ = let tpp = match elem_typ with | Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) -> string "mword " ^^ doc_nexp ctx (nexp_simp m) - | _ -> string "list" ^^ space ^^ typ elem_typ in + | _ -> string "vec" ^^ space ^^ typ elem_typ ^^ space ^^ doc_nexp ctx (nexp_simp m) in if atyp_needed then parens tpp else tpp | Typ_app(Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ etyp, _)]) -> let tpp = string "register_ref regstate register_value " ^^ typ etyp in @@ -424,7 +447,7 @@ let doc_typ, doc_atomic_typ = List.fold_left add_tyvar tpp kids | None -> match nc with - | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids) +(* | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids)*) | _ -> List.fold_left add_tyvar (doc_arithfact ctx nc) kids end and doc_typ_arg (Typ_arg_aux(t,_)) = match t with @@ -537,9 +560,9 @@ let doc_typquant_items ctx delimit (TypQ_aux (tq,_)) = let doc_typquant_items_separate ctx delimit (TypQ_aux (tq,_)) = match tq with | TypQ_tq qis -> - separate_opt space (doc_quant_item_id ctx delimit) qis, - separate_opt space (doc_quant_item_constr ctx delimit) qis - | TypQ_no_forall -> empty, empty + Util.map_filter (doc_quant_item_id ctx delimit) qis, + Util.map_filter (doc_quant_item_constr ctx delimit) qis + | TypQ_no_forall -> [], [] let doc_typquant ctx (TypQ_aux(tq,_)) typ = match tq with | TypQ_tq ((_ :: _) as qs) -> @@ -687,14 +710,34 @@ let typ_id_of (Typ_aux (typ, l)) = match typ with | Typ_app (id, _) -> id | _ -> raise (Reporting_basic.err_unreachable l "failed to get type id") +(* TODO: maybe Nexp_exp, division? *) +(* Evaluation of constant nexp subexpressions, because Coq will be able to do those itself *) +let rec nexp_const_eval (Nexp_aux (n,l) as nexp) = + let binop f re l n1 n2 = + match nexp_const_eval n1, nexp_const_eval n2 with + | Nexp_aux (Nexp_constant c1,_), Nexp_aux (Nexp_constant c2,_) -> + Nexp_aux (Nexp_constant (f c1 c2),l) + | n1', n2' -> Nexp_aux (re n1' n2',l) + in + let unop f re l n1 = + match nexp_const_eval n1 with + | Nexp_aux (Nexp_constant c1,_) -> Nexp_aux (Nexp_constant (f c1),l) + | n1' -> Nexp_aux (re n1',l) + in + match n with + | Nexp_times (n1,n2) -> binop Big_int.mul (fun n1 n2 -> Nexp_times (n1,n2)) l n1 n2 + | Nexp_sum (n1,n2) -> binop Big_int.add (fun n1 n2 -> Nexp_sum (n1,n2)) l n1 n2 + | Nexp_minus (n1,n2) -> binop Big_int.sub (fun n1 n2 -> Nexp_minus (n1,n2)) l n1 n2 + | Nexp_neg n1 -> unop Big_int.negate (fun n -> Nexp_neg n) l n1 + | _ -> nexp + (* Decide whether two nexps used in a vector size are similar; if not a cast will be inserted *) -let similar_nexps n1 n2 = +let similar_nexps env n1 n2 = let rec same_nexp_shape (Nexp_aux (n1,_)) (Nexp_aux (n2,_)) = match n1, n2 with - | Nexp_id _, Nexp_id _ - | Nexp_var _, Nexp_var _ - -> true + | Nexp_id _, Nexp_id _ -> true + | Nexp_var k1, Nexp_var k2 -> prove env (nc_eq (nvar k1) (nvar k2)) | Nexp_constant c1, Nexp_constant c2 -> Nat_big_num.equal c1 c2 | Nexp_app (f1,args1), Nexp_app (f2,args2) -> Id.compare f1 f2 == 0 && List.for_all2 same_nexp_shape args1 args2 @@ -706,7 +749,48 @@ let similar_nexps n1 n2 = | Nexp_neg n1, Nexp_neg n2 -> same_nexp_shape n1 n2 | _ -> false - in if same_nexp_shape n1 n2 then true else false + in if same_nexp_shape (nexp_const_eval n1) (nexp_const_eval n2) then true else false + +let constraint_fns = ["Z.leb"; "Z.geb"; "Z.ltb"; "Z.gtb"; "Z.eqb"; "neq_atom"] + +let condition_produces_constraint exp = + (* Cheat a little - this isn't quite the right environment for subexpressions + but will have all of the relevant functions in it. *) + let env = env_of exp in + Rewriter.fold_exp + { (Rewriter.pure_exp_alg false (||)) with + Rewriter.e_app = fun (f,bs) -> + List.exists (fun x -> x) bs || + (let name = if Env.is_extern f env "coq" + then Env.get_extern f env "coq" + else string_id f in + List.exists (fun id -> String.compare name id == 0) constraint_fns) + } exp + +(* For most functions whose return types are non-trivial atoms we return a + dependent pair with a proof that the result is the expected integer. This + is redundant for basic arithmetic functions and functions which we unfold + in the constraint solver. *) +let no_Z_proof_fns = ["Z.add"; "Z.sub"; "Z.opp"; "Z.mul"; "length_mword"; "length"] + +let is_no_Z_proof_fn env id = + if Env.is_extern id env "coq" + then + let s = Env.get_extern id env "coq" in + List.exists (fun x -> String.compare x s == 0) no_Z_proof_fns + else false + +let replace_atom_return_type ret_typ = + (* TODO: more complex uses of atom *) + match ret_typ with + | Typ_aux (Typ_app (Id_aux (Id "atom",_), [Typ_arg_aux (Typ_arg_nexp nexp,_)]),l) -> + let kid = mk_kid "_retval" in (* TODO: collision avoidance *) + true, Typ_aux (Typ_exist ([kid], nc_eq (nvar kid) nexp, atom_typ (nvar kid)),Generated l) + | Typ_aux (Typ_id (Id_aux (Id "nat",_)),l) -> + let kid = mk_kid "_retval" in + true, Typ_aux (Typ_exist ([kid], nc_gteq (nvar kid) (nconstant Nat_big_num.zero), atom_typ (nvar kid)),Generated l) + | _ -> false, ret_typ + let prefix_recordtype = true let report = Reporting_basic.err_unreachable @@ -815,7 +899,7 @@ let doc_exp_lem, doc_let_lem = | Id_aux (Id "foreach", _) -> begin match args with - | [exp1; exp2; exp3; ord_exp; vartuple; body] -> + | [from_exp; to_exp; step_exp; ord_exp; vartuple; body] -> let loopvar, body = match body with | E_aux (E_let (LB_aux (LB_val (_, _), _), E_aux (E_let (LB_aux (LB_val (_, _), _), @@ -826,13 +910,13 @@ let doc_exp_lem, doc_let_lem = | (P_aux (P_id id, _))), _), _), body), _), _), _)), _)), _) -> id, body | _ -> raise (Reporting_basic.err_unreachable l ("Unable to find loop variable in " ^ string_of_exp body)) in - let step = match ord_exp with - | E_aux (E_lit (L_aux (L_false, _)), _) -> - parens (separate space [string "integerNegate"; expY exp3]) - | _ -> expY exp3 + let dir = match ord_exp with + | E_aux (E_lit (L_aux (L_false, _)), _) -> "_down" + | E_aux (E_lit (L_aux (L_true, _)), _) -> "_up" + | _ -> raise (Reporting_basic.err_unreachable l ("Unexpected loop direction " ^ string_of_exp ord_exp)) in - let combinator = if effectful (effect_of body) then "foreachM" else "foreach" in - let indices_pp = parens (separate space [string "index_list"; expY exp1; expY exp2; step]) in + let combinator = if effectful (effect_of body) then "foreach_ZM" else "foreach_Z" in + let combinator = combinator ^ dir in let used_vars_body = find_e_ids body in let body_lambda = (* Work around indentation issues in Lem when translating @@ -840,18 +924,20 @@ let doc_exp_lem, doc_let_lem = match fst (uncast_exp vartuple) with | E_aux (E_tuple _, _) when not (IdSet.mem (mk_id "varstup") used_vars_body)-> - separate space [string "fun"; doc_id loopvar; string "varstup"; bigarrow] + separate space [string "fun"; doc_id loopvar; string "_"; string "varstup"; bigarrow] ^^ break 1 ^^ - separate space [string "let"; expY vartuple; string ":= varstup in"] + separate space [string "let"; squote ^^ expY vartuple; string ":= varstup in"] | E_aux (E_lit (L_aux (L_unit, _)), _) when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> - separate space [string "fun"; doc_id loopvar; string "unit_var"; bigarrow] + separate space [string "fun"; doc_id loopvar; string "_"; string "unit_var"; bigarrow] | _ -> - separate space [string "fun"; doc_id loopvar; expY vartuple; bigarrow] + separate space [string "fun"; doc_id loopvar; string "_"; expY vartuple; bigarrow] in parens ( (prefix 2 1) - ((separate space) [string combinator; indices_pp; expY vartuple]) + ((separate space) [string combinator; + expY from_exp; expY to_exp; expY step_exp; + expY vartuple]) (parens (prefix 2 1 (group body_lambda) (expN body)) ) @@ -879,7 +965,7 @@ let doc_exp_lem, doc_let_lem = | E_aux (E_tuple _, _) when not (IdSet.mem (mk_id "varstup") used_vars_body)-> separate space [string "fun varstup"; bigarrow] ^^ break 1 ^^ - separate space [string "let"; expY varstuple; string ":= varstup in"] + separate space [string "let"; squote ^^ expY varstuple; string ":= varstup in"] | E_aux (E_lit (L_aux (L_unit, _)), _) when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> separate space [string "fun unit_var"; bigarrow] @@ -946,17 +1032,35 @@ let doc_exp_lem, doc_let_lem = (* TODO: more sophisticated check *) match destruct_exist env arg_ty, destruct_exist env typ_from_fn with | Some _, None -> parens (string "projT1 " ^^ arg_pp) + (* Usually existentials have already been built elsewhere, but this + is useful for (e.g.) ranges *) + | None, Some _ -> parens (string "build_ex " ^^ arg_pp) | _, _ -> arg_pp in let epp = hang 2 (flow (break 1) (call :: List.map2 doc_arg args arg_typs)) in - (* Unpack existential result *) - let inst = instantiation_of full_exp in + (* Decide whether to unpack an existential result, pack one, or cast. + To do this we compare the expected type stored in the checked expression + with the inferred type. *) + let inst = + match instantiation_of_without_type full_exp with + | x -> x + (* Not all function applications can be inferred, so try falling back to the + type inferred when we know the target type. + TODO: there are probably some edge cases where this won't pick up a need + to cast. *) + | exception _ -> instantiation_of full_exp + in let inst = KBindings.fold (fun k u m -> KBindings.add (orig_kid k) u m) inst KBindings.empty in - let ret_typ_inst = subst_unifiers inst ret_typ in + let ret_typ_inst = + subst_unifiers inst ret_typ + in let unpack,build_ex,autocast = let ann_typ = Env.expand_synonyms env (typ_of_annot (l,annot)) in let ann_typ = expand_range_type ann_typ in let ret_typ_inst = expand_range_type (Env.expand_synonyms env ret_typ_inst) in + let ret_typ_inst = + if is_no_Z_proof_fn env f then ret_typ_inst + else snd (replace_atom_return_type ret_typ_inst) in let unpack, build_ex, in_typ, out_typ = match ret_typ_inst, ann_typ with | Typ_aux (Typ_exist (_,_,t1),_), Typ_aux (Typ_exist (_,_,t2),_) -> @@ -968,10 +1072,12 @@ let doc_exp_lem, doc_let_lem = | t1, t2 -> false,false,t1,t2 in let autocast = - match destruct_vector env in_typ, destruct_vector env out_typ with - | Some (n1,_,t1), Some (n2,_,t2) - when is_bit_typ t1 && is_bit_typ t2 -> - not (similar_nexps n1 n2) + (* Avoid using helper functions which simplify the nexps *) + is_bitvector_typ in_typ && is_bitvector_typ out_typ && + match in_typ, out_typ with + | Typ_aux (Typ_app (_,[Typ_arg_aux (Typ_arg_nexp n1,_);_;_]),_), + Typ_aux (Typ_app (_,[Typ_arg_aux (Typ_arg_nexp n2,_);_;_]),_) -> + not (similar_nexps env n1 n2) | _ -> false in unpack,build_ex,autocast in @@ -1048,13 +1154,33 @@ let doc_exp_lem, doc_let_lem = (doc_fexp ctxt recordtyp) fexps)) in if aexp_needed then parens epp else epp | E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) -> - let recordtyp = match annot with + let recordtyp, env = match annot with | Some (env, Typ_aux (Typ_id tid,_), _) | Some (env, Typ_aux (Typ_app (tid, _), _), _) when Env.is_record tid env -> - tid + tid, env | _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in - enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps)) + if List.length fexps > 1 then + let _,fields = Env.get_record recordtyp env in + let var, let_pp = + match e with + | E_aux (E_id id,_) -> id, empty + | _ -> let v = mk_id "_record" in (* TODO: collision avoid *) + v, separate space [string "let "; doc_id v; coloneq; top_exp ctxt true e; string "in"] ^^ break 1 + in + let doc_field (_,id) = + match List.find (fun (FE_aux (FE_Fexp (id',_),_)) -> Id.compare id id' == 0) fexps with + | fexp -> doc_fexp ctxt recordtyp fexp + | exception Not_found -> + let fname = + if prefix_recordtype && string_of_id recordtyp <> "regstate" + then (string (string_of_id recordtyp ^ "_")) ^^ doc_id id + else doc_id id in + doc_op coloneq fname (doc_id var ^^ dot ^^ parens fname) + in let_pp ^^ enclose_record (align (separate_map (semi_sp ^^ break 1) + doc_field fields)) + else + enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps)) | E_vector exps -> let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let start, (len, order, etyp) = @@ -1079,7 +1205,9 @@ let doc_exp_lem, doc_let_lem = if is_bit_typ etyp then let bepp = string "vec_of_bits" ^^ space ^^ align epp in (align (group (prefix 0 1 bepp (doc_tannot_lem ctxt (env_of full_exp) false t))), true) - else (epp,aexp_needed) in + else + let vepp = string "vec_of_list_len" ^^ space ^^ align epp in + (vepp,aexp_needed) in if aexp_needed then parens (align epp) else epp | E_vector_update(v,e1,e2) -> raise (Reporting_basic.err_unreachable l @@ -1100,7 +1228,8 @@ let doc_exp_lem, doc_let_lem = if effectful (effect_of e) then let try_catch = if ctxt.early_ret then "try_catchR" else "try_catch" in let epp = - group ((separate space [string try_catch; expY e; string "(function "]) ^/^ + (* TODO capture avoidance for __catch_val *) + group ((separate space [string try_catch; expY e; string "(fun __catch_val => match __catch_val with "]) ^/^ (separate_map (break 1) (doc_case ctxt exc_typ) pexps) ^/^ (string "end)")) in if aexp_needed then parens (align epp) else align epp @@ -1119,24 +1248,37 @@ let doc_exp_lem, doc_let_lem = | E_var(lexp, eq_exp, in_exp) -> raise (report l "E_vars should have been removed before pretty-printing") | E_internal_plet (pat,e1,e2) -> - let epp = - let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in - let middle = - match fst (untyp_pat pat) with - | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> - string ">>" - | P_aux (P_id id,_) -> - separate space [string ">>= fun"; doc_id id; bigarrow] - | P_aux (P_typ (typ, P_aux (P_id id,_)),_) -> - separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt typ; bigarrow] - | _ -> - separate space [string ">>= fun"; squote ^^ doc_pat ctxt true (pat, typ_of e1); bigarrow] - in - infix 0 1 middle (expV b e1) (expN e2) - in - if aexp_needed then parens (align epp) else epp + begin + match pat, e1 with + | (P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _)), + (E_aux (E_assert (assert_e1,assert_e2),_)) -> + let epp = liftR (separate space [string "assert_exp'"; expY assert_e1; expY assert_e2]) in + let epp = infix 0 1 (string ">>= fun _ =>") epp (expN e2) in + if aexp_needed then parens (align epp) else align epp + | _ -> + let epp = + let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in + let middle = + match fst (untyp_pat pat) with + | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> + string ">>" + | P_aux (P_id id,_) -> + separate space [string ">>= fun"; doc_id id; bigarrow] + | P_aux (P_typ (typ, P_aux (P_id id,_)),_) -> + separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt typ; bigarrow] + | _ -> + separate space [string ">>= fun"; squote ^^ doc_pat ctxt true (pat, typ_of e1); bigarrow] + in + infix 0 1 middle (expV b e1) (expN e2) + in + if aexp_needed then parens (align epp) else epp + end | E_internal_return (e1) -> - wrap_parens (align (separate space [string "returnm"; expY e1])) + let e1pp = expY e1 in + let valpp = if ctxt.build_ex_return + then parens (string "build_ex" ^^ space ^^ e1pp) + else e1pp in + wrap_parens (align (separate space [string "returnm"; valpp])) | E_sizeof nexp -> (match nexp_simp nexp with | Nexp_aux (Nexp_constant i, _) -> doc_lit (L_aux (L_num i, l)) @@ -1153,7 +1295,7 @@ let doc_exp_lem, doc_let_lem = parens (doc_typ ctxt (typ_of full_exp)); parens (doc_typ ctxt (typ_of r))] in align (parens (string "early_return" ^//^ expV true r ^//^ ta)) - | E_constraint _ -> string "true" + | E_constraint nc -> wrap_parens (doc_nc_exp ctxt nc) | E_comment _ | E_comment_struc _ -> empty | E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_exp_user _ | E_internal_value _ -> @@ -1168,7 +1310,9 @@ let doc_exp_lem, doc_let_lem = | _ -> prefix 2 1 (string "else") (top_exp ctxt false e) in (prefix 2 1 - (soft_surround 2 1 if_pp (string "sumbool_of_bool" ^^ space ^^ parens (top_exp ctxt true c)) (string "then")) + (soft_surround 2 1 if_pp + ((if condition_produces_constraint c then string "sumbool_of_bool" ^^ space else empty) + ^^ parens (top_exp ctxt true c)) (string "then")) (top_exp ctxt false t)) ^^ break 1 ^^ else_pp @@ -1404,17 +1548,28 @@ let demote_as_pattern i (P_aux (_,p_annot) as pat,typ) = E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)),p_annot),e),e_ann) else (pat,typ), fun e -> e -(* Ideally we'd remove the duplication between type variables and atom - arguments, but for now we just add an equality constraint. *) +(* Add equality constraints between arguments and nexps, except in the case + that they've been merged. *) -let atom_constraint ctxt (pat, typ) = +let rec atom_constraint ctxt (pat, typ) = let typ = Env.base_typ_of (pat_env_of pat) typ in match pat, typ with | P_aux (P_id id, _), Typ_aux (Typ_app (Id_aux (Id "atom",_), - [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_) -> + [Typ_arg_aux (Typ_arg_nexp nexp,_)]),_) -> + (match nexp with + (* When the kid is mapped to the id, we don't need a constraint *) + | Nexp_aux (Nexp_var kid,_) + when (try Id.compare (KBindings.find kid ctxt.kid_id_renames) id == 0 with _ -> false) -> + None + | _ -> + Some (bquote ^^ braces (string "ArithFact" ^^ space ^^ + parens (doc_op equals (doc_id id) (doc_nexp ctxt nexp))))) + | P_aux (P_id id, _), + Typ_aux (Typ_id (Id_aux (Id "nat",_)),_) -> Some (bquote ^^ braces (string "ArithFact" ^^ space ^^ - parens (doc_op equals (doc_id id) (doc_var_lem ctxt kid)))) + parens (doc_op (string ">=") (doc_id id) (string "0")))) + | P_aux (P_typ (_,p),_), _ -> atom_constraint ctxt (p, typ) | _ -> None let all_ids pexp = @@ -1485,14 +1640,13 @@ let merge_kids_atoms pats = let gone,map,_ = List.fold_left try_eliminate (KidSet.empty, KBindings.empty, KidSet.empty) pats in gone,map -let doc_binder ctxt (P_aux (p,ann) as pat, typ) = - let env = env_of_annot ann in - let exp_typ = Env.expand_synonyms env typ in - match p with - | P_id id - | P_typ (_,P_aux (P_id id,_)) when Util.is_none (is_auto_decomposed_exist env exp_typ) -> - parens (separate space [doc_id id; colon; doc_typ ctxt typ]) - | _ -> squote ^^ parens (separate space [doc_pat ctxt true (pat, exp_typ); colon; doc_typ ctxt typ]) +let merge_var_patterns map pats = + let map,pats = List.fold_left (fun (map,pats) (pat, typ) -> + match pat with + | P_aux (P_var (P_aux (P_id id,_), TP_aux (TP_var kid,_)),ann) -> + KBindings.add kid id map, (P_aux (P_id id,ann), typ) :: pats + | _ -> map, (pat,typ)::pats) (map,[]) pats + in map, List.rev pats let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) = let (tq,typ) = Env.get_val_spec_orig id (env_of_annot annot) in @@ -1500,38 +1654,70 @@ let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) = | Typ_aux (Typ_fn (arg_typ, ret_typ, eff),_) -> arg_typ, ret_typ, eff | _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type") in + let build_ex, ret_typ = replace_atom_return_type ret_typ in let ids_to_avoid = all_ids pexp in let kids_used = tyvars_of_typquant tq in let pat,guard,exp,(l,_) = destruct_pexp pexp in let pats, bind = untuple_args_pat arg_typ pat in let pats, binds = List.split (Util.list_mapi demote_as_pattern pats) in let eliminated_kids, kid_to_arg_rename = merge_kids_atoms pats in + let kid_to_arg_rename, pats = merge_var_patterns kid_to_arg_rename pats in let kids_used = KidSet.diff kids_used eliminated_kids in let ctxt = { early_ret = contains_early_return exp; kid_renames = mk_kid_renames ids_to_avoid kids_used; kid_id_renames = kid_to_arg_rename; - bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) } in + bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ); + build_ex_return = effectful eff && build_ex; + } in (* Put the constraints after pattern matching so that any type variable that's been replaced by one of the term-level arguments is bound. *) let quantspp, constrspp = doc_typquant_items_separate ctxt braces tq in let exp = List.fold_left (fun body f -> f body) (bind exp) binds in - let patspp = separate_map space (doc_binder ctxt) pats in - let atom_constr_pp = separate_opt space (atom_constraint ctxt) pats in + let used_a_pattern = ref false in + let doc_binder (P_aux (p,ann) as pat, typ) = + let env = env_of_annot ann in + let exp_typ = Env.expand_synonyms env typ in + match p with + | P_id id + | P_typ (_,P_aux (P_id id,_)) when Util.is_none (is_auto_decomposed_exist env exp_typ) -> + parens (separate space [doc_id id; colon; doc_typ ctxt typ]) + | _ -> + (used_a_pattern := true; + squote ^^ parens (separate space [doc_pat ctxt true (pat, exp_typ); colon; doc_typ ctxt typ])) + in + let patspp = separate_map space doc_binder pats in + let atom_constrs = Util.map_filter (atom_constraint ctxt) pats in + let atom_constr_pp = separate space atom_constrs in let retpp = if effectful eff then string "M" ^^ space ^^ parens (doc_typ ctxt ret_typ) else doc_typ ctxt ret_typ in + let idpp = doc_id id in + (* Work around Coq bug 7975 about pattern binders followed by implicit arguments *) + let implicitargs = + if !used_a_pattern && List.length constrspp + List.length atom_constrs > 0 then + break 1 ^^ separate space + ([string "Arguments"; idpp;] @ + List.map (fun _ -> string "{_}") quantspp @ + List.map (fun _ -> string "_") pats @ + List.map (fun _ -> string "{_}") constrspp @ + List.map (fun _ -> string "{_}") atom_constrs) + ^^ dot + else empty + in let _ = match guard with | None -> () | _ -> raise (Reporting_basic.err_unreachable l "guarded pattern expression should have been rewritten before pretty-printing") in + let bodypp = doc_fun_body ctxt exp in + let bodypp = if effectful eff || not build_ex then bodypp else string "build_ex" ^^ parens bodypp in group (prefix 3 1 - (separate space [doc_id id; quantspp; patspp; constrspp; atom_constr_pp] ^/^ - colon ^^ space ^^ retpp ^^ coloneq) - (doc_fun_body ctxt exp ^^ dot)) + (separate space ([idpp] @ quantspp @ [patspp] @ constrspp @ [atom_constr_pp]) ^/^ + separate space [colon; retpp; coloneq]) + (bodypp ^^ dot)) ^^ implicitargs let get_id = function | [] -> failwith "FD_function with empty list" @@ -1658,9 +1844,16 @@ let doc_axiom_typschm typ_env (TypSchm_aux (TypSchm_ts (tqs,typ),l) as ts) = | _ -> parens (underscore ^^ string " : " ^^ doc_typ empty_ctxt typ) in let arg_typs_pp = separate space (List.map doc_typ' typs) in + let _, ret_ty = replace_atom_return_type ret_ty in let ret_typ_pp = doc_typ empty_ctxt ret_ty in + let ret_typ_pp = + if effectful eff + then string "M" ^^ space ^^ parens ret_typ_pp + else ret_typ_pp + in let tyvars_pp, constrs_pp = doc_typquant_items_separate empty_ctxt braces tqs in - string "forall" ^/^ tyvars_pp ^/^ arg_typs_pp ^/^ constrs_pp ^^ comma ^/^ ret_typ_pp + string "forall" ^/^ separate space tyvars_pp ^/^ + arg_typs_pp ^/^ separate space constrs_pp ^^ comma ^/^ ret_typ_pp | _ -> doc_typschm empty_ctxt true ts let doc_val_spec unimplemented (VS_aux (VS_val_spec(tys,id,_,_),ann)) = @@ -1800,6 +1993,8 @@ try (fun lib -> separate space [string "Require Import";string lib] ^^ dot) defs_modules;hardline; string "Import ListNotations."; hardline; + string "Open Scope string."; hardline; + string "Open Scope bool."; hardline; (* Put the body into a Section so that we can define some values with Let to put them into the local context, where tactics can see them *) string "Section Content."; |
