diff options
| author | Christopher Pulte | 2015-11-19 14:37:19 +0000 |
|---|---|---|
| committer | Christopher Pulte | 2015-11-19 14:37:19 +0000 |
| commit | a1d41f415a555bbe31e86375601e75f8ecf37f54 (patch) | |
| tree | a404c7bd198763b1ffa9b3048a7419ea3ddefe4d /src | |
| parent | 3323f7a685f0aa7d125a9f348112b6e25fb392ae (diff) | |
fixes for cumulative effect anotations
Diffstat (limited to 'src')
| -rw-r--r-- | src/gen_lib/sail_values.lem | 107 | ||||
| -rw-r--r-- | src/gen_lib/state.lem | 147 | ||||
| -rw-r--r-- | src/gen_lib/vector.lem | 67 | ||||
| -rw-r--r-- | src/pretty_print.ml | 248 | ||||
| -rw-r--r-- | src/rewriter.ml | 137 |
5 files changed, 405 insertions, 301 deletions
diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index 2104d072..4d74976b 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -3,46 +3,7 @@ open import State open import Vector open import Arch -let to_bool = function - | O -> false - | I -> true - end - -let get_start (V _ s _) = s -let length (V bs _ _) = length bs - -let write_two_regs r1 r2 vec = - let size = length_reg r1 in - let start = get_start vec in - let vsize = length vec in - let r1_v = slice vec start ((if defaultDir then size - start else start - size) - 1) in - let r2_v = - (slice vec) - (if defaultDir then size - start else start - size) - (if defaultDir then vsize - start else start - vsize) in - write_reg r1 r1_v >> write_reg r2 r2_v - -let rec replace bs ((n : nat),b') = match (n,bs) with - | (_, []) -> [] - | (0, _::bs) -> b' :: bs - | (n+1, b::bs) -> b :: replace bs (n,b') - end - -let make_indexed_vector_reg entries default start length = - let (Just v) = default in - V (List.foldl replace (replicate length v) entries) start - -let make_indexed_vector_bit entries default start length = - let default = match default with Nothing -> U | Just v -> v end in - V (List.foldl replace (replicate length default) entries) start - -let make_bitvector_undef length = - V (replicate length U) 0 true - -let vector_concat (V bs start is_inc) (V bs' _ _) = - V(bs ++ bs') start is_inc - -let (^^) = vector_concat +let length l = integerFromNat (length l) let has_undef (V bs _ _) = List.any (function U -> true | _ -> false end) bs @@ -100,7 +61,7 @@ let unsigned (V bs _ _ as v) : integer = (fun (acc,exp) b -> (acc + (if b = I then integerPow 2 exp else 0),exp +1)) (0,0) bs) end -let signed v = +let signed v : integer = match most_significant v with | I -> 0 - (1 + (unsigned (bitwise_not v))) | O -> unsigned v @@ -119,43 +80,43 @@ let min_8 = (0 - 128 : integer) let max_5 = (31 : integer) let min_5 = (0 - 32 : integer) -let get_max_representable_in sign n = +let get_max_representable_in sign (n : integer) : integer = if (n = 64) then match sign with | true -> max_64 | false -> max_64u end else if (n=32) then match sign with | true -> max_32 | false -> max_32u end else if (n=8) then max_8 else if (n=5) then max_5 - else match sign with | true -> integerPow 2 (n -1) - | false -> integerPow 2 n + else match sign with | true -> integerPow 2 ((natFromInteger n) -1) + | false -> integerPow 2 (natFromInteger n) end -let get_min_representable_in _ n = +let get_min_representable_in _ (n : integer) : integer = if n = 64 then min_64 else if n = 32 then min_32 else if n = 8 then min_8 else if n = 5 then min_5 - else 0 - (integerPow 2 n) + else 0 - (integerPow 2 (natFromInteger n)) -let rec divide_by_2 bs i (n : integer) = +let rec divide_by_2 bs (i : integer) (n : integer) = if i < 0 || n = 0 then bs else if (n mod 2 = 1) - then divide_by_2 (replace bs (i,I)) (i - 1) (n / 2) + then divide_by_2 (replace bs (natFromInteger i,I)) (i - 1) (n / 2) else divide_by_2 bs (i-1) (n div 2) -let rec add_one_bit bs co i = +let rec add_one_bit bs co (i : integer) = if i < 0 then bs else match (nth bs i,co) with - | (O,false) -> replace bs (i,I) - | (O,true) -> add_one_bit (replace bs (i,I)) true (i-1) - | (I,false) -> add_one_bit (replace bs (i,O)) true (i-1) + | (O,false) -> replace bs (natFromInteger i,I) + | (O,true) -> add_one_bit (replace bs (natFromInteger i,I)) true (i-1) + | (I,false) -> add_one_bit (replace bs (natFromInteger i,O)) true (i-1) | (I,true) -> add_one_bit bs true (i-1) (* | Vundef,_ -> assert false*) end -let to_vec is_inc (len,(n : integer)) = - let bs = List.replicate len O in +let to_vec is_inc ((len : integer),(n : integer)) = + let bs = List.replicate (natFromInteger len) O in let start = if is_inc then 0 else len-1 in if n = 0 then V bs start is_inc @@ -169,8 +130,11 @@ let to_vec is_inc (len,(n : integer)) = let to_vec_inc = to_vec true let to_vec_dec = to_vec false -let to_vec_undef is_inc len = - V (replicate len U) (if is_inc then 0 else len-1) is_inc +let to_vec_undef is_inc (len : integer) = + V (replicate (natFromInteger len) U) (if is_inc then 0 else len-1) is_inc + +let to_vec_inc_undef = to_vec_undef true +let to_vec_dec_undef = to_vec_undef false let add = uncurry integerAdd let add_signed = uncurry integerAdd @@ -180,7 +144,7 @@ let modulo = uncurry integerMod let quot = uncurry integerDiv let power = uncurry integerPow -let arith_op_vec op sign size ((V _ _ is_inc as l),r) = +let arith_op_vec op sign (size : integer) ((V _ _ is_inc as l),r) = let (l',r') = (to_num sign l, to_num sign r) in let n = op l' r' in to_vec is_inc (size * (length l),n) @@ -228,7 +192,7 @@ let arith_op_vec_vec_range op sign ((V _ _ is_inc as l),r) = let add_vec_vec_range = arith_op_vec_vec_range integerAdd false let add_vec_vec_range_signed = arith_op_vec_vec_range integerAdd true -let arith_op_vec_bit op sign (size : nat) ((V _ _ is_inc as l),r) = +let arith_op_vec_bit op sign (size : integer) ((V _ _ is_inc as l),r) = let l' = to_num sign l in let n = op l' match r with | I -> (1 : integer) | _ -> 0 end in to_vec is_inc (length l * size,n) @@ -260,7 +224,7 @@ let minus_overflow_vec_signed = arith_op_overflow_vec integerMinus true 1 let mult_overflow_vec = arith_op_overflow_vec integerMult false 2 let mult_overflow_vec_signed = arith_op_overflow_vec integerMult true 2 -let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (size : nat) +let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (size : integer) ((V _ _ is_inc as l),r_bit) = let act_size = length l * size in let l' = to_num sign l in @@ -286,17 +250,16 @@ let minus_overflow_vec_bit_signed = arith_op_overflow_vec_bit integerMinus true type shift = LL | RR | LLL -let shift_op_vec op ((V bs start is_inc as l),r) = - let len = List.length bs in - let n = r in +let shift_op_vec op ((V bs start is_inc as l),(n : integer)) = + let len = integerFromNat (List.length bs) in match op with | LL (*"<<"*) -> - let right_vec = V (List.replicate n O) 0 true in + let right_vec = V (List.replicate (natFromInteger n) O) 0 true in let left_vec = slice l n (if is_inc then len + start else start - len) in vector_concat left_vec right_vec | RR (*">>"*) -> let right_vec = slice l start n in - let left_vec = V (List.replicate n O) 0 true in + let left_vec = V (List.replicate (natFromInteger n) O) 0 true in vector_concat left_vec right_vec | LLL (*"<<<"*) -> let left_vec = slice l n (if is_inc then len + start else start - len) in @@ -326,7 +289,7 @@ let rec arith_op_vec_no0 (op : integer -> integer -> integer) sign size (((V _ s end in if representable then to_vec is_inc (act_size,n') - else V (List.replicate act_size U) start is_inc + else V (List.replicate (natFromInteger act_size) U) start is_inc let mod_vec = arith_op_vec_no0 integerMod false 1 let quot_vec = arith_op_vec_no0 integerDiv false 1 @@ -350,8 +313,8 @@ let arith_op_overflow_no0_vec op sign size (((V _ start is_inc) as l),r) = if representable then (to_vec is_inc (act_size,n'),to_vec is_inc (act_size + 1,n_u')) else - (V (List.replicate act_size U) start is_inc, - V (List.replicate (act_size + 1) U) start is_inc) in + (V (List.replicate (natFromInteger act_size) U) start is_inc, + V (List.replicate (natFromInteger (act_size + 1)) U) start is_inc) in let overflow = if representable then O else I in (correct_size_num,overflow,most_significant one_more) @@ -364,7 +327,7 @@ let arith_op_vec_range_no0 op sign size ((V _ _ is_inc as l),r) = let mod_vec_range = arith_op_vec_range_no0 integerMod false 1 let duplicate (bit,length) = - V (List.replicate length bit) 0 true + V (List.replicate (natFromInteger length) bit) 0 true let compare_op op (l,r) = bool_to_bit (op l r) @@ -415,3 +378,11 @@ let eq_vec_vec (l,r) = eq (to_num true l, to_num true r) let neq (l,r) = bitwise_not_bit (eq (l,r)) let neq_vec (l,r) = bitwise_not_bit (eq_vec_vec (l,r)) +let neq_vec_range (l,r) = bitwise_not_bit (eq_vec_range (l,r)) +let neq_range_vec (l,r) = bitwise_not_bit (eq_range_vec (l,r)) + + +let EXTS (v1,(V _ _ is_inc as v)) = + to_vec is_inc (v1,signed v) + +let EXTZ = EXTS diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index dee300ef..ac65a347 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -2,19 +2,60 @@ open import Pervasives open import Vector open import Arch -type M 's 'a = 's -> ('a * 's) +(* 'a is result type, 'e is error type *) +type M 's 'e 'a = 's -> (either 'a 'e * 's) -val return : forall 's 'a. 'a -> M 's 'a -let return a = fun s -> (a,s) +val return : forall 's 'e 'a. 'a -> M 's 'e 'a +let return a s = (Left a,s) -val bind : forall 's 'a 'b. M 's 'a -> ('a -> M 's 'b) -> M 's 'b -let bind m f s = let (a,s') = m s in f a s' +val bind : forall 's 'e 'a 'b. M 's 'e 'a -> ('a -> M 's 'e 'b) -> M 's 'e 'b +let bind m f s = match m s with + | (Left a,s') -> f a s' + | (Right error,s') -> (Right error,s') + end + +val exit : forall 's 'e 'a. 'e -> M 's 'e 'a +let exit e s = (Right e,s) let (>>=) = bind let (>>) m n = m >>= fun _ -> n -val foreach_inc : forall 's 'vars. (nat * nat * nat) -> 'vars -> - (nat -> 'vars -> (unit * 'vars)) -> (unit * 'vars) +val read_reg_range : forall 'e. register -> (integer * integer) (*(nat * nat)*) -> M state 'e (vector bit) +let read_reg_range reg (i,j) s = + let v = slice (read_regstate s reg) i j in + (Left v,s) + +val read_reg_bit : forall 'e. register -> integer (*nat*) -> M state 'e bit +let read_reg_bit reg i s = + let v = access (read_regstate s reg) i in + (Left v,s) + +val write_reg_range : forall 'e. register -> (integer * integer) (*(nat * nat)*) -> vector bit -> M state 'e unit +let write_reg_range (reg : register) (i,j) (v : vector bit) s = + let v' = update (read_regstate s reg) i j v in + let s' = write_regstate s reg v' in + (Left (),s') + +val write_reg_bit : forall 'e. register -> integer (*nat*) -> bit -> M state 'e unit +let write_reg_bit reg i bit s = + let v = read_regstate s reg in + let v' = update_pos v i bit in + let s' = write_regstate s reg v' in + (Left (),s') + +val read_reg : forall 'e. register -> M state 'e (vector bit) +let read_reg reg s = + let v = read_regstate s reg in + (Left v,s) + +val write_reg : forall 'e. register -> vector bit -> M state 'e unit +let write_reg reg v s = + let s' = write_regstate s reg v in + (Left (),s') + + +val foreach_inc : forall 's 'e 'vars. (integer * integer * integer) (*(nat * nat * nat)*) -> 'vars -> + (integer (*nat*) -> 'vars -> (unit * 'vars)) -> (unit * 'vars) let rec foreach_inc (i,stop,by) vars body = if i <= stop then @@ -23,8 +64,8 @@ let rec foreach_inc (i,stop,by) vars body = else ((),vars) -val foreach_dec : forall 's 'vars. (nat * nat * nat) -> 'vars -> - (nat -> 'vars -> (unit * 'vars)) -> (unit * 'vars) +val foreach_dec : forall 's 'e 'vars. (integer * integer * integer) (*(nat * nat * nat)*) -> 'vars -> + (integer (*nat*) -> 'vars -> (unit * 'vars)) -> (unit * 'vars) let rec foreach_dec (i,stop,by) vars body = if i >= stop then @@ -33,8 +74,8 @@ let rec foreach_dec (i,stop,by) vars body = else ((),vars) -val foreachM_inc : forall 's 'vars. (nat * nat * nat) -> 'vars -> - (nat -> 'vars -> M 's (unit * 'vars)) -> M 's (unit * 'vars) +val foreachM_inc : forall 's 'e 'vars. (nat * nat * nat) -> 'vars -> + (nat -> 'vars -> M 's 'e (unit * 'vars)) -> M 's 'e (unit * 'vars) let rec foreachM_inc (i,stop,by) vars body = if i <= stop then @@ -43,8 +84,8 @@ let rec foreachM_inc (i,stop,by) vars body = else return ((),vars) -val foreachM_dec : forall 's 'vars. (nat * nat * nat) -> 'vars -> - (nat -> 'vars -> M 's (unit * 'vars)) -> M 's (unit * 'vars) +val foreachM_dec : forall 's 'e 'vars. (nat * nat * nat) -> 'vars -> + (nat -> 'vars -> M 's 'e (unit * 'vars)) -> M 's 'e (unit * 'vars) let rec foreachM_dec (i,stop,by) vars body = if i >= stop then @@ -52,72 +93,28 @@ let rec foreachM_dec (i,stop,by) vars body = foreachM_dec (i - by,stop,by) vars body else return ((),vars) - - -let slice (V bs start is_inc) n m = - let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in - let (_,suffix) = List.splitAt offset bs in - let (subvector,_) = List.splitAt length suffix in - V subvector n is_inc - -let update (V bs start is_inc) n m (V bs' _ _) = - let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in - let (prefix,_) = List.splitAt offset bs in - let (_,suffix) = List.splitAt (offset + length) bs in - V (prefix ++ (List.take length bs') ++ suffix) start is_inc - -let hd (x :: _) = x - -val access : forall 'a. vector 'a -> nat -> 'a -let access (V bs start is_inc) n = - if is_inc then nth bs (n - start) else nth bs (start - n) - -val update_pos : forall 'a. vector 'a -> nat -> 'a -> vector 'a -let update_pos v n b = - update v n n (V [b] 0 defaultDir) - -val read_reg_range : register -> (nat * nat) -> M state (vector bit) -let read_reg_range reg (i,j) s = - let v = slice (read_regstate s reg) i j in - (v,s) - -val read_reg_bit : register -> nat -> M state bit -let read_reg_bit reg i s = - let v = access (read_regstate s reg) i in - (v,s) - -val write_reg_range : register -> (nat * nat) -> vector bit -> M state unit -let write_reg_range (reg : register) (i,j) (v : vector bit) s = - let v' = update (read_regstate s reg) i j v in - let s' = write_regstate s reg v' in - ((),s') - -val write_reg_bit : register -> nat -> bit -> M state unit -let write_reg_bit reg i bit s = - let v = read_regstate s reg in - let v' = update_pos v i bit in - let s' = write_regstate s reg v' in - ((),s') - - -val read_reg : register -> M state (vector bit) -let read_reg reg s = - let v = read_regstate s reg in - (v,s) - -val write_reg : register -> vector bit -> M state unit -let write_reg reg v s = - let s' = write_regstate s reg v in - ((),s') - -val read_reg_field : register -> register_field -> M state (vector bit) +val read_reg_field : forall 'e. register -> register_field -> M state 'e (vector bit) let read_reg_field reg rfield = read_reg_range reg (field_indices rfield) -val write_reg_field : register -> register_field -> vector bit -> M state unit +val write_reg_field : forall 'e. register -> register_field -> vector bit -> M state 'e unit let write_reg_field reg rfield = write_reg_range reg (field_indices rfield) -val read_reg_field_bit : register -> register_field_bit -> M state bit +val read_reg_field_bit : forall 'e. register -> register_field_bit -> M state 'e bit let read_reg_field_bit reg rbit = read_reg_bit reg (field_index_bit rbit) -val write_reg_field_bit : register -> register_field_bit -> bit -> M state unit +val write_reg_field_bit : forall 'e. register -> register_field_bit -> bit -> M state 'e unit let write_reg_field_bit reg rbit = write_reg_bit reg (field_index_bit rbit) + + +let length l = integerFromNat (length l) + +let write_two_regs r1 r2 vec = + let size = length_reg r1 in + let start = get_start vec in + let vsize = length vec in + let r1_v = slice vec start ((if defaultDir then size - start else start - size) - 1) in + let r2_v = + (slice vec) + (if defaultDir then size - start else start - size) + (if defaultDir then vsize - start else start - vsize) in + write_reg r1 r1_v >> write_reg r2 r2_v diff --git a/src/gen_lib/vector.lem b/src/gen_lib/vector.lem index f409ceb7..5e78e010 100644 --- a/src/gen_lib/vector.lem +++ b/src/gen_lib/vector.lem @@ -1,9 +1,72 @@ open import Pervasives type bit = O | I | U -type vector 'a = V of list 'a * nat * bool +type vector 'a = V of list 'a * integer * bool -let rec nth xs (n : nat) = match (n,xs) with +let rec nth xs (n : integer) = match (n,xs) with | (0,x :: xs) -> x | (n + 1,x :: xs) -> nth xs n end + + +let to_bool = function + | O -> false + | I -> true + end + +let get_start (V _ s _) = s +let length (V bs _ _) = length bs + +let rec replace bs ((n : nat),b') = match (n,bs) with + | (_, []) -> [] + | (0, _::bs) -> b' :: bs + | (n+1, b::bs) -> b :: replace bs (n,b') + end + +let make_indexed_vector_reg entries default start length = + let (Just v) = default in + V (List.foldl replace (replicate length v) entries) start + +let make_indexed_vector_bit entries default start length = + let default = match default with Nothing -> U | Just v -> v end in + V (List.foldl replace (replicate length default) entries) start + +let make_bitvector_undef length = + V (replicate length U) 0 true + +let vector_concat (V bs start is_inc) (V bs' _ _) = + V (bs ++ bs') start is_inc + +let (^^) = vector_concat + +val slice : vector bit -> integer -> integer -> vector bit +let slice (V bs start is_inc) n m = + let n = natFromInteger n in + let m = natFromInteger m in + let start = natFromInteger start in + let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in + let (_,suffix) = List.splitAt offset bs in + let (subvector,_) = List.splitAt length suffix in + let n = integerFromNat n in + V subvector n is_inc + +let update (V bs start is_inc) n m (V bs' _ _) = + let n = natFromInteger n in + let m = natFromInteger m in + let start = natFromInteger start in + let (length,offset) = if is_inc then (m-n+1,n-start) else (n-m+1,start-n) in + let (prefix,_) = List.splitAt offset bs in + let (_,suffix) = List.splitAt (offset + length) bs in + let start = integerFromNat start in + V (prefix ++ (List.take length bs') ++ suffix) start is_inc + +let hd (x :: _) = x + + +val access : forall 'a. vector 'a -> (*nat*) integer -> 'a +let access (V bs start is_inc) n = + if is_inc then nth bs (n - start) else nth bs (start - n) + +val update_pos : forall 'a. vector 'a -> (*nat*) integer -> 'a -> vector 'a +let update_pos v n b = + update v n n (V [b] 0 true) diff --git a/src/pretty_print.ml b/src/pretty_print.ml index 457effbb..02fc62a5 100644 --- a/src/pretty_print.ml +++ b/src/pretty_print.ml @@ -1185,8 +1185,8 @@ let doc_def def = group (match def with let doc_defs (Defs(defs)) = separate_map hardline doc_def defs -let print ?(len=80) channel doc = ToChannel.pretty 1. len channel doc -let to_buf ?(len=80) buf doc = ToBuffer.pretty 1. len buf doc +let print ?(len=100) channel doc = ToChannel.pretty 1. len channel doc +let to_buf ?(len=100) buf doc = ToBuffer.pretty 1. len buf doc let pp_defs f d = print f (doc_defs d) let pp_exp b e = to_buf b (doc_exp e) @@ -1826,14 +1826,15 @@ let doc_id_lem_type (Id_aux(i,_)) = * token in case of x ending with star. *) parens (separate space [colon; string x; empty]) -let doc_id_lem_ctor (Id_aux(i,_)) = +let doc_id_lem_ctor aexp_needed (Id_aux(i,_)) = match i with | Id("bit") -> string "bit" | Id i -> string (String.capitalize i) | DeIid x -> (* add an extra space through empty to avoid a closing-comment * token in case of x ending with star. *) - parens (separate space [colon; string (String.capitalize x); empty]) + let epp = separate space [colon; string (String.capitalize x); empty] in + if aexp_needed then parens epp else epp let doc_typ_lem, doc_atomic_typ_lem = (* following the structure of parser for precedence *) @@ -1850,8 +1851,8 @@ let doc_typ_lem, doc_atomic_typ_lem = Typ_arg_aux(Typ_arg_nexp n, _); Typ_arg_aux(Typ_arg_nexp m, _); Typ_arg_aux (Typ_arg_order ord, _); - Typ_arg_aux (Typ_arg_typ typ, _)]) -> - string "vector" + Typ_arg_aux (Typ_arg_typ typa, _)]) -> + string "vector" ^^ space ^^ parens (typ typa) | Typ_app(Id_aux (Id "range", _), [ Typ_arg_aux(Typ_arg_nexp n, _); Typ_arg_aux(Typ_arg_nexp m, _);]) -> @@ -1883,7 +1884,10 @@ let doc_lit_lem in_pat (L_aux(l,_)) = | L_one -> "I" | L_false -> "O" | L_true -> "I" - | L_num i -> if i < 0 then "(0 " ^ string_of_int i ^ ")" else string_of_int i + | L_num i -> + let ipp = string_of_int i in + (if i < 0 then "((0"^ipp^") : integer)" + else "("^ipp^" : integer)") | L_hex n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*)*) | L_bin n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*)*) | L_undef -> "U" @@ -1904,7 +1908,7 @@ let doc_pat_lem apat_needed = | P_app(id, ((_ :: _) as pats)) -> (match annot with | Base(_,Constructor _,_,_,_,_) -> - doc_unop (doc_id_lem_ctor id) (separate_map space pat pats) + doc_unop (doc_id_lem_ctor true id) (separate_map space pat pats) | _ -> empty) | P_lit lit -> doc_lit_lem true lit | P_wild -> underscore @@ -1913,7 +1917,7 @@ let doc_pat_lem apat_needed = | P_typ(typ,p) -> doc_op colon (pat p) (doc_typ_lem typ) | P_app(id,[]) -> (match annot with - | Base(_,Constructor n,_,_,_,_) -> doc_id_lem_ctor id + | Base(_,Constructor n,_,_,_,_) -> doc_id_lem_ctor apat_needed id | _ -> empty) | P_vector pats -> let non_bit_print () = @@ -1938,33 +1942,56 @@ let doc_pat_lem apat_needed = | P_list pats -> brackets (separate_map semi pat pats) (*Never seen but easy in lem*) in pat +let rec getregtyp (LEXP_aux (le,(_,annot))) = match le with + | LEXP_id _ + | LEXP_cast _ -> + let (Base ((_,t),_,_,_,_,_)) = annot in + (match t with + | {t = Tabbrev ({t = Tid name},_)} -> name) + | LEXP_memory _ -> failwith "This lexp writes memory" + | LEXP_vector (le,_) + | LEXP_vector_range (le,_,_) + | LEXP_field (le,_) -> + getregtyp le + let doc_exp_lem, doc_let_lem = let rec top_exp (aexp_needed : bool) (E_aux (e, (_,annot))) = let exp = top_exp true in match e with | E_assign((LEXP_aux(le_act,tannot) as le),e) -> (* can only be register writes *) - let (_,(Base ((_,_),tag,_,_,_,_))) = tannot in - (separate space) + let (_,(Base ((_,{t = t}),tag,_,_,_,_))) = tannot in + let E_aux (_,(_,(Base ((_,{t = et}),_,_,_,_,_)))) = e in + let (f,args) = (match tag with | External _ -> (match le_act with | LEXP_vector (le,e2) -> - [string "write_reg_bit";doc_lexp_deref_lem le;exp e2;exp e] + (string "write_reg_bit",[doc_lexp_deref_lem le;exp e2;exp e]) | LEXP_vector_range (le,e2,e3) -> - [string "write_reg_range";doc_lexp_deref_lem le;exp e2;exp e3;exp e] + (string "write_reg_range",[doc_lexp_deref_lem le; + parens (exp e2 ^^ comma ^^ exp e3); + exp e]) | LEXP_field (lexp,id) -> - let (Base ((_,{t = t}),_,_,_,_,_)) = annot in - (match t with - | Tid "bit" -> - [string "write_reg_field_bit";doc_lexp_deref_lem le;doc_id_lem id;exp e] - | _ -> - [string "write_reg_field";doc_lexp_deref_lem le;doc_id_lem id;exp e]) - | (LEXP_id _ | LEXP_cast _) -> [string "write_reg";doc_lexp_deref_lem le;exp e]) - | _ -> [string "write_reg";doc_lexp_deref_lem le;exp e] - ) + let typprefix = String.uncapitalize (getregtyp lexp) ^ "_" in + (match et with + | Tid "bit" + | Tabbrev (_,{t=Tid "bit"}) -> + (string "write_reg_field_bit"), + [doc_lexp_deref_lem lexp;string typprefix ^^ doc_id_lem id;exp e] + | Tapp ("vector",_) + | Tabbrev (_,{t=Tapp ("vector",_)}) -> + (string "write_reg_field", + [doc_lexp_deref_lem lexp;string typprefix ^^ doc_id_lem id;exp e]) + | _ -> failwith (t_to_string {t = et}) + ) + | (LEXP_id _ | LEXP_cast _) -> (string "write_reg",[doc_lexp_deref_lem le;exp e])) + | _ -> (string "write_reg",[doc_lexp_deref_lem le;exp e]) + ) in + prefix 2 1 f (separate (break 1) args) | E_vector_append(l,r) -> - let epp = (separate space [exp l;string "^^";exp r]) in + let epp = + separate space [exp l;string "^^"] ^//^ exp r in if aexp_needed then parens epp else epp | E_cons(l,r) -> doc_op (group (colon^^colon)) (exp l) (exp r) | E_if(c,t,e) -> @@ -1975,9 +2002,9 @@ let doc_exp_lem, doc_let_lem = | Base ((_,({t = Tid "bit"})),_,_,_,_,_) -> separate space [string "if";string "to_bool";exp c] | _ -> separate space [string "if";exp c]) - ^/^ - (prefix 2 1 (string "then") (exp t)) ^^ (break 1) ^^ - (prefix 2 1 (string "else") (exp e))) in + ^^ break 1 ^^ + (prefix 2 1 (string "then") (top_exp false t)) ^^ (break 1) ^^ + (prefix 2 1 (string "else") (top_exp false e))) in if aexp_needed then parens epp else epp | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> failwith "E_for should have been removed till now" @@ -2002,19 +2029,28 @@ let doc_exp_lem, doc_let_lem = ) ) | _ -> - let call = match annot with - | Base(_,External (Some n),_,_,_,_) -> - (match n with - | "bitwise_not_bit" -> string "~" - | _ -> string n) - | Base(_,Constructor _,_,_,_,_) -> doc_id_lem_ctor f - | _ -> doc_id_lem f in - let epp = - (doc_unop call) - (match args with - | [a] -> exp a - | args -> parens (separate_map comma exp args)) in - if aexp_needed then parens epp else epp + (match annot with + | Base (_,Constructor _,_,_,_,_) -> + let epp = separate space [doc_id_lem f;separate_map space (top_exp true) args] in + if aexp_needed then parens epp else epp + | Base (_,External (Some "bitwise_not_bit"),_,_,_,_) -> + let [a] = args in + let epp = string "~" ^^ exp a in + if aexp_needed then parens epp else epp + | _ -> + let call = match annot with + | Base(_,External (Some n),_,_,_,_) -> + (match n with + | _ -> string n) + | Base(_,Constructor _,_,_,_,_) -> doc_id_lem_ctor false f + | _ -> doc_id_lem f in + let epp = + (doc_unop call) + (match args with + | [a] -> exp a + | args -> parens (separate_map comma (top_exp false) args)) in + if aexp_needed then parens epp else epp + ) ) | E_vector_access(v,e) -> let epp = separate space [string "access";exp v;exp e] in @@ -2023,16 +2059,16 @@ let doc_exp_lem, doc_let_lem = let epp = (string "slice") ^^ space ^^ (exp v) ^^ space ^^ (exp e1) ^^ space ^^ (exp e2) in if aexp_needed then parens epp else epp | E_field((E_aux(_,(_,fannot)) as fexp),id) -> - (match fannot with - | Base((_,{t= Tapp("register",_)}),_,_,_,_,_) | - Base((_,{t= Tabbrev(_,{t=Tapp("register",_)})}),_,_,_,_,_)-> - let field_f = match annot with - | Base((_,{t = Tid "bit"}),_,_,_,_,_) | - Base((_,{t = Tabbrev(_,{t=Tid "bit"})}),_,_,_,_,_) -> + let (Base ((_,{t = t}),_,_,_,_,_)) = fannot in + (match t with + | Tabbrev({t = Tid regtyp},{t=Tapp("register",_)}) -> + let field_f = match annot with + | Base((_,{t = Tid "bit"}),_,_,_,_,_) + | Base((_,{t = Tabbrev(_,{t=Tid "bit"})}),_,_,_,_,_) -> string "read_reg_field_bit" - | _ -> string "read_reg_field" in - - let epp = field_f ^^ space ^^ (exp fexp) ^^ space ^^ string_lit (doc_id_lem id) in + | _ -> string "read_reg_field" in + let epp = field_f ^^ space ^^ (exp fexp) ^^ space ^^ + string (regtyp ^ "_") ^^ doc_id_lem id in if aexp_needed then parens epp else epp | _ -> exp fexp ^^ dot ^^ doc_id_lem id) | E_block [] -> string "()" @@ -2046,7 +2082,7 @@ let doc_exp_lem, doc_let_lem = (match tag with | External _ -> separate space [string "read_reg";doc_id_lem id] | _ -> doc_id_lem id) - | Base(_,(Constructor i |Enum i),_,_,_,_) -> doc_id_lem_ctor id + | Base(_,(Constructor i |Enum i),_,_,_,_) -> doc_id_lem_ctor aexp_needed id | Base((_,t),Alias alias_info,_,_,_,_) -> (match alias_info with | Alias_field(reg,field) -> @@ -2108,6 +2144,7 @@ let doc_exp_lem, doc_let_lem = | Nconst i -> string_of_big_int i | N2n(_,Some i) -> string_of_big_int i | _ -> if dir then "0" else string_of_int (List.length exps) in + let epp = group (separate space [string "V"; brackets (separate_map (semi) exp exps); string start;string dir_out]) in if aexp_needed then parens epp else epp @@ -2149,7 +2186,7 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (Reporting_basic.err_unreachable dl "nono") in parens (string "Just " ^^ parens (string ("UndefinedReg " ^ string_of_big_int n)))) in - let iexp (i,e) = parens (separate_map comma_sp (fun x -> x) [(doc_int i); (exp e)]) in + let iexp (i,e) = parens (separate_map comma (fun x -> x) [(doc_int i); (exp e)]) in let epp = (separate space) [call;(brackets (separate_map semi iexp iexps)); @@ -2178,12 +2215,12 @@ let doc_exp_lem, doc_let_lem = (match annot with | Base((_,t),External(Some name),_,_,_,_) -> let epp = match name with - | "bitwise_and_bit" -> separate space [exp e1;string "&.";exp e2] - | "bitwise_or_bit" -> separate space [exp e1;string "|.";exp e2] - | "bitwise_xor_bit" -> separate space [exp e1;string "+.";exp e2] - | "add" -> separate space [exp e1;string "+";exp e2] - | "minus" -> separate space [exp e1;string "-";exp e2] - | "multiply" -> separate space [exp e1;string "*";exp e2] + | "bitwise_and_bit" -> separate space [exp e1;string "&."] ^//^ exp e2 + | "bitwise_or_bit" -> separate space [exp e1;string "|."] ^//^ exp e2 + | "bitwise_xor_bit" -> separate space [exp e1;string "+."] ^//^ exp e2 + | "add" -> separate space [exp e1;string "+";exp e2] + | "minus" -> separate space [exp e1;string "-";exp e2] + | "multiply" -> separate space [exp e1;string "*";exp e2] (* | "lt" -> separate space [exp e1;string "<";exp e2] | "gt" -> separate space [exp e1;string ">";exp e2] | "lteq" -> separate space [exp e1;string "<=";exp e2] @@ -2192,29 +2229,32 @@ let doc_exp_lem, doc_let_lem = | "gt_vec" -> separate space [exp e1;string ">";exp e2] | "lteq_vec" -> separate space [exp e1;string "<=";exp e2] | "gteq_vec" -> separate space [exp e1;string ">=";exp e2] *) - | _ -> separate space [string name; parens (separate_map comma exp [e1;e2])] in + | _ -> separate space [string name; parens (separate_map comma (top_exp false) [e1;e2])] in if aexp_needed then parens epp else epp | _ -> - let epp = separate space [doc_id_lem id; parens (separate_map comma exp [e1;e2])] in + let epp = separate space [doc_id_lem id; parens (separate_map comma (top_exp false) [e1;e2])] in if aexp_needed then parens epp else epp) | E_internal_let(lexp, eq_exp, in_exp) -> - failwith "E_internal_lets should have been removed till now" -(* (separate + (* failwith "E_internal_lets should have been removed till now" *) + (separate space - [string "let TAKTAK"; - doc_lexp_lem true lexp; (*Rewriter/typecheck should ensure this is only cast or id*) + [string "let internal"; + (match lexp with (LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_)) -> doc_id_lem id); coloneq; exp eq_exp; string "in"]) ^/^ - exp in_exp *) + exp in_exp | E_internal_plet (pat,e1,e2) -> - (match pat with - | P_aux (P_wild,_) -> - (separate space [exp e1; string ">>"]) ^/^ - top_exp false e2 - | _ -> - (separate space [exp e1; string ">>= fun"; doc_pat_lem true pat;arrow]) ^/^ - top_exp false e2) + let epp = + let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in + match pat with + | P_aux (P_wild,_) -> + (separate space [top_exp b e1; string ">>"]) ^/^ + top_exp false e2 + | _ -> + (separate space [top_exp b e1; string ">>= fun"; doc_pat_lem true pat;arrow]) ^/^ + top_exp false e2 in + if aexp_needed then parens epp else epp | E_internal_return (e1) -> separate space [string "return"; exp e1;] and let_exp (LB_aux(lb,_)) = match lb with @@ -2242,8 +2282,8 @@ let doc_exp_lem, doc_let_lem = (*TODO Upcase and downcase type and constructors as needed*) let doc_type_union_lem (Tu_aux(typ_u,_)) = match typ_u with - | Tu_ty_id(typ,id) -> separate space [pipe; doc_id_lem_ctor id; string "of"; doc_typ_lem typ;] - | Tu_id id -> separate space [pipe; doc_id_lem_ctor id] + | Tu_ty_id(typ,id) -> separate space [pipe; doc_id_lem_ctor false id; string "of"; doc_typ_lem typ;] + | Tu_id id -> separate space [pipe; doc_id_lem_ctor false id] let rec doc_range_lem (BF_aux(r,_)) = match r with | BF_single i -> parens (doc_op comma (doc_int i) (doc_int i)) @@ -2265,7 +2305,7 @@ let doc_typdef_lem (TD_aux(td,_)) = match td with (concat [string "type"; space; doc_id_lem_type id;]) (doc_typquant_lem typq ar_doc) | TD_enum(id,nm,enums,_) -> - let enums_doc = group (separate_map (break 1 ^^ pipe ^^ space) doc_id_lem_ctor enums) in + let enums_doc = group (separate_map (break 1 ^^ pipe ^^ space) (doc_id_lem_ctor false) enums) in doc_op equals (concat [string "type"; space; doc_id_lem_type id;]) (enums_doc) @@ -2279,7 +2319,7 @@ let doc_tannot_opt_lem (Typ_annot_opt_aux(t,_)) = match t with | Typ_annot_opt_some(tq,typ) -> doc_typquant_lem tq (doc_typ_lem typ) let doc_funcl_lem (FCL_aux(FCL_Funcl(id,pat,exp),_)) = - group (doc_op arrow (doc_pat_lem false pat) (doc_exp_lem false exp)) + group (prefix 3 1 ((doc_pat_lem false pat) ^^ space ^^ arrow) (doc_exp_lem false exp)) let get_id = function | [] -> failwith "FD_function with empty list" @@ -2352,16 +2392,22 @@ let reg_decls (Defs defs) = | _ -> (regtypes,rsranges,rsbits,defs @ [def]) ) ([],[],[],[]) defs in - let (regs,defs) = + let (regs,regaliases,defs) = List.fold_left - (fun (regs,defs) def -> + (fun (regs,regaliases,defs) def -> match def with | DEF_reg_dec (DEC_aux (DEC_reg(Typ_aux (Typ_id (Id_aux (Id typ,_)),_),Id_aux (Id name,_)),_)) -> - (regs @ [(name,Some typ)],defs) + (regs @ [(name,Some typ)],regaliases,defs) | DEF_reg_dec (DEC_aux (DEC_reg(_, Id_aux (Id name,_)),_)) -> - (regs @ [(name,None)],defs) - | def -> (regs,defs @ [def]) - ) ([],[]) defs in + (regs @ [(name,None)],regaliases,defs) + | DEF_reg_dec + (DEC_aux (DEC_alias + (Id_aux (Id name1,_), + AL_aux (AL_concat (RI_aux (RI_id (Id_aux (Id name2,_)),_), + RI_aux (RI_id (Id_aux (Id name3,_)),_)),_)),_)) -> + (regs,regaliases @ [(name1,(name2,name3))],defs) + | def -> (regs,regaliases,defs @ [def]) + ) ([],[],[]) defs in (* maybe we need a function that analyses the spec for this as well *) let default = @@ -2373,7 +2419,8 @@ let reg_decls (Defs defs) = (prefix 2 1) (separate space [string "type";string "register";equals]) ((separate_map space (fun (reg,_) -> pipe ^^ space ^^ string reg) regs) - ^^ space ^^ pipe ^^ space ^^ string "UndefinedReg of nat") in + ^^ space ^^ pipe ^^ space ^^ string "UndefinedReg of integer" ^^ + pipe ^^ space ^^ string "RegisterPair of register * register") in let regfields_pp = (prefix 2 1) @@ -2387,6 +2434,12 @@ let reg_decls (Defs defs) = (separate_map space (fun (fname,tname,_) -> pipe ^^ space ^^ string (tname ^ "_" ^ fname)) rsbits) in + let regalias_pp = + (separate_map (break 1)) + (fun (name1,(name2,name3)) -> + separate space [string "let";string name1;equals;string "RegisterPair";string name2;string name3]) + regaliases in + let state_pp = (prefix 2 1) (separate space [string "type";string "state";equals]) @@ -2397,10 +2450,10 @@ let reg_decls (Defs defs) = )) in let length_pp = - (separate space [string "val";string "length_reg";colon;string "register";arrow;string "nat"]) + (separate space [string "val";string "length_reg";colon;string "register";arrow;string "integer"]) ^/^ (prefix 2 1) - (separate space [string "let";string "length_reg";equals;string "function"]) + (separate space [string "let rec";string "length_reg";string "reg";equals;string "match reg with"]) (((separate_map (break 1)) (fun (name,typ) -> let ((n1,n2,_,_),typname) = @@ -2412,13 +2465,15 @@ let reg_decls (Defs defs) = regs) ^/^ separate space [pipe;string "UndefinedReg n";arrow; string "failwith \"Trying to compute length of undefined register\""] ^/^ + separate space [pipe;string "RegisterPair r1 r2";arrow; + string "length_reg r1 + length_reg r2"] ^/^ string "end") in let field_indices_pp = (prefix 2 1) ((separate space) [string "let";string "field_indices"; - colon;string "register_field";arrow;string "(nat * nat)"; + colon;string "register_field";arrow;string "(integer * integer)"; equals;string "function"]) ( ((separate_map (break 1)) @@ -2433,7 +2488,7 @@ let reg_decls (Defs defs) = let field_index_bit_pp = (prefix 2 1) ((separate space) [string "let";string "field_index_bit"; - colon;string "register_field_bit";arrow;string "nat"; + colon;string "register_field_bit";arrow;string "integer"; equals;string "function"]) ( ((separate_map (break 1)) @@ -2447,7 +2502,7 @@ let field_index_bit_pp = let read_regstate_pp = (prefix 2 1) - (separate space [string "let";string "read_regstate";string "s";equals;string "function"]) + (separate space [string "let rec";string "read_regstate";string "s";equals;string "function"]) ( ((separate_map (break 1)) (fun (name,_) -> @@ -2455,11 +2510,13 @@ let field_index_bit_pp = regs) ^/^ separate space [pipe;string "UndefinedReg n";arrow; string "failwith \"Trying to read from undefined register\""] ^/^ + separate space [pipe;string "RegisterPair r1 r2";arrow; + string "read_regstate s r1 ^^ read_regstate s r2"] ^/^ string "end" ^^ hardline ) in let write_regstate_pp = (prefix 2 1) - (separate space [string "let";string "write_regstate";string "s";string "reg";string "v"; + (separate space [string "let rec";string "write_regstate";string "s";string "reg";string "v"; equals;string "match reg with"]) ( ((separate_map (break 1)) @@ -2474,11 +2531,26 @@ let field_index_bit_pp = ) regs) ^/^ separate space [pipe;string "UndefinedReg n";arrow; string "failwith \"Trying to write to undefined register\""] ^/^ + ((prefix 3 1) + (separate space [pipe;string "RegisterPair r1 r2";arrow]) + ((separate (break 1)) + [ + string "let size = length_reg r1 in"; + string "let start = get_start v in"; + string "let vsize = length v in"; + string "let vsize = integerFromNat vsize in"; + string ("let r1_v = slice v start " ^ + (if is_inc then "(size - start - 1) in" else "(start - size) - 1) in")); + string ("let r2_v = slice v " ^ + (if is_inc then "(size - start)" else "(start - size)") ^ + (if is_inc then "(vsize - start) in" else ("start - vsize) in"))); + string "write_regstate (write_regstate s r1 r1_v) r2 r2_v" + ])) ^/^ string "end" ^^ hardline ) in (separate (hardline ^^ hardline) [dir_pp;regs_pp;regfields_pp;regfieldsbit_pp;field_index_bit_pp;field_indices_pp; - state_pp;length_pp;read_regstate_pp;write_regstate_pp],defs) + regalias_pp;state_pp;length_pp;read_regstate_pp;write_regstate_pp],defs) let doc_defs_lem (Defs defs) = let (decls,defs) = reg_decls (Defs defs) in diff --git a/src/rewriter.ml b/src/rewriter.ml index 3c4ff5e8..c7e6986b 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -25,6 +25,43 @@ let fresh_name () = let () = fresh_name_counter := (current + 1) in current +let geteffs_annot (_,t) = match t with + | Base (_,_,_,_,effs,_) -> effs + | NoTyp -> failwith "no effect information" + | _ -> failwith "a_normalise doesn't support Overload" + +let gettype_annot (_,t) = match t with + | Base((_,t),_,_,_,_,_) -> t + | NoTyp -> failwith "no type information" + | _ -> failwith "a_normalise doesn't support Overload" + +let gettype (E_aux (_,a)) = gettype_annot a +let geteffs (E_aux (_,a)) = geteffs_annot a + +let effectful_effs {effect = Eset effs} = + List.exists + (fun (BE_aux (be,_)) -> + match be with + | BE_nondet | BE_unspec | BE_undef | BE_lset -> false + | _ -> true + ) effs + +let effectful eaux = + effectful_effs (geteffs eaux) + +let updates_vars_effs {effect = Eset effs} = + List.exists + (fun (BE_aux (be,_)) -> + match be with + | BE_lset -> true + | _ -> false + ) effs + +let updates_vars eaux = + updates_vars_effs (geteffs eaux) + +let eff_union es = + List.fold_left (fun acc e -> union_effects acc (geteffs e)) pure_e es let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with | [] -> None @@ -871,8 +908,9 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful let le' = rewriters.rewrite_lexp rewriters nmap le in let e' = rewrite_base e in let exps' = walker exps in - [(E_aux (E_internal_let(le', e', E_aux(E_block exps', (l, simple_annot {t=Tid "unit"}))), - (l, simple_annot t)))] + let effects = eff_union exps' in + [E_aux (E_internal_let(le', e', E_aux(E_block exps', (l, simple_annot_efr {t=Tid "unit"} effects))), + (l, simple_annot_efr t (eff_union (e::exps'))))] | ((E_aux(E_if(c,t,e),(l,annot))) as exp)::exps -> let vars_t = introduced_variables t in let vars_e = introduced_variables e in @@ -886,8 +924,9 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful let c' = rewrite_base c in let t' = rewriters.rewrite_exp rewriters new_nmap t in let e' = rewriters.rewrite_exp rewriters new_nmap e in - Envmap.fold - (fun res i (t,e) -> + let exps' = walker exps in + fst ((Envmap.fold + (fun (res,effects) i (t,e) -> let bitlit = E_aux (E_lit (L_aux(L_zero, Parse_ast.Unknown)), (Parse_ast.Unknown, simple_annot bit_t)) in let rangelit = E_aux (E_lit (L_aux (L_num 0, Parse_ast.Unknown)), @@ -907,12 +946,13 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful (Parse_ast.Unknown,simple_annot bit_t))), (Parse_ast.Unknown, simple_annot t)) | _ -> e in - [E_aux (E_internal_let (LEXP_aux (LEXP_id (Id_aux (Id i, Parse_ast.Unknown)), + let unioneffs = union_effects effects (geteffs set_exp) in + ([E_aux (E_internal_let (LEXP_aux (LEXP_id (Id_aux (Id i, Parse_ast.Unknown)), (Parse_ast.Unknown, (tag_annot t Emp_intro))), set_exp, - E_aux (E_block res, (Parse_ast.Unknown, (simple_annot unit_t)))), - (Parse_ast.Unknown, simple_annot unit_t))]) - (E_aux(E_if(c',t',e'), (Parse_ast.Unknown, annot))::(walker exps)) new_vars + E_aux (E_block res, (Parse_ast.Unknown, (simple_annot_efr unit_t effects)))), + (Parse_ast.Unknown, simple_annot_efr unit_t unioneffs))],unioneffs))) + (E_aux(E_if(c',t',e'), (Parse_ast.Unknown, annot))::exps',eff_union exps') new_vars) | e::exps -> (rewrite_rec e)::(walker exps) in rewrap (E_block (walker exps)) @@ -923,7 +963,7 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful (match le' with | LEXP_aux(_, (_,Base(_,Emp_intro,_,_,_,_))) -> let e' = rewrite_base e in - rewrap (E_internal_let(le', e', E_aux(E_block [], (l, simple_annot unit_t)))) + rewrap (E_internal_let(le', e', E_aux(E_block [], (l, simple_annot_efr unit_t (geteffs e'))))) | _ -> E_aux((E_assign(le', rewrite_base e)),(l, tag_annot unit_t Emp_set))) | _ -> rewrite_base full_exp) | _ -> rewrite_base full_exp @@ -961,52 +1001,12 @@ let rewrite_defs_ocaml defs = let defs_lifted_assign = rewrite_defs_exp_lift_assign defs_vec_concat_removed in defs_lifted_assign - - -let geteffs_annot (_,t) = match t with - | Base (_,_,_,_,effs,_) -> effs - | NoTyp -> failwith "no effect information" - | _ -> failwith "a_normalise doesn't support Overload" - -let gettype_annot (_,t) = match t with - | Base((_,t),_,_,_,_,_) -> t - | NoTyp -> failwith "no type information" - | _ -> failwith "a_normalise doesn't support Overload" - -let gettype (E_aux (_,a)) = gettype_annot a -let geteffs (E_aux (_,a)) = geteffs_annot a - -let effectful_effs {effect = Eset effs} = - List.exists - (fun (BE_aux (be,_)) -> - match be with - | BE_nondet | BE_unspec | BE_undef | BE_lset -> false - | _ -> true - ) effs - -let effectful eaux = - effectful_effs (geteffs eaux) - -let updates_vars_effs {effect = Eset effs} = - List.exists - (fun (BE_aux (be,_)) -> - match be with - | BE_lset -> true - | _ -> false - ) effs - -let updates_vars eaux = - updates_vars_effs (geteffs eaux) - -let eff_union e1 e2 = union_effects (geteffs e1) (geteffs e2) - -let remove_blocks_exp_alg = +let remove_blocks = let letbind_wild v body = - let annot_pat = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in - let annot_lb = annot_pat in - let annot_let = - (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v body)) in + let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in + let annot_lb = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in + let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_wild,annot_pat),v),annot_lb),body),annot_let) in let rec f = function @@ -1018,7 +1018,7 @@ let remove_blocks_exp_alg = | (E_block es,annot) -> f es | (e,annot) -> E_aux (e,annot) in - { id_exp_alg with e_aux = e_aux } + fold_exp { id_exp_alg with e_aux = e_aux } let fresh_id annot = @@ -1036,7 +1036,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = let body = body e in let annot_pat = (Parse_ast.Unknown,simple_annot unit_t) in let annot_lb = annot_pat in - let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union v body)) in + let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in let pat = P_aux (P_wild,annot_pat) in if effectful v @@ -1049,7 +1049,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in let annot_lb = annot_pat in - let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union v body)) in + let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in let pat = P_aux (P_id id,annot_pat) in if effectful v @@ -1106,7 +1106,7 @@ let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = and n_exp_pure (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = n_exp exp (fun exp -> if not (effectful exp || updates_vars exp) then k exp else letbind exp k) - + and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp = mapCont n_exp_name exps k @@ -1170,7 +1170,6 @@ and n_lexp (lexp : 'a lexp) (k : 'a lexp -> 'a exp) : 'a exp = k (LEXP_aux (LEXP_field (lexp,id),local_eff_plus effs annot))) and n_exp_term (new_return : bool) (exp : 'a exp) : 'a exp = - let (E_aux (_,annot)) = exp in let exp = if new_return then E_aux (E_internal_return exp,(Unknown,simple_annot_efr (gettype exp) (geteffs exp))) @@ -1187,8 +1186,8 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = let (E_aux (exp_aux,annot)) = exp in let rewrap_effs effsum exp_aux = (* explicitly give effect sum *) - let (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds)) = annot in - E_aux (exp_aux, (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds))) in + let (l,Base (t,tag,nexps,eff,_,bounds)) = annot in + E_aux (exp_aux, (l,Base (t,tag,nexps,eff,effsum,bounds))) in let rewrap_localeff exp_aux = (* give exp_aux the local effect as the effect sum *) E_aux (exp_aux,only_local_eff annot) in @@ -1218,7 +1217,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = let new_return = effectful exp2 || effectful exp3 in let exp2 = n_exp_term new_return exp2 in let exp3 = n_exp_term new_return exp3 in - k (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3)))) + k (rewrap_effs (eff_union [exp2;exp3]) (E_if (exp1,exp2,exp3)))) | E_for (id,start,stop,by,dir,body) -> n_exp_name start (fun start -> n_exp_name stop (fun stop -> @@ -1292,7 +1291,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = | LB_aux (LB_val_explicit (_,pat,exp'),annot') | LB_aux (LB_val_implicit (pat,exp'),annot') -> if effectful exp' - then (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',n_exp body k))) + then (rewrap_effs (eff_union [exp';body]) (E_internal_plet (pat,exp',n_exp body k))) else (rewrap_effs (geteffs body) (E_let (lb,n_exp body k))) )) | E_assign (lexp,exp1) -> @@ -1309,7 +1308,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = (if effectful exp1 then n_exp_name exp1 else n_exp exp1) (fun exp1 -> - (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,n_exp exp2 k)))) + rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,n_exp exp2 k))) | E_internal_return exp1 -> n_exp_name exp1 (fun exp1 -> k (rewrap_localeff (E_internal_return exp1))) @@ -1318,7 +1317,8 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = let rewrite_defs_a_normalise = let rewrite_exp _ _ e = - n_exp_term (effectful e) (fold_exp remove_blocks_exp_alg e) in + let e = remove_blocks e in + n_exp_term (effectful e) e in rewrite_defs_base {rewrite_exp = rewrite_exp ; rewrite_pat = rewrite_pat @@ -1376,7 +1376,8 @@ let find_updated_vars exp = ; e_internal_cast = (fun (_,e1) -> e1) ; e_internal_exp = (fun _ -> []) ; e_internal_exp_user = (fun _ -> []) - ; e_internal_let = (fun ((None,[(id,_)]), e2, e3) -> List.filter (eqidtyp id) (e2 @ e3)) + ; e_internal_let = (fun ((None,[(id,_)]), e2, e3) -> + List.filter (fun id2 -> not (eqidtyp id id2)) (e2 @ e3)) ; e_internal_plet = (fun (_, e1, e2) -> e1 @ e2) ; e_internal_return = (fun e -> e) ; e_aux = (fun (e,_) -> e) @@ -1523,7 +1524,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = (* after a-normalisation c shouldn't need rewriting *) let t = gettype e1 in (* let () = assert (simple_annot t = simple_annot (gettype e2)) in *) - let v = E_aux (E_if (c,e1,e2), (Unknown,simple_annot_efr t (eff_union e1 e2))) in + let v = E_aux (E_if (c,e1,e2), (Unknown,simple_annot_efr t (eff_union [e1;e2]))) in let pat = (* if overwrite then mktup_pat vars @@ -1615,7 +1616,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = (match rewrite v pat with | Added_vars (v,pat) -> E_aux (E_internal_plet (pat,v,body), - (Unknown,simple_annot_efr (gettype body) (eff_union v body))) + (Unknown,simple_annot_efr (gettype body) (eff_union [v;body]))) | Same_vars v -> E_aux (E_internal_plet (pat,v,body),annot)) | E_internal_let (lexp,v,body) -> (* After a-normalisation E_internal_lets can only bind values to names, those don't @@ -1627,7 +1628,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = let pat = P_aux (P_id id, (Parse_ast.Unknown,simple_annot (gettype v))) in let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in let lb = LB_aux (LB_val_implicit (pat,v),lbannot) in - E_aux (E_let (lb,body),(Unknown,simple_annot_efr (gettype body) (eff_union v body))) + E_aux (E_let (lb,body),(Unknown,simple_annot_efr (gettype body) (eff_union [v;body]))) (* In tail-position there shouldn't be anything we need to do as the terms after * a-normalisation are pure and don't update local variables. There can't be any variable * assignments in tail-position (because of the effect), there could be pure pattern-match |
