summaryrefslogtreecommitdiff
path: root/src/gen_lib/state_monad.lem
diff options
context:
space:
mode:
authorThomas Bauereiss2018-01-30 18:15:29 +0000
committerThomas Bauereiss2018-01-31 12:49:20 +0000
commit3cad2ad60f5f5f05ef94ba38590539939d3ccda0 (patch)
treeaa58990cbecf5a0367990039321c0d7672dbddce /src/gen_lib/state_monad.lem
parent0beb4619c72f2e1cc2123e278c1ed7744e350899 (diff)
Split base definitions of Lem monads and further built-ins (e.g. loop combinators)
Add Isabelle-specific theories imported directly after monad definitions, but before other combinators. These theories contain lemmas that tell the function package how to deal with monadic binds in function definitions.
Diffstat (limited to 'src/gen_lib/state_monad.lem')
-rw-r--r--src/gen_lib/state_monad.lem250
1 files changed, 250 insertions, 0 deletions
diff --git a/src/gen_lib/state_monad.lem b/src/gen_lib/state_monad.lem
new file mode 100644
index 00000000..2d8e412e
--- /dev/null
+++ b/src/gen_lib/state_monad.lem
@@ -0,0 +1,250 @@
+open import Pervasives_extra
+open import Sail_impl_base
+open import Sail_values
+
+(* 'a is result type *)
+
+type memstate = map integer memory_byte
+type tagstate = map integer bitU
+(* type regstate = map string (vector bitU) *)
+
+type sequential_state 'regs =
+ <| regstate : 'regs;
+ memstate : memstate;
+ tagstate : tagstate;
+ write_ea : maybe (write_kind * integer * integer);
+ last_exclusive_operation_was_load : bool|>
+
+val init_state : forall 'regs. 'regs -> sequential_state 'regs
+let init_state regs =
+ <| regstate = regs;
+ memstate = Map.empty;
+ tagstate = Map.empty;
+ write_ea = Nothing;
+ last_exclusive_operation_was_load = false |>
+
+type ex 'e =
+ | Exit
+ | Assert of string
+ | Throw of 'e
+
+type result 'a 'e =
+ | Value of 'a
+ | Exception of (ex 'e)
+
+(* State, nondeterminism and exception monad with result value type 'a
+ and exception type 'e. *)
+type M 'regs 'a 'e = sequential_state 'regs -> list (result 'a 'e * sequential_state 'regs)
+
+val return : forall 'regs 'a 'e. 'a -> M 'regs 'a 'e
+let return a s = [(Value a,s)]
+
+val bind : forall 'regs 'a 'b 'e. M 'regs 'a 'e -> ('a -> M 'regs 'b 'e) -> M 'regs 'b 'e
+let bind m f (s : sequential_state 'regs) =
+ List.concatMap (function
+ | (Value a, s') -> f a s'
+ | (Exception e, s') -> [(Exception e, s')]
+ end) (m s)
+
+let inline (>>=) = bind
+val (>>): forall 'regs 'b 'e. M 'regs unit 'e -> M 'regs 'b 'e -> M 'regs 'b 'e
+let inline (>>) m n = m >>= fun (_ : unit) -> n
+
+val throw : forall 'regs 'a 'e. 'e -> M 'regs 'a 'e
+let throw e s = [(Exception (Throw e), s)]
+
+val try_catch : forall 'regs 'a 'e1 'e2. M 'regs 'a 'e1 -> ('e1 -> M 'regs 'a 'e2) -> M 'regs 'a 'e2
+let try_catch m h s =
+ List.concatMap (function
+ | (Value a, s') -> return a s'
+ | (Exception (Throw e), s') -> h e s'
+ | (Exception Exit, s') -> [(Exception Exit, s')]
+ | (Exception (Assert msg), s') -> [(Exception (Assert msg), s')]
+ end) (m s)
+
+val exit : forall 'regs 'e 'a. unit -> M 'regs 'a 'e
+let exit () s = [(Exception Exit, s)]
+
+val assert_exp : forall 'regs 'e. bool -> string -> M 'regs unit 'e
+let assert_exp exp msg s = if exp then [(Value (), s)] else [(Exception (Assert msg), s)]
+
+(* For early return, we abuse exceptions by throwing and catching
+ the return value. The exception type is "either 'r 'e", where "Right e"
+ represents a proper exception and "Left r" an early return of value "r". *)
+type MR 'regs 'a 'r 'e = M 'regs 'a (either 'r 'e)
+
+val early_return : forall 'regs 'a 'r 'e. 'r -> MR 'regs 'a 'r 'e
+let early_return r = throw (Left r)
+
+val catch_early_return : forall 'regs 'a 'e. MR 'regs 'a 'a 'e -> M 'regs 'a 'e
+let catch_early_return m =
+ try_catch m
+ (function
+ | Left a -> return a
+ | Right e -> throw e
+ end)
+
+(* Lift to monad with early return by wrapping exceptions *)
+val liftR : forall 'a 'r 'regs 'e. M 'regs 'a 'e -> MR 'regs 'a 'r 'e
+let liftR m = try_catch m (fun e -> throw (Right e))
+
+(* Catch exceptions in the presence of early returns *)
+val try_catchR : forall 'regs 'a 'r 'e1 'e2. MR 'regs 'a 'r 'e1 -> ('e1 -> MR 'regs 'a 'r 'e2) -> MR 'regs 'a 'r 'e2
+let try_catchR m h =
+ try_catch m
+ (function
+ | Left r -> throw (Left r)
+ | Right e -> h e
+ end)
+
+val range : integer -> integer -> list integer
+let rec range i j =
+ if j < i then []
+ else if i = j then [i]
+ else i :: range (i+1) j
+
+val get_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a
+let get_reg state reg = reg.read_from state.regstate
+
+val set_reg : forall 'regs 'a. sequential_state 'regs -> register_ref 'regs 'a -> 'a -> sequential_state 'regs
+let set_reg state reg v =
+ <| state with regstate = reg.write_to state.regstate v |>
+
+
+let is_exclusive = function
+ | Sail_impl_base.Read_plain -> false
+ | Sail_impl_base.Read_reserve -> true
+ | Sail_impl_base.Read_acquire -> false
+ | Sail_impl_base.Read_exclusive -> true
+ | Sail_impl_base.Read_exclusive_acquire -> true
+ | Sail_impl_base.Read_stream -> false
+ | Sail_impl_base.Read_RISCV_acquire -> false
+ | Sail_impl_base.Read_RISCV_strong_acquire -> false
+ | Sail_impl_base.Read_RISCV_reserved -> true
+ | Sail_impl_base.Read_RISCV_reserved_acquire -> true
+ | Sail_impl_base.Read_RISCV_reserved_strong_acquire -> true
+ | Sail_impl_base.Read_X86_locked -> true
+end
+
+
+val read_mem : forall 'regs 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> M 'regs 'b 'e
+let read_mem read_kind addr sz state =
+ let addr = unsigned addr in
+ let addrs = range addr (addr+sz-1) in
+ let memory_value = List.map (fun addr -> Map_extra.find addr state.memstate) addrs in
+ let value = of_bits (Sail_values.internal_mem_value memory_value) in
+ if is_exclusive read_kind
+ then [(Value value, <| state with last_exclusive_operation_was_load = true |>)]
+ else [(Value value, state)]
+
+(* caps are aligned at 32 bytes *)
+let cap_alignment = (32 : integer)
+
+val read_tag : forall 'regs 'a 'e. Bitvector 'a => read_kind -> 'a -> M 'regs bitU 'e
+let read_tag read_kind addr state =
+ let addr = (unsigned addr) / cap_alignment in
+ let tag = match (Map.lookup addr state.tagstate) with
+ | Just t -> t
+ | Nothing -> B0
+ end in
+ if is_exclusive read_kind
+ then [(Value tag, <| state with last_exclusive_operation_was_load = true |>)]
+ else [(Value tag, state)]
+
+val excl_result : forall 'regs 'e. unit -> M 'regs bool 'e
+let excl_result () state =
+ let success =
+ (Value true, <| state with last_exclusive_operation_was_load = false |>) in
+ (Value false, state) :: if state.last_exclusive_operation_was_load then [success] else []
+
+val write_mem_ea : forall 'regs 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M 'regs unit 'e
+let write_mem_ea write_kind addr sz state =
+ [(Value (), <| state with write_ea = Just (write_kind,unsigned addr,sz) |>)]
+
+val write_mem_val : forall 'a 'regs 'b 'e. Bitvector 'a => 'a -> M 'regs bool 'e
+let write_mem_val v state =
+ let (_,addr,sz) = match state.write_ea with
+ | Nothing -> failwith "write ea has not been announced yet"
+ | Just write_ea -> write_ea end in
+ let addrs = range addr (addr+sz-1) in
+ let v = external_mem_value (bits_of v) in
+ let addresses_with_value = List.zip addrs v in
+ let memstate = List.foldl (fun mem (addr,v) -> Map.insert addr v mem)
+ state.memstate addresses_with_value in
+ [(Value true, <| state with memstate = memstate |>)]
+
+val write_tag : forall 'regs 'e. bitU -> M 'regs bool 'e
+let write_tag t state =
+ let (_,addr,_) = match state.write_ea with
+ | Nothing -> failwith "write ea has not been announced yet"
+ | Just write_ea -> write_ea end in
+ let taddr = addr / cap_alignment in
+ let tagstate = Map.insert taddr t state.tagstate in
+ [(Value true, <| state with tagstate = tagstate |>)]
+
+val read_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> M 'regs 'a 'e
+let read_reg reg state =
+ let v = reg.read_from state.regstate in
+ [(Value v,state)]
+(*let read_reg_range reg i j state =
+ let v = slice (get_reg state (name_of_reg reg)) i j in
+ [(Value (vec_to_bvec v),state)]
+let read_reg_bit reg i state =
+ let v = access (get_reg state (name_of_reg reg)) i in
+ [(Value v,state)]
+let read_reg_field reg regfield =
+ let (i,j) = register_field_indices reg regfield in
+ read_reg_range reg i j
+let read_reg_bitfield reg regfield =
+ let (i,_) = register_field_indices reg regfield in
+ read_reg_bit reg i *)
+
+let reg_deref = read_reg
+
+val write_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> 'a -> M 'regs unit 'e
+let write_reg reg v state =
+ [(Value (), <| state with regstate = reg.write_to state.regstate v |>)]
+
+let write_reg_ref (reg, v) = write_reg reg v
+
+val update_reg : forall 'regs 'a 'b 'e. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit 'e
+let update_reg reg f v state =
+ let current_value = get_reg state reg in
+ let new_value = f current_value v in
+ [(Value (), set_reg state reg new_value)]
+
+let write_reg_field reg regfield = update_reg reg regfield.set_field
+
+val update_reg_range : forall 'regs 'a 'b. Bitvector 'a, Bitvector 'b => register_ref 'regs 'a -> integer -> integer -> 'a -> 'b -> 'a
+let update_reg_range reg i j reg_val new_val = set_bits (reg.reg_is_inc) reg_val i j (bits_of new_val)
+let write_reg_range reg i j = update_reg reg (update_reg_range reg i j)
+
+let update_reg_pos reg i reg_val x = update_list reg.reg_is_inc reg_val i x
+let write_reg_pos reg i = update_reg reg (update_reg_pos reg i)
+
+let update_reg_bit reg i reg_val bit = set_bit (reg.reg_is_inc) reg_val i (to_bitU bit)
+let write_reg_bit reg i = update_reg reg (update_reg_bit reg i)
+
+let update_reg_field_range regfield i j reg_val new_val =
+ let current_field_value = regfield.get_field reg_val in
+ let new_field_value = set_bits (regfield.field_is_inc) current_field_value i j (bits_of new_val) in
+ regfield.set_field reg_val new_field_value
+let write_reg_field_range reg regfield i j = update_reg reg (update_reg_field_range regfield i j)
+
+let update_reg_field_pos regfield i reg_val x =
+ let current_field_value = regfield.get_field reg_val in
+ let new_field_value = update_list regfield.field_is_inc current_field_value i x in
+ regfield.set_field reg_val new_field_value
+let write_reg_field_pos reg regfield i = update_reg reg (update_reg_field_pos regfield i)
+
+let update_reg_field_bit regfield i reg_val bit =
+ let current_field_value = regfield.get_field reg_val in
+ let new_field_value = set_bit (regfield.field_is_inc) current_field_value i (to_bitU bit) in
+ regfield.set_field reg_val new_field_value
+let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i)
+
+val barrier : forall 'regs 'e. barrier_kind -> M 'regs unit 'e
+let barrier _ = return ()
+
+val footprint : forall 'regs 'e. M 'regs unit 'e
+let footprint s = return () s