diff options
Diffstat (limited to 'src/rewrites.ml')
| -rw-r--r-- | src/rewrites.ml | 52 |
1 files changed, 44 insertions, 8 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index 6e98abb0..fbaf1234 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -174,14 +174,50 @@ let find_updated_vars exp = fst (fold_exp { (compute_exp_alg IdSet.empty IdSet.union) with lEXP_aux = lEXP_aux } exp) +let lookup_equal_kids env = + let get_eq_kids kid eqs = match KBindings.find_opt kid eqs with + | Some kids -> kids + | None -> KidSet.singleton kid + in + let add_eq_kids kid1 kid2 eqs = + let kids = KidSet.union (get_eq_kids kid2 eqs) (get_eq_kids kid1 eqs) in + eqs + |> KBindings.add kid1 kids + |> KBindings.add kid2 kids + in + let add_nc eqs = function + | NC_aux (NC_equal (Nexp_aux (Nexp_var kid1, _), Nexp_aux (Nexp_var kid2, _)), _) -> + add_eq_kids kid1 kid2 eqs + | _ -> eqs + in + List.fold_left add_nc KBindings.empty (Env.get_constraints env) + +let lookup_constant_kid env kid = + match KBindings.find_opt kid (lookup_equal_kids env) with + | Some kids -> + let check_nc const nc = match const, nc with + | None, NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant i, _)), _) + when KidSet.mem kid kids -> + Some i + | _, _ -> const + in + List.fold_left check_nc None (Env.get_constraints env) + | None -> None + let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = match nexp with -| Nexp_id id -> rewrite_nexp_ids env (Env.get_num_def id env) -| Nexp_times (nexp1, nexp2) -> Nexp_aux (Nexp_times (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) -| Nexp_sum (nexp1, nexp2) -> Nexp_aux (Nexp_sum (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) -| Nexp_minus (nexp1, nexp2) -> Nexp_aux (Nexp_minus (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) -| Nexp_exp nexp -> Nexp_aux (Nexp_exp (rewrite_nexp_ids env nexp), l) -| Nexp_neg nexp -> Nexp_aux (Nexp_neg (rewrite_nexp_ids env nexp), l) -| _ -> nexp_aux + | Nexp_id id -> rewrite_nexp_ids env (Env.get_num_def id env) + | Nexp_var kid -> + begin + match lookup_constant_kid env kid with + | Some i -> nconstant i + | None -> nexp_aux + end + | Nexp_times (nexp1, nexp2) -> Nexp_aux (Nexp_times (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) + | Nexp_sum (nexp1, nexp2) -> Nexp_aux (Nexp_sum (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) + | Nexp_minus (nexp1, nexp2) -> Nexp_aux (Nexp_minus (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) + | Nexp_exp nexp -> Nexp_aux (Nexp_exp (rewrite_nexp_ids env nexp), l) + | Nexp_neg nexp -> Nexp_aux (Nexp_neg (rewrite_nexp_ids env nexp), l) + | _ -> nexp_aux let rewrite_defs_nexp_ids, rewrite_typ_nexp_ids = let rec rewrite_typ env (Typ_aux (typ, l) as typ_aux) = match typ with @@ -3097,6 +3133,7 @@ let rewrite_defs_lem = [ ("guarded_pats", rewrite_defs_guarded_pats); ("bitvector_exps", rewrite_bitvector_exps); (* ("register_ref_writes", rewrite_register_ref_writes); *) + ("nexp_ids", rewrite_defs_nexp_ids); ("fix_val_specs", rewrite_fix_val_specs); ("split_execute", rewrite_split_fun_constr_pats "execute"); ("recheck_defs", recheck_defs); @@ -3107,7 +3144,6 @@ let rewrite_defs_lem = [ ("trivial_sizeof", rewrite_trivial_sizeof); ("sizeof", rewrite_sizeof); ("early_return", rewrite_defs_early_return); - ("nexp_ids", rewrite_defs_nexp_ids); ("fix_val_specs", rewrite_fix_val_specs); ("remove_blocks", rewrite_defs_remove_blocks); ("letbind_effects", rewrite_defs_letbind_effects); |
