diff options
Diffstat (limited to 'src/gen_lib/state_monad.lem')
| -rw-r--r-- | src/gen_lib/state_monad.lem | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/src/gen_lib/state_monad.lem b/src/gen_lib/state_monad.lem index 26179244..26f912fd 100644 --- a/src/gen_lib/state_monad.lem +++ b/src/gen_lib/state_monad.lem @@ -38,17 +38,17 @@ type result 'a 'e = (* State, nondeterminism and exception monad with result value type 'a and exception type 'e. *) -type monadS 'regs 'a 'e = sequential_state 'regs -> list (result 'a 'e * sequential_state 'regs) +type monadS 'regs 'a 'e = sequential_state 'regs -> set (result 'a 'e * sequential_state 'regs) val returnS : forall 'regs 'a 'e. 'a -> monadS 'regs 'a 'e -let returnS a s = [(Value a,s)] +let returnS a s = {(Value a,s)} val bindS : forall 'regs 'a 'b 'e. monadS 'regs 'a 'e -> ('a -> monadS 'regs 'b 'e) -> monadS 'regs 'b 'e let bindS m f (s : sequential_state 'regs) = - List.concatMap (function - | (Value a, s') -> f a s' - | (Ex e, s') -> [(Ex e, s')] - end) (m s) + Set.bigunion (Set.map (function + | (Value a, s') -> f a s' + | (Ex e, s') -> {(Ex e, s')} + end) (m s)) val seqS: forall 'regs 'b 'e. monadS 'regs unit 'e -> monadS 'regs 'b 'e -> monadS 'regs 'b 'e let seqS m n = bindS m (fun (_ : unit) -> n) @@ -56,8 +56,8 @@ let seqS m n = bindS m (fun (_ : unit) -> n) let inline (>>$=) = bindS let inline (>>$) = seqS -val chooseS : forall 'regs 'a 'e. list 'a -> monadS 'regs 'a 'e -let chooseS xs s = List.map (fun x -> (Value x, s)) xs +val chooseS : forall 'regs 'a 'e. SetType 'a => set 'a -> monadS 'regs 'a 'e +let chooseS xs s = Set.map (fun x -> (Value x, s)) xs val readS : forall 'regs 'a 'e. (sequential_state 'regs -> 'a) -> monadS 'regs 'a 'e let readS f = (fun s -> returnS (f s) s) @@ -66,7 +66,7 @@ val updateS : forall 'regs 'e. (sequential_state 'regs -> sequential_state 'regs let updateS f = (fun s -> returnS () (f s)) val failS : forall 'regs 'a 'e. string -> monadS 'regs 'a 'e -let failS msg s = [(Ex (Failure msg), s)] +let failS msg s = {(Ex (Failure msg), s)} val undefined_boolS : forall 'regval 'regs 'a 'e. unit -> monadS 'regs bool 'e let undefined_boolS () = @@ -78,15 +78,15 @@ val exitS : forall 'regs 'e 'a. unit -> monadS 'regs 'a 'e let exitS () = failS "exit" val throwS : forall 'regs 'a 'e. 'e -> monadS 'regs 'a 'e -let throwS e s = [(Ex (Throw e), s)] +let throwS e s = {(Ex (Throw e), s)} val try_catchS : forall 'regs 'a 'e1 'e2. monadS 'regs 'a 'e1 -> ('e1 -> monadS 'regs 'a 'e2) -> monadS 'regs 'a 'e2 let try_catchS m h s = - List.concatMap (function - | (Value a, s') -> returnS a s' - | (Ex (Throw e), s') -> h e s' - | (Ex (Failure msg), s') -> [(Ex (Failure msg), s')] - end) (m s) + Set.bigunion (Set.map (function + | (Value a, s') -> returnS a s' + | (Ex (Throw e), s') -> h e s' + | (Ex (Failure msg), s') -> {(Ex (Failure msg), s')} + end) (m s)) val assert_expS : forall 'regs 'e. bool -> string -> monadS 'regs unit 'e let assert_expS exp msg = if exp then returnS () else failS msg @@ -150,7 +150,7 @@ val excl_resultS : forall 'regs 'e. unit -> monadS 'regs bool 'e let excl_resultS () = readS (fun s -> s.last_exclusive_operation_was_load) >>$= (fun excl_load -> updateS (fun s -> <| s with last_exclusive_operation_was_load = false |>) >>$ - chooseS (if excl_load then [false; true] else [false])) + chooseS (if excl_load then {false; true} else {false})) val write_mem_eaS : forall 'regs 'e 'a. Bitvector 'a => write_kind -> 'a -> nat -> monadS 'regs unit 'e let write_mem_eaS write_kind addr sz = |
