summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorChristopher Pulte2015-11-19 14:37:19 +0000
committerChristopher Pulte2015-11-19 14:37:19 +0000
commita1d41f415a555bbe31e86375601e75f8ecf37f54 (patch)
treea404c7bd198763b1ffa9b3048a7419ea3ddefe4d /src
parent3323f7a685f0aa7d125a9f348112b6e25fb392ae (diff)
fixes for cumulative effect anotations
Diffstat (limited to 'src')
-rw-r--r--src/gen_lib/sail_values.lem107
-rw-r--r--src/gen_lib/state.lem147
-rw-r--r--src/gen_lib/vector.lem67
-rw-r--r--src/pretty_print.ml248
-rw-r--r--src/rewriter.ml137
5 files changed, 405 insertions, 301 deletions
diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem
index 2104d072..4d74976b 100644
--- a/src/gen_lib/sail_values.lem
+++ b/src/gen_lib/sail_values.lem
@@ -3,46 +3,7 @@ open import State
open import Vector
open import Arch
-let to_bool = function
- | O -> false
- | I -> true
- end
-
-let get_start (V _ s _) = s
-let length (V bs _ _) = length bs
-
-let write_two_regs r1 r2 vec =
- let size = length_reg r1 in
- let start = get_start vec in
- let vsize = length vec in
- let r1_v = slice vec start ((if defaultDir then size - start else start - size) - 1) in
- let r2_v =
- (slice vec)
- (if defaultDir then size - start else start - size)
- (if defaultDir then vsize - start else start - vsize) in
- write_reg r1 r1_v >> write_reg r2 r2_v
-
-let rec replace bs ((n : nat),b') = match (n,bs) with
- | (_, []) -> []
- | (0, _::bs) -> b' :: bs
- | (n+1, b::bs) -> b :: replace bs (n,b')
- end
-
-let make_indexed_vector_reg entries default start length =
- let (Just v) = default in
- V (List.foldl replace (replicate length v) entries) start
-
-let make_indexed_vector_bit entries default start length =
- let default = match default with Nothing -> U | Just v -> v end in
- V (List.foldl replace (replicate length default) entries) start
-
-let make_bitvector_undef length =
- V (replicate length U) 0 true
-
-let vector_concat (V bs start is_inc) (V bs' _ _) =
- V(bs ++ bs') start is_inc
-
-let (^^) = vector_concat
+let length l = integerFromNat (length l)
let has_undef (V bs _ _) = List.any (function U -> true | _ -> false end) bs
@@ -100,7 +61,7 @@ let unsigned (V bs _ _ as v) : integer =
(fun (acc,exp) b -> (acc + (if b = I then integerPow 2 exp else 0),exp +1)) (0,0) bs)
end
-let signed v =
+let signed v : integer =
match most_significant v with
| I -> 0 - (1 + (unsigned (bitwise_not v)))
| O -> unsigned v
@@ -119,43 +80,43 @@ let min_8 = (0 - 128 : integer)
let max_5 = (31 : integer)
let min_5 = (0 - 32 : integer)
-let get_max_representable_in sign n =
+let get_max_representable_in sign (n : integer) : integer =
if (n = 64) then match sign with | true -> max_64 | false -> max_64u end
else if (n=32) then match sign with | true -> max_32 | false -> max_32u end
else if (n=8) then max_8
else if (n=5) then max_5
- else match sign with | true -> integerPow 2 (n -1)
- | false -> integerPow 2 n
+ else match sign with | true -> integerPow 2 ((natFromInteger n) -1)
+ | false -> integerPow 2 (natFromInteger n)
end
-let get_min_representable_in _ n =
+let get_min_representable_in _ (n : integer) : integer =
if n = 64 then min_64
else if n = 32 then min_32
else if n = 8 then min_8
else if n = 5 then min_5
- else 0 - (integerPow 2 n)
+ else 0 - (integerPow 2 (natFromInteger n))
-let rec divide_by_2 bs i (n : integer) =
+let rec divide_by_2 bs (i : integer) (n : integer) =
if i < 0 || n = 0
then bs
else
if (n mod 2 = 1)
- then divide_by_2 (replace bs (i,I)) (i - 1) (n / 2)
+ then divide_by_2 (replace bs (natFromInteger i,I)) (i - 1) (n / 2)
else divide_by_2 bs (i-1) (n div 2)
-let rec add_one_bit bs co i =
+let rec add_one_bit bs co (i : integer) =
if i < 0 then bs
else match (nth bs i,co) with
- | (O,false) -> replace bs (i,I)
- | (O,true) -> add_one_bit (replace bs (i,I)) true (i-1)
- | (I,false) -> add_one_bit (replace bs (i,O)) true (i-1)
+ | (O,false) -> replace bs (natFromInteger i,I)
+ | (O,true) -> add_one_bit (replace bs (natFromInteger i,I)) true (i-1)
+ | (I,false) -> add_one_bit (replace bs (natFromInteger i,O)) true (i-1)
| (I,true) -> add_one_bit bs true (i-1)
(* | Vundef,_ -> assert false*)
end
-let to_vec is_inc (len,(n : integer)) =
- let bs = List.replicate len O in
+let to_vec is_inc ((len : integer),(n : integer)) =
+ let bs = List.replicate (natFromInteger len) O in
let start = if is_inc then 0 else len-1 in
if n = 0 then
V bs start is_inc
@@ -169,8 +130,11 @@ let to_vec is_inc (len,(n : integer)) =
let to_vec_inc = to_vec true
let to_vec_dec = to_vec false
-let to_vec_undef is_inc len =
- V (replicate len U) (if is_inc then 0 else len-1) is_inc
+let to_vec_undef is_inc (len : integer) =
+ V (replicate (natFromInteger len) U) (if is_inc then 0 else len-1) is_inc
+
+let to_vec_inc_undef = to_vec_undef true
+let to_vec_dec_undef = to_vec_undef false
let add = uncurry integerAdd
let add_signed = uncurry integerAdd
@@ -180,7 +144,7 @@ let modulo = uncurry integerMod
let quot = uncurry integerDiv
let power = uncurry integerPow
-let arith_op_vec op sign size ((V _ _ is_inc as l),r) =
+let arith_op_vec op sign (size : integer) ((V _ _ is_inc as l),r) =
let (l',r') = (to_num sign l, to_num sign r) in
let n = op l' r' in
to_vec is_inc (size * (length l),n)
@@ -228,7 +192,7 @@ let arith_op_vec_vec_range op sign ((V _ _ is_inc as l),r) =
let add_vec_vec_range = arith_op_vec_vec_range integerAdd false
let add_vec_vec_range_signed = arith_op_vec_vec_range integerAdd true
-let arith_op_vec_bit op sign (size : nat) ((V _ _ is_inc as l),r) =
+let arith_op_vec_bit op sign (size : integer) ((V _ _ is_inc as l),r) =
let l' = to_num sign l in
let n = op l' match r with | I -> (1 : integer) | _ -> 0 end in
to_vec is_inc (length l * size,n)
@@ -260,7 +224,7 @@ let minus_overflow_vec_signed = arith_op_overflow_vec integerMinus true 1
let mult_overflow_vec = arith_op_overflow_vec integerMult false 2
let mult_overflow_vec_signed = arith_op_overflow_vec integerMult true 2
-let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (size : nat)
+let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (size : integer)
((V _ _ is_inc as l),r_bit) =
let act_size = length l * size in
let l' = to_num sign l in
@@ -286,17 +250,16 @@ let minus_overflow_vec_bit_signed = arith_op_overflow_vec_bit integerMinus true
type shift = LL | RR | LLL
-let shift_op_vec op ((V bs start is_inc as l),r) =
- let len = List.length bs in
- let n = r in
+let shift_op_vec op ((V bs start is_inc as l),(n : integer)) =
+ let len = integerFromNat (List.length bs) in
match op with
| LL (*"<<"*) ->
- let right_vec = V (List.replicate n O) 0 true in
+ let right_vec = V (List.replicate (natFromInteger n) O) 0 true in
let left_vec = slice l n (if is_inc then len + start else start - len) in
vector_concat left_vec right_vec
| RR (*">>"*) ->
let right_vec = slice l start n in
- let left_vec = V (List.replicate n O) 0 true in
+ let left_vec = V (List.replicate (natFromInteger n) O) 0 true in
vector_concat left_vec right_vec
| LLL (*"<<<"*) ->
let left_vec = slice l n (if is_inc then len + start else start - len) in
@@ -326,7 +289,7 @@ let rec arith_op_vec_no0 (op : integer -> integer -> integer) sign size (((V _ s
end in
if representable
then to_vec is_inc (act_size,n')
- else V (List.replicate act_size U) start is_inc
+ else V (List.replicate (natFromInteger act_size) U) start is_inc
let mod_vec = arith_op_vec_no0 integerMod false 1
let quot_vec = arith_op_vec_no0 integerDiv false 1
@@ -350,8 +313,8 @@ let arith_op_overflow_no0_vec op sign size (((V _ start is_inc) as l),r) =
if representable then
(to_vec is_inc (act_size,n'),to_vec is_inc (act_size + 1,n_u'))
else
- (V (List.replicate act_size U) start is_inc,
- V (List.replicate (act_size + 1) U) start is_inc) in
+ (V (List.replicate (natFromInteger act_size) U) start is_inc,
+ V (List.replicate (natFromInteger (act_size + 1)) U) start is_inc) in
let overflow = if representable then O else I in
(correct_size_num,overflow,most_significant one_more)
@@ -364,7 +327,7 @@ let arith_op_vec_range_no0 op sign size ((V _ _ is_inc as l),r) =
let mod_vec_range = arith_op_vec_range_no0 integerMod false 1
let duplicate (bit,length) =
- V (List.replicate length bit) 0 true
+ V (List.replicate (natFromInteger length) bit) 0 true
let compare_op op (l,r) = bool_to_bit (op l r)
@@ -415,3 +378,11 @@ let eq_vec_vec (l,r) = eq (to_num true l, to_num true r)
let neq (l,r) = bitwise_not_bit (eq (l,r))
let neq_vec (l,r) = bitwise_not_bit (eq_vec_vec (l,r))
+let neq_vec_range (l,r) = bitwise_not_bit (eq_vec_range (l,r))
+let neq_range_vec (l,r) = bitwise_not_bit (eq_range_vec (l,r))
+
+
+let EXTS (v1,(V _ _ is_inc as v)) =
+ to_vec is_inc (v1,signed v)
+
+let EXTZ = EXTS
diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem
index dee300ef..ac65a347 100644
--- a/src/gen_lib/state.lem
+++ b/src/gen_lib/state.lem
@@ -2,19 +2,60 @@ open import Pervasives
open import Vector
open import Arch
-type M 's 'a = 's -> ('a * 's)
+(* 'a is result type, 'e is error type *)
+type M 's 'e 'a = 's -> (either 'a 'e * 's)
-val return : forall 's 'a. 'a -> M 's 'a
-let return a = fun s -> (a,s)
+val return : forall 's 'e 'a. 'a -> M 's 'e 'a
+let return a s = (Left a,s)
-val bind : forall 's 'a 'b. M 's 'a -> ('a -> M 's 'b) -> M 's 'b
-let bind m f s = let (a,s') = m s in f a s'
+val bind : forall 's 'e 'a 'b. M 's 'e 'a -> ('a -> M 's 'e 'b) -> M 's 'e 'b
+let bind m f s = match m s with
+ | (Left a,s') -> f a s'
+ | (Right error,s') -> (Right error,s')
+ end
+
+val exit : forall 's 'e 'a. 'e -> M 's 'e 'a
+let exit e s = (Right e,s)
let (>>=) = bind
let (>>) m n = m >>= fun _ -> n
-val foreach_inc : forall 's 'vars. (nat * nat * nat) -> 'vars ->
- (nat -> 'vars -> (unit * 'vars)) -> (unit * 'vars)
+val read_reg_range : forall 'e. register -> (integer * integer) (*(nat * nat)*) -> M state 'e (vector bit)
+let read_reg_range reg (i,j) s =
+ let v = slice (read_regstate s reg) i j in
+ (Left v,s)
+
+val read_reg_bit : forall 'e. register -> integer (*nat*) -> M state 'e bit
+let read_reg_bit reg i s =
+ let v = access (read_regstate s reg) i in
+ (Left v,s)
+
+val write_reg_range : forall 'e. register -> (integer * integer) (*(nat * nat)*) -> vector bit -> M state 'e unit
+let write_reg_range (reg : register) (i,j) (v : vector bit) s =
+ let v' = update (read_regstate s reg) i j v in
+ let s' = write_regstate s reg v' in
+ (Left (),s')
+
+val write_reg_bit : forall 'e. register -> integer (*nat*) -> bit -> M state 'e unit
+let write_reg_bit reg i bit s =
+ let v = read_regstate s reg in
+ let v' = update_pos v i bit in
+ let s' = write_regstate s reg v' in
+ (Left (),s')
+
+val read_reg : forall 'e. register -> M state 'e (vector bit)
+let read_reg reg s =
+ let v = read_regstate s reg in
+ (Left v,s)
+
+val write_reg : forall 'e. register -> vector bit -> M state 'e unit
+let write_reg reg v s =
+ let s' = write_regstate s reg v in
+ (Left (),s')
+
+
+val foreach_inc : forall 's 'e 'vars. (integer * integer * integer) (*(nat * nat * nat)*) -> 'vars ->
+ (integer (*nat*) -> 'vars -> (unit * 'vars)) -> (unit * 'vars)
let rec foreach_inc (i,stop,by) vars body =
if i <= stop
then
@@ -23,8 +64,8 @@ let rec foreach_inc (i,stop,by) vars body =
else ((),vars)
-val foreach_dec : forall 's 'vars. (nat * nat * nat) -> 'vars ->
- (nat -> 'vars -> (unit * 'vars)) -> (unit * 'vars)
+val foreach_dec : forall 's 'e 'vars. (integer * integer * integer) (*(nat * nat * nat)*) -> 'vars ->
+ (integer (*nat*) -> 'vars -> (unit * 'vars)) -> (unit * 'vars)
let rec foreach_dec (i,stop,by) vars body =
if i >= stop
then
@@ -33,8 +74,8 @@ let rec foreach_dec (i,stop,by) vars body =
else ((),vars)
-val foreachM_inc : forall 's 'vars. (nat * nat * nat) -> 'vars ->
- (nat -> 'vars -> M 's (unit * 'vars)) -> M 's (unit * 'vars)
+val foreachM_inc : forall 's 'e 'vars. (nat * nat * nat) -> 'vars ->
+ (nat -> 'vars -> M 's 'e (unit * 'vars)) -> M 's 'e (unit * 'vars)
let rec foreachM_inc (i,stop,by) vars body =
if i <= stop
then
@@ -43,8 +84,8 @@ let rec foreachM_inc (i,stop,by) vars body =
else return ((),vars)
-val foreachM_dec : forall 's 'vars. (nat * nat * nat) -> 'vars ->
- (nat -> 'vars -> M 's (unit * 'vars)) -> M 's (unit * 'vars)
+val foreachM_dec : forall 's 'e 'vars. (nat * nat * nat) -> 'vars ->
+ (nat -> 'vars -> M 's 'e (unit * 'vars)) -> M 's 'e (unit * 'vars)
let rec foreachM_dec (i,stop,by) vars body =
if i >= stop
then
@@ -52,72 +93,28 @@ let rec foreachM_dec (i,stop,by) vars body =
foreachM_dec (i - by,stop,by) vars body
else return ((),vars)
-
-
-let slice (V bs start is_inc) n m =
- let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in
- let (_,suffix) = List.splitAt offset bs in
- let (subvector,_) = List.splitAt length suffix in
- V subvector n is_inc
-
-let update (V bs start is_inc) n m (V bs' _ _) =
- let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in
- let (prefix,_) = List.splitAt offset bs in
- let (_,suffix) = List.splitAt (offset + length) bs in
- V (prefix ++ (List.take length bs') ++ suffix) start is_inc
-
-let hd (x :: _) = x
-
-val access : forall 'a. vector 'a -> nat -> 'a
-let access (V bs start is_inc) n =
- if is_inc then nth bs (n - start) else nth bs (start - n)
-
-val update_pos : forall 'a. vector 'a -> nat -> 'a -> vector 'a
-let update_pos v n b =
- update v n n (V [b] 0 defaultDir)
-
-val read_reg_range : register -> (nat * nat) -> M state (vector bit)
-let read_reg_range reg (i,j) s =
- let v = slice (read_regstate s reg) i j in
- (v,s)
-
-val read_reg_bit : register -> nat -> M state bit
-let read_reg_bit reg i s =
- let v = access (read_regstate s reg) i in
- (v,s)
-
-val write_reg_range : register -> (nat * nat) -> vector bit -> M state unit
-let write_reg_range (reg : register) (i,j) (v : vector bit) s =
- let v' = update (read_regstate s reg) i j v in
- let s' = write_regstate s reg v' in
- ((),s')
-
-val write_reg_bit : register -> nat -> bit -> M state unit
-let write_reg_bit reg i bit s =
- let v = read_regstate s reg in
- let v' = update_pos v i bit in
- let s' = write_regstate s reg v' in
- ((),s')
-
-
-val read_reg : register -> M state (vector bit)
-let read_reg reg s =
- let v = read_regstate s reg in
- (v,s)
-
-val write_reg : register -> vector bit -> M state unit
-let write_reg reg v s =
- let s' = write_regstate s reg v in
- ((),s')
-
-val read_reg_field : register -> register_field -> M state (vector bit)
+val read_reg_field : forall 'e. register -> register_field -> M state 'e (vector bit)
let read_reg_field reg rfield = read_reg_range reg (field_indices rfield)
-val write_reg_field : register -> register_field -> vector bit -> M state unit
+val write_reg_field : forall 'e. register -> register_field -> vector bit -> M state 'e unit
let write_reg_field reg rfield = write_reg_range reg (field_indices rfield)
-val read_reg_field_bit : register -> register_field_bit -> M state bit
+val read_reg_field_bit : forall 'e. register -> register_field_bit -> M state 'e bit
let read_reg_field_bit reg rbit = read_reg_bit reg (field_index_bit rbit)
-val write_reg_field_bit : register -> register_field_bit -> bit -> M state unit
+val write_reg_field_bit : forall 'e. register -> register_field_bit -> bit -> M state 'e unit
let write_reg_field_bit reg rbit = write_reg_bit reg (field_index_bit rbit)
+
+
+let length l = integerFromNat (length l)
+
+let write_two_regs r1 r2 vec =
+ let size = length_reg r1 in
+ let start = get_start vec in
+ let vsize = length vec in
+ let r1_v = slice vec start ((if defaultDir then size - start else start - size) - 1) in
+ let r2_v =
+ (slice vec)
+ (if defaultDir then size - start else start - size)
+ (if defaultDir then vsize - start else start - vsize) in
+ write_reg r1 r1_v >> write_reg r2 r2_v
diff --git a/src/gen_lib/vector.lem b/src/gen_lib/vector.lem
index f409ceb7..5e78e010 100644
--- a/src/gen_lib/vector.lem
+++ b/src/gen_lib/vector.lem
@@ -1,9 +1,72 @@
open import Pervasives
type bit = O | I | U
-type vector 'a = V of list 'a * nat * bool
+type vector 'a = V of list 'a * integer * bool
-let rec nth xs (n : nat) = match (n,xs) with
+let rec nth xs (n : integer) = match (n,xs) with
| (0,x :: xs) -> x
| (n + 1,x :: xs) -> nth xs n
end
+
+
+let to_bool = function
+ | O -> false
+ | I -> true
+ end
+
+let get_start (V _ s _) = s
+let length (V bs _ _) = length bs
+
+let rec replace bs ((n : nat),b') = match (n,bs) with
+ | (_, []) -> []
+ | (0, _::bs) -> b' :: bs
+ | (n+1, b::bs) -> b :: replace bs (n,b')
+ end
+
+let make_indexed_vector_reg entries default start length =
+ let (Just v) = default in
+ V (List.foldl replace (replicate length v) entries) start
+
+let make_indexed_vector_bit entries default start length =
+ let default = match default with Nothing -> U | Just v -> v end in
+ V (List.foldl replace (replicate length default) entries) start
+
+let make_bitvector_undef length =
+ V (replicate length U) 0 true
+
+let vector_concat (V bs start is_inc) (V bs' _ _) =
+ V (bs ++ bs') start is_inc
+
+let (^^) = vector_concat
+
+val slice : vector bit -> integer -> integer -> vector bit
+let slice (V bs start is_inc) n m =
+ let n = natFromInteger n in
+ let m = natFromInteger m in
+ let start = natFromInteger start in
+ let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in
+ let (_,suffix) = List.splitAt offset bs in
+ let (subvector,_) = List.splitAt length suffix in
+ let n = integerFromNat n in
+ V subvector n is_inc
+
+let update (V bs start is_inc) n m (V bs' _ _) =
+ let n = natFromInteger n in
+ let m = natFromInteger m in
+ let start = natFromInteger start in
+ let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in
+ let (prefix,_) = List.splitAt offset bs in
+ let (_,suffix) = List.splitAt (offset + length) bs in
+ let start = integerFromNat start in
+ V (prefix ++ (List.take length bs') ++ suffix) start is_inc
+
+let hd (x :: _) = x
+
+
+val access : forall 'a. vector 'a -> (*nat*) integer -> 'a
+let access (V bs start is_inc) n =
+ if is_inc then nth bs (n - start) else nth bs (start - n)
+
+val update_pos : forall 'a. vector 'a -> (*nat*) integer -> 'a -> vector 'a
+let update_pos v n b =
+ update v n n (V [b] 0 true)
diff --git a/src/pretty_print.ml b/src/pretty_print.ml
index 457effbb..02fc62a5 100644
--- a/src/pretty_print.ml
+++ b/src/pretty_print.ml
@@ -1185,8 +1185,8 @@ let doc_def def = group (match def with
let doc_defs (Defs(defs)) =
separate_map hardline doc_def defs
-let print ?(len=80) channel doc = ToChannel.pretty 1. len channel doc
-let to_buf ?(len=80) buf doc = ToBuffer.pretty 1. len buf doc
+let print ?(len=100) channel doc = ToChannel.pretty 1. len channel doc
+let to_buf ?(len=100) buf doc = ToBuffer.pretty 1. len buf doc
let pp_defs f d = print f (doc_defs d)
let pp_exp b e = to_buf b (doc_exp e)
@@ -1826,14 +1826,15 @@ let doc_id_lem_type (Id_aux(i,_)) =
* token in case of x ending with star. *)
parens (separate space [colon; string x; empty])
-let doc_id_lem_ctor (Id_aux(i,_)) =
+let doc_id_lem_ctor aexp_needed (Id_aux(i,_)) =
match i with
| Id("bit") -> string "bit"
| Id i -> string (String.capitalize i)
| DeIid x ->
(* add an extra space through empty to avoid a closing-comment
* token in case of x ending with star. *)
- parens (separate space [colon; string (String.capitalize x); empty])
+ let epp = separate space [colon; string (String.capitalize x); empty] in
+ if aexp_needed then parens epp else epp
let doc_typ_lem, doc_atomic_typ_lem =
(* following the structure of parser for precedence *)
@@ -1850,8 +1851,8 @@ let doc_typ_lem, doc_atomic_typ_lem =
Typ_arg_aux(Typ_arg_nexp n, _);
Typ_arg_aux(Typ_arg_nexp m, _);
Typ_arg_aux (Typ_arg_order ord, _);
- Typ_arg_aux (Typ_arg_typ typ, _)]) ->
- string "vector"
+ Typ_arg_aux (Typ_arg_typ typa, _)]) ->
+ string "vector" ^^ space ^^ parens (typ typa)
| Typ_app(Id_aux (Id "range", _), [
Typ_arg_aux(Typ_arg_nexp n, _);
Typ_arg_aux(Typ_arg_nexp m, _);]) ->
@@ -1883,7 +1884,10 @@ let doc_lit_lem in_pat (L_aux(l,_)) =
| L_one -> "I"
| L_false -> "O"
| L_true -> "I"
- | L_num i -> if i < 0 then "(0 " ^ string_of_int i ^ ")" else string_of_int i
+ | L_num i ->
+ let ipp = string_of_int i in
+ (if i < 0 then "((0"^ipp^") : integer)"
+ else "("^ipp^" : integer)")
| L_hex n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*)*)
| L_bin n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*)*)
| L_undef -> "U"
@@ -1904,7 +1908,7 @@ let doc_pat_lem apat_needed =
| P_app(id, ((_ :: _) as pats)) ->
(match annot with
| Base(_,Constructor _,_,_,_,_) ->
- doc_unop (doc_id_lem_ctor id) (separate_map space pat pats)
+ doc_unop (doc_id_lem_ctor true id) (separate_map space pat pats)
| _ -> empty)
| P_lit lit -> doc_lit_lem true lit
| P_wild -> underscore
@@ -1913,7 +1917,7 @@ let doc_pat_lem apat_needed =
| P_typ(typ,p) -> doc_op colon (pat p) (doc_typ_lem typ)
| P_app(id,[]) ->
(match annot with
- | Base(_,Constructor n,_,_,_,_) -> doc_id_lem_ctor id
+ | Base(_,Constructor n,_,_,_,_) -> doc_id_lem_ctor apat_needed id
| _ -> empty)
| P_vector pats ->
let non_bit_print () =
@@ -1938,33 +1942,56 @@ let doc_pat_lem apat_needed =
| P_list pats -> brackets (separate_map semi pat pats) (*Never seen but easy in lem*)
in pat
+let rec getregtyp (LEXP_aux (le,(_,annot))) = match le with
+ | LEXP_id _
+ | LEXP_cast _ ->
+ let (Base ((_,t),_,_,_,_,_)) = annot in
+ (match t with
+ | {t = Tabbrev ({t = Tid name},_)} -> name)
+ | LEXP_memory _ -> failwith "This lexp writes memory"
+ | LEXP_vector (le,_)
+ | LEXP_vector_range (le,_,_)
+ | LEXP_field (le,_) ->
+ getregtyp le
+
let doc_exp_lem, doc_let_lem =
let rec top_exp (aexp_needed : bool) (E_aux (e, (_,annot))) =
let exp = top_exp true in
match e with
| E_assign((LEXP_aux(le_act,tannot) as le),e) ->
(* can only be register writes *)
- let (_,(Base ((_,_),tag,_,_,_,_))) = tannot in
- (separate space)
+ let (_,(Base ((_,{t = t}),tag,_,_,_,_))) = tannot in
+ let E_aux (_,(_,(Base ((_,{t = et}),_,_,_,_,_)))) = e in
+ let (f,args) =
(match tag with
| External _ ->
(match le_act with
| LEXP_vector (le,e2) ->
- [string "write_reg_bit";doc_lexp_deref_lem le;exp e2;exp e]
+ (string "write_reg_bit",[doc_lexp_deref_lem le;exp e2;exp e])
| LEXP_vector_range (le,e2,e3) ->
- [string "write_reg_range";doc_lexp_deref_lem le;exp e2;exp e3;exp e]
+ (string "write_reg_range",[doc_lexp_deref_lem le;
+ parens (exp e2 ^^ comma ^^ exp e3);
+ exp e])
| LEXP_field (lexp,id) ->
- let (Base ((_,{t = t}),_,_,_,_,_)) = annot in
- (match t with
- | Tid "bit" ->
- [string "write_reg_field_bit";doc_lexp_deref_lem le;doc_id_lem id;exp e]
- | _ ->
- [string "write_reg_field";doc_lexp_deref_lem le;doc_id_lem id;exp e])
- | (LEXP_id _ | LEXP_cast _) -> [string "write_reg";doc_lexp_deref_lem le;exp e])
- | _ -> [string "write_reg";doc_lexp_deref_lem le;exp e]
- )
+ let typprefix = String.uncapitalize (getregtyp lexp) ^ "_" in
+ (match et with
+ | Tid "bit"
+ | Tabbrev (_,{t=Tid "bit"}) ->
+ (string "write_reg_field_bit"),
+ [doc_lexp_deref_lem lexp;string typprefix ^^ doc_id_lem id;exp e]
+ | Tapp ("vector",_)
+ | Tabbrev (_,{t=Tapp ("vector",_)}) ->
+ (string "write_reg_field",
+ [doc_lexp_deref_lem lexp;string typprefix ^^ doc_id_lem id;exp e])
+ | _ -> failwith (t_to_string {t = et})
+ )
+ | (LEXP_id _ | LEXP_cast _) -> (string "write_reg",[doc_lexp_deref_lem le;exp e]))
+ | _ -> (string "write_reg",[doc_lexp_deref_lem le;exp e])
+ ) in
+ prefix 2 1 f (separate (break 1) args)
| E_vector_append(l,r) ->
- let epp = (separate space [exp l;string "^^";exp r]) in
+ let epp =
+ separate space [exp l;string "^^"] ^//^ exp r in
if aexp_needed then parens epp else epp
| E_cons(l,r) -> doc_op (group (colon^^colon)) (exp l) (exp r)
| E_if(c,t,e) ->
@@ -1975,9 +2002,9 @@ let doc_exp_lem, doc_let_lem =
| Base ((_,({t = Tid "bit"})),_,_,_,_,_) ->
separate space [string "if";string "to_bool";exp c]
| _ -> separate space [string "if";exp c])
- ^/^
- (prefix 2 1 (string "then") (exp t)) ^^ (break 1) ^^
- (prefix 2 1 (string "else") (exp e))) in
+ ^^ break 1 ^^
+ (prefix 2 1 (string "then") (top_exp false t)) ^^ (break 1) ^^
+ (prefix 2 1 (string "else") (top_exp false e))) in
if aexp_needed then parens epp else epp
| E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) ->
failwith "E_for should have been removed till now"
@@ -2002,19 +2029,28 @@ let doc_exp_lem, doc_let_lem =
)
)
| _ ->
- let call = match annot with
- | Base(_,External (Some n),_,_,_,_) ->
- (match n with
- | "bitwise_not_bit" -> string "~"
- | _ -> string n)
- | Base(_,Constructor _,_,_,_,_) -> doc_id_lem_ctor f
- | _ -> doc_id_lem f in
- let epp =
- (doc_unop call)
- (match args with
- | [a] -> exp a
- | args -> parens (separate_map comma exp args)) in
- if aexp_needed then parens epp else epp
+ (match annot with
+ | Base (_,Constructor _,_,_,_,_) ->
+ let epp = separate space [doc_id_lem f;separate_map space (top_exp true) args] in
+ if aexp_needed then parens epp else epp
+ | Base (_,External (Some "bitwise_not_bit"),_,_,_,_) ->
+ let [a] = args in
+ let epp = string "~" ^^ exp a in
+ if aexp_needed then parens epp else epp
+ | _ ->
+ let call = match annot with
+ | Base(_,External (Some n),_,_,_,_) ->
+ (match n with
+ | _ -> string n)
+ | Base(_,Constructor _,_,_,_,_) -> doc_id_lem_ctor false f
+ | _ -> doc_id_lem f in
+ let epp =
+ (doc_unop call)
+ (match args with
+ | [a] -> exp a
+ | args -> parens (separate_map comma (top_exp false) args)) in
+ if aexp_needed then parens epp else epp
+ )
)
| E_vector_access(v,e) ->
let epp = separate space [string "access";exp v;exp e] in
@@ -2023,16 +2059,16 @@ let doc_exp_lem, doc_let_lem =
let epp = (string "slice") ^^ space ^^ (exp v) ^^ space ^^ (exp e1) ^^ space ^^ (exp e2) in
if aexp_needed then parens epp else epp
| E_field((E_aux(_,(_,fannot)) as fexp),id) ->
- (match fannot with
- | Base((_,{t= Tapp("register",_)}),_,_,_,_,_) |
- Base((_,{t= Tabbrev(_,{t=Tapp("register",_)})}),_,_,_,_,_)->
- let field_f = match annot with
- | Base((_,{t = Tid "bit"}),_,_,_,_,_) |
- Base((_,{t = Tabbrev(_,{t=Tid "bit"})}),_,_,_,_,_) ->
+ let (Base ((_,{t = t}),_,_,_,_,_)) = fannot in
+ (match t with
+ | Tabbrev({t = Tid regtyp},{t=Tapp("register",_)}) ->
+ let field_f = match annot with
+ | Base((_,{t = Tid "bit"}),_,_,_,_,_)
+ | Base((_,{t = Tabbrev(_,{t=Tid "bit"})}),_,_,_,_,_) ->
string "read_reg_field_bit"
- | _ -> string "read_reg_field" in
-
- let epp = field_f ^^ space ^^ (exp fexp) ^^ space ^^ string_lit (doc_id_lem id) in
+ | _ -> string "read_reg_field" in
+ let epp = field_f ^^ space ^^ (exp fexp) ^^ space ^^
+ string (regtyp ^ "_") ^^ doc_id_lem id in
if aexp_needed then parens epp else epp
| _ -> exp fexp ^^ dot ^^ doc_id_lem id)
| E_block [] -> string "()"
@@ -2046,7 +2082,7 @@ let doc_exp_lem, doc_let_lem =
(match tag with
| External _ -> separate space [string "read_reg";doc_id_lem id]
| _ -> doc_id_lem id)
- | Base(_,(Constructor i |Enum i),_,_,_,_) -> doc_id_lem_ctor id
+ | Base(_,(Constructor i |Enum i),_,_,_,_) -> doc_id_lem_ctor aexp_needed id
| Base((_,t),Alias alias_info,_,_,_,_) ->
(match alias_info with
| Alias_field(reg,field) ->
@@ -2108,6 +2144,7 @@ let doc_exp_lem, doc_let_lem =
| Nconst i -> string_of_big_int i
| N2n(_,Some i) -> string_of_big_int i
| _ -> if dir then "0" else string_of_int (List.length exps) in
+
let epp = group (separate space [string "V"; brackets (separate_map (semi) exp exps);
string start;string dir_out]) in
if aexp_needed then parens epp else epp
@@ -2149,7 +2186,7 @@ let doc_exp_lem, doc_let_lem =
| _ ->
raise (Reporting_basic.err_unreachable dl "nono") in
parens (string "Just " ^^ parens (string ("UndefinedReg " ^ string_of_big_int n)))) in
- let iexp (i,e) = parens (separate_map comma_sp (fun x -> x) [(doc_int i); (exp e)]) in
+ let iexp (i,e) = parens (separate_map comma (fun x -> x) [(doc_int i); (exp e)]) in
let epp =
(separate space)
[call;(brackets (separate_map semi iexp iexps));
@@ -2178,12 +2215,12 @@ let doc_exp_lem, doc_let_lem =
(match annot with
| Base((_,t),External(Some name),_,_,_,_) ->
let epp = match name with
- | "bitwise_and_bit" -> separate space [exp e1;string "&.";exp e2]
- | "bitwise_or_bit" -> separate space [exp e1;string "|.";exp e2]
- | "bitwise_xor_bit" -> separate space [exp e1;string "+.";exp e2]
- | "add" -> separate space [exp e1;string "+";exp e2]
- | "minus" -> separate space [exp e1;string "-";exp e2]
- | "multiply" -> separate space [exp e1;string "*";exp e2]
+ | "bitwise_and_bit" -> separate space [exp e1;string "&."] ^//^ exp e2
+ | "bitwise_or_bit" -> separate space [exp e1;string "|."] ^//^ exp e2
+ | "bitwise_xor_bit" -> separate space [exp e1;string "+."] ^//^ exp e2
+ | "add" -> separate space [exp e1;string "+";exp e2]
+ | "minus" -> separate space [exp e1;string "-";exp e2]
+ | "multiply" -> separate space [exp e1;string "*";exp e2]
(* | "lt" -> separate space [exp e1;string "<";exp e2]
| "gt" -> separate space [exp e1;string ">";exp e2]
| "lteq" -> separate space [exp e1;string "<=";exp e2]
@@ -2192,29 +2229,32 @@ let doc_exp_lem, doc_let_lem =
| "gt_vec" -> separate space [exp e1;string ">";exp e2]
| "lteq_vec" -> separate space [exp e1;string "<=";exp e2]
| "gteq_vec" -> separate space [exp e1;string ">=";exp e2] *)
- | _ -> separate space [string name; parens (separate_map comma exp [e1;e2])] in
+ | _ -> separate space [string name; parens (separate_map comma (top_exp false) [e1;e2])] in
if aexp_needed then parens epp else epp
| _ ->
- let epp = separate space [doc_id_lem id; parens (separate_map comma exp [e1;e2])] in
+ let epp = separate space [doc_id_lem id; parens (separate_map comma (top_exp false) [e1;e2])] in
if aexp_needed then parens epp else epp)
| E_internal_let(lexp, eq_exp, in_exp) ->
- failwith "E_internal_lets should have been removed till now"
-(* (separate
+ (* failwith "E_internal_lets should have been removed till now" *)
+ (separate
space
- [string "let TAKTAK";
- doc_lexp_lem true lexp; (*Rewriter/typecheck should ensure this is only cast or id*)
+ [string "let internal";
+ (match lexp with (LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_)) -> doc_id_lem id);
coloneq;
exp eq_exp;
string "in"]) ^/^
- exp in_exp *)
+ exp in_exp
| E_internal_plet (pat,e1,e2) ->
- (match pat with
- | P_aux (P_wild,_) ->
- (separate space [exp e1; string ">>"]) ^/^
- top_exp false e2
- | _ ->
- (separate space [exp e1; string ">>= fun"; doc_pat_lem true pat;arrow]) ^/^
- top_exp false e2)
+ let epp =
+ let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in
+ match pat with
+ | P_aux (P_wild,_) ->
+ (separate space [top_exp b e1; string ">>"]) ^/^
+ top_exp false e2
+ | _ ->
+ (separate space [top_exp b e1; string ">>= fun"; doc_pat_lem true pat;arrow]) ^/^
+ top_exp false e2 in
+ if aexp_needed then parens epp else epp
| E_internal_return (e1) ->
separate space [string "return"; exp e1;]
and let_exp (LB_aux(lb,_)) = match lb with
@@ -2242,8 +2282,8 @@ let doc_exp_lem, doc_let_lem =
(*TODO Upcase and downcase type and constructors as needed*)
let doc_type_union_lem (Tu_aux(typ_u,_)) = match typ_u with
- | Tu_ty_id(typ,id) -> separate space [pipe; doc_id_lem_ctor id; string "of"; doc_typ_lem typ;]
- | Tu_id id -> separate space [pipe; doc_id_lem_ctor id]
+ | Tu_ty_id(typ,id) -> separate space [pipe; doc_id_lem_ctor false id; string "of"; doc_typ_lem typ;]
+ | Tu_id id -> separate space [pipe; doc_id_lem_ctor false id]
let rec doc_range_lem (BF_aux(r,_)) = match r with
| BF_single i -> parens (doc_op comma (doc_int i) (doc_int i))
@@ -2265,7 +2305,7 @@ let doc_typdef_lem (TD_aux(td,_)) = match td with
(concat [string "type"; space; doc_id_lem_type id;])
(doc_typquant_lem typq ar_doc)
| TD_enum(id,nm,enums,_) ->
- let enums_doc = group (separate_map (break 1 ^^ pipe ^^ space) doc_id_lem_ctor enums) in
+ let enums_doc = group (separate_map (break 1 ^^ pipe ^^ space) (doc_id_lem_ctor false) enums) in
doc_op equals
(concat [string "type"; space; doc_id_lem_type id;])
(enums_doc)
@@ -2279,7 +2319,7 @@ let doc_tannot_opt_lem (Typ_annot_opt_aux(t,_)) = match t with
| Typ_annot_opt_some(tq,typ) -> doc_typquant_lem tq (doc_typ_lem typ)
let doc_funcl_lem (FCL_aux(FCL_Funcl(id,pat,exp),_)) =
- group (doc_op arrow (doc_pat_lem false pat) (doc_exp_lem false exp))
+ group (prefix 3 1 ((doc_pat_lem false pat) ^^ space ^^ arrow) (doc_exp_lem false exp))
let get_id = function
| [] -> failwith "FD_function with empty list"
@@ -2352,16 +2392,22 @@ let reg_decls (Defs defs) =
| _ -> (regtypes,rsranges,rsbits,defs @ [def])
) ([],[],[],[]) defs in
- let (regs,defs) =
+ let (regs,regaliases,defs) =
List.fold_left
- (fun (regs,defs) def ->
+ (fun (regs,regaliases,defs) def ->
match def with
| DEF_reg_dec (DEC_aux (DEC_reg(Typ_aux (Typ_id (Id_aux (Id typ,_)),_),Id_aux (Id name,_)),_)) ->
- (regs @ [(name,Some typ)],defs)
+ (regs @ [(name,Some typ)],regaliases,defs)
| DEF_reg_dec (DEC_aux (DEC_reg(_, Id_aux (Id name,_)),_)) ->
- (regs @ [(name,None)],defs)
- | def -> (regs,defs @ [def])
- ) ([],[]) defs in
+ (regs @ [(name,None)],regaliases,defs)
+ | DEF_reg_dec
+ (DEC_aux (DEC_alias
+ (Id_aux (Id name1,_),
+ AL_aux (AL_concat (RI_aux (RI_id (Id_aux (Id name2,_)),_),
+ RI_aux (RI_id (Id_aux (Id name3,_)),_)),_)),_)) ->
+ (regs,regaliases @ [(name1,(name2,name3))],defs)
+ | def -> (regs,regaliases,defs @ [def])
+ ) ([],[],[]) defs in
(* maybe we need a function that analyses the spec for this as well *)
let default =
@@ -2373,7 +2419,8 @@ let reg_decls (Defs defs) =
(prefix 2 1)
(separate space [string "type";string "register";equals])
((separate_map space (fun (reg,_) -> pipe ^^ space ^^ string reg) regs)
- ^^ space ^^ pipe ^^ space ^^ string "UndefinedReg of nat") in
+ ^^ space ^^ pipe ^^ space ^^ string "UndefinedReg of integer" ^^
+ pipe ^^ space ^^ string "RegisterPair of register * register") in
let regfields_pp =
(prefix 2 1)
@@ -2387,6 +2434,12 @@ let reg_decls (Defs defs) =
(separate_map space (fun (fname,tname,_) ->
pipe ^^ space ^^ string (tname ^ "_" ^ fname)) rsbits) in
+ let regalias_pp =
+ (separate_map (break 1))
+ (fun (name1,(name2,name3)) ->
+ separate space [string "let";string name1;equals;string "RegisterPair";string name2;string name3])
+ regaliases in
+
let state_pp =
(prefix 2 1)
(separate space [string "type";string "state";equals])
@@ -2397,10 +2450,10 @@ let reg_decls (Defs defs) =
)) in
let length_pp =
- (separate space [string "val";string "length_reg";colon;string "register";arrow;string "nat"])
+ (separate space [string "val";string "length_reg";colon;string "register";arrow;string "integer"])
^/^
(prefix 2 1)
- (separate space [string "let";string "length_reg";equals;string "function"])
+ (separate space [string "let rec";string "length_reg";string "reg";equals;string "match reg with"])
(((separate_map (break 1))
(fun (name,typ) ->
let ((n1,n2,_,_),typname) =
@@ -2412,13 +2465,15 @@ let reg_decls (Defs defs) =
regs) ^/^
separate space [pipe;string "UndefinedReg n";arrow;
string "failwith \"Trying to compute length of undefined register\""] ^/^
+ separate space [pipe;string "RegisterPair r1 r2";arrow;
+ string "length_reg r1 + length_reg r2"] ^/^
string "end") in
let field_indices_pp =
(prefix 2 1)
((separate space) [string "let";string "field_indices";
- colon;string "register_field";arrow;string "(nat * nat)";
+ colon;string "register_field";arrow;string "(integer * integer)";
equals;string "function"])
(
((separate_map (break 1))
@@ -2433,7 +2488,7 @@ let reg_decls (Defs defs) =
let field_index_bit_pp =
(prefix 2 1)
((separate space) [string "let";string "field_index_bit";
- colon;string "register_field_bit";arrow;string "nat";
+ colon;string "register_field_bit";arrow;string "integer";
equals;string "function"])
(
((separate_map (break 1))
@@ -2447,7 +2502,7 @@ let field_index_bit_pp =
let read_regstate_pp =
(prefix 2 1)
- (separate space [string "let";string "read_regstate";string "s";equals;string "function"])
+ (separate space [string "let rec";string "read_regstate";string "s";equals;string "function"])
(
((separate_map (break 1))
(fun (name,_) ->
@@ -2455,11 +2510,13 @@ let field_index_bit_pp =
regs) ^/^
separate space [pipe;string "UndefinedReg n";arrow;
string "failwith \"Trying to read from undefined register\""] ^/^
+ separate space [pipe;string "RegisterPair r1 r2";arrow;
+ string "read_regstate s r1 ^^ read_regstate s r2"] ^/^
string "end" ^^ hardline ) in
let write_regstate_pp =
(prefix 2 1)
- (separate space [string "let";string "write_regstate";string "s";string "reg";string "v";
+ (separate space [string "let rec";string "write_regstate";string "s";string "reg";string "v";
equals;string "match reg with"])
(
((separate_map (break 1))
@@ -2474,11 +2531,26 @@ let field_index_bit_pp =
) regs) ^/^
separate space [pipe;string "UndefinedReg n";arrow;
string "failwith \"Trying to write to undefined register\""] ^/^
+ ((prefix 3 1)
+ (separate space [pipe;string "RegisterPair r1 r2";arrow])
+ ((separate (break 1))
+ [
+ string "let size = length_reg r1 in";
+ string "let start = get_start v in";
+ string "let vsize = length v in";
+ string "let vsize = integerFromNat vsize in";
+ string ("let r1_v = slice v start " ^
+ (if is_inc then "(size - start - 1) in" else "(start - size) - 1) in"));
+ string ("let r2_v = slice v " ^
+ (if is_inc then "(size - start)" else "(start - size)") ^
+ (if is_inc then "(vsize - start) in" else ("start - vsize) in")));
+ string "write_regstate (write_regstate s r1 r1_v) r2 r2_v"
+ ])) ^/^
string "end" ^^ hardline ) in
(separate (hardline ^^ hardline)
[dir_pp;regs_pp;regfields_pp;regfieldsbit_pp;field_index_bit_pp;field_indices_pp;
- state_pp;length_pp;read_regstate_pp;write_regstate_pp],defs)
+ regalias_pp;state_pp;length_pp;read_regstate_pp;write_regstate_pp],defs)
let doc_defs_lem (Defs defs) =
let (decls,defs) = reg_decls (Defs defs) in
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 3c4ff5e8..c7e6986b 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -25,6 +25,43 @@ let fresh_name () =
let () = fresh_name_counter := (current + 1) in
current
+let geteffs_annot (_,t) = match t with
+ | Base (_,_,_,_,effs,_) -> effs
+ | NoTyp -> failwith "no effect information"
+ | _ -> failwith "a_normalise doesn't support Overload"
+
+let gettype_annot (_,t) = match t with
+ | Base((_,t),_,_,_,_,_) -> t
+ | NoTyp -> failwith "no type information"
+ | _ -> failwith "a_normalise doesn't support Overload"
+
+let gettype (E_aux (_,a)) = gettype_annot a
+let geteffs (E_aux (_,a)) = geteffs_annot a
+
+let effectful_effs {effect = Eset effs} =
+ List.exists
+ (fun (BE_aux (be,_)) ->
+ match be with
+ | BE_nondet | BE_unspec | BE_undef | BE_lset -> false
+ | _ -> true
+ ) effs
+
+let effectful eaux =
+ effectful_effs (geteffs eaux)
+
+let updates_vars_effs {effect = Eset effs} =
+ List.exists
+ (fun (BE_aux (be,_)) ->
+ match be with
+ | BE_lset -> true
+ | _ -> false
+ ) effs
+
+let updates_vars eaux =
+ updates_vars_effs (geteffs eaux)
+
+let eff_union es =
+ List.fold_left (fun acc e -> union_effects acc (geteffs e)) pure_e es
let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with
| [] -> None
@@ -871,8 +908,9 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful
let le' = rewriters.rewrite_lexp rewriters nmap le in
let e' = rewrite_base e in
let exps' = walker exps in
- [(E_aux (E_internal_let(le', e', E_aux(E_block exps', (l, simple_annot {t=Tid "unit"}))),
- (l, simple_annot t)))]
+ let effects = eff_union exps' in
+ [E_aux (E_internal_let(le', e', E_aux(E_block exps', (l, simple_annot_efr {t=Tid "unit"} effects))),
+ (l, simple_annot_efr t (eff_union (e::exps'))))]
| ((E_aux(E_if(c,t,e),(l,annot))) as exp)::exps ->
let vars_t = introduced_variables t in
let vars_e = introduced_variables e in
@@ -886,8 +924,9 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful
let c' = rewrite_base c in
let t' = rewriters.rewrite_exp rewriters new_nmap t in
let e' = rewriters.rewrite_exp rewriters new_nmap e in
- Envmap.fold
- (fun res i (t,e) ->
+ let exps' = walker exps in
+ fst ((Envmap.fold
+ (fun (res,effects) i (t,e) ->
let bitlit = E_aux (E_lit (L_aux(L_zero, Parse_ast.Unknown)),
(Parse_ast.Unknown, simple_annot bit_t)) in
let rangelit = E_aux (E_lit (L_aux (L_num 0, Parse_ast.Unknown)),
@@ -907,12 +946,13 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful
(Parse_ast.Unknown,simple_annot bit_t))),
(Parse_ast.Unknown, simple_annot t))
| _ -> e in
- [E_aux (E_internal_let (LEXP_aux (LEXP_id (Id_aux (Id i, Parse_ast.Unknown)),
+ let unioneffs = union_effects effects (geteffs set_exp) in
+ ([E_aux (E_internal_let (LEXP_aux (LEXP_id (Id_aux (Id i, Parse_ast.Unknown)),
(Parse_ast.Unknown, (tag_annot t Emp_intro))),
set_exp,
- E_aux (E_block res, (Parse_ast.Unknown, (simple_annot unit_t)))),
- (Parse_ast.Unknown, simple_annot unit_t))])
- (E_aux(E_if(c',t',e'), (Parse_ast.Unknown, annot))::(walker exps)) new_vars
+ E_aux (E_block res, (Parse_ast.Unknown, (simple_annot_efr unit_t effects)))),
+ (Parse_ast.Unknown, simple_annot_efr unit_t unioneffs))],unioneffs)))
+ (E_aux(E_if(c',t',e'), (Parse_ast.Unknown, annot))::exps',eff_union exps') new_vars)
| e::exps -> (rewrite_rec e)::(walker exps)
in
rewrap (E_block (walker exps))
@@ -923,7 +963,7 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful
(match le' with
| LEXP_aux(_, (_,Base(_,Emp_intro,_,_,_,_))) ->
let e' = rewrite_base e in
- rewrap (E_internal_let(le', e', E_aux(E_block [], (l, simple_annot unit_t))))
+ rewrap (E_internal_let(le', e', E_aux(E_block [], (l, simple_annot_efr unit_t (geteffs e')))))
| _ -> E_aux((E_assign(le', rewrite_base e)),(l, tag_annot unit_t Emp_set)))
| _ -> rewrite_base full_exp)
| _ -> rewrite_base full_exp
@@ -961,52 +1001,12 @@ let rewrite_defs_ocaml defs =
let defs_lifted_assign = rewrite_defs_exp_lift_assign defs_vec_concat_removed in
defs_lifted_assign
-
-
-let geteffs_annot (_,t) = match t with
- | Base (_,_,_,_,effs,_) -> effs
- | NoTyp -> failwith "no effect information"
- | _ -> failwith "a_normalise doesn't support Overload"
-
-let gettype_annot (_,t) = match t with
- | Base((_,t),_,_,_,_,_) -> t
- | NoTyp -> failwith "no type information"
- | _ -> failwith "a_normalise doesn't support Overload"
-
-let gettype (E_aux (_,a)) = gettype_annot a
-let geteffs (E_aux (_,a)) = geteffs_annot a
-
-let effectful_effs {effect = Eset effs} =
- List.exists
- (fun (BE_aux (be,_)) ->
- match be with
- | BE_nondet | BE_unspec | BE_undef | BE_lset -> false
- | _ -> true
- ) effs
-
-let effectful eaux =
- effectful_effs (geteffs eaux)
-
-let updates_vars_effs {effect = Eset effs} =
- List.exists
- (fun (BE_aux (be,_)) ->
- match be with
- | BE_lset -> true
- | _ -> false
- ) effs
-
-let updates_vars eaux =
- updates_vars_effs (geteffs eaux)
-
-let eff_union e1 e2 = union_effects (geteffs e1) (geteffs e2)
-
-let remove_blocks_exp_alg =
+let remove_blocks =
let letbind_wild v body =
- let annot_pat = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
- let annot_lb = annot_pat in
- let annot_let =
- (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v body)) in
+ let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in
+ let annot_lb = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
+ let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in
E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_wild,annot_pat),v),annot_lb),body),annot_let) in
let rec f = function
@@ -1018,7 +1018,7 @@ let remove_blocks_exp_alg =
| (E_block es,annot) -> f es
| (e,annot) -> E_aux (e,annot) in
- { id_exp_alg with e_aux = e_aux }
+ fold_exp { id_exp_alg with e_aux = e_aux }
let fresh_id annot =
@@ -1036,7 +1036,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
let body = body e in
let annot_pat = (Parse_ast.Unknown,simple_annot unit_t) in
let annot_lb = annot_pat in
- let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union v body)) in
+ let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in
let pat = P_aux (P_wild,annot_pat) in
if effectful v
@@ -1049,7 +1049,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in
let annot_lb = annot_pat in
- let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union v body)) in
+ let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in
let pat = P_aux (P_id id,annot_pat) in
if effectful v
@@ -1106,7 +1106,7 @@ let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
and n_exp_pure (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
n_exp exp (fun exp -> if not (effectful exp || updates_vars exp) then k exp else letbind exp k)
-
+
and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp =
mapCont n_exp_name exps k
@@ -1170,7 +1170,6 @@ and n_lexp (lexp : 'a lexp) (k : 'a lexp -> 'a exp) : 'a exp =
k (LEXP_aux (LEXP_field (lexp,id),local_eff_plus effs annot)))
and n_exp_term (new_return : bool) (exp : 'a exp) : 'a exp =
- let (E_aux (_,annot)) = exp in
let exp =
if new_return
then E_aux (E_internal_return exp,(Unknown,simple_annot_efr (gettype exp) (geteffs exp)))
@@ -1187,8 +1186,8 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
let (E_aux (exp_aux,annot)) = exp in
let rewrap_effs effsum exp_aux = (* explicitly give effect sum *)
- let (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds)) = annot in
- E_aux (exp_aux, (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds))) in
+ let (l,Base (t,tag,nexps,eff,_,bounds)) = annot in
+ E_aux (exp_aux, (l,Base (t,tag,nexps,eff,effsum,bounds))) in
let rewrap_localeff exp_aux = (* give exp_aux the local effect as the effect sum *)
E_aux (exp_aux,only_local_eff annot) in
@@ -1218,7 +1217,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
let new_return = effectful exp2 || effectful exp3 in
let exp2 = n_exp_term new_return exp2 in
let exp3 = n_exp_term new_return exp3 in
- k (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3))))
+ k (rewrap_effs (eff_union [exp2;exp3]) (E_if (exp1,exp2,exp3))))
| E_for (id,start,stop,by,dir,body) ->
n_exp_name start (fun start ->
n_exp_name stop (fun stop ->
@@ -1292,7 +1291,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
| LB_aux (LB_val_explicit (_,pat,exp'),annot')
| LB_aux (LB_val_implicit (pat,exp'),annot') ->
if effectful exp'
- then (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',n_exp body k)))
+ then (rewrap_effs (eff_union [exp';body]) (E_internal_plet (pat,exp',n_exp body k)))
else (rewrap_effs (geteffs body) (E_let (lb,n_exp body k)))
))
| E_assign (lexp,exp1) ->
@@ -1309,7 +1308,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
(if effectful exp1
then n_exp_name exp1
else n_exp exp1) (fun exp1 ->
- (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,n_exp exp2 k))))
+ rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,n_exp exp2 k)))
| E_internal_return exp1 ->
n_exp_name exp1 (fun exp1 ->
k (rewrap_localeff (E_internal_return exp1)))
@@ -1318,7 +1317,8 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
let rewrite_defs_a_normalise =
let rewrite_exp _ _ e =
- n_exp_term (effectful e) (fold_exp remove_blocks_exp_alg e) in
+ let e = remove_blocks e in
+ n_exp_term (effectful e) e in
rewrite_defs_base
{rewrite_exp = rewrite_exp
; rewrite_pat = rewrite_pat
@@ -1376,7 +1376,8 @@ let find_updated_vars exp =
; e_internal_cast = (fun (_,e1) -> e1)
; e_internal_exp = (fun _ -> [])
; e_internal_exp_user = (fun _ -> [])
- ; e_internal_let = (fun ((None,[(id,_)]), e2, e3) -> List.filter (eqidtyp id) (e2 @ e3))
+ ; e_internal_let = (fun ((None,[(id,_)]), e2, e3) ->
+ List.filter (fun id2 -> not (eqidtyp id id2)) (e2 @ e3))
; e_internal_plet = (fun (_, e1, e2) -> e1 @ e2)
; e_internal_return = (fun e -> e)
; e_aux = (fun (e,_) -> e)
@@ -1523,7 +1524,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) =
(* after a-normalisation c shouldn't need rewriting *)
let t = gettype e1 in
(* let () = assert (simple_annot t = simple_annot (gettype e2)) in *)
- let v = E_aux (E_if (c,e1,e2), (Unknown,simple_annot_efr t (eff_union e1 e2))) in
+ let v = E_aux (E_if (c,e1,e2), (Unknown,simple_annot_efr t (eff_union [e1;e2]))) in
let pat =
(* if overwrite then
mktup_pat vars
@@ -1615,7 +1616,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) =
(match rewrite v pat with
| Added_vars (v,pat) ->
E_aux (E_internal_plet (pat,v,body),
- (Unknown,simple_annot_efr (gettype body) (eff_union v body)))
+ (Unknown,simple_annot_efr (gettype body) (eff_union [v;body])))
| Same_vars v -> E_aux (E_internal_plet (pat,v,body),annot))
| E_internal_let (lexp,v,body) ->
(* After a-normalisation E_internal_lets can only bind values to names, those don't
@@ -1627,7 +1628,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) =
let pat = P_aux (P_id id, (Parse_ast.Unknown,simple_annot (gettype v))) in
let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
let lb = LB_aux (LB_val_implicit (pat,v),lbannot) in
- E_aux (E_let (lb,body),(Unknown,simple_annot_efr (gettype body) (eff_union v body)))
+ E_aux (E_let (lb,body),(Unknown,simple_annot_efr (gettype body) (eff_union [v;body])))
(* In tail-position there shouldn't be anything we need to do as the terms after
* a-normalisation are pure and don't update local variables. There can't be any variable
* assignments in tail-position (because of the effect), there could be pure pattern-match