summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBrian Campbell2019-01-31 11:14:52 +0000
committerBrian Campbell2019-01-31 11:14:52 +0000
commite62896cecd575134a85def6815fe552f3154ea01 (patch)
treeed97115d55af68526c3b4bcb874d9204275e2479 /src
parent57ee80b836440110b933350a3646ca5059badda0 (diff)
Monomorphisation: improve cast insertion and nexp rewriting on variants
It now pushes casts into lets and constructor applications, and so supports the case needed for RISC-V.
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml85
-rw-r--r--src/type_check.mli2
2 files changed, 74 insertions, 13 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 666abe86..84025595 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -3969,7 +3969,8 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
[A_aux (A_nexp size',l_size'); t_ord;
A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_) as t_bit]) -> begin
match simplify_size_nexp env quant_kids size, simplify_size_nexp env quant_kids size' with
- | Some size, Some size' when Nexp.compare size size' <> 0 ->
+ | Some size, Some size' ->
+ if Nexp.compare size size' <> 0 then
let var = fresh () in
let tar_typ' = Typ_aux (Typ_app (t_id, [A_aux (A_nexp size',l_size');t_ord;t_bit]),
tar_l) in
@@ -3980,6 +3981,10 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
E_aux (E_app (Id_aux (Id cast_name, genunk),
[E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))),
(genunk, tar_ann))
+ else
+ let var = fresh () in
+ P_aux (P_id var,(Generated src_l,src_ann)),
+ E_aux (E_id var,(Generated src_l,tar_ann))
| _ ->
let var = fresh () in
P_aux (P_id var,(Generated src_l,src_ann)),
@@ -4028,7 +4033,44 @@ let make_bitvector_env_casts env quant_kids (kid,i) exp =
Bindings.fold (fun var (mut,typ) exp ->
if mut = Immutable then mk_cast var typ exp else exp) locals exp
-let make_bitvector_cast_exp cast_name env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns cast_name env quant_kids typ target_typ)) exp
+let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp =
+ let infer_arg_typ env f l typ =
+ let (typq, ctor_typ) = Env.get_union_id f env in
+ let quants = quant_items typq in
+ match Env.expand_synonyms env ctor_typ with
+ | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) ->
+ begin
+ let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in
+ let unifiers = unify l env goals ret_typ typ in
+ let arg_typ' = subst_unifiers unifiers arg_typ in
+ arg_typ'
+ end
+ | _ -> typ_error l ("Malformed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ)
+
+ in
+ (* Push the cast down, including through constructors *)
+ let rec aux exp (typ, target_typ) =
+ let exp_env = env_of exp in
+ match exp with
+ | E_aux (E_let (lb,exp'),ann) ->
+ E_aux (E_let (lb,aux exp' (typ, target_typ)),ann)
+ | E_aux (E_tuple exps,(l,ann)) -> begin
+ match Env.expand_synonyms exp_env typ, Env.expand_synonyms exp_env target_typ with
+ | Typ_aux (Typ_tup src_typs,_), Typ_aux (Typ_tup tgt_typs,_) ->
+ E_aux (E_tuple (List.map2 aux exps (List.combine src_typs tgt_typs)),(l,ann))
+ | _ -> raise (Reporting.err_unreachable l __POS__
+ ("Attempted to insert cast on tuple on non-tuple type: " ^
+ string_of_typ typ ^ " to " ^ string_of_typ target_typ))
+ end
+ | E_aux (E_app (f,args),(l,ann)) when Env.is_union_constructor f (env_of exp) ->
+ let arg = match args with [arg] -> arg | _ -> E_aux (E_tuple args, (l,empty_tannot)) in
+ let src_arg_typ = infer_arg_typ (env_of exp) f l typ in
+ let tgt_arg_typ = infer_arg_typ (env_of exp) f l target_typ in
+ E_aux (E_app (f,[aux arg (src_arg_typ, tgt_arg_typ)]),(l,ann))
+ | _ ->
+ (snd (make_bitvector_cast_fns cast_name cast_env quant_kids typ target_typ)) exp
+ in
+ aux exp (typ, target_typ)
let rec extract_value_from_guard var (E_aux (e,_)) =
match e with
@@ -4077,13 +4119,14 @@ let add_bitvector_casts (Defs defs) =
let pat,guard,body,ann = destruct_pexp pexp in
let body = match pat, guard with
| P_aux (P_lit (L_aux (L_num i,_)),_), _ ->
- let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
+ (* We used to just substitute kid, but fill_in_type also catches other kids defined by it *)
+ let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in
make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
(make_bitvector_env_casts env quant_kids (kid,i) body)
| P_aux (P_id var,_), Some guard ->
(match extract_value_from_guard var guard with
| Some i ->
- let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in
+ let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in
make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ
(make_bitvector_env_casts env quant_kids (kid,i) body)
| None -> body)
@@ -4251,14 +4294,13 @@ let rewrite_toplevel_nexps (Defs defs) =
let nexp_map, typ = rewrite_typ_in_spec env nexp_map typ in
(nexp_map, typ::t)) typs (nexp_map,[])
in nexp_map, Typ_aux (Typ_tup typs,ann)
- | _ ->
- let typ' = Env.base_typ_of env typ_full in
+ | _ when is_number typ_full || is_bitvector_typ typ_full -> begin
let nexp_opt =
- match destruct_atom_nexp env typ' with
+ match destruct_atom_nexp env typ_full with
| Some nexp -> Some nexp
| None ->
- if is_bitvector_typ typ' then
- let (size,_,_) = vector_typ_args_of typ' in
+ if is_bitvector_typ typ_full then
+ let (size,_,_) = vector_typ_args_of typ_full in
Some size
else None
in match nexp_opt with
@@ -4274,10 +4316,27 @@ let rewrite_toplevel_nexps (Defs defs) =
(kid, nexp)::nexp_map, kid
in
let new_nexp = nvar kid in
- (* Try to avoid expanding the original type *)
- let changed, typ = replace_nexp_in_typ env typ_full nexp new_nexp in
- if changed then nexp_map, typ
- else nexp_map, snd (replace_nexp_in_typ env typ' nexp new_nexp)
+ nexp_map, snd (replace_nexp_in_typ env typ_full nexp new_nexp)
+ end
+ | _ ->
+ let typ' = Env.base_typ_of env typ_full in
+ if Typ.compare typ_full typ' == 0 then
+ match t with
+ | Typ_app (f,args) ->
+ let in_arg nexp_map (A_aux (arg,l) as arg_full) =
+ match arg with
+ | A_typ typ ->
+ let nexp_map, typ' = rewrite_typ_in_spec env nexp_map typ in
+ nexp_map, A_aux (A_typ typ',l)
+ | A_bool _ | A_nexp _ | A_order _ -> nexp_map, arg_full
+ in
+ let nexp_map, args =
+ List.fold_right (fun arg (nexp_map,args) ->
+ let nexp_map, arg = in_arg nexp_map arg in
+ (nexp_map, arg::args)) args (nexp_map,[])
+ in nexp_map, Typ_aux (Typ_app (f,args),ann)
+ | _ -> nexp_map, typ_full
+ else rewrite_typ_in_spec env nexp_map typ'
in
let rewrite_valspec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tqs,typ),ts_l),id,ext_opt,is_cast),ann)) =
match tqs with
diff --git a/src/type_check.mli b/src/type_check.mli
index 6a0f0410..ce5324d1 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -202,6 +202,8 @@ module Env : sig
val pattern_completeness_ctx : t -> Pattern_completeness.ctx
val builtin_typs : typquant Bindings.t
+
+ val get_union_id : id -> t -> typquant * typ
end
(** Push all the type variables and constraints from a typquant into