diff options
Diffstat (limited to 'src/monomorphise.ml')
| -rw-r--r-- | src/monomorphise.ml | 94 |
1 files changed, 59 insertions, 35 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 48a7ac65..3f49689b 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -3810,10 +3810,21 @@ end module BitvectorSizeCasts = struct +let simplify_size_nexp env quant_kids (Nexp_aux (_,l) as nexp) = + match solve env nexp with + | Some n -> Some (nconstant n) + | None -> + let is_equal kid = + prove env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown)) + in + match List.find is_equal quant_kids with + | kid -> Some (Nexp_aux (Nexp_var kid,Generated l)) + | exception Not_found -> None + (* These functions add cast functions across case splits, so that when a bitvector size becomes known in sail, the generated Lem code contains a function call to change mword 'n to (say) mword ty16, and vice versa. *) -let make_bitvector_cast_fns env src_typ target_typ = +let make_bitvector_cast_fns env quant_kids src_typ target_typ = let genunk = Generated Unknown in let fresh = let counter = ref 0 in @@ -3822,7 +3833,7 @@ let make_bitvector_cast_fns env src_typ target_typ = let () = counter := n+1 in mk_id ("cast#" ^ string_of_int n) in - let required = ref false in + let at_least_one = ref None in let rec aux (Typ_aux (src_t,src_l) as src_typ) (Typ_aux (tar_t,tar_l) as tar_typ) = let src_ann = Some (env,src_typ,no_effect) in let tar_ann = Some (env,tar_typ,no_effect) in @@ -3834,18 +3845,26 @@ let make_bitvector_cast_fns env src_typ target_typ = | Typ_app (Id_aux (Id "vector",_), [Typ_arg_aux (Typ_arg_nexp size,_); _; Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]), - Typ_app (Id_aux (Id "vector",_), - [Typ_arg_aux (Typ_arg_nexp size',_); _; - Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) - when Nexp.compare size size' <> 0 -> - let () = required := true in - let var = fresh () in - P_aux (P_id var,(Generated src_l,src_ann)), - E_aux - (E_cast (tar_typ, - E_aux (E_app (Id_aux (Id "bitvector_cast", genunk), - [E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))), - (genunk, tar_ann)) + Typ_app (Id_aux (Id "vector",_) as t_id, + [Typ_arg_aux (Typ_arg_nexp size',l_size'); t_ord; + Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_) as t_bit]) -> begin + match simplify_size_nexp env quant_kids size, simplify_size_nexp env quant_kids size' with + | Some size, Some size' when Nexp.compare size size' <> 0 -> + let var = fresh () in + let tar_typ' = Typ_aux (Typ_app (t_id, [Typ_arg_aux (Typ_arg_nexp size',l_size');t_ord;t_bit]), + tar_l) in + let () = at_least_one := Some tar_typ' in + P_aux (P_id var,(Generated src_l,src_ann)), + E_aux + (E_cast (tar_typ', + E_aux (E_app (Id_aux (Id "bitvector_cast", genunk), + [E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))), + (genunk, tar_ann)) + | _ -> + let var = fresh () in + P_aux (P_id var,(Generated src_l,src_ann)), + E_aux (E_id var,(Generated src_l,tar_ann)) + end | _ -> let var = fresh () in P_aux (P_id var,(Generated src_l,src_ann)), @@ -3854,8 +3873,8 @@ let make_bitvector_cast_fns env src_typ target_typ = let src_typ' = Env.base_typ_of env src_typ in let target_typ' = Env.base_typ_of env target_typ in let pat, e' = aux src_typ' target_typ' in - if !required - then + match !at_least_one with + | Some one_target_typ -> begin let src_ann = Some (env,src_typ,no_effect) in let tar_ann = Some (env,target_typ,no_effect) in match src_typ' with @@ -3863,12 +3882,12 @@ let make_bitvector_cast_fns env src_typ target_typ = | Typ_aux (Typ_app _,_) -> (fun var exp -> let exp_ann = Some (env,typ_of exp,effect_of exp) in - E_aux (E_let (LB_aux (LB_val (P_aux (P_typ (target_typ, P_aux (P_id var,(genunk,tar_ann))),(genunk,tar_ann)), + E_aux (E_let (LB_aux (LB_val (P_aux (P_typ (one_target_typ, P_aux (P_id var,(genunk,tar_ann))),(genunk,tar_ann)), E_aux (E_app (Id_aux (Id "bitvector_cast",genunk), [E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann))),(genunk,tar_ann)), exp),(genunk,exp_ann))), (fun (E_aux (_,(exp_l,exp_ann)) as exp) -> - E_aux (E_cast (target_typ, + E_aux (E_cast (one_target_typ, E_aux (E_app (Id_aux (Id "bitvector_cast", genunk), [exp]), (Generated exp_l,tar_ann))), (Generated exp_l,tar_ann))) | _ -> @@ -3879,16 +3898,17 @@ let make_bitvector_cast_fns env src_typ target_typ = exp),(genunk,exp_ann))),(genunk,exp_ann))), (fun (E_aux (_,(exp_l,exp_ann)) as exp) -> E_aux (E_let (LB_aux (LB_val (pat, exp),(Generated exp_l,exp_ann)), e'),(Generated exp_l,tar_ann))) - else (fun _ e -> e),(fun e -> e) + end + | None -> (fun _ e -> e),(fun e -> e) (* TODO: bound vars *) -let make_bitvector_env_casts env (kid,i) exp = - let mk_cast var typ exp = (fst (make_bitvector_cast_fns env typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in +let make_bitvector_env_casts env quant_kids (kid,i) exp = + let mk_cast var typ exp = (fst (make_bitvector_cast_fns env quant_kids typ (subst_src_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in let locals = Env.get_locals env in Bindings.fold (fun var (mut,typ) exp -> if mut = Immutable then mk_cast var typ exp else exp) locals exp -let make_bitvector_cast_exp typ target_typ exp = (snd (make_bitvector_cast_fns (env_of exp) typ target_typ)) exp +let make_bitvector_cast_exp env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns env quant_kids typ target_typ)) exp let rec extract_value_from_guard var (E_aux (e,_)) = match e with @@ -3909,14 +3929,14 @@ let fill_in_type env typ = | BK_type | BK_order -> subst | BK_int -> - match solve env (nvar kid) with + (match solve env (nvar kid) with | None -> subst - | Some n -> KBindings.add kid (nconstant n) subst) tyvars KBindings.empty in + | Some n -> KBindings.add kid (nconstant n) subst)) tyvars KBindings.empty in subst_src_typ subst typ (* TODO: top-level patterns *) let add_bitvector_casts (Defs defs) = - let rewrite_body ret_typ exp = + let rewrite_body id quant_kids top_env ret_typ exp = let rewrite_aux (e,ann) = match e with | E_case (E_aux (e',ann') as exp',cases) -> begin @@ -3931,16 +3951,17 @@ let add_bitvector_casts (Defs defs) = let body = match pat, guard with | P_aux (P_lit (L_aux (L_num i,_)),_), _ -> let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in - make_bitvector_cast_exp src_typ result_typ - (make_bitvector_env_casts env (kid,i) body) + make_bitvector_cast_exp env quant_kids src_typ result_typ + (make_bitvector_env_casts env quant_kids (kid,i) body) | P_aux (P_id var,_), Some guard -> (match extract_value_from_guard var guard with | Some i -> let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in - make_bitvector_cast_exp src_typ result_typ - (make_bitvector_env_casts env (kid,i) body) + make_bitvector_cast_exp env quant_kids src_typ result_typ + (make_bitvector_env_casts env quant_kids (kid,i) body) | None -> body) - | _ -> body + | _ -> + body in construct_pexp (pat, guard, body, ann) in @@ -3948,10 +3969,10 @@ let add_bitvector_casts (Defs defs) = | _ -> E_aux (e,ann) end | E_return e' -> - E_aux (E_return (make_bitvector_cast_exp (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann) + E_aux (E_return (make_bitvector_cast_exp top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann) | E_assign (LEXP_aux (lexp,lexp_annot),e') -> E_aux (E_assign (LEXP_aux (lexp,lexp_annot), - make_bitvector_cast_exp (fill_in_type (env_of e') (typ_of e')) + make_bitvector_cast_exp (env_of_annot ann) quant_kids (fill_in_type (env_of e') (typ_of e')) (typ_of_annot lexp_annot) e'),ann) | _ -> E_aux (e,ann) in @@ -3961,8 +3982,10 @@ let add_bitvector_casts (Defs defs) = e_aux = rewrite_aux } exp in let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann)) = + let (tq,typ) = Env.get_val_spec_orig id (env_of_annot fcl_ann) in + let quant_kids = List.map kopt_kid (quant_kopts tq) in let ret_typ = - match typ_of_annot fcl_ann with + match typ with | Typ_aux (Typ_fn (_,ret,_),_) -> ret | Typ_aux (_,l) as typ -> raise (Reporting_basic.err_unreachable l @@ -3970,10 +3993,11 @@ let add_bitvector_casts (Defs defs) = " is not a function type")) in let pat,guard,body,annot = destruct_pexp pexp in - let body = rewrite_body ret_typ body in + let env = env_of body in + let body = rewrite_body id quant_kids env ret_typ body in (* Also add a cast around the entire function clause body, if necessary *) let body = - make_bitvector_cast_exp (fill_in_type (env_of body) (typ_of body)) ret_typ body + make_bitvector_cast_exp env quant_kids (fill_in_type (env_of body) (typ_of body)) ret_typ body in let pexp = construct_pexp (pat,guard,body,annot) in FCL_aux (FCL_Funcl (id,pexp),fcl_ann) |
