summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlasdair Armstrong2017-10-26 18:41:42 +0100
committerAlasdair Armstrong2017-10-26 18:41:42 +0100
commit1d38bcff2ce300f880d2ab045678bb07b2fc67a8 (patch)
tree53696e50e6c728cc11d1ee49972842623bd63e6b /src
parent68d109416999f31bf0674516e69d56ea9995be0d (diff)
parentc59cfa97be7eb21e86948e9b90ca8f4926cb5815 (diff)
Merge branch 'experiments' of https://bitbucket.org/Peter_Sewell/sail into experiments
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml37
-rw-r--r--src/ast_util.mli9
-rw-r--r--src/pretty_print_lem.ml118
-rw-r--r--src/rewriter.ml222
-rw-r--r--src/type_check.ml22
-rw-r--r--src/type_check.mli1
6 files changed, 272 insertions, 137 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index b0f4c052..4a887898 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -592,11 +592,36 @@ module Id = struct
| Id_aux (DeIid _, _), Id_aux (Id _, _) -> 1
end
+module Nexp = struct
+ type t = nexp
+ let rec compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) =
+ let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in
+ match nexp1, nexp2 with
+ | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2
+ | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2
+ | Nexp_constant c1, Nexp_constant c2 -> Pervasives.compare c1 c2
+ | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b)
+ | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b)
+ | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) ->
+ lex_ord (compare n1a n2a, compare n1b n2b)
+ | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2
+ | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2
+ | Nexp_constant _, _ -> -1 | _, Nexp_constant _ -> 1
+ | Nexp_id _, _ -> -1 | _, Nexp_id _ -> 1
+ | Nexp_var _, _ -> -1 | _, Nexp_var _ -> 1
+ | Nexp_neg _, _ -> -1 | _, Nexp_neg _ -> 1
+ | Nexp_exp _, _ -> -1 | _, Nexp_exp _ -> 1
+ | Nexp_minus _, _ -> -1 | _, Nexp_minus _ -> 1
+ | Nexp_sum _, _ -> -1 | _, Nexp_sum _ -> 1
+ | Nexp_times _, _ -> -1 | _, Nexp_times _ -> 1
+end
+
module BESet = Set.Make(BE)
module Bindings = Map.Make(Id)
module IdSet = Set.Make(Id)
module KBindings = Map.Make(Kid)
module KidSet = Set.Make(Kid)
+module NexpSet = Set.Make(Nexp)
let rec nexp_frees (Nexp_aux (nexp, l)) =
match nexp with
@@ -609,17 +634,7 @@ let rec nexp_frees (Nexp_aux (nexp, l)) =
| Nexp_exp n -> nexp_frees n
| Nexp_neg n -> nexp_frees n
-let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) =
- match nexp1, nexp2 with
- | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0
- | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 = 0
- | Nexp_constant c1, Nexp_constant c2 -> c1 = c2
- | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b
- | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b
- | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b
- | Nexp_exp n1, Nexp_exp n2 -> nexp_identical n1 n2
- | Nexp_neg n1, Nexp_neg n2 -> nexp_identical n1 n2
- | _, _ -> false
+let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0)
let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with
| Nexp_id _ | Nexp_var _ -> false
diff --git a/src/ast_util.mli b/src/ast_util.mli
index ef367e4e..d497a687 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -187,6 +187,11 @@ module Kid : sig
val compare : kid -> kid -> int
end
+module Nexp : sig
+ type t = nexp
+ val compare : nexp -> nexp -> int
+end
+
module BE : sig
type t = base_effect
val compare : base_effect -> base_effect -> int
@@ -196,6 +201,10 @@ module IdSet : sig
include Set.S with type elt = id
end
+module NexpSet : sig
+ include Set.S with type elt = nexp
+end
+
module BESet : sig
include Set.S with type elt = base_effect
end
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 23fc8287..a0a4878b 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -152,47 +152,80 @@ let is_regtyp (Typ_aux (typ, _)) env = match typ with
let doc_nexp_lem (Nexp_aux (nexp, l) as full_nexp) = match nexp with
| Nexp_constant i -> string ("ty" ^ string_of_int i)
- | Nexp_var v -> string (string_of_kid v)
- | _ -> raise (Reporting_basic.err_unreachable l
- ("cannot pretty-print non-atomic nexp \"" ^ string_of_nexp full_nexp ^ "\""))
+ | Nexp_var v -> string (string_of_kid (orig_kid v))
+ | _ ->
+ let rec mangle_nexp (Nexp_aux (nexp, _)) = begin
+ match nexp with
+ | Nexp_id id -> string_of_id id
+ | Nexp_var kid -> string_of_id (id_of_kid (orig_kid kid))
+ | Nexp_constant i -> Pretty_print_lem_ast.lemnum string_of_int i
+ | Nexp_times (n1, n2) -> mangle_nexp n1 ^ "_times_" ^ mangle_nexp n2
+ | Nexp_sum (n1, n2) -> mangle_nexp n1 ^ "_plus_" ^ mangle_nexp n2
+ | Nexp_minus (n1, n2) -> mangle_nexp n1 ^ "_minus_" ^ mangle_nexp n2
+ | Nexp_exp n -> "exp_" ^ mangle_nexp n
+ | Nexp_neg n -> "neg_" ^ mangle_nexp n
+ end in
+ string ("'" ^ mangle_nexp full_nexp)
+ (* raise (Reporting_basic.err_unreachable l
+ ("cannot pretty-print non-atomic nexp \"" ^ string_of_nexp full_nexp ^ "\"")) *)
+
+(* Rewrite mangled names of type variables to the original names *)
+let rec orig_nexp (Nexp_aux (nexp, l)) =
+ let rewrap nexp = Nexp_aux (nexp, l) in
+ match nexp with
+ | Nexp_var kid -> rewrap (Nexp_var (orig_kid kid))
+ | Nexp_times (n1, n2) -> rewrap (Nexp_times (orig_nexp n1, orig_nexp n2))
+ | Nexp_sum (n1, n2) -> rewrap (Nexp_sum (orig_nexp n1, orig_nexp n2))
+ | Nexp_minus (n1, n2) -> rewrap (Nexp_minus (orig_nexp n1, orig_nexp n2))
+ | Nexp_exp n -> rewrap (Nexp_exp (orig_nexp n))
+ | Nexp_neg n -> rewrap (Nexp_neg (orig_nexp n))
+ | _ -> rewrap nexp
(* Returns the set of type variables that will appear in the Lem output,
which may be smaller than those in the Sail type. May need to be
updated with doc_typ_lem *)
-let rec lem_tyvars_of_typ sequential mwords (Typ_aux (t,_)) =
- let trec = lem_tyvars_of_typ sequential mwords in
+let rec lem_nexps_of_typ sequential mwords (Typ_aux (t,_)) =
+ let trec = lem_nexps_of_typ sequential mwords in
match t with
| Typ_wild
- | Typ_id _ -> KidSet.empty
- | Typ_var kid -> KidSet.singleton kid
- | Typ_fn (t1,t2,_) -> KidSet.union (trec t1) (trec t2)
+ | Typ_id _ -> NexpSet.empty
+ | Typ_var kid -> NexpSet.singleton (orig_nexp (nvar kid))
+ | Typ_fn (t1,t2,_) -> NexpSet.union (trec t1) (trec t2)
| Typ_tup ts ->
- List.fold_left (fun s t -> KidSet.union s (trec t))
- KidSet.empty ts
+ List.fold_left (fun s t -> NexpSet.union s (trec t))
+ NexpSet.empty ts
| Typ_app(Id_aux (Id "vector", _), [
Typ_arg_aux (Typ_arg_nexp n, _);
Typ_arg_aux (Typ_arg_nexp m, _);
Typ_arg_aux (Typ_arg_order ord, _);
Typ_arg_aux (Typ_arg_typ elem_typ, _)]) ->
- KidSet.union
- (if mwords then tyvars_of_nexp (simplify_nexp m) else KidSet.empty)
- (trec elem_typ)
+ let m = simplify_nexp m in
+ if mwords && is_bit_typ elem_typ && not (is_nexp_constant m) then
+ NexpSet.singleton (orig_nexp m)
+ else trec elem_typ
+ (* NexpSet.union
+ (if mwords then tyvars_of_nexp (simplify_nexp m) else NexpSet.empty)
+ (trec elem_typ) *)
| Typ_app(Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ etyp, _)]) ->
- if sequential then trec etyp else KidSet.empty
+ if sequential then trec etyp else NexpSet.empty
| Typ_app(Id_aux (Id "range", _),_)
| Typ_app(Id_aux (Id "implicit", _),_)
- | Typ_app(Id_aux (Id "atom", _), _) -> KidSet.empty
+ | Typ_app(Id_aux (Id "atom", _), _) -> NexpSet.empty
| Typ_app (_,tas) ->
- List.fold_left (fun s ta -> KidSet.union s (lem_tyvars_of_typ_arg sequential mwords ta))
- KidSet.empty tas
+ List.fold_left (fun s ta -> NexpSet.union s (lem_nexps_of_typ_arg sequential mwords ta))
+ NexpSet.empty tas
| Typ_exist (kids,_,t) ->
let s = trec t in
- List.fold_left (fun s k -> KidSet.remove k s) s kids
-and lem_tyvars_of_typ_arg sequential mwords (Typ_arg_aux (ta,_)) =
+ List.fold_left (fun s k -> NexpSet.remove k s) s (List.map nvar kids)
+and lem_nexps_of_typ_arg sequential mwords (Typ_arg_aux (ta,_)) =
match ta with
- | Typ_arg_nexp nexp -> tyvars_of_nexp nexp
- | Typ_arg_typ typ -> lem_tyvars_of_typ sequential mwords typ
- | Typ_arg_order _ -> KidSet.empty
+ | Typ_arg_nexp nexp -> NexpSet.singleton (orig_nexp (simplify_nexp nexp))
+ | Typ_arg_typ typ -> lem_nexps_of_typ sequential mwords typ
+ | Typ_arg_order _ -> NexpSet.empty
+
+let lem_tyvars_of_typ sequential mwords typ =
+ NexpSet.fold (fun nexp ks -> KidSet.union ks (tyvars_of_nexp nexp))
+ (lem_nexps_of_typ sequential mwords typ) KidSet.empty
(* When making changes here, check whether they affect lem_tyvars_of_typ *)
let doc_typ_lem, doc_atomic_typ_lem =
@@ -364,7 +397,9 @@ let doc_quant_item vars_included (QI_aux (qi, _)) = match qi with
| QI_id (KOpt_aux (KOpt_kind (_, kid), _)) ->
(match vars_included with
None -> doc_var kid
- | Some set when KidSet.mem kid set -> doc_var kid
+ | Some set -> (*when KidSet.mem kid set -> doc_var kid*)
+ let nexps = NexpSet.filter (fun nexp -> KidSet.mem (orig_kid kid) (nexp_frees nexp)) set in
+ separate_map space doc_nexp_lem (NexpSet.elements nexps)
| _ -> empty)
| _ -> empty
@@ -381,39 +416,38 @@ let doc_typquant_lem (TypQ_aux(tq,_)) vars_included typ = match tq with
machine words. Often these will be unnecessary, but this simple
approach will do for now. *)
-let rec typeclass_vars (Typ_aux(t,_)) = match t with
+let rec typeclass_nexps (Typ_aux(t,_)) = match t with
| Typ_wild
| Typ_id _
| Typ_var _
- -> []
-| Typ_fn (t1,t2,_) -> typeclass_vars t1 @ typeclass_vars t2
-| Typ_tup ts -> List.concat (List.map typeclass_vars ts)
+ -> NexpSet.empty
+| Typ_fn (t1,t2,_) -> NexpSet.union (typeclass_nexps t1) (typeclass_nexps t2)
+| Typ_tup ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts)
| Typ_app (Id_aux (Id "vector",_),
- [_;Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var v,_)),_);
+ [_;Typ_arg_aux (Typ_arg_nexp size_nexp,_);
_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
- [v]
-| Typ_app _ -> []
-| Typ_exist (kids,_,t) -> [] (* todo *)
+ if is_nexp_constant (simplify_nexp size_nexp) then NexpSet.empty else
+ NexpSet.singleton (orig_nexp size_nexp)
+| Typ_app _ -> NexpSet.empty
+| Typ_exist (kids,_,t) -> NexpSet.empty (* todo *)
let doc_typclasses_lem mwords t =
if mwords then
- let vars = typeclass_vars t in
- let vars = List.sort_uniq Kid.compare vars in
- match vars with
- | [] -> (empty, KidSet.empty)
- | _ -> (separate_map comma_sp (fun var -> string "Size " ^^ doc_var var) vars ^^ string " => ", KidSet.of_list vars)
- else (empty, KidSet.empty)
+ let nexps = typeclass_nexps t in
+ if NexpSet.is_empty nexps then (empty, NexpSet.empty) else
+ (separate_map comma_sp (fun nexp -> string "Size " ^^ doc_nexp_lem nexp) (NexpSet.elements nexps) ^^ string " => ", nexps)
+ else (empty, NexpSet.empty)
let doc_typschm_lem sequential mwords quants (TypSchm_aux(TypSchm_ts(tq,t),_)) =
let pt = doc_typ_lem sequential mwords t in
if quants
then
- let tyvars_used = lem_tyvars_of_typ sequential mwords t in
- let ptyc, tyvars_sizes = doc_typclasses_lem mwords t in
- let tyvars_to_include = KidSet.union tyvars_used tyvars_sizes in
- if KidSet.is_empty tyvars_to_include
+ let nexps_used = lem_nexps_of_typ sequential mwords t in
+ let ptyc, nexps_sizes = doc_typclasses_lem mwords t in
+ let nexps_to_include = NexpSet.union nexps_used nexps_sizes in
+ if NexpSet.is_empty nexps_to_include
then pt
- else doc_typquant_lem tq (Some tyvars_to_include) (ptyc ^^ pt)
+ else doc_typquant_lem tq (Some nexps_to_include) (ptyc ^^ pt)
else pt
let is_ctor env id = match Env.lookup_id id env with
diff --git a/src/rewriter.ml b/src/rewriter.ml
index bcfb731a..5329b01d 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -996,6 +996,46 @@ let compute_exp_alg bot join =
; pat_alg = compute_pat_alg bot join
}
+let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = match nexp with
+| Nexp_id id -> rewrite_nexp_ids env (Env.get_num_def id env)
+| Nexp_times (nexp1, nexp2) -> Nexp_aux (Nexp_times (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
+| Nexp_sum (nexp1, nexp2) -> Nexp_aux (Nexp_sum (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
+| Nexp_minus (nexp1, nexp2) -> Nexp_aux (Nexp_minus (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
+| Nexp_exp nexp -> Nexp_aux (Nexp_exp (rewrite_nexp_ids env nexp), l)
+| Nexp_neg nexp -> Nexp_aux (Nexp_neg (rewrite_nexp_ids env nexp), l)
+| _ -> nexp_aux
+
+let rewrite_defs_nexp_ids, rewrite_typ_nexp_ids =
+ let rec rewrite_typ env (Typ_aux (typ, l) as typ_aux) = match typ with
+ | Typ_fn (arg_t, ret_t, eff) ->
+ Typ_aux (Typ_fn (rewrite_typ env arg_t, rewrite_typ env ret_t, eff), l)
+ | Typ_tup ts ->
+ Typ_aux (Typ_tup (List.map (rewrite_typ env) ts), l)
+ | Typ_exist (kids, c, typ) ->
+ Typ_aux (Typ_exist (kids, c, rewrite_typ env typ), l)
+ | Typ_app (id, targs) ->
+ Typ_aux (Typ_app (id, List.map (rewrite_typ_arg env) targs), l)
+ | _ -> typ_aux
+ and rewrite_typ_arg env (Typ_arg_aux (targ, l) as targ_aux) = match targ with
+ | Typ_arg_nexp nexp ->
+ Typ_arg_aux (Typ_arg_nexp (rewrite_nexp_ids env nexp), l)
+ | Typ_arg_typ typ ->
+ Typ_arg_aux (Typ_arg_typ (rewrite_typ env typ), l)
+ | Typ_arg_order ord ->
+ Typ_arg_aux (Typ_arg_order ord, l)
+ in
+
+ let rewrite_annot = function
+ | (l, Some (env, typ, eff)) -> (l, Some (env, rewrite_typ env typ, eff))
+ | (l, None) -> (l, None)
+ in
+
+ rewrite_defs_base {
+ rewriters_base with rewrite_exp = (fun _ -> map_exp_annot rewrite_annot)
+ },
+ rewrite_typ
+
+
(* Re-write trivial sizeof expressions - trivial meaning that the
value of the sizeof can be directly inferred from the type
variables in scope. *)
@@ -1023,14 +1063,6 @@ let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp =
| Nexp_neg nexp -> mk_exp (E_app (mk_id "negate_range", [split_nexp nexp]))
| _ -> mk_exp (E_sizeof nexp)
in
- let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = match nexp with
- | Nexp_id id -> rewrite_nexp_ids env (Env.get_num_def id env)
- | Nexp_times (nexp1, nexp2) -> Nexp_aux (Nexp_times (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
- | Nexp_sum (nexp1, nexp2) -> Nexp_aux (Nexp_sum (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
- | Nexp_minus (nexp1, nexp2) -> Nexp_aux (Nexp_minus (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
- | Nexp_exp nexp -> Nexp_aux (Nexp_exp (rewrite_nexp_ids env nexp), l)
- | Nexp_neg nexp -> Nexp_aux (Nexp_neg (rewrite_nexp_ids env nexp), l)
- | _ -> nexp_aux in
let rec rewrite_e_aux split_sizeof (E_aux (e_aux, (l, _)) as orig_exp) =
let env = env_of orig_exp in
match e_aux with
@@ -1049,7 +1081,7 @@ let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp =
|> List.concat
in
match exps with
- | (exp :: _) -> exp
+ | (exp :: _) -> check_exp env (strip_exp exp) (typ_of exp)
| [] when split_sizeof ->
fold_exp (rewrite_e_sizeof false) (check_exp env (split_nexp nexp) (typ_of orig_exp))
| [] -> orig_exp
@@ -1318,8 +1350,8 @@ let rewrite_sizeof (Defs defs) =
| _ -> Typ_aux (Typ_tup (kid_typs @ [vtyp_arg]), vl)
end in
Typ_aux (Typ_fn (vtyp_arg', vtyp_ret, declared_eff), vl)
- | _ -> raise (Reporting_basic.err_typ l
- "val spec with non-function type") in
+ | _ ->
+ raise (Reporting_basic.err_typ l "val spec with non-function type") in
TypSchm_aux (TypSchm_ts (tq, typ'), l)
else ts in
match def with
@@ -2401,40 +2433,59 @@ let rewrite_defs_early_return =
(* Propagate effects of functions, if effect checking and propagation
have not been performed already by the type checker. *)
-let rewrite_fix_fun_effs (Defs defs) =
- let e_aux fun_effs (exp, (l, annot)) =
+let rewrite_fix_val_specs (Defs defs) =
+ let find_vs env val_specs id =
+ try Bindings.find id val_specs with
+ | Not_found ->
+ begin
+ try Env.get_val_spec id env with
+ | _ ->
+ raise (Reporting_basic.err_unreachable (Parse_ast.Unknown)
+ ("No val spec found for " ^ string_of_id id))
+ end
+ in
+
+ let add_eff_to_vs eff = function
+ | (tq, Typ_aux (Typ_fn (args_t, ret_t, eff'), a)) ->
+ (tq, Typ_aux (Typ_fn (args_t, ret_t, union_effects eff eff'), a))
+ | vs -> vs
+ in
+
+ let eff_of_vs = function
+ | (tq, Typ_aux (Typ_fn (args_t, ret_t, eff), a)) -> eff
+ | _ -> no_effect
+ in
+
+ let e_aux val_specs (exp, (l, annot)) =
match fix_eff_exp (E_aux (exp, (l, annot))) with
| E_aux (E_app_infix (_, f, _) as exp, (l, Some (env, typ, eff)))
- | E_aux (E_app (f, _) as exp, (l, Some (env, typ, eff)))
- when Bindings.mem f fun_effs ->
- let eff' = Bindings.find f fun_effs in
- let env =
- try
- match Env.get_val_spec f env with
- | (tq, Typ_aux (Typ_fn (args_t, ret_t, eff), a)) ->
- Env.update_val_spec f (tq, Typ_aux (Typ_fn (args_t, ret_t, union_effects eff eff'), a)) env
- | _ -> env
- with
- | _ -> env in
- E_aux (exp, (l, Some (env, typ, union_effects eff eff')))
- | e_aux -> e_aux in
-
- let rewrite_exp fun_effs = fold_exp { id_exp_alg with e_aux = e_aux fun_effs } in
-
- let rewrite_funcl (fun_effs, funcls) (FCL_aux (FCL_Funcl (id, pat, exp), (l, annot))) =
- let exp = propagate_exp_effect (rewrite_exp fun_effs exp) in
- let fun_eff =
- try union_effects (effect_of exp) (Bindings.find id fun_effs)
- with Not_found -> (effect_of exp) in
- let annot =
- match annot with
- | Some (env, typ, eff) -> Some (env, typ, union_effects eff fun_eff)
- | None -> None in
- (Bindings.add id fun_eff fun_effs,
- funcls @ [FCL_aux (FCL_Funcl (id, pat, exp), (l, annot))]) in
-
- let rewrite_fundef (fun_effs, FD_aux (FD_function (recopt, tannotopt, effopt, funcls), a)) =
- let (fun_effs, funcls) = List.fold_left rewrite_funcl (fun_effs, []) funcls in
+ | E_aux (E_app (f, _) as exp, (l, Some (env, typ, eff))) ->
+ let vs = find_vs env val_specs f in
+ let env = Env.update_val_spec f vs env in
+ E_aux (exp, (l, Some (env, typ, union_effects eff (eff_of_vs vs))))
+ | e_aux -> e_aux
+ in
+
+ let rewrite_exp val_specs = fold_exp { id_exp_alg with e_aux = e_aux val_specs } in
+
+ let rewrite_funcl (val_specs, funcls) (FCL_aux (FCL_Funcl (id, pat, exp), (l, annot))) =
+ let exp = propagate_exp_effect (rewrite_exp val_specs exp) in
+ let vs, eff = match find_vs (env_of_annot (l, annot)) val_specs id with
+ | (tq, Typ_aux (Typ_fn (args_t, ret_t, eff), a)) ->
+ let eff' = union_effects eff (effect_of exp) in
+ let args_t' = rewrite_typ_nexp_ids (env_of exp) (pat_typ_of pat) in
+ let ret_t' = rewrite_typ_nexp_ids (env_of exp) (typ_of exp) in
+ (tq, Typ_aux (Typ_fn (args_t', ret_t', eff'), a)), eff'
+ in
+ let annot = add_effect_annot annot eff in
+ (Bindings.add id vs val_specs,
+ funcls @ [FCL_aux (FCL_Funcl (id, pat, exp), (l, annot))])
+ in
+
+ let rewrite_fundef (val_specs, FD_aux (FD_function (recopt, tannotopt, effopt, funcls), a)) =
+ let (val_specs, funcls) = List.fold_left rewrite_funcl (val_specs, []) funcls in
+ (* Repeat once to cross-propagate effects between clauses *)
+ let (val_specs, funcls) = List.fold_left rewrite_funcl (val_specs, []) funcls in
let is_funcl_rec (FCL_aux (FCL_Funcl (id, _, exp), _)) =
fst (fold_exp
{ (compute_exp_alg false (||) ) with
@@ -2444,39 +2495,65 @@ let rewrite_fix_fun_effs (Defs defs) =
E_app (f, es)));
e_app_infix = (fun ((r1,e1), f, (r2,e2)) ->
(r1 || r2 || (string_of_id f = string_of_id id),
- E_app_infix (e1, f, e2))) } exp) in
- let is_rec = List.exists is_funcl_rec funcls in
- (* Repeat once for recursive functions:
- propagates union of effects to all clauses *)
- let recopt, (fun_effs, funcls) =
- if is_rec then
- Rec_aux (Rec_rec, Parse_ast.Unknown),
- List.fold_left rewrite_funcl (fun_effs, []) funcls
- else recopt, (fun_effs, funcls) in
- (fun_effs, FD_aux (FD_function (recopt, tannotopt, effopt, funcls), a)) in
-
- let rec rewrite_fundefs (fun_effs, fundefs) =
+ E_app_infix (e1, f, e2))) }
+ exp)
+ in
+ let recopt =
+ if List.exists is_funcl_rec funcls then
+ Rec_aux (Rec_rec, Parse_ast.Unknown)
+ else recopt
+ in
+ (val_specs, FD_aux (FD_function (recopt, tannotopt, effopt, funcls), a)) in
+
+ let rec rewrite_fundefs (val_specs, fundefs) =
match fundefs with
| fundef :: fundefs ->
- let (fun_effs, fundef) = rewrite_fundef (fun_effs, fundef) in
- let (fun_effs, fundefs) = rewrite_fundefs (fun_effs, fundefs) in
- (fun_effs, fundef :: fundefs)
- | [] -> (fun_effs, []) in
-
- let rewrite_def (fun_effs, defs) = function
- | DEF_fundef fundef ->
- let (fun_effs, fundef) = rewrite_fundef (fun_effs, fundef) in
- (fun_effs, defs @ [DEF_fundef fundef])
- | DEF_internal_mutrec fundefs ->
- let (fun_effs, fundefs) = rewrite_fundefs (fun_effs, fundefs) in
- (fun_effs, defs @ [DEF_internal_mutrec fundefs])
- | DEF_val (LB_aux (LB_val (pat, exp), a)) ->
- (fun_effs, defs @ [DEF_val (LB_aux (LB_val (pat, rewrite_exp fun_effs exp), a))])
- | def -> (fun_effs, defs @ [def]) in
+ let (val_specs, fundef) = rewrite_fundef (val_specs, fundef) in
+ let (val_specs, fundefs) = rewrite_fundefs (val_specs, fundefs) in
+ (val_specs, fundef :: fundefs)
+ | [] -> (val_specs, []) in
+
+ let rewrite_def (val_specs, defs) = function
+ | DEF_fundef fundef ->
+ let (val_specs, fundef) = rewrite_fundef (val_specs, fundef) in
+ (val_specs, defs @ [DEF_fundef fundef])
+ | DEF_internal_mutrec fundefs ->
+ let (val_specs, fundefs) = rewrite_fundefs (val_specs, fundefs) in
+ (val_specs, defs @ [DEF_internal_mutrec fundefs])
+ | DEF_val (LB_aux (LB_val (pat, exp), a)) ->
+ (val_specs, defs @ [DEF_val (LB_aux (LB_val (pat, rewrite_exp val_specs exp), a))])
+ | DEF_spec (VS_aux (VS_val_spec (typschm, id, ext_opt, is_cast), a)) ->
+ let typschm, val_specs =
+ if Bindings.mem id val_specs then begin
+ let (tq, typ) = Bindings.find id val_specs in
+ TypSchm_aux (TypSchm_ts (tq, typ), Parse_ast.Unknown), val_specs
+ end else begin
+ let (TypSchm_aux (TypSchm_ts (tq, typ), _)) = typschm in
+ typschm, Bindings.add id (tq, typ) val_specs
+ end
+ in
+ (val_specs, defs @ [DEF_spec (VS_aux (VS_val_spec (typschm, id, ext_opt, is_cast), a))])
+ | def -> (val_specs, defs @ [def])
+ in
+
+ let rewrite_val_specs val_specs = function
+ | DEF_spec (VS_aux (VS_val_spec (typschm, id, ext_opt, is_cast), a))
+ when Bindings.mem id val_specs ->
+ let typschm = match typschm with
+ | TypSchm_aux (TypSchm_ts (tq, typ), l) ->
+ let (tq, typ) = Bindings.find id val_specs in
+ TypSchm_aux (TypSchm_ts (tq, typ), l)
+ in
+ DEF_spec (VS_aux (VS_val_spec (typschm, id, ext_opt, is_cast), a))
+ | def -> def
+ in
+
+ let (val_specs, defs) = List.fold_left rewrite_def (Bindings.empty, []) defs in
+ let defs = List.map (rewrite_val_specs val_specs) defs in
(* if !Type_check.opt_no_effects
then *)
- Defs (snd (List.fold_left rewrite_def (Bindings.empty, []) defs))
+ Defs defs
(* else Defs defs *)
(* Turn constraints into numeric expressions with sizeof *)
@@ -3540,7 +3617,8 @@ let rewrite_defs_lem = [
("guarded_pats", rewrite_defs_guarded_pats);
(* ("recheck_defs", recheck_defs); *)
("early_return", rewrite_defs_early_return);
- ("fix_fun_effs", rewrite_fix_fun_effs);
+ ("nexp_ids", rewrite_defs_nexp_ids);
+ ("fix_val_specs", rewrite_fix_val_specs);
("exp_lift_assign", rewrite_defs_exp_lift_assign);
("remove_blocks", rewrite_defs_remove_blocks);
("letbind_effects", rewrite_defs_letbind_effects);
diff --git a/src/type_check.ml b/src/type_check.ml
index 270b2cf4..751e70e5 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -2858,24 +2858,22 @@ let effect_of_annot = function
let effect_of (E_aux (exp, (l, annot))) = effect_of_annot annot
-let add_effect (E_aux (exp, (l, annot))) eff1 =
- match annot with
- | Some (env, typ, eff2) -> E_aux (exp, (l, Some (env, typ, union_effects eff1 eff2)))
- | None -> assert false
+let add_effect_annot annot eff = match annot with
+ | Some (env, typ, eff') -> Some (env, typ, union_effects eff eff')
+ | None -> None
+
+let add_effect (E_aux (exp, (l, annot))) eff =
+ E_aux (exp, (l, add_effect_annot annot eff))
let effect_of_lexp (LEXP_aux (exp, (l, annot))) = effect_of_annot annot
-let add_effect_lexp (LEXP_aux (lexp, (l, annot))) eff1 =
- match annot with
- | Some (env, typ, eff2) -> LEXP_aux (lexp, (l, Some (env, typ, union_effects eff1 eff2)))
- | None -> assert false
+let add_effect_lexp (LEXP_aux (lexp, (l, annot))) eff =
+ LEXP_aux (lexp, (l, add_effect_annot annot eff))
let effect_of_pat (P_aux (exp, (l, annot))) = effect_of_annot annot
-let add_effect_pat (P_aux (pat, (l, annot))) eff1 =
- match annot with
- | Some (env, typ, eff2) -> P_aux (pat, (l, Some (env, typ, union_effects eff1 eff2)))
- | None -> assert false
+let add_effect_pat (P_aux (pat, (l, annot))) eff =
+ P_aux (pat, (l, add_effect_annot annot eff))
let collect_effects xs = List.fold_left union_effects no_effect (List.map effect_of xs)
diff --git a/src/type_check.mli b/src/type_check.mli
index ff9eb74e..f3e5e861 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -217,6 +217,7 @@ val pat_env_of : tannot pat -> Env.t
val effect_of : tannot exp -> effect
val effect_of_annot : tannot -> effect
+val add_effect_annot : tannot -> effect -> tannot
val destruct_atom_nexp : Env.t -> typ -> nexp option