diff options
| author | Brian Campbell | 2018-02-08 15:03:34 +0000 |
|---|---|---|
| committer | Brian Campbell | 2018-02-08 15:03:42 +0000 |
| commit | 579cd897d7873436ba6cfb3469b185bf6b321dac (patch) | |
| tree | 1ca652f246135f2ad4973acbb5051a2c79e0d411 /src | |
| parent | 45519ae89ceef4c838cdd52e2bbaa4174e63f27d (diff) | |
Add (most of) the bitvector cast insertion transformation
to help Lem go from a general type `bits('n)` to a specific type `bits(16)`
at a case split, and the other way around for a returned value.
Doesn't handle function clause patterns yet
Diffstat (limited to 'src')
| -rw-r--r-- | src/monomorphise.ml | 177 |
1 files changed, 177 insertions, 0 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 2b4821e0..5cf5181d 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -3404,6 +3404,182 @@ let mono_rewrite defs = defs end +module BitvectorSizeCasts = +struct + +(* These functions add cast functions across case splits, so that when a + bitvector size becomes known in sail, the generated Lem code contains a + function call to change mword 'n to (say) mword ty16, and vice versa. *) +let make_bitvector_cast_fns env src_typ target_typ = + let genunk = Generated Unknown in + let fresh = + let counter = ref 0 in + fun () -> + let n = !counter in + let () = counter := n+1 in + mk_id ("cast#" ^ string_of_int n) + in + let required = ref false in + let rec aux (Typ_aux (src_t,src_l) as src_typ) (Typ_aux (tar_t,tar_l) as tar_typ) = + let src_ann = Some (env,src_typ,no_effect) in + let tar_ann = Some (env,tar_typ,no_effect) in + match src_t, tar_t with + | Typ_tup typs, Typ_tup typs' -> + let ps,es = List.split (List.map2 aux typs typs') in + P_aux (P_tup ps,(Generated src_l, src_ann)), + E_aux (E_tuple es,(Generated tar_l, tar_ann)) + | Typ_app (Id_aux (Id "vector",_), + [Typ_arg_aux (Typ_arg_nexp size,_); _; + Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]), + Typ_app (Id_aux (Id "vector",_), + [Typ_arg_aux (Typ_arg_nexp size',_); _; + Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) + when Nexp.compare size size' <> 0 -> + let () = required := true in + let var = fresh () in + P_aux (P_id var,(Generated src_l,src_ann)), + E_aux + (E_cast (tar_typ, + E_aux (E_app (Id_aux (Id "bitvector_cast", genunk), + [E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))), + (genunk, tar_ann)) + | _ -> + let var = fresh () in + P_aux (P_id var,(Generated src_l,src_ann)), + E_aux (E_id var,(Generated src_l,tar_ann)) + in + let src_typ' = Env.base_typ_of env src_typ in + let target_typ' = Env.base_typ_of env target_typ in + let pat, e' = aux src_typ' target_typ' in + if !required + then + let src_ann = Some (env,src_typ,no_effect) in + let tar_ann = Some (env,target_typ,no_effect) in + match src_typ' with + (* Simple case with just the bitvector; don't need to pull apart value *) + | Typ_aux (Typ_app _,_) -> + (fun var exp -> + let exp_ann = Some (env,typ_of exp,effect_of exp) in + E_aux (E_let (LB_aux (LB_val (P_aux (P_typ (target_typ, P_aux (P_id var,(genunk,tar_ann))),(genunk,tar_ann)), + E_aux (E_app (Id_aux (Id "bitvector_cast",genunk), + [E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann))),(genunk,tar_ann)), + exp),(genunk,exp_ann))), + (fun (E_aux (_,(exp_l,exp_ann)) as exp) -> + E_aux (E_cast (target_typ, + E_aux (E_app (Id_aux (Id "bitvector_cast", genunk), [exp]), (Generated exp_l,tar_ann))), + (Generated exp_l,tar_ann))) + | _ -> + (fun var exp -> + let exp_ann = Some (env,typ_of exp,effect_of exp) in + E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id var,(genunk,src_ann))),(genunk,src_ann)), + E_aux (E_let (LB_aux (LB_val (P_aux (P_id var,(genunk,tar_ann)),e'),(genunk,tar_ann)), + exp),(genunk,exp_ann))),(genunk,exp_ann))), + (fun (E_aux (_,(exp_l,exp_ann)) as exp) -> + E_aux (E_let (LB_aux (LB_val (pat, exp),(Generated exp_l,exp_ann)), e'),(Generated exp_l,tar_ann))) + else (fun _ e -> e),(fun e -> e) + +(* TODO: bound vars *) +let make_bitvector_env_casts env (kid,i) exp = + let mk_cast var typ exp = (fst (make_bitvector_cast_fns env typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in + let locals = Env.get_locals env in + Bindings.fold (fun var (mut,typ) exp -> + if mut = Immutable then mk_cast var typ exp else exp) locals exp + +let make_bitvector_cast_exp typ target_typ exp = (snd (make_bitvector_cast_fns (env_of exp) typ target_typ)) exp + +let rec extract_value_from_guard var (E_aux (e,_)) = + match e with + | E_app (op, ([E_aux (E_id var',_); E_aux (E_lit (L_aux (L_num i,_)),_)] | + [E_aux (E_lit (L_aux (L_num i,_)),_); E_aux (E_id var',_)])) + when string_of_id op = "eq_atom" && Id.compare var var' == 0 -> + Some i + | E_app (op, [e1;e2]) when string_of_id op = "and_bool" -> + (match extract_value_from_guard var e1 with + | Some i -> Some i + | None -> extract_value_from_guard var e2) + | _ -> None + +let rec flatten_constraints = function + | [] -> [] + | (NC_aux (NC_and (nc1,nc2),_))::t -> flatten_constraints (nc1::nc2::t) + | h::t -> h::(flatten_constraints t) + +let fill_in_type env typ = + let constraints = Env.get_constraints env in + let constraints = flatten_constraints constraints in + let subst = Util.map_filter (function + | NC_aux (NC_equal (Nexp_aux (Nexp_var var,_), (Nexp_aux (Nexp_constant _,_) as i)),_) -> Some (var,i) + | NC_aux (NC_equal (Nexp_aux (Nexp_constant _,_) as i, Nexp_aux (Nexp_var var,_)),_) -> Some (var,i) + | _ -> None) constraints in + subst_src_typ (kbindings_from_list subst) typ + +(* TODO: top-level patterns *) +let add_bitvector_casts (Defs defs) = + let rewrite_body ret_typ exp = + let rewrite_aux (e,ann) = + match e with + | E_case (E_aux (e',ann') as exp',cases) -> begin + let env = env_of_annot ann in + let result_typ = Env.base_typ_of env (typ_of_annot ann) in + let matched_typ = Env.base_typ_of env (typ_of_annot ann') in + match e',matched_typ with + | E_sizeof (Nexp_aux (Nexp_var kid,_)), _ + | _, Typ_aux (Typ_app (Id_aux (Id "atom",_), [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_) -> + let map_case pexp = + let pat,guard,body,ann = destruct_pexp pexp in + let body = match pat, guard with + | P_aux (P_lit (L_aux (L_num i,_)),_), _ -> + let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in + make_bitvector_cast_exp src_typ result_typ + (make_bitvector_env_casts env (kid,i) body) + | P_aux (P_id var,_), Some guard -> + (match extract_value_from_guard var guard with + | Some i -> + let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in + make_bitvector_cast_exp src_typ result_typ + (make_bitvector_env_casts env (kid,i) body) + | None -> body) + | _ -> body + in + construct_pexp (pat, guard, body, ann) + in + E_aux (E_case (exp', List.map map_case cases),ann) + | _ -> E_aux (e,ann) + end + | E_return e' -> + E_aux (E_return (make_bitvector_cast_exp (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann) + | E_assign (LEXP_aux (lexp,lexp_annot),e') -> + E_aux (E_assign (LEXP_aux (lexp,lexp_annot), + make_bitvector_cast_exp (fill_in_type (env_of e') (typ_of e')) + (typ_of_annot lexp_annot) e'),ann) + | _ -> E_aux (e,ann) + in + let open Rewriter in + fold_exp + { id_exp_alg with + e_aux = rewrite_aux } exp + in + let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann)) = + let ret_typ = + match typ_of_annot fcl_ann with + | Typ_aux (Typ_fn (_,ret,_),_) -> ret + | Typ_aux (_,l) as typ -> + raise (Reporting_basic.err_unreachable l + ("Function clause must have function type: " ^ string_of_typ typ ^ + " is not a function type")) + in + let pat,guard,body,annot = destruct_pexp pexp in + let body = rewrite_body ret_typ body in + let pexp = construct_pexp (pat,guard,body,annot) in + FCL_aux (FCL_Funcl (id,pexp),fcl_ann) + in + let rewrite_def = function + | DEF_fundef (FD_aux (FD_function (r,t,e,fcls),fd_ann)) -> + DEF_fundef (FD_aux (FD_function (r,t,e,List.map rewrite_funcl fcls),fd_ann)) + | d -> d + in Defs (List.map rewrite_def defs) +end + type options = { auto : bool; debug_analysis : int; @@ -3450,6 +3626,7 @@ let monomorphise opts splits env defs = then () else raise (Reporting_basic.err_general Unknown "Unable to monomorphise program") in + let defs = BitvectorSizeCasts.add_bitvector_casts defs in (* TODO: currently doing this because constant propagation leaves numeric literals as int, try to avoid this later; also use final env for DEF_spec case above, because the type checker doesn't store the env at that point :( *) |
