summaryrefslogtreecommitdiff
path: root/src/gen_lib/state.lem
blob: 69b9e301fdabb126c54bc803ce17218c6dbc6059 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
open import Pervasives_extra
open import Sail_impl_base
open import Sail_values

(* 'a is result type *)

type memstate = map integer memory_byte
type tagstate = map integer bitU
(* type regstate = map string (vector bitU) *)

type sequential_state 'regs =
  <| regstate : 'regs;
     memstate : memstate;
     tagstate : tagstate;
     write_ea : maybe (write_kind * integer * integer);
     last_exclusive_operation_was_load : bool|>

(* State, nondeterminism and exception monad with result type 'a
   and exception type 'e. *)
type ME 'regs 'a 'e = sequential_state 'regs -> list ((either 'a 'e) * sequential_state 'regs)

(* By default, we use strings to distinguish between different types of exceptions *)
type M 'regs 'a = ME 'regs 'a string

(* For early return, we abuse exceptions by throwing and catching
   the return value. The exception type is "either 'r string", where "Right e"
   represents a proper exception and "Left r" an early return of value "r". *)
type MR 'regs 'a 'r = ME 'regs 'a (either 'r string)

val liftR : forall 'a 'r 'regs. M 'regs 'a -> MR 'regs 'a 'r
let liftR m s = List.map (function
  | (Left a, s') -> (Left a, s')
  | (Right e, s') -> (Right (Right e), s')
  end) (m s)

val return : forall 'regs 'a 'e. 'a -> ME 'regs 'a 'e
let return a s = [(Left a,s)]

val bind : forall 'regs 'a 'b 'e. ME 'regs 'a 'e -> ('a -> ME 'regs 'b 'e) -> ME 'regs 'b 'e
let bind m f (s : sequential_state 'regs) =
  List.concatMap (function
                  | (Left a, s') -> f a s'
                  | (Right e, s') -> [(Right e, s')]
                  end) (m s)

let inline (>>=) = bind
val (>>): forall 'regs 'b 'e. ME 'regs unit 'e -> ME 'regs 'b 'e -> ME 'regs 'b 'e
let inline (>>) m n = m >>= fun _ -> n

val exit : forall 'regs 'e 'a. 'e -> M 'regs 'a
let exit _ s = [(Right "exit", s)]

val assert_exp : forall 'regs. bool -> string -> M 'regs unit
let assert_exp exp msg s = if exp then [(Left (), s)] else [(Right msg, s)]

val early_return : forall 'regs 'a 'r. 'r -> MR 'regs 'a 'r
let early_return r s = [(Right (Left r), s)]

val catch_early_return : forall 'regs 'a. MR 'regs 'a 'a -> M 'regs 'a
let catch_early_return m s =
  List.map
    (function
     | (Right (Left a), s') -> (Left a, s')
     | (Right (Right e), s') -> (Right e, s')
     | (Left a, s') -> (Left a, s')
    end) (m s)

val range : integer -> integer -> list integer
let rec range i j =
  if j < i then []
  else if i = j then [i]
  else i :: range (i+1) j

val get_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a
let get_reg state reg = reg.read_from state.regstate

val set_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a -> sequential_state 'regs
let set_reg state reg v =
  <| state with regstate = reg.write_to state.regstate v |>


val read_mem : forall 'regs 'a 'b. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'regs 'b
let read_mem dir read_kind addr sz state =
  let addr = unsigned addr in
  let addrs = range addr (addr+sz-1) in
  let memory_value = List.map (fun addr -> Map_extra.find addr state.memstate) addrs in
  let value = of_bits (Sail_values.internal_mem_value dir memory_value) in
  let is_exclusive = match read_kind with
    | Sail_impl_base.Read_plain -> false
    | Sail_impl_base.Read_reserve -> true
    | Sail_impl_base.Read_acquire -> false
    | Sail_impl_base.Read_exclusive -> true
    | Sail_impl_base.Read_exclusive_acquire -> true
    | Sail_impl_base.Read_stream -> false
  end in

  if is_exclusive 
  then [(Left value, <| state with last_exclusive_operation_was_load = true |>)]
  else [(Left value, state)]

(* caps are aligned at 32 bytes *)
let cap_alignment = (32 : integer)

val read_tag : forall 'regs 'a. Bitvector 'a => bool -> read_kind -> 'a -> M 'regs bitU
let read_tag dir read_kind addr state =
  let addr = (unsigned addr) / cap_alignment in
  let tag = match (Map.lookup addr state.tagstate) with
    | Just t -> t
    | Nothing -> B0
  end in
  let is_exclusive = match read_kind with
    | Sail_impl_base.Read_plain -> false
    | Sail_impl_base.Read_reserve -> true
    | Sail_impl_base.Read_acquire -> false
    | Sail_impl_base.Read_exclusive -> true
    | Sail_impl_base.Read_exclusive_acquire -> true
    | Sail_impl_base.Read_stream -> false
  end in

  (* TODO Should reading a tag set the exclusive flag? *)
  if is_exclusive
  then [(Left tag, <| state with last_exclusive_operation_was_load = true |>)]
  else [(Left tag, state)]

val excl_result : forall 'regs. unit -> M 'regs bool
let excl_result () state =
  let success =
    (Left true,  <| state with last_exclusive_operation_was_load = false |>) in
  (Left false, state) :: if state.last_exclusive_operation_was_load then [success] else []

val write_mem_ea : forall 'regs 'a. Bitvector 'a => write_kind -> 'a -> integer -> M 'regs unit
let write_mem_ea write_kind addr sz state = 
  [(Left (), <| state with write_ea = Just (write_kind,unsigned addr,sz) |>)]

val write_mem_val : forall 'a 'regs 'b. Bitvector 'a => 'a -> M 'regs bool
let write_mem_val v state =
  let (write_kind,addr,sz) = match state.write_ea with
    | Nothing -> failwith "write ea has not been announced yet"
    | Just write_ea -> write_ea end in
  let addrs = range addr (addr+sz-1) in
  let v = external_mem_value (bits_of v) in
  let addresses_with_value = List.zip addrs v in
  let memstate = List.foldl (fun mem (addr,v) -> Map.insert addr v mem)
                            state.memstate addresses_with_value in
  [(Left true,  <| state with memstate = memstate |>)]

val write_tag : forall 'regs. bitU -> M 'regs bool
let write_tag t state =
  let (write_kind,addr,sz) = match state.write_ea with
    | Nothing -> failwith "write ea has not been announced yet"
    | Just write_ea -> write_ea end in
  let taddr = addr / cap_alignment in
  let tagstate = Map.insert taddr t state.tagstate in
  [(Left true,  <| state with tagstate = tagstate |>)]

val read_reg : forall 'regs 'a. register_ref 'regs 'a -> M 'regs 'a
let read_reg reg state =
  let v = reg.read_from state.regstate in
  [(Left v,state)]
(*let read_reg_range reg i j state =
  let v = slice (get_reg state (name_of_reg reg)) i j in
  [(Left (vec_to_bvec v),state)]
let read_reg_bit reg i state =
  let v = access (get_reg state (name_of_reg reg)) i in
  [(Left v,state)]
let read_reg_field reg regfield =
  let (i,j) = register_field_indices reg regfield in
  read_reg_range reg i j
let read_reg_bitfield reg regfield =
  let (i,_) = register_field_indices reg regfield in
  read_reg_bit reg i *)

let reg_deref = read_reg

val write_reg : forall 'regs 'a. register_ref 'regs 'a -> 'a -> M 'regs unit
let write_reg reg v state =
  [(Left (), <| state with regstate = reg.write_to state.regstate v |>)]

val update_reg : forall 'regs 'a 'b. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit
let update_reg reg f v state =
  let current_value = get_reg state reg in
  let new_value = f current_value v in
  [(Left (), set_reg state reg new_value)]

let write_reg_field reg regfield = update_reg reg regfield.set_field

val update_reg_range : forall 'regs 'a 'b. Bitvector 'a, Bitvector 'b => register_ref 'regs 'a -> integer -> integer -> 'a -> 'b -> 'a
let update_reg_range reg i j reg_val new_val = set_bits (reg.reg_is_inc) (reg.reg_start) reg_val i j (bits_of new_val)
let write_reg_range reg i j = update_reg reg (update_reg_range reg i j)

let update_reg_pos reg i reg_val x = update_pos reg_val i x
let write_reg_pos reg i = update_reg reg (update_reg_pos reg i)

let update_reg_bit reg i reg_val bit = set_bit (reg.reg_is_inc) (reg.reg_start) reg_val i (to_bitU bit)
let write_reg_bit reg i = update_reg reg (update_reg_bit reg i)

let update_reg_field_range regfield i j reg_val new_val =
  let current_field_value = regfield.get_field reg_val in
  let new_field_value = set_bits (regfield.field_is_inc) (regfield.field_start) current_field_value i j (bits_of new_val) in
  regfield.set_field reg_val new_field_value
let write_reg_field_range reg regfield i j = update_reg reg (update_reg_field_range regfield i j)

let update_reg_field_pos regfield i reg_val x =
  let current_field_value = regfield.get_field reg_val in
  let new_field_value = update_pos current_field_value i x in
  regfield.set_field reg_val new_field_value
let write_reg_field_pos reg regfield i = update_reg reg (update_reg_field_pos regfield i)

let update_reg_field_bit regfield i reg_val bit =
  let current_field_value = regfield.get_field reg_val in
  let new_field_value = set_bit (regfield.field_is_inc) (regfield.field_start) current_field_value i (to_bitU bit) in
  regfield.set_field reg_val new_field_value
let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i)

val barrier : forall 'regs. barrier_kind -> M 'regs unit
let barrier _ = return ()

val footprint : forall 'regs. M 'regs unit
let footprint s = return () s


val foreachM_inc : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars ->
                  (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
let rec foreachM_inc (i,stop,by) vars body =
  if (by > 0 && i <= stop) || (by < 0 && stop <= i)
  then
    body i vars >>= fun vars ->
    foreachM_inc (i + by,stop,by) vars body
  else return vars


val foreachM_dec : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars ->
                  (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
let rec foreachM_dec (stop,i,by) vars body =
  if (by > 0 && i >= stop) || (by < 0 && stop >= i)
  then
    body i vars >>= fun vars ->
    foreachM_dec (stop,i - by,by) vars body
  else return vars

val while_PP : forall 'vars. 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'vars
let rec while_PP vars cond body =
  if cond vars then while_PP (body vars) cond body else vars

val while_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) ->
                ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
let rec while_PM vars cond body =
  if cond vars then
    body vars >>= fun vars -> while_PM vars cond body
  else return vars

val while_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
                ('vars -> 'vars) -> ME 'regs 'vars 'e
let rec while_MP vars cond body =
  cond vars >>= fun cond_val ->
  if cond_val then while_MP (body vars) cond body else return vars

val while_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
                ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
let rec while_MM vars cond body =
  cond vars >>= fun cond_val ->
  if cond_val then
    body vars >>= fun vars -> while_MM vars cond body
  else return vars

val until_PP : forall 'vars. 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'vars
let rec until_PP vars cond body =
  let vars = body vars in
  if (cond vars) then vars else until_PP (body vars) cond body

val until_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) ->
                ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
let rec until_PM vars cond body =
  body vars >>= fun vars ->
  if (cond vars) then return vars else until_PM vars cond body

val until_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
                ('vars -> 'vars) -> ME 'regs 'vars 'e
let rec until_MP vars cond body =
  let vars = body vars in
  cond vars >>= fun cond_val ->
  if cond_val then return vars else until_MP vars cond body

val until_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
                ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
let rec until_MM vars cond body =
  body vars >>= fun vars ->
  cond vars >>= fun cond_val ->
  if cond_val then return vars else until_MM vars cond body

(*let write_two_regs r1 r2 bvec state =
  let vec = bvec_to_vec bvec in
  let is_inc =
    let is_inc_r1 = is_inc_of_reg r1 in
    let is_inc_r2 = is_inc_of_reg r2 in
    let () = ensure (is_inc_r1 = is_inc_r2)
                    "write_two_regs called with vectors of different direction" in
    is_inc_r1 in

  let (size_r1 : integer) = size_of_reg r1 in
  let (start_vec : integer) = get_start vec in
  let size_vec = length vec in
  let r1_v =
    if is_inc
    then slice vec start_vec (size_r1 - start_vec - 1)
    else slice vec start_vec (start_vec - size_r1 - 1) in
  let r2_v =
    if is_inc
    then slice vec (size_r1 - start_vec) (size_vec - start_vec)
    else slice vec (start_vec - size_r1) (start_vec - size_vec) in
  let state1 = set_reg state (name_of_reg r1) r1_v in
  let state2 = set_reg state1 (name_of_reg r2) r2_v in
  [(Left (), state2)]*)