summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml65
-rw-r--r--src/pretty_print_lem.ml35
2 files changed, 71 insertions, 29 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index e9e2c6b6..ab6d9e2d 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -3912,7 +3912,7 @@ let simplify_size_nexp env quant_kids (Nexp_aux (_,l) as nexp) =
(* These functions add cast functions across case splits, so that when a
bitvector size becomes known in sail, the generated Lem code contains a
function call to change mword 'n to (say) mword ty16, and vice versa. *)
-let make_bitvector_cast_fns env quant_kids src_typ target_typ =
+let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
let genunk = Generated Unknown in
let fresh =
let counter = ref 0 in
@@ -3945,7 +3945,7 @@ let make_bitvector_cast_fns env quant_kids src_typ target_typ =
P_aux (P_id var,(Generated src_l,src_ann)),
E_aux
(E_cast (tar_typ',
- E_aux (E_app (Id_aux (Id "bitvector_cast", genunk),
+ E_aux (E_app (Id_aux (Id cast_name, genunk),
[E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))),
(genunk, tar_ann))
| _ ->
@@ -3971,12 +3971,12 @@ let make_bitvector_cast_fns env quant_kids src_typ target_typ =
(fun var exp ->
let exp_ann = mk_tannot env (typ_of exp) (effect_of exp) in
E_aux (E_let (LB_aux (LB_val (P_aux (P_typ (one_target_typ, P_aux (P_id var,(genunk,tar_ann))),(genunk,tar_ann)),
- E_aux (E_app (Id_aux (Id "bitvector_cast",genunk),
+ E_aux (E_app (Id_aux (Id cast_name,genunk),
[E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann))),(genunk,tar_ann)),
exp),(genunk,exp_ann))),
(fun (E_aux (_,(exp_l,exp_ann)) as exp) ->
E_aux (E_cast (one_target_typ,
- E_aux (E_app (Id_aux (Id "bitvector_cast", genunk), [exp]), (Generated exp_l,tar_ann))),
+ E_aux (E_app (Id_aux (Id cast_name, genunk), [exp]), (Generated exp_l,tar_ann))),
(Generated exp_l,tar_ann)))
| _ ->
(fun var exp ->
@@ -3991,12 +3991,12 @@ let make_bitvector_cast_fns env quant_kids src_typ target_typ =
(* TODO: bound vars *)
let make_bitvector_env_casts env quant_kids (kid,i) exp =
- let mk_cast var typ exp = (fst (make_bitvector_cast_fns env quant_kids typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
+ let mk_cast var typ exp = (fst (make_bitvector_cast_fns "bitvector_cast_in" env quant_kids typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
let locals = Env.get_locals env in
Bindings.fold (fun var (mut,typ) exp ->
if mut = Immutable then mk_cast var typ exp else exp) locals exp
-let make_bitvector_cast_exp env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns env quant_kids typ target_typ)) exp
+let make_bitvector_cast_exp cast_name env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns cast_name env quant_kids typ target_typ)) exp
let rec extract_value_from_guard var (E_aux (e,_)) =
match e with
@@ -4023,6 +4023,12 @@ let fill_in_type env typ =
subst_src_typ subst typ
(* TODO: top-level patterns *)
+(* TODO: proper environment tracking for variables. Currently we pretend that
+ we can print the type of a variable in the top-level environment, but in
+ practice they might be below a case split. Note that we'd also need to
+ provide some way for the Lem pretty printer to know what to use; currently
+ we just use two names for the cast, bitvector_cast_in and bitvector_cast_out,
+ to let the pretty printer know whether to use the top-level environment. *)
let add_bitvector_casts (Defs defs) =
let rewrite_body id quant_kids top_env ret_typ exp =
let rewrite_aux (e,ann) =
@@ -4039,13 +4045,13 @@ let add_bitvector_casts (Defs defs) =
let body = match pat, guard with
| P_aux (P_lit (L_aux (L_num i,_)),_), _ ->
let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
- make_bitvector_cast_exp env quant_kids src_typ result_typ
+ make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
(make_bitvector_env_casts env quant_kids (kid,i) body)
| P_aux (P_id var,_), Some guard ->
(match extract_value_from_guard var guard with
| Some i ->
let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
- make_bitvector_cast_exp env quant_kids src_typ result_typ
+ make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
(make_bitvector_env_casts env quant_kids (kid,i) body)
| None -> body)
| _ ->
@@ -4056,15 +4062,46 @@ let add_bitvector_casts (Defs defs) =
E_aux (E_case (exp', List.map map_case cases),ann)
| _ -> E_aux (e,ann)
end
+ | E_if (e1,e2,e3) ->
+ let env = env_of_annot ann in
+ let result_typ = Env.base_typ_of env (typ_of_annot ann) in
+ let rec extract (E_aux (e,_)) =
+ match e with
+ | E_app (op,
+ ([E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_); y] |
+ [y; E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_)]))
+ when string_of_id op = "eq_atom" ->
+ (match destruct_atom_nexp (env_of y) (typ_of y) with
+ | Some (Nexp_aux (Nexp_constant i,_)) -> [(kid,i)]
+ | _ -> [])
+ | E_app (op, [x;y]) when string_of_id op = "and_bool" ->
+ extract x @ extract y
+ | _ -> []
+ in
+ let insts = extract e1 in
+ let e2' = List.fold_left (fun body inst ->
+ make_bitvector_env_casts env quant_kids inst body) e2 insts in
+ let insts = List.fold_left (fun insts (kid,i) ->
+ KBindings.add kid (nconstant i) insts) KBindings.empty insts in
+ let src_typ = subst_src_typ insts result_typ in
+ let e2' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ e2' in
+ E_aux (E_if (e1,e2',e3), ann)
| E_return e' ->
- E_aux (E_return (make_bitvector_cast_exp top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann)
- (* TODO: (env_of_annot ann) isn't suitable, because it contains
- constraints revealing the case splits involved; needs a more
- subtle approach *)
+ E_aux (E_return (make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann)
| E_assign (LEXP_aux (lexp,lexp_annot),e') ->
E_aux (E_assign (LEXP_aux (lexp,lexp_annot),
- make_bitvector_cast_exp (env_of_annot ann) quant_kids (fill_in_type (env_of e') (typ_of e'))
+ make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e'))
(typ_of_annot lexp_annot) e'),ann)
+ | E_id id -> begin
+ let env = env_of_annot ann in
+ match Env.lookup_id id env with
+ | Local (Mutable, vtyp) ->
+ make_bitvector_cast_exp "bitvector_cast_in" top_env quant_kids
+ (fill_in_type (env_of_annot ann) (typ_of_annot ann))
+ vtyp
+ (E_aux (e,ann))
+ | _ -> E_aux (e,ann)
+ end
| _ -> E_aux (e,ann)
in
let open Rewriter in
@@ -4089,7 +4126,7 @@ let add_bitvector_casts (Defs defs) =
let body = rewrite_body id quant_kids body_env ret_typ body in
(* Also add a cast around the entire function clause body, if necessary *)
let body =
- make_bitvector_cast_exp fcl_env quant_kids (fill_in_type body_env (typ_of body)) ret_typ body
+ make_bitvector_cast_exp "bitvector_cast_out" fcl_env quant_kids (fill_in_type body_env (typ_of body)) ret_typ body
in
let pexp = construct_pexp (pat,guard,body,annot) in
FCL_aux (FCL_Funcl (id,pexp),fcl_ann)
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 75284418..bef54f05 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -65,8 +65,9 @@ let opt_mwords = ref false
type context = {
early_ret : bool;
bound_nexps : NexpSet.t;
+ top_env : Env.t
}
-let empty_ctxt = { early_ret = false; bound_nexps = NexpSet.empty }
+let empty_ctxt = { early_ret = false; bound_nexps = NexpSet.empty; top_env = Env.empty }
let print_to_from_interp_value = ref false
let langlebar = string "<|"
@@ -328,10 +329,9 @@ let doc_typ_lem, doc_atomic_typ_lem =
| Typ_arg_order o -> empty
in typ', atomic_typ
-(* Check for variables in types that would be pretty-printed and are not
- bound in the val spec of the function. *)
+(* Check for variables in types that would be pretty-printed. *)
let contains_t_pp_var ctxt (Typ_aux (t,a) as typ) =
- NexpSet.diff (lem_nexps_of_typ typ) ctxt.bound_nexps
+ lem_nexps_of_typ typ
|> NexpSet.exists (fun nexp -> not (is_nexp_constant nexp))
let replace_typ_size ctxt env (Typ_aux (t,a)) =
@@ -341,14 +341,14 @@ let replace_typ_size ctxt env (Typ_aux (t,a)) =
let mk_typ nexp =
Some (Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp nexp,Parse_ast.Unknown);ord;typ']),a))
in
- let is_equal nexp =
- prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown))
- in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with
- | nexp -> mk_typ nexp
- | exception Not_found ->
- match Type_check.solve env size with
- | Some n -> mk_typ (nconstant n)
- | None -> None
+ match Type_check.solve env size with
+ | Some n -> mk_typ (nconstant n)
+ | None ->
+ let is_equal nexp =
+ prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown))
+ in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with
+ | nexp -> mk_typ nexp
+ | exception Not_found -> None
end
| _ -> None
@@ -760,8 +760,12 @@ let doc_exp_lem, doc_let_lem =
let env = env_of full_exp in
let t = Env.expand_synonyms env (typ_of full_exp) in
let eff = effect_of full_exp in
- if typ_needs_printed t
- then (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env (effectful eff) t))), true)
+ if typ_needs_printed t then
+ if Id.compare f (mk_id "bitvector_cast_out") <> 0
+ then (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env (effectful eff) t))), true)
+ (* TODO: coordinate with the code in monomorphise.ml to find the correct
+ typing environment to use *)
+ else (align (group (prefix 0 1 epp (doc_tannot_lem ctxt ctxt.top_env (effectful eff) t))), true)
else (epp, aexp_needed) in
liftR (if aexp_needed then parens (align taepp) else taepp)
end
@@ -1255,7 +1259,8 @@ let doc_funcl_lem (FCL_aux(FCL_Funcl(id, pexp), annot)) =
let pat,guard,exp,(l,_) = destruct_pexp pexp in
let ctxt =
{ early_ret = contains_early_return exp;
- bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) } in
+ bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ);
+ top_env = env_of_annot annot } in
let pats, bind = untuple_args_pat pat in
let patspp = separate_map space (doc_pat_lem ctxt true) pats in
let _ = match guard with