summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
authorThomas Bauereiss2020-04-12 20:30:17 +0100
committerThomas Bauereiss2020-04-21 14:02:39 +0100
commitba733dcb489ee6990bc0b5125cb5a99fc3e9b722 (patch)
tree68d738b854b180c7a2d717ad9be9ea7a611d9ac2 /src/rewrites.ml
parentab0fe3e7920d10d7f9ab74649df71deb47dfb97f (diff)
Add rewrite for constant-folding top-level letbindings
This will constant-fold letbindings such as let LOG2_TAG_GRANULE : int(4) = 4 let TAG_GRANULE : int = (1 << LOG2_TAG_GRANULE) which is useful for the translation to Lem if TAG_GRANULE is used in bitvector lengths.
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml32
1 files changed, 32 insertions, 0 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 7705c75b..3779d4f2 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -4792,6 +4792,36 @@ let rec move_loop_measures (Defs defs) =
in Defs (List.rev rev_defs)
+let rewrite_toplevel_consts target type_env (Defs defs) =
+ let istate = Constant_fold.initial_state (Defs defs) type_env in
+ let subst consts exp =
+ let open Rewriter in
+ let used_ids = fold_exp { (pure_exp_alg IdSet.empty IdSet.union) with e_id = IdSet.singleton } exp in
+ let subst_ids = IdSet.filter (fun id -> Bindings.mem id consts) used_ids in
+ IdSet.fold (fun id -> subst id (Bindings.find id consts)) subst_ids exp
+ in
+ let rewrite_def (revdefs, consts) = function
+ | DEF_val (LB_aux (LB_val (pat, exp), a) as lb) ->
+ begin match unaux_pat pat with
+ | P_id id | P_typ (_, P_aux (P_id id, _)) ->
+ let exp' = Constant_fold.rewrite_exp_once target istate (subst consts exp) in
+ if Constant_fold.is_constant exp' then
+ try
+ let exp' = infer_exp (env_of exp') (strip_exp exp') in
+ let pannot = (pat_loc pat, mk_tannot (env_of_pat pat) (typ_of exp') no_effect) in
+ let pat' = P_aux (P_typ (typ_of exp', P_aux (P_id id, pannot)), pannot) in
+ let consts' = Bindings.add id exp' consts in
+ (DEF_val (LB_aux (LB_val (pat', exp'), a)) :: revdefs, consts')
+ with
+ | _ -> (DEF_val lb :: revdefs, consts)
+ else (DEF_val lb :: revdefs, consts)
+ | _ -> (DEF_val lb :: revdefs, consts)
+ end
+ | def -> (def :: revdefs, consts)
+ in
+ let (revdefs, _) = List.fold_left rewrite_def ([], Bindings.empty) defs in
+ Defs (List.rev revdefs)
+
let opt_mono_rewrites = ref false
let opt_mono_complex_nexps = ref true
@@ -4891,6 +4921,7 @@ let all_rewrites = [
("mapping_builtins", Basic_rewriter rewrite_defs_mapping_patterns);
("mono_rewrites", Basic_rewriter mono_rewrites);
("toplevel_nexps", Basic_rewriter rewrite_toplevel_nexps);
+ ("toplevel_consts", String_rewriter (fun target -> Basic_rewriter (rewrite_toplevel_consts target)));
("monomorphise", String_rewriter (fun target -> Basic_rewriter (monomorphise target)));
("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
("add_bitvector_casts", Basic_rewriter Monomorphise.add_bitvector_casts);
@@ -4942,6 +4973,7 @@ let rewrites_lem = [
("mono_rewrites", []);
("recheck_defs", [If_mono_arg]);
("undefined", [Bool_arg false]);
+ ("toplevel_consts", [String_arg "lem"; If_mwords_arg]);
("toplevel_nexps", [If_mono_arg]);
("monomorphise", [String_arg "lem"; If_mono_arg]);
("recheck_defs", [If_mwords_arg]);