summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml52
1 files changed, 44 insertions, 8 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 6e98abb0..fbaf1234 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -174,14 +174,50 @@ let find_updated_vars exp =
fst (fold_exp
{ (compute_exp_alg IdSet.empty IdSet.union) with lEXP_aux = lEXP_aux } exp)
+let lookup_equal_kids env =
+ let get_eq_kids kid eqs = match KBindings.find_opt kid eqs with
+ | Some kids -> kids
+ | None -> KidSet.singleton kid
+ in
+ let add_eq_kids kid1 kid2 eqs =
+ let kids = KidSet.union (get_eq_kids kid2 eqs) (get_eq_kids kid1 eqs) in
+ eqs
+ |> KBindings.add kid1 kids
+ |> KBindings.add kid2 kids
+ in
+ let add_nc eqs = function
+ | NC_aux (NC_equal (Nexp_aux (Nexp_var kid1, _), Nexp_aux (Nexp_var kid2, _)), _) ->
+ add_eq_kids kid1 kid2 eqs
+ | _ -> eqs
+ in
+ List.fold_left add_nc KBindings.empty (Env.get_constraints env)
+
+let lookup_constant_kid env kid =
+ match KBindings.find_opt kid (lookup_equal_kids env) with
+ | Some kids ->
+ let check_nc const nc = match const, nc with
+ | None, NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant i, _)), _)
+ when KidSet.mem kid kids ->
+ Some i
+ | _, _ -> const
+ in
+ List.fold_left check_nc None (Env.get_constraints env)
+ | None -> None
+
let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = match nexp with
-| Nexp_id id -> rewrite_nexp_ids env (Env.get_num_def id env)
-| Nexp_times (nexp1, nexp2) -> Nexp_aux (Nexp_times (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
-| Nexp_sum (nexp1, nexp2) -> Nexp_aux (Nexp_sum (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
-| Nexp_minus (nexp1, nexp2) -> Nexp_aux (Nexp_minus (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
-| Nexp_exp nexp -> Nexp_aux (Nexp_exp (rewrite_nexp_ids env nexp), l)
-| Nexp_neg nexp -> Nexp_aux (Nexp_neg (rewrite_nexp_ids env nexp), l)
-| _ -> nexp_aux
+ | Nexp_id id -> rewrite_nexp_ids env (Env.get_num_def id env)
+ | Nexp_var kid ->
+ begin
+ match lookup_constant_kid env kid with
+ | Some i -> nconstant i
+ | None -> nexp_aux
+ end
+ | Nexp_times (nexp1, nexp2) -> Nexp_aux (Nexp_times (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
+ | Nexp_sum (nexp1, nexp2) -> Nexp_aux (Nexp_sum (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
+ | Nexp_minus (nexp1, nexp2) -> Nexp_aux (Nexp_minus (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l)
+ | Nexp_exp nexp -> Nexp_aux (Nexp_exp (rewrite_nexp_ids env nexp), l)
+ | Nexp_neg nexp -> Nexp_aux (Nexp_neg (rewrite_nexp_ids env nexp), l)
+ | _ -> nexp_aux
let rewrite_defs_nexp_ids, rewrite_typ_nexp_ids =
let rec rewrite_typ env (Typ_aux (typ, l) as typ_aux) = match typ with
@@ -3097,6 +3133,7 @@ let rewrite_defs_lem = [
("guarded_pats", rewrite_defs_guarded_pats);
("bitvector_exps", rewrite_bitvector_exps);
(* ("register_ref_writes", rewrite_register_ref_writes); *)
+ ("nexp_ids", rewrite_defs_nexp_ids);
("fix_val_specs", rewrite_fix_val_specs);
("split_execute", rewrite_split_fun_constr_pats "execute");
("recheck_defs", recheck_defs);
@@ -3107,7 +3144,6 @@ let rewrite_defs_lem = [
("trivial_sizeof", rewrite_trivial_sizeof);
("sizeof", rewrite_sizeof);
("early_return", rewrite_defs_early_return);
- ("nexp_ids", rewrite_defs_nexp_ids);
("fix_val_specs", rewrite_fix_val_specs);
("remove_blocks", rewrite_defs_remove_blocks);
("letbind_effects", rewrite_defs_letbind_effects);