summaryrefslogtreecommitdiff
path: root/src/ast_util.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/ast_util.ml')
-rw-r--r--src/ast_util.ml162
1 files changed, 64 insertions, 98 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 4a887898..8106f89c 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -108,11 +108,70 @@ let is_typ_kopt = function
| KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_type, _)], _), _), _) -> true
| _ -> false
+let string_of_kid = function
+ | Kid_aux (Var v, _) -> v
+
+module Kid = struct
+ type t = kid
+ let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2)
+end
+
+module Id = struct
+ type t = id
+ let compare id1 id2 =
+ match (id1, id2) with
+ | Id_aux (Id x, _), Id_aux (Id y, _) -> String.compare x y
+ | Id_aux (DeIid x, _), Id_aux (DeIid y, _) -> String.compare x y
+ | Id_aux (Id _, _), Id_aux (DeIid _, _) -> -1
+ | Id_aux (DeIid _, _), Id_aux (Id _, _) -> 1
+end
+
+module Nexp = struct
+ type t = nexp
+ let rec compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) =
+ let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in
+ match nexp1, nexp2 with
+ | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2
+ | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2
+ | Nexp_constant c1, Nexp_constant c2 -> Pervasives.compare c1 c2
+ | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b)
+ | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b)
+ | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) ->
+ lex_ord (compare n1a n2a, compare n1b n2b)
+ | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2
+ | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2
+ | Nexp_constant _, _ -> -1 | _, Nexp_constant _ -> 1
+ | Nexp_id _, _ -> -1 | _, Nexp_id _ -> 1
+ | Nexp_var _, _ -> -1 | _, Nexp_var _ -> 1
+ | Nexp_neg _, _ -> -1 | _, Nexp_neg _ -> 1
+ | Nexp_exp _, _ -> -1 | _, Nexp_exp _ -> 1
+ | Nexp_minus _, _ -> -1 | _, Nexp_minus _ -> 1
+ | Nexp_sum _, _ -> -1 | _, Nexp_sum _ -> 1
+ | Nexp_times _, _ -> -1 | _, Nexp_times _ -> 1
+end
+
+module Bindings = Map.Make(Id)
+module IdSet = Set.Make(Id)
+module KBindings = Map.Make(Kid)
+module KidSet = Set.Make(Kid)
+module NexpSet = Set.Make(Nexp)
+
+let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0)
+
+let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with
+ | Nexp_id _ | Nexp_var _ -> false
+ | Nexp_constant _ -> true
+ | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) ->
+ is_nexp_constant n1 && is_nexp_constant n2
+ | Nexp_exp n | Nexp_neg n -> is_nexp_constant n
+
let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l)
and nexp_simp_aux = function
- | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), Nexp_aux (Nexp_constant c1, _)), _), Nexp_aux (Nexp_constant c2, _)) when c1 = c2 ->
+ | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), nexp2), _), nexp3)
+ when nexp_identical nexp2 nexp3 ->
nexp_simp_aux n1
- | Nexp_sum (Nexp_aux (Nexp_minus (Nexp_aux (n1, _), Nexp_aux (Nexp_constant c1, _)), _), Nexp_aux (Nexp_constant c2, _)) when c1 = c2 ->
+ | Nexp_sum (Nexp_aux (Nexp_minus (Nexp_aux (n1, _), nexp2), _), nexp3)
+ when nexp_identical nexp2 nexp3 ->
nexp_simp_aux n1
| Nexp_sum (n1, n2) ->
begin
@@ -128,6 +187,8 @@ and nexp_simp_aux = function
let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in
let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in
match n1_simp, n2_simp with
+ | Nexp_constant 1, _ -> n2_simp
+ | _, Nexp_constant 1 -> n1_simp
| Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (c1 * c2)
| _, _ -> Nexp_times (n1, n2)
end
@@ -333,9 +394,6 @@ let kid_of_id = function
| Id_aux (Id v, l) -> Kid_aux (Var ("'" ^ v), l)
| Id_aux (DeIid v, _) -> assert false
-let string_of_kid = function
- | Kid_aux (Var v, _) -> v
-
let prepend_id str = function
| Id_aux (Id v, l) -> Id_aux (Id (str ^ v), l)
| Id_aux (DeIid v, l) -> Id_aux (DeIid (str ^ v), l)
@@ -572,56 +630,12 @@ let id_of_fundef (FD_aux (FD_function (_, _, _, funcls), (l, _))) =
| Some id -> id
| None -> raise (Reporting_basic.err_typ l "funcl list is empty")
-module Kid = struct
- type t = kid
- let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2)
-end
-
module BE = struct
type t = base_effect
let compare be1 be2 = String.compare (string_of_base_effect be1) (string_of_base_effect be2)
end
-module Id = struct
- type t = id
- let compare id1 id2 =
- match (id1, id2) with
- | Id_aux (Id x, _), Id_aux (Id y, _) -> String.compare x y
- | Id_aux (DeIid x, _), Id_aux (DeIid y, _) -> String.compare x y
- | Id_aux (Id _, _), Id_aux (DeIid _, _) -> -1
- | Id_aux (DeIid _, _), Id_aux (Id _, _) -> 1
-end
-
-module Nexp = struct
- type t = nexp
- let rec compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) =
- let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in
- match nexp1, nexp2 with
- | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2
- | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2
- | Nexp_constant c1, Nexp_constant c2 -> Pervasives.compare c1 c2
- | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b)
- | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b)
- | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) ->
- lex_ord (compare n1a n2a, compare n1b n2b)
- | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2
- | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2
- | Nexp_constant _, _ -> -1 | _, Nexp_constant _ -> 1
- | Nexp_id _, _ -> -1 | _, Nexp_id _ -> 1
- | Nexp_var _, _ -> -1 | _, Nexp_var _ -> 1
- | Nexp_neg _, _ -> -1 | _, Nexp_neg _ -> 1
- | Nexp_exp _, _ -> -1 | _, Nexp_exp _ -> 1
- | Nexp_minus _, _ -> -1 | _, Nexp_minus _ -> 1
- | Nexp_sum _, _ -> -1 | _, Nexp_sum _ -> 1
- | Nexp_times _, _ -> -1 | _, Nexp_times _ -> 1
-end
-
module BESet = Set.Make(BE)
-module Bindings = Map.Make(Id)
-module IdSet = Set.Make(Id)
-module KBindings = Map.Make(Kid)
-module KidSet = Set.Make(Kid)
-module NexpSet = Set.Make(Nexp)
let rec nexp_frees (Nexp_aux (nexp, l)) =
match nexp with
@@ -634,54 +648,6 @@ let rec nexp_frees (Nexp_aux (nexp, l)) =
| Nexp_exp n -> nexp_frees n
| Nexp_neg n -> nexp_frees n
-let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0)
-
-let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with
- | Nexp_id _ | Nexp_var _ -> false
- | Nexp_constant _ -> true
- | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) ->
- is_nexp_constant n1 && is_nexp_constant n2
- | Nexp_exp n | Nexp_neg n -> is_nexp_constant n
-
-let rec simplify_nexp (Nexp_aux (nexp, l)) =
- let rewrap n = Nexp_aux (n, l) in
- let try_binop op n1 n2 c = (match simplify_nexp n1, simplify_nexp n2 with
- | Nexp_aux (Nexp_constant i1, _), Nexp_aux (Nexp_constant i2, _) ->
- rewrap (Nexp_constant (op i1 i2))
- | n1, n2 -> rewrap (c n1 n2)) in
- match nexp with
- | Nexp_times (Nexp_aux (Nexp_constant 1,_),n')
- | Nexp_times (n',Nexp_aux (Nexp_constant 1,_))
- -> simplify_nexp n'
- | Nexp_times (n1, n2) -> try_binop ( * ) n1 n2 (fun n1 n2 -> Nexp_times (n1, n2))
- | Nexp_sum (n1, n2) -> try_binop ( + ) n1 n2 (fun n1 n2 -> Nexp_sum (n1, n2))
- | Nexp_minus (n', Nexp_aux (Nexp_constant 0,_)) -> simplify_nexp n'
- (* A vector range x['n-1 .. 0] can result in the size "('n-1) - -1" *)
- | Nexp_minus (n1, n2) ->
- begin
- match simplify_nexp n1, simplify_nexp n2 with
- | Nexp_aux (Nexp_minus (n', Nexp_aux (Nexp_constant 1,_)),_),
- Nexp_aux (Nexp_constant (-1),_) -> simplify_nexp n'
- | Nexp_aux (Nexp_constant i1,_), Nexp_aux (Nexp_constant i2,_) ->
- rewrap (Nexp_constant (i1-i2))
- | n1',n2' -> rewrap (Nexp_minus (n1,n2))
- end
- (* | Nexp_exp n ->
- (match simplify_nexp n with
- | Nexp_aux (Nexp_constant i, _) ->
- rewrap (Nexp_constant (power 2 i))
- | n -> rewrap (Nexp_exp n)) *)
- | Nexp_neg n ->
- (match simplify_nexp n with
- | Nexp_aux (Nexp_constant i, _) ->
- rewrap (Nexp_constant (-i))
- | n -> rewrap (Nexp_neg n))
- | _ -> rewrap nexp
- (* | Nexp_sum of nexp * nexp (* sum *)
- | Nexp_minus of nexp * nexp (* subtraction *)
- | Nexp_exp of nexp (* exponential *)
- | Nexp_neg of nexp (* For internal use *) *)
-
let rec lexp_to_exp (LEXP_aux (lexp_aux, annot) as le) =
let rewrap e_aux = E_aux (e_aux, annot) in
match lexp_aux with
@@ -726,7 +692,7 @@ let typ_app_args_of = function
let rec vector_typ_args_of typ = match typ_app_args_of typ with
| ("vector", [Typ_arg_nexp start; Typ_arg_nexp len; Typ_arg_order ord; Typ_arg_typ etyp], _) ->
- (simplify_nexp start, simplify_nexp len, ord, etyp)
+ (nexp_simp start, nexp_simp len, ord, etyp)
| ("register", [Typ_arg_typ rtyp], _) -> vector_typ_args_of rtyp
| (_, _, l) ->
raise (Reporting_basic.err_typ l