summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml8
-rw-r--r--src/ast_util.mli1
-rw-r--r--src/gen_lib/prompt.lem149
-rw-r--r--src/gen_lib/state.lem174
-rw-r--r--src/lem_interp/sail_impl_base.lem27
-rw-r--r--src/monomorphise.ml2
-rw-r--r--src/pretty_print_lem.ml52
-rw-r--r--src/rewriter.ml5
-rw-r--r--src/rewrites.ml8
9 files changed, 249 insertions, 177 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 2fd43db5..4ceb3e7f 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -679,6 +679,14 @@ let id_of_fundef (FD_aux (FD_function (_, _, _, funcls), (l, _))) =
| Some id -> id
| None -> raise (Reporting_basic.err_typ l "funcl list is empty")
+let id_of_type_def_aux = function
+ | TD_abbrev (id, _, _)
+ | TD_record (id, _, _, _, _)
+ | TD_variant (id, _, _, _, _)
+ | TD_enum (id, _, _, _)
+ | TD_register (id, _, _, _) -> id
+let id_of_type_def (TD_aux (td_aux, _)) = id_of_type_def_aux td_aux
+
module BE = struct
type t = base_effect
let compare be1 be2 = String.compare (string_of_base_effect be1) (string_of_base_effect be2)
diff --git a/src/ast_util.mli b/src/ast_util.mli
index a45ca4e9..68955387 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -192,6 +192,7 @@ val string_of_letbind : 'a letbind -> string
val string_of_index_range : index_range -> string
val id_of_fundef : 'a fundef -> id
+val id_of_type_def : 'a type_def -> id
val id_of_kid : kid -> id
val kid_of_id : id -> kid
diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem
index 4646ef6f..77a39096 100644
--- a/src/gen_lib/prompt.lem
+++ b/src/gen_lib/prompt.lem
@@ -2,13 +2,12 @@ open import Pervasives_extra
open import Sail_impl_base
open import Sail_values
-type MR 'a 'r = outcome_r 'a 'r
-type M 'a = outcome 'a
+type M 'a 'e = outcome 'a 'e
-val return : forall 'a 'r. 'a -> MR 'a 'r
+val return : forall 'a 'e. 'a -> M 'a 'e
let return a = Done a
-val bind : forall 'a 'b 'r. MR 'a 'r -> ('a -> MR 'b 'r) -> MR 'b 'r
+val bind : forall 'a 'b 'e. M 'a 'e -> ('a -> M 'b 'e) -> M 'b 'e
let rec bind m f = match m with
| Done a -> f a
| Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (bind o f,opt))
@@ -22,61 +21,73 @@ let rec bind m f = match m with
| Escape descr -> Escape descr
| Fail descr -> Fail descr
| Error descr -> Error descr
- | Return a -> Return a
+ | Exception e -> Exception e
| Internal descr o_s -> Internal descr (let (o,opt) = o_s in (bind o f ,opt))
end
let inline (>>=) = bind
-val (>>) : forall 'b 'r. MR unit 'r -> MR 'b 'r -> MR 'b 'r
-let inline (>>) m n = m >>= fun _ -> n
+val (>>) : forall 'b 'e. M unit 'e -> M 'b 'e -> M 'b 'e
+let inline (>>) m n = m >>= fun (_ : unit) -> n
-val exit : forall 'a 'b. 'b -> M 'a
-let exit s = Fail Nothing
+val exit : forall 'a 'e. unit -> M 'a 'e
+let exit () = Fail Nothing
-val assert_exp : bool -> string -> M unit
+val assert_exp : forall 'e. bool -> string -> M unit 'e
let assert_exp exp msg = if exp then Done () else Fail (Just msg)
-val early_return : forall 'r. 'r -> MR unit 'r
-let early_return r = Return r
+val throw : forall 'a 'e. 'e -> M 'a 'e
+let throw e = Exception e
-val liftR : forall 'a 'r. M 'a -> MR 'a 'r
-let rec liftR m = match m with
+val try_catch : forall 'a 'e1 'e2. M 'a 'e1 -> ('e1 -> M 'a 'e2) -> M 'a 'e2
+let rec try_catch m h = match m with
| Done a -> Done a
- | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (liftR o,opt))
- | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (liftR o,opt))
- | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (liftR o,opt))
- | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (liftR o,opt))
- | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (liftR o,opt))
- | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (liftR o,opt))
- | Footprint o_s -> Footprint (let (o,opt) = o_s in (liftR o,opt))
- | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (liftR o,opt))
- | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (liftR o,opt))
+ | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (try_catch o h,opt))
+ | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (try_catch o h,opt))
+ | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (try_catch o h,opt))
+ | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (try_catch o h,opt))
+ | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (try_catch o h,opt))
+ | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (try_catch o h,opt))
+ | Footprint o_s -> Footprint (let (o,opt) = o_s in (try_catch o h,opt))
+ | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (try_catch o h,opt))
| Escape descr -> Escape descr
| Fail descr -> Fail descr
| Error descr -> Error descr
- | Return _ -> Error "uncaught early return"
+ | Exception e -> h e
+ | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (try_catch o h ,opt))
end
-val catch_early_return : forall 'a 'r. MR 'a 'a -> M 'a
-let rec catch_early_return m = match m with
- | Done a -> Done a
- | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (catch_early_return o,opt))
- | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (catch_early_return o,opt))
- | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (catch_early_return o,opt))
- | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (catch_early_return o,opt))
- | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (catch_early_return o,opt))
- | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (catch_early_return o,opt))
- | Footprint o_s -> Footprint (let (o,opt) = o_s in (catch_early_return o,opt))
- | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (catch_early_return o,opt))
- | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (catch_early_return o,opt))
- | Escape descr -> Escape descr
- | Fail descr -> Fail descr
- | Error descr -> Error descr
- | Return a -> Done a
-end
-
-val read_mem : forall 'a 'b. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'b
+(* For early return, we abuse exceptions by throwing and catching
+ the return value. The exception type is "either 'r 'e", where "Right e"
+ represents a proper exception and "Left r" an early return of value "r". *)
+type MR 'a 'r 'e = M 'a (either 'r 'e)
+
+val early_return : forall 'a 'r 'e. 'r -> MR 'a 'r 'e
+let early_return r = throw (Left r)
+
+val catch_early_return : forall 'a 'e. MR 'a 'a 'e -> M 'a 'e
+let catch_early_return m =
+ try_catch m
+ (function
+ | Left a -> return a
+ | Right e -> throw e
+ end)
+
+(* Lift to monad with early return by wrapping exceptions *)
+val liftR : forall 'a 'r 'e. M 'a 'e -> MR 'a 'r 'e
+let liftR m = try_catch m (fun e -> throw (Right e))
+
+(* Catch exceptions in the presence of early returns *)
+val try_catchR : forall 'a 'r 'e1 'e2. MR 'a 'r 'e1 -> ('e1 -> MR 'a 'r 'e2) -> MR 'a 'r 'e2
+let try_catchR m h =
+ try_catch m
+ (function
+ | Left r -> throw (Left r)
+ | Right e -> h e
+ end)
+
+
+val read_mem : forall 'a 'b 'e. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'b 'e
let read_mem dir rk addr sz =
let addr = address_lifted_of_bitv (bits_of addr) in
let sz = natFromInteger sz in
@@ -85,24 +96,24 @@ let read_mem dir rk addr sz =
(Done bitv,Nothing) in
Read_mem (rk,addr,sz) k
-val excl_result : unit -> M bool
+val excl_result : forall 'e. unit -> M bool 'e
let excl_result () =
let k successful = (return successful,Nothing) in
Excl_res k
-val write_mem_ea : forall 'a. Bitvector 'a => write_kind -> 'a -> integer -> M unit
+val write_mem_ea : forall 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M unit 'e
let write_mem_ea wk addr sz =
let addr = address_lifted_of_bitv (bits_of addr) in
let sz = natFromInteger sz in
Write_ea (wk,addr,sz) (Done (),Nothing)
-val write_mem_val : forall 'a. Bitvector 'a => 'a -> M bool
+val write_mem_val : forall 'a 'e. Bitvector 'a => 'a -> M bool 'e
let write_mem_val v =
let v = external_mem_value (bits_of v) in
let k successful = (return successful,Nothing) in
Write_memv v k
-val read_reg_aux : forall 'a. Bitvector 'a => reg_name -> M 'a
+val read_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> M 'a 'e
let read_reg_aux reg =
let k reg_value =
let v = of_bits (internal_reg_value reg_value) in
@@ -124,7 +135,7 @@ let read_reg_bitfield reg regfield =
let reg_deref = read_reg
-val write_reg_aux : forall 'a. Bitvector 'a => reg_name -> 'a -> M unit
+val write_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> 'a -> M unit 'e
let write_reg_aux reg_name v =
let regval = external_reg_value reg_name (bits_of v) in
Write_reg (reg_name,regval) (Done (), Nothing)
@@ -150,29 +161,29 @@ let write_reg_field_bit = write_reg_field_pos
let write_reg_ref (reg, v) = write_reg reg v
-val barrier : barrier_kind -> M unit
+val barrier : forall 'e. barrier_kind -> M unit 'e
let barrier bk = Barrier bk (Done (), Nothing)
-val footprint : M unit
+val footprint : forall 'e. M unit 'e
let footprint = Footprint (Done (),Nothing)
-val iter_aux : forall 'regs 'e 'a. integer -> (integer -> 'a -> MR unit 'e) -> list 'a -> MR unit 'e
+val iter_aux : forall 'a 'e. integer -> (integer -> 'a -> M unit 'e) -> list 'a -> M unit 'e
let rec iter_aux i f xs = match xs with
| x :: xs -> f i x >> iter_aux (i + 1) f xs
| [] -> return ()
end
-val iteri : forall 'regs 'e 'a. (integer -> 'a -> MR unit 'e) -> list 'a -> MR unit 'e
+val iteri : forall 'a 'e. (integer -> 'a -> M unit 'e) -> list 'a -> M unit 'e
let iteri f xs = iter_aux 0 f xs
-val iter : forall 'regs 'e 'a. ('a -> MR unit 'e) -> list 'a -> MR unit 'e
+val iter : forall 'a 'e. ('a -> M unit 'e) -> list 'a -> M unit 'e
let iter f xs = iteri (fun _ x -> f x) xs
-val foreachM_inc : forall 'vars 'r. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> MR 'vars 'r) -> MR 'vars 'r
+val foreachM_inc : forall 'vars 'e. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> M 'vars 'e) -> M 'vars 'e
let rec foreachM_inc (i,stop,by) vars body =
if (by > 0 && i <= stop) || (by < 0 && stop <= i)
then
@@ -181,8 +192,8 @@ let rec foreachM_inc (i,stop,by) vars body =
else return vars
-val foreachM_dec : forall 'vars 'r. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> MR 'vars 'r) -> MR 'vars 'r
+val foreachM_dec : forall 'vars 'e. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> M 'vars 'e) -> M 'vars 'e
let rec foreachM_dec (i,stop,by) vars body =
if (by > 0 && i >= stop) || (by < 0 && stop >= i)
then
@@ -194,21 +205,21 @@ val while_PP : forall 'vars. 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'va
let rec while_PP vars cond body =
if cond vars then while_PP (body vars) cond body else vars
-val while_PM : forall 'vars 'r. 'vars -> ('vars -> bool) ->
- ('vars -> MR 'vars 'r) -> MR 'vars 'r
+val while_PM : forall 'vars 'e. 'vars -> ('vars -> bool) ->
+ ('vars -> M 'vars 'e) -> M 'vars 'e
let rec while_PM vars cond body =
if cond vars then
body vars >>= fun vars -> while_PM vars cond body
else return vars
-val while_MP : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) ->
- ('vars -> 'vars) -> MR 'vars 'r
+val while_MP : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) ->
+ ('vars -> 'vars) -> M 'vars 'e
let rec while_MP vars cond body =
cond vars >>= fun cond_val ->
if cond_val then while_MP (body vars) cond body else return vars
-val while_MM : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) ->
- ('vars -> MR 'vars 'r) -> MR 'vars 'r
+val while_MM : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) ->
+ ('vars -> M 'vars 'e) -> M 'vars 'e
let rec while_MM vars cond body =
cond vars >>= fun cond_val ->
if cond_val then
@@ -220,21 +231,21 @@ let rec until_PP vars cond body =
let vars = body vars in
if (cond vars) then vars else until_PP (body vars) cond body
-val until_PM : forall 'vars 'r. 'vars -> ('vars -> bool) ->
- ('vars -> MR 'vars 'r) -> MR 'vars 'r
+val until_PM : forall 'vars 'e. 'vars -> ('vars -> bool) ->
+ ('vars -> M 'vars 'e) -> M 'vars 'e
let rec until_PM vars cond body =
body vars >>= fun vars ->
if (cond vars) then return vars else until_PM vars cond body
-val until_MP : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) ->
- ('vars -> 'vars) -> MR 'vars 'r
+val until_MP : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) ->
+ ('vars -> 'vars) -> M 'vars 'e
let rec until_MP vars cond body =
let vars = body vars in
cond vars >>= fun cond_val ->
if cond_val then return vars else until_MP vars cond body
-val until_MM : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) ->
- ('vars -> MR 'vars 'r) -> MR 'vars 'r
+val until_MM : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) ->
+ ('vars -> M 'vars 'e) -> M 'vars 'e
let rec until_MM vars cond body =
body vars >>= fun vars ->
cond vars >>= fun cond_val ->
diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem
index fa0fcd24..f9011323 100644
--- a/src/gen_lib/state.lem
+++ b/src/gen_lib/state.lem
@@ -23,55 +23,79 @@ let init_state regs =
write_ea = Nothing;
last_exclusive_operation_was_load = false |>
-(* State, nondeterminism and exception monad with result type 'a
- and exception type 'e. *)
-type ME 'regs 'a 'e = sequential_state 'regs -> list ((either 'a 'e) * sequential_state 'regs)
-
-(* By default, we use strings to distinguish between different types of exceptions *)
-type M 'regs 'a = ME 'regs 'a string
+type ex 'e =
+ | Exit
+ | Assert of string
+ | Throw of 'e
-(* For early return, we abuse exceptions by throwing and catching
- the return value. The exception type is "either 'r string", where "Right e"
- represents a proper exception and "Left r" an early return of value "r". *)
-type MR 'regs 'a 'r = ME 'regs 'a (either 'r string)
+type result 'a 'e =
+ | Value of 'a
+ | Exception of (ex 'e)
-val liftR : forall 'a 'r 'regs. M 'regs 'a -> MR 'regs 'a 'r
-let liftR m s = List.map (function
- | (Left a, s') -> (Left a, s')
- | (Right e, s') -> (Right (Right e), s')
- end) (m s)
+(* State, nondeterminism and exception monad with result value type 'a
+ and exception type 'e. *)
+type M 'regs 'a 'e = sequential_state 'regs -> list (result 'a 'e * sequential_state 'regs)
-val return : forall 'regs 'a 'e. 'a -> ME 'regs 'a 'e
-let return a s = [(Left a,s)]
+val return : forall 'regs 'a 'e. 'a -> M 'regs 'a 'e
+let return a s = [(Value a,s)]
-val bind : forall 'regs 'a 'b 'e. ME 'regs 'a 'e -> ('a -> ME 'regs 'b 'e) -> ME 'regs 'b 'e
+val bind : forall 'regs 'a 'b 'e. M 'regs 'a 'e -> ('a -> M 'regs 'b 'e) -> M 'regs 'b 'e
let bind m f (s : sequential_state 'regs) =
List.concatMap (function
- | (Left a, s') -> f a s'
- | (Right e, s') -> [(Right e, s')]
+ | (Value a, s') -> f a s'
+ | (Exception e, s') -> [(Exception e, s')]
end) (m s)
let inline (>>=) = bind
-val (>>): forall 'regs 'b 'e. ME 'regs unit 'e -> ME 'regs 'b 'e -> ME 'regs 'b 'e
-let inline (>>) m n = m >>= fun _ -> n
+val (>>): forall 'regs 'b 'e. M 'regs unit 'e -> M 'regs 'b 'e -> M 'regs 'b 'e
+let inline (>>) m n = m >>= fun (_ : unit) -> n
-val exit : forall 'regs 'e 'a. 'e -> M 'regs 'a
-let exit _ s = [(Right "exit", s)]
+val throw : forall 'regs 'a 'e. 'e -> M 'regs 'a 'e
+let throw e s = [(Exception (Throw e), s)]
-val assert_exp : forall 'regs. bool -> string -> M 'regs unit
-let assert_exp exp msg s = if exp then [(Left (), s)] else [(Right msg, s)]
+val try_catch : forall 'regs 'a 'e1 'e2. M 'regs 'a 'e1 -> ('e1 -> M 'regs 'a 'e2) -> M 'regs 'a 'e2
+let try_catch m h s =
+ List.concatMap (function
+ | (Value a, s') -> return a s'
+ | (Exception (Throw e), s') -> h e s'
+ | (Exception Exit, s') -> [(Exception Exit, s')]
+ | (Exception (Assert msg), s') -> [(Exception (Assert msg), s')]
+ end) (m s)
-val early_return : forall 'regs 'a 'r. 'r -> MR 'regs 'a 'r
-let early_return r s = [(Right (Left r), s)]
+val exit : forall 'regs 'e 'a. unit -> M 'regs 'a 'e
+let exit () s = [(Exception Exit, s)]
-val catch_early_return : forall 'regs 'a. MR 'regs 'a 'a -> M 'regs 'a
-let catch_early_return m s =
- List.map
+val assert_exp : forall 'regs 'e. bool -> string -> M 'regs unit 'e
+let assert_exp exp msg s = if exp then [(Value (), s)] else [(Exception (Assert msg), s)]
+
+(* For early return, we abuse exceptions by throwing and catching
+ the return value. The exception type is "either 'r 'e", where "Right e"
+ represents a proper exception and "Left r" an early return of value "r". *)
+type MR 'regs 'a 'r 'e = M 'regs 'a (either 'r 'e)
+
+val early_return : forall 'regs 'a 'r 'e. 'r -> MR 'regs 'a 'r 'e
+let early_return r = throw (Left r)
+
+val catch_early_return : forall 'regs 'a 'e. MR 'regs 'a 'a 'e -> M 'regs 'a 'e
+let catch_early_return m =
+ try_catch m
+ (function
+ | Left a -> return a
+ | Right e -> throw e
+ end)
+
+(* Lift to monad with early return by wrapping exceptions *)
+val liftR : forall 'a 'r 'regs 'e. M 'regs 'a 'e -> MR 'regs 'a 'r 'e
+let liftR m = try_catch m (fun e -> throw (Right e))
+
+(* Catch exceptions in the presence of early returns *)
+val try_catchR : forall 'regs 'a 'r 'e1 'e2. MR 'regs 'a 'r 'e1 -> ('e1 -> MR 'regs 'a 'r 'e2) -> MR 'regs 'a 'r 'e2
+let try_catchR m h =
+ try_catch m
(function
- | (Right (Left a), s') -> (Left a, s')
- | (Right (Right e), s') -> (Right e, s')
- | (Left a, s') -> (Left a, s')
- end) (m s)
+ | Left r -> throw (Left r)
+ | Right e -> h e
+ end)
val range : integer -> integer -> list integer
let rec range i j =
@@ -103,20 +127,20 @@ let is_exclusive = function
end
-val read_mem : forall 'regs 'a 'b. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'regs 'b
+val read_mem : forall 'regs 'a 'b 'e. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'regs 'b 'e
let read_mem dir read_kind addr sz state =
let addr = unsigned addr in
let addrs = range addr (addr+sz-1) in
let memory_value = List.map (fun addr -> Map_extra.find addr state.memstate) addrs in
let value = of_bits (Sail_values.internal_mem_value dir memory_value) in
if is_exclusive read_kind
- then [(Left value, <| state with last_exclusive_operation_was_load = true |>)]
- else [(Left value, state)]
+ then [(Value value, <| state with last_exclusive_operation_was_load = true |>)]
+ else [(Value value, state)]
(* caps are aligned at 32 bytes *)
let cap_alignment = (32 : integer)
-val read_tag : forall 'regs 'a. Bitvector 'a => bool -> read_kind -> 'a -> M 'regs bitU
+val read_tag : forall 'regs 'a 'e. Bitvector 'a => bool -> read_kind -> 'a -> M 'regs bitU 'e
let read_tag dir read_kind addr state =
let addr = (unsigned addr) / cap_alignment in
let tag = match (Map.lookup addr state.tagstate) with
@@ -124,20 +148,20 @@ let read_tag dir read_kind addr state =
| Nothing -> B0
end in
if is_exclusive read_kind
- then [(Left tag, <| state with last_exclusive_operation_was_load = true |>)]
- else [(Left tag, state)]
+ then [(Value tag, <| state with last_exclusive_operation_was_load = true |>)]
+ else [(Value tag, state)]
-val excl_result : forall 'regs. unit -> M 'regs bool
+val excl_result : forall 'regs 'e. unit -> M 'regs bool 'e
let excl_result () state =
let success =
- (Left true, <| state with last_exclusive_operation_was_load = false |>) in
- (Left false, state) :: if state.last_exclusive_operation_was_load then [success] else []
+ (Value true, <| state with last_exclusive_operation_was_load = false |>) in
+ (Value false, state) :: if state.last_exclusive_operation_was_load then [success] else []
-val write_mem_ea : forall 'regs 'a. Bitvector 'a => write_kind -> 'a -> integer -> M 'regs unit
+val write_mem_ea : forall 'regs 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M 'regs unit 'e
let write_mem_ea write_kind addr sz state =
- [(Left (), <| state with write_ea = Just (write_kind,unsigned addr,sz) |>)]
+ [(Value (), <| state with write_ea = Just (write_kind,unsigned addr,sz) |>)]
-val write_mem_val : forall 'a 'regs 'b. Bitvector 'a => 'a -> M 'regs bool
+val write_mem_val : forall 'a 'regs 'b 'e. Bitvector 'a => 'a -> M 'regs bool 'e
let write_mem_val v state =
let (write_kind,addr,sz) = match state.write_ea with
| Nothing -> failwith "write ea has not been announced yet"
@@ -147,27 +171,27 @@ let write_mem_val v state =
let addresses_with_value = List.zip addrs v in
let memstate = List.foldl (fun mem (addr,v) -> Map.insert addr v mem)
state.memstate addresses_with_value in
- [(Left true, <| state with memstate = memstate |>)]
+ [(Value true, <| state with memstate = memstate |>)]
-val write_tag : forall 'regs. bitU -> M 'regs bool
+val write_tag : forall 'regs 'e. bitU -> M 'regs bool 'e
let write_tag t state =
let (write_kind,addr,sz) = match state.write_ea with
| Nothing -> failwith "write ea has not been announced yet"
| Just write_ea -> write_ea end in
let taddr = addr / cap_alignment in
let tagstate = Map.insert taddr t state.tagstate in
- [(Left true, <| state with tagstate = tagstate |>)]
+ [(Value true, <| state with tagstate = tagstate |>)]
-val read_reg : forall 'regs 'a. register_ref 'regs 'a -> M 'regs 'a
+val read_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> M 'regs 'a 'e
let read_reg reg state =
let v = reg.read_from state.regstate in
- [(Left v,state)]
+ [(Value v,state)]
(*let read_reg_range reg i j state =
let v = slice (get_reg state (name_of_reg reg)) i j in
- [(Left (vec_to_bvec v),state)]
+ [(Value (vec_to_bvec v),state)]
let read_reg_bit reg i state =
let v = access (get_reg state (name_of_reg reg)) i in
- [(Left v,state)]
+ [(Value v,state)]
let read_reg_field reg regfield =
let (i,j) = register_field_indices reg regfield in
read_reg_range reg i j
@@ -177,17 +201,17 @@ let read_reg_bitfield reg regfield =
let reg_deref = read_reg
-val write_reg : forall 'regs 'a. register_ref 'regs 'a -> 'a -> M 'regs unit
+val write_reg : forall 'regs 'a 'e. register_ref 'regs 'a -> 'a -> M 'regs unit 'e
let write_reg reg v state =
- [(Left (), <| state with regstate = reg.write_to state.regstate v |>)]
+ [(Value (), <| state with regstate = reg.write_to state.regstate v |>)]
let write_reg_ref (reg, v) = write_reg reg v
-val update_reg : forall 'regs 'a 'b. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit
+val update_reg : forall 'regs 'a 'b 'e. register_ref 'regs 'a -> ('a -> 'b -> 'a) -> 'b -> M 'regs unit 'e
let update_reg reg f v state =
let current_value = get_reg state reg in
let new_value = f current_value v in
- [(Left (), set_reg state reg new_value)]
+ [(Value (), set_reg state reg new_value)]
let write_reg_field reg regfield = update_reg reg regfield.set_field
@@ -219,26 +243,26 @@ let update_reg_field_bit regfield i reg_val bit =
regfield.set_field reg_val new_field_value
let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i)
-val barrier : forall 'regs. barrier_kind -> M 'regs unit
+val barrier : forall 'regs 'e. barrier_kind -> M 'regs unit 'e
let barrier _ = return ()
-val footprint : forall 'regs. M 'regs unit
+val footprint : forall 'regs 'e. M 'regs unit 'e
let footprint s = return () s
-val iter_aux : forall 'regs 'e 'a. integer -> (integer -> 'a -> ME 'regs unit 'e) -> list 'a -> ME 'regs unit 'e
+val iter_aux : forall 'regs 'e 'a. integer -> (integer -> 'a -> M 'regs unit 'e) -> list 'a -> M 'regs unit 'e
let rec iter_aux i f xs = match xs with
| x :: xs -> f i x >> iter_aux (i + 1) f xs
| [] -> return ()
end
-val iteri : forall 'regs 'e 'a. (integer -> 'a -> ME 'regs unit 'e) -> list 'a -> ME 'regs unit 'e
+val iteri : forall 'regs 'e 'a. (integer -> 'a -> M 'regs unit 'e) -> list 'a -> M 'regs unit 'e
let iteri f xs = iter_aux 0 f xs
-val iter : forall 'regs 'e 'a. ('a -> ME 'regs unit 'e) -> list 'a -> ME 'regs unit 'e
+val iter : forall 'regs 'e 'a. ('a -> M 'regs unit 'e) -> list 'a -> M 'regs unit 'e
let iter f xs = iteri (fun _ x -> f x) xs
val foreachM_inc : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+ (integer -> 'vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e
let rec foreachM_inc (i,stop,by) vars body =
if (by > 0 && i <= stop) || (by < 0 && stop <= i)
then
@@ -248,7 +272,7 @@ let rec foreachM_inc (i,stop,by) vars body =
val foreachM_dec : forall 'regs 'vars 'e. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+ (integer -> 'vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e
let rec foreachM_dec (i,stop,by) vars body =
if (by > 0 && i >= stop) || (by < 0 && stop >= i)
then
@@ -261,20 +285,20 @@ let rec while_PP vars cond body =
if cond vars then while_PP (body vars) cond body else vars
val while_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) ->
- ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+ ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e
let rec while_PM vars cond body =
if cond vars then
body vars >>= fun vars -> while_PM vars cond body
else return vars
-val while_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
- ('vars -> 'vars) -> ME 'regs 'vars 'e
+val while_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) ->
+ ('vars -> 'vars) -> M 'regs 'vars 'e
let rec while_MP vars cond body =
cond vars >>= fun cond_val ->
if cond_val then while_MP (body vars) cond body else return vars
-val while_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
- ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+val while_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) ->
+ ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e
let rec while_MM vars cond body =
cond vars >>= fun cond_val ->
if cond_val then
@@ -287,20 +311,20 @@ let rec until_PP vars cond body =
if (cond vars) then vars else until_PP (body vars) cond body
val until_PM : forall 'regs 'vars 'e. 'vars -> ('vars -> bool) ->
- ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+ ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e
let rec until_PM vars cond body =
body vars >>= fun vars ->
if (cond vars) then return vars else until_PM vars cond body
-val until_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
- ('vars -> 'vars) -> ME 'regs 'vars 'e
+val until_MP : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) ->
+ ('vars -> 'vars) -> M 'regs 'vars 'e
let rec until_MP vars cond body =
let vars = body vars in
cond vars >>= fun cond_val ->
if cond_val then return vars else until_MP vars cond body
-val until_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> ME 'regs bool 'e) ->
- ('vars -> ME 'regs 'vars 'e) -> ME 'regs 'vars 'e
+val until_MM : forall 'regs 'vars 'e. 'vars -> ('vars -> M 'regs bool 'e) ->
+ ('vars -> M 'regs 'vars 'e) -> M 'regs 'vars 'e
let rec until_MM vars cond body =
body vars >>= fun vars ->
cond vars >>= fun cond_val ->
diff --git a/src/lem_interp/sail_impl_base.lem b/src/lem_interp/sail_impl_base.lem
index 421219da..368f7505 100644
--- a/src/lem_interp/sail_impl_base.lem
+++ b/src/lem_interp/sail_impl_base.lem
@@ -905,36 +905,35 @@ end
(* the address_lifted types should go away here and be replaced by address *)
type with_aux 'o = 'o * maybe ((unit -> (string * string)) * ((list (reg_name * register_value)) -> list event))
-type outcome_r 'a 'r =
+type outcome 'a 'e =
(* Request to read memory, value is location to read, integer is size to read,
followed by registers that were used in computing that size *)
- | Read_mem of (read_kind * address_lifted * nat) * (memory_value -> with_aux (outcome_r 'a 'r))
+ | Read_mem of (read_kind * address_lifted * nat) * (memory_value -> with_aux (outcome 'a 'e))
(* Tell the system a write is imminent, at address lifted, of size nat *)
- | Write_ea of (write_kind * address_lifted * nat) * (with_aux (outcome_r 'a 'r))
+ | Write_ea of (write_kind * address_lifted * nat) * (with_aux (outcome 'a 'e))
(* Request the result of store-exclusive *)
- | Excl_res of (bool -> with_aux (outcome_r 'a 'r))
+ | Excl_res of (bool -> with_aux (outcome 'a 'e))
(* Request to write memory at last signalled address. Memory value should be 8
times the size given in ea signal *)
- | Write_memv of memory_value * (bool -> with_aux (outcome_r 'a 'r))
+ | Write_memv of memory_value * (bool -> with_aux (outcome 'a 'e))
(* Request a memory barrier *)
- | Barrier of barrier_kind * with_aux (outcome_r 'a 'r)
+ | Barrier of barrier_kind * with_aux (outcome 'a 'e)
(* Tell the system to dynamically recalculate dependency footprint *)
- | Footprint of with_aux (outcome_r 'a 'r)
+ | Footprint of with_aux (outcome 'a 'e)
(* Request to read register, will track dependency when mode.track_values *)
- | Read_reg of reg_name * (register_value -> with_aux (outcome_r 'a 'r))
+ | Read_reg of reg_name * (register_value -> with_aux (outcome 'a 'e))
(* Request to write register *)
- | Write_reg of (reg_name * register_value) * with_aux (outcome_r 'a 'r)
+ | Write_reg of (reg_name * register_value) * with_aux (outcome 'a 'e)
| Escape of maybe string
(*Result of a failed assert with possible error message to report*)
| Fail of maybe string
- (* Early return with value of type 'r *)
- | Return of 'r
- | Internal of (maybe string * maybe (unit -> string)) * with_aux (outcome_r 'a 'r)
+ (* Exception of type 'e *)
+ | Exception of 'e
+ | Internal of (maybe string * maybe (unit -> string)) * with_aux (outcome 'a 'e)
| Done of 'a
| Error of string
-type outcome 'a = outcome_r 'a unit
-type outcome_s 'a = with_aux (outcome 'a)
+type outcome_s 'a 'e = with_aux (outcome 'a 'e)
(* first string : output of instruction_stack_to_string
second string: output of local_variables_to_string *)
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 8730b019..0cbeda49 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -102,7 +102,7 @@ let rec subst_nc substs (NC_aux (nc,l) as n_constraint) =
begin
match KBindings.find kid substs with
| Nexp_aux (Nexp_constant i,_) ->
- if List.mem i is then re NC_true else re NC_false
+ if List.exists (fun j -> Big_int.eq_big_int i j) is then re NC_true else re NC_false
| nexp ->
raise (Reporting_basic.err_general l
("Unable to substitute " ^ string_of_nexp nexp ^
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 96f312fb..6a3d1293 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -249,12 +249,12 @@ let doc_typ_lem, doc_atomic_typ_lem =
| Typ_fn(arg,ret,efct) ->
let ret_typ =
if effectful efct
- then separate space [string "_M"; fn_typ true ret]
+ then separate space [string "M"; fn_typ true ret]
else separate space [fn_typ false ret] in
let arg_typs = match arg with
| Typ_aux (Typ_tup typs, _) ->
- List.map (app_typ true) typs
- | _ -> [tup_typ true arg] in
+ List.map (app_typ false) typs
+ | _ -> [tup_typ false arg] in
let tpp = separate (space ^^ arrow ^^ space) (arg_typs @ [ret_typ]) in
(* once we have proper excetions we need to know what the exceptions type is *)
if atyp_needed then parens tpp else tpp
@@ -355,7 +355,7 @@ let doc_tannot_lem eff typ =
if contains_t_pp_var typ then empty
else
let ta = doc_typ_lem typ in
- if eff then string " : _M " ^^ parens ta
+ if eff then string " : M " ^^ parens ta
else string " : " ^^ ta
(* doc_lit_lem gets as an additional parameter the type information from the
@@ -685,7 +685,7 @@ let doc_exp_lem, doc_let_lem =
contains_t_pp_var (typ_of full_exp) then
aexp_needed, epp
else
- let tannot = separate space [string "_MR";
+ let tannot = separate space [string "MR";
doc_atomic_typ_lem false (typ_of full_exp);
doc_atomic_typ_lem false (typ_of exp)] in
true, doc_op colon epp tannot in
@@ -855,7 +855,20 @@ let doc_exp_lem, doc_let_lem =
(separate_map (break 1) (doc_case early_ret) pexps) ^/^
(string "end")) in
if aexp_needed then parens (align epp) else align epp
- | E_exit e -> liftR (separate space [string "exit"; expY e;])
+ | E_try (e, pexps) ->
+ if effectful (effect_of e) then
+ let try_catch = if early_ret then "try_catchR" else "try_catch" in
+ let epp =
+ group ((separate space [string try_catch; expY e; string "(function "]) ^/^
+ (separate_map (break 1) (doc_case early_ret) pexps) ^/^
+ (string "end)")) in
+ if aexp_needed then parens (align epp) else align epp
+ else
+ raise (Reporting_basic.err_todo l "Warning: try-block around pure expression")
+ | E_throw e ->
+ let epp = liftR (separate space [string "throw"; expY e]) in
+ if aexp_needed then parens (align epp) else align epp
+ | E_exit e -> liftR (separate space [string "exit"; expY e])
| E_assert (e1,e2) ->
let epp = liftR (separate space [string "assert_exp"; expY e1; expY e2]) in
if aexp_needed then parens (align epp) else align epp
@@ -1495,6 +1508,12 @@ let find_registers (Defs defs) =
| _ -> acc
) [] defs
+let find_exc_typ (Defs defs) =
+ let is_exc_typ_def = function
+ | DEF_type td -> string_of_id (id_of_type_def td) = "exception"
+ | _ -> false in
+ if List.exists is_exc_typ_def defs then "exception" else "unit"
+
let doc_regstate_lem registers =
let l = Parse_ast.Unknown in
let annot = (l, None) in
@@ -1528,12 +1547,7 @@ let doc_regstate_lem registers =
doc_op equals (string "let initial_regstate") (doc_exp_lem false false exp)
else empty
in
- concat [
- doc_typdef_lem (TD_aux (regstate, annot)); hardline;
- hardline;
- string "type _MR 'a 'r = MR regstate 'a 'r"; hardline;
- string "type _M 'a = M regstate 'a"
- ],
+ doc_typdef_lem (TD_aux (regstate, annot)),
initregstate
let doc_register_refs_lem registers =
@@ -1564,6 +1578,7 @@ let pp_defs_lem (types_file,types_modules) (defs_file,defs_modules) d top_line =
let (typdefs,valdefs) = doc_defs_lem d in
let regstate_def, initregstate_def = doc_regstate_lem (find_registers d) in
let register_refs = doc_register_refs_lem (find_registers d) in
+ let exc_typ = find_exc_typ d in
(print types_file)
(concat
[string "(*" ^^ (string top_line) ^^ string "*)";hardline;
@@ -1585,10 +1600,17 @@ let pp_defs_lem (types_file,types_modules) (defs_file,defs_modules) d top_line =
hardline;
if !opt_sequential then
concat [regstate_def; hardline;
- hardline;
- register_refs]
+ hardline;
+ string ("type MR 'a 'r = State.MR regstate 'a 'r " ^ exc_typ); hardline;
+ string ("type M 'a = State.M regstate 'a " ^ exc_typ); hardline;
+ hardline;
+ register_refs
+ ]
else
- concat [string "type _MR 'a 'r = MR 'a 'r"; hardline; string "type _M 'a = M 'a"; hardline]
+ concat [
+ string ("type MR 'a 'r = Prompt.MR 'a 'r " ^ exc_typ); hardline;
+ string ("type M 'a = Prompt.M 'a " ^ exc_typ); hardline
+ ]
]);
(print defs_file)
(concat
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 88fb17a5..31bcb577 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -130,12 +130,11 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match snd annot with
| E_record_update(e,fexps) ->
union_effects (effect_of e) (effect_of_fexps fexps)
| E_field (e,_) -> effect_of e
- | E_case (e,pexps) ->
+ | E_case (e,pexps) | E_try (e,pexps) ->
List.fold_left union_effects (effect_of e) (List.map effect_of_pexp pexps)
| E_let (lb,e) -> union_effects (effect_of_lb lb) (effect_of e)
| E_assign (lexp,e) -> union_effects (effect_of_lexp lexp) (effect_of e)
- | E_exit e -> union_effects eff (effect_of e)
- | E_return e -> union_effects eff (effect_of e)
+ | E_exit e | E_return e | E_throw e -> union_effects eff (effect_of e)
| E_sizeof _ | E_sizeof_internal _ | E_constraint _ -> no_effect
| E_assert (c,m) -> union_effects eff (union_eff_exps [c; m])
| E_comment _ | E_comment_struc _ -> no_effect
diff --git a/src/rewrites.ml b/src/rewrites.ml
index a42335b9..32ffe54a 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -2378,6 +2378,11 @@ let rewrite_defs_letbind_effects =
n_exp_name exp1 (fun exp1 ->
n_pexpL newreturn pexps (fun pexps ->
k (rewrap (E_case (exp1,pexps)))))
+ | E_try (exp1,pexps) ->
+ let newreturn = effectful exp1 || List.exists effectful_pexp pexps in
+ n_exp_name exp1 (fun exp1 ->
+ n_pexpL newreturn pexps (fun pexps ->
+ k (rewrap (E_try (exp1,pexps)))))
| E_let (lb,body) ->
n_lb lb (fun lb ->
rewrap (E_let (lb,n_exp body k)))
@@ -2416,6 +2421,9 @@ let rewrite_defs_letbind_effects =
| E_return exp' ->
n_exp_name exp' (fun exp' ->
k (rewrap (E_return exp')))
+ | E_throw exp' ->
+ n_exp_name exp' (fun exp' ->
+ k (rewrap (E_throw exp')))
| E_internal_plet _ -> failwith "E_internal_plet should not be here yet" in
let rewrite_fun _ (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),fdannot)) =