summaryrefslogtreecommitdiff
path: root/src/monomorphise.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/monomorphise.ml')
-rw-r--r--src/monomorphise.ml94
1 files changed, 59 insertions, 35 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 48a7ac65..3f49689b 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -3810,10 +3810,21 @@ end
module BitvectorSizeCasts =
struct
+let simplify_size_nexp env quant_kids (Nexp_aux (_,l) as nexp) =
+ match solve env nexp with
+ | Some n -> Some (nconstant n)
+ | None ->
+ let is_equal kid =
+ prove env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown))
+ in
+ match List.find is_equal quant_kids with
+ | kid -> Some (Nexp_aux (Nexp_var kid,Generated l))
+ | exception Not_found -> None
+
(* These functions add cast functions across case splits, so that when a
bitvector size becomes known in sail, the generated Lem code contains a
function call to change mword 'n to (say) mword ty16, and vice versa. *)
-let make_bitvector_cast_fns env src_typ target_typ =
+let make_bitvector_cast_fns env quant_kids src_typ target_typ =
let genunk = Generated Unknown in
let fresh =
let counter = ref 0 in
@@ -3822,7 +3833,7 @@ let make_bitvector_cast_fns env src_typ target_typ =
let () = counter := n+1 in
mk_id ("cast#" ^ string_of_int n)
in
- let required = ref false in
+ let at_least_one = ref None in
let rec aux (Typ_aux (src_t,src_l) as src_typ) (Typ_aux (tar_t,tar_l) as tar_typ) =
let src_ann = Some (env,src_typ,no_effect) in
let tar_ann = Some (env,tar_typ,no_effect) in
@@ -3834,18 +3845,26 @@ let make_bitvector_cast_fns env src_typ target_typ =
| Typ_app (Id_aux (Id "vector",_),
[Typ_arg_aux (Typ_arg_nexp size,_); _;
Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]),
- Typ_app (Id_aux (Id "vector",_),
- [Typ_arg_aux (Typ_arg_nexp size',_); _;
- Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)])
- when Nexp.compare size size' <> 0 ->
- let () = required := true in
- let var = fresh () in
- P_aux (P_id var,(Generated src_l,src_ann)),
- E_aux
- (E_cast (tar_typ,
- E_aux (E_app (Id_aux (Id "bitvector_cast", genunk),
- [E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))),
- (genunk, tar_ann))
+ Typ_app (Id_aux (Id "vector",_) as t_id,
+ [Typ_arg_aux (Typ_arg_nexp size',l_size'); t_ord;
+ Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_) as t_bit]) -> begin
+ match simplify_size_nexp env quant_kids size, simplify_size_nexp env quant_kids size' with
+ | Some size, Some size' when Nexp.compare size size' <> 0 ->
+ let var = fresh () in
+ let tar_typ' = Typ_aux (Typ_app (t_id, [Typ_arg_aux (Typ_arg_nexp size',l_size');t_ord;t_bit]),
+ tar_l) in
+ let () = at_least_one := Some tar_typ' in
+ P_aux (P_id var,(Generated src_l,src_ann)),
+ E_aux
+ (E_cast (tar_typ',
+ E_aux (E_app (Id_aux (Id "bitvector_cast", genunk),
+ [E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))),
+ (genunk, tar_ann))
+ | _ ->
+ let var = fresh () in
+ P_aux (P_id var,(Generated src_l,src_ann)),
+ E_aux (E_id var,(Generated src_l,tar_ann))
+ end
| _ ->
let var = fresh () in
P_aux (P_id var,(Generated src_l,src_ann)),
@@ -3854,8 +3873,8 @@ let make_bitvector_cast_fns env src_typ target_typ =
let src_typ' = Env.base_typ_of env src_typ in
let target_typ' = Env.base_typ_of env target_typ in
let pat, e' = aux src_typ' target_typ' in
- if !required
- then
+ match !at_least_one with
+ | Some one_target_typ -> begin
let src_ann = Some (env,src_typ,no_effect) in
let tar_ann = Some (env,target_typ,no_effect) in
match src_typ' with
@@ -3863,12 +3882,12 @@ let make_bitvector_cast_fns env src_typ target_typ =
| Typ_aux (Typ_app _,_) ->
(fun var exp ->
let exp_ann = Some (env,typ_of exp,effect_of exp) in
- E_aux (E_let (LB_aux (LB_val (P_aux (P_typ (target_typ, P_aux (P_id var,(genunk,tar_ann))),(genunk,tar_ann)),
+ E_aux (E_let (LB_aux (LB_val (P_aux (P_typ (one_target_typ, P_aux (P_id var,(genunk,tar_ann))),(genunk,tar_ann)),
E_aux (E_app (Id_aux (Id "bitvector_cast",genunk),
[E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann))),(genunk,tar_ann)),
exp),(genunk,exp_ann))),
(fun (E_aux (_,(exp_l,exp_ann)) as exp) ->
- E_aux (E_cast (target_typ,
+ E_aux (E_cast (one_target_typ,
E_aux (E_app (Id_aux (Id "bitvector_cast", genunk), [exp]), (Generated exp_l,tar_ann))),
(Generated exp_l,tar_ann)))
| _ ->
@@ -3879,16 +3898,17 @@ let make_bitvector_cast_fns env src_typ target_typ =
exp),(genunk,exp_ann))),(genunk,exp_ann))),
(fun (E_aux (_,(exp_l,exp_ann)) as exp) ->
E_aux (E_let (LB_aux (LB_val (pat, exp),(Generated exp_l,exp_ann)), e'),(Generated exp_l,tar_ann)))
- else (fun _ e -> e),(fun e -> e)
+ end
+ | None -> (fun _ e -> e),(fun e -> e)
(* TODO: bound vars *)
-let make_bitvector_env_casts env (kid,i) exp =
- let mk_cast var typ exp = (fst (make_bitvector_cast_fns env typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
+let make_bitvector_env_casts env quant_kids (kid,i) exp =
+ let mk_cast var typ exp = (fst (make_bitvector_cast_fns env quant_kids typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
let locals = Env.get_locals env in
Bindings.fold (fun var (mut,typ) exp ->
if mut = Immutable then mk_cast var typ exp else exp) locals exp
-let make_bitvector_cast_exp typ target_typ exp = (snd (make_bitvector_cast_fns (env_of exp) typ target_typ)) exp
+let make_bitvector_cast_exp env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns env quant_kids typ target_typ)) exp
let rec extract_value_from_guard var (E_aux (e,_)) =
match e with
@@ -3909,14 +3929,14 @@ let fill_in_type env typ =
| BK_type
| BK_order -> subst
| BK_int ->
- match solve env (nvar kid) with
+ (match solve env (nvar kid) with
| None -> subst
- | Some n -> KBindings.add kid (nconstant n) subst) tyvars KBindings.empty in
+ | Some n -> KBindings.add kid (nconstant n) subst)) tyvars KBindings.empty in
subst_src_typ subst typ
(* TODO: top-level patterns *)
let add_bitvector_casts (Defs defs) =
- let rewrite_body ret_typ exp =
+ let rewrite_body id quant_kids top_env ret_typ exp =
let rewrite_aux (e,ann) =
match e with
| E_case (E_aux (e',ann') as exp',cases) -> begin
@@ -3931,16 +3951,17 @@ let add_bitvector_casts (Defs defs) =
let body = match pat, guard with
| P_aux (P_lit (L_aux (L_num i,_)),_), _ ->
let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
- make_bitvector_cast_exp src_typ result_typ
- (make_bitvector_env_casts env (kid,i) body)
+ make_bitvector_cast_exp env quant_kids src_typ result_typ
+ (make_bitvector_env_casts env quant_kids (kid,i) body)
| P_aux (P_id var,_), Some guard ->
(match extract_value_from_guard var guard with
| Some i ->
let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
- make_bitvector_cast_exp src_typ result_typ
- (make_bitvector_env_casts env (kid,i) body)
+ make_bitvector_cast_exp env quant_kids src_typ result_typ
+ (make_bitvector_env_casts env quant_kids (kid,i) body)
| None -> body)
- | _ -> body
+ | _ ->
+ body
in
construct_pexp (pat, guard, body, ann)
in
@@ -3948,10 +3969,10 @@ let add_bitvector_casts (Defs defs) =
| _ -> E_aux (e,ann)
end
| E_return e' ->
- E_aux (E_return (make_bitvector_cast_exp (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann)
+ E_aux (E_return (make_bitvector_cast_exp top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann)
| E_assign (LEXP_aux (lexp,lexp_annot),e') ->
E_aux (E_assign (LEXP_aux (lexp,lexp_annot),
- make_bitvector_cast_exp (fill_in_type (env_of e') (typ_of e'))
+ make_bitvector_cast_exp (env_of_annot ann) quant_kids (fill_in_type (env_of e') (typ_of e'))
(typ_of_annot lexp_annot) e'),ann)
| _ -> E_aux (e,ann)
in
@@ -3961,8 +3982,10 @@ let add_bitvector_casts (Defs defs) =
e_aux = rewrite_aux } exp
in
let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann)) =
+ let (tq,typ) = Env.get_val_spec_orig id (env_of_annot fcl_ann) in
+ let quant_kids = List.map kopt_kid (quant_kopts tq) in
let ret_typ =
- match typ_of_annot fcl_ann with
+ match typ with
| Typ_aux (Typ_fn (_,ret,_),_) -> ret
| Typ_aux (_,l) as typ ->
raise (Reporting_basic.err_unreachable l
@@ -3970,10 +3993,11 @@ let add_bitvector_casts (Defs defs) =
" is not a function type"))
in
let pat,guard,body,annot = destruct_pexp pexp in
- let body = rewrite_body ret_typ body in
+ let env = env_of body in
+ let body = rewrite_body id quant_kids env ret_typ body in
(* Also add a cast around the entire function clause body, if necessary *)
let body =
- make_bitvector_cast_exp (fill_in_type (env_of body) (typ_of body)) ret_typ body
+ make_bitvector_cast_exp env quant_kids (fill_in_type (env_of body) (typ_of body)) ret_typ body
in
let pexp = construct_pexp (pat,guard,body,annot) in
FCL_aux (FCL_Funcl (id,pexp),fcl_ann)