diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 8 | ||||
| -rw-r--r-- | src/ast_util.mli | 1 | ||||
| -rw-r--r-- | src/gen_lib/prompt.lem | 149 | ||||
| -rw-r--r-- | src/gen_lib/state.lem | 174 | ||||
| -rw-r--r-- | src/lem_interp/sail_impl_base.lem | 27 | ||||
| -rw-r--r-- | src/monomorphise.ml | 2 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 52 | ||||
| -rw-r--r-- | src/rewriter.ml | 5 | ||||
| -rw-r--r-- | src/rewrites.ml | 8 |
9 files changed, 249 insertions, 177 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 2fd43db5..4ceb3e7f 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -679,6 +679,14 @@ let id_of_fundef (FD_aux (FD_function (_, _, _, funcls), (l, _))) = | Some id -> id | None -> raise (Reporting_basic.err_typ l "funcl list is empty") +let id_of_type_def_aux = function + | TD_abbrev (id, _, _) + | TD_record (id, _, _, _, _) + | TD_variant (id, _, _, _, _) + | TD_enum (id, _, _, _) + | TD_register (id, _, _, _) -> id +let id_of_type_def (TD_aux (td_aux, _)) = id_of_type_def_aux td_aux + module BE = struct type t = base_effect let compare be1 be2 = String.compare (string_of_base_effect be1) (string_of_base_effect be2) diff --git a/src/ast_util.mli b/src/ast_util.mli index a45ca4e9..68955387 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -192,6 +192,7 @@ val string_of_letbind : 'a letbind -> string val string_of_index_range : index_range -> string val id_of_fundef : 'a fundef -> id +val id_of_type_def : 'a type_def -> id val id_of_kid : kid -> id val kid_of_id : id -> kid diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem index 4646ef6f..77a39096 100644 --- a/src/gen_lib/prompt.lem +++ b/src/gen_lib/prompt.lem @@ -2,13 +2,12 @@ open import Pervasives_extra open import Sail_impl_base open import Sail_values -type MR 'a 'r = outcome_r 'a 'r -type M 'a = outcome 'a +type M 'a 'e = outcome 'a 'e -val return : forall 'a 'r. 'a -> MR 'a 'r +val return : forall 'a 'e. 'a -> M 'a 'e let return a = Done a -val bind : forall 'a 'b 'r. MR 'a 'r -> ('a -> MR 'b 'r) -> MR 'b 'r +val bind : forall 'a 'b 'e. M 'a 'e -> ('a -> M 'b 'e) -> M 'b 'e let rec bind m f = match m with | Done a -> f a | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (bind o f,opt)) @@ -22,61 +21,73 @@ let rec bind m f = match m with | Escape descr -> Escape descr | Fail descr -> Fail descr | Error descr -> Error descr - | Return a -> Return a + | Exception e -> Exception e | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (bind o f ,opt)) end let inline (>>=) = bind -val (>>) : forall 'b 'r. MR unit 'r -> MR 'b 'r -> MR 'b 'r -let inline (>>) m n = m >>= fun _ -> n +val (>>) : forall 'b 'e. M unit 'e -> M 'b 'e -> M 'b 'e +let inline (>>) m n = m >>= fun (_ : unit) -> n -val exit : forall 'a 'b. 'b -> M 'a -let exit s = Fail Nothing +val exit : forall 'a 'e. unit -> M 'a 'e +let exit () = Fail Nothing -val assert_exp : bool -> string -> M unit +val assert_exp : forall 'e. bool -> string -> M unit 'e let assert_exp exp msg = if exp then Done () else Fail (Just msg) -val early_return : forall 'r. 'r -> MR unit 'r -let early_return r = Return r +val throw : forall 'a 'e. 'e -> M 'a 'e +let throw e = Exception e -val liftR : forall 'a 'r. M 'a -> MR 'a 'r -let rec liftR m = match m with +val try_catch : forall 'a 'e1 'e2. M 'a 'e1 -> ('e1 -> M 'a 'e2) -> M 'a 'e2 +let rec try_catch m h = match m with | Done a -> Done a - | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (liftR o,opt)) - | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (liftR o,opt)) - | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (liftR o,opt)) - | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (liftR o,opt)) - | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (liftR o,opt)) - | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (liftR o,opt)) - | Footprint o_s -> Footprint (let (o,opt) = o_s in (liftR o,opt)) - | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (liftR o,opt)) - | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (liftR o,opt)) + | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (try_catch o h,opt)) + | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (try_catch o h,opt)) + | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (try_catch o h,opt)) + | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (try_catch o h,opt)) + | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (try_catch o h,opt)) + | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (try_catch o h,opt)) + | Footprint o_s -> Footprint (let (o,opt) = o_s in (try_catch o h,opt)) + | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (try_catch o h,opt)) | Escape descr -> Escape descr | Fail descr -> Fail descr | Error descr -> Error descr - | Return _ -> Error "uncaught early return" + | Exception e -> h e + | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (try_catch o h ,opt)) end -val catch_early_return : forall 'a 'r. MR 'a 'a -> M 'a -let rec catch_early_return m = match m with - | Done a -> Done a - | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (catch_early_return o,opt)) - | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (catch_early_return o,opt)) - | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (catch_early_return o,opt)) - | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (catch_early_return o,opt)) - | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (catch_early_return o,opt)) - | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (catch_early_return o,opt)) - | Footprint o_s -> Footprint (let (o,opt) = o_s in (catch_early_return o,opt)) - | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (catch_early_return o,opt)) - | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (catch_early_return o,opt)) - | Escape descr -> Escape descr - | Fail descr -> Fail descr - | Error descr -> Error descr - | Return a -> Done a -end - -val read_mem : forall 'a 'b. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'b +(* For early return, we abuse exceptions by throwing and catching + the return value. The exception type is "either 'r 'e", where "Right e" + represents a proper exception and "Left r" an early return of value "r". *) +type MR 'a 'r 'e = M 'a (either 'r 'e) + +val early_return : forall 'a 'r 'e. 'r -> MR 'a 'r 'e +let early_return r = throw (Left r) + +val catch_early_return : forall 'a 'e. MR 'a 'a 'e -> M 'a 'e +let catch_early_return m = + try_catch m + (function + | Left a -> return a + | Right e -> throw e + end) + +(* Lift to monad with early return by wrapping exceptions *) +val liftR : forall 'a 'r 'e. M 'a 'e -> MR 'a 'r 'e +let liftR m = try_catch m (fun e -> throw (Right e)) + +(* Catch exceptions in the presence of early returns *) +val try_catchR : forall 'a 'r 'e1 'e2. MR 'a 'r 'e1 -> ('e1 -> MR 'a 'r 'e2) -> MR 'a 'r 'e2 +let try_catchR m h = + try_catch m + (function + | Left r -> throw (Left r) + | Right e -> h e + end) + + +val read_mem : forall 'a 'b 'e. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'b 'e let read_mem dir rk addr sz = let addr = address_lifted_of_bitv (bits_of addr) in let sz = natFromInteger sz in @@ -85,24 +96,24 @@ let read_mem dir rk addr sz = (Done bitv,Nothing) in Read_mem (rk,addr,sz) k -val excl_result : unit -> M bool +val excl_result : forall 'e. unit -> M bool 'e let excl_result () = let k successful = (return successful,Nothing) in Excl_res k -val write_mem_ea : forall 'a. Bitvector 'a => write_kind -> 'a -> integer -> M unit +val write_mem_ea : forall 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M unit 'e let write_mem_ea wk addr sz = let addr = address_lifted_of_bitv (bits_of addr) in let sz = natFromInteger sz in Write_ea (wk,addr,sz) (Done (),Nothing) -val write_mem_val : forall 'a. Bitvector 'a => 'a -> M bool +val write_mem_val : forall 'a 'e. Bitvector 'a => 'a -> M bool 'e let write_mem_val v = let v = external_mem_value (bits_of v) in let k successful = (return successful,Nothing) in Write_memv v k -val read_reg_aux : forall 'a. Bitvector 'a => reg_name -> M 'a +val read_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> M 'a 'e let read_reg_aux reg = let k reg_value = let v = of_bits (internal_reg_value reg_value) in @@ -124,7 +135,7 @@ let read_reg_bitfield reg regfield = let reg_deref = read_reg -val write_reg_aux : forall 'a. Bitvector 'a => reg_name -> 'a -> M unit +val write_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> 'a -> M unit 'e let write_reg_aux reg_name v = let regval = external_reg_value reg_name (bits_of v) in Write_reg (reg_name,regval) (Done (), Nothing) @@ -150,29 +161,29 @@ let write_reg_field_bit = write_reg_field_pos let write_reg_ref (reg, v) = write_reg reg v -val barrier : barrier_kind -> M unit +val barrier : forall 'e. barrier_kind -> M unit 'e let barrier bk = Barrier bk (Done (), Nothing) -val footprint : M unit +val footprint : forall 'e. M unit 'e let footprint = Footprint (Done (),Nothing) -val iter_aux : forall 'regs 'e 'a. integer -> (integer -> 'a -> MR unit 'e) -> list 'a -> MR unit 'e +val iter_aux : forall 'a 'e. integer -> (integer -> 'a -> M unit 'e) -> list 'a -> M unit 'e let rec iter_aux i f xs = match xs with | x :: xs -> f i x >> iter_aux (i + 1) f xs | [] -> return () end -val iteri : forall 'regs 'e 'a. (integer -> 'a -> MR unit 'e) -> list 'a -> MR unit 'e +val iteri : forall 'a 'e. (integer -> 'a -> M unit 'e) -> list 'a -> M unit 'e let iteri f xs = iter_aux 0 f xs -val iter : forall 'regs 'e 'a. ('a -> MR unit 'e) -> list 'a -> MR unit 'e +val iter : forall 'a 'e. ('a -> M unit 'e) -> list 'a -> M unit 'e let iter f xs = iteri (fun _ x -> f x) xs -val foreachM_inc : forall 'vars 'r. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> MR 'vars 'r) -> MR 'vars 'r +val foreachM_inc : forall 'vars 'e. (integer * integer * integer) -> 'vars -> + (integer -> 'vars -> M 'vars 'e) -> M 'vars 'e let rec foreachM_inc (i,stop,by) vars body = if (by > 0 && i <= stop) || (by < 0 && stop <= i) then @@ -181,8 +192,8 @@ let rec foreachM_inc (i,stop,by) vars body = else return vars -val foreachM_dec : forall 'vars 'r. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> MR 'vars 'r) -> MR 'vars 'r +val foreachM_dec : forall 'vars 'e. (integer * integer * integer) -> 'vars -> + (integer -> 'vars -> M 'vars 'e) -> M 'vars 'e let rec foreachM_dec (i,stop,by) vars body = if (by > 0 && i >= stop) || (by < 0 && stop >= i) then @@ -194,21 +205,21 @@ val while_PP : forall 'vars. 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'va let rec while_PP vars cond body = if cond vars then while_PP (body vars) cond body else vars -val while_PM : forall 'vars 'r. 'vars -> ('vars -> bool) -> - ('vars -> MR 'vars 'r) -> MR 'vars 'r +val while_PM : forall 'vars 'e. 'vars -> ('vars -> bool) -> + ('vars -> M 'vars 'e) -> M 'vars 'e let rec while_PM vars cond body = if cond vars then body vars >>= fun vars -> while_PM vars cond body else return vars -val while_MP : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) -> - ('vars -> 'vars) -> MR 'vars 'r +val while_MP : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) -> + ('vars -> 'vars) -> M 'vars 'e let rec while_MP vars cond body = cond vars >>= fun cond_val -> if cond_val then while_MP (body vars) cond body else return vars -val while_MM : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) -> - ('vars -> MR 'vars 'r) -> MR 'vars 'r +val while_MM : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) -> + ('vars -> M 'vars 'e) -> M 'vars 'e let rec while_MM vars cond body = cond vars >>= fun cond_val -> if cond_val then @@ -220,21 +231,21 @@ let rec until_PP vars cond body = let vars = body vars in if (cond vars) then vars else until_PP (body vars) cond body -val until_PM : forall 'vars 'r. 'vars -> ('vars -> bool) -> - ('vars -> MR 'vars 'r) -> MR 'vars 'r +val until_PM : forall 'vars 'e. 'vars -> ('vars -> bool) -> + ('vars -> M 'vars 'e) -> M 'vars 'e let rec until_PM vars cond body = body vars >>= fun vars -> if (cond vars) then return vars else until_PM vars cond body -val until_MP : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) -> - ('vars -> 'vars) -> MR 'vars 'r +val until_MP : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) -> + ('vars -> 'vars) -> M 'vars 'e let rec until_MP vars cond body = let vars = body vars in cond vars >>= fun cond_val -> if cond_val then return vars else until_MP vars cond body -val until_MM : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) -> - ('vars -> MR 'vars 'r) -> MR 'vars 'r +val until_MM : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) -> + ('vars -> M 'vars 'e) -> M 'vars 'e let rec until_MM vars cond body = body vars >>= fun vars -> cond vars >>= fun cond_val -> diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index fa0fcd24..f9011323 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -23,55 +23,79 @@ let init_state regs = write_ea = Nothing; last_exclusive_operation_was_load = false |> -(* State, nondeterminism and exception monad with result type 'a - and exception type 'e. *) -type ME 'regs 'a 'e = sequential_state 'regs -> list ((either 'a 'e) * sequential_state 'regs) - -(* By default, we use strings to distinguish between different types of exceptions *) -type M 'regs 'a = ME 'regs 'a string +type ex 'e = + | Exit + | Assert of string + | Throw of 'e -(* For early return, we abuse exceptions by throwing and catching - the return value. The exception type is "either 'r string", where "Right e" - represents a proper exception and "Left r" an early return of value "r". *) -type MR 'regs 'a 'r = ME 'regs 'a (either 'r string) +type result 'a 'e = + | Value of 'a + | Exception of (ex 'e) -val liftR : forall 'a 'r 'regs. M 'regs 'a -> MR 'regs 'a 'r -let liftR m s = List.map (function - | (Left a, s') -> (Left a, s') - | (Right e, s') -> (Right (Right e), s') - end) (m s) +(* State, nondeterminism and exception monad with result value type 'a + and exception type 'e. *) +type M 'regs 'a 'e = sequential_state 'regs -> list (result 'a 'e * sequential_state 'regs) -val return : forall 'regs 'a 'e. 'a -> ME 'regs 'a 'e -let return a s = [(Left a,s)] +val return : forall 'regs 'a 'e. 'a -> M 'regs 'a 'e +let return a s = [(Value a,s)] -val bind : forall 'regs 'a 'b 'e. ME 'regs 'a 'e -> ('a -> ME 'regs 'b 'e) -> ME 'regs 'b 'e +val bind : forall 'regs 'a 'b 'e. M 'regs 'a 'e -> ('a -> M 'regs 'b 'e) -> M 'regs 'b 'e let bind m f (s : sequential_state 'regs) = List.concatMap (function - | (Left a, s') -> f a s' - | (Right e, s') -> [(Right e, s')] + | (Value a, s') -> f a s' + | (Exception e, s') -> [(Exception e, s')] end) (m s) let inline (>>=) = bind -val (>>): forall 'regs 'b 'e. ME 'regs unit 'e -> ME 'regs 'b 'e -> ME 'regs 'b 'e -let inline (>>) m n = m >>= fun _ -> n +val (>>): forall 'regs 'b 'e. M 'regs unit 'e -> M 'regs 'b 'e -> M 'regs 'b 'e +let inline (>>) m n = m >>= fun (_ : unit) -> n -val exit : forall 'regs 'e 'a. 'e -> M 'regs 'a -let exit _ s = [(Right "exit", s)] +val throw : forall 'regs 'a 'e. 'e -> M 'regs 'a 'e +let throw e s = [(Exception (Throw e), s)] -val assert_exp : forall 'regs. bool -> string -> M 'regs unit -let assert_exp exp msg s = if exp then [(Left (), s)] else [(Right msg, s)] +val try_catch : forall 'regs 'a 'e1 'e2. M 'regs 'a 'e1 -> ('e1 -> M 'regs 'a 'e2) -> M 'regs 'a 'e2 +let try_catch m h s = + List.concatMap (function + | (Value a, s') -> return a s' + | (Exception (Throw e), s') -> h e s' + | (Exception Exit, s') -> [(Exception Exit, s')] + | (Exception (Assert msg), s') -> [(Exception (Assert msg), s')] + end) (m s) -val early_return : forall 'regs 'a 'r. 'r -> MR 'regs 'a 'r -let early_return r s = [(Right (Left r), s)] +val exit : forall 'regs 'e 'a. unit -> M 'regs 'a 'e +let exit () s = [(Exception Exit, s)] -val catch_early_return : forall 'regs 'a. MR 'regs 'a 'a -> M 'regs 'a -let catch_early_return m s = - List.map +val assert_exp : forall 'regs 'e. bool -> string -> M 'regs unit 'e +let assert_exp exp msg s = if exp then [(Value (), s)] else [(Exception (Assert msg), s)] + +(* For early return, we abuse exceptions by throwing and catching + the return value. The exception type is "either 'r 'e", where "Right e" + represents a proper exception and "Left r" an early return of value "r". *) +type MR 'regs 'a 'r 'e = M 'regs 'a (either 'r 'e) + +val early_return : forall 'regs 'a 'r 'e. 'r -> MR 'regs 'a 'r 'e +let early_return r = throw (Left r) + +val catch_early_return : forall 'regs 'a 'e. MR 'regs 'a 'a 'e -> M 'regs 'a 'e +let catch_early_return m = + try_catch m + (function + | Left a -> return a + | Right e -> throw e + end) + +(* Lift to monad with early return by wrapping exceptions *) +val liftR : forall 'a 'r 'regs 'e. M 'regs 'a 'e -> MR 'regs 'a 'r 'e +let liftR m = try_catch m (fun e -> throw (Right e)) + +(* Catch exceptions in the presence of early returns *) +val try_catchR : forall 'regs 'a 'r 'e1 'e2. MR 'regs 'a 'r 'e1 -> ('e1 -> MR 'regs 'a 'r 'e2) -> MR 'regs 'a 'r 'e2 +let try_catchR m h = + try_catch m (function - | (Right (Left a), s') -> (Left a, s') - | (Right (Right e), s') -> (Right e, s') - | (Left a, s') -> (Left a, s') - end) (m s) + | Left r -> throw (Left r) + | Right e -> h e + end) val range : integer -> integer -> list integer let rec range i j = @@ -103,20 +127,20 @@ let is_exclusive = function end -val read_mem : forall 'regs 'a 'b. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'regs 'b +val read_mem : forall 'regs 'a 'b 'e. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'regs 'b 'e let read_mem dir read_kind addr sz state = let addr = unsigned addr in let addrs = range addr (addr+sz-1) in let memory_value = List.map (fun addr -> Map_extra.find addr state.memstate) addrs in let value = of_bits (Sail_values.internal_mem_value dir memory_value) in if is_exclusive read_kind - then [(Left value, <| state with last_exclusive_operation_was_load = true |>)] - else [(Left value, state)] + then [(Value value, <| state with last_exclusive_operation_was_load = true |>)] + else [(Value value, state)] (* caps are aligned at 32 bytes *) let cap_alignment = (32 : integer) -val read_tag : forall 'regs 'a. Bitvector 'a => bool -> read_kind -> 'a -> M 'regs bitU +val read_tag : forall 'regs 'a 'e. Bitvector 'a => bool -> read_kind -> 'a -> M 'regs bitU 'e let read_tag dir read_kind addr state = let addr = (unsigned addr) / cap_alignment in let tag = match (Map.lookup addr state.tagstate) with @@ -124,20 +148,20 @@ let read_tag dir read_kind addr state = | Nothing -> B0 end in if is_exclusive read_kind - then [(Left tag, <| state with last_exclusive_operation_was_load = true |>)] - else [(Left tag, state)] + then [(Value tag, <| state with last_exclusive_operation_was_load = true |>)] + else [(Value tag, state)] -val excl_result : forall 'regs. unit -> M 'regs bool +val excl_result : forall 'regs 'e. unit -> M 'regs bool 'e let excl_result () state = let success = - (Left true, <| state with last_exclusive_operation_was_load = false |>) in - (Left false, state) :: if state.last_exclusive_operation_was_load then [success] else [] + (Value true, <| state with last_exclusive_operation_was_load = false |>) in + (Value false, state) :: if state.last_exclusive_operation_was_load then [success] else [] -val write_mem_ea : forall 'regs 'a. Bitvector 'a => write_kind -> 'a -> integer -> M 'regs unit +val write_mem_ea : forall 'regs 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M 'regs unit 'e let write_mem_ea write_kind addr sz state = - [(Left (), <| state with write_ea = Just (write_kind,unsigned addr,sz) |>)] + [(Value (), <| state with write_ea = Just (write_kind,unsigned addr,sz) |>)] -val write_mem_val : forall 'a 'regs 'b. Bitvector 'a => 'a -> M 'regs bool +val write_mem_val : forall 'a 'regs 'b 'e. Bitvector 'a => 'a -> M 'regs bool 'e let write_mem_val v state = let (write_kind,addr,sz) = match state.write_ea with | Nothing -> failwith "write ea has not been announced yet" @@ -147,27 +171,27 @@ let write_mem_val v state = let addresses_with_value = List.zip addrs v in let memstate = List.foldl (fun mem (addr,v) -> Map.insert addr v mem) state.memstate addresses_with_value in - [(Left true, <| state with memstate = memstate |>)] + [(Value true, <| state with memstate = memstate |>)] -val write_tag : forall 'regs. bitU -> M 'regs bool +val write_tag : forall 'regs 'e. bitU -> M 'regs bool 'e let write_tag t state = let (write_kind,addr,sz) = match state.write_ea with | Nothing -> failwith "write ea has not been announced yet" | Just write_ea -> write_ea end in let taddr = addr / cap_alignment in let tagstate = Map.insert taddr t state.tagstate in - [(Left true, <| state with tagstate = tagstate |>)] + [(Value true, <| state with tagstate = tagstate |>)] -val read_reg : forall 'regs 'a. register_ref 'regs 'a -> M 'regs 'a +val read_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> M 'regs 'a 'e let read_reg reg state = let v = reg.read_from state.regstate in - [(Left v,state)] + [(Value v,state)] (*let read_reg_range reg i j state = let v = slice (get_reg state (name_of_reg reg)) i j in - [(Left (vec_to_bvec v),state)] + [(Value (vec_to_bvec v),state)] let read_reg_bit reg i state = let v = access (get_reg state (name_of_reg reg)) i in - [(Left v,state)] + [(Value v,state)] let read_reg_field reg regfield = let (i,j) = register_field_indices reg regfield in read_reg_range reg i j @@ -177,17 +201,17 @@ let read_reg_bitfield reg regfield = let reg_deref = read_reg -val write_reg : forall 'regs 'a. register_ref 'regs 'a -> 'a -> M 'regs unit +val write_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> 'a -> M 'regs unit 'e let write_reg reg v state = - [(Left (), <| state with regstate = reg.write_to state.regstate v |>)] + [(Value (), <| state with regstate = reg.write_to state.regstate v |>)] let write_reg_ref (reg, v) = write_reg reg v -val update_reg : forall 'regs 'a 'b. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit +val update_reg : forall 'regs 'a 'b 'e. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit 'e let update_reg reg f v state = let current_value = get_reg state reg in let new_value = f current_value v in - [(Left (), set_reg state reg new_value)] + [(Value (), set_reg state reg new_value)] let write_reg_field reg regfield = update_reg reg regfield.set_field @@ -219,26 +243,26 @@ let update_reg_field_bit regfield i reg_val bit = regfield.set_field reg_val new_field_value let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i) -val barrier : forall 'regs. barrier_kind -> M 'regs unit +val barrier : forall 'regs 'e. barrier_kind -> M 'regs unit 'e let barrier _ = return () -val footprint : forall 'regs. M 'regs unit +val footprint : forall 'regs 'e. M 'regs unit 'e let footprint s = return () s -val iter_aux : forall 'regs 'e 'a. integer -> (integer -> 'a -> ME 'regs unit 'e) -> list 'a -> ME 'regs unit 'e +val iter_aux : forall 'regs 'e 'a. integer -> (integer -> 'a -> M 'regs unit 'e) -> list 'a -> M 'regs unit 'e let rec iter_aux i f xs = match xs with | x :: xs -> f i x >> iter_aux (i + 1) f xs | [] -> return () end -val iteri : forall 'regs 'e 'a. (integer -> 'a -> ME 'regs unit 'e) -> list 'a -> ME 'regs unit 'e +val iteri : forall 'regs 'e 'a. (integer -> 'a -> M 'regs unit 'e) -> list 'a -> M 'regs unit 'e let iteri f xs = iter_aux 0 f xs -val iter : forall 'regs 'e 'a. ('a -> ME 'regs unit 'e) -> list 'a -> ME 'regs unit 'e +val iter : forall 'regs 'e 'a. ('a -> M 'regs unit 'e) -> list 'a -> M 'regs unit 'e let iter f xs = iteri (fun _ x -> f x) xs val foreachM_inc : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e + (integer -> 'vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e let rec foreachM_inc (i,stop,by) vars body = if (by > 0 && i <= stop) || (by < 0 && stop <= i) then @@ -248,7 +272,7 @@ let rec foreachM_inc (i,stop,by) vars body = val foreachM_dec : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e + (integer -> 'vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e let rec foreachM_dec (i,stop,by) vars body = if (by > 0 && i >= stop) || (by < 0 && stop >= i) then @@ -261,20 +285,20 @@ let rec while_PP vars cond body = if cond vars then while_PP (body vars) cond body else vars val while_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) -> - ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e + ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e let rec while_PM vars cond body = if cond vars then body vars >>= fun vars -> while_PM vars cond body else return vars -val while_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) -> - ('vars -> 'vars) -> ME 'regs 'vars 'e +val while_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) -> + ('vars -> 'vars) -> M 'regs 'vars 'e let rec while_MP vars cond body = cond vars >>= fun cond_val -> if cond_val then while_MP (body vars) cond body else return vars -val while_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) -> - ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e +val while_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) -> + ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e let rec while_MM vars cond body = cond vars >>= fun cond_val -> if cond_val then @@ -287,20 +311,20 @@ let rec until_PP vars cond body = if (cond vars) then vars else until_PP (body vars) cond body val until_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) -> - ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e + ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e let rec until_PM vars cond body = body vars >>= fun vars -> if (cond vars) then return vars else until_PM vars cond body -val until_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) -> - ('vars -> 'vars) -> ME 'regs 'vars 'e +val until_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) -> + ('vars -> 'vars) -> M 'regs 'vars 'e let rec until_MP vars cond body = let vars = body vars in cond vars >>= fun cond_val -> if cond_val then return vars else until_MP vars cond body -val until_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) -> - ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e +val until_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) -> + ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e let rec until_MM vars cond body = body vars >>= fun vars -> cond vars >>= fun cond_val -> diff --git a/src/lem_interp/sail_impl_base.lem b/src/lem_interp/sail_impl_base.lem index 421219da..368f7505 100644 --- a/src/lem_interp/sail_impl_base.lem +++ b/src/lem_interp/sail_impl_base.lem @@ -905,36 +905,35 @@ end (* the address_lifted types should go away here and be replaced by address *) type with_aux 'o = 'o * maybe ((unit -> (string * string)) * ((list (reg_name * register_value)) -> list event)) -type outcome_r 'a 'r = +type outcome 'a 'e = (* Request to read memory, value is location to read, integer is size to read, followed by registers that were used in computing that size *) - | Read_mem of (read_kind * address_lifted * nat) * (memory_value -> with_aux (outcome_r 'a 'r)) + | Read_mem of (read_kind * address_lifted * nat) * (memory_value -> with_aux (outcome 'a 'e)) (* Tell the system a write is imminent, at address lifted, of size nat *) - | Write_ea of (write_kind * address_lifted * nat) * (with_aux (outcome_r 'a 'r)) + | Write_ea of (write_kind * address_lifted * nat) * (with_aux (outcome 'a 'e)) (* Request the result of store-exclusive *) - | Excl_res of (bool -> with_aux (outcome_r 'a 'r)) + | Excl_res of (bool -> with_aux (outcome 'a 'e)) (* Request to write memory at last signalled address. Memory value should be 8 times the size given in ea signal *) - | Write_memv of memory_value * (bool -> with_aux (outcome_r 'a 'r)) + | Write_memv of memory_value * (bool -> with_aux (outcome 'a 'e)) (* Request a memory barrier *) - | Barrier of barrier_kind * with_aux (outcome_r 'a 'r) + | Barrier of barrier_kind * with_aux (outcome 'a 'e) (* Tell the system to dynamically recalculate dependency footprint *) - | Footprint of with_aux (outcome_r 'a 'r) + | Footprint of with_aux (outcome 'a 'e) (* Request to read register, will track dependency when mode.track_values *) - | Read_reg of reg_name * (register_value -> with_aux (outcome_r 'a 'r)) + | Read_reg of reg_name * (register_value -> with_aux (outcome 'a 'e)) (* Request to write register *) - | Write_reg of (reg_name * register_value) * with_aux (outcome_r 'a 'r) + | Write_reg of (reg_name * register_value) * with_aux (outcome 'a 'e) | Escape of maybe string (*Result of a failed assert with possible error message to report*) | Fail of maybe string - (* Early return with value of type 'r *) - | Return of 'r - | Internal of (maybe string * maybe (unit -> string)) * with_aux (outcome_r 'a 'r) + (* Exception of type 'e *) + | Exception of 'e + | Internal of (maybe string * maybe (unit -> string)) * with_aux (outcome 'a 'e) | Done of 'a | Error of string -type outcome 'a = outcome_r 'a unit -type outcome_s 'a = with_aux (outcome 'a) +type outcome_s 'a 'e = with_aux (outcome 'a 'e) (* first string : output of instruction_stack_to_string second string: output of local_variables_to_string *) diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 8730b019..0cbeda49 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -102,7 +102,7 @@ let rec subst_nc substs (NC_aux (nc,l) as n_constraint) = begin match KBindings.find kid substs with | Nexp_aux (Nexp_constant i,_) -> - if List.mem i is then re NC_true else re NC_false + if List.exists (fun j -> Big_int.eq_big_int i j) is then re NC_true else re NC_false | nexp -> raise (Reporting_basic.err_general l ("Unable to substitute " ^ string_of_nexp nexp ^ diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 96f312fb..6a3d1293 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -249,12 +249,12 @@ let doc_typ_lem, doc_atomic_typ_lem = | Typ_fn(arg,ret,efct) -> let ret_typ = if effectful efct - then separate space [string "_M"; fn_typ true ret] + then separate space [string "M"; fn_typ true ret] else separate space [fn_typ false ret] in let arg_typs = match arg with | Typ_aux (Typ_tup typs, _) -> - List.map (app_typ true) typs - | _ -> [tup_typ true arg] in + List.map (app_typ false) typs + | _ -> [tup_typ false arg] in let tpp = separate (space ^^ arrow ^^ space) (arg_typs @ [ret_typ]) in (* once we have proper excetions we need to know what the exceptions type is *) if atyp_needed then parens tpp else tpp @@ -355,7 +355,7 @@ let doc_tannot_lem eff typ = if contains_t_pp_var typ then empty else let ta = doc_typ_lem typ in - if eff then string " : _M " ^^ parens ta + if eff then string " : M " ^^ parens ta else string " : " ^^ ta (* doc_lit_lem gets as an additional parameter the type information from the @@ -685,7 +685,7 @@ let doc_exp_lem, doc_let_lem = contains_t_pp_var (typ_of full_exp) then aexp_needed, epp else - let tannot = separate space [string "_MR"; + let tannot = separate space [string "MR"; doc_atomic_typ_lem false (typ_of full_exp); doc_atomic_typ_lem false (typ_of exp)] in true, doc_op colon epp tannot in @@ -855,7 +855,20 @@ let doc_exp_lem, doc_let_lem = (separate_map (break 1) (doc_case early_ret) pexps) ^/^ (string "end")) in if aexp_needed then parens (align epp) else align epp - | E_exit e -> liftR (separate space [string "exit"; expY e;]) + | E_try (e, pexps) -> + if effectful (effect_of e) then + let try_catch = if early_ret then "try_catchR" else "try_catch" in + let epp = + group ((separate space [string try_catch; expY e; string "(function "]) ^/^ + (separate_map (break 1) (doc_case early_ret) pexps) ^/^ + (string "end)")) in + if aexp_needed then parens (align epp) else align epp + else + raise (Reporting_basic.err_todo l "Warning: try-block around pure expression") + | E_throw e -> + let epp = liftR (separate space [string "throw"; expY e]) in + if aexp_needed then parens (align epp) else align epp + | E_exit e -> liftR (separate space [string "exit"; expY e]) | E_assert (e1,e2) -> let epp = liftR (separate space [string "assert_exp"; expY e1; expY e2]) in if aexp_needed then parens (align epp) else align epp @@ -1495,6 +1508,12 @@ let find_registers (Defs defs) = | _ -> acc ) [] defs +let find_exc_typ (Defs defs) = + let is_exc_typ_def = function + | DEF_type td -> string_of_id (id_of_type_def td) = "exception" + | _ -> false in + if List.exists is_exc_typ_def defs then "exception" else "unit" + let doc_regstate_lem registers = let l = Parse_ast.Unknown in let annot = (l, None) in @@ -1528,12 +1547,7 @@ let doc_regstate_lem registers = doc_op equals (string "let initial_regstate") (doc_exp_lem false false exp) else empty in - concat [ - doc_typdef_lem (TD_aux (regstate, annot)); hardline; - hardline; - string "type _MR 'a 'r = MR regstate 'a 'r"; hardline; - string "type _M 'a = M regstate 'a" - ], + doc_typdef_lem (TD_aux (regstate, annot)), initregstate let doc_register_refs_lem registers = @@ -1564,6 +1578,7 @@ let pp_defs_lem (types_file,types_modules) (defs_file,defs_modules) d top_line = let (typdefs,valdefs) = doc_defs_lem d in let regstate_def, initregstate_def = doc_regstate_lem (find_registers d) in let register_refs = doc_register_refs_lem (find_registers d) in + let exc_typ = find_exc_typ d in (print types_file) (concat [string "(*" ^^ (string top_line) ^^ string "*)";hardline; @@ -1585,10 +1600,17 @@ let pp_defs_lem (types_file,types_modules) (defs_file,defs_modules) d top_line = hardline; if !opt_sequential then concat [regstate_def; hardline; - hardline; - register_refs] + hardline; + string ("type MR 'a 'r = State.MR regstate 'a 'r " ^ exc_typ); hardline; + string ("type M 'a = State.M regstate 'a " ^ exc_typ); hardline; + hardline; + register_refs + ] else - concat [string "type _MR 'a 'r = MR 'a 'r"; hardline; string "type _M 'a = M 'a"; hardline] + concat [ + string ("type MR 'a 'r = Prompt.MR 'a 'r " ^ exc_typ); hardline; + string ("type M 'a = Prompt.M 'a " ^ exc_typ); hardline + ] ]); (print defs_file) (concat diff --git a/src/rewriter.ml b/src/rewriter.ml index 88fb17a5..31bcb577 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -130,12 +130,11 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match snd annot with | E_record_update(e,fexps) -> union_effects (effect_of e) (effect_of_fexps fexps) | E_field (e,_) -> effect_of e - | E_case (e,pexps) -> + | E_case (e,pexps) | E_try (e,pexps) -> List.fold_left union_effects (effect_of e) (List.map effect_of_pexp pexps) | E_let (lb,e) -> union_effects (effect_of_lb lb) (effect_of e) | E_assign (lexp,e) -> union_effects (effect_of_lexp lexp) (effect_of e) - | E_exit e -> union_effects eff (effect_of e) - | E_return e -> union_effects eff (effect_of e) + | E_exit e | E_return e | E_throw e -> union_effects eff (effect_of e) | E_sizeof _ | E_sizeof_internal _ | E_constraint _ -> no_effect | E_assert (c,m) -> union_effects eff (union_eff_exps [c; m]) | E_comment _ | E_comment_struc _ -> no_effect diff --git a/src/rewrites.ml b/src/rewrites.ml index a42335b9..32ffe54a 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2378,6 +2378,11 @@ let rewrite_defs_letbind_effects = n_exp_name exp1 (fun exp1 -> n_pexpL newreturn pexps (fun pexps -> k (rewrap (E_case (exp1,pexps))))) + | E_try (exp1,pexps) -> + let newreturn = effectful exp1 || List.exists effectful_pexp pexps in + n_exp_name exp1 (fun exp1 -> + n_pexpL newreturn pexps (fun pexps -> + k (rewrap (E_try (exp1,pexps))))) | E_let (lb,body) -> n_lb lb (fun lb -> rewrap (E_let (lb,n_exp body k))) @@ -2416,6 +2421,9 @@ let rewrite_defs_letbind_effects = | E_return exp' -> n_exp_name exp' (fun exp' -> k (rewrap (E_return exp'))) + | E_throw exp' -> + n_exp_name exp' (fun exp' -> + k (rewrap (E_throw exp'))) | E_internal_plet _ -> failwith "E_internal_plet should not be here yet" in let rewrite_fun _ (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),fdannot)) = |
