diff options
| author | Alasdair Armstrong | 2017-07-05 16:51:26 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2017-07-05 16:51:26 +0100 |
| commit | 5f46cb01ed07fbca390535440f77f1816cb61684 (patch) | |
| tree | c2874e1b085b4101f7eb295eb37fcd0acaa2a2d5 /src | |
| parent | ddc487757360eb2c0ec1adaf2f333c23df3f250a (diff) | |
| parent | 9cb879efde58abfd5cc4ae8b2d0344902c983cde (diff) | |
Merge remote-tracking branch 'origin/word' into sail_new_tc
Diffstat (limited to 'src')
| -rw-r--r-- | src/Makefile | 2 | ||||
| -rw-r--r-- | src/gen_lib/prompt.lem | 4 | ||||
| -rw-r--r-- | src/gen_lib/sail_values.lem | 241 | ||||
| -rw-r--r-- | src/gen_lib/sail_values_word.lem | 1026 | ||||
| -rw-r--r-- | src/gen_lib/state.lem | 72 | ||||
| -rw-r--r-- | src/monomorphise.ml | 621 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 288 | ||||
| -rw-r--r-- | src/pretty_print_ocaml.ml | 6 | ||||
| -rw-r--r-- | src/rewriter.ml | 570 | ||||
| -rw-r--r-- | src/sail.ml | 14 | ||||
| -rw-r--r-- | src/type_check.ml | 2 | ||||
| -rw-r--r-- | src/type_internal.ml | 41 | ||||
| -rw-r--r-- | src/type_internal.mli | 3 | ||||
| -rw-r--r-- | src/util.ml | 7 | ||||
| -rw-r--r-- | src/util.mli | 3 |
15 files changed, 1611 insertions, 1289 deletions
diff --git a/src/Makefile b/src/Makefile index be1eb9e5..a58646b1 100644 --- a/src/Makefile +++ b/src/Makefile @@ -40,7 +40,7 @@ # SUCH DAMAGE. # ########################################################################## -.PHONY: all sail test clean doc lib power test_power test_idempotence +.PHONY: all sail sail.native sail.byte test clean doc lib power test_power test_idempotence # set to -p on command line to enable gprof profiling OCAML_OPTS?= diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem index 426b0811..70850dc1 100644 --- a/src/gen_lib/prompt.lem +++ b/src/gen_lib/prompt.lem @@ -71,12 +71,12 @@ let read_reg_range reg i j = read_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger j)) let read_reg_bit reg i = read_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger i)) >>= fun v -> - return (extract_only_bit v) + return (extract_only_element v) let read_reg_field reg regfield = read_reg_aux (external_reg_field_whole reg regfield) let read_reg_bitfield reg regfield = read_reg_aux (external_reg_field_whole reg regfield) >>= fun v -> - return (extract_only_bit v) + return (extract_only_element v) val write_reg_aux : reg_name -> vector bitU -> M unit let write_reg_aux reg_name v = diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index 4fded5a1..d7318567 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -1,4 +1,7 @@ +(* Version of sail_values.lem that uses Lem's machine words library *) + open import Pervasives_extra +open import Machine_word open import Sail_impl_base @@ -197,58 +200,141 @@ val update_pos : forall 'a. vector 'a -> integer -> 'a -> vector 'a let update_pos v n b = update_aux v n n [b] +(*** Bitvectors *) + +(* element list * start * has increasing direction *) +type bitvector 'a = Bitvector of mword 'a * integer * bool +declare isabelle target_sorts bitvector = `len` + +let showBitvector (Bitvector elems start inc) = + "Bitvector " ^ show elems ^ " " ^ show start ^ " " ^ show inc + +let bvget_dir (Bitvector _ _ ord) = ord +let bvget_start (Bitvector _ s _) = s +let bvget_elems (Bitvector elems _ _) = elems +let bvlength (Bitvector bs _ _) = integerFromNat (word_length bs) + +instance forall 'a. (Show (bitvector 'a)) + let show = showBitvector +end + +let bvec_to_vec (Bitvector bs start is_inc) = + let bits = List.map bool_to_bitU (bitlistFromWord bs) in + Vector bits start is_inc + +let vec_to_bvec (Vector elems start is_inc) = + let word = wordFromBitlist (List.map bitU_to_bool elems) in + Bitvector word start is_inc + +(*** Vector operations *) + +val set_bitvector_start : forall 'a. integer -> bitvector 'a -> bitvector 'a +let set_bitvector_start new_start (Bitvector bs _ is_inc) = + Bitvector bs new_start is_inc + +let reset_bitvector_start v = + set_bitvector_start (if (bvget_dir v) then 0 else (bvlength v - 1)) v + +let set_bitvector_start_to_length v = + set_bitvector_start (bvlength v - 1) v + +let bitvector_concat (Bitvector bs start is_inc) (Bitvector bs' _ _) = + Bitvector (word_concat bs bs') start is_inc + +let inline (^^^) = bitvector_concat + +val bvslice : forall 'a 'b. bitvector 'a -> integer -> integer -> bitvector 'b +let bvslice (Bitvector bs start is_inc) i j = + let iN = natFromInteger i in + let jN = natFromInteger j in + let startN = natFromInteger start in + let (lo,hi) = if is_inc then (iN-startN,jN-startN) else (startN-iN,startN-jN) in + let subvector_bits = word_extract lo hi bs in + Bitvector subvector_bits i is_inc + +(* this is for the vector slicing introduced in vector-concat patterns: i and j +index into the "raw data", the list of bits. Therefore getting the bit list is +easy, but the start index has to be transformed to match the old vector start +and the direction. *) +val bvslice_raw : forall 'a 'b. Size 'b => bitvector 'a -> integer -> integer -> bitvector 'b +let bvslice_raw (Bitvector bs start is_inc) i j = + let iN = natFromInteger i in + let jN = natFromInteger j in + let bits = word_extract iN jN bs in + let len = integerFromNat (word_length bits) in + Bitvector bits (if is_inc then 0 else len - 1) is_inc + +val bvupdate_aux : forall 'a 'b. bitvector 'a -> integer -> integer -> mword 'b -> bitvector 'a +let bvupdate_aux (Bitvector bs start is_inc) i j bs' = + let iN = natFromInteger i in + let jN = natFromInteger j in + let startN = natFromInteger start in + let (lo,hi) = if is_inc then (iN-startN,jN-startN) else (startN-iN,startN-jN) in + let bits = word_update bs lo hi bs' in + Bitvector bits start is_inc + +val bvupdate : forall 'a 'b. bitvector 'a -> integer -> integer -> bitvector 'b -> bitvector 'a +let bvupdate v i j (Bitvector bs' _ _) = + bvupdate_aux v i j bs' + +(* TODO: decide between nat/natural, change either here or in machine_word *) +val getBit' : forall 'a. mword 'a -> nat -> bool +let getBit' w n = getBit w (naturalFromNat n) + +val bvaccess : forall 'a. bitvector 'a -> integer -> bitU +let bvaccess (Bitvector bs start is_inc) n = bool_to_bitU ( + if is_inc then getBit' bs (natFromInteger (n - start)) + else getBit' bs (natFromInteger (start - n))) + +val bvupdate_pos : forall 'a. Size 'a => bitvector 'a -> integer -> bitU -> bitvector 'a +let bvupdate_pos v n b = + bvupdate_aux v n n ((wordFromNatural (if bitU_to_bool b then 1 else 0)) : mword ty1) (*** Bit vector operations *) -let extract_only_bit (Vector elems _ _) = match elems with - | [] -> failwith "extract_single_bit called for empty vector" +let extract_only_element (Vector elems _ _) = match elems with + | [] -> failwith "extract_only_element called for empty vector" | [e] -> e - | _ -> failwith "extract_single_bit called for vector with more bits" + | _ -> failwith "extract_only_element called for vector with more elements" end +val extract_only_bit : bitvector ty1 -> bitU +let extract_only_bit (Bitvector elems _ _) = + (*let l = word_length elems in + if l = 1 then*) + bool_to_bitU (msb elems) + (*else if l = 0 then + failwith "extract_single_bit called for empty vector" + else + failwith "extract_single_bit called for vector with more bits"*) + let pp_bitu_vector (Vector elems start inc) = let elems_pp = List.foldl (fun acc elem -> acc ^ showBitU elem) "" elems in "Vector [" ^ elems_pp ^ "] " ^ show start ^ " " ^ show inc -let most_significant = function - | (Vector (b :: _) _ _) -> b - | _ -> failwith "most_significant applied to empty vector" - end +let most_significant (Bitvector v _ _) = + if word_length v = 0 then + failwith "most_significant applied to empty vector" + else + bool_to_bitU (msb v) let bitwise_not_bitlist = List.map bitwise_not_bit -let bitwise_not (Vector bs start is_inc) = - Vector (bitwise_not_bitlist bs) start is_inc - -let bitwise_binop op (Vector bsl start is_inc, Vector bsr _ _) = - let revbs = foldl (fun acc pair -> bitwise_binop_bit op pair :: acc) [] (zip bsl bsr) in - Vector (reverse revbs) start is_inc - -let bitwise_and = bitwise_binop (&&) -let bitwise_or = bitwise_binop (||) -let bitwise_xor = bitwise_binop xor - -let unsigned (Vector bs _ _) : integer = - let (sum,_) = - List.foldr - (fun b (acc,exp) -> - match b with - | B1 -> (acc + integerPow 2 exp,exp + 1) - | B0 -> (acc, exp + 1) - | BU -> failwith "unsigned: vector has undefined bits" - end) - (0,0) bs in - sum +let bitwise_not (Bitvector bs start is_inc) = + Bitvector (lNot bs) start is_inc + +let bitwise_binop op (Bitvector bsl start is_inc, Bitvector bsr _ _) = + Bitvector (op bsl bsr) start is_inc +let bitwise_and = bitwise_binop lAnd +let bitwise_or = bitwise_binop lOr +let bitwise_xor = bitwise_binop lXor + +let unsigned (Bitvector bs _ _) : integer = unsignedIntegerFromWord bs let unsigned_big = unsigned -let signed v : integer = - match most_significant v with - | B1 -> 0 - (1 + (unsigned (bitwise_not v))) - | B0 -> unsigned v - | BU -> failwith "signed applied to vector with undefined bits" - end +let signed (Bitvector v _ _) : integer = signedIntegerFromWord v let hardware_mod (a: integer) (b:integer) : integer = if a < 0 && b < 0 @@ -319,36 +405,31 @@ end let add_one_bit_ignore_overflow bits = List.reverse (add_one_bit_ignore_overflow_aux (List.reverse bits)) - let to_vec is_inc ((len : integer),(n : integer)) = let start = if is_inc then 0 else len - 1 in - let bits = to_bin (naturalFromInteger (abs n)) in - let len_bits = integerFromNat (List.length bits) in - let longer = len - len_bits in - let bits' = - if longer < 0 then drop (natFromInteger (abs (longer))) bits - else pad_zero bits longer in - if n > (0 : integer) - then Vector bits' start is_inc - else Vector (add_one_bit_ignore_overflow (bitwise_not_bitlist bits')) - start is_inc + let bits = wordFromInteger n in + if integerFromNat (word_length bits) = len then + Bitvector bits start is_inc + else + failwith "Vector length mismatch in to_vec" let to_vec_big = to_vec let to_vec_inc = to_vec true let to_vec_dec = to_vec false +(* TODO: Think about undefined bit(vector)s *) let to_vec_undef is_inc (len : integer) = - Vector (replicate (natFromInteger len) BU) (if is_inc then 0 else len-1) is_inc + Bitvector (failwith "undefined bitvector") (if is_inc then 0 else len-1) is_inc let to_vec_inc_undef = to_vec_undef true let to_vec_dec_undef = to_vec_undef false -let exts (len, vec) = to_vec (get_dir vec) (len,signed vec) -let extz (len, vec) = to_vec (get_dir vec) (len,unsigned vec) +let exts (len, vec) = to_vec (bvget_dir vec) (len,signed vec) +let extz (len, vec) = to_vec (bvget_dir vec) (len,unsigned vec) -let exts_big (len, vec) = to_vec_big (get_dir vec) (len, signed_big vec) -let extz_big (len, vec) = to_vec_big (get_dir vec) (len, unsigned_big vec) +let exts_big (len, vec) = to_vec_big (bvget_dir vec) (len, signed_big vec) +let extz_big (len, vec) = to_vec_big (bvget_dir vec) (len, unsigned_big vec) let add = integerAdd let add_signed = integerAdd @@ -358,10 +439,13 @@ let modulo = hardware_mod let quot = hardware_quot let power = integerPow -let arith_op_vec op sign (size : integer) (Vector _ _ is_inc as l) r = +(* TODO: this, and the definitions that use it, currently require Size for + to_vec, which I'd rather avoid in favour of library versions; the + double-size results for multiplication may be a problem *) +let arith_op_vec op sign (size : integer) (Bitvector _ _ is_inc as l) r = let (l',r') = (to_num sign l, to_num sign r) in let n = op l' r' in - to_vec is_inc (size * (length l),n) + to_vec is_inc (size * (bvlength l),n) (* add_vec @@ -376,8 +460,9 @@ let minus_VVV = arith_op_vec integerMinus false 1 let mult_VVV = arith_op_vec integerMult false 2 let multS_VVV = arith_op_vec integerMult true 2 -let arith_op_vec_range op sign size (Vector _ _ is_inc as l) r = - arith_op_vec op sign size l (to_vec is_inc (length l,r)) +val arith_op_vec_range : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> bitvector 'a -> integer -> bitvector 'b +let arith_op_vec_range op sign size (Bitvector _ _ is_inc as l) r = + arith_op_vec op sign size l ((to_vec is_inc (bvlength l,r)) : bitvector 'a) (* add_vec_range * add_vec_range_signed @@ -391,8 +476,9 @@ let minus_VIV = arith_op_vec_range integerMinus false 1 let mult_VIV = arith_op_vec_range integerMult false 2 let multS_VIV = arith_op_vec_range integerMult true 2 -let arith_op_range_vec op sign size l (Vector _ _ is_inc as r) = - arith_op_vec op sign size (to_vec is_inc (length r, l)) r +val arith_op_range_vec : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> integer -> bitvector 'a -> bitvector 'b +let arith_op_range_vec op sign size l (Bitvector _ _ is_inc as r) = + arith_op_vec op sign size ((to_vec is_inc (bvlength r, l)) : bitvector 'a) r (* add_range_vec * add_range_vec_signed @@ -438,10 +524,10 @@ let arith_op_vec_vec_range op sign l r = let add_VVI = arith_op_vec_vec_range integerAdd false let addS_VVI = arith_op_vec_vec_range integerAdd true -let arith_op_vec_bit op sign (size : integer) (Vector _ _ is_inc as l)r = +let arith_op_vec_bit op sign (size : integer) (Bitvector _ _ is_inc as l)r = let l' = to_num sign l in let n = op l' (match r with | B1 -> (1 : integer) | _ -> 0 end) in - to_vec is_inc (length l * size,n) + to_vec is_inc (bvlength l * size,n) (* add_vec_bit * add_vec_bit_signed @@ -451,8 +537,10 @@ let add_VBV = arith_op_vec_bit integerAdd false 1 let addS_VBV = arith_op_vec_bit integerAdd true 1 let minus_VBV = arith_op_vec_bit integerMinus true 1 -let rec arith_op_overflow_vec (op : integer -> integer -> integer) sign size (Vector _ _ is_inc as l) r = - let len = length l in +(* TODO: these can't be done directly in Lem because of the one_more size calculation +val arith_op_overflow_vec : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> bitvector 'a -> bitvector 'a -> bitvector 'b * bitU * bool +let rec arith_op_overflow_vec op sign size (Bitvector _ _ is_inc as l) r = + let len = bvlength l in let act_size = len * size in let (l_sign,r_sign) = (to_num sign l,to_num sign r) in let (l_unsign,r_unsign) = (to_num false l,to_num false r) in @@ -481,9 +569,11 @@ let minusSO_VVV = arith_op_overflow_vec integerMinus true 1 let multO_VVV = arith_op_overflow_vec integerMult false 2 let multSO_VVV = arith_op_overflow_vec integerMult true 2 +val arith_op_overflow_vec_bit : forall 'a 'b. Size 'a, Size 'b => (integer -> integer -> integer) -> bool -> integer -> + bitvector 'a -> bitU -> bitvector 'b * bitU * bool let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (size : integer) - (Vector _ _ is_inc as l) r_bit = - let act_size = length l * size in + (Bitvector _ _ is_inc as l) r_bit = + let act_size = bvlength l * size in let l' = to_num sign l in let l_u = to_num false l in let (n,nu,changed) = match r_bit with @@ -509,18 +599,18 @@ let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (siz let addSO_VBV = arith_op_overflow_vec_bit integerAdd true 1 let minusO_VBV = arith_op_overflow_vec_bit integerMinus false 1 let minusSO_VBV = arith_op_overflow_vec_bit integerMinus true 1 - +*) type shift = LL_shift | RR_shift | LLL_shift -let shift_op_vec op (Vector bs start is_inc,(n : integer)) = +let shift_op_vec op (Bitvector bs start is_inc,(n : integer)) = let n = natFromInteger n in match op with | LL_shift (*"<<"*) -> - Vector (sublist bs (n,List.length bs -1) ++ List.replicate n B0) start is_inc + Bitvector (shiftLeft bs (naturalFromNat n)) start is_inc | RR_shift (*">>"*) -> - Vector (List.replicate n B0 ++ sublist bs (0,n-1)) start is_inc + Bitvector (shiftRight bs (naturalFromNat n)) start is_inc | LLL_shift (*"<<<"*) -> - Vector (sublist bs (n,List.length bs - 1) ++ sublist bs (0,n-1)) start is_inc + Bitvector (rotateLeft (naturalFromNat n) bs) start is_inc end let bitwise_leftshift = shift_op_vec LL_shift (*"<<"*) @@ -531,9 +621,9 @@ let rec arith_op_no0 (op : integer -> integer -> integer) l r = if r = 0 then Nothing else Just (op l r) - -let rec arith_op_vec_no0 (op : integer -> integer -> integer) sign size ((Vector _ start is_inc) as l) r = - let act_size = length l * size in +(* TODO +let rec arith_op_vec_no0 (op : integer -> integer -> integer) sign size ((Bitvector _ start is_inc) as l) r = + let act_size = bvlength l * size in let (l',r') = (to_num sign l,to_num sign r) in let n = arith_op_no0 op l' r' in let (representable,n') = @@ -581,7 +671,7 @@ let arith_op_vec_range_no0 op sign size (Vector _ _ is_inc as l) r = arith_op_vec_no0 op sign size l (to_vec is_inc (length l,r)) let mod_VIV = arith_op_vec_range_no0 hardware_mod false 1 - +*) val repeat : forall 'a. list 'a -> integer -> list 'a let rec repeat xs n = if n = 0 then [] @@ -663,9 +753,9 @@ let make_bitvector_undef length = (* let bitwise_not_range_bit n = bitwise_not (to_vec defaultDir n) *) -let mask (n,Vector bits start dir) = - let current_size = List.length bits in - Vector (drop (current_size - (natFromInteger n)) bits) (if dir then 0 else (n-1)) dir +let mask (n,bv) = + let len = bvlength bv in + bvslice_raw bv (len - n) (len - 1) val byte_chunks : forall 'a. nat -> list 'a -> list (list 'a) @@ -952,4 +1042,3 @@ let diafp_to_dia reginfo = function | DIAFP_concrete v -> DIA_concrete_address (address_of_bitv v) | DIAFP_reg r -> DIA_register (regfp_to_reg reginfo r) end - diff --git a/src/gen_lib/sail_values_word.lem b/src/gen_lib/sail_values_word.lem deleted file mode 100644 index ef7b03b9..00000000 --- a/src/gen_lib/sail_values_word.lem +++ /dev/null @@ -1,1026 +0,0 @@ -(* Version of sail_values.lem that uses Lem's machine words library *) - -open import Pervasives_extra -open import Machine_word -open import Sail_impl_base - - -type ii = integer -type nn = natural - -val pow : integer -> integer -> integer -let pow m n = m ** (natFromInteger n) - -let rec replace bs ((n : integer),b') = match bs with - | [] -> [] - | b :: bs -> - if n = 0 then b' :: bs - else b :: replace bs (n - 1,b') - end - - -(*** Bits *) -type bitU = B0 | B1 | BU - -let showBitU = function - | B0 -> "O" - | B1 -> "I" - | BU -> "U" -end - -instance (Show bitU) - let show = showBitU -end - - -let bitU_to_bool = function - | B0 -> false - | B1 -> true - | BU -> failwith "to_bool applied to BU" - end - -let bit_lifted_of_bitU = function - | B0 -> Bitl_zero - | B1 -> Bitl_one - | BU -> Bitl_undef - end - -let bitU_of_bit = function - | Bitc_zero -> B0 - | Bitc_one -> B1 - end - -let bit_of_bitU = function - | B0 -> Bitc_zero - | B1 -> Bitc_one - | BU -> failwith "bit_of_bitU: BU" - end - -let bitU_of_bit_lifted = function - | Bitl_zero -> B0 - | Bitl_one -> B1 - | Bitl_undef -> BU - | Bitl_unknown -> failwith "bitU_of_bit_lifted Bitl_unknown" - end - -let bitwise_not_bit = function - | B1 -> B0 - | B0 -> B1 - | BU -> BU - end - -let inline (~) = bitwise_not_bit - -val is_one : integer -> bitU -let is_one i = - if i = 1 then B1 else B0 - -let bool_to_bitU b = if b then B1 else B0 - -let bitwise_binop_bit op = function - | (BU,_) -> BU (*Do we want to do this or to respect | of I and & of B0 rules?*) - | (_,BU) -> BU (*Do we want to do this or to respect | of I and & of B0 rules?*) - | (x,y) -> bool_to_bitU (op (bitU_to_bool x) (bitU_to_bool y)) - end - -val bitwise_and_bit : bitU * bitU -> bitU -let bitwise_and_bit = bitwise_binop_bit (&&) - -val bitwise_or_bit : bitU * bitU -> bitU -let bitwise_or_bit = bitwise_binop_bit (||) - -val bitwise_xor_bit : bitU * bitU -> bitU -let bitwise_xor_bit = bitwise_binop_bit xor - -val (&.) : bitU -> bitU -> bitU -let inline (&.) x y = bitwise_and_bit (x,y) - -val (|.) : bitU -> bitU -> bitU -let inline (|.) x y = bitwise_or_bit (x,y) - -val (+.) : bitU -> bitU -> bitU -let inline (+.) x y = bitwise_xor_bit (x,y) - - - -(*** Vectors *) - -(* element list * start * has increasing direction *) -type vector 'a = Vector of list 'a * integer * bool - -let showVector (Vector elems start inc) = - "Vector " ^ show elems ^ " " ^ show start ^ " " ^ show inc - -let get_dir (Vector _ _ ord) = ord -let get_start (Vector _ s _) = s -let get_elems (Vector elems _ _) = elems -let length (Vector bs _ _) = integerFromNat (length bs) - -instance forall 'a. Show 'a => (Show (vector 'a)) - let show = showVector -end - -let dir is_inc = if is_inc then D_increasing else D_decreasing -let bool_of_dir = function - | D_increasing -> true - | D_decreasing -> false - end - -(*** Vector operations *) - -val set_vector_start : forall 'a. integer -> vector 'a -> vector 'a -let set_vector_start new_start (Vector bs _ is_inc) = - Vector bs new_start is_inc - -let reset_vector_start v = - set_vector_start (if (get_dir v) then 0 else (length v - 1)) v - -let set_vector_start_to_length v = - set_vector_start (length v - 1) v - -let vector_concat (Vector bs start is_inc) (Vector bs' _ _) = - Vector (bs ++ bs') start is_inc - -let inline (^^) = vector_concat - -val sublist : forall 'a. list 'a -> (nat * nat) -> list 'a -let sublist xs (i,j) = - let (toJ,_suffix) = List.splitAt (j+1) xs in - let (_prefix,fromItoJ) = List.splitAt i toJ in - fromItoJ - -val update_sublist : forall 'a. list 'a -> (nat * nat) -> list 'a -> list 'a -let update_sublist xs (i,j) xs' = - let (toJ,suffix) = List.splitAt (j+1) xs in - let (prefix,_fromItoJ) = List.splitAt i toJ in - prefix ++ xs' ++ suffix - -val slice : forall 'a. vector 'a -> integer -> integer -> vector 'a -let slice (Vector bs start is_inc) i j = - let iN = natFromInteger i in - let jN = natFromInteger j in - let startN = natFromInteger start in - let subvector_bits = - sublist bs (if is_inc then (iN-startN,jN-startN) else (startN-iN,startN-jN)) in - Vector subvector_bits i is_inc - -(* this is for the vector slicing introduced in vector-concat patterns: i and j -index into the "raw data", the list of bits. Therefore getting the bit list is -easy, but the start index has to be transformed to match the old vector start -and the direction. *) -val slice_raw : forall 'a. vector 'a -> integer -> integer -> vector 'a -let slice_raw (Vector bs start is_inc) i j = - let iN = natFromInteger i in - let jN = natFromInteger j in - let bits = sublist bs (iN,jN) in - let len = integerFromNat (List.length bits) in - Vector bits (if is_inc then 0 else len - 1) is_inc - - -val update_aux : forall 'a. vector 'a -> integer -> integer -> list 'a -> vector 'a -let update_aux (Vector bs start is_inc) i j bs' = - let iN = natFromInteger i in - let jN = natFromInteger j in - let startN = natFromInteger start in - let bits = - (update_sublist bs) - (if is_inc then (iN-startN,jN-startN) else (startN-iN,startN-jN)) bs' in - Vector bits start is_inc - -val update : forall 'a. vector 'a -> integer -> integer -> vector 'a -> vector 'a -let update v i j (Vector bs' _ _) = - update_aux v i j bs' - -val access : forall 'a. vector 'a -> integer -> 'a -let access (Vector bs start is_inc) n = - if is_inc then List_extra.nth bs (natFromInteger (n - start)) - else List_extra.nth bs (natFromInteger (start - n)) - -val update_pos : forall 'a. vector 'a -> integer -> 'a -> vector 'a -let update_pos v n b = - update_aux v n n [b] - -(*** Bitvectors *) - -(* element list * start * has increasing direction *) -type bitvector 'a = Bitvector of mword 'a * integer * bool - -let showBitvector (Bitvector elems start inc) = - "Bitvector " ^ show elems ^ " " ^ show start ^ " " ^ show inc - -let bvget_dir (Bitvector _ _ ord) = ord -let bvget_start (Bitvector _ s _) = s -let bvget_elems (Bitvector elems _ _) = elems -let bvlength (Bitvector bs _ _) = integerFromNat (word_length bs) - -instance forall 'a. Show 'a => (Show (bitvector 'a)) - let show = showBitvector -end - -(*** Vector operations *) - -val set_bitvector_start : forall 'a. integer -> bitvector 'a -> bitvector 'a -let set_bitvector_start new_start (Bitvector bs _ is_inc) = - Bitvector bs new_start is_inc - -let reset_bitvector_start v = - set_bitvector_start (if (bvget_dir v) then 0 else (bvlength v - 1)) v - -let set_bitvector_start_to_length v = - set_bitvector_start (bvlength v - 1) v - -let bitvector_concat (Bitvector bs start is_inc) (Bitvector bs' _ _) = - Bitvector (word_concat bs bs') start is_inc - -let inline (^^^) = bitvector_concat - -val bvslice : forall 'a 'b. bitvector 'a -> integer -> integer -> bitvector 'b -let bvslice (Bitvector bs start is_inc) i j = - let iN = natFromInteger i in - let jN = natFromInteger j in - let startN = natFromInteger start in - let (lo,hi) = if is_inc then (iN-startN,jN-startN) else (startN-iN,startN-jN) in - let subvector_bits = word_extract lo hi bs in - Bitvector subvector_bits i is_inc - -(* this is for the vector slicing introduced in vector-concat patterns: i and j -index into the "raw data", the list of bits. Therefore getting the bit list is -easy, but the start index has to be transformed to match the old vector start -and the direction. *) -val bvslice_raw : forall 'a 'b. Size 'b => bitvector 'a -> integer -> integer -> bitvector 'b -let bvslice_raw (Bitvector bs start is_inc) i j = - let iN = natFromInteger i in - let jN = natFromInteger j in - let bits = word_extract iN jN bs in - let len = integerFromNat (word_length bits) in - Bitvector bits (if is_inc then 0 else len - 1) is_inc - -val bvupdate_aux : forall 'a 'b. bitvector 'a -> integer -> integer -> mword 'b -> bitvector 'a -let bvupdate_aux (Bitvector bs start is_inc) i j bs' = - let iN = natFromInteger i in - let jN = natFromInteger j in - let startN = natFromInteger start in - let (lo,hi) = if is_inc then (iN-startN,jN-startN) else (startN-iN,startN-jN) in - let bits = word_update bs lo hi bs' in - Bitvector bits start is_inc - -val bvupdate : forall 'a. bitvector 'a -> integer -> integer -> bitvector 'a -> bitvector 'a -let bvupdate v i j (Bitvector bs' _ _) = - bvupdate_aux v i j bs' - -(* TODO: decide between nat/natural, change either here or in machine_word *) -val getBit' : forall 'a. mword 'a -> nat -> bool -let getBit' w n = getBit w (naturalFromNat n) - -val bvaccess : forall 'a. bitvector 'a -> integer -> bool -let bvaccess (Bitvector bs start is_inc) n = - if is_inc then getBit' bs (natFromInteger (n - start)) - else getBit' bs (natFromInteger (start - n)) - -val bvupdate_pos : forall 'a. Size 'a => bitvector 'a -> integer -> bool -> bitvector 'a -let bvupdate_pos v n b = - bvupdate_aux v n n (wordFromNatural (if b then 1 else 0)) - -(*** Bit vector operations *) - -let extract_only_bit (Bitvector elems _ _) = - let l = word_length elems in - if l = 1 then - msb elems - else if l = 0 then - failwith "extract_single_bit called for empty vector" - else - failwith "extract_single_bit called for vector with more bits" - -let pp_bitu_vector (Vector elems start inc) = - let elems_pp = List.foldl (fun acc elem -> acc ^ showBitU elem) "" elems in - "Vector [" ^ elems_pp ^ "] " ^ show start ^ " " ^ show inc - - -let most_significant (Bitvector v _ _) = - if word_length v = 0 then - failwith "most_significant applied to empty vector" - else - msb v - -let bitwise_not_bitlist = List.map bitwise_not_bit - -let bitwise_not (Bitvector bs start is_inc) = - Bitvector (lNot bs) start is_inc - -let bitwise_binop op (Bitvector bsl start is_inc, Bitvector bsr _ _) = - Bitvector (op bsl bsr) start is_inc - -let bitwise_and = bitwise_binop lAnd -let bitwise_or = bitwise_binop lOr -let bitwise_xor = bitwise_binop lXor - -let unsigned (Bitvector bs _ _) : integer = unsignedIntegerFromWord bs -let unsigned_big = unsigned - -let signed (Bitvector v _ _) : integer = signedIntegerFromWord v - -let hardware_mod (a: integer) (b:integer) : integer = - if a < 0 && b < 0 - then (abs a) mod (abs b) - else if (a < 0 && b >= 0) - then (a mod b) - b - else a mod b - -let hardware_quot (a:integer) (b:integer) : integer = - if a < 0 && b < 0 - then (abs a) / (abs b) - else if (a < 0 && b > 0) - then (a/b) + 1 - else a/b - -let quot_signed = hardware_quot - - -let signed_big = signed - -let to_num sign = if sign then signed else unsigned - -let max_64u = (integerPow 2 64) - 1 -let max_64 = (integerPow 2 63) - 1 -let min_64 = 0 - (integerPow 2 63) -let max_32u = (4294967295 : integer) -let max_32 = (2147483647 : integer) -let min_32 = (0 - 2147483648 : integer) -let max_8 = (127 : integer) -let min_8 = (0 - 128 : integer) -let max_5 = (31 : integer) -let min_5 = (0 - 32 : integer) - -let get_max_representable_in sign (n : integer) : integer = - if (n = 64) then match sign with | true -> max_64 | false -> max_64u end - else if (n=32) then match sign with | true -> max_32 | false -> max_32u end - else if (n=8) then max_8 - else if (n=5) then max_5 - else match sign with | true -> integerPow 2 ((natFromInteger n) -1) - | false -> integerPow 2 (natFromInteger n) - end - -let get_min_representable_in _ (n : integer) : integer = - if n = 64 then min_64 - else if n = 32 then min_32 - else if n = 8 then min_8 - else if n = 5 then min_5 - else 0 - (integerPow 2 (natFromInteger n)) - -val to_bin_aux : natural -> list bitU -let rec to_bin_aux x = - if x = 0 then [] - else (if x mod 2 = 1 then B1 else B0) :: to_bin_aux (x / 2) -let to_bin n = List.reverse (to_bin_aux n) - -val pad_zero : list bitU -> integer -> list bitU -let rec pad_zero bits n = - if n = 0 then bits else pad_zero (B0 :: bits) (n -1) - - -let rec add_one_bit_ignore_overflow_aux bits = match bits with - | [] -> [] - | B0 :: bits -> B1 :: bits - | B1 :: bits -> B0 :: add_one_bit_ignore_overflow_aux bits - | BU :: _ -> failwith "add_one_bit_ignore_overflow: undefined bit" -end - -let add_one_bit_ignore_overflow bits = - List.reverse (add_one_bit_ignore_overflow_aux (List.reverse bits)) - -let to_vec is_inc ((len : integer),(n : integer)) = - let start = if is_inc then 0 else len - 1 in - let bits = wordFromInteger n in - if integerFromNat (word_length bits) = len then - Bitvector bits start is_inc - else - failwith "Vector length mismatch in to_vec" - -let to_vec_big = to_vec - -let to_vec_inc = to_vec true -let to_vec_dec = to_vec false -(* TODO?? -let to_vec_undef is_inc (len : integer) = - Vector (replicate (natFromInteger len) BU) (if is_inc then 0 else len-1) is_inc - -let to_vec_inc_undef = to_vec_undef true -let to_vec_dec_undef = to_vec_undef false -*) -let exts (len, vec) = to_vec (bvget_dir vec) (len,signed vec) -let extz (len, vec) = to_vec (bvget_dir vec) (len,unsigned vec) - -let exts_big (len, vec) = to_vec_big (bvget_dir vec) (len, signed_big vec) -let extz_big (len, vec) = to_vec_big (bvget_dir vec) (len, unsigned_big vec) - -let add = integerAdd -let add_signed = integerAdd -let minus = integerMinus -let multiply = integerMult -let modulo = hardware_mod -let quot = hardware_quot -let power = integerPow - -(* TODO: this, and the definitions that use it, currently requires Size for - to_vec, which I'd rather avoid *) -let arith_op_vec op sign (size : integer) (Bitvector _ _ is_inc as l) r = - let (l',r') = (to_num sign l, to_num sign r) in - let n = op l' r' in - to_vec is_inc (size * (bvlength l),n) - - -(* add_vec - * add_vec_signed - * minus_vec - * multiply_vec - * multiply_vec_signed - *) -let add_VVV = arith_op_vec integerAdd false 1 -let addS_VVV = arith_op_vec integerAdd true 1 -let minus_VVV = arith_op_vec integerMinus false 1 -let mult_VVV = arith_op_vec integerMult false 2 -let multS_VVV = arith_op_vec integerMult true 2 - -val arith_op_vec_range : forall 'a. Size 'a => (integer -> integer -> integer) -> bool -> integer -> bitvector 'a -> integer -> bitvector 'a -let arith_op_vec_range op sign size (Bitvector _ _ is_inc as l) r = - arith_op_vec op sign size l (to_vec is_inc (bvlength l,r)) - -(* add_vec_range - * add_vec_range_signed - * minus_vec_range - * mult_vec_range - * mult_vec_range_signed - *) -let add_VIV = arith_op_vec_range integerAdd false 1 -let addS_VIV = arith_op_vec_range integerAdd true 1 -let minus_VIV = arith_op_vec_range integerMinus false 1 -let mult_VIV = arith_op_vec_range integerMult false 2 -let multS_VIV = arith_op_vec_range integerMult true 2 - -val arith_op_range_vec : forall 'a. Size 'a => (integer -> integer -> integer) -> bool -> integer -> integer -> bitvector 'a -> bitvector 'a -let arith_op_range_vec op sign size l (Bitvector _ _ is_inc as r) = - arith_op_vec op sign size (to_vec is_inc (bvlength r, l)) r - -(* add_range_vec - * add_range_vec_signed - * minus_range_vec - * mult_range_vec - * mult_range_vec_signed - *) -let add_IVV = arith_op_range_vec integerAdd false 1 -let addS_IVV = arith_op_range_vec integerAdd true 1 -let minus_IVV = arith_op_range_vec integerMinus false 1 -let mult_IVV = arith_op_range_vec integerMult false 2 -let multS_IVV = arith_op_range_vec integerMult true 2 - -let arith_op_range_vec_range op sign l r = op l (to_num sign r) - -(* add_range_vec_range - * add_range_vec_range_signed - * minus_range_vec_range - *) -let add_IVI = arith_op_range_vec_range integerAdd false -let addS_IVI = arith_op_range_vec_range integerAdd true -let minus_IVI = arith_op_range_vec_range integerMinus false - -let arith_op_vec_range_range op sign l r = op (to_num sign l) r - -(* add_vec_range_range - * add_vec_range_range_signed - * minus_vec_range_range - *) -let add_VII = arith_op_vec_range_range integerAdd false -let addS_VII = arith_op_vec_range_range integerAdd true -let minus_VII = arith_op_vec_range_range integerMinus false - - - -let arith_op_vec_vec_range op sign l r = - let (l',r') = (to_num sign l,to_num sign r) in - op l' r' - -(* add_vec_vec_range - * add_vec_vec_range_signed - *) -let add_VVI = arith_op_vec_vec_range integerAdd false -let addS_VVI = arith_op_vec_vec_range integerAdd true - -let arith_op_vec_bit op sign (size : integer) (Bitvector _ _ is_inc as l)r = - let l' = to_num sign l in - let n = op l' (match r with | B1 -> (1 : integer) | _ -> 0 end) in - to_vec is_inc (bvlength l * size,n) - -(* add_vec_bit - * add_vec_bit_signed - * minus_vec_bit_signed - *) -let add_VBV = arith_op_vec_bit integerAdd false 1 -let addS_VBV = arith_op_vec_bit integerAdd true 1 -let minus_VBV = arith_op_vec_bit integerMinus true 1 - -val arith_op_overflow_vec : forall 'a. Size 'a => (integer -> integer -> integer) -> bool -> integer -> bitvector 'a -> bitvector 'a -> bitvector 'a * bitU * bool -let rec arith_op_overflow_vec op sign size (Bitvector _ _ is_inc as l) r = - let len = bvlength l in - let act_size = len * size in - let (l_sign,r_sign) = (to_num sign l,to_num sign r) in - let (l_unsign,r_unsign) = (to_num false l,to_num false r) in - let n = op l_sign r_sign in - let n_unsign = op l_unsign r_unsign in - let correct_size_num = to_vec is_inc (act_size,n) in - let one_more_size_u = to_vec is_inc (act_size + 1,n_unsign) in - let overflow = - if n <= get_max_representable_in sign len && - n >= get_min_representable_in sign len - then B0 else B1 in - let c_out = most_significant one_more_size_u in - (correct_size_num,overflow,c_out) - -(* add_overflow_vec - * add_overflow_vec_signed - * minus_overflow_vec - * minus_overflow_vec_signed - * mult_overflow_vec - * mult_overflow_vec_signed - *) -let addO_VVV = arith_op_overflow_vec integerAdd false 1 -let addSO_VVV = arith_op_overflow_vec integerAdd true 1 -let minusO_VVV = arith_op_overflow_vec integerMinus false 1 -let minusSO_VVV = arith_op_overflow_vec integerMinus true 1 -let multO_VVV = arith_op_overflow_vec integerMult false 2 -let multSO_VVV = arith_op_overflow_vec integerMult true 2 - -val arith_op_overflow_vec_bit : forall 'a. Size 'a => (integer -> integer -> integer) -> bool -> integer -> - bitvector 'a -> bitU -> bitvector 'a * bitU * bool -let rec arith_op_overflow_vec_bit (op : integer -> integer -> integer) sign (size : integer) - (Bitvector _ _ is_inc as l) r_bit = - let act_size = bvlength l * size in - let l' = to_num sign l in - let l_u = to_num false l in - let (n,nu,changed) = match r_bit with - | B1 -> (op l' 1, op l_u 1, true) - | B0 -> (l',l_u,false) - | BU -> failwith "arith_op_overflow_vec_bit applied to undefined bit" - end in -(* | _ -> assert false *) - let correct_size_num = to_vec is_inc (act_size,n) in - let one_larger = to_vec is_inc (act_size + 1,nu) in - let overflow = - if changed - then - if n <= get_max_representable_in sign act_size && n >= get_min_representable_in sign act_size - then B0 else B1 - else B0 in - (correct_size_num,overflow,most_significant one_larger) - -(* add_overflow_vec_bit_signed - * minus_overflow_vec_bit - * minus_overflow_vec_bit_signed - *) -let addSO_VBV = arith_op_overflow_vec_bit integerAdd true 1 -let minusO_VBV = arith_op_overflow_vec_bit integerMinus false 1 -let minusSO_VBV = arith_op_overflow_vec_bit integerMinus true 1 - -type shift = LL_shift | RR_shift | LLL_shift - -let shift_op_vec op (Bitvector bs start is_inc,(n : integer)) = - let n = natFromInteger n in - match op with - | LL_shift (*"<<"*) -> - Bitvector (shiftLeft bs (naturalFromNat n)) start is_inc - | RR_shift (*">>"*) -> - Bitvector (shiftRight bs (naturalFromNat n)) start is_inc - | LLL_shift (*"<<<"*) -> - Bitvector (rotateLeft (naturalFromNat n) bs) start is_inc - end - -let bitwise_leftshift = shift_op_vec LL_shift (*"<<"*) -let bitwise_rightshift = shift_op_vec RR_shift (*">>"*) -let bitwise_rotate = shift_op_vec LLL_shift (*"<<<"*) - -let rec arith_op_no0 (op : integer -> integer -> integer) l r = - if r = 0 - then Nothing - else Just (op l r) -(* TODO -let rec arith_op_vec_no0 (op : integer -> integer -> integer) sign size ((Bitvector _ start is_inc) as l) r = - let act_size = bvlength l * size in - let (l',r') = (to_num sign l,to_num sign r) in - let n = arith_op_no0 op l' r' in - let (representable,n') = - match n with - | Just n' -> - (n' <= get_max_representable_in sign act_size && - n' >= get_min_representable_in sign act_size, n') - | _ -> (false,0) - end in - if representable - then to_vec is_inc (act_size,n') - else Vector (List.replicate (natFromInteger act_size) BU) start is_inc - -let mod_VVV = arith_op_vec_no0 hardware_mod false 1 -let quot_VVV = arith_op_vec_no0 hardware_quot false 1 -let quotS_VVV = arith_op_vec_no0 hardware_quot true 1 - -let arith_op_overflow_no0_vec op sign size ((Vector _ start is_inc) as l) r = - let rep_size = length r * size in - let act_size = length l * size in - let (l',r') = (to_num sign l,to_num sign r) in - let (l_u,r_u) = (to_num false l,to_num false r) in - let n = arith_op_no0 op l' r' in - let n_u = arith_op_no0 op l_u r_u in - let (representable,n',n_u') = - match (n, n_u) with - | (Just n',Just n_u') -> - ((n' <= get_max_representable_in sign rep_size && - n' >= (get_min_representable_in sign rep_size)), n', n_u') - | _ -> (true,0,0) - end in - let (correct_size_num,one_more) = - if representable then - (to_vec is_inc (act_size,n'),to_vec is_inc (act_size + 1,n_u')) - else - (Vector (List.replicate (natFromInteger act_size) BU) start is_inc, - Vector (List.replicate (natFromInteger (act_size + 1)) BU) start is_inc) in - let overflow = if representable then B0 else B1 in - (correct_size_num,overflow,most_significant one_more) - -let quotO_VVV = arith_op_overflow_no0_vec hardware_quot false 1 -let quotSO_VVV = arith_op_overflow_no0_vec hardware_quot true 1 - -let arith_op_vec_range_no0 op sign size (Vector _ _ is_inc as l) r = - arith_op_vec_no0 op sign size l (to_vec is_inc (length l,r)) - -let mod_VIV = arith_op_vec_range_no0 hardware_mod false 1 -*) -val repeat : forall 'a. list 'a -> integer -> list 'a -let rec repeat xs n = - if n = 0 then [] - else xs ++ repeat xs (n-1) - -(* -let duplicate bit length = - Vector (repeat [bit] length) (if dir then 0 else length - 1) dir - *) - -let compare_op op (l,r) = bool_to_bitU (op l r) - -let lt = compare_op (<) -let gt = compare_op (>) -let lteq = compare_op (<=) -let gteq = compare_op (>=) - - -let compare_op_vec op sign (l,r) = - let (l',r') = (to_num sign l, to_num sign r) in - compare_op op (l',r') - -let lt_vec = compare_op_vec (<) true -let gt_vec = compare_op_vec (>) true -let lteq_vec = compare_op_vec (<=) true -let gteq_vec = compare_op_vec (>=) true - -let lt_vec_signed = compare_op_vec (<) true -let gt_vec_signed = compare_op_vec (>) true -let lteq_vec_signed = compare_op_vec (<=) true -let gteq_vec_signed = compare_op_vec (>=) true -let lt_vec_unsigned = compare_op_vec (<) false -let gt_vec_unsigned = compare_op_vec (>) false -let lteq_vec_unsigned = compare_op_vec (<=) false -let gteq_vec_unsigned = compare_op_vec (>=) false - -let compare_op_vec_range op sign (l,r) = - compare_op op ((to_num sign l),r) - -let lt_vec_range = compare_op_vec_range (<) true -let gt_vec_range = compare_op_vec_range (>) true -let lteq_vec_range = compare_op_vec_range (<=) true -let gteq_vec_range = compare_op_vec_range (>=) true - -let compare_op_range_vec op sign (l,r) = - compare_op op (l, (to_num sign r)) - -let lt_range_vec = compare_op_range_vec (<) true -let gt_range_vec = compare_op_range_vec (>) true -let lteq_range_vec = compare_op_range_vec (<=) true -let gteq_range_vec = compare_op_range_vec (>=) true - -let eq (l,r) = bool_to_bitU (l = r) -let eq_range (l,r) = bool_to_bitU (l = r) -let eq_vec (l,r) = bool_to_bitU (l = r) -let eq_bit (l,r) = bool_to_bitU (l = r) -let eq_vec_range (l,r) = eq (to_num false l,r) -let eq_range_vec (l,r) = eq (l, to_num false r) -let eq_vec_vec (l,r) = eq (to_num true l, to_num true r) - -let neq (l,r) = bitwise_not_bit (eq (l,r)) -let neq_bit (l,r) = bitwise_not_bit (eq_bit (l,r)) -let neq_range (l,r) = bitwise_not_bit (eq_range (l,r)) -let neq_vec (l,r) = bitwise_not_bit (eq_vec_vec (l,r)) -let neq_vec_range (l,r) = bitwise_not_bit (eq_vec_range (l,r)) -let neq_range_vec (l,r) = bitwise_not_bit (eq_range_vec (l,r)) - - -val make_indexed_vector : forall 'a. list (integer * 'a) -> 'a -> integer -> integer -> bool -> vector 'a -let make_indexed_vector entries default start length dir = - let length = natFromInteger length in - Vector (List.foldl replace (replicate length default) entries) start dir - -(* -val make_bit_vector_undef : integer -> vector bitU -let make_bitvector_undef length = - Vector (replicate (natFromInteger length) BU) 0 true - *) - -(* let bitwise_not_range_bit n = bitwise_not (to_vec defaultDir n) *) - -let mask (n,Vector bits start dir) = - let current_size = List.length bits in - Vector (drop (current_size - (natFromInteger n)) bits) (if dir then 0 else (n-1)) dir - - -val byte_chunks : forall 'a. nat -> list 'a -> list (list 'a) -let rec byte_chunks n list = match (n,list) with - | (0,_) -> [] - | (n+1, a::b::c::d::e::f::g::h::rest) -> [a;b;c;d;e;f;g;h] :: byte_chunks n rest - | _ -> failwith "byte_chunks not given enough bits" -end - -val bitv_of_byte_lifteds : bool -> list Sail_impl_base.byte_lifted -> vector bitU -let bitv_of_byte_lifteds dir v = - let bits = foldl (fun x (Byte_lifted y) -> x ++ (List.map bitU_of_bit_lifted y)) [] v in - let len = integerFromNat (List.length bits) in - Vector bits (if dir then 0 else len - 1) dir - -val bitv_of_bytes : bool -> list Sail_impl_base.byte -> vector bitU -let bitv_of_bytes dir v = - let bits = foldl (fun x (Byte y) -> x ++ (List.map bitU_of_bit y)) [] v in - let len = integerFromNat (List.length bits) in - Vector bits (if dir then 0 else len - 1) dir - - -val byte_lifteds_of_bitv : vector bitU -> list byte_lifted -let byte_lifteds_of_bitv (Vector bits length is_inc) = - let bits = List.map bit_lifted_of_bitU bits in - byte_lifteds_of_bit_lifteds bits - -val bytes_of_bitv : vector bitU -> list byte -let bytes_of_bitv (Vector bits length is_inc) = - let bits = List.map bit_of_bitU bits in - bytes_of_bits bits - -val bit_lifteds_of_bitUs : list bitU -> list bit_lifted -let bit_lifteds_of_bitUs bits = List.map bit_lifted_of_bitU bits - -val bit_lifteds_of_bitv : vector bitU -> list bit_lifted -let bit_lifteds_of_bitv v = bit_lifteds_of_bitUs (get_elems v) - - -val address_lifted_of_bitv : vector bitU -> address_lifted -let address_lifted_of_bitv v = - let byte_lifteds = byte_lifteds_of_bitv v in - let maybe_address_integer = - match (maybe_all (List.map byte_of_byte_lifted byte_lifteds)) with - | Just bs -> Just (integer_of_byte_list bs) - | _ -> Nothing - end in - Address_lifted byte_lifteds maybe_address_integer - -val address_of_bitv : vector bitU -> address -let address_of_bitv v = - let bytes = bytes_of_bitv v in - address_of_byte_list bytes - - - -(*** Registers *) - -type register_field = string -type register_field_index = string * (integer * integer) (* name, start and end *) - -type register = - | Register of string * (* name *) - integer * (* length *) - integer * (* start index *) - bool * (* is increasing *) - list register_field_index - | UndefinedRegister of integer (* length *) - | RegisterPair of register * register - -let name_of_reg = function - | Register name _ _ _ _ -> name - | UndefinedRegister _ -> failwith "name_of_reg UndefinedRegister" - | RegisterPair _ _ -> failwith "name_of_reg RegisterPair" -end - -let size_of_reg = function - | Register _ size _ _ _ -> size - | UndefinedRegister size -> size - | RegisterPair _ _ -> failwith "size_of_reg RegisterPair" -end - -let start_of_reg = function - | Register _ _ start _ _ -> start - | UndefinedRegister _ -> failwith "start_of_reg UndefinedRegister" - | RegisterPair _ _ -> failwith "start_of_reg RegisterPair" -end - -let is_inc_of_reg = function - | Register _ _ _ is_inc _ -> is_inc - | UndefinedRegister _ -> failwith "is_inc_of_reg UndefinedRegister" - | RegisterPair _ _ -> failwith "in_inc_of_reg RegisterPair" -end - -let dir_of_reg = function - | Register _ _ _ is_inc _ -> dir is_inc - | UndefinedRegister _ -> failwith "dir_of_reg UndefinedRegister" - | RegisterPair _ _ -> failwith "dir_of_reg RegisterPair" -end - -let size_of_reg_nat reg = natFromInteger (size_of_reg reg) -let start_of_reg_nat reg = natFromInteger (start_of_reg reg) - -val register_field_indices_aux : register -> register_field -> maybe (integer * integer) -let rec register_field_indices_aux register rfield = - match register with - | Register _ _ _ _ rfields -> List.lookup rfield rfields - | RegisterPair r1 r2 -> - let m_indices = register_field_indices_aux r1 rfield in - if isJust m_indices then m_indices else register_field_indices_aux r2 rfield - | UndefinedRegister _ -> Nothing - end - -val register_field_indices : register -> register_field -> integer * integer -let register_field_indices register rfield = - match register_field_indices_aux register rfield with - | Just indices -> indices - | Nothing -> failwith "Invalid register/register-field combination" - end - -let register_field_indices_nat reg regfield= - let (i,j) = register_field_indices reg regfield in - (natFromInteger i,natFromInteger j) - -let rec external_reg_value reg_name v = - let (internal_start, external_start, direction) = - match reg_name with - | Reg _ start size dir -> - (start, (if dir = D_increasing then start else (start - (size +1))), dir) - | Reg_slice _ reg_start dir (slice_start, slice_end) -> - ((if dir = D_increasing then slice_start else (reg_start - slice_start)), - slice_start, dir) - | Reg_field _ reg_start dir _ (slice_start, slice_end) -> - ((if dir = D_increasing then slice_start else (reg_start - slice_start)), - slice_start, dir) - | Reg_f_slice _ reg_start dir _ _ (slice_start, slice_end) -> - ((if dir = D_increasing then slice_start else (reg_start - slice_start)), - slice_start, dir) - end in - let bits = bit_lifteds_of_bitv v in - <| rv_bits = bits; - rv_dir = direction; - rv_start = external_start; - rv_start_internal = internal_start |> - -val internal_reg_value : register_value -> vector bitU -let internal_reg_value v = - Vector (List.map bitU_of_bit_lifted v.rv_bits) - (integerFromNat v.rv_start_internal) - (v.rv_dir = D_increasing) - - -let external_slice (d:direction) (start:nat) ((i,j):(nat*nat)) = - match d with - (*This is the case the thread/concurrecny model expects, so no change needed*) - | D_increasing -> (i,j) - | D_decreasing -> let slice_i = start - i in - let slice_j = (i - j) + slice_i in - (slice_i,slice_j) - end - -let external_reg_whole reg = - Reg (name_of_reg reg) (start_of_reg_nat reg) (size_of_reg_nat reg) (dir_of_reg reg) - -let external_reg_slice reg (i,j) = - let start = start_of_reg_nat reg in - let dir = dir_of_reg reg in - Reg_slice (name_of_reg reg) start dir (external_slice dir start (i,j)) - -let external_reg_field_whole reg rfield = - let (m,n) = register_field_indices_nat reg rfield in - let start = start_of_reg_nat reg in - let dir = dir_of_reg reg in - Reg_field (name_of_reg reg) start dir rfield (external_slice dir start (m,n)) - -let external_reg_field_slice reg rfield (i,j) = - let (m,n) = register_field_indices_nat reg rfield in - let start = start_of_reg_nat reg in - let dir = dir_of_reg reg in - Reg_f_slice (name_of_reg reg) start dir rfield - (external_slice dir start (m,n)) - (external_slice dir start (i,j)) - -let external_mem_value v = - byte_lifteds_of_bitv v $> List.reverse - -let internal_mem_value direction bytes = - List.reverse bytes $> bitv_of_byte_lifteds direction - - - - - -val foreach_inc : forall 'vars. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> 'vars) -> 'vars -let rec foreach_inc (i,stop,by) vars body = - if i <= stop - then let vars = body i vars in - foreach_inc (i + by,stop,by) vars body - else vars - -val foreach_dec : forall 'vars. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> 'vars) -> 'vars -let rec foreach_dec (i,stop,by) vars body = - if i >= stop - then let vars = body i vars in - foreach_dec (i - by,stop,by) vars body - else vars - -let assert' b msg_opt = - let msg = match msg_opt with - | Just msg -> msg - | Nothing -> "unspecified error" - end in - if bitU_to_bool b then () else failwith msg - -(* convert numbers unsafely to naturals *) - -class (ToNatural 'a) val toNatural : 'a -> natural end -(* eta-expanded for Isabelle output, otherwise it breaks *) -instance (ToNatural integer) let toNatural = (fun n -> naturalFromInteger n) end -instance (ToNatural int) let toNatural = (fun n -> naturalFromInt n) end -instance (ToNatural nat) let toNatural = (fun n -> naturalFromNat n) end -instance (ToNatural natural) let toNatural = (fun n -> n) end - -let toNaturalFiveTup (n1,n2,n3,n4,n5) = - (toNatural n1, - toNatural n2, - toNatural n3, - toNatural n4, - toNatural n5) - - -type regfp = - | RFull of (string) - | RSlice of (string * integer * integer) - | RSliceBit of (string * integer) - | RField of (string * string) - -type niafp = - | NIAFP_successor - | NIAFP_concrete_address of vector bitU - | NIAFP_LR - | NIAFP_CTR - | NIAFP_register of regfp - -(* only for MIPS *) -type diafp = - | DIAFP_none - | DIAFP_concrete of vector bitU - | DIAFP_reg of regfp - -let regfp_to_reg (reg_info : string -> maybe string -> (nat * nat * direction * (nat * nat))) = function - | RFull name -> - let (start,length,direction,_) = reg_info name Nothing in - Reg name start length direction - | RSlice (name,i,j) -> - let i = natFromInteger i in - let j = natFromInteger j in - let (start,length,direction,_) = reg_info name Nothing in - let slice = external_slice direction start (i,j) in - Reg_slice name start direction slice - | RSliceBit (name,i) -> - let i = natFromInteger i in - let (start,length,direction,_) = reg_info name Nothing in - let slice = external_slice direction start (i,i) in - Reg_slice name start direction slice - | RField (name,field_name) -> - let (start,length,direction,span) = reg_info name (Just field_name) in - let slice = external_slice direction start span in - Reg_field name start direction field_name slice -end - -let niafp_to_nia reginfo = function - | NIAFP_successor -> NIA_successor - | NIAFP_concrete_address v -> NIA_concrete_address (address_of_bitv v) - | NIAFP_LR -> NIA_LR - | NIAFP_CTR -> NIA_CTR - | NIAFP_register r -> NIA_register (regfp_to_reg reginfo r) -end - -let diafp_to_dia reginfo = function - | DIAFP_none -> DIA_none - | DIAFP_concrete v -> DIA_concrete_address (address_of_bitv v) - | DIAFP_reg r -> DIA_register (regfp_to_reg reginfo r) -end - diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index 430ee562..709052fe 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -47,12 +47,12 @@ let set_reg state reg bitv = <| state with regstate = Map.insert reg bitv state.regstate |> -val read_mem : bool -> read_kind -> vector bitU -> integer -> M (vector bitU) +val read_mem : forall 'a 'b. Size 'b => bool -> read_kind -> bitvector 'a -> integer -> M (bitvector 'b) let read_mem dir read_kind addr sz state = - let addr = integer_of_address (address_of_bitv addr) in + let addr = unsigned addr in let addrs = range addr (addr+sz-1) in let memory_value = List.map (fun addr -> Map_extra.find addr state.memstate) addrs in - let value = Sail_values.internal_mem_value dir memory_value in + let value = vec_to_bvec (Sail_values.internal_mem_value dir memory_value) in let is_exclusive = match read_kind with | Sail_impl_base.Read_plain -> false | Sail_impl_base.Read_reserve -> true @@ -69,9 +69,9 @@ let read_mem dir read_kind addr sz state = (* caps are aligned at 32 bytes *) let cap_alignment = (32 : integer) -val read_tag : bool -> read_kind -> vector bitU -> M bitU +val read_tag : forall 'a. bool -> read_kind -> bitvector 'a -> M bitU let read_tag dir read_kind addr state = - let addr = (integer_of_address (address_of_bitv addr)) / cap_alignment in + let addr = (unsigned addr) / cap_alignment in let tag = match (Map.lookup addr state.tagstate) with | Just t -> t | Nothing -> B0 @@ -96,18 +96,18 @@ let excl_result () state = (Left true, <| state with last_exclusive_operation_was_load = false |>) in (Left false, state) :: if state.last_exclusive_operation_was_load then [success] else [] -val write_mem_ea : write_kind -> vector bitU -> integer -> M unit +val write_mem_ea : forall 'a. write_kind -> bitvector 'a -> integer -> M unit let write_mem_ea write_kind addr sz state = - let addr = integer_of_address (address_of_bitv addr) in + let addr = unsigned addr in [(Left (), <| state with write_ea = Just (write_kind,addr,sz) |>)] -val write_mem_val : vector bitU -> M bool +val write_mem_val : forall 'b. bitvector 'b -> M bool let write_mem_val v state = let (write_kind,addr,sz) = match state.write_ea with | Nothing -> failwith "write ea has not been announced yet" | Just write_ea -> write_ea end in let addrs = range addr (addr+sz-1) in - let v = external_mem_value v in + let v = external_mem_value (bvec_to_vec v) in let addresses_with_value = List.zip addrs v in let memstate = List.foldl (fun mem (addr,v) -> Map.insert addr v mem) state.memstate addresses_with_value in @@ -122,16 +122,16 @@ let write_tag t state = let tagstate = Map.insert taddr t state.tagstate in [(Left true, <| state with tagstate = tagstate |>)] -val read_reg : register -> M (vector bitU) +val read_reg : forall 'a. Size 'a => register -> M (bitvector 'a) let read_reg reg state = - let v = Map_extra.find (name_of_reg reg) state.regstate in + let v = get_reg state (name_of_reg reg) in + [(Left (vec_to_bvec v),state)] +let read_reg_range reg i j state = + let v = slice (get_reg state (name_of_reg reg)) i j in + [(Left (vec_to_bvec v),state)] +let read_reg_bit reg i state = + let v = access (get_reg state (name_of_reg reg)) i in [(Left v,state)] -let read_reg_range reg i j = - read_reg reg >>= fun rv -> - return (slice rv i j) -let read_reg_bit reg i = - read_reg_range reg i i >>= fun v -> - return (extract_only_bit v) let read_reg_field reg regfield = let (i,j) = register_field_indices reg regfield in read_reg_range reg i j @@ -139,25 +139,30 @@ let read_reg_bitfield reg regfield = let (i,_) = register_field_indices reg regfield in read_reg_bit reg i -val write_reg : register -> vector bitU -> M unit +val write_reg : forall 'a. Size 'a => register -> bitvector 'a -> M unit let write_reg reg v state = - [(Left (),<| state with regstate = Map.insert (name_of_reg reg) v state.regstate |>)] -let write_reg_range reg i j v = - read_reg reg >>= fun current_value -> - let new_value = update current_value i j v in - write_reg reg new_value -let write_reg_bit reg i bit = - write_reg_range reg i i (Vector [bit] i (is_inc_of_reg reg)) + [(Left (), set_reg state (name_of_reg reg) (bvec_to_vec v))] +let write_reg_range reg i j v state = + let current_value = get_reg state (name_of_reg reg) in + let new_value = update current_value i j (bvec_to_vec v) in + [(Left (), set_reg state (name_of_reg reg) new_value)] +let write_reg_bit reg i bit state = + let current_value = get_reg state (name_of_reg reg) in + let new_value = update_pos current_value i bit in + [(Left (), set_reg state (name_of_reg reg) new_value)] let write_reg_field reg regfield = - let (i,j) = register_field_indices reg regfield in + let (i,j) = register_field_indices reg regfield in write_reg_range reg i j let write_reg_bitfield reg regfield = let (i,_) = register_field_indices reg regfield in write_reg_bit reg i -let write_reg_field_range reg regfield i j v = - read_reg_field reg regfield >>= fun current_field_value -> - let new_field_value = update current_field_value i j v in - write_reg_field reg regfield new_field_value +let write_reg_field_range reg regfield i j v state = + let (i0,j0) = register_field_indices reg regfield in + let current_value = get_reg state (name_of_reg reg) in + let current_field_value = slice current_value i0 j0 in + let new_field_value = update current_field_value i j (bvec_to_vec v) in + let new_value = update current_value i j new_field_value in + [(Left (), set_reg state (name_of_reg reg) new_value)] val barrier : barrier_kind -> M unit @@ -186,7 +191,8 @@ let rec foreachM_dec (i,stop,by) vars body = foreachM_dec (i - by,stop,by) vars body else return vars -let write_two_regs r1 r2 vec = +let write_two_regs r1 r2 bvec state = + let vec = bvec_to_vec bvec in let is_inc = let is_inc_r1 = is_inc_of_reg r1 in let is_inc_r2 = is_inc_of_reg r2 in @@ -205,4 +211,6 @@ let write_two_regs r1 r2 vec = if is_inc then slice vec (size_r1 - start_vec) (size_vec - start_vec) else slice vec (start_vec - size_r1) (start_vec - size_vec) in - write_reg r1 r1_v >> write_reg r2 r2_v + let state1 = set_reg state (name_of_reg r1) r1_v in + let state2 = set_reg state1 (name_of_reg r2) r2_v in + [(Left (), state2)] diff --git a/src/monomorphise.ml b/src/monomorphise.ml new file mode 100644 index 00000000..14ea30ba --- /dev/null +++ b/src/monomorphise.ml @@ -0,0 +1,621 @@ +open Parse_ast +open Ast +open Type_internal + +(* TODO: put this somewhere common *) + +let id_to_string (Id_aux(id,l)) = + match id with + | Id(s) -> s + | DeIid(s) -> s + +(* TODO: check for temporary failwiths *) + +let optmap v f = + match v with + | None -> None + | Some v -> Some (f v) + +let disable_const_propagation = ref false + +(* Based on current type checker's behaviour *) +let pat_id_is_variable t_env id = + match Envmap.apply t_env id with + | Some (Base(_,Constructor _,_,_,_,_)) + | Some (Base(_,Enum _,_,_,_,_)) + -> false + | _ -> true + + +let nexp_subst substs exp = + let s_t t = typ_subst substs true t in +(* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in + hopefully don't need this anyway *) + let s_typschm tsh = tsh in + let s_tannot = function + | Base ((params,t),tag,ranges,effl,effc,bounds) -> + (* TODO: do other fields need mapped? *) + Base ((params,s_t t),tag,ranges,effl,effc,bounds) + | tannot -> tannot + in + let rec s_pat (P_aux (p,(l,annot))) = + let re p = P_aux (p,(l,s_tannot annot)) in + match p with + | P_lit _ | P_wild | P_id _ -> re p + | P_as (p',id) -> re (P_as (s_pat p', id)) + | P_typ (ty,p') -> re (P_typ (ty,s_pat p')) + | P_app (id,ps) -> re (P_app (id, List.map s_pat ps)) + | P_record (fps,flag) -> re (P_record (List.map s_fpat fps, flag)) + | P_vector ps -> re (P_vector (List.map s_pat ps)) + | P_vector_indexed ips -> re (P_vector_indexed (List.map (fun (i,p) -> (i,s_pat p)) ips)) + | P_vector_concat ps -> re (P_vector_concat (List.map s_pat ps)) + | P_tup ps -> re (P_tup (List.map s_pat ps)) + | P_list ps -> re (P_list (List.map s_pat ps)) + and s_fpat (FP_aux (FP_Fpat (id, p), (l,annot))) = + FP_aux (FP_Fpat (id, s_pat p), (l,s_tannot annot)) + in + let rec s_exp (E_aux (e,(l,annot))) = + let re e = E_aux (e,(l,s_tannot annot)) in + match e with + | E_block es -> re (E_block (List.map s_exp es)) + | E_nondet es -> re (E_nondet (List.map s_exp es)) + | E_id _ + | E_lit _ + | E_comment _ -> re e + | E_sizeof ne -> re (E_sizeof ne) (* TODO: do this need done? does it appear in type checked code? *) + | E_internal_exp (l,annot) -> re (E_internal_exp (l, s_tannot annot)) + | E_sizeof_internal (l,annot) -> re (E_sizeof_internal (l, s_tannot annot)) + | E_internal_exp_user ((l1,annot1),(l2,annot2)) -> + re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2))) + | E_cast (t,e') -> re (E_cast (t, s_exp e')) + | E_app (id,es) -> re (E_app (id,List.map s_exp es)) + | E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2)) + | E_tuple es -> re (E_tuple (List.map s_exp es)) + | E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3)) + | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,s_exp e1,s_exp e2,s_exp e3,ord,s_exp e4)) + | E_vector es -> re (E_vector (List.map s_exp es)) + | E_vector_indexed (ies,ed) -> re (E_vector_indexed (List.map (fun (i,e) -> (i,s_exp e)) ies, + s_opt_default ed)) + | E_vector_access (e1,e2) -> re (E_vector_access (s_exp e1,s_exp e2)) + | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (s_exp e1,s_exp e2,s_exp e3)) + | E_vector_update (e1,e2,e3) -> re (E_vector_update (s_exp e1,s_exp e2,s_exp e3)) + | E_vector_update_subrange (e1,e2,e3,e4) -> re (E_vector_update_subrange (s_exp e1,s_exp e2,s_exp e3,s_exp e4)) + | E_vector_append (e1,e2) -> re (E_vector_append (s_exp e1,s_exp e2)) + | E_list es -> re (E_list (List.map s_exp es)) + | E_cons (e1,e2) -> re (E_cons (s_exp e1,s_exp e2)) + | E_record fes -> re (E_record (s_fexps fes)) + | E_record_update (e,fes) -> re (E_record_update (s_exp e, s_fexps fes)) + | E_field (e,id) -> re (E_field (s_exp e,id)) + | E_case (e,cases) -> re (E_case (s_exp e, List.map s_pexp cases)) + | E_let (lb,e) -> re (E_let (s_letbind lb, s_exp e)) + | E_assign (le,e) -> re (E_assign (s_lexp le, s_exp e)) + | E_exit e -> re (E_exit (s_exp e)) + | E_return e -> re (E_return (s_exp e)) + | E_assert (e1,e2) -> re (E_assert (s_exp e1,s_exp e2)) + | E_internal_cast ((l,ann),e) -> re (E_internal_cast ((l,s_tannot ann),s_exp e)) + | E_comment_struc e -> re (E_comment_struc e) + | E_internal_let (le,e1,e2) -> re (E_internal_let (s_lexp le, s_exp e1, s_exp e2)) + | E_internal_plet (p,e1,e2) -> re (E_internal_plet (s_pat p, s_exp e1, s_exp e2)) + | E_internal_return e -> re (E_internal_return (s_exp e)) + and s_opt_default (Def_val_aux (ed,(l,annot))) = + match ed with + | Def_val_empty -> Def_val_aux (Def_val_empty,(l,s_tannot annot)) + | Def_val_dec e -> Def_val_aux (Def_val_dec (s_exp e),(l,s_tannot annot)) + and s_fexps (FES_aux (FES_Fexps (fes,flag), (l,annot))) = + FES_aux (FES_Fexps (List.map s_fexp fes, flag), (l,s_tannot annot)) + and s_fexp (FE_aux (FE_Fexp (id,e), (l,annot))) = + FE_aux (FE_Fexp (id,s_exp e),(l,s_tannot annot)) + and s_pexp (Pat_aux (Pat_exp (p,e),(l,annot))) = + Pat_aux (Pat_exp (s_pat p, s_exp e),(l,s_tannot annot)) + and s_letbind (LB_aux (lb,(l,annot))) = + match lb with + | LB_val_explicit (tysch,p,e) -> + LB_aux (LB_val_explicit (s_typschm tysch,s_pat p,s_exp e), (l,s_tannot annot)) + | LB_val_implicit (p,e) -> LB_aux (LB_val_implicit (s_pat p,s_exp e), (l,s_tannot annot)) + and s_lexp (LEXP_aux (e,(l,annot))) = + let re e = LEXP_aux (e,(l,s_tannot annot)) in + match e with + | LEXP_id _ + | LEXP_cast _ + -> re e + | LEXP_memory (id,es) -> re (LEXP_memory (id,List.map s_exp es)) + | LEXP_tup les -> re (LEXP_tup (List.map s_lexp les)) + | LEXP_vector (le,e) -> re (LEXP_vector (s_lexp le, s_exp e)) + | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2)) + | LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id)) + in s_exp exp + +let bindings_from_pat t_env p = + let rec aux_pat (P_aux (p,annot)) = + match p with + | P_lit _ + | P_wild + -> [] + | P_as (p,id) -> id_to_string id::(aux_pat p) + | P_typ (_,p) -> aux_pat p + | P_id id -> + let i = id_to_string id in + if pat_id_is_variable t_env i then [i] else [] + | P_vector ps + | P_vector_concat ps + | P_app (_,ps) + | P_tup ps + | P_list ps + -> List.concat (List.map aux_pat ps) + | P_record (fps,_) -> List.concat (List.map aux_fpat fps) + | P_vector_indexed ips -> List.concat (List.map (fun (_,p) -> aux_pat p) ips) + and aux_fpat (FP_aux (FP_Fpat (_,p), _)) = aux_pat p + in aux_pat p + +let remove_bound t_env env pat = + let bound = bindings_from_pat t_env pat in + List.fold_left (fun sub v -> Envmap.remove env v) env bound + + +let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = + + let can_match (E_aux (e,(l,annot)) as exp) cases = + match e with + | E_id id -> + let i = id_to_string id in + (match Envmap.apply t_env i with + | Some(Base(_,Enum _,_,_,_,_)) -> + let rec findpat cases = + match cases with + | [] -> (Reporting_basic.print_err false true l "Monomorphisation" + ("Failed to find a case for " ^ i); None) + | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some exp + | (Pat_aux (Pat_exp (P_aux (P_id id',_),exp),_))::tl + | (Pat_aux (Pat_exp (P_aux (P_app (id',[]),_),exp),_))::tl -> + if i = id_to_string id' then Some exp else findpat tl + | (Pat_aux (Pat_exp (P_aux (_,(l',_)),_),_))::_ -> + (Reporting_basic.print_err false true l' "Monomorphisation" + "Unexpected kind of pattern for enumeration"; None) + in findpat cases + | _ -> None) + (* TODO: could generalise Lit matching *) + | E_lit (L_aux ((L_zero | L_one | L_true | L_false) as bit, _)) -> + let rec findpat cases = + match cases with + | [] -> (Reporting_basic.print_err false true l "Monomorphisation" + ("Failed to find a case for bit"); None) + | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some exp + | (Pat_aux (Pat_exp (P_aux (P_lit (L_aux (lit, _)),_),exp),_))::tl -> + (match bit,lit with + | (L_zero | L_false), (L_zero | L_false) -> Some exp + | (L_one | L_true ), (L_one | L_true ) -> Some exp + | _ -> findpat tl) + | (Pat_aux (Pat_exp (P_aux (_,(l',_)),_),_))::_ -> + (Reporting_basic.print_err false true l' "Monomorphisation" + "Unexpected kind of pattern for bit"; None) + in findpat cases + | _ -> None + in + + (* TODO: doublecheck *) + let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) = + match l1,l2 with + | (L_zero|L_false), (L_zero|L_false) + | (L_one |L_true ), (L_one |L_true) + -> Some true + | L_undef, _ | _, L_undef -> None + | _ -> Some (l1 = l2) + in + + (* TODO: any useful type information revealed? (probably not) *) + let try_app_infix (l,ann) (E_aux (e1,ann1)) id (E_aux (e2,ann2)) = + let new_l = Generated l in + match e1, id, e2 with + | E_lit l1, ("=="|"!="), E_lit l2 -> + let lit b = if b then L_true else L_false in + let lit b = lit (if id = "==" then b else not b) in + (match lit_eq l1 l2 with + | Some b -> Some (E_aux (E_lit (L_aux (lit b,new_l)), (l,ann))) + | None -> None) + | _ -> None + in + + let build_nexp_subst l t1 t2 = + let rec from_types t1 t2 = + let t1 = match t1.t with Tabbrev(_,t) -> t | _ -> t1 in + let t2 = match t2.t with Tabbrev(_,t) -> t | _ -> t2 in + if t1 = t2 then [] else + match t1.t,t2.t with + | Tapp (s1,args1), Tapp (s2,args2) -> + if s1 = s2 then + List.concat (List.map2 from_args args1 args2) + else (Reporting_basic.print_err false true l "Monomorphisation" + "Unexpected type mismatch"; []) + | Ttup ts1, Ttup ts2 -> + if List.length ts1 = List.length ts2 then + List.concat (List.map2 from_types ts1 ts2) + else (Reporting_basic.print_err false true l "Monomorphisation" + "Unexpected type mismatch"; []) + | _ -> [] + and from_args arg1 arg2 = + match arg1,arg2 with + | TA_typ t1, TA_typ t2 -> from_types t1 t2 + | TA_nexp n1, TA_nexp n2 -> from_nexps n1 n2 + | _ -> [] + and from_nexps n1 n2 = + match n1.nexp, n2.nexp with + | Nvar s, Nvar s' when s = s' -> [] + | Nvar s, _ -> [(s,n2)] + | Nadd (n3,n4), Nadd (n5,n6) + | Nsub (n3,n4), Nsub (n5,n6) + | Nmult (n3,n4), Nmult (n5,n6) + -> from_nexps n3 n5 @ from_nexps n4 n6 + | N2n (n3,p1), N2n (n4,p2) when p1 = p2 -> from_nexps n3 n4 + | Npow (n3,p1), Npow (n4,p2) when p1 = p2 -> from_nexps n3 n4 + | Nneg n3, Nneg n4 -> from_nexps n3 n4 + | _ -> [] + in match t1,t2 with + | Base ((_,t1),_,_,_,_,_),Base ((_,t2),_,_,_,_,_) -> from_types t1 t2 + | _ -> [] + in + + let nexp_substs = ref [] in + + (* Constant propogation *) + let rec const_prop_exp substs ((E_aux (e,(l,annot))) as exp) = + let re e = E_aux (e,(l,annot)) in + match e with + (* TODO: are there circumstances in which we should get rid of these? *) + | E_block es -> re (E_block (List.map (const_prop_exp substs) es)) + | E_nondet es -> re (E_nondet (List.map (const_prop_exp substs) es)) + + | E_id id -> + (match Envmap.apply substs (id_to_string id) with + | None -> exp + | Some exp' -> exp') + | E_lit _ + | E_sizeof _ + | E_internal_exp _ + | E_sizeof_internal _ + | E_internal_exp_user _ + | E_comment _ + -> exp + | E_cast (t,e') -> re (E_cast (t, const_prop_exp substs e')) + | E_app (id,es) -> re (E_app (id,List.map (const_prop_exp substs) es)) + | E_app_infix (e1,id,e2) -> + let e1',e2' = const_prop_exp substs e1,const_prop_exp substs e2 in + (match try_app_infix (l,annot) e1' (id_to_string id) e2' with + | Some exp -> exp + | None -> re (E_app_infix (e1',id,e2'))) + | E_tuple es -> re (E_tuple (List.map (const_prop_exp substs) es)) + | E_if (e1,e2,e3) -> + let e1' = const_prop_exp substs e1 in + let e2',e3' = const_prop_exp substs e2, const_prop_exp substs e3 in + (match e1' with + | E_aux (E_lit (L_aux ((L_true|L_false) as lit ,_)),_) -> + let e' = match lit with L_true -> e2' | _ -> e3' in + (match e' with E_aux (_,(_,annot')) -> + nexp_substs := build_nexp_subst l annot annot' @ !nexp_substs; + e') + | _ -> re (E_if (e1',e2',e3'))) + | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3,ord,const_prop_exp (Envmap.remove substs (id_to_string id)) e4)) + | E_vector es -> re (E_vector (List.map (const_prop_exp substs) es)) + | E_vector_indexed (ies,ed) -> re (E_vector_indexed (List.map (fun (i,e) -> (i,const_prop_exp substs e)) ies, + const_prop_opt_default substs ed)) + | E_vector_access (e1,e2) -> re (E_vector_access (const_prop_exp substs e1,const_prop_exp substs e2)) + | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3)) + | E_vector_update (e1,e2,e3) -> re (E_vector_update (const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3)) + | E_vector_update_subrange (e1,e2,e3,e4) -> re (E_vector_update_subrange (const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3,const_prop_exp substs e4)) + | E_vector_append (e1,e2) -> re (E_vector_append (const_prop_exp substs e1,const_prop_exp substs e2)) + | E_list es -> re (E_list (List.map (const_prop_exp substs) es)) + | E_cons (e1,e2) -> re (E_cons (const_prop_exp substs e1,const_prop_exp substs e2)) + | E_record fes -> re (E_record (const_prop_fexps substs fes)) + | E_record_update (e,fes) -> re (E_record_update (const_prop_exp substs e, const_prop_fexps substs fes)) + | E_field (e,id) -> re (E_field (const_prop_exp substs e,id)) + | E_case (e,cases) -> + let e' = const_prop_exp substs e in + (match can_match e' cases with + | None -> re (E_case (e', List.map (const_prop_pexp substs) cases)) + | Some (E_aux (_,(_,annot')) as exp) -> + nexp_substs := build_nexp_subst l annot annot' @ !nexp_substs; + const_prop_exp substs exp) + | E_let (lb,e) -> + let (lb',substs') = const_prop_letbind substs lb in + re (E_let (lb', const_prop_exp substs' e)) + | E_assign (le,e) -> re (E_assign (const_prop_lexp substs le, const_prop_exp substs e)) + | E_exit e -> re (E_exit (const_prop_exp substs e)) + | E_return e -> re (E_return (const_prop_exp substs e)) + | E_assert (e1,e2) -> re (E_assert (const_prop_exp substs e1,const_prop_exp substs e2)) + | E_internal_cast (ann,e) -> re (E_internal_cast (ann,const_prop_exp substs e)) + | E_comment_struc e -> re (E_comment_struc e) + | E_internal_let _ + | E_internal_plet _ + | E_internal_return _ + -> raise (Reporting_basic.err_unreachable l + "Unexpected internal expression encountered in monomorphisation") + and const_prop_opt_default substs ((Def_val_aux (ed,annot)) as eda) = + match ed with + | Def_val_empty -> eda + | Def_val_dec e -> Def_val_aux (Def_val_dec (const_prop_exp substs e),annot) + and const_prop_fexps substs (FES_aux (FES_Fexps (fes,flag), annot)) = + FES_aux (FES_Fexps (List.map (const_prop_fexp substs) fes, flag), annot) + and const_prop_fexp substs (FE_aux (FE_Fexp (id,e), annot)) = + FE_aux (FE_Fexp (id,const_prop_exp substs e),annot) + and const_prop_pexp substs (Pat_aux (Pat_exp (p,e),l)) = + Pat_aux (Pat_exp (p,const_prop_exp (remove_bound t_env substs p) e),l) + and const_prop_letbind substs (LB_aux (lb,annot)) = + match lb with + | LB_val_explicit (tysch,p,e) -> + (LB_aux (LB_val_explicit (tysch,p,const_prop_exp substs e), annot), + remove_bound t_env substs p) + | LB_val_implicit (p,e) -> + (LB_aux (LB_val_implicit (p,const_prop_exp substs e), annot), + remove_bound t_env substs p) + and const_prop_lexp substs ((LEXP_aux (e,annot)) as le) = + let re e = LEXP_aux (e,annot) in + match e with + | LEXP_id _ (* shouldn't end up substituting here *) + | LEXP_cast _ + -> le + | LEXP_memory (id,es) -> re (LEXP_memory (id,List.map (const_prop_exp substs) es)) (* or here *) + | LEXP_tup les -> re (LEXP_tup (List.map (const_prop_lexp substs) les)) + | LEXP_vector (le,e) -> re (LEXP_vector (const_prop_lexp substs le, const_prop_exp substs e)) + | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (const_prop_lexp substs le, const_prop_exp substs e1, const_prop_exp substs e2)) + | LEXP_field (le,id) -> re (LEXP_field (const_prop_lexp substs le, id)) + in + + let subst_exp subst exp = + if !disable_const_propagation then + (* TODO: This just sticks a let in - we really need propogation *) + let (subi,(E_aux (_,subannot) as sube)) = subst in + let E_aux (e,(l,annot)) = exp in + let lg = Generated l in + let p = P_aux (P_id (Id_aux (Id subi, lg)), subannot) in + E_aux (E_let (LB_aux (LB_val_implicit (p,sube),(lg,annot)), exp),(lg,annot)) + else + let substs = Envmap.from_list [subst] in + let () = nexp_substs := [] in + let exp' = const_prop_exp substs exp in + (* Substitute what we've learned about nvars into the term *) + let nsubsts = Envmap.from_list (List.map (fun (id,ne) -> (id,TA_nexp ne)) !nexp_substs) in + let () = nexp_substs := [] in + nexp_subst nsubsts exp' + in + + + (* Split a variable pattern into every possible value *) + + let split id l tannot = + let new_l = Generated l in + let new_id i = Id_aux (Id i, new_l) in + match tannot with + | Type_internal.NoTyp -> + raise (Reporting_basic.err_general l ("No type information for variable " ^ id ^ " to split on")) + | Type_internal.Overload _ -> + raise (Reporting_basic.err_general l ("Type for variable " ^ id ^ " to split on is overloaded")) + | Type_internal.Base ((tparams,ty0),_,cs,_,_,_) -> + let () = match tparams with + | [] -> () + | _ -> raise (Reporting_basic.err_general l ("Type for variable " ^ id ^ " to split on has parameters")) + in + let ty = match ty0.t with Tabbrev(_,ty) -> ty | _ -> ty0 in + let cannot () = + raise (Reporting_basic.err_general l + ("Cannot split type " ^ Type_internal.t_to_string ty ^ " for variable " ^ id)) + in + (match ty.t with + | Tid i -> + (match Envmap.apply d_env.enum_env i with + (* enumerations *) + | Some ns -> List.map (fun n -> (P_aux (P_id (new_id n),(l,tannot)), + (id,E_aux (E_id (new_id n),(new_l,tannot))))) ns + | None -> + if i = "bit" then + List.map (fun b -> + P_aux (P_lit (L_aux (b,new_l)),(l,tannot)), + (id,E_aux (E_lit (L_aux (b,new_l)),(new_l, tannot)))) + [L_zero; L_one] + else cannot ()) + (*| vectors TODO *) + (*| numbers TODO *) + | _ -> cannot ()) + in + + (* Split variable patterns at the given locations *) + + let map_locs ls (Defs defs) = + let rec match_l = function + | Unknown + | Int _ -> [] + | Generated l -> [] (* Could do match_l l, but only want to split user-written patterns *) + | Range (p,q) -> + List.filter (fun ((filename,line),_) -> + Filename.basename p.Lexing.pos_fname = filename && + p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum) ls + in + + let split_pat var p = + let rec list f = function + | [] -> None + | h::t -> + match f h with + | None -> (match list f t with None -> None | Some (l,ps,r) -> Some (h::l,ps,r)) + | Some ps -> Some ([],ps,t) + in + let rec spl (P_aux (p,(l,annot))) = + let relist f ctx ps = + optmap (list f ps) + (fun (left,ps,right) -> + List.map (fun (p,sub) -> P_aux (ctx (left@p::right),(l,annot)),sub) ps) + in + let re f p = + optmap (spl p) + (fun ps -> List.map (fun (p,sub) -> (P_aux (f p,(l,annot)), sub)) ps) + in + let fpat (FP_aux ((FP_Fpat (id,p),annot))) = + optmap (spl p) + (fun ps -> List.map (fun (p,sub) -> FP_aux (FP_Fpat (id,p), annot), sub) ps) + in + let ipat (i,p) = optmap (spl p) (List.map (fun (p,sub) -> (i,p),sub)) + in + match p with + | P_lit _ + | P_wild + -> None + | P_as (p',id) -> + let i = id_to_string id in + if i = var + then raise (Reporting_basic.err_general l + ("Cannot split " ^ var ^ " on 'as' pattern")) + else re (fun p -> P_as (p,id)) p' + | P_typ (t,p') -> re (fun p -> P_typ (t,p)) p' + | P_id id -> + let i = id_to_string id in + if i = var + then Some (split i l annot) + else None + | P_app (id,ps) -> + relist spl (fun ps -> P_app (id,ps)) ps + | P_record (fps,flag) -> + relist fpat (fun fps -> P_record (fps,flag)) fps + | P_vector ps -> + relist spl (fun ps -> P_vector ps) ps + | P_vector_indexed ips -> + relist ipat (fun ips -> P_vector_indexed ips) ips + | P_vector_concat ps -> + relist spl (fun ps -> P_vector_concat ps) ps + | P_tup ps -> + relist spl (fun ps -> P_tup ps) ps + | P_list ps -> + relist spl (fun ps -> P_list ps) ps + in spl p + in + + let map_pat (P_aux (_,(l,_)) as p) = + match match_l l with + | [] -> None + | [(_,var)] -> split_pat var p + | lvs -> raise (Reporting_basic.err_general l + ("Multiple variables to split on: " ^ String.concat ", " (List.map snd lvs))) + in + + let check_single_pat (P_aux (_,(l,_)) as p) = + match match_l l with + | [] -> p + | lvs -> + let pvs = bindings_from_pat t_env p in + let overlap = List.exists (fun (_,v) -> List.mem v pvs) lvs in + let () = + if overlap then + Reporting_basic.print_err false true l "Monomorphisation" + "Splitting a singleton pattern is not possible" + in p + in + + let rec map_exp ((E_aux (e,annot)) as ea) = + let re e = E_aux (e,annot) in + match e with + | E_block es -> re (E_block (List.map map_exp es)) + | E_nondet es -> re (E_nondet (List.map map_exp es)) + | E_id _ + | E_lit _ + | E_sizeof _ + | E_internal_exp _ + | E_sizeof_internal _ + | E_internal_exp_user _ + | E_comment _ + -> ea + | E_cast (t,e') -> re (E_cast (t, map_exp e')) + | E_app (id,es) -> re (E_app (id,List.map map_exp es)) + | E_app_infix (e1,id,e2) -> re (E_app_infix (map_exp e1,id,map_exp e2)) + | E_tuple es -> re (E_tuple (List.map map_exp es)) + | E_if (e1,e2,e3) -> re (E_if (map_exp e1, map_exp e2, map_exp e3)) + | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,map_exp e1,map_exp e2,map_exp e3,ord,map_exp e4)) + | E_vector es -> re (E_vector (List.map map_exp es)) + | E_vector_indexed (ies,ed) -> re (E_vector_indexed (List.map (fun (i,e) -> (i,map_exp e)) ies, + map_opt_default ed)) + | E_vector_access (e1,e2) -> re (E_vector_access (map_exp e1,map_exp e2)) + | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (map_exp e1,map_exp e2,map_exp e3)) + | E_vector_update (e1,e2,e3) -> re (E_vector_update (map_exp e1,map_exp e2,map_exp e3)) + | E_vector_update_subrange (e1,e2,e3,e4) -> re (E_vector_update_subrange (map_exp e1,map_exp e2,map_exp e3,map_exp e4)) + | E_vector_append (e1,e2) -> re (E_vector_append (map_exp e1,map_exp e2)) + | E_list es -> re (E_list (List.map map_exp es)) + | E_cons (e1,e2) -> re (E_cons (map_exp e1,map_exp e2)) + | E_record fes -> re (E_record (map_fexps fes)) + | E_record_update (e,fes) -> re (E_record_update (map_exp e, map_fexps fes)) + | E_field (e,id) -> re (E_field (map_exp e,id)) + | E_case (e,cases) -> re (E_case (map_exp e, List.concat (List.map map_pexp cases))) + | E_let (lb,e) -> re (E_let (map_letbind lb, map_exp e)) + | E_assign (le,e) -> re (E_assign (map_lexp le, map_exp e)) + | E_exit e -> re (E_exit (map_exp e)) + | E_return e -> re (E_return (map_exp e)) + | E_assert (e1,e2) -> re (E_assert (map_exp e1,map_exp e2)) + | E_internal_cast (ann,e) -> re (E_internal_cast (ann,map_exp e)) + | E_comment_struc e -> re (E_comment_struc e) + | E_internal_let (le,e1,e2) -> re (E_internal_let (map_lexp le, map_exp e1, map_exp e2)) + | E_internal_plet (p,e1,e2) -> re (E_internal_plet (check_single_pat p, map_exp e1, map_exp e2)) + | E_internal_return e -> re (E_internal_return (map_exp e)) + and map_opt_default ((Def_val_aux (ed,annot)) as eda) = + match ed with + | Def_val_empty -> eda + | Def_val_dec e -> Def_val_aux (Def_val_dec (map_exp e),annot) + and map_fexps (FES_aux (FES_Fexps (fes,flag), annot)) = + FES_aux (FES_Fexps (List.map map_fexp fes, flag), annot) + and map_fexp (FE_aux (FE_Fexp (id,e), annot)) = + FE_aux (FE_Fexp (id,map_exp e),annot) + and map_pexp (Pat_aux (Pat_exp (p,e),l)) = + match map_pat p with + | None -> [Pat_aux (Pat_exp (p,map_exp e),l)] + | Some patsubsts -> + List.map (fun (pat',subst) -> + let exp' = subst_exp subst e in + Pat_aux (Pat_exp (pat', map_exp exp'),l)) + patsubsts + and map_letbind (LB_aux (lb,annot)) = + match lb with + | LB_val_explicit (tysch,p,e) -> LB_aux (LB_val_explicit (tysch,check_single_pat p,map_exp e), annot) + | LB_val_implicit (p,e) -> LB_aux (LB_val_implicit (check_single_pat p,map_exp e), annot) + and map_lexp ((LEXP_aux (e,annot)) as le) = + let re e = LEXP_aux (e,annot) in + match e with + | LEXP_id _ + | LEXP_cast _ + -> le + | LEXP_memory (id,es) -> re (LEXP_memory (id,List.map map_exp es)) + | LEXP_tup les -> re (LEXP_tup (List.map map_lexp les)) + | LEXP_vector (le,e) -> re (LEXP_vector (map_lexp le, map_exp e)) + | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (map_lexp le, map_exp e1, map_exp e2)) + | LEXP_field (le,id) -> re (LEXP_field (map_lexp le, id)) + in + + let map_funcl (FCL_aux (FCL_Funcl (id,pat,exp),annot)) = + match map_pat pat with + | None -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)] + | Some patsubsts -> + List.map (fun (pat',subst) -> + let exp' = subst_exp subst exp in + FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot)) + patsubsts + in + + let map_fundef (FD_aux (FD_function (r,t,e,fcls),annot)) = + FD_aux (FD_function (r,t,e,List.concat (List.map map_funcl fcls)),annot) + in + let map_scattered_def sd = + match sd with + | SD_aux (SD_scattered_funcl fcl, annot) -> + List.map (fun fcl' -> SD_aux (SD_scattered_funcl fcl', annot)) (map_funcl fcl) + | _ -> [sd] + in + let map_def d = + match d with + | DEF_kind _ + | DEF_type _ + | DEF_spec _ + | DEF_default _ + | DEF_reg_dec _ + | DEF_comm _ + -> [d] + | DEF_fundef fd -> [DEF_fundef (map_fundef fd)] + | DEF_val lb -> [DEF_val (map_letbind lb)] + | DEF_scattered sd -> List.map (fun x -> DEF_scattered x) (map_scattered_def sd) + + in + Defs (List.concat (List.map map_def defs)) + + in map_locs splits defs diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 9758b2de..5f2e9888 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -118,18 +118,24 @@ let doc_id_lem_ctor (Id_aux(i,_)) = * token in case of x ending with star. *) separate space [colon; string (String.capitalize x); empty] +let effectful_set = + List.exists + (fun (BE_aux (eff,_)) -> + match eff with + | BE_rreg | BE_wreg | BE_rmem | BE_rmemt | BE_wmem | BE_eamem + | BE_exmem | BE_wmv | BE_wmvt | BE_barr | BE_depend | BE_nondet + | BE_escape -> true + | _ -> false) + let effectful (Effect_aux (eff,_)) = match eff with | Effect_var _ -> failwith "effectful: Effect_var not supported" - | Effect_set effs -> - List.exists - (fun (BE_aux (eff,_)) -> - match eff with - | BE_rreg | BE_wreg | BE_rmem | BE_rmemt | BE_wmem | BE_eamem - | BE_exmem | BE_wmv | BE_wmvt | BE_barr | BE_depend | BE_nondet - | BE_escape -> true - | _ -> false) - effs + | Effect_set effs -> effectful_set effs + +let effectful_t eff = + match eff.effect with + | Eset effs -> effectful_set effs + | _ -> false let rec is_number {t=t} = match t with @@ -160,9 +166,26 @@ let doc_typ_lem, doc_atomic_typ_lem = if atyp_needed then parens tpp else tpp | _ -> app_typ regtypes atyp_needed ty and app_typ regtypes atyp_needed ((Typ_aux (t, _)) as ty) = match t with - | Typ_app(Id_aux (Id "vector", _),[_;_;_;Typ_arg_aux (Typ_arg_typ typa, _)]) -> - let tpp = string "vector" ^^ space ^^ typ regtypes typa in + | Typ_app(Id_aux (Id "vector", _), [ + Typ_arg_aux (Typ_arg_nexp n, _); + Typ_arg_aux (Typ_arg_nexp m, _); + Typ_arg_aux (Typ_arg_order ord, _); + Typ_arg_aux (Typ_arg_typ elem_typ, _)]) -> + let tpp = match elem_typ with + | Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) -> + let len = match m with + | (Nexp_aux(Nexp_constant i,_)) -> string "ty" ^^ doc_int i + | _ -> doc_nexp m in + string "bitvector" ^^ space ^^ len + | _ -> string "vector" ^^ space ^^ typ regtypes elem_typ in if atyp_needed then parens tpp else tpp + | Typ_app(Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ etyp, _)]) -> + (* TODO: Better distinguish register names and contents? + The former are represented in the Lem library using a type + "register" (without parameters), the latter just using the content + type (e.g. "bitvector ty64"). We assume the latter is meant here + and drop the "register" keyword. *) + fn_typ regtypes atyp_needed etyp | Typ_app(Id_aux (Id "range", _),_) -> (string "integer") | Typ_app(Id_aux (Id "implicit", _),_) -> @@ -189,12 +212,17 @@ let doc_typ_lem, doc_atomic_typ_lem = let tpp = typ regtypes ty in if atyp_needed then parens tpp else tpp and doc_typ_arg_lem regtypes (Typ_arg_aux(t,_)) = match t with - | Typ_arg_typ t -> app_typ regtypes false t + | Typ_arg_typ t -> app_typ regtypes true t | Typ_arg_nexp n -> empty | Typ_arg_order o -> empty | Typ_arg_effect e -> empty in typ', atomic_typ +let doc_tannot_lem regtypes eff t = + let ta = doc_typ_lem regtypes (t_to_typ (normalize_t t)) in + if eff then string " : M " ^^ parens ta + else string " : " ^^ ta + (* doc_lit_lem gets as an additional parameter the type information from the * expression around it: that's a hack, but how else can we distinguish between * undefined values of different types ? *) @@ -270,6 +298,44 @@ let rec doc_pat_lem regtypes apat_needed (P_aux (p,(l,annot)) as pa) = match p w | _ -> parens (separate_map comma_sp (doc_pat_lem regtypes false) pats)) | P_list pats -> brackets (separate_map semi (doc_pat_lem regtypes false) pats) (*Never seen but easy in lem*) +let rec contains_bitvector_type t = match t.t with + | Ttup ts -> List.exists contains_bitvector_type ts + | Tapp (_, targs) -> is_bit_vector t || List.exists contains_bitvector_type_arg targs + | Tabbrev (_,t') -> contains_bitvector_type t' + | Tfn (t1,t2,_,_) -> contains_bitvector_type t1 || contains_bitvector_type t2 + | _ -> false +and contains_bitvector_type_arg targ = match targ with + | TA_typ t -> contains_bitvector_type t + | _ -> false + +let const_nexp nexp = match nexp.nexp with + | Nconst _ -> true + | _ -> false + +(* Check for variables in types that would be pretty-printed. + In particular, in case of vector types, only the element type and the + length argument are checked for variables, and the latter only if it is + a bitvector; for other types of vectors, the length is not pretty-printed + in the type, and the start index is never pretty-printed in vector types. *) +let rec contains_t_pp_var t = match t.t with + | Tvar _ -> true + | Tfn (t1,t2,_,_) -> contains_t_pp_var t1 || contains_t_pp_var t2 + | Ttup ts -> List.exists contains_t_pp_var ts + | Tapp ("vector",[_;TA_nexp m;_;TA_typ t']) -> + if is_bit_vector t then not (const_nexp (normalize_nexp m)) + else contains_t_pp_var t' + | Tapp (c,targs) -> List.exists contains_t_arg_pp_var targs + | Tabbrev (_,t') -> contains_t_pp_var t' + | Toptions (t1,t2o) -> + contains_t_pp_var t1 || + (match t2o with Some t2 -> contains_t_pp_var t2 | _ -> false) + | Tuvar _ -> true + | Tid _ -> false +and contains_t_arg_pp_var targ = match targ with + | TA_typ t -> contains_t_pp_var t + | TA_nexp nexp -> not (const_nexp (normalize_nexp nexp)) + | _ -> false + let prefix_recordtype = true let report = Reporting_basic.err_unreachable let doc_exp_lem, doc_let_lem = @@ -336,8 +402,15 @@ let doc_exp_lem, doc_let_lem = | _ -> (prefix 2 1) (string "write_reg") (doc_lexp_deref_lem regtypes le ^/^ expY e)) | E_vector_append(l,r) -> + let (Base((_,t),_,_,_,_,_)) = annot in + let (call,ta,aexp_needed) = + if is_bit_vector t then + if not (contains_t_pp_var t) + then ("bitvector_concat", doc_tannot_lem regtypes false t, true) + else ("bitvector_concat", empty, aexp_needed) + else ("vector_concat",empty,aexp_needed) in let epp = - align (group (separate space [expY l;string "^^"] ^/^ expY r)) in + align (group (separate space [string call;expY l;expY r])) ^^ ta in if aexp_needed then parens epp else epp | E_cons(l,r) -> doc_op (group (colon^^colon)) (expY l) (expY r) | E_if(c,t,e) -> @@ -382,7 +455,15 @@ let doc_exp_lem, doc_let_lem = if aexp_needed then parens (align epp) else epp | Id_aux (Id "slice_raw",_) -> let [e1;e2;e3] = args in - let epp = separate space [string "slice_raw";expY e1;expY e2;expY e3] in + let (E_aux (_,(_,Base((_,t1),_,_,_,_,_)))) = e1 in + let call = if is_bit_vector t1 then "bvslice_raw" else "slice_raw" in + let epp = separate space [string call;expY e1;expY e2;expY e3] in + if aexp_needed then parens (align epp) else epp + | Id_aux (Id "length",_) -> + let [arg] = args in + let (E_aux (_,(_,Base((_,targ),_,_,_,_,_)))) = arg in + let call = if is_bit_vector targ then "bvlength" else "length" in + let epp = separate space [string call;expY arg] in if aexp_needed then parens (align epp) else epp | _ -> begin match annot with @@ -392,10 +473,13 @@ let doc_exp_lem, doc_let_lem = if aexp_needed then parens (align epp) else epp | Base (_,Constructor _,_,_,_,_) -> let argpp a_needed arg = - let (E_aux (_,(_,Base((_,{t=t}),_,_,_,_,_)))) = arg in - match t with - | Tapp("vector",_) -> - let epp = concat [string "reset_vector_start";space;expY arg] in + let (E_aux (_,(_,Base((_,t),_,_,_,_,_)))) = arg in + match t.t with + | Tapp("vector",[_;_;_;_]) -> + let call = + if is_bit_vector t then "reset_bitvector_start" + else "reset_vector_start" in + let epp = concat [string call;space;expY arg] in if a_needed then parens epp else epp | _ -> expV a_needed arg in let epp = @@ -411,17 +495,25 @@ let doc_exp_lem, doc_let_lem = | Base(_,External (Some n),_,_,_,_) -> string n | _ -> doc_id_lem f in let argpp a_needed arg = - let (E_aux (_,(_,Base((_,{t=t}),_,_,_,_,_)))) = arg in - match t with - | Tapp("vector",_) -> - let epp = concat [string "reset_vector_start";space;expY arg] in + let (E_aux (_,(_,Base((_,t),_,_,_,_,_)))) = arg in + match t.t with + | Tapp("vector",[_;_;_;_]) -> + let call = + if is_bit_vector t then "reset_bitvector_start" + else "reset_vector_start" in + let epp = concat [string call;space;expY arg] in if a_needed then parens epp else epp | _ -> expV a_needed arg in let argspp = match args with | [arg] -> argpp true arg | args -> parens (align (separate_map (comma ^^ break 0) (argpp false) args)) in let epp = align (call ^//^ argspp) in - if aexp_needed then parens (align epp) else epp + let (taepp,aexp_needed) = + let (Base ((_,t),_,_,eff,_,_)) = annot in + if contains_bitvector_type t && not (contains_t_pp_var t) + then (align epp ^^ (doc_tannot_lem regtypes (effectful_t eff) t), true) + else (epp, aexp_needed) in + if aexp_needed then parens (align taepp) else taepp end end | E_vector_access (v,e) -> @@ -430,27 +522,40 @@ let doc_exp_lem, doc_let_lem = if has_rreg_effect eff then separate space [string "read_reg_bit";expY v;expY e] else - separate space [string "access";expY v;expY e] in + let (E_aux (_,(_,Base ((_,tv),_,_,_,_,_)))) = v in + let call = if is_bit_vector tv then "bvaccess" else "access" in + separate space [string call;expY v;expY e] in if aexp_needed then parens (align epp) else epp | E_vector_subrange (v,e1,e2) -> - let (Base (_,_,_,_,eff,_)) = annot in - let epp = + let (Base ((_,t),_,_,_,eff,_)) = annot in + let (epp,aexp_needed) = if has_rreg_effect eff then - align (string "read_reg_range" ^^ space ^^ expY v ^//^ expY e1 ^//^ expY e2) + let epp = align (string "read_reg_range" ^^ space ^^ expY v ^//^ expY e1 ^//^ expY e2) in + if contains_bitvector_type t && not (contains_t_pp_var t) + then (epp ^^ doc_tannot_lem regtypes true t, true) + else (epp, aexp_needed) else - align (string "slice" ^^ space ^^ expY v ^//^ expY e1 ^//^ expY e2) in + if is_bit_vector t then + let bepp = string "bvslice" ^^ space ^^ expY v ^//^ expY e1 ^//^ expY e2 in + if not (contains_t_pp_var t) + then (bepp ^^ doc_tannot_lem regtypes false t, true) + else (bepp, aexp_needed) + else (string "slice" ^^ space ^^ expY v ^//^ expY e1 ^//^ expY e2, aexp_needed) in if aexp_needed then parens (align epp) else epp | E_field((E_aux(_,(l,fannot)) as fexp),id) -> - let (Base ((_,{t = t}),_,_,_,_,_)) = fannot in - (match t with + let (Base ((_,{t = ft}),_,_,_,_,_)) = fannot in + (match ft with | Tabbrev({t = Tid regtyp},{t=Tapp("register",_)}) -> - let field_f = match annot with - | Base((_,{t = Tid "bit"}),_,_,_,_,_) - | Base((_,{t = Tabbrev(_,{t=Tid "bit"})}),_,_,_,_,_) -> - string "read_reg_bitfield" + let (Base((_,t),_,_,_,_,_)) = annot in + let field_f = match t.t with + | Tid "bit" | Tabbrev(_,{t=Tid "bit"}) -> string "read_reg_bitfield" | _ -> string "read_reg_field" in + let (ta,aexp_needed) = + if contains_bitvector_type t && not (contains_t_pp_var t) + then (doc_tannot_lem regtypes true t, true) + else (empty, aexp_needed) in let epp = field_f ^^ space ^^ (expY fexp) ^^ space ^^ string_lit (doc_id_lem id) in - if aexp_needed then parens (align epp) else epp + if aexp_needed then parens (align epp ^^ ta) else (epp ^^ ta) | Tid recordtyp | Tabbrev ({t = Tid recordtyp},_) -> let fname = @@ -464,57 +569,73 @@ let doc_exp_lem, doc_let_lem = | E_block exps -> raise (report l "Blocks should have been removed till now.") | E_nondet exps -> raise (report l "Nondet blocks not supported.") | E_id id -> + let (Base((_,t),_,_,_,_,_)) = annot in (match annot with | Base((_, ({t = Tapp("register",_)} | {t=Tabbrev(_,{t=Tapp("register",_)})})), External _,_,eff,_,_) -> if has_rreg_effect eff then - separate space [string "read_reg";doc_id_lem id] + let epp = separate space [string "read_reg";doc_id_lem id] in + if contains_bitvector_type t && not (contains_t_pp_var t) + then parens (epp ^^ doc_tannot_lem regtypes true t) + else epp else doc_id_lem id | Base(_,(Constructor i |Enum i),_,_,_,_) -> doc_id_lem_ctor id | Base((_,t),Alias alias_info,_,eff,_,_) -> (match alias_info with | Alias_field(reg,field) -> - let epp = match t.t with - | Tid "bit" | Tabbrev (_,{t=Tid "bit"}) -> - (separate space) - [string "read_reg_bitfield"; string reg;string_lit(string field)] - | _ -> - (separate space) - [string "read_reg_field"; string reg; string_lit(string field)] in - if aexp_needed then parens (align epp) else epp + let call = match t.t with + | Tid "bit" | Tabbrev (_,{t=Tid "bit"}) -> "read_reg_bitfield" + | _ -> "read_reg_field" in + let ta = + if contains_bitvector_type t && not (contains_t_pp_var t) + then doc_tannot_lem regtypes true t else empty in + let epp = separate space [string call;string reg;string_lit(string field)] ^^ ta in + if aexp_needed then parens (align epp) else epp | Alias_pair(reg1,reg2) -> - let epp = - if has_rreg_effect eff then - separate space [string "read_two_regs";string reg1;string reg2] - else - separate space [string "RegisterPair";string reg1;string reg2] in - if aexp_needed then parens (align epp) else epp + let (call,ta) = + if has_rreg_effect eff then + let ta = + if contains_bitvector_type t && not (contains_t_pp_var t) + then doc_tannot_lem regtypes true t else empty in + ("read_two_regs", ta) + else + ("RegisterPair", empty) in + let epp = separate space [string call;string reg1;string reg2] ^^ ta in + if aexp_needed then parens (align epp) else epp | Alias_extract(reg,start,stop) -> - let epp = - if start = stop then - (separate space) - [string "access";doc_int start; - parens (string "read_reg" ^^ space ^^ string reg)] - else - (separate space) - [string "slice"; doc_int start; doc_int stop; - parens (string "read_reg" ^^ space ^^ string reg)] in - if aexp_needed then parens (align epp) else epp + let epp = + if start = stop then + separate space [string "read_reg_bit";string reg;doc_int start] + else + let ta = + if contains_bitvector_type t && not (contains_t_pp_var t) + then doc_tannot_lem regtypes true t else empty in + separate space [string "read_reg_range";string reg;doc_int start;doc_int stop] ^^ ta in + if aexp_needed then parens (align epp) else epp ) | _ -> doc_id_lem id) | E_lit lit -> doc_lit_lem false lit annot | E_cast(Typ_aux (typ,_),e) -> (match annot with - | Base(_,External _,_,_,_,_) -> string "read_reg" ^^ space ^^ expY e - | _ -> + | Base((_,t),External _,_,_,_,_) -> + let epp = string "read_reg" ^^ space ^^ expY e in + if contains_bitvector_type t && not (contains_t_pp_var t) + then parens (epp ^^ doc_tannot_lem regtypes true t) else epp + | Base((_,t),_,_,_,_,_) -> (match typ with | Typ_app (Id_aux (Id "vector",_), [Typ_arg_aux (Typ_arg_nexp(Nexp_aux (Nexp_constant i,_)),_);_;_;_]) -> - let epp = (concat [string "set_vector_start";space;string (string_of_int i)]) ^//^ + let call = + if is_bit_vector t then "set_bitvector_start" + else "set_vector_start" in + let epp = (concat [string call;space;string (string_of_int i)]) ^//^ expY e in if aexp_needed then parens epp else epp | Typ_var (Kid_aux (Var "length",_)) -> - let epp = (string "set_vector_start_to_length") ^//^ expY e in + let call = + if is_bit_vector t then "set_bitvector_start_to_length" + else "set_vector_start_to_length" in + let epp = (string call) ^//^ expY e in if aexp_needed then parens epp else epp | _ -> expV aexp_needed e)) (*(parens (doc_op colon (group (expY e)) (doc_typ_lem typ)))) *) @@ -543,8 +664,8 @@ let doc_exp_lem, doc_let_lem = (match annot with | Base((_,t),_,_,_,_,_) -> match t.t with - | Tapp("vector", [TA_nexp start; _; TA_ord order; _]) - | Tabbrev(_,{t= Tapp("vector", [TA_nexp start; _; TA_ord order; _])}) -> + | Tapp("vector", [TA_nexp start; TA_nexp len; TA_ord order; TA_typ etyp]) + | Tabbrev(_,{t= Tapp("vector", [TA_nexp start; TA_nexp len; TA_ord order; TA_typ etyp])}) -> let dir,dir_out = match order.order with | Oinc -> true,"true" | _ -> false, "false" in @@ -566,6 +687,13 @@ let doc_exp_lem, doc_let_lem = align (group expspp) in let epp = group (separate space [string "Vector"; brackets expspp;string start;string dir_out]) in + let (epp,aexp_needed) = + if etyp.t = Tid "bit" then + let bepp = string "vec_to_bvec" ^^ space ^^ parens (align epp) in + if contains_t_pp_var t + then (bepp, aexp_needed) + else (bepp ^^ doc_tannot_lem regtypes false t, true) + else (epp,aexp_needed) in if aexp_needed then parens (align epp) else epp ) | E_vector_indexed (iexps, (Def_val_aux (default,(dl,dannot)))) -> @@ -621,12 +749,20 @@ let doc_exp_lem, doc_let_lem = let epp = align (group (call ^//^ brackets expspp ^/^ separate space [default_string;string start;string size;string dir_out])) in - if aexp_needed then parens (align epp) else epp + let (bepp, aexp_needed) = + if is_bit_vector t + then (string "vec_to_bvec" ^^ space ^^ parens (epp) ^^ doc_tannot_lem regtypes false t, true) + else (epp, aexp_needed) in + if aexp_needed then parens (align bepp) else bepp | E_vector_update(v,e1,e2) -> - let epp = separate space [string "update_pos";expY v;expY e1;expY e2] in + let (Base((_,t),_,_,_,_,_)) = annot in + let call = if is_bit_vector t then "bvupdate_pos" else "update_pos" in + let epp = separate space [string call;expY v;expY e1;expY e2] in if aexp_needed then parens (align epp) else epp | E_vector_update_subrange(v,e1,e2,e3) -> - let epp = align (string "update" ^//^ + let (Base((_,t),_,_,_,_,_)) = annot in + let call = if is_bit_vector t then "bvupdate" else "update" in + let epp = align (string call ^//^ group (group (expY v) ^/^ group (expY e1) ^/^ group (expY e2)) ^/^ group (expY e3)) in if aexp_needed then parens (align epp) else epp @@ -664,9 +800,13 @@ let doc_exp_lem, doc_let_lem = (match annot with | Base((_,t),External(Some name),_,_,_,_) -> let argpp arg = - let (E_aux (_,(_,Base((_,{t=t}),_,_,_,_,_)))) = arg in - match t with - | Tapp("vector",_) -> parens (concat [string "reset_vector_start";space;expY arg]) + let (E_aux (_,(_,Base((_,t),_,_,_,_,_)))) = arg in + match t.t with + | Tapp("vector",_) -> + let call = + if is_bit_vector t then "reset_bitvector_start" + else "reset_vector_start" in + parens (concat [string call;space;expY arg]) | _ -> expY arg in let epp = let aux name = align (argpp e1 ^^ space ^^ string name ^//^ argpp e2) in @@ -734,6 +874,10 @@ let doc_exp_lem, doc_let_lem = | _ -> string name ^//^ parens (expN e1 ^^ comma ^/^ expN e2)) in + let (epp,aexp_needed) = + if contains_bitvector_type t && not (contains_t_pp_var t) + then (parens epp ^^ doc_tannot_lem regtypes false t, true) + else (epp, aexp_needed) in if aexp_needed then parens (align epp) else epp | _ -> let epp = diff --git a/src/pretty_print_ocaml.ml b/src/pretty_print_ocaml.ml index 3772f549..adca6b12 100644 --- a/src/pretty_print_ocaml.ml +++ b/src/pretty_print_ocaml.ml @@ -338,10 +338,7 @@ let doc_exp_ocaml, doc_let_ocaml = | Typ_var (Kid_aux (Var "length",_)) -> parens ((string "set_start_to_length") ^//^ exp e) | _ -> - parens (doc_op colon (group (exp e)) (doc_typ_ocaml typ))) - - -) + parens (doc_op colon (group (exp e)) (doc_typ_ocaml typ)))) | E_tuple exps -> parens (separate_map comma exp exps) | E_record(FES_aux(FES_Fexps(fexps,_),_)) -> @@ -753,4 +750,3 @@ let pp_defs_ocaml f d top_line opens = print f (string "(*" ^^ (string top_line) ^^ string "*)" ^/^ (separate_map hardline (fun lib -> (string "open") ^^ space ^^ (string lib)) opens) ^/^ (doc_defs_ocaml d)) - diff --git a/src/rewriter.ml b/src/rewriter.ml index d26879e9..8e120dda 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -63,15 +63,6 @@ type 'a rewriters = { let (>>) f g = fun x -> g(f(x)) -let fresh_name_counter = ref 0 - -let fresh_name () = - let current = !fresh_name_counter in - let () = fresh_name_counter := (current + 1) in - current -let reset_fresh_name_counter () = - fresh_name_counter := 0 - let get_effsum_annot (_,t) = match t with | Base (_,_,_,_,effs,_) -> effs | NoTyp -> failwith "no effect information" @@ -89,6 +80,31 @@ let get_type_annot (_,t) = match t with let get_type (E_aux (_,a)) = get_type_annot a +let get_loc (E_aux (_,(l,_))) = l + +let fresh_name_counter = ref 0 + +let fresh_name () = + let current = !fresh_name_counter in + let () = fresh_name_counter := (current + 1) in + current +let reset_fresh_name_counter () = + fresh_name_counter := 0 + +let fresh_id pre l = + let current = fresh_name () in + Id_aux (Id (pre ^ string_of_int current), Parse_ast.Generated l) + +let fresh_id_exp pre ((l,_) as annot) = + let id = fresh_id pre l in + let annot_var = (Parse_ast.Generated l,simple_annot (get_type_annot annot)) in + E_aux (E_id id, annot_var) + +let fresh_id_pat pre ((l,_) as annot) = + let id = fresh_id pre l in + let annot_var = (Parse_ast.Generated l,simple_annot (get_type_annot annot)) in + P_aux (P_id id, annot_var) + let union_effs effs = List.fold_left (fun acc eff -> union_effects acc eff) pure_e effs @@ -210,6 +226,11 @@ let updates_vars_effs {effect = Eset effs} = let updates_vars eaux = updates_vars_effs (get_effsum_exp eaux) +let id_to_string (Id_aux(id,l)) = + match id with + | Id(s) -> s + | DeIid(s) -> s + let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with | [] -> None @@ -217,6 +238,10 @@ let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b let mk_atom_typ i = {t=Tapp("atom",[TA_nexp i])} +let simple_num l n : tannot exp = + let typ = simple_annot (mk_atom_typ (mk_c (big_int_of_int n))) in + E_aux (E_lit (L_aux (L_num n,l)), (l,typ)) + let rec rewrite_nexp_to_exp program_vars l nexp = let rewrite n = rewrite_nexp_to_exp program_vars l n in let typ = mk_atom_typ nexp in @@ -832,10 +857,8 @@ let remove_vector_concat_pat pat = let pat = remove_typed_patterns pat in - let fresh_name l = - let current = fresh_name () in - Id_aux (Id ("v__" ^ string_of_int current), Parse_ast.Generated l) in - + let fresh_id_v = fresh_id "v__" in + (* expects that P_typ elements have been removed from AST, that the length of all vectors involved is known, that we don't have indexed vectors *) @@ -860,7 +883,7 @@ let remove_vector_concat_pat pat = | P_vector_concat pats -> (if contained_in_p_as then P_aux (pat,annot) - else P_aux (P_as (P_aux (pat,annot),fresh_name l),annot)) + else P_aux (P_as (P_aux (pat,annot),fresh_id_v l),annot)) | _ -> P_aux (pat,annot) ) ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) @@ -873,7 +896,7 @@ let remove_vector_concat_pat pat = let name_vector_concat_elements = let p_vector_concat pats = let aux ((P_aux (p,((l,_) as a))) as pat) = match p with - | P_vector _ -> P_aux (P_as (pat,fresh_name l),a) + | P_vector _ -> P_aux (P_as (pat,fresh_id_v l),a) | P_id id -> P_aux (P_id id,a) | P_as (p,id) -> P_aux (P_as (p,id),a) | P_wild -> P_aux (P_wild,a) @@ -908,17 +931,13 @@ let remove_vector_concat_pat pat = let (Id_aux (Id rootname,_)) = rootid in let (Id_aux (Id childname,_)) = child in - let simple_num n : tannot exp = - let typ = simple_annot (mk_atom_typ (mk_c (big_int_of_int n))) in - E_aux (E_lit (L_aux (L_num n,l)), (l,typ)) in - let vlength_info (Base ((_,{t = Tapp("vector",[_;TA_nexp nexp;_;_])}),_,_,_,_,_)) = nexp in let root : tannot exp = E_aux (E_id rootid,rannot) in - let index_i = simple_num i in + let index_i = simple_num l i in let index_j : tannot exp = match j with - | Some j -> simple_num j + | Some j -> simple_num l j | None -> let length_root_nexp = vlength_info (snd rannot) in let length_app_exp : tannot exp = @@ -950,41 +969,31 @@ let remove_vector_concat_pat pat = let p_aux = function | ((P_as (P_aux (P_vector_concat pats,rannot'),rootid),decls),rannot) -> let aux (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) = match cannot with - | (_,Base((_,{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_,_)) - | (_,Base((_,{t = Tabbrev (_,{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])})}),_,_,_,_,_)) -> - let length = int_of_big_int length in + | (l,Base((_,({t = Tapp ("vector",[_;TA_nexp length;_;_])} as t)),_,_,_,_,_)) + | (l,Base((_,({t = Tabbrev (_,{t = Tapp ("vector",[_;TA_nexp length;_;_])})} as t)),_,_,_,_,_)) -> + let (pos',index_j) = match has_const_vector_length t with + | Some i -> + let length = int_of_big_int i in + (pos+length, Some(pos+length-1)) + | None -> + if is_last then (pos,None) + else + raise + (Reporting_basic.err_unreachable + l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) in (match p with (* if we see a named vector pattern, remove the name and remember to declare it later *) | P_as (P_aux (p,cannot),cname) -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,Some(pos+length-1)) in - (pos + length, pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) + let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in + (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) (* if we see a P_id variable, remember to declare it later *) | P_id cname -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,Some(pos+length-1)) in - (pos + length, pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) + let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in + (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) (* normal vector patterns are fine *) - | _ -> (pos + length, pat_acc @ [P_aux (p,cannot)],decl_acc) ) + | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc) ) (* non-vector patterns aren't *) - | (l,Base((_,{t = Tapp ("vector",[_;_;_;_])}),_,_,_,_,_)) - | (l,Base((_,{t = Tabbrev (_,{t = Tapp ("vector",[_;_;_;_])})}),_,_,_,_,_)) -> - if is_last then - match p with - (* if we see a named vector pattern, remove the name and remember to - declare it later *) - | P_as (P_aux (p,cannot),cname) -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,None) in - (pos, pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) - (* if we see a P_id variable, remember to declare it later *) - | P_id cname -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,None) in - (pos, pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) - (* normal vector patterns are fine *) - | _ -> (pos, pat_acc @ [P_aux (p,cannot)],decl_acc) - else - raise - (Reporting_basic.err_unreachable - l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) | (l,Base((_,t),_,_,_,_,_)) -> raise (Reporting_basic.err_unreachable @@ -1162,7 +1171,15 @@ let rewrite_fun_remove_vector_concat_pat (FCL_aux (FCL_Funcl (id,pat,rewriters.rewrite_exp rewriters None (decls exp)),(l,annot))) in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot)) -let rewrite_defs_remove_vector_concat_pat rewriters (Defs defs) = +let rewrite_defs_remove_vector_concat (Defs defs) = + let rewriters = + {rewrite_exp = rewrite_exp_remove_vector_concat_pat; + rewrite_pat = rewrite_pat; + rewrite_let = rewrite_let; + rewrite_lexp = rewrite_lexp; + rewrite_fun = rewrite_fun_remove_vector_concat_pat; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base} in let rewrite_def d = let d = rewriters.rewrite_def rewriters d in match d with @@ -1174,18 +1191,451 @@ let rewrite_defs_remove_vector_concat_pat rewriters (Defs defs) = let (pat,letbinds,_) = remove_vector_concat_pat pat in let defvals = List.map (fun lb -> DEF_val lb) letbinds in [DEF_val (LB_aux (LB_val_implicit (pat,exp),a))] @ defvals - | d -> [rewriters.rewrite_def rewriters d] in + | d -> [d] in Defs (List.flatten (List.map rewrite_def defs)) -let rewrite_defs_remove_vector_concat defs = rewrite_defs_base - {rewrite_exp = rewrite_exp_remove_vector_concat_pat; +let map_default f = function +| None -> None +| Some x -> f x + +let rec binop_opt f x y = match x, y with +| None, None -> None +| Some x, None -> Some x +| None, Some y -> Some y +| Some x, Some y -> Some (f x y) + +let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with +| P_lit _ | P_wild | P_id _ -> false +| P_as (pat,_) | P_typ (_,pat) -> contains_bitvector_pat pat +| P_vector _ | P_vector_concat _ | P_vector_indexed _ -> + is_bit_vector (get_type_annot annot) +| P_app (_,pats) | P_tup pats | P_list pats -> + List.exists contains_bitvector_pat pats +| P_record (fpats,_) -> + List.exists (fun (FP_aux (FP_Fpat (_,pat),_)) -> contains_bitvector_pat pat) fpats + +let remove_bitvector_pat pat = + + (* first introduce names for bitvector patterns *) + let name_bitvector_roots = + { p_lit = (fun lit -> P_lit lit) + ; p_typ = (fun (typ,p) -> P_typ (typ,p false)) + ; p_wild = P_wild + ; p_as = (fun (pat,id) -> P_as (pat true,id)) + ; p_id = (fun id -> P_id id) + ; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps)) + ; p_record = (fun (fpats,b) -> P_record (fpats, b)) + ; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps)) + ; p_vector_indexed = (fun ps -> P_vector_indexed (List.map (fun (i,p) -> (i,p false)) ps)) + ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)) + ; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps)) + ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)) + ; p_aux = + (fun (pat,annot) contained_in_p_as -> + let t = get_type_annot annot in + let (l,_) = annot in + match pat, is_bit_vector t, contained_in_p_as with + | P_vector _, true, false + | P_vector_indexed _, true, false -> + P_aux (P_as (P_aux (pat,annot),fresh_id "b__" l), annot) + | _ -> P_aux (pat,annot) + ) + ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) + ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false)) + } in + let pat = (fold_pat name_bitvector_roots pat) false in + + (* Then collect guard expressions testing whether the literal bits of a + bitvector pattern match those of a given bitvector, and collect let + bindings for the bits bound by P_id or P_as patterns *) + + (* Helper functions for calculating vector indices *) + let vec_ord t = match (normalize_t t).t with + | Tapp("vector", [_;_;TA_ord {order = ord}; _]) -> ord + | _ -> Oinc (* TODO Use default order *) in + + let vec_is_inc t = match vec_ord t with Oinc -> true | _ -> false in + + let vec_start t = match (normalize_t t).t with + | Tapp("vector", [TA_nexp {nexp = Nconst i};_;_; _]) -> int_of_big_int i + | _ -> 0 in + + let vec_length t = match (normalize_t t).t with + | Tapp("vector", [_;TA_nexp {nexp = Nconst j};_; _]) -> int_of_big_int j + | _ -> 0 in + + (* Helper functions for generating guard expressions *) + let bit_annot l = (Parse_ast.Generated l, simple_annot {t = Tid "bit"}) in + + let access_bit_exp (rootid,rannot) l idx = + let root : tannot exp = E_aux (E_id rootid,rannot) in + E_aux (E_vector_access (root,simple_num l idx), bit_annot l) in + + let test_bit_exp rootid l t idx exp = + let rannot = (Parse_ast.Generated l, simple_annot t) in + let elem = access_bit_exp (rootid,rannot) l idx in + let eqid = Id_aux (Id "==", Parse_ast.Generated l) in + let eqannot = (Parse_ast.Generated l, + tag_annot {t = Tid "bit"} (External (Some "eq_bit"))) in + let eqexp : tannot exp = E_aux (E_app_infix(elem,eqid,exp), eqannot) in + Some (eqexp) in + + let test_subvec_exp rootid l t i j lits = + let l' = Parse_ast.Generated l in + let t' = mk_vector {t = Tid "bit"} {order = vec_ord t} + (mk_c_int i) (mk_c_int (List.length lits)) in + let subvec_exp = + if vec_start t = i && vec_length t = List.length lits + then E_id rootid + else E_vector_subrange ( + E_aux (E_id rootid, (l', simple_annot t)), + simple_num l i, + simple_num l j) in + E_aux (E_app_infix( + E_aux (subvec_exp, (l', simple_annot t')), + Id_aux (Id "==", l'), + E_aux (E_vector lits, (l', simple_annot t'))), + (l', tag_annot {t = Tid "bit"} (External (Some "eq_vec")))) in + + let letbind_bit_exp rootid l t idx id = + let rannot = (Parse_ast.Generated l, simple_annot t) in + let elem = access_bit_exp (rootid,rannot) l idx in + let e = P_aux (P_id id, bit_annot l) in + let letbind = LB_aux (LB_val_implicit (e,elem), bit_annot l) in + let letexp = (fun body -> + let (E_aux (_,(_,bannot))) = body in + E_aux (E_let (letbind,body), (Parse_ast.Generated l, bannot))) in + (letexp, letbind) in + + (* Helper functions for composing guards *) + let bitwise_and exp1 exp2 = + let (E_aux (_,(l,_))) = exp1 in + let andid = Id_aux (Id "&", Parse_ast.Generated l) in + let andannot = (Parse_ast.Generated l, + tag_annot {t = Tid "bit"} (External (Some "bitwise_and_bit"))) in + E_aux (E_app_infix(exp1,andid,exp2), andannot) in + + let compose_guards guards = + List.fold_right (binop_opt bitwise_and) guards None in + + let flatten_guards_decls gd = + let (guards,decls,letbinds) = Util.split3 gd in + (compose_guards guards, (List.fold_right (@@) decls), List.flatten letbinds) in + + (* Collect guards and let bindings *) + let guard_bitvector_pat = + let collect_guards_decls ps rootid t = + let rec collect current (guards,dls) idx ps = + let idx' = if vec_is_inc t then idx + 1 else idx - 1 in + (match ps with + | pat :: ps' -> + (match pat with + | P_aux (P_lit lit, (l,annot)) -> + let e = E_aux (E_lit lit, (Parse_ast.Generated l, annot)) in + let current' = (match current with + | Some (l,i,j,lits) -> Some (l,i,idx,lits @ [e]) + | None -> Some (l,idx,idx,[e])) in + collect current' (guards, dls) idx' ps' + | P_aux (P_as (pat',id), (l,annot)) -> + let dl = letbind_bit_exp rootid l t idx id in + collect current (guards, dls @ [dl]) idx (pat' :: ps') + | _ -> + let dls' = (match pat with + | P_aux (P_id id, (l,annot)) -> + dls @ [letbind_bit_exp rootid l t idx id] + | _ -> dls) in + let guards' = (match current with + | Some (l,i,j,lits) -> + guards @ [Some (test_subvec_exp rootid l t i j lits)] + | None -> guards) in + collect None (guards', dls') idx' ps') + | [] -> + let guards' = (match current with + | Some (l,i,j,lits) -> + guards @ [Some (test_subvec_exp rootid l t i j lits)] + | None -> guards) in + (guards',dls)) in + let (guards,dls) = collect None ([],[]) (vec_start t) ps in + let (decls,letbinds) = List.split dls in + (compose_guards guards, List.fold_right (@@) decls, letbinds) in + + let collect_guards_decls_indexed ips rootid t = + let rec guard_decl (idx,pat) = (match pat with + | P_aux (P_lit lit, (l,annot)) -> + let exp = E_aux (E_lit lit, (l,annot)) in + (test_bit_exp rootid l t idx exp, (fun b -> b), []) + | P_aux (P_as (pat',id), (l,annot)) -> + let (guard,decls,letbinds) = guard_decl (idx,pat') in + let (letexp,letbind) = letbind_bit_exp rootid l t idx id in + (guard, decls >> letexp, letbind :: letbinds) + | P_aux (P_id id, (l,annot)) -> + let (letexp,letbind) = letbind_bit_exp rootid l t idx id in + (None, letexp, [letbind]) + | _ -> (None, (fun b -> b), [])) in + let (guards,decls,letbinds) = Util.split3 (List.map guard_decl ips) in + (compose_guards guards, List.fold_right (@@) decls, List.flatten letbinds) in + + { p_lit = (fun lit -> (P_lit lit, (None, (fun b -> b), []))) + ; p_wild = (P_wild, (None, (fun b -> b), [])) + ; p_as = (fun ((pat,gdls),id) -> (P_as (pat,id), gdls)) + ; p_typ = (fun (typ,(pat,gdls)) -> (P_typ (typ,pat), gdls)) + ; p_id = (fun id -> (P_id id, (None, (fun b -> b), []))) + ; p_app = (fun (id,ps) -> let (ps,gdls) = List.split ps in + (P_app (id,ps), flatten_guards_decls gdls)) + ; p_record = (fun (ps,b) -> let (ps,gdls) = List.split ps in + (P_record (ps,b), flatten_guards_decls gdls)) + ; p_vector = (fun ps -> let (ps,gdls) = List.split ps in + (P_vector ps, flatten_guards_decls gdls)) + ; p_vector_indexed = (fun p -> let (is,p) = List.split p in + let (ps,gdls) = List.split p in + let ps = List.combine is ps in + (P_vector_indexed ps, flatten_guards_decls gdls)) + ; p_vector_concat = (fun ps -> let (ps,gdls) = List.split ps in + (P_vector_concat ps, flatten_guards_decls gdls)) + ; p_tup = (fun ps -> let (ps,gdls) = List.split ps in + (P_tup ps, flatten_guards_decls gdls)) + ; p_list = (fun ps -> let (ps,gdls) = List.split ps in + (P_list ps, flatten_guards_decls gdls)) + ; p_aux = (fun ((pat,gdls),annot) -> + let t = get_type_annot annot in + (match pat, is_bit_vector t with + | P_as (P_aux (P_vector ps, _), id), true -> + (P_aux (P_id id, annot), collect_guards_decls ps id t) + | P_as (P_aux (P_vector_indexed ips, _), id), true -> + (P_aux (P_id id, annot), collect_guards_decls_indexed ips id t) + | _, _ -> (P_aux (pat,annot), gdls))) + ; fP_aux = (fun ((fpat,gdls),annot) -> (FP_aux (fpat,annot), gdls)) + ; fP_Fpat = (fun (id,(pat,gdls)) -> (FP_Fpat (id,pat), gdls)) + } in + fold_pat guard_bitvector_pat pat + +let remove_wildcards pre (P_aux (_,(l,_)) as pat) = + fold_pat + {id_pat_alg with + p_aux = function + | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot)) + | (p,annot) -> P_aux (p,annot) } + pat + +(* Check if one pattern subsumes the other, and if so, calculate a + substitution of variables that are used in the same position. + TODO: Check somewhere that there are no variable clashes (the same variable + name used in different positions of the patterns) + *) +let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,_) as pat2) = + let rewrap p = P_aux (p,annot1) in + let subsumes_list s pats1 pats2 = + if List.length pats1 = List.length pats2 + then + let subs = List.map2 s pats1 pats2 in + List.fold_right + (fun p acc -> match p, acc with + | Some subst, Some substs -> Some (subst @ substs) + | _ -> None) + subs (Some []) + else None in + match p1, p2 with + | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) -> + if lit1 = lit2 then Some [] else None + | P_as (pat1,_), _ -> subsumes_pat pat1 pat2 + | _, P_as (pat2,_) -> subsumes_pat pat1 pat2 + | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2 + | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2 + | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) -> + if id1 = id2 then Some [] else Some [(id2,id1)] + | P_id id1, _ -> Some [] + | P_wild, _ -> Some [] + | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) -> + if id1 = id2 then subsumes_list subsumes_pat args1 args2 else None + | P_record (fps1,b1), P_record (fps2,b2) -> + if b1 = b2 then subsumes_list subsumes_fpat fps1 fps2 else None + | P_vector pats1, P_vector pats2 + | P_vector_concat pats1, P_vector_concat pats2 + | P_tup pats1, P_tup pats2 + | P_list pats1, P_list pats2 -> + subsumes_list subsumes_pat pats1 pats2 + | P_vector_indexed ips1, P_vector_indexed ips2 -> + let (is1,ps1) = List.split ips1 in + let (is2,ps2) = List.split ips2 in + if is1 = is2 then subsumes_list subsumes_pat ps1 ps2 else None + | _ -> None +and subsumes_fpat (FP_aux (FP_Fpat (id1,pat1),_)) (FP_aux (FP_Fpat (id2,pat2),_)) = + if id1 = id2 then subsumes_pat pat1 pat2 else None + +let equiv_pats pat1 pat2 = + match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with + | Some _, Some _ -> true + | _, _ -> false + +let subst_id_pat pat (id1,id2) = + let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in + fold_pat {id_pat_alg with p_id = p_id} pat + +let subst_id_exp exp (id1,id2) = + (* TODO Don't substitute bound occurrences inside let expressions etc *) + let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in + fold_exp {id_exp_alg with e_id = e_id} exp + +let gen_annot l t efr = (Parse_ast.Generated l,simple_annot_efr t efr) + +let rec pat_to_exp (P_aux (pat,(l,annot))) = + let rewrap e = E_aux (e,(l,annot)) in + match pat with + | P_lit lit -> rewrap (E_lit lit) + | P_wild -> raise (Reporting_basic.err_unreachable l + "pat_to_exp given wildcard pattern") + | P_as (pat,id) -> rewrap (E_id id) + | P_typ (_,pat) -> pat_to_exp pat + | P_id id -> rewrap (E_id id) + | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats)) + | P_record (fpats,b) -> + rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot)))) + | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats)) + | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l + "pat_to_exp not implemented for P_vector_concat") + (* We assume that vector concatenation patterns have been transformed + away already *) + | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats)) + | P_list pats -> rewrap (E_list (List.map pat_to_exp pats)) + | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l + "pat_to_exp not implemented for P_vector_indexed") + (* TODO: We can't guess the default value for the indexed vector + expression here. We should make sure that indexed vector patterns are + bound to a variable via P_as before calling pat_to_exp *) +and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) = + FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot)) + +let case_exp e t cs = + let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in + let ps = List.map pexp cs in + (* let efr = union_effs (List.map get_effsum_pexp ps) in *) + fix_effsum_exp (E_aux (E_case (e,ps), gen_annot (get_loc e) t pure_e)) + +let rewrite_guarded_clauses l cs = + let rec group clauses = + let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in + let rec group_aux current acc = (function + | ((pat,guard,body,annot) as c) :: cs -> + let (current_pat,_,_) = current in + (match subsumes_pat current_pat pat with + | Some substs -> + let pat' = List.fold_left subst_id_pat pat substs in + let guard' = (match guard with + | Some exp -> Some (List.fold_left subst_id_exp exp substs) + | None -> None) in + let body' = List.fold_left subst_id_exp body substs in + let c' = (pat',guard',body',annot) in + group_aux (add_clause current c') acc cs + | None -> + let pat = remove_wildcards "g__" pat in + group_aux (pat,[c],annot) (acc @ [current]) cs) + | [] -> acc @ [current]) in + let groups = match clauses with + | ((pat,guard,body,annot) as c) :: cs -> + group_aux (remove_wildcards "g__" pat, [c], annot) [] cs + | _ -> + raise (Reporting_basic.err_unreachable l + "group given empty list in rewrite_guarded_clauses") in + List.map (fun cs -> if_pexp cs) groups + and if_pexp (pat,cs,annot) = (match cs with + | c :: _ -> + (* fix_effsum_pexp (pexp *) + let body = if_exp pat cs in + let pexp = fix_effsum_pexp (Pat_aux (Pat_exp (pat,body),annot)) in + let (Pat_aux (Pat_exp (_,_),annot)) = pexp in + (pat, body, annot) + | [] -> + raise (Reporting_basic.err_unreachable l + "if_pexp given empty list in rewrite_guarded_clauses")) + and if_exp current_pat = (function + | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs -> + (match guard with + | Some exp -> + let else_exp = + if equiv_pats current_pat pat' + then if_exp current_pat (c' :: cs) + else case_exp (pat_to_exp current_pat) (get_type_annot annot') (group (c' :: cs)) in + fix_effsum_exp (E_aux (E_if (exp,body,else_exp), annot)) + | None -> body) + | [(pat,guard,body,annot)] -> body + | [] -> + raise (Reporting_basic.err_unreachable l + "if_exp given empty list in rewrite_guarded_clauses")) in + group cs + +let rewrite_exp_remove_bitvector_pat rewriters nmap (E_aux (exp,(l,annot)) as full_exp) = + let rewrap e = E_aux (e,(l,annot)) in + let rewrite_rec = rewriters.rewrite_exp rewriters nmap in + let rewrite_base = rewrite_exp rewriters nmap in + match exp with + | E_case (e,ps) + when List.exists (fun (Pat_aux (Pat_exp (pat,_),_)) -> contains_bitvector_pat pat) ps -> + let clause (Pat_aux (Pat_exp (pat,body),annot')) = + let (pat',(guard,decls,_)) = remove_bitvector_pat pat in + let body' = decls (rewrite_rec body) in + (pat',guard,body',annot') in + let clauses = rewrite_guarded_clauses l (List.map clause ps) in + if (effectful e) then + let e = rewrite_rec e in + let (E_aux (_,(el,eannot))) = e in + let pat_e' = fresh_id_pat "p__" (el,eannot) in + let exp_e' = pat_to_exp pat_e' in + (* let fresh = fresh_id "p__" el in + let exp_e' = E_aux (E_id fresh, gen_annot l (get_type e) pure_e) in + let pat_e' = P_aux (P_id fresh, gen_annot l (get_type e) pure_e) in *) + let letbind_e = LB_aux (LB_val_implicit (pat_e',e), gen_annot l (get_type e) (get_effsum_exp e)) in + let exp' = case_exp exp_e' (get_type full_exp) clauses in + rewrap (E_let (letbind_e, exp')) + else case_exp e (get_type full_exp) clauses + | E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) -> + let (pat,(_,decls,_)) = remove_bitvector_pat pat in + rewrap (E_let (LB_aux (LB_val_explicit (typ,pat,rewrite_rec v),annot'), + decls (rewrite_rec body))) + | E_let (LB_aux (LB_val_implicit (pat,v),annot'),body) -> + let (pat,(_,decls,_)) = remove_bitvector_pat pat in + rewrap (E_let (LB_aux (LB_val_implicit (pat,rewrite_rec v),annot'), + decls (rewrite_rec body))) + | _ -> rewrite_base full_exp + +let rewrite_fun_remove_bitvector_pat + rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = + let _ = reset_fresh_name_counter () in + (* TODO Can there be clauses with different id's in one FD_function? *) + let funcls = match funcls with + | (FCL_aux (FCL_Funcl(id,_,_),_) :: _) -> + let clause (FCL_aux (FCL_Funcl(_,pat,exp),annot)) = + let (pat,(guard,decls,_)) = remove_bitvector_pat pat in + let exp = decls (rewriters.rewrite_exp rewriters None exp) in + (pat,guard,exp,annot) in + let cs = rewrite_guarded_clauses l (List.map clause funcls) in + List.map (fun (pat,exp,annot) -> FCL_aux (FCL_Funcl(id,pat,exp),annot)) cs + | _ -> funcls (* TODO is the empty list possible here? *) in + FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot)) + +let rewrite_defs_remove_bitvector_pats (Defs defs) = + let rewriters = + {rewrite_exp = rewrite_exp_remove_bitvector_pat; rewrite_pat = rewrite_pat; rewrite_let = rewrite_let; rewrite_lexp = rewrite_lexp; - rewrite_fun = rewrite_fun_remove_vector_concat_pat; + rewrite_fun = rewrite_fun_remove_bitvector_pat; rewrite_def = rewrite_def; - rewrite_defs = rewrite_defs_remove_vector_concat_pat} defs - + rewrite_defs = rewrite_defs_base } in + let rewrite_def d = + let d = rewriters.rewrite_def rewriters d in + match d with + | DEF_val (LB_aux (LB_val_explicit (t,pat,exp),a)) -> + let (pat',(_,_,letbinds)) = remove_bitvector_pat pat in + let defvals = List.map (fun lb -> DEF_val lb) letbinds in + [DEF_val (LB_aux (LB_val_explicit (t,pat',exp),a))] @ defvals + | DEF_val (LB_aux (LB_val_implicit (pat,exp),a)) -> + let (pat',(_,_,letbinds)) = remove_bitvector_pat pat in + let defvals = List.map (fun lb -> DEF_val lb) letbinds in + [DEF_val (LB_aux (LB_val_implicit (pat',exp),a))] @ defvals + | d -> [d] in + Defs (List.flatten (List.map rewrite_def defs)) + + (*Expects to be called after rewrite_defs; thus the following should not appear: internal_exp of any form lit vectors in patterns or expressions @@ -1384,12 +1834,6 @@ let rewrite_defs_remove_blocks = -let fresh_id ((l,_) as annot) = - let current = fresh_name () in - let id = Id_aux (Id ("w__" ^ string_of_int current), Parse_ast.Generated l) in - let annot_var = (Parse_ast.Generated l,simple_annot (get_type_annot annot)) in - E_aux (E_id id, annot_var) - let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = (* body is a function : E_id variable -> actual body *) match get_type v with @@ -1405,7 +1849,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) | _ -> let (E_aux (_,((l,_) as annot))) = v in - let ((E_aux (E_id id,_)) as e_id) = fresh_id annot in + let ((E_aux (E_id id,_)) as e_id) = fresh_id_exp "w__" annot in let body = body e_id in let annot_pat = (Parse_ast.Generated l,simple_annot (get_type v)) in @@ -2146,6 +2590,7 @@ let rewrite_defs_remove_e_assign = let rewrite_defs_lem = top_sort_defs >> rewrite_defs_remove_vector_concat >> + rewrite_defs_remove_bitvector_pats >> rewrite_defs_exp_lift_assign >> rewrite_defs_remove_blocks >> rewrite_defs_letbind_effects >> @@ -2154,4 +2599,3 @@ let rewrite_defs_lem = rewrite_defs_remove_superfluous_letbinds >> rewrite_defs_remove_superfluous_returns - diff --git a/src/sail.ml b/src/sail.ml index 452e6e72..41c42fe4 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -53,6 +53,7 @@ let opt_print_ocaml = ref false let opt_libs_lem = ref ([]:string list) let opt_libs_ocaml = ref ([]:string list) let opt_file_arguments = ref ([]:string list) +let opt_mono_split = ref ([]:((string * int) * string) list) let options = Arg.align ([ ( "-o", Arg.String (fun f -> opt_file_out := Some f), @@ -86,6 +87,13 @@ let options = Arg.align ([ ( "-skip_constraints", Arg.Clear Type_internal.do_resolve_constraints, " (debug) skip constraint resolution in type-checking"); + ( "-mono-split", + Arg.String (fun s -> + let l = String.split_on_char ':' s in + match l with + | [fname;line;var] -> opt_mono_split := ((fname,int_of_string line),var)::!opt_mono_split + | _ -> raise (Arg.Bad (s ^ " not of form <filename>:<line>:<variable>"))), + "<filename>:<line>:<variable> to case split for monomorphisation"); ( "-new_typecheck", Arg.Set opt_new_typecheck, " (experimental) use new typechecker with Z3 constraint solving"); @@ -141,6 +149,12 @@ let main() = -> Parse_ast.Defs (ast_nodes@later_nodes)) parsed (Parse_ast.Defs []) in let (ast,kenv,ord) = convert_ast ast in let (ast,type_envs) = check_ast ast kenv ord in + + let ast = match !opt_mono_split with + | [] -> ast + | l -> Monomorphise.split_defs l type_envs ast + in + let ast = rewrite_ast ast in let out_name = match !opt_file_out with | None -> fst (List.hd parsed) diff --git a/src/type_check.ml b/src/type_check.ml index bc6d67a8..73520825 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -1593,7 +1593,7 @@ let rec check_exp envs (imp_param:nexp option) (widen_num:bool) (widen_vec:bool) let (pexps',t,cs',ef') = check_cases envs imp_param ret_t t' expect_t (if (List.length pexps) = 1 then Solo else Switch) pexps in let effects = union_effects ef ef' in - (E_aux(E_case(e',pexps'),(l,simple_annot_efr t effects)),t, + (E_aux(E_case(e',pexps'),(l,simple_annot_efr expect_t effects)),t, t_env,cs@[BranchCons(Expr l, None, cs')],nob,effects) | E_let(lbind,body) -> let (lb',t_env',cs,b_env',ef) = (check_lbind envs imp_param false (Some ret_t) Emp_local lbind) in diff --git a/src/type_internal.ml b/src/type_internal.ml index 2932cf33..5df5e94d 100644 --- a/src/type_internal.ml +++ b/src/type_internal.ml @@ -560,15 +560,6 @@ let rec pow_i i n = | n -> mult_int_big_int i (pow_i i (n-1)) let two_pow = pow_i 2 -let is_bit_vector t = match t.t with - | Tapp("vector", [_;_;_; TA_typ t]) - | Tabbrev(_,{t=Tapp("vector",[_;_;_; TA_typ t])}) - | Tapp("reg", [TA_typ {t=Tapp("vector",[_;_;_; TA_typ t])}])-> - (match t.t with - | Tid "bit" | Tabbrev(_,{t=Tid "bit"}) | Tapp("reg",[TA_typ {t=Tid "bit"}]) -> true - | _ -> false) - | _ -> false - (* predicate to determine if pushing a constant in for addition or multiplication could change the form *) let rec contains_const n = match n.nexp with @@ -929,8 +920,40 @@ let rec normalize_n_rec recur_ok n = let normalize_nexp = normalize_n_rec true +let rec normalize_t t = match t.t with + | Tfn (t1,t2,i,eff) -> {t = Tfn (normalize_t t1,normalize_t t2,i,eff)} + | Ttup ts -> {t = Ttup (List.map normalize_t ts)} + | Tapp (c,args) -> {t = Tapp (c, List.map normalize_t_arg args)} + | Tabbrev (_,t') -> t' + | _ -> t +and normalize_t_arg targ = match targ with + | TA_typ t -> TA_typ (normalize_t t) + | TA_nexp nexp -> TA_nexp (normalize_nexp nexp) + | _ -> targ + let int_to_nexp = mk_c_int +let is_bit t = match t.t with + | Tid "bit" + | Tabbrev(_,{t=Tid "bit"}) + | Tapp("register",[TA_typ {t=Tid "bit"}]) -> true + | _ -> false + +let rec is_bit_vector t = match t.t with + | Tapp("vector", [_;_;_; TA_typ t]) -> is_bit t + | Tapp("register", [TA_typ t']) -> is_bit_vector t' + | Tabbrev(_,t') -> is_bit_vector t' + | _ -> false + +let rec has_const_vector_length t = match t.t with + | Tapp("vector", [_;TA_nexp m;_;_]) -> + (match (normalize_nexp m).nexp with + | Nconst i -> Some i + | _ -> None) + | Tapp("register", [TA_typ t']) -> has_const_vector_length t' + | Tabbrev(_,t') -> has_const_vector_length t' + | _ -> None + let v_count = ref 0 let t_count = ref 0 let tuvars = ref [] diff --git a/src/type_internal.mli b/src/type_internal.mli index ee2e3988..f4924a63 100644 --- a/src/type_internal.mli +++ b/src/type_internal.mli @@ -313,6 +313,7 @@ val get_abbrev : def_envs -> t -> (t * nexp_range list) val is_enum_typ : def_envs -> t -> int option val is_bit_vector : t -> bool +val has_const_vector_length : t -> big_int option val extract_bounds : def_envs -> string -> t -> bounds_env val merge_bounds : bounds_env -> bounds_env -> bounds_env @@ -324,6 +325,7 @@ val merge_option_maps : nexp_map option -> nexp_map option -> nexp_map option val expand_nexp : nexp -> nexp list val normalize_nexp : nexp -> nexp +val normalize_t : t -> t val get_index : nexp -> int (*expose nindex through this for debugging purposes*) val get_all_nvar : nexp -> string list (*Pull out all of the contained nvar and nuvars in nexp*) @@ -387,4 +389,3 @@ val tannot_merge : constraint_origin -> def_envs -> bool -> tannot -> tannot -> val initial_typ_env : tannot Envmap.t val initial_typ_env_list : (string * ((string * tannot) list)) list - diff --git a/src/util.ml b/src/util.ml index 2b6f81f8..bb277016 100644 --- a/src/util.ml +++ b/src/util.ml @@ -240,6 +240,12 @@ let split_after n l = | _ -> raise (Failure "index too large") in aux [] n l +let rec split3 = function + | (x, y, z) :: xs -> + let (xs, ys, zs) = split3 xs in + (x :: xs, y :: ys, z :: zs) + | [] -> ([], [], []) + let list_mapi (f : int -> 'a -> 'b) (l : 'a list) : 'b list = let rec aux f i l = match l with @@ -324,4 +330,3 @@ let rec string_of_list sep string_of = function | [] -> "" | [x] -> string_of x | x::ls -> (string_of x) ^ sep ^ (string_of_list sep string_of ls) - diff --git a/src/util.mli b/src/util.mli index c565cdce..496c63cf 100644 --- a/src/util.mli +++ b/src/util.mli @@ -145,6 +145,9 @@ val undo_list_to_front : int -> 'a list -> 'a list [l1] and [l2], with [length l1 = n] and [l1 @ l2 = l]. Fails if n is too small or large. *) val split_after : int -> 'a list -> 'a list * 'a list +(** [split3 l] splits a list of triples into a triple of lists *) +val split3 : ('a * 'b * 'c) list -> 'a list * 'b list * 'c list + val compare_list : ('a -> 'b -> int) -> 'a list -> 'b list -> int |
