summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBrian Campbell2018-08-07 11:02:09 +0100
committerBrian Campbell2018-08-07 11:14:20 +0100
commitf9282ab5dec29d7ec99d473d013d32b41a0b8dbc (patch)
tree97908cf43a4eb858a6be57e7b370eba503b6cf03 /src
parent6538b63c944e32692447423829f8f4e91428b473 (diff)
Improve cast introduction for Lem
Handles mutable variables and conditionals (there are still some corner cases that don't appear in Aarch64 to do). The pretty printer is now back to preferring to use concrete types, but has a special case for casts to print more general types.
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml65
-rw-r--r--src/pretty_print_lem.ml35
2 files changed, 71 insertions, 29 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index e9e2c6b6..ab6d9e2d 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -3912,7 +3912,7 @@ let simplify_size_nexp env quant_kids (Nexp_aux (_,l) as nexp) =
(* These functions add cast functions across case splits, so that when a
bitvector size becomes known in sail, the generated Lem code contains a
function call to change mword 'n to (say) mword ty16, and vice versa. *)
-let make_bitvector_cast_fns env quant_kids src_typ target_typ =
+let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
let genunk = Generated Unknown in
let fresh =
let counter = ref 0 in
@@ -3945,7 +3945,7 @@ let make_bitvector_cast_fns env quant_kids src_typ target_typ =
P_aux (P_id var,(Generated src_l,src_ann)),
E_aux
(E_cast (tar_typ',
- E_aux (E_app (Id_aux (Id "bitvector_cast", genunk),
+ E_aux (E_app (Id_aux (Id cast_name, genunk),
[E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))),
(genunk, tar_ann))
| _ ->
@@ -3971,12 +3971,12 @@ let make_bitvector_cast_fns env quant_kids src_typ target_typ =
(fun var exp ->
let exp_ann = mk_tannot env (typ_of exp) (effect_of exp) in
E_aux (E_let (LB_aux (LB_val (P_aux (P_typ (one_target_typ, P_aux (P_id var,(genunk,tar_ann))),(genunk,tar_ann)),
- E_aux (E_app (Id_aux (Id "bitvector_cast",genunk),
+ E_aux (E_app (Id_aux (Id cast_name,genunk),
[E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann))),(genunk,tar_ann)),
exp),(genunk,exp_ann))),
(fun (E_aux (_,(exp_l,exp_ann)) as exp) ->
E_aux (E_cast (one_target_typ,
- E_aux (E_app (Id_aux (Id "bitvector_cast", genunk), [exp]), (Generated exp_l,tar_ann))),
+ E_aux (E_app (Id_aux (Id cast_name, genunk), [exp]), (Generated exp_l,tar_ann))),
(Generated exp_l,tar_ann)))
| _ ->
(fun var exp ->
@@ -3991,12 +3991,12 @@ let make_bitvector_cast_fns env quant_kids src_typ target_typ =
(* TODO: bound vars *)
let make_bitvector_env_casts env quant_kids (kid,i) exp =
- let mk_cast var typ exp = (fst (make_bitvector_cast_fns env quant_kids typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
+ let mk_cast var typ exp = (fst (make_bitvector_cast_fns "bitvector_cast_in" env quant_kids typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
let locals = Env.get_locals env in
Bindings.fold (fun var (mut,typ) exp ->
if mut = Immutable then mk_cast var typ exp else exp) locals exp
-let make_bitvector_cast_exp env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns env quant_kids typ target_typ)) exp
+let make_bitvector_cast_exp cast_name env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns cast_name env quant_kids typ target_typ)) exp
let rec extract_value_from_guard var (E_aux (e,_)) =
match e with
@@ -4023,6 +4023,12 @@ let fill_in_type env typ =
subst_src_typ subst typ
(* TODO: top-level patterns *)
+(* TODO: proper environment tracking for variables. Currently we pretend that
+ we can print the type of a variable in the top-level environment, but in
+ practice they might be below a case split. Note that we'd also need to
+ provide some way for the Lem pretty printer to know what to use; currently
+ we just use two names for the cast, bitvector_cast_in and bitvector_cast_out,
+ to let the pretty printer know whether to use the top-level environment. *)
let add_bitvector_casts (Defs defs) =
let rewrite_body id quant_kids top_env ret_typ exp =
let rewrite_aux (e,ann) =
@@ -4039,13 +4045,13 @@ let add_bitvector_casts (Defs defs) =
let body = match pat, guard with
| P_aux (P_lit (L_aux (L_num i,_)),_), _ ->
let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
- make_bitvector_cast_exp env quant_kids src_typ result_typ
+ make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
(make_bitvector_env_casts env quant_kids (kid,i) body)
| P_aux (P_id var,_), Some guard ->
(match extract_value_from_guard var guard with
| Some i ->
let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
- make_bitvector_cast_exp env quant_kids src_typ result_typ
+ make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
(make_bitvector_env_casts env quant_kids (kid,i) body)
| None -> body)
| _ ->
@@ -4056,15 +4062,46 @@ let add_bitvector_casts (Defs defs) =
E_aux (E_case (exp', List.map map_case cases),ann)
| _ -> E_aux (e,ann)
end
+ | E_if (e1,e2,e3) ->
+ let env = env_of_annot ann in
+ let result_typ = Env.base_typ_of env (typ_of_annot ann) in
+ let rec extract (E_aux (e,_)) =
+ match e with
+ | E_app (op,
+ ([E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_); y] |
+ [y; E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_)]))
+ when string_of_id op = "eq_atom" ->
+ (match destruct_atom_nexp (env_of y) (typ_of y) with
+ | Some (Nexp_aux (Nexp_constant i,_)) -> [(kid,i)]
+ | _ -> [])
+ | E_app (op, [x;y]) when string_of_id op = "and_bool" ->
+ extract x @ extract y
+ | _ -> []
+ in
+ let insts = extract e1 in
+ let e2' = List.fold_left (fun body inst ->
+ make_bitvector_env_casts env quant_kids inst body) e2 insts in
+ let insts = List.fold_left (fun insts (kid,i) ->
+ KBindings.add kid (nconstant i) insts) KBindings.empty insts in
+ let src_typ = subst_src_typ insts result_typ in
+ let e2' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ e2' in
+ E_aux (E_if (e1,e2',e3), ann)
| E_return e' ->
- E_aux (E_return (make_bitvector_cast_exp top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann)
- (* TODO: (env_of_annot ann) isn't suitable, because it contains
- constraints revealing the case splits involved; needs a more
- subtle approach *)
+ E_aux (E_return (make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann)
| E_assign (LEXP_aux (lexp,lexp_annot),e') ->
E_aux (E_assign (LEXP_aux (lexp,lexp_annot),
- make_bitvector_cast_exp (env_of_annot ann) quant_kids (fill_in_type (env_of e') (typ_of e'))
+ make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e'))
(typ_of_annot lexp_annot) e'),ann)
+ | E_id id -> begin
+ let env = env_of_annot ann in
+ match Env.lookup_id id env with
+ | Local (Mutable, vtyp) ->
+ make_bitvector_cast_exp "bitvector_cast_in" top_env quant_kids
+ (fill_in_type (env_of_annot ann) (typ_of_annot ann))
+ vtyp
+ (E_aux (e,ann))
+ | _ -> E_aux (e,ann)
+ end
| _ -> E_aux (e,ann)
in
let open Rewriter in
@@ -4089,7 +4126,7 @@ let add_bitvector_casts (Defs defs) =
let body = rewrite_body id quant_kids body_env ret_typ body in
(* Also add a cast around the entire function clause body, if necessary *)
let body =
- make_bitvector_cast_exp fcl_env quant_kids (fill_in_type body_env (typ_of body)) ret_typ body
+ make_bitvector_cast_exp "bitvector_cast_out" fcl_env quant_kids (fill_in_type body_env (typ_of body)) ret_typ body
in
let pexp = construct_pexp (pat,guard,body,annot) in
FCL_aux (FCL_Funcl (id,pexp),fcl_ann)
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 75284418..bef54f05 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -65,8 +65,9 @@ let opt_mwords = ref false
type context = {
early_ret : bool;
bound_nexps : NexpSet.t;
+ top_env : Env.t
}
-let empty_ctxt = { early_ret = false; bound_nexps = NexpSet.empty }
+let empty_ctxt = { early_ret = false; bound_nexps = NexpSet.empty; top_env = Env.empty }
let print_to_from_interp_value = ref false
let langlebar = string "<|"
@@ -328,10 +329,9 @@ let doc_typ_lem, doc_atomic_typ_lem =
| Typ_arg_order o -> empty
in typ', atomic_typ
-(* Check for variables in types that would be pretty-printed and are not
- bound in the val spec of the function. *)
+(* Check for variables in types that would be pretty-printed. *)
let contains_t_pp_var ctxt (Typ_aux (t,a) as typ) =
- NexpSet.diff (lem_nexps_of_typ typ) ctxt.bound_nexps
+ lem_nexps_of_typ typ
|> NexpSet.exists (fun nexp -> not (is_nexp_constant nexp))
let replace_typ_size ctxt env (Typ_aux (t,a)) =
@@ -341,14 +341,14 @@ let replace_typ_size ctxt env (Typ_aux (t,a)) =
let mk_typ nexp =
Some (Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp nexp,Parse_ast.Unknown);ord;typ']),a))
in
- let is_equal nexp =
- prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown))
- in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with
- | nexp -> mk_typ nexp
- | exception Not_found ->
- match Type_check.solve env size with
- | Some n -> mk_typ (nconstant n)
- | None -> None
+ match Type_check.solve env size with
+ | Some n -> mk_typ (nconstant n)
+ | None ->
+ let is_equal nexp =
+ prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown))
+ in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with
+ | nexp -> mk_typ nexp
+ | exception Not_found -> None
end
| _ -> None
@@ -760,8 +760,12 @@ let doc_exp_lem, doc_let_lem =
let env = env_of full_exp in
let t = Env.expand_synonyms env (typ_of full_exp) in
let eff = effect_of full_exp in
- if typ_needs_printed t
- then (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env (effectful eff) t))), true)
+ if typ_needs_printed t then
+ if Id.compare f (mk_id "bitvector_cast_out") <> 0
+ then (align (group (prefix 0 1 epp (doc_tannot_lem ctxt env (effectful eff) t))), true)
+ (* TODO: coordinate with the code in monomorphise.ml to find the correct
+ typing environment to use *)
+ else (align (group (prefix 0 1 epp (doc_tannot_lem ctxt ctxt.top_env (effectful eff) t))), true)
else (epp, aexp_needed) in
liftR (if aexp_needed then parens (align taepp) else taepp)
end
@@ -1255,7 +1259,8 @@ let doc_funcl_lem (FCL_aux(FCL_Funcl(id, pexp), annot)) =
let pat,guard,exp,(l,_) = destruct_pexp pexp in
let ctxt =
{ early_ret = contains_early_return exp;
- bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) } in
+ bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ);
+ top_env = env_of_annot annot } in
let pats, bind = untuple_args_pat pat in
let patspp = separate_map space (doc_pat_lem ctxt true) pats in
let _ = match guard with