summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/monomorphise.ml14
-rw-r--r--src/monomorphise.mli2
-rw-r--r--src/pretty_print_lem.ml20
-rw-r--r--src/rewrites.ml2
4 files changed, 22 insertions, 16 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 9efacb7a..4eea717a 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -3479,7 +3479,7 @@ let rec extract (E_aux (e,_)) =
provide some way for the Lem pretty printer to know what to use; currently
we just use two names for the cast, bitvector_cast_in and bitvector_cast_out,
to let the pretty printer know whether to use the top-level environment. *)
-let add_bitvector_casts (Defs defs) =
+let add_bitvector_casts global_env (Defs defs) =
let rewrite_body id quant_kids top_env ret_typ exp =
let rewrite_aux (e,ann) =
match e with
@@ -3568,9 +3568,8 @@ let add_bitvector_casts (Defs defs) =
e_aux = rewrite_aux } exp
in
let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),((l,_) as fcl_ann))) =
- let fcl_env = env_of_annot fcl_ann in
- let (tq,typ) = Env.get_val_spec_orig id fcl_env in
- let fun_env = add_typquant l tq fcl_env in
+ let (tq,typ) = Env.get_val_spec_orig id global_env in
+ let fun_env = List.fold_right (Env.add_typ_var l) (quant_kopts tq) global_env in
let quant_kids = List.map kopt_kid (List.filter is_int_kopt (quant_kopts tq)) in
let ret_typ =
match typ with
@@ -3582,6 +3581,13 @@ let add_bitvector_casts (Defs defs) =
in
let pat,guard,body,annot = destruct_pexp pexp in
let body = rewrite_body id quant_kids fun_env ret_typ body in
+ (* Cast function arguments, if necessary *)
+ let add_constraint insts = function
+ | NC_aux (NC_equal (Nexp_aux (Nexp_var kid,_), nexp), _) -> KBindings.add kid nexp insts
+ | _ -> insts
+ in
+ let insts = List.fold_left add_constraint KBindings.empty (Env.get_constraints (env_of body)) in
+ let body = make_bitvector_env_casts fun_env (env_of body) quant_kids insts body in
(* Also add a cast around the entire function clause body, if necessary *)
let body =
make_bitvector_cast_exp "bitvector_cast_out" fun_env quant_kids (fill_in_type (env_of body) (typ_of body)) ret_typ body
diff --git a/src/monomorphise.mli b/src/monomorphise.mli
index 39d89461..b4eb7ead 100644
--- a/src/monomorphise.mli
+++ b/src/monomorphise.mli
@@ -69,7 +69,7 @@ val mono_rewrites : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
val rewrite_toplevel_nexps : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
(* Add casts across case splits *)
-val add_bitvector_casts : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
+val add_bitvector_casts : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
(* Replace atom arguments which are fixed by a type parameter for a size with a singleton type *)
val rewrite_atoms_to_singletons : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 0994a821..581872ee 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -1367,7 +1367,7 @@ let doc_fun_body_lem ctxt exp =
then align (string "catch_early_return" ^//^ parens (doc_exp))
else doc_exp
-let doc_funcl_lem (FCL_aux(FCL_Funcl(id, pexp), annot)) =
+let doc_funcl_lem type_env (FCL_aux(FCL_Funcl(id, pexp), annot)) =
let typ = typ_of_annot annot in
let arg_typs = match typ with
| Typ_aux (Typ_fn (arg_typs, typ_ret, _), _) -> arg_typs
@@ -1377,7 +1377,7 @@ let doc_funcl_lem (FCL_aux(FCL_Funcl(id, pexp), annot)) =
let ctxt =
{ early_ret = contains_early_return exp;
bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ);
- top_env = env_of_annot annot } in
+ top_env = type_env } in
let pats, bind = untuple_args_pat pat arg_typs in
let patspp = separate_map space (doc_pat_lem ctxt true) pats in
let _ = match guard with
@@ -1398,16 +1398,16 @@ module StringSet = Set.Make(String)
(* Strictly speaking, Lem doesn't support multiple clauses for a single function
joined by "and", although it has worked for Isabelle before. However, all
the funcls should have been merged by the merge_funcls rewrite now. *)
-let doc_fundef_rhs_lem (FD_aux(FD_function(r, typa, efa, funcls),fannot) as fd) =
- separate_map (hardline ^^ string "and ") doc_funcl_lem funcls
+let doc_fundef_rhs_lem type_env (FD_aux(FD_function(r, typa, efa, funcls),fannot) as fd) =
+ separate_map (hardline ^^ string "and ") (doc_funcl_lem type_env) funcls
-let doc_mutrec_lem = function
+let doc_mutrec_lem type_env = function
| [] -> failwith "DEF_internal_mutrec with empty function list"
| fundefs ->
string "let rec " ^^
- separate_map (hardline ^^ string "and ") doc_fundef_rhs_lem fundefs
+ separate_map (hardline ^^ string "and ") (doc_fundef_rhs_lem type_env) fundefs
-let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) =
+let rec doc_fundef_lem type_env (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) =
match fcls with
| [] -> failwith "FD_function with empty function list"
| FCL_aux (FCL_Funcl(id, pexp),annot) :: _
@@ -1422,7 +1422,7 @@ let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) =
pexp
in
let doc_rec = if is_funcl_rec then [string "rec"] else [] in
- separate space ([string "let"] @ doc_rec @ [doc_fundef_rhs_lem fd])
+ separate space ([string "let"] @ doc_rec @ [doc_fundef_rhs_lem type_env fd])
| _ -> empty
@@ -1523,8 +1523,8 @@ let rec doc_def_lem type_env def =
| DEF_reg_dec dec -> group (doc_dec_lem dec)
| DEF_default df -> empty
- | DEF_fundef fdef -> group (doc_fundef_lem fdef) ^/^ hardline
- | DEF_internal_mutrec fundefs -> doc_mutrec_lem fundefs ^/^ hardline
+ | DEF_fundef fdef -> group (doc_fundef_lem type_env fdef) ^/^ hardline
+ | DEF_internal_mutrec fundefs -> doc_mutrec_lem type_env fundefs ^/^ hardline
| DEF_val (LB_aux (LB_val (pat, _), _) as lbind) ->
group (doc_let_lem empty_ctxt lbind) ^/^ hardline
| DEF_scattered sdef -> failwith "doc_def_lem: shoulnd't have DEF_scattered at this point"
diff --git a/src/rewrites.ml b/src/rewrites.ml
index abb3e4ed..a3238cce 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -4873,7 +4873,7 @@ let all_rewrites = [
("toplevel_nexps", Basic_rewriter rewrite_toplevel_nexps);
("monomorphise", String_rewriter (fun target -> Basic_rewriter (monomorphise target)));
("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
- ("add_bitvector_casts", Basic_rewriter (fun _ -> Monomorphise.add_bitvector_casts));
+ ("add_bitvector_casts", Basic_rewriter Monomorphise.add_bitvector_casts);
("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
("remove_impossible_int_cases", Basic_rewriter Constant_propagation.remove_impossible_int_cases);
("const_prop_mutrec", String_rewriter (fun target -> Basic_rewriter (Constant_propagation_mutrec.rewrite_defs target)));