diff options
| author | Christopher Pulte | 2016-10-27 00:06:19 +0100 |
|---|---|---|
| committer | Christopher Pulte | 2016-10-27 00:06:19 +0100 |
| commit | 5cbc35eb6d253e87185bbf247aa1e3db12d998d8 (patch) | |
| tree | 6f0ca304671576b7bea45aca286743a59505d4b9 /src | |
| parent | 587f9dc7c6409d5ef89719fd65fe7bbb8f8d86b7 (diff) | |
more shallow embedding fixes
Diffstat (limited to 'src')
| -rw-r--r-- | src/gen_lib/prompt.lem | 22 | ||||
| -rw-r--r-- | src/gen_lib/sail_values.lem | 388 | ||||
| -rw-r--r-- | src/pretty_print.ml | 127 | ||||
| -rw-r--r-- | src/rewriter.ml | 3 |
4 files changed, 295 insertions, 245 deletions
diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem index 6dea2e9b..b369fd21 100644 --- a/src/gen_lib/prompt.lem +++ b/src/gen_lib/prompt.lem @@ -35,16 +35,16 @@ let inline (>>) m n = m >>= fun _ -> n val exit : forall 'a. string -> M 'a let exit s = Fail (Just s) -val read_memory : read_kind -> vector bitU -> integer -> M (vector bitU) -let read_memory rk addr sz = +val read_memory : bool -> read_kind -> vector bitU -> integer -> M (vector bitU) +let read_memory dir rk addr sz = let addr = address_lifted_of_bitv addr in let sz = natFromInteger sz in let k memory_value = - let bitv = bitv_of_byte_lifteds memory_value in + let bitv = bitv_of_byte_lifteds dir memory_value in (Done bitv,Nothing) in Read_mem (rk,addr,sz) k -val write_memory_ea : write_kind -> vector bitU -> integer -> M unit +val write_memory_ea : write_kind -> vector bitU -> integer -> M unit let write_memory_ea wk addr sz = let addr = address_lifted_of_bitv addr in let sz = natFromInteger sz in @@ -63,7 +63,7 @@ let read_reg_range reg i j = let reg = Reg_slice (name_of_reg reg) (start_of_reg_nat reg) (dir_of_reg reg) (if i<j then (i,j) else (j,i)) in let k register_value = - let v = bitvFromRegisterValue register_value in + let v = bitv_of_register_value register_value in (Done v,Nothing) in Read_reg reg k @@ -76,7 +76,7 @@ let read_reg reg = let reg = Reg (name_of_reg reg) (start_of_reg_nat reg) (size_of_reg_nat reg) (dir_of_reg reg) in let k register_value = - let v = bitvFromRegisterValue register_value in + let v = bitv_of_register_value register_value in (Done v,Nothing) in Read_reg reg k @@ -87,7 +87,7 @@ let read_reg_field reg regfield = let reg = Reg_slice (name_of_reg reg) (start_of_reg_nat reg) (dir_of_reg reg) (if i<j then (i,j) else (j,i)) in let k register_value = - let v = bitvFromRegisterValue register_value in + let v = bitv_of_register_value register_value in (Done v,Nothing) in Read_reg reg k @@ -98,19 +98,21 @@ let read_reg_bitfield reg rbit = val write_reg_range : register -> integer -> integer -> vector bitU -> M unit let write_reg_range reg i j v = - let rv = registerValueFromBitv v reg in let (i,j) = (natFromInteger i,natFromInteger j) in let reg = Reg_slice (name_of_reg reg) (start_of_reg_nat reg) (dir_of_reg reg) (i,j) in + let rv = extern_reg_value reg v in Write_reg (reg,rv) (Done (),Nothing) val write_reg_bit : register -> integer -> bitU -> M unit -let write_reg_bit reg i bit = write_reg_range reg i i (Vector [bit] 0 true) +let write_reg_bit reg i bit = + write_reg_range reg i i (Vector [bit] 0 (is_inc_of_reg reg)) + (* the zero start index shouldn't matter *) val write_reg : register -> vector bitU -> M unit let write_reg reg v = - let rv = registerValueFromBitv v reg in let reg = Reg (name_of_reg reg) (start_of_reg_nat reg) (size_of_reg_nat reg) (dir_of_reg reg) in + let rv = extern_reg_value reg v in Write_reg (reg,rv) (Done (),Nothing) val write_reg_field : register -> register_field -> vector bitU -> M unit diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index 6f5dfb28..0052f493 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -1,14 +1,30 @@ open import Pervasives_extra -open import Sail_impl_base open import Interp (* only for converting between shallow- and deep-embedding values *) open import Interp_ast (* only for converting between shallow- and deep-embedding values *) +open import Sail_impl_base + type ii = integer type nn = natural -type bitU = O | I | Undef +val pow : integer -> integer -> integer +let pow m n = m ** (natFromInteger n) + +let rec replace bs ((n : integer),b') = match bs with + | [] -> [] + | b :: bs -> + if n = 0 then b' :: bs + else b :: replace bs (n - 1,b') + end + + + + +(*** Bits *) +type bitU = O | I | Undef + let showBitU = function | O -> "O" | I -> "I" @@ -19,24 +35,95 @@ instance (Show bitU) let show = showBitU end + +let to_bool = function + | O -> false + | I -> true + | Undef -> failwith "to_bool applied to Undef" + end + +let bit_lifted_of_bitU = function + | O -> Bitl_zero + | I -> Bitl_one + | Undef -> Bitl_undef + end + +let bitU_of_bit = function + | Bitc_zero -> O + | Bitc_one -> I + end + +let bitU_of_bit_lifted = function + | Bitl_zero -> O + | Bitl_one -> I + | Bitl_undef -> Undef + | Bitl_unknown -> failwith "bitU_of_bit_lifted Bitl_unknown" + end + +let bitwise_not_bit = function + | I -> O + | O -> I + | Undef -> Undef + end + +let inline (~) = bitwise_not_bit + +val is_one : integer -> bitU +let is_one i = + if i = 1 then I else O + +let bool_to_bit b = if b then I else O + +let bitwise_binop_bit op = function + | (Undef,_) -> Undef (*Do we want to do this or to respect | of I and & of B0 rules?*) + | (_,Undef) -> Undef (*Do we want to do this or to respect | of I and & of B0 rules?*) + | (x,y) -> bool_to_bit (op (to_bool x) (to_bool y)) + end + +val bitwise_and_bit : bitU * bitU -> bitU +let bitwise_and_bit = bitwise_binop_bit (&&) + +val bitwise_or_bit : bitU * bitU -> bitU +let bitwise_or_bit = bitwise_binop_bit (||) + +val bitwise_xor_bit : bitU * bitU -> bitU +let bitwise_xor_bit = bitwise_binop_bit xor + +val (&.) : bitU -> bitU -> bitU +let inline (&.) x y = bitwise_and_bit (x,y) + +val (|.) : bitU -> bitU -> bitU +let inline (|.) x y = bitwise_or_bit (x,y) + +val (+.) : bitU -> bitU -> bitU +let inline (+.) x y = bitwise_xor_bit (x,y) + + + +(*** Vectors *) + (* element list * start * has increasing direction *) type vector 'a = Vector of list 'a * integer * bool let showVector (Vector elems start inc) = "Vector " ^ show elems ^ " " ^ show start ^ " " ^ show inc +let get_dir (Vector _ _ ord) = ord +let get_start (Vector _ s _) = s +let get_elems (Vector elems _ _) = elems +let length (Vector bs _ _) = integerFromNat (length bs) + instance forall 'a. Show 'a => (Show (vector 'a)) let show = showVector end -let pp_bitu_vector (Vector elems start inc) = - let elems_pp = List.foldl (fun acc elem -> acc ^ showBitU elem) "" elems in - "Vector [" ^ elems_pp ^ "] " ^ show start ^ " " ^ show inc +let dir is_inc = if is_inc then D_increasing else D_decreasing +let bool_of_dir = function + | D_increasing -> true + | D_decreasing -> false + end - -let get_dir (Vector _ _ ord) = ord -let get_start (Vector _ s _) = s -let length (Vector bs _ _) = integerFromNat (length bs) +(*** Vector operations *) val set_vector_start : forall 'a. integer -> vector 'a -> vector 'a let set_vector_start new_start (Vector bs _ is_inc) = @@ -72,10 +159,23 @@ let slice (Vector bs start is_inc) i j = let startN = natFromInteger start in let subvector_bits = sublist bs (if is_inc then (iN-startN,jN-startN) else (startN-iN,startN-jN)) in - (Vector subvector_bits i is_inc) + Vector subvector_bits i is_inc + +(* this is for the vector slicing introduced in vector-concat patterns: i and j +index into the "raw data", the list of bits. Therefore getting the bit list is +easy, but the start index has to be transformed to match the old vector start +and the direction. *) +val slice_raw : forall 'a. vector 'a -> integer -> integer -> vector 'a +let slice_raw (Vector bs start is_inc) i j = + let iN = natFromInteger i in + let jN = natFromInteger j in + let bits = sublist bs (iN,jN) in + let new_start = if is_inc then start + i else start - i in + Vector bits new_start is_inc -val update : forall 'a. vector 'a -> integer -> integer -> vector 'a -> vector 'a -let update (Vector bs start is_inc) i j (Vector bs' _ _) = + +val update_aux : forall 'a. vector 'a -> integer -> integer -> list 'a -> vector 'a +let update_aux (Vector bs start is_inc) i j bs' = let iN = natFromInteger i in let jN = natFromInteger j in let startN = natFromInteger start in @@ -84,6 +184,10 @@ let update (Vector bs start is_inc) i j (Vector bs' _ _) = (if is_inc then (iN-startN,jN-startN) else (startN-iN,startN-jN)) bs' in Vector bits start is_inc +val update : forall 'a. vector 'a -> integer -> integer -> vector 'a -> vector 'a +let update v i j (Vector bs' _ _) = + update_aux v i j bs' + val access : forall 'a. vector 'a -> integer -> 'a let access (Vector bs start is_inc) n = if is_inc then List_extra.nth bs (natFromInteger (n - start)) @@ -91,133 +195,26 @@ let access (Vector bs start is_inc) n = val update_pos : forall 'a. vector 'a -> integer -> 'a -> vector 'a let update_pos v n b = - update v n n (Vector [b] 0 true) + update_aux v n n [b] -type register_field = string -type register_field_index = string * (integer * integer) (* name, start and end *) - -type register = - | Register of string * (* name *) - integer * (* length *) - integer * (* start index *) - bool * (* is increasing *) - list register_field_index - | UndefinedRegister of integer (* length *) - | RegisterPair of register * register +(*** Bit vector operations *) -let dir is_inc = if is_inc then D_increasing else D_decreasing -let bool_of_dir = function - | D_increasing -> true - | D_decreasing -> false - end - -let name_of_reg (Register name _ _ _ _) = name -let size_of_reg (Register _ size _ _ _) = size -let start_of_reg (Register _ _ start _ _) = start -let is_inc_of_reg (Register _ _ _ is_inc _) = is_inc -let dir_of_reg (Register _ _ _ is_inc _) = dir is_inc - -let size_of_reg_nat reg = natFromInteger (size_of_reg reg) -let start_of_reg_nat reg = natFromInteger (start_of_reg reg) - - -val register_field_indices_aux : register -> register_field -> maybe (integer * integer) -let rec register_field_indices_aux register rfield = - match register with - | Register _ _ _ _ rfields -> List.lookup rfield rfields - | RegisterPair r1 r2 -> - let m_indices = register_field_indices_aux r1 rfield in - if isJust m_indices then m_indices else register_field_indices_aux r2 rfield - | UndefinedRegister -> Nothing - end - -val register_field_indices : register -> register_field -> integer * integer -let register_field_indices register rfield = - match register_field_indices_aux register rfield with - | Just indices -> indices - | Nothing -> failwith "Invalid register/register-field combination" - end - -let register_field_indices_nat reg regfield= - let (i,j) = register_field_indices reg regfield in - (natFromInteger i,natFromInteger j) - -let to_bool = function - | O -> false - | I -> true - | Undef -> failwith "to_bool applied to Undef" - end - -let bit_lifted_of_bitU = function - | O -> Bitl_zero - | I -> Bitl_one - | Undef -> Bitl_undef - end - -let bitU_of_bit = function - | Bitc_zero -> O - | Bitc_one -> I - end +let pp_bitu_vector (Vector elems start inc) = + let elems_pp = List.foldl (fun acc elem -> acc ^ showBitU elem) "" elems in + "Vector [" ^ elems_pp ^ "] " ^ show start ^ " " ^ show inc -let bitU_of_bit_lifted = function - | Bitl_zero -> O - | Bitl_one -> I - | Bitl_undef -> Undef - | Bitl_unknown -> failwith "bitU_of_bit_lifted Bitl_unknown" - end let most_significant = function | (Vector (b :: _) _ _) -> b | _ -> failwith "most_significant applied to empty vector" end -let bitwise_not_bit = function - | I -> O - | O -> I - | Undef -> Undef - end - -let inline (~) = bitwise_not_bit - -val pow : integer -> integer -> integer -let pow m n = m ** (natFromInteger n) - let bitwise_not_bitlist = List.map bitwise_not_bit let bitwise_not (Vector bs start is_inc) = Vector (bitwise_not_bitlist bs) start is_inc -val is_one : integer -> bitU -let is_one i = - if i = 1 then I else O - -let bool_to_bit b = if b then I else O - -let bitwise_binop_bit op = function - | (Undef,_) -> Undef (*Do we want to do this or to respect | of I and & of B0 rules?*) - | (_,Undef) -> Undef (*Do we want to do this or to respect | of I and & of B0 rules?*) - | (x,y) -> bool_to_bit (op (to_bool x) (to_bool y)) - end - -val bitwise_and_bit : bitU * bitU -> bitU -let bitwise_and_bit = bitwise_binop_bit (&&) - -val bitwise_or_bit : bitU * bitU -> bitU -let bitwise_or_bit = bitwise_binop_bit (||) - -val bitwise_xor_bit : bitU * bitU -> bitU -let bitwise_xor_bit = bitwise_binop_bit xor - -val (&.) : bitU -> bitU -> bitU -let inline (&.) x y = bitwise_and_bit (x,y) - -val (|.) : bitU -> bitU -> bitU -let inline (|.) x y = bitwise_or_bit (x,y) - -val (+.) : bitU -> bitU -> bitU -let inline (+.) x y = bitwise_xor_bit (x,y) - let bitwise_binop op (Vector bsl start is_inc, Vector bsr _ _) = let revbs = foldl (fun acc pair -> bitwise_binop_bit op pair :: acc) [] (zip bsl bsr) in Vector (reverse revbs) start is_inc @@ -493,21 +490,15 @@ let minusSO_VBV = arith_op_overflow_vec_bit integerMinus true 1 type shift = LL_shift | RR_shift | LLL_shift let shift_op_vec op ((Vector bs start is_inc as l),(n : integer)) = - let len = integerFromNat (List.length bs) in - match op with - | LL_shift (*"<<"*) -> - let right_vec = Vector (List.replicate (natFromInteger n) O) 0 true in - let left_vec = slice l n (if is_inc then len + start else start - len) in - vector_concat left_vec right_vec - | RR_shift (*">>"*) -> - let right_vec = slice l start n in - let left_vec = Vector (List.replicate (natFromInteger n) O) 0 true in - vector_concat left_vec right_vec - | LLL_shift (*"<<<"*) -> - let left_vec = slice l n (if is_inc then len + start else start - len) in - let right_vec = slice l start n in - vector_concat left_vec right_vec - end + let n = natFromInteger n in + match op with + | LL_shift (*"<<"*) -> + Vector (sublist bs (n,List.length bs -1) ++ List.replicate n O) start is_inc + | RR_shift (*">>"*) -> + Vector (List.replicate n O ++ sublist bs (0,n-1)) start is_inc + | LLL_shift (*"<<<"*) -> + Vector (sublist bs (n,List.length bs - 1) ++ sublist bs (0,n-1)) start is_inc + end let bitwise_leftshift = shift_op_vec LL_shift (*"<<"*) let bitwise_rightshift = shift_op_vec RR_shift (*">>"*) @@ -568,18 +559,11 @@ let arith_op_vec_range_no0 op sign size (Vector _ _ is_inc as l) r = let mod_VIV = arith_op_vec_range_no0 integerMod false 1 -let duplicate (bit,length) = - Vector (List.replicate (natFromInteger length) bit) 0 true - val repeat : forall 'a. list 'a -> integer -> list 'a let rec repeat xs n = if n = 0 then [] else xs ++ repeat xs (n-1) -let duplicate_bits (Vector bits start direction,len) = - let bits' = repeat bits len in - Vector bits' start direction - let compare_op op (l,r) = bool_to_bit (op l r) let lt = compare_op (<) @@ -637,22 +621,17 @@ let neq_vec (l,r) = bitwise_not_bit (eq_vec_vec (l,r)) let neq_vec_range (l,r) = bitwise_not_bit (eq_vec_range (l,r)) let neq_range_vec (l,r) = bitwise_not_bit (eq_range_vec (l,r)) -let rec replace bs ((n : integer),b') = match bs with - | [] -> [] - | b :: bs -> - if n = 0 then b' :: bs - else b :: replace bs (n - 1,b') - end - val make_indexed_vector : forall 'a. list (integer * 'a) -> 'a -> integer -> integer -> bool -> vector 'a let make_indexed_vector entries default start length dir = let length = natFromInteger length in Vector (List.foldl replace (replicate length default) entries) start dir +(* val make_bit_vector_undef : integer -> vector bitU let make_bitvector_undef length = Vector (replicate (natFromInteger length) Undef) 0 true + *) (* let bitwise_not_range_bit n = bitwise_not (to_vec defaultDir n) *) @@ -668,14 +647,17 @@ let rec byte_chunks n list = match (n,list) with | _ -> failwith "byte_chunks not given enough bits" end +val bitv_of_byte_lifteds : bool -> list Sail_impl_base.byte_lifted -> vector bitU +let bitv_of_byte_lifteds dir v = + let bits = foldl (fun x (Byte_lifted y) -> x ++ (List.map bitU_of_bit_lifted y)) [] v in + let len = integerFromNat (List.length bits) in + Vector bits (if dir then 0 else len - 1) dir -val bitv_of_byte_lifteds : list Sail_impl_base.byte_lifted -> vector bitU -let bitv_of_byte_lifteds v = - Vector (foldl (fun x (Byte_lifted y) -> x ++ (List.map bitU_of_bit_lifted y)) [] v) 0 true - -val bitv_of_bytes : list Sail_impl_base.byte -> vector bitU -let bitv_of_bytes v = - Vector (foldl (fun x (Byte y) -> x ++ (List.map bitU_of_bit y)) [] v) 0 true +val bitv_of_bytes : bool -> list Sail_impl_base.byte -> vector bitU +let bitv_of_bytes dir v = + let bits = foldl (fun x (Byte y) -> x ++ (List.map bitU_of_bit y)) [] v in + let len = integerFromNat (List.length bits) in + Vector bits (if dir then 0 else len - 1) dir val byte_lifteds_of_bitv : vector bitU -> list byte_lifted @@ -683,6 +665,13 @@ let byte_lifteds_of_bitv (Vector bits length is_inc) = let bits = List.map bit_lifted_of_bitU bits in byte_lifteds_of_bit_lifteds bits +val bit_lifteds_of_bitUs : list bitU -> list bit_lifted +let bit_lifteds_of_bitUs bits = List.map bit_lifted_of_bitU bits + +val bit_lifteds_of_bitv : vector bitU -> list bit_lifted +let bit_lifteds_of_bitv v = bit_lifteds_of_bitUs (get_elems v) + + val address_lifted_of_bitv : vector bitU -> address_lifted let address_lifted_of_bitv v = let byte_lifteds = byte_lifteds_of_bitv v in @@ -694,22 +683,79 @@ let address_lifted_of_bitv v = Address_lifted byte_lifteds maybe_address_integer -val bitvFromRegisterValue : register_value -> vector bitU -let bitvFromRegisterValue v = + +(*** Registers *) + +type register_field = string +type register_field_index = string * (integer * integer) (* name, start and end *) + +type register = + | Register of string * (* name *) + integer * (* length *) + integer * (* start index *) + bool * (* is increasing *) + list register_field_index + | UndefinedRegister of integer (* length *) + | RegisterPair of register * register + +let name_of_reg (Register name _ _ _ _) = name +let size_of_reg (Register _ size _ _ _) = size +let start_of_reg (Register _ _ start _ _) = start +let is_inc_of_reg (Register _ _ _ is_inc _) = is_inc +let dir_of_reg (Register _ _ _ is_inc _) = dir is_inc + +let size_of_reg_nat reg = natFromInteger (size_of_reg reg) +let start_of_reg_nat reg = natFromInteger (start_of_reg reg) + +val register_field_indices_aux : register -> register_field -> maybe (integer * integer) +let rec register_field_indices_aux register rfield = + match register with + | Register _ _ _ _ rfields -> List.lookup rfield rfields + | RegisterPair r1 r2 -> + let m_indices = register_field_indices_aux r1 rfield in + if isJust m_indices then m_indices else register_field_indices_aux r2 rfield + | UndefinedRegister -> Nothing + end + +val register_field_indices : register -> register_field -> integer * integer +let register_field_indices register rfield = + match register_field_indices_aux register rfield with + | Just indices -> indices + | Nothing -> failwith "Invalid register/register-field combination" + end + +let register_field_indices_nat reg regfield= + let (i,j) = register_field_indices reg regfield in + (natFromInteger i,natFromInteger j) + +val bitv_of_register_value : register_value -> vector bitU +let bitv_of_register_value v = Vector (List.map bitU_of_bit_lifted v.rv_bits) (integerFromNat v.rv_start_internal) (v.rv_dir = D_increasing) -val registerValueFromBitv : vector bitU -> register -> register_value -let registerValueFromBitv (Vector bits start is_inc) reg = - let start = natFromInteger start in - let bit_lifteds = - List.map bit_lifted_of_bitU bits in - <| rv_bits = bit_lifteds; - rv_dir = dir is_inc; - rv_start_internal = start; - rv_start = if is_inc then start else start+1 - (size_of_reg_nat reg) |> +let rec extern_reg_value reg_name v = + let (internal_start, external_start, direction) = + (match reg_name with + | Reg _ start size dir -> + (start, (if dir = D_increasing then start else (start - (size +1))), dir) + | Reg_slice _ reg_start dir (slice_start, slice_end) -> + ((if dir = D_increasing then slice_start else (reg_start - slice_start)), + slice_start, dir) + | Reg_field _ reg_start dir _ (slice_start, slice_end) -> + ((if dir = D_increasing then slice_start else (reg_start - slice_start)), + slice_start, dir) + | Reg_f_slice _ reg_start dir _ _ (slice_start, slice_end) -> + ((if dir = D_increasing then slice_start else (reg_start - slice_start)), + slice_start, dir) end) in + let bits = bit_lifteds_of_bitv v in + <| rv_bits = bits; + rv_dir = direction; + rv_start = external_start; + rv_start_internal = internal_start |> + + class (ToNatural 'a) diff --git a/src/pretty_print.ml b/src/pretty_print.ml index e87f42a6..baa51857 100644 --- a/src/pretty_print.ml +++ b/src/pretty_print.ml @@ -1638,7 +1638,7 @@ let doc_exp_ocaml, doc_let_ocaml = (match annot with | Base(_,External _,_,_,_,_) -> if read_registers - then parens( string "read_register" ^^ space ^^ exp e) + then parens (string "read_register" ^^ space ^^ exp e) else exp e | _ -> (parens (doc_op colon (group (exp e)) (doc_typ_ocaml typ)))) | E_tuple exps -> @@ -2329,68 +2329,69 @@ let doc_exp_lem, doc_let_lem = let epp = let_exp regtypes leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in if aexp_needed then parens epp else epp | E_app(f,args) -> - (match f with - (* temporary hack to make the loop body a function of the temporary variables *) - | Id_aux ((Id (("foreach_inc" | "foreach_dec" | - "foreachM_inc" | "foreachM_dec" ) as loopf),_)) -> - let [id;indices;body;e5] = args in - let varspp = match e5 with - | E_aux (E_tuple vars,_) -> - let vars = List.map (fun (E_aux (E_id (Id_aux (Id name,_)),_)) -> string name) vars in - begin match vars with - | [v] -> v - | _ -> parens (separate comma vars) end - | E_aux (E_id (Id_aux (Id name,_)),_) -> - string name - | E_aux (E_lit (L_aux (L_unit,_)),_) -> - string "_" in - parens ( - (prefix 2 1) - ((separate space) [string loopf;group (expY indices);expY e5]) - (parens - (prefix 1 1 (separate space [string "fun";expY id;varspp;arrow]) (expN body)) - ) - ) - | Id_aux (Id "append",_) -> - let [e1;e2] = args in - let epp = align (expY e1 ^^ space ^^ string "++" ^//^ expY e2) in - if aexp_needed then parens (align epp) else epp - | _ -> - (match annot with - | Base (_,Constructor _,_,_,_,_) -> - let epp = - match args with - | [] -> doc_id_lem_ctor f - | [arg] -> doc_id_lem_ctor f ^^ space ^^ expY arg - | _ -> - doc_id_lem_ctor f ^^ space ^^ - parens (separate_map comma expY args) in - if aexp_needed then parens (align epp) else epp - | Base (_,External (Some "bitwise_not_bit"),_,_,_,_) -> - let [a] = args in - let epp = align (string "~" ^^ expY a) in - if aexp_needed then parens (align epp) else epp - | _ -> - let call = match annot with - | Base(_,External (Some n),_,_,_,_) -> - (match n with - | _ -> string n) - | Base(_,Constructor _,_,_,_,_) -> doc_id_lem_ctor f - | _ -> doc_id_lem f in - let argpp a_needed arg = - let (E_aux (_,(_,Base((_,{t=t}),_,_,_,_,_)))) = arg in - match t with - | Tapp("vector",_) -> - let epp = concat [string "copy";space;expY arg] in - if a_needed then parens epp else epp - | _ -> expV a_needed arg in - let argspp = match args with - | [arg] -> argpp true arg - | args -> parens (align (separate_map (comma ^^ break 0) (argpp false) args)) in - let epp = align (call ^//^ argspp) in - if aexp_needed then parens (align epp) else epp - ) - ) + begin match f with + (* temporary hack to make the loop body a function of the temporary variables *) + | Id_aux ((Id (("foreach_inc" | "foreach_dec" | + "foreachM_inc" | "foreachM_dec" ) as loopf),_)) -> + let [id;indices;body;e5] = args in + let varspp = match e5 with + | E_aux (E_tuple vars,_) -> + let vars = List.map (fun (E_aux (E_id (Id_aux (Id name,_)),_)) -> string name) vars in + begin match vars with + | [v] -> v + | _ -> parens (separate comma vars) end + | E_aux (E_id (Id_aux (Id name,_)),_) -> + string name + | E_aux (E_lit (L_aux (L_unit,_)),_) -> + string "_" in + parens ( + (prefix 2 1) + ((separate space) [string loopf;group (expY indices);expY e5]) + (parens + (prefix 1 1 (separate space [string "fun";expY id;varspp;arrow]) (expN body)) + ) + ) + | Id_aux (Id "append",_) -> + let [e1;e2] = args in + let epp = align (expY e1 ^^ space ^^ string "++" ^//^ expY e2) in + if aexp_needed then parens (align epp) else epp + | Id_aux (Id "slice_raw",_) -> + let [e1;e2;e3] = args in + let epp = separate space [string "slice_raw";expY e1;expY e2;expY e3] in + if aexp_needed then parens (align epp) else epp + | _ -> + begin match annot with + | Base (_,External (Some "bitwise_not_bit"),_,_,_,_) -> + let [a] = args in + let epp = align (string "~" ^^ expY a) in + if aexp_needed then parens (align epp) else epp + | Base (_,Constructor _,_,_,_,_) -> + let epp = + match args with + | [] -> doc_id_lem_ctor f + | [arg] -> doc_id_lem_ctor f ^^ space ^^ expY arg + | _ -> + doc_id_lem_ctor f ^^ space ^^ + parens (separate_map comma expY args) in + if aexp_needed then parens (align epp) else epp + | _ -> + let call = match annot with + | Base(_,External (Some n),_,_,_,_) -> string n + | _ -> doc_id_lem f in + let argpp a_needed arg = + let (E_aux (_,(_,Base((_,{t=t}),_,_,_,_,_)))) = arg in + match t with + | Tapp("vector",_) -> + let epp = concat [string "copy";space;expY arg] in + if a_needed then parens epp else epp + | _ -> expV a_needed arg in + let argspp = match args with + | [arg] -> argpp true arg + | args -> parens (align (separate_map (comma ^^ break 0) (argpp false) args)) in + let epp = align (call ^//^ argspp) in + if aexp_needed then parens (align epp) else epp + end + end | E_vector_access (v,e) -> let (Base (_,_,_,_,eff,_)) = annot in let epp = diff --git a/src/rewriter.ml b/src/rewriter.ml index 4b6598c0..dab16ae4 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -895,7 +895,8 @@ let remove_vector_concat_pat pat = E_aux (E_app_infix(length_app_exp,minus,one_exp),annot) in exp in - let subv = E_aux (E_vector_subrange (root,index_i,index_j),cannot) in + let subv = E_aux (E_app (Id_aux (Id "slice_raw",Unknown), + [root;index_i;index_j]),cannot) in let typ = (Parse_ast.Generated l,simple_annot {t = Tid "unit"}) in |
