summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBrian Campbell2018-02-08 15:03:34 +0000
committerBrian Campbell2018-02-08 15:03:42 +0000
commit579cd897d7873436ba6cfb3469b185bf6b321dac (patch)
tree1ca652f246135f2ad4973acbb5051a2c79e0d411 /src
parent45519ae89ceef4c838cdd52e2bbaa4174e63f27d (diff)
Add (most of) the bitvector cast insertion transformation
to help Lem go from a general type `bits('n)` to a specific type `bits(16)` at a case split, and the other way around for a returned value. Doesn't handle function clause patterns yet
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml177
1 files changed, 177 insertions, 0 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 2b4821e0..5cf5181d 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -3404,6 +3404,182 @@ let mono_rewrite defs =
defs
end
+module BitvectorSizeCasts =
+struct
+
+(* These functions add cast functions across case splits, so that when a
+ bitvector size becomes known in sail, the generated Lem code contains a
+ function call to change mword 'n to (say) mword ty16, and vice versa. *)
+let make_bitvector_cast_fns env src_typ target_typ =
+ let genunk = Generated Unknown in
+ let fresh =
+ let counter = ref 0 in
+ fun () ->
+ let n = !counter in
+ let () = counter := n+1 in
+ mk_id ("cast#" ^ string_of_int n)
+ in
+ let required = ref false in
+ let rec aux (Typ_aux (src_t,src_l) as src_typ) (Typ_aux (tar_t,tar_l) as tar_typ) =
+ let src_ann = Some (env,src_typ,no_effect) in
+ let tar_ann = Some (env,tar_typ,no_effect) in
+ match src_t, tar_t with
+ | Typ_tup typs, Typ_tup typs' ->
+ let ps,es = List.split (List.map2 aux typs typs') in
+ P_aux (P_tup ps,(Generated src_l, src_ann)),
+ E_aux (E_tuple es,(Generated tar_l, tar_ann))
+ | Typ_app (Id_aux (Id "vector",_),
+ [Typ_arg_aux (Typ_arg_nexp size,_); _;
+ Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]),
+ Typ_app (Id_aux (Id "vector",_),
+ [Typ_arg_aux (Typ_arg_nexp size',_); _;
+ Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)])
+ when Nexp.compare size size' <> 0 ->
+ let () = required := true in
+ let var = fresh () in
+ P_aux (P_id var,(Generated src_l,src_ann)),
+ E_aux
+ (E_cast (tar_typ,
+ E_aux (E_app (Id_aux (Id "bitvector_cast", genunk),
+ [E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))),
+ (genunk, tar_ann))
+ | _ ->
+ let var = fresh () in
+ P_aux (P_id var,(Generated src_l,src_ann)),
+ E_aux (E_id var,(Generated src_l,tar_ann))
+ in
+ let src_typ' = Env.base_typ_of env src_typ in
+ let target_typ' = Env.base_typ_of env target_typ in
+ let pat, e' = aux src_typ' target_typ' in
+ if !required
+ then
+ let src_ann = Some (env,src_typ,no_effect) in
+ let tar_ann = Some (env,target_typ,no_effect) in
+ match src_typ' with
+ (* Simple case with just the bitvector; don't need to pull apart value *)
+ | Typ_aux (Typ_app _,_) ->
+ (fun var exp ->
+ let exp_ann = Some (env,typ_of exp,effect_of exp) in
+ E_aux (E_let (LB_aux (LB_val (P_aux (P_typ (target_typ, P_aux (P_id var,(genunk,tar_ann))),(genunk,tar_ann)),
+ E_aux (E_app (Id_aux (Id "bitvector_cast",genunk),
+ [E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann))),(genunk,tar_ann)),
+ exp),(genunk,exp_ann))),
+ (fun (E_aux (_,(exp_l,exp_ann)) as exp) ->
+ E_aux (E_cast (target_typ,
+ E_aux (E_app (Id_aux (Id "bitvector_cast", genunk), [exp]), (Generated exp_l,tar_ann))),
+ (Generated exp_l,tar_ann)))
+ | _ ->
+ (fun var exp ->
+ let exp_ann = Some (env,typ_of exp,effect_of exp) in
+ E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id var,(genunk,src_ann))),(genunk,src_ann)),
+ E_aux (E_let (LB_aux (LB_val (P_aux (P_id var,(genunk,tar_ann)),e'),(genunk,tar_ann)),
+ exp),(genunk,exp_ann))),(genunk,exp_ann))),
+ (fun (E_aux (_,(exp_l,exp_ann)) as exp) ->
+ E_aux (E_let (LB_aux (LB_val (pat, exp),(Generated exp_l,exp_ann)), e'),(Generated exp_l,tar_ann)))
+ else (fun _ e -> e),(fun e -> e)
+
+(* TODO: bound vars *)
+let make_bitvector_env_casts env (kid,i) exp =
+ let mk_cast var typ exp = (fst (make_bitvector_cast_fns env typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
+ let locals = Env.get_locals env in
+ Bindings.fold (fun var (mut,typ) exp ->
+ if mut = Immutable then mk_cast var typ exp else exp) locals exp
+
+let make_bitvector_cast_exp typ target_typ exp = (snd (make_bitvector_cast_fns (env_of exp) typ target_typ)) exp
+
+let rec extract_value_from_guard var (E_aux (e,_)) =
+ match e with
+ | E_app (op, ([E_aux (E_id var',_); E_aux (E_lit (L_aux (L_num i,_)),_)] |
+ [E_aux (E_lit (L_aux (L_num i,_)),_); E_aux (E_id var',_)]))
+ when string_of_id op = "eq_atom" && Id.compare var var' == 0 ->
+ Some i
+ | E_app (op, [e1;e2]) when string_of_id op = "and_bool" ->
+ (match extract_value_from_guard var e1 with
+ | Some i -> Some i
+ | None -> extract_value_from_guard var e2)
+ | _ -> None
+
+let rec flatten_constraints = function
+ | [] -> []
+ | (NC_aux (NC_and (nc1,nc2),_))::t -> flatten_constraints (nc1::nc2::t)
+ | h::t -> h::(flatten_constraints t)
+
+let fill_in_type env typ =
+ let constraints = Env.get_constraints env in
+ let constraints = flatten_constraints constraints in
+ let subst = Util.map_filter (function
+ | NC_aux (NC_equal (Nexp_aux (Nexp_var var,_), (Nexp_aux (Nexp_constant _,_) as i)),_) -> Some (var,i)
+ | NC_aux (NC_equal (Nexp_aux (Nexp_constant _,_) as i, Nexp_aux (Nexp_var var,_)),_) -> Some (var,i)
+ | _ -> None) constraints in
+ subst_src_typ (kbindings_from_list subst) typ
+
+(* TODO: top-level patterns *)
+let add_bitvector_casts (Defs defs) =
+ let rewrite_body ret_typ exp =
+ let rewrite_aux (e,ann) =
+ match e with
+ | E_case (E_aux (e',ann') as exp',cases) -> begin
+ let env = env_of_annot ann in
+ let result_typ = Env.base_typ_of env (typ_of_annot ann) in
+ let matched_typ = Env.base_typ_of env (typ_of_annot ann') in
+ match e',matched_typ with
+ | E_sizeof (Nexp_aux (Nexp_var kid,_)), _
+ | _, Typ_aux (Typ_app (Id_aux (Id "atom",_), [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_) ->
+ let map_case pexp =
+ let pat,guard,body,ann = destruct_pexp pexp in
+ let body = match pat, guard with
+ | P_aux (P_lit (L_aux (L_num i,_)),_), _ ->
+ let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
+ make_bitvector_cast_exp src_typ result_typ
+ (make_bitvector_env_casts env (kid,i) body)
+ | P_aux (P_id var,_), Some guard ->
+ (match extract_value_from_guard var guard with
+ | Some i ->
+ let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
+ make_bitvector_cast_exp src_typ result_typ
+ (make_bitvector_env_casts env (kid,i) body)
+ | None -> body)
+ | _ -> body
+ in
+ construct_pexp (pat, guard, body, ann)
+ in
+ E_aux (E_case (exp', List.map map_case cases),ann)
+ | _ -> E_aux (e,ann)
+ end
+ | E_return e' ->
+ E_aux (E_return (make_bitvector_cast_exp (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann)
+ | E_assign (LEXP_aux (lexp,lexp_annot),e') ->
+ E_aux (E_assign (LEXP_aux (lexp,lexp_annot),
+ make_bitvector_cast_exp (fill_in_type (env_of e') (typ_of e'))
+ (typ_of_annot lexp_annot) e'),ann)
+ | _ -> E_aux (e,ann)
+ in
+ let open Rewriter in
+ fold_exp
+ { id_exp_alg with
+ e_aux = rewrite_aux } exp
+ in
+ let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann)) =
+ let ret_typ =
+ match typ_of_annot fcl_ann with
+ | Typ_aux (Typ_fn (_,ret,_),_) -> ret
+ | Typ_aux (_,l) as typ ->
+ raise (Reporting_basic.err_unreachable l
+ ("Function clause must have function type: " ^ string_of_typ typ ^
+ " is not a function type"))
+ in
+ let pat,guard,body,annot = destruct_pexp pexp in
+ let body = rewrite_body ret_typ body in
+ let pexp = construct_pexp (pat,guard,body,annot) in
+ FCL_aux (FCL_Funcl (id,pexp),fcl_ann)
+ in
+ let rewrite_def = function
+ | DEF_fundef (FD_aux (FD_function (r,t,e,fcls),fd_ann)) ->
+ DEF_fundef (FD_aux (FD_function (r,t,e,List.map rewrite_funcl fcls),fd_ann))
+ | d -> d
+ in Defs (List.map rewrite_def defs)
+end
+
type options = {
auto : bool;
debug_analysis : int;
@@ -3450,6 +3626,7 @@ let monomorphise opts splits env defs =
then ()
else raise (Reporting_basic.err_general Unknown "Unable to monomorphise program")
in
+ let defs = BitvectorSizeCasts.add_bitvector_casts defs in
(* TODO: currently doing this because constant propagation leaves numeric literals as
int, try to avoid this later; also use final env for DEF_spec case above, because the
type checker doesn't store the env at that point :( *)