summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ast_util.ml162
-rw-r--r--src/ast_util.mli1
-rw-r--r--src/pretty_print_lem.ml36
-rw-r--r--src/rewriter.ml8
4 files changed, 86 insertions, 121 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 4a887898..8106f89c 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -108,11 +108,70 @@ let is_typ_kopt = function
| KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_type, _)], _), _), _) -> true
| _ -> false
+let string_of_kid = function
+ | Kid_aux (Var v, _) -> v
+
+module Kid = struct
+ type t = kid
+ let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2)
+end
+
+module Id = struct
+ type t = id
+ let compare id1 id2 =
+ match (id1, id2) with
+ | Id_aux (Id x, _), Id_aux (Id y, _) -> String.compare x y
+ | Id_aux (DeIid x, _), Id_aux (DeIid y, _) -> String.compare x y
+ | Id_aux (Id _, _), Id_aux (DeIid _, _) -> -1
+ | Id_aux (DeIid _, _), Id_aux (Id _, _) -> 1
+end
+
+module Nexp = struct
+ type t = nexp
+ let rec compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) =
+ let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in
+ match nexp1, nexp2 with
+ | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2
+ | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2
+ | Nexp_constant c1, Nexp_constant c2 -> Pervasives.compare c1 c2
+ | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b)
+ | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b)
+ | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) ->
+ lex_ord (compare n1a n2a, compare n1b n2b)
+ | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2
+ | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2
+ | Nexp_constant _, _ -> -1 | _, Nexp_constant _ -> 1
+ | Nexp_id _, _ -> -1 | _, Nexp_id _ -> 1
+ | Nexp_var _, _ -> -1 | _, Nexp_var _ -> 1
+ | Nexp_neg _, _ -> -1 | _, Nexp_neg _ -> 1
+ | Nexp_exp _, _ -> -1 | _, Nexp_exp _ -> 1
+ | Nexp_minus _, _ -> -1 | _, Nexp_minus _ -> 1
+ | Nexp_sum _, _ -> -1 | _, Nexp_sum _ -> 1
+ | Nexp_times _, _ -> -1 | _, Nexp_times _ -> 1
+end
+
+module Bindings = Map.Make(Id)
+module IdSet = Set.Make(Id)
+module KBindings = Map.Make(Kid)
+module KidSet = Set.Make(Kid)
+module NexpSet = Set.Make(Nexp)
+
+let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0)
+
+let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with
+ | Nexp_id _ | Nexp_var _ -> false
+ | Nexp_constant _ -> true
+ | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) ->
+ is_nexp_constant n1 && is_nexp_constant n2
+ | Nexp_exp n | Nexp_neg n -> is_nexp_constant n
+
let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l)
and nexp_simp_aux = function
- | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), Nexp_aux (Nexp_constant c1, _)), _), Nexp_aux (Nexp_constant c2, _)) when c1 = c2 ->
+ | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), nexp2), _), nexp3)
+ when nexp_identical nexp2 nexp3 ->
nexp_simp_aux n1
- | Nexp_sum (Nexp_aux (Nexp_minus (Nexp_aux (n1, _), Nexp_aux (Nexp_constant c1, _)), _), Nexp_aux (Nexp_constant c2, _)) when c1 = c2 ->
+ | Nexp_sum (Nexp_aux (Nexp_minus (Nexp_aux (n1, _), nexp2), _), nexp3)
+ when nexp_identical nexp2 nexp3 ->
nexp_simp_aux n1
| Nexp_sum (n1, n2) ->
begin
@@ -128,6 +187,8 @@ and nexp_simp_aux = function
let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in
let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in
match n1_simp, n2_simp with
+ | Nexp_constant 1, _ -> n2_simp
+ | _, Nexp_constant 1 -> n1_simp
| Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (c1 * c2)
| _, _ -> Nexp_times (n1, n2)
end
@@ -333,9 +394,6 @@ let kid_of_id = function
| Id_aux (Id v, l) -> Kid_aux (Var ("'" ^ v), l)
| Id_aux (DeIid v, _) -> assert false
-let string_of_kid = function
- | Kid_aux (Var v, _) -> v
-
let prepend_id str = function
| Id_aux (Id v, l) -> Id_aux (Id (str ^ v), l)
| Id_aux (DeIid v, l) -> Id_aux (DeIid (str ^ v), l)
@@ -572,56 +630,12 @@ let id_of_fundef (FD_aux (FD_function (_, _, _, funcls), (l, _))) =
| Some id -> id
| None -> raise (Reporting_basic.err_typ l "funcl list is empty")
-module Kid = struct
- type t = kid
- let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2)
-end
-
module BE = struct
type t = base_effect
let compare be1 be2 = String.compare (string_of_base_effect be1) (string_of_base_effect be2)
end
-module Id = struct
- type t = id
- let compare id1 id2 =
- match (id1, id2) with
- | Id_aux (Id x, _), Id_aux (Id y, _) -> String.compare x y
- | Id_aux (DeIid x, _), Id_aux (DeIid y, _) -> String.compare x y
- | Id_aux (Id _, _), Id_aux (DeIid _, _) -> -1
- | Id_aux (DeIid _, _), Id_aux (Id _, _) -> 1
-end
-
-module Nexp = struct
- type t = nexp
- let rec compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) =
- let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in
- match nexp1, nexp2 with
- | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2
- | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2
- | Nexp_constant c1, Nexp_constant c2 -> Pervasives.compare c1 c2
- | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b)
- | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b)
- | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) ->
- lex_ord (compare n1a n2a, compare n1b n2b)
- | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2
- | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2
- | Nexp_constant _, _ -> -1 | _, Nexp_constant _ -> 1
- | Nexp_id _, _ -> -1 | _, Nexp_id _ -> 1
- | Nexp_var _, _ -> -1 | _, Nexp_var _ -> 1
- | Nexp_neg _, _ -> -1 | _, Nexp_neg _ -> 1
- | Nexp_exp _, _ -> -1 | _, Nexp_exp _ -> 1
- | Nexp_minus _, _ -> -1 | _, Nexp_minus _ -> 1
- | Nexp_sum _, _ -> -1 | _, Nexp_sum _ -> 1
- | Nexp_times _, _ -> -1 | _, Nexp_times _ -> 1
-end
-
module BESet = Set.Make(BE)
-module Bindings = Map.Make(Id)
-module IdSet = Set.Make(Id)
-module KBindings = Map.Make(Kid)
-module KidSet = Set.Make(Kid)
-module NexpSet = Set.Make(Nexp)
let rec nexp_frees (Nexp_aux (nexp, l)) =
match nexp with
@@ -634,54 +648,6 @@ let rec nexp_frees (Nexp_aux (nexp, l)) =
| Nexp_exp n -> nexp_frees n
| Nexp_neg n -> nexp_frees n
-let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0)
-
-let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with
- | Nexp_id _ | Nexp_var _ -> false
- | Nexp_constant _ -> true
- | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) ->
- is_nexp_constant n1 && is_nexp_constant n2
- | Nexp_exp n | Nexp_neg n -> is_nexp_constant n
-
-let rec simplify_nexp (Nexp_aux (nexp, l)) =
- let rewrap n = Nexp_aux (n, l) in
- let try_binop op n1 n2 c = (match simplify_nexp n1, simplify_nexp n2 with
- | Nexp_aux (Nexp_constant i1, _), Nexp_aux (Nexp_constant i2, _) ->
- rewrap (Nexp_constant (op i1 i2))
- | n1, n2 -> rewrap (c n1 n2)) in
- match nexp with
- | Nexp_times (Nexp_aux (Nexp_constant 1,_),n')
- | Nexp_times (n',Nexp_aux (Nexp_constant 1,_))
- -> simplify_nexp n'
- | Nexp_times (n1, n2) -> try_binop ( * ) n1 n2 (fun n1 n2 -> Nexp_times (n1, n2))
- | Nexp_sum (n1, n2) -> try_binop ( + ) n1 n2 (fun n1 n2 -> Nexp_sum (n1, n2))
- | Nexp_minus (n', Nexp_aux (Nexp_constant 0,_)) -> simplify_nexp n'
- (* A vector range x['n-1 .. 0] can result in the size "('n-1) - -1" *)
- | Nexp_minus (n1, n2) ->
- begin
- match simplify_nexp n1, simplify_nexp n2 with
- | Nexp_aux (Nexp_minus (n', Nexp_aux (Nexp_constant 1,_)),_),
- Nexp_aux (Nexp_constant (-1),_) -> simplify_nexp n'
- | Nexp_aux (Nexp_constant i1,_), Nexp_aux (Nexp_constant i2,_) ->
- rewrap (Nexp_constant (i1-i2))
- | n1',n2' -> rewrap (Nexp_minus (n1,n2))
- end
- (* | Nexp_exp n ->
- (match simplify_nexp n with
- | Nexp_aux (Nexp_constant i, _) ->
- rewrap (Nexp_constant (power 2 i))
- | n -> rewrap (Nexp_exp n)) *)
- | Nexp_neg n ->
- (match simplify_nexp n with
- | Nexp_aux (Nexp_constant i, _) ->
- rewrap (Nexp_constant (-i))
- | n -> rewrap (Nexp_neg n))
- | _ -> rewrap nexp
- (* | Nexp_sum of nexp * nexp (* sum *)
- | Nexp_minus of nexp * nexp (* subtraction *)
- | Nexp_exp of nexp (* exponential *)
- | Nexp_neg of nexp (* For internal use *) *)
-
let rec lexp_to_exp (LEXP_aux (lexp_aux, annot) as le) =
let rewrap e_aux = E_aux (e_aux, annot) in
match lexp_aux with
@@ -726,7 +692,7 @@ let typ_app_args_of = function
let rec vector_typ_args_of typ = match typ_app_args_of typ with
| ("vector", [Typ_arg_nexp start; Typ_arg_nexp len; Typ_arg_order ord; Typ_arg_typ etyp], _) ->
- (simplify_nexp start, simplify_nexp len, ord, etyp)
+ (nexp_simp start, nexp_simp len, ord, etyp)
| ("register", [Typ_arg_typ rtyp], _) -> vector_typ_args_of rtyp
| (_, _, l) ->
raise (Reporting_basic.err_typ l
diff --git a/src/ast_util.mli b/src/ast_util.mli
index d497a687..1a056e58 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -224,7 +224,6 @@ end
val nexp_frees : nexp -> KidSet.t
val nexp_identical : nexp -> nexp -> bool
val is_nexp_constant : nexp -> bool
-val simplify_nexp : nexp -> nexp
val lexp_to_exp : 'a lexp -> 'a exp
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index a0a4878b..7d0d3459 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -199,12 +199,12 @@ let rec lem_nexps_of_typ sequential mwords (Typ_aux (t,_)) =
Typ_arg_aux (Typ_arg_nexp m, _);
Typ_arg_aux (Typ_arg_order ord, _);
Typ_arg_aux (Typ_arg_typ elem_typ, _)]) ->
- let m = simplify_nexp m in
+ let m = nexp_simp m in
if mwords && is_bit_typ elem_typ && not (is_nexp_constant m) then
NexpSet.singleton (orig_nexp m)
else trec elem_typ
(* NexpSet.union
- (if mwords then tyvars_of_nexp (simplify_nexp m) else NexpSet.empty)
+ (if mwords then tyvars_of_nexp (nexp_simp m) else NexpSet.empty)
(trec elem_typ) *)
| Typ_app(Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ etyp, _)]) ->
if sequential then trec etyp else NexpSet.empty
@@ -219,7 +219,7 @@ let rec lem_nexps_of_typ sequential mwords (Typ_aux (t,_)) =
List.fold_left (fun s k -> NexpSet.remove k s) s (List.map nvar kids)
and lem_nexps_of_typ_arg sequential mwords (Typ_arg_aux (ta,_)) =
match ta with
- | Typ_arg_nexp nexp -> NexpSet.singleton (orig_nexp (simplify_nexp nexp))
+ | Typ_arg_nexp nexp -> NexpSet.singleton (nexp_simp (orig_nexp nexp))
| Typ_arg_typ typ -> lem_nexps_of_typ sequential mwords typ
| Typ_arg_order _ -> NexpSet.empty
@@ -256,8 +256,8 @@ let doc_typ_lem, doc_atomic_typ_lem =
Typ_arg_aux (Typ_arg_typ elem_typ, _)]) ->
let tpp = match elem_typ with
| Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) when mwords ->
- string "bitvector " ^^ doc_nexp_lem (simplify_nexp m)
- (* (match simplify_nexp m with
+ string "bitvector " ^^ doc_nexp_lem (nexp_simp m)
+ (* (match nexp_simp m with
| (Nexp_aux(Nexp_constant i,_)) -> string "bitvector ty" ^^ doc_int i
| (Nexp_aux(Nexp_var _, _)) -> separate space [string "bitvector"; doc_nexp m]
| _ -> raise (Reporting_basic.err_unreachable l
@@ -308,7 +308,7 @@ let doc_typ_lem, doc_atomic_typ_lem =
end
and doc_typ_arg_lem sequential mwords (Typ_arg_aux(t,_)) = match t with
| Typ_arg_typ t -> app_typ sequential mwords true t
- | Typ_arg_nexp n -> doc_nexp_lem (simplify_nexp n)
+ | Typ_arg_nexp n -> doc_nexp_lem (nexp_simp n)
| Typ_arg_order o -> empty
in typ', atomic_typ
@@ -329,11 +329,11 @@ let rec contains_t_pp_var (Typ_aux (t,a) as typ) = match t with
if Ast_util.is_number typ then false
else if is_bitvector_typ typ then
let (_,length,_,_) = vector_typ_args_of typ in
- not (is_nexp_constant (simplify_nexp length))
+ not (is_nexp_constant (nexp_simp length))
else List.exists contains_t_arg_pp_var targs
and contains_t_arg_pp_var (Typ_arg_aux (targ, _)) = match targ with
| Typ_arg_typ t -> contains_t_pp_var t
- | Typ_arg_nexp nexp -> not (is_nexp_constant (simplify_nexp nexp))
+ | Typ_arg_nexp nexp -> not (is_nexp_constant (nexp_simp nexp))
| _ -> false
let doc_tannot_lem sequential mwords eff typ =
@@ -426,8 +426,8 @@ let rec typeclass_nexps (Typ_aux(t,_)) = match t with
| Typ_app (Id_aux (Id "vector",_),
[_;Typ_arg_aux (Typ_arg_nexp size_nexp,_);
_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
- if is_nexp_constant (simplify_nexp size_nexp) then NexpSet.empty else
- NexpSet.singleton (orig_nexp size_nexp)
+ if is_nexp_constant (nexp_simp size_nexp) then NexpSet.empty else
+ NexpSet.singleton (nexp_simp (orig_nexp size_nexp))
| Typ_app _ -> NexpSet.empty
| Typ_exist (kids,_,t) -> NexpSet.empty (* todo *)
@@ -748,7 +748,7 @@ let doc_exp_lem, doc_let_lem =
let ord_suffix = if is_order_inc ord then "_inc" else "_dec" in
let bit_prefix = if is_bitvector_typ vtyp then "bit" else "" in
let call = bit_prefix ^ "vector_access" ^ ord_suffix in
- let start_idx = match simplify_nexp (start) with
+ let start_idx = match nexp_simp (start) with
| Nexp_aux (Nexp_constant i, _) -> expN (simple_num l i)
| _ ->
let nc = nc_eq start (nminus len (nconstant 1)) in
@@ -776,7 +776,7 @@ let doc_exp_lem, doc_let_lem =
let ord_suffix = if is_order_inc ord then "_inc" else "_dec" in
let bit_prefix = if is_bitvector_typ vtyp then "bit" else "" in
let call = bit_prefix ^ "vector_subrange" ^ ord_suffix in
- let start_idx = match simplify_nexp (start) with
+ let start_idx = match nexp_simp (start) with
| Nexp_aux (Nexp_constant i, _) -> expN (simple_num l i)
| _ ->
let nc = nc_eq start (nminus len (nconstant 1)) in
@@ -942,7 +942,7 @@ let doc_exp_lem, doc_let_lem =
| Tapp("vector", [TA_nexp start; TA_nexp len; TA_ord order; TA_typ etyp])
| Tabbrev(_,{t= Tapp("vector", [TA_nexp start; TA_nexp len; TA_ord order; TA_typ etyp])}) ->*)
let dir,dir_out = if is_order_inc order then (true,"true") else (false, "false") in
- let start = match simplify_nexp start with
+ let start = match nexp_simp start with
| Nexp_aux (Nexp_constant i, _) -> string_of_int i
| _ -> if dir then "0" else string_of_int (List.length exps) in
let expspp =
@@ -972,7 +972,7 @@ let doc_exp_lem, doc_let_lem =
let ord_suffix = if is_order_inc ord then "_inc" else "_dec" in
let bit_prefix = if is_bitvector_typ t then "bit" else "" in
let call = bit_prefix ^ "vector_update_pos" ^ ord_suffix in
- let start_idx = match simplify_nexp (start) with
+ let start_idx = match nexp_simp (start) with
| Nexp_aux (Nexp_constant i, _) -> expN (simple_num l i)
| _ ->
let nc = nc_eq start (nminus len (nconstant 1)) in
@@ -989,7 +989,7 @@ let doc_exp_lem, doc_let_lem =
let ord_suffix = if is_order_inc ord then "_inc" else "_dec" in
let bit_prefix = if is_bitvector_typ t then "bit" else "" in
let call = bit_prefix ^ "vector_update_subrange" ^ ord_suffix in
- let start_idx = match simplify_nexp (start) with
+ let start_idx = match nexp_simp (start) with
| Nexp_aux (Nexp_constant i, _) -> expN (simple_num l i)
| _ ->
let nc = nc_eq start (nminus len (nconstant 1)) in
@@ -1146,7 +1146,7 @@ let doc_exp_lem, doc_let_lem =
| E_internal_return (e1) ->
separate space [string "return"; expY e1;]
| E_sizeof nexp ->
- (match simplify_nexp nexp with
+ (match nexp_simp nexp with
| Nexp_aux (Nexp_constant i, _) -> doc_lit_lem sequential mwords false (L_aux (L_num i, l)) annot
| _ ->
raise (Reporting_basic.err_unreachable l
@@ -1249,7 +1249,7 @@ let doc_typdef_lem sequential mwords (TD_aux(td, (l, annot))) = match td with
let (start, is_inc) =
try
let (start, _, ord, _) = vector_typ_args_of base_ftyp in
- match simplify_nexp start with
+ match nexp_simp start with
| Nexp_aux (Nexp_constant i, _) -> (i, is_order_inc ord)
| _ ->
raise (Reporting_basic.err_unreachable Parse_ast.Unknown
@@ -1761,7 +1761,7 @@ let doc_register_refs_lem registers =
let (start, is_inc) =
try
let (start, _, ord, _) = vector_typ_args_of base_typ in
- match simplify_nexp start with
+ match nexp_simp start with
| Nexp_aux (Nexp_constant i, _) -> (i, is_order_inc ord)
| _ ->
raise (Reporting_basic.err_unreachable Parse_ast.Unknown
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 5329b01d..99f7b15f 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -1070,7 +1070,7 @@ let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp =
E_aux (E_lit (L_aux (L_num c, l)), (l, Some (env, atom_typ nexp, no_effect)))
| E_sizeof nexp ->
begin
- match simplify_nexp (rewrite_nexp_ids (env_of orig_exp) nexp) with
+ match nexp_simp (rewrite_nexp_ids (env_of orig_exp) nexp) with
| Nexp_aux (Nexp_constant c, _) ->
E_aux (E_lit (L_aux (L_num c, l)), (l, Some (env, atom_typ nexp, no_effect)))
| _ ->
@@ -1142,7 +1142,7 @@ let rewrite_sizeof (Defs defs) =
Id_aux (Id op, Parse_ast.Unknown),
E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2))
) in
- let (Nexp_aux (nexp, l) as nexp_aux) = simplify_nexp nexp_aux in
+ let (Nexp_aux (nexp, l) as nexp_aux) = nexp_simp nexp_aux in
(match nexp with
| Nexp_constant i -> E_lit (L_aux (L_num i, l))
| Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2
@@ -2735,7 +2735,7 @@ let rewrite_tuple_vector_assignments defs =
let ltyp = Env.base_typ_of env (typ_of_annot lannot) in
if is_vector_typ ltyp then
let (_, len, _, _) = vector_typ_args_of ltyp in
- match simplify_nexp len with
+ match nexp_simp len with
| Nexp_aux (Nexp_constant len, _) -> len
| _ -> 1
else 1 in
@@ -2743,7 +2743,7 @@ let rewrite_tuple_vector_assignments defs =
if is_order_inc ord
then (i + step - 1, i + step)
else (i - step + 1, i - step) in
- let i = match simplify_nexp start with
+ let i = match nexp_simp start with
| (Nexp_aux (Nexp_constant i, _)) -> i
| _ -> if is_order_inc ord then 0 else List.length lexps - 1 in
let l = gen_loc (fst annot) in