diff options
Diffstat (limited to 'src/gen_lib')
| -rw-r--r-- | src/gen_lib/sail2_prompt_monad.lem | 82 | ||||
| -rw-r--r-- | src/gen_lib/sail2_state_lifting.lem | 36 | ||||
| -rw-r--r-- | src/gen_lib/sail2_state_monad.lem | 20 |
3 files changed, 86 insertions, 52 deletions
diff --git a/src/gen_lib/sail2_prompt_monad.lem b/src/gen_lib/sail2_prompt_monad.lem index 7503ca22..079375a3 100644 --- a/src/gen_lib/sail2_prompt_monad.lem +++ b/src/gen_lib/sail2_prompt_monad.lem @@ -9,9 +9,10 @@ type address = list bitU type monad 'regval 'a 'e = | Done of 'a (* Read a number of bytes from memory, returned in little endian order, - together with a tag. The first nat specifies the address, the second + with or without a tag. The first nat specifies the address, the second the number of bytes. *) - | Read_mem of read_kind * nat * nat * ((list memory_byte * bitU) -> monad 'regval 'a 'e) + | Read_mem of read_kind * nat * nat * (list memory_byte -> monad 'regval 'a 'e) + | Read_tagged_mem of read_kind * nat * nat * ((list memory_byte * bitU) -> monad 'regval 'a 'e) (* Tell the system a write is imminent, at the given address and with the given size. *) | Write_ea of write_kind * nat * nat * monad 'regval 'a 'e @@ -38,7 +39,8 @@ type monad 'regval 'a 'e = | Exception of 'e type event 'regval = - | E_read_mem of read_kind * nat * nat * (list memory_byte * bitU) + | E_read_mem of read_kind * nat * nat * list memory_byte + | E_read_tagged_mem of read_kind * nat * nat * (list memory_byte * bitU) | E_write_mem of write_kind * nat * nat * list memory_byte * bitU * bool | E_write_ea of write_kind * nat * nat | E_excl_res of bool @@ -57,18 +59,19 @@ let return a = Done a val bind : forall 'rv 'a 'b 'e. monad 'rv 'a 'e -> ('a -> monad 'rv 'b 'e) -> monad 'rv 'b 'e let rec bind m f = match m with | Done a -> f a - | Read_mem rk a sz k -> Read_mem rk a sz (fun v -> bind (k v) f) - | Write_mem wk a sz v t k -> Write_mem wk a sz v t (fun v -> bind (k v) f) - | Read_reg descr k -> Read_reg descr (fun v -> bind (k v) f) - | Excl_res k -> Excl_res (fun v -> bind (k v) f) - | Undefined k -> Undefined (fun v -> bind (k v) f) - | Write_ea wk a sz k -> Write_ea wk a sz (bind k f) - | Footprint k -> Footprint (bind k f) - | Barrier bk k -> Barrier bk (bind k f) - | Write_reg r v k -> Write_reg r v (bind k f) - | Print msg k -> Print msg (bind k f) - | Fail descr -> Fail descr - | Exception e -> Exception e + | Read_mem rk a sz k -> Read_mem rk a sz (fun v -> bind (k v) f) + | Read_tagged_mem rk a sz k -> Read_tagged_mem rk a sz (fun v -> bind (k v) f) + | Write_mem wk a sz v t k -> Write_mem wk a sz v t (fun v -> bind (k v) f) + | Read_reg descr k -> Read_reg descr (fun v -> bind (k v) f) + | Excl_res k -> Excl_res (fun v -> bind (k v) f) + | Undefined k -> Undefined (fun v -> bind (k v) f) + | Write_ea wk a sz k -> Write_ea wk a sz (bind k f) + | Footprint k -> Footprint (bind k f) + | Barrier bk k -> Barrier bk (bind k f) + | Write_reg r v k -> Write_reg r v (bind k f) + | Print msg k -> Print msg (bind k f) + | Fail descr -> Fail descr + | Exception e -> Exception e end val exit : forall 'rv 'a 'e. unit -> monad 'rv 'a 'e @@ -86,18 +89,19 @@ let throw e = Exception e val try_catch : forall 'rv 'a 'e1 'e2. monad 'rv 'a 'e1 -> ('e1 -> monad 'rv 'a 'e2) -> monad 'rv 'a 'e2 let rec try_catch m h = match m with | Done a -> Done a - | Read_mem rk a sz k -> Read_mem rk a sz (fun v -> try_catch (k v) h) - | Write_mem wk a sz v t k -> Write_mem wk a sz v t (fun v -> try_catch (k v) h) - | Read_reg descr k -> Read_reg descr (fun v -> try_catch (k v) h) - | Excl_res k -> Excl_res (fun v -> try_catch (k v) h) - | Undefined k -> Undefined (fun v -> try_catch (k v) h) - | Write_ea wk a sz k -> Write_ea wk a sz (try_catch k h) - | Footprint k -> Footprint (try_catch k h) - | Barrier bk k -> Barrier bk (try_catch k h) - | Write_reg r v k -> Write_reg r v (try_catch k h) - | Print msg k -> Print msg (try_catch k h) - | Fail descr -> Fail descr - | Exception e -> h e + | Read_mem rk a sz k -> Read_mem rk a sz (fun v -> try_catch (k v) h) + | Read_tagged_mem rk a sz k -> Read_tagged_mem rk a sz (fun v -> try_catch (k v) h) + | Write_mem wk a sz v t k -> Write_mem wk a sz v t (fun v -> try_catch (k v) h) + | Read_reg descr k -> Read_reg descr (fun v -> try_catch (k v) h) + | Excl_res k -> Excl_res (fun v -> try_catch (k v) h) + | Undefined k -> Undefined (fun v -> try_catch (k v) h) + | Write_ea wk a sz k -> Write_ea wk a sz (try_catch k h) + | Footprint k -> Footprint (try_catch k h) + | Barrier bk k -> Barrier bk (try_catch k h) + | Write_reg r v k -> Write_reg r v (try_catch k h) + | Print msg k -> Print msg (try_catch k h) + | Fail descr -> Fail descr + | Exception e -> h e end (* For early return, we abuse exceptions by throwing and catching @@ -135,19 +139,35 @@ let maybe_fail msg = function | Nothing -> Fail msg end -val read_mem_bytes : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv (list memory_byte * bitU) 'e +val read_tagged_mem_bytes : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv (list memory_byte * bitU) 'e +let read_tagged_mem_bytes rk addr sz = + bind + (maybe_fail "nat_of_bv" (nat_of_bv addr)) + (fun addr -> Read_tagged_mem rk addr (nat_of_int sz) return) + +val read_tagged_mem : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv ('b * bitU) 'e +let read_tagged_mem rk addr sz = + bind + (read_tagged_mem_bytes rk addr sz) + (fun (bytes, tag) -> + match of_bits (bits_of_mem_bytes bytes) with + | Just v -> return (v, tag) + | Nothing -> Fail "bits_of_mem_bytes" + end) + +val read_mem_bytes : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv (list memory_byte) 'e let read_mem_bytes rk addr sz = bind (maybe_fail "nat_of_bv" (nat_of_bv addr)) (fun addr -> Read_mem rk addr (nat_of_int sz) return) -val read_mem : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv ('b * bitU) 'e +val read_mem : forall 'rv 'a 'b 'e. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monad 'rv 'b 'e let read_mem rk addr sz = bind (read_mem_bytes rk addr sz) - (fun (bytes, tag) -> + (fun bytes -> match of_bits (bits_of_mem_bytes bytes) with - | Just v -> return (v, tag) + | Just v -> return v | Nothing -> Fail "bits_of_mem_bytes" end) diff --git a/src/gen_lib/sail2_state_lifting.lem b/src/gen_lib/sail2_state_lifting.lem index 0e7addbb..07a6215b 100644 --- a/src/gen_lib/sail2_state_lifting.lem +++ b/src/gen_lib/sail2_state_lifting.lem @@ -8,28 +8,32 @@ open import {isabelle} `Sail2_state_monad_lemmas` (* Lifting from prompt monad to state monad *) val liftState : forall 'regval 'regs 'a 'e. register_accessors 'regs 'regval -> monad 'regval 'a 'e -> monadS 'regs 'a 'e let rec liftState ra m = match m with - | (Done a) -> returnS a - | (Read_mem rk a sz k) -> bindS (read_mem_bytesS rk a sz) (fun v -> liftState ra (k v)) - | (Write_mem wk a sz v t k) -> bindS (write_mem_bytesS wk a sz v t) (fun v -> liftState ra (k v)) - | (Read_reg r k) -> bindS (read_regvalS ra r) (fun v -> liftState ra (k v)) - | (Excl_res k) -> bindS (excl_resultS ()) (fun v -> liftState ra (k v)) - | (Undefined k) -> bindS (undefined_boolS ()) (fun v -> liftState ra (k v)) - | (Write_reg r v k) -> seqS (write_regvalS ra r v) (liftState ra k) - | (Write_ea _ _ _ k) -> liftState ra k - | (Footprint k) -> liftState ra k - | (Barrier _ k) -> liftState ra k - | (Print _ k) -> liftState ra k (* TODO *) - | (Fail descr) -> failS descr - | (Exception e) -> throwS e + | (Done a) -> returnS a + | (Read_mem rk a sz k) -> bindS (read_mem_bytesS rk a sz) (fun v -> liftState ra (k v)) + | (Read_tagged_mem rk a sz k) -> bindS (read_tagged_mem_bytesS rk a sz) (fun v -> liftState ra (k v)) + | (Write_mem wk a sz v t k) -> bindS (write_mem_bytesS wk a sz v t) (fun v -> liftState ra (k v)) + | (Read_reg r k) -> bindS (read_regvalS ra r) (fun v -> liftState ra (k v)) + | (Excl_res k) -> bindS (excl_resultS ()) (fun v -> liftState ra (k v)) + | (Undefined k) -> bindS (undefined_boolS ()) (fun v -> liftState ra (k v)) + | (Write_reg r v k) -> seqS (write_regvalS ra r v) (liftState ra k) + | (Write_ea _ _ _ k) -> liftState ra k + | (Footprint k) -> liftState ra k + | (Barrier _ k) -> liftState ra k + | (Print _ k) -> liftState ra k (* TODO *) + | (Fail descr) -> failS descr + | (Exception e) -> throwS e end val emitEventS : forall 'regval 'regs 'a 'e. Eq 'regval => register_accessors 'regs 'regval -> event 'regval -> sequential_state 'regs -> maybe (sequential_state 'regs) let emitEventS ra e s = match e with - | E_read_mem _ addr sz (v, tag) -> + | E_read_mem _ addr sz v -> + Maybe.bind (get_mem_bytes addr sz s) (fun (v', _) -> + if v' = v then Just s else Nothing) + | E_read_tagged_mem _ addr sz (v, tag) -> Maybe.bind (get_mem_bytes addr sz s) (fun (v', tag') -> if v' = v && tag' = tag then Just s else Nothing) - | E_write_mem _ addr sz v tag _ -> - Just (put_mem_bytes addr sz v tag s) + | E_write_mem _ addr sz v tag success -> + if success then Just (put_mem_bytes addr sz v tag s) else Nothing | E_read_reg r v -> let (read_reg, _) = ra in Maybe.bind (read_reg r s.regstate) (fun v' -> diff --git a/src/gen_lib/sail2_state_monad.lem b/src/gen_lib/sail2_state_monad.lem index 6c1cd4bd..8626052f 100644 --- a/src/gen_lib/sail2_state_monad.lem +++ b/src/gen_lib/sail2_state_monad.lem @@ -129,18 +129,28 @@ let get_mem_bytes addr sz s = (fun mem_val -> (mem_val, List.foldl and_bit B1 (List.map (read_tag s) addrs))) (just_list (List.map (read_byte s) addrs)) -val read_mem_bytesS : forall 'regs 'e. read_kind -> nat -> nat -> monadS 'regs (list memory_byte * bitU) 'e -let read_mem_bytesS _ addr sz = +val read_tagged_mem_bytesS : forall 'regs 'e. read_kind -> nat -> nat -> monadS 'regs (list memory_byte * bitU) 'e +let read_tagged_mem_bytesS _ addr sz = readS (get_mem_bytes addr sz) >>$= maybe_failS "read_memS" -val read_memS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monadS 'regs ('b * bitU) 'e -let read_memS rk a sz = +val read_mem_bytesS : forall 'regs 'e. read_kind -> nat -> nat -> monadS 'regs (list memory_byte) 'e +let read_mem_bytesS rk addr sz = + read_tagged_mem_bytesS rk addr sz >>$= (fun (bytes, _) -> + returnS bytes) + +val read_tagged_memS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monadS 'regs ('b * bitU) 'e +let read_tagged_memS rk a sz = maybe_failS "nat_of_bv" (nat_of_bv a) >>$= (fun a -> - read_mem_bytesS rk a (nat_of_int sz) >>$= (fun (bytes, tag) -> + read_tagged_mem_bytesS rk a (nat_of_int sz) >>$= (fun (bytes, tag) -> maybe_failS "bits_of_mem_bytes" (of_bits (bits_of_mem_bytes bytes)) >>$= (fun mem_val -> returnS (mem_val, tag)))) +val read_memS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monadS 'regs 'b 'e +let read_memS rk a sz = + read_tagged_memS rk a sz >>$= (fun (bytes, _) -> + returnS bytes) + val excl_resultS : forall 'regs 'e. unit -> monadS 'regs bool 'e let excl_resultS = (* TODO: This used to be more deterministic, checking a flag in the state |
