From 1c1a121ae0434e5dc6cb05bbafa6e8c2fa3cbf35 Mon Sep 17 00:00:00 2001 From: Brian Campbell Date: Mon, 25 Jun 2018 15:44:45 +0100 Subject: Coq: automatic cast introduction --- lib/coq/Sail2_operators_mwords.v | 3 ++ lib/vector_dec.sail | 2 +- lib/vector_inc.sail | 2 +- src/pretty_print_coq.ml | 60 ++++++++++++++++++++++++++++++++-------- 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/lib/coq/Sail2_operators_mwords.v b/lib/coq/Sail2_operators_mwords.v index 25a643e7..ee98c94e 100644 --- a/lib/coq/Sail2_operators_mwords.v +++ b/lib/coq/Sail2_operators_mwords.v @@ -27,6 +27,9 @@ Qed. Definition autocast {m n} (x : mword m) `{H:ArithFact (m = n)} : mword n := cast_mword x (use_ArithFact H). +Definition autocast_m {rv e m n} (x : monad rv (mword m) e) `{H:ArithFact (m = n)} : monad rv (mword n) e := + x >>= fun x => returnm (cast_mword x (use_ArithFact H)). + Definition cast_word {m n} (x : Word.word m) (eq : m = n) : Word.word n. rewrite <- eq. exact x. diff --git a/lib/vector_dec.sail b/lib/vector_dec.sail index 1d528cf6..86bbe601 100644 --- a/lib/vector_dec.sail +++ b/lib/vector_dec.sail @@ -112,7 +112,7 @@ val vector_subrange = { c: "vector_subrange", coq: "subrange_vec_dec" } : forall ('n : Int) ('m : Int) ('o : Int), 0 <= 'o <= 'm < 'n. - (bits('n), atom('m), atom('o)) -> bits('m - ('o - 1)) + (bits('n), atom('m), atom('o)) -> bits('m - 'o + 1) val vector_update_subrange = { ocaml: "update_subrange", diff --git a/lib/vector_inc.sail b/lib/vector_inc.sail index 873d2d33..b295c92c 100644 --- a/lib/vector_inc.sail +++ b/lib/vector_inc.sail @@ -106,7 +106,7 @@ val vector_subrange = { c: "vector_subrange", coq: "subrange_vec_inc" } : forall ('n : Int) ('m : Int) ('o : Int), 0 <= 'm <= 'o < 'n. - (bits('n), atom('m), atom('o)) -> bits('o - ('m - 1)) + (bits('n), atom('m), atom('o)) -> bits('o - 'm + 1) val vector_update_subrange = { ocaml: "update_subrange", diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index 2b328ecb..5a07cb1b 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -687,6 +687,27 @@ let typ_id_of (Typ_aux (typ, l)) = match typ with | Typ_app (id, _) -> id | _ -> raise (Reporting_basic.err_unreachable l "failed to get type id") +(* Decide whether two nexps used in a vector size are similar; if not + a cast will be inserted *) +let similar_nexps n1 n2 = + let rec same_nexp_shape (Nexp_aux (n1,_)) (Nexp_aux (n2,_)) = + match n1, n2 with + | Nexp_id _, Nexp_id _ + | Nexp_var _, Nexp_var _ + -> true + | Nexp_constant c1, Nexp_constant c2 -> Nat_big_num.equal c1 c2 + | Nexp_app (f1,args1), Nexp_app (f2,args2) -> + Id.compare f1 f2 == 0 && List.for_all2 same_nexp_shape args1 args2 + | Nexp_times (n1,n2), Nexp_times (n3,n4) + | Nexp_sum (n1,n2), Nexp_sum (n3,n4) + | Nexp_minus (n1,n2), Nexp_minus (n3,n4) + -> same_nexp_shape n1 n3 && same_nexp_shape n2 n4 + | Nexp_exp n1, Nexp_exp n2 + | Nexp_neg n1, Nexp_neg n2 + -> same_nexp_shape n1 n2 + | _ -> false + in if same_nexp_shape n1 n2 then true else false + let prefix_recordtype = true let report = Reporting_basic.err_unreachable let doc_exp_lem, doc_let_lem = @@ -910,11 +931,11 @@ let doc_exp_lem, doc_let_lem = then string (Env.get_extern f env "coq"), true else doc_id f, false in let (tqs,fn_ty) = Env.get_val_spec_orig f env in - let arg_typs, ret_typ = match fn_ty with - | Typ_aux (Typ_fn (arg_typ,ret_typ,_),_) -> + let arg_typs, ret_typ, eff = match fn_ty with + | Typ_aux (Typ_fn (arg_typ,ret_typ,eff),_) -> (match arg_typ with - | Typ_aux (Typ_tup typs,_) -> typs, ret_typ - | _ -> [arg_typ], ret_typ) + | Typ_aux (Typ_tup typs,_) -> typs, ret_typ, eff + | _ -> [arg_typ], ret_typ, eff) | _ -> raise (Reporting_basic.err_unreachable l "Function not a function type") in (* Insert existential unpacking of arguments where necessary *) @@ -929,19 +950,34 @@ let doc_exp_lem, doc_let_lem = in let epp = hang 2 (flow (break 1) (call :: List.map2 doc_arg args arg_typs)) in (* Unpack existential result *) - let ret_typ_inst = subst_unifiers (instantiation_of full_exp) ret_typ in - let unpack,build_ex = + let inst = instantiation_of full_exp in + let inst = KBindings.fold (fun k u m -> KBindings.add (orig_kid k) u m) inst KBindings.empty in + let ret_typ_inst = subst_unifiers inst ret_typ in + let unpack,build_ex,autocast = let ann_typ = Env.expand_synonyms env (typ_of_annot (l,annot)) in let ann_typ = expand_range_type ann_typ in let ret_typ_inst = expand_range_type (Env.expand_synonyms env ret_typ_inst) in - match ret_typ_inst, ann_typ with - | Typ_aux (Typ_exist _,_), Typ_aux (Typ_exist _,_) -> - if alpha_equivalent env ret_typ_inst ann_typ then false,false else true,true - | Typ_aux (Typ_exist _,_), _ -> true,false - | _, Typ_aux (Typ_exist _,_) -> false,true - | _, _ -> false,false + let unpack, build_ex, in_typ, out_typ = + match ret_typ_inst, ann_typ with + | Typ_aux (Typ_exist (_,_,t1),_), Typ_aux (Typ_exist (_,_,t2),_) -> + if alpha_equivalent env ret_typ_inst ann_typ + then false,false,t1,t2 + else true,true,t1,t2 + | Typ_aux (Typ_exist (_,_,t1),_),t2 -> true,false,t1,t2 + | t1, Typ_aux (Typ_exist (_,_,t2),_) -> false,true,t1,t2 + | t1, t2 -> false,false,t1,t2 + in + let autocast = + match destruct_vector env in_typ, destruct_vector env out_typ with + | Some (n1,_,t1), Some (n2,_,t2) + when is_bit_typ t1 && is_bit_typ t2 -> + not (similar_nexps n1 n2) + | _ -> false + in unpack,build_ex,autocast in + let autocast_id = if effectful eff then "autocast_m" else "autocast" in let epp = if unpack then string "projT1" ^^ space ^^ parens epp else epp in + let epp = if autocast then string autocast_id ^^ space ^^ parens epp else epp in let epp = if build_ex then string "build_ex" ^^ space ^^ parens epp else epp in liftR (if aexp_needed then parens (align epp) else epp) end -- cgit v1.2.3