summaryrefslogtreecommitdiff
path: root/src/gen_lib/state.lem
diff options
context:
space:
mode:
Diffstat (limited to 'src/gen_lib/state.lem')
-rw-r--r--src/gen_lib/state.lem252
1 files changed, 195 insertions, 57 deletions
diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem
index 88e29522..2fff7344 100644
--- a/src/gen_lib/state.lem
+++ b/src/gen_lib/state.lem
@@ -6,45 +6,77 @@ open import Sail_values
type memstate = map integer memory_byte
type tagstate = map integer bitU
-type regstate = map string (vector bitU)
-
-type sequential_state = <| regstate : regstate;
- memstate : memstate;
- tagstate : tagstate;
- write_ea : maybe (write_kind * integer * integer);
- last_exclusive_operation_was_load : bool|>
-
-type M 'a = sequential_state -> list ((either 'a string) * sequential_state)
-
-val return : forall 'a. 'a -> M 'a
+(* type regstate = map string (vector bitU) *)
+
+type sequential_state 'regs =
+ <| regstate : 'regs;
+ memstate : memstate;
+ tagstate : tagstate;
+ write_ea : maybe (write_kind * integer * integer);
+ last_exclusive_operation_was_load : bool|>
+
+(* State, nondeterminism and exception monad with result type 'a
+ and exception type 'e. *)
+type ME 'regs 'a 'e = sequential_state 'regs -> list ((either 'a 'e) * sequential_state 'regs)
+
+(* By default, we use strings to distinguish between different types of exceptions *)
+type M 'regs 'a = ME 'regs 'a string
+
+(* For early return, we abuse exceptions by throwing and catching
+ the return value. The exception type is "either 'r string", where "Right e"
+ represents a proper exception and "Left r" an early return of value "r". *)
+type MR 'regs 'a 'r = ME 'regs 'a (either 'r string)
+
+val liftR : forall 'a 'r 'regs. M 'regs 'a -> MR 'regs 'a 'r
+let liftR m s = List.map (function
+ | (Left a, s') -> (Left a, s')
+ | (Right e, s') -> (Right (Right e), s')
+ end) (m s)
+
+val return : forall 'regs 'a 'e. 'a -> ME 'regs 'a 'e
let return a s = [(Left a,s)]
-val bind : forall 'a 'b. M 'a -> ('a -> M 'b) -> M 'b
-let bind m f (s : sequential_state) =
+val bind : forall 'regs 'a 'b 'e. ME 'regs 'a 'e -> ('a -> ME 'regs 'b 'e) -> ME 'regs 'b 'e
+let bind m f (s : sequential_state 'regs) =
List.concatMap (function
| (Left a, s') -> f a s'
| (Right e, s') -> [(Right e, s')]
end) (m s)
let inline (>>=) = bind
-val (>>): forall 'b. M unit -> M 'b -> M 'b
+val (>>): forall 'regs 'b 'e. ME 'regs unit 'e -> ME 'regs 'b 'e -> ME 'regs 'b 'e
let inline (>>) m n = m >>= fun _ -> n
-val exit : forall 'e 'a. 'e -> M 'a
-let exit _ s = [(Right "exit",s)]
+val exit : forall 'regs 'e 'a. 'e -> M 'regs 'a
+let exit _ s = [(Right "exit", s)]
+val assert_exp : forall 'regs. bool -> string -> M 'regs unit
+let assert_exp exp msg s = if exp then [(Left (), s)] else [(Right msg, s)]
+
+val early_return : forall 'regs 'a 'r. 'r -> MR 'regs 'a 'r
+let early_return r s = [(Right (Left r), s)]
+
+val catch_early_return : forall 'regs 'a. MR 'regs 'a 'a -> M 'regs 'a
+let catch_early_return m s =
+ List.map
+ (function
+ | (Right (Left a), s') -> (Left a, s')
+ | (Right (Right e), s') -> (Right e, s')
+ | (Left a, s') -> (Left a, s')
+ end) (m s)
val range : integer -> integer -> list integer
-let rec range i j =
- if i = j then [i]
+let rec range i j =
+ if j < i then []
+ else if i = j then [i]
else i :: range (i+1) j
-val get_reg : sequential_state -> string -> vector bitU
-let get_reg state reg = Map_extra.find reg state.regstate
+val get_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a
+let get_reg state reg = reg.read_from state.regstate
-val set_reg : sequential_state -> string -> vector bitU -> sequential_state
-let set_reg state reg bitv =
- <| state with regstate = Map.insert reg bitv state.regstate |>
+val set_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a -> sequential_state 'regs
+let set_reg state reg v =
+ <| state with regstate = reg.write_to state.regstate v |>
let is_exclusive = function
@@ -65,10 +97,18 @@ end
val read_mem : bool -> read_kind -> vector bitU -> integer -> M (vector bitU)
let read_mem dir read_kind addr sz state =
- let addr = integer_of_address (address_of_bitv addr) in
+ let addr = unsigned addr in
let addrs = range addr (addr+sz-1) in
let memory_value = List.map (fun addr -> Map_extra.find addr state.memstate) addrs in
- let value = Sail_values.internal_mem_value dir memory_value in
+ let value = of_bits (Sail_values.internal_mem_value dir memory_value) in
+ let is_exclusive = match read_kind with
+ | Sail_impl_base.Read_plain -> false
+ | Sail_impl_base.Read_reserve -> true
+ | Sail_impl_base.Read_acquire -> false
+ | Sail_impl_base.Read_exclusive -> true
+ | Sail_impl_base.Read_exclusive_acquire -> true
+ | Sail_impl_base.Read_stream -> false
+ end in
if is_exclusive read_kind
then [(Left value, <| state with last_exclusive_operation_was_load = true |>)]
@@ -77,42 +117,50 @@ let read_mem dir read_kind addr sz state =
(* caps are aligned at 32 bytes *)
let cap_alignment = (32 : integer)
-val read_tag : bool -> read_kind -> vector bitU -> M bitU
+val read_tag : forall 'regs 'a. Bitvector 'a => bool -> read_kind -> 'a -> M 'regs bitU
let read_tag dir read_kind addr state =
- let addr = (integer_of_address (address_of_bitv addr)) / cap_alignment in
+ let addr = (unsigned addr) / cap_alignment in
let tag = match (Map.lookup addr state.tagstate) with
| Just t -> t
| Nothing -> B0
end in
+ let is_exclusive = match read_kind with
+ | Sail_impl_base.Read_plain -> false
+ | Sail_impl_base.Read_reserve -> true
+ | Sail_impl_base.Read_acquire -> false
+ | Sail_impl_base.Read_exclusive -> true
+ | Sail_impl_base.Read_exclusive_acquire -> true
+ | Sail_impl_base.Read_stream -> false
+ end in
+
(* TODO Should reading a tag set the exclusive flag? *)
if is_exclusive read_kind
then [(Left tag, <| state with last_exclusive_operation_was_load = true |>)]
else [(Left tag, state)]
-val excl_result : unit -> M bool
+val excl_result : forall 'regs. unit -> M 'regs bool
let excl_result () state =
let success =
(Left true, <| state with last_exclusive_operation_was_load = false |>) in
(Left false, state) :: if state.last_exclusive_operation_was_load then [success] else []
-val write_mem_ea : write_kind -> vector bitU -> integer -> M unit
+val write_mem_ea : forall 'regs 'a. Bitvector 'a => write_kind -> 'a -> integer -> M 'regs unit
let write_mem_ea write_kind addr sz state =
- let addr = integer_of_address (address_of_bitv addr) in
- [(Left (), <| state with write_ea = Just (write_kind,addr,sz) |>)]
+ [(Left (), <| state with write_ea = Just (write_kind,unsigned addr,sz) |>)]
-val write_mem_val : vector bitU -> M bool
+val write_mem_val : forall 'a 'regs 'b. Bitvector 'a => 'a -> M 'regs bool
let write_mem_val v state =
let (write_kind,addr,sz) = match state.write_ea with
| Nothing -> failwith "write ea has not been announced yet"
| Just write_ea -> write_ea end in
let addrs = range addr (addr+sz-1) in
- let v = external_mem_value v in
+ let v = external_mem_value (bits_of v) in
let addresses_with_value = List.zip addrs v in
let memstate = List.foldl (fun mem (addr,v) -> Map.insert addr v mem)
state.memstate addresses_with_value in
[(Left true, <| state with memstate = memstate |>)]
-val write_tag : bitU -> M bool
+val write_tag : forall 'regs. bitU -> M 'regs bool
let write_tag t state =
let (write_kind,addr,sz) = match state.write_ea with
| Nothing -> failwith "write ea has not been announced yet"
@@ -121,24 +169,26 @@ let write_tag t state =
let tagstate = Map.insert taddr t state.tagstate in
[(Left true, <| state with tagstate = tagstate |>)]
-val read_reg : register -> M (vector bitU)
+val read_reg : forall 'regs 'a. register_ref 'regs 'a -> M 'regs 'a
let read_reg reg state =
- let v = Map_extra.find (name_of_reg reg) state.regstate in
+ let v = reg.read_from state.regstate in
+ [(Left v,state)]
+(*let read_reg_range reg i j state =
+ let v = slice (get_reg state (name_of_reg reg)) i j in
+ [(Left (vec_to_bvec v),state)]
+let read_reg_bit reg i state =
+ let v = access (get_reg state (name_of_reg reg)) i in
[(Left v,state)]
-let read_reg_range reg i j =
- read_reg reg >>= fun rv ->
- return (slice rv i j)
-let read_reg_bit reg i =
- read_reg_range reg i i >>= fun v ->
- return (extract_only_bit v)
let read_reg_field reg regfield =
let (i,j) = register_field_indices reg regfield in
read_reg_range reg i j
let read_reg_bitfield reg regfield =
let (i,_) = register_field_indices reg regfield in
- read_reg_bit reg i
+ read_reg_bit reg i *)
-val write_reg : register -> vector bitU -> M unit
+let reg_deref = read_reg
+
+val write_reg : forall 'regs 'a. register_ref 'regs 'a -> 'a -> M 'regs unit
let write_reg reg v state =
[(Left (),<| state with regstate = Map.insert (name_of_reg reg) v state.regstate |>)]
let write_reg_range reg i j v =
@@ -158,34 +208,120 @@ let write_reg_field_range reg regfield i j v =
let new_field_value = update current_field_value i j v in
write_reg_field reg regfield new_field_value
+val update_reg : forall 'regs 'a 'b. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit
+let update_reg reg f v state =
+ let current_value = get_reg state reg in
+ let new_value = f current_value v in
+ [(Left (), set_reg state reg new_value)]
+
+let write_reg_field reg regfield = update_reg reg regfield.set_field
-val barrier : barrier_kind -> M unit
+val update_reg_range : forall 'regs 'a 'b. Bitvector 'a, Bitvector 'b => register_ref 'regs 'a -> integer -> integer -> 'a -> 'b -> 'a
+let update_reg_range reg i j reg_val new_val = set_bits (reg.reg_is_inc) (reg.reg_start) reg_val i j (bits_of new_val)
+let write_reg_range reg i j = update_reg reg (update_reg_range reg i j)
+
+let update_reg_pos reg i reg_val x = update_pos reg_val i x
+let write_reg_pos reg i = update_reg reg (update_reg_pos reg i)
+
+let update_reg_bit reg i reg_val bit = set_bit (reg.reg_is_inc) (reg.reg_start) reg_val i (to_bitU bit)
+let write_reg_bit reg i = update_reg reg (update_reg_bit reg i)
+
+let update_reg_field_range regfield i j reg_val new_val =
+ let current_field_value = regfield.get_field reg_val in
+ let new_field_value = set_bits (regfield.field_is_inc) (regfield.field_start) current_field_value i j (bits_of new_val) in
+ regfield.set_field reg_val new_field_value
+let write_reg_field_range reg regfield i j = update_reg reg (update_reg_field_range regfield i j)
+
+let update_reg_field_pos regfield i reg_val x =
+ let current_field_value = regfield.get_field reg_val in
+ let new_field_value = update_pos current_field_value i x in
+ regfield.set_field reg_val new_field_value
+let write_reg_field_pos reg regfield i = update_reg reg (update_reg_field_pos regfield i)
+
+let update_reg_field_bit regfield i reg_val bit =
+ let current_field_value = regfield.get_field reg_val in
+ let new_field_value = set_bit (regfield.field_is_inc) (regfield.field_start) current_field_value i (to_bitU bit) in
+ regfield.set_field reg_val new_field_value
+let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i)
+
+val barrier : forall 'regs. barrier_kind -> M 'regs unit
let barrier _ = return ()
-val footprint : M unit
-let footprint = return ()
+val footprint : forall 'regs. M 'regs unit
+let footprint s = return () s
-val foreachM_inc : forall 'vars. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> M 'vars) -> M 'vars
+val foreachM_inc : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
let rec foreachM_inc (i,stop,by) vars body =
- if i <= stop
+ if (by > 0 && i <= stop) || (by < 0 && stop <= i)
then
body i vars >>= fun vars ->
foreachM_inc (i + by,stop,by) vars body
else return vars
-val foreachM_dec : forall 'vars. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> M 'vars) -> M 'vars
-let rec foreachM_dec (i,stop,by) vars body =
- if i >= stop
+val foreachM_dec : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+let rec foreachM_dec (stop,i,by) vars body =
+ if (by > 0 && i >= stop) || (by < 0 && stop >= i)
then
body i vars >>= fun vars ->
- foreachM_dec (i - by,stop,by) vars body
+ foreachM_dec (stop,i - by,by) vars body
+ else return vars
+
+val while_PP : forall 'vars. 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'vars
+let rec while_PP vars cond body =
+ if cond vars then while_PP (body vars) cond body else vars
+
+val while_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) ->
+ ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+let rec while_PM vars cond body =
+ if cond vars then
+ body vars >>= fun vars -> while_PM vars cond body
+ else return vars
+
+val while_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
+ ('vars -> 'vars) -> ME 'regs 'vars 'e
+let rec while_MP vars cond body =
+ cond vars >>= fun cond_val ->
+ if cond_val then while_MP (body vars) cond body else return vars
+
+val while_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
+ ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+let rec while_MM vars cond body =
+ cond vars >>= fun cond_val ->
+ if cond_val then
+ body vars >>= fun vars -> while_MM vars cond body
else return vars
-let write_two_regs r1 r2 vec =
+val until_PP : forall 'vars. 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'vars
+let rec until_PP vars cond body =
+ let vars = body vars in
+ if (cond vars) then vars else until_PP (body vars) cond body
+
+val until_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) ->
+ ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+let rec until_PM vars cond body =
+ body vars >>= fun vars ->
+ if (cond vars) then return vars else until_PM vars cond body
+
+val until_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
+ ('vars -> 'vars) -> ME 'regs 'vars 'e
+let rec until_MP vars cond body =
+ let vars = body vars in
+ cond vars >>= fun cond_val ->
+ if cond_val then return vars else until_MP vars cond body
+
+val until_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
+ ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+let rec until_MM vars cond body =
+ body vars >>= fun vars ->
+ cond vars >>= fun cond_val ->
+ if cond_val then return vars else until_MM vars cond body
+
+(*let write_two_regs r1 r2 bvec state =
+ let vec = bvec_to_vec bvec in
let is_inc =
let is_inc_r1 = is_inc_of_reg r1 in
let is_inc_r2 = is_inc_of_reg r2 in
@@ -204,4 +340,6 @@ let write_two_regs r1 r2 vec =
if is_inc
then slice vec (size_r1 - start_vec) (size_vec - start_vec)
else slice vec (start_vec - size_r1) (start_vec - size_vec) in
- write_reg r1 r1_v >> write_reg r2 r2_v
+ let state1 = set_reg state (name_of_reg r1) r1_v in
+ let state2 = set_reg state1 (name_of_reg r2) r2_v in
+ [(Left (), state2)]*)