diff options
Diffstat (limited to 'src/ast_util.ml')
| -rw-r--r-- | src/ast_util.ml | 162 |
1 files changed, 64 insertions, 98 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 4a887898..8106f89c 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -108,11 +108,70 @@ let is_typ_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_kind [BK_aux (BK_type, _)], _), _), _) -> true | _ -> false +let string_of_kid = function + | Kid_aux (Var v, _) -> v + +module Kid = struct + type t = kid + let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2) +end + +module Id = struct + type t = id + let compare id1 id2 = + match (id1, id2) with + | Id_aux (Id x, _), Id_aux (Id y, _) -> String.compare x y + | Id_aux (DeIid x, _), Id_aux (DeIid y, _) -> String.compare x y + | Id_aux (Id _, _), Id_aux (DeIid _, _) -> -1 + | Id_aux (DeIid _, _), Id_aux (Id _, _) -> 1 +end + +module Nexp = struct + type t = nexp + let rec compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = + let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in + match nexp1, nexp2 with + | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 + | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 + | Nexp_constant c1, Nexp_constant c2 -> Pervasives.compare c1 c2 + | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) + | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) + | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> + lex_ord (compare n1a n2a, compare n1b n2b) + | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2 + | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2 + | Nexp_constant _, _ -> -1 | _, Nexp_constant _ -> 1 + | Nexp_id _, _ -> -1 | _, Nexp_id _ -> 1 + | Nexp_var _, _ -> -1 | _, Nexp_var _ -> 1 + | Nexp_neg _, _ -> -1 | _, Nexp_neg _ -> 1 + | Nexp_exp _, _ -> -1 | _, Nexp_exp _ -> 1 + | Nexp_minus _, _ -> -1 | _, Nexp_minus _ -> 1 + | Nexp_sum _, _ -> -1 | _, Nexp_sum _ -> 1 + | Nexp_times _, _ -> -1 | _, Nexp_times _ -> 1 +end + +module Bindings = Map.Make(Id) +module IdSet = Set.Make(Id) +module KBindings = Map.Make(Kid) +module KidSet = Set.Make(Kid) +module NexpSet = Set.Make(Nexp) + +let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0) + +let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with + | Nexp_id _ | Nexp_var _ -> false + | Nexp_constant _ -> true + | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> + is_nexp_constant n1 && is_nexp_constant n2 + | Nexp_exp n | Nexp_neg n -> is_nexp_constant n + let rec nexp_simp (Nexp_aux (nexp, l)) = Nexp_aux (nexp_simp_aux nexp, l) and nexp_simp_aux = function - | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), Nexp_aux (Nexp_constant c1, _)), _), Nexp_aux (Nexp_constant c2, _)) when c1 = c2 -> + | Nexp_minus (Nexp_aux (Nexp_sum (Nexp_aux (n1, _), nexp2), _), nexp3) + when nexp_identical nexp2 nexp3 -> nexp_simp_aux n1 - | Nexp_sum (Nexp_aux (Nexp_minus (Nexp_aux (n1, _), Nexp_aux (Nexp_constant c1, _)), _), Nexp_aux (Nexp_constant c2, _)) when c1 = c2 -> + | Nexp_sum (Nexp_aux (Nexp_minus (Nexp_aux (n1, _), nexp2), _), nexp3) + when nexp_identical nexp2 nexp3 -> nexp_simp_aux n1 | Nexp_sum (n1, n2) -> begin @@ -128,6 +187,8 @@ and nexp_simp_aux = function let (Nexp_aux (n1_simp, _) as n1) = nexp_simp n1 in let (Nexp_aux (n2_simp, _) as n2) = nexp_simp n2 in match n1_simp, n2_simp with + | Nexp_constant 1, _ -> n2_simp + | _, Nexp_constant 1 -> n1_simp | Nexp_constant c1, Nexp_constant c2 -> Nexp_constant (c1 * c2) | _, _ -> Nexp_times (n1, n2) end @@ -333,9 +394,6 @@ let kid_of_id = function | Id_aux (Id v, l) -> Kid_aux (Var ("'" ^ v), l) | Id_aux (DeIid v, _) -> assert false -let string_of_kid = function - | Kid_aux (Var v, _) -> v - let prepend_id str = function | Id_aux (Id v, l) -> Id_aux (Id (str ^ v), l) | Id_aux (DeIid v, l) -> Id_aux (DeIid (str ^ v), l) @@ -572,56 +630,12 @@ let id_of_fundef (FD_aux (FD_function (_, _, _, funcls), (l, _))) = | Some id -> id | None -> raise (Reporting_basic.err_typ l "funcl list is empty") -module Kid = struct - type t = kid - let compare kid1 kid2 = String.compare (string_of_kid kid1) (string_of_kid kid2) -end - module BE = struct type t = base_effect let compare be1 be2 = String.compare (string_of_base_effect be1) (string_of_base_effect be2) end -module Id = struct - type t = id - let compare id1 id2 = - match (id1, id2) with - | Id_aux (Id x, _), Id_aux (Id y, _) -> String.compare x y - | Id_aux (DeIid x, _), Id_aux (DeIid y, _) -> String.compare x y - | Id_aux (Id _, _), Id_aux (DeIid _, _) -> -1 - | Id_aux (DeIid _, _), Id_aux (Id _, _) -> 1 -end - -module Nexp = struct - type t = nexp - let rec compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = - let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in - match nexp1, nexp2 with - | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 - | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 - | Nexp_constant c1, Nexp_constant c2 -> Pervasives.compare c1 c2 - | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) - | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) - | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> - lex_ord (compare n1a n2a, compare n1b n2b) - | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2 - | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2 - | Nexp_constant _, _ -> -1 | _, Nexp_constant _ -> 1 - | Nexp_id _, _ -> -1 | _, Nexp_id _ -> 1 - | Nexp_var _, _ -> -1 | _, Nexp_var _ -> 1 - | Nexp_neg _, _ -> -1 | _, Nexp_neg _ -> 1 - | Nexp_exp _, _ -> -1 | _, Nexp_exp _ -> 1 - | Nexp_minus _, _ -> -1 | _, Nexp_minus _ -> 1 - | Nexp_sum _, _ -> -1 | _, Nexp_sum _ -> 1 - | Nexp_times _, _ -> -1 | _, Nexp_times _ -> 1 -end - module BESet = Set.Make(BE) -module Bindings = Map.Make(Id) -module IdSet = Set.Make(Id) -module KBindings = Map.Make(Kid) -module KidSet = Set.Make(Kid) -module NexpSet = Set.Make(Nexp) let rec nexp_frees (Nexp_aux (nexp, l)) = match nexp with @@ -634,54 +648,6 @@ let rec nexp_frees (Nexp_aux (nexp, l)) = | Nexp_exp n -> nexp_frees n | Nexp_neg n -> nexp_frees n -let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0) - -let rec is_nexp_constant (Nexp_aux (nexp, _)) = match nexp with - | Nexp_id _ | Nexp_var _ -> false - | Nexp_constant _ -> true - | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> - is_nexp_constant n1 && is_nexp_constant n2 - | Nexp_exp n | Nexp_neg n -> is_nexp_constant n - -let rec simplify_nexp (Nexp_aux (nexp, l)) = - let rewrap n = Nexp_aux (n, l) in - let try_binop op n1 n2 c = (match simplify_nexp n1, simplify_nexp n2 with - | Nexp_aux (Nexp_constant i1, _), Nexp_aux (Nexp_constant i2, _) -> - rewrap (Nexp_constant (op i1 i2)) - | n1, n2 -> rewrap (c n1 n2)) in - match nexp with - | Nexp_times (Nexp_aux (Nexp_constant 1,_),n') - | Nexp_times (n',Nexp_aux (Nexp_constant 1,_)) - -> simplify_nexp n' - | Nexp_times (n1, n2) -> try_binop ( * ) n1 n2 (fun n1 n2 -> Nexp_times (n1, n2)) - | Nexp_sum (n1, n2) -> try_binop ( + ) n1 n2 (fun n1 n2 -> Nexp_sum (n1, n2)) - | Nexp_minus (n', Nexp_aux (Nexp_constant 0,_)) -> simplify_nexp n' - (* A vector range x['n-1 .. 0] can result in the size "('n-1) - -1" *) - | Nexp_minus (n1, n2) -> - begin - match simplify_nexp n1, simplify_nexp n2 with - | Nexp_aux (Nexp_minus (n', Nexp_aux (Nexp_constant 1,_)),_), - Nexp_aux (Nexp_constant (-1),_) -> simplify_nexp n' - | Nexp_aux (Nexp_constant i1,_), Nexp_aux (Nexp_constant i2,_) -> - rewrap (Nexp_constant (i1-i2)) - | n1',n2' -> rewrap (Nexp_minus (n1,n2)) - end - (* | Nexp_exp n -> - (match simplify_nexp n with - | Nexp_aux (Nexp_constant i, _) -> - rewrap (Nexp_constant (power 2 i)) - | n -> rewrap (Nexp_exp n)) *) - | Nexp_neg n -> - (match simplify_nexp n with - | Nexp_aux (Nexp_constant i, _) -> - rewrap (Nexp_constant (-i)) - | n -> rewrap (Nexp_neg n)) - | _ -> rewrap nexp - (* | Nexp_sum of nexp * nexp (* sum *) - | Nexp_minus of nexp * nexp (* subtraction *) - | Nexp_exp of nexp (* exponential *) - | Nexp_neg of nexp (* For internal use *) *) - let rec lexp_to_exp (LEXP_aux (lexp_aux, annot) as le) = let rewrap e_aux = E_aux (e_aux, annot) in match lexp_aux with @@ -726,7 +692,7 @@ let typ_app_args_of = function let rec vector_typ_args_of typ = match typ_app_args_of typ with | ("vector", [Typ_arg_nexp start; Typ_arg_nexp len; Typ_arg_order ord; Typ_arg_typ etyp], _) -> - (simplify_nexp start, simplify_nexp len, ord, etyp) + (nexp_simp start, nexp_simp len, ord, etyp) | ("register", [Typ_arg_typ rtyp], _) -> vector_typ_args_of rtyp | (_, _, l) -> raise (Reporting_basic.err_typ l |
