summaryrefslogtreecommitdiff
path: root/src/gen_lib/prompt.lem
diff options
context:
space:
mode:
Diffstat (limited to 'src/gen_lib/prompt.lem')
-rw-r--r--src/gen_lib/prompt.lem149
1 files changed, 80 insertions, 69 deletions
diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem
index 4646ef6f..77a39096 100644
--- a/src/gen_lib/prompt.lem
+++ b/src/gen_lib/prompt.lem
@@ -2,13 +2,12 @@ open import Pervasives_extra
open import Sail_impl_base
open import Sail_values
-type MR 'a 'r = outcome_r 'a 'r
-type M 'a = outcome 'a
+type M 'a 'e = outcome 'a 'e
-val return : forall 'a 'r. 'a -> MR 'a 'r
+val return : forall 'a 'e. 'a -> M 'a 'e
let return a = Done a
-val bind : forall 'a 'b 'r. MR 'a 'r -> ('a -> MR 'b 'r) -> MR 'b 'r
+val bind : forall 'a 'b 'e. M 'a 'e -> ('a -> M 'b 'e) -> M 'b 'e
let rec bind m f = match m with
| Done a -> f a
| Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (bind o f,opt))
@@ -22,61 +21,73 @@ let rec bind m f = match m with
| Escape descr -> Escape descr
| Fail descr -> Fail descr
| Error descr -> Error descr
- | Return a -> Return a
+ | Exception e -> Exception e
| Internal descr o_s -> Internal descr (let (o,opt) = o_s in (bind o f ,opt))
end
let inline (>>=) = bind
-val (>>) : forall 'b 'r. MR unit 'r -> MR 'b 'r -> MR 'b 'r
-let inline (>>) m n = m >>= fun _ -> n
+val (>>) : forall 'b 'e. M unit 'e -> M 'b 'e -> M 'b 'e
+let inline (>>) m n = m >>= fun (_ : unit) -> n
-val exit : forall 'a 'b. 'b -> M 'a
-let exit s = Fail Nothing
+val exit : forall 'a 'e. unit -> M 'a 'e
+let exit () = Fail Nothing
-val assert_exp : bool -> string -> M unit
+val assert_exp : forall 'e. bool -> string -> M unit 'e
let assert_exp exp msg = if exp then Done () else Fail (Just msg)
-val early_return : forall 'r. 'r -> MR unit 'r
-let early_return r = Return r
+val throw : forall 'a 'e. 'e -> M 'a 'e
+let throw e = Exception e
-val liftR : forall 'a 'r. M 'a -> MR 'a 'r
-let rec liftR m = match m with
+val try_catch : forall 'a 'e1 'e2. M 'a 'e1 -> ('e1 -> M 'a 'e2) -> M 'a 'e2
+let rec try_catch m h = match m with
| Done a -> Done a
- | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (liftR o,opt))
- | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (liftR o,opt))
- | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (liftR o,opt))
- | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (liftR o,opt))
- | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (liftR o,opt))
- | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (liftR o,opt))
- | Footprint o_s -> Footprint (let (o,opt) = o_s in (liftR o,opt))
- | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (liftR o,opt))
- | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (liftR o,opt))
+ | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (try_catch o h,opt))
+ | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (try_catch o h,opt))
+ | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (try_catch o h,opt))
+ | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (try_catch o h,opt))
+ | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (try_catch o h,opt))
+ | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (try_catch o h,opt))
+ | Footprint o_s -> Footprint (let (o,opt) = o_s in (try_catch o h,opt))
+ | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (try_catch o h,opt))
| Escape descr -> Escape descr
| Fail descr -> Fail descr
| Error descr -> Error descr
- | Return _ -> Error "uncaught early return"
+ | Exception e -> h e
+ | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (try_catch o h ,opt))
end
-val catch_early_return : forall 'a 'r. MR 'a 'a -> M 'a
-let rec catch_early_return m = match m with
- | Done a -> Done a
- | Read_mem descr k -> Read_mem descr (fun v -> let (o,opt) = k v in (catch_early_return o,opt))
- | Read_reg descr k -> Read_reg descr (fun v -> let (o,opt) = k v in (catch_early_return o,opt))
- | Write_memv descr k -> Write_memv descr (fun v -> let (o,opt) = k v in (catch_early_return o,opt))
- | Excl_res k -> Excl_res (fun v -> let (o,opt) = k v in (catch_early_return o,opt))
- | Write_ea descr o_s -> Write_ea descr (let (o,opt) = o_s in (catch_early_return o,opt))
- | Barrier descr o_s -> Barrier descr (let (o,opt) = o_s in (catch_early_return o,opt))
- | Footprint o_s -> Footprint (let (o,opt) = o_s in (catch_early_return o,opt))
- | Write_reg descr o_s -> Write_reg descr (let (o,opt) = o_s in (catch_early_return o,opt))
- | Internal descr o_s -> Internal descr (let (o,opt) = o_s in (catch_early_return o,opt))
- | Escape descr -> Escape descr
- | Fail descr -> Fail descr
- | Error descr -> Error descr
- | Return a -> Done a
-end
-
-val read_mem : forall 'a 'b. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'b
+(* For early return, we abuse exceptions by throwing and catching
+ the return value. The exception type is "either 'r 'e", where "Right e"
+ represents a proper exception and "Left r" an early return of value "r". *)
+type MR 'a 'r 'e = M 'a (either 'r 'e)
+
+val early_return : forall 'a 'r 'e. 'r -> MR 'a 'r 'e
+let early_return r = throw (Left r)
+
+val catch_early_return : forall 'a 'e. MR 'a 'a 'e -> M 'a 'e
+let catch_early_return m =
+ try_catch m
+ (function
+ | Left a -> return a
+ | Right e -> throw e
+ end)
+
+(* Lift to monad with early return by wrapping exceptions *)
+val liftR : forall 'a 'r 'e. M 'a 'e -> MR 'a 'r 'e
+let liftR m = try_catch m (fun e -> throw (Right e))
+
+(* Catch exceptions in the presence of early returns *)
+val try_catchR : forall 'a 'r 'e1 'e2. MR 'a 'r 'e1 -> ('e1 -> MR 'a 'r 'e2) -> MR 'a 'r 'e2
+let try_catchR m h =
+ try_catch m
+ (function
+ | Left r -> throw (Left r)
+ | Right e -> h e
+ end)
+
+
+val read_mem : forall 'a 'b 'e. Bitvector 'a, Bitvector 'b => bool -> read_kind -> 'a -> integer -> M 'b 'e
let read_mem dir rk addr sz =
let addr = address_lifted_of_bitv (bits_of addr) in
let sz = natFromInteger sz in
@@ -85,24 +96,24 @@ let read_mem dir rk addr sz =
(Done bitv,Nothing) in
Read_mem (rk,addr,sz) k
-val excl_result : unit -> M bool
+val excl_result : forall 'e. unit -> M bool 'e
let excl_result () =
let k successful = (return successful,Nothing) in
Excl_res k
-val write_mem_ea : forall 'a. Bitvector 'a => write_kind -> 'a -> integer -> M unit
+val write_mem_ea : forall 'a 'e. Bitvector 'a => write_kind -> 'a -> integer -> M unit 'e
let write_mem_ea wk addr sz =
let addr = address_lifted_of_bitv (bits_of addr) in
let sz = natFromInteger sz in
Write_ea (wk,addr,sz) (Done (),Nothing)
-val write_mem_val : forall 'a. Bitvector 'a => 'a -> M bool
+val write_mem_val : forall 'a 'e. Bitvector 'a => 'a -> M bool 'e
let write_mem_val v =
let v = external_mem_value (bits_of v) in
let k successful = (return successful,Nothing) in
Write_memv v k
-val read_reg_aux : forall 'a. Bitvector 'a => reg_name -> M 'a
+val read_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> M 'a 'e
let read_reg_aux reg =
let k reg_value =
let v = of_bits (internal_reg_value reg_value) in
@@ -124,7 +135,7 @@ let read_reg_bitfield reg regfield =
let reg_deref = read_reg
-val write_reg_aux : forall 'a. Bitvector 'a => reg_name -> 'a -> M unit
+val write_reg_aux : forall 'a 'e. Bitvector 'a => reg_name -> 'a -> M unit 'e
let write_reg_aux reg_name v =
let regval = external_reg_value reg_name (bits_of v) in
Write_reg (reg_name,regval) (Done (), Nothing)
@@ -150,29 +161,29 @@ let write_reg_field_bit = write_reg_field_pos
let write_reg_ref (reg, v) = write_reg reg v
-val barrier : barrier_kind -> M unit
+val barrier : forall 'e. barrier_kind -> M unit 'e
let barrier bk = Barrier bk (Done (), Nothing)
-val footprint : M unit
+val footprint : forall 'e. M unit 'e
let footprint = Footprint (Done (),Nothing)
-val iter_aux : forall 'regs 'e 'a. integer -> (integer -> 'a -> MR unit 'e) -> list 'a -> MR unit 'e
+val iter_aux : forall 'a 'e. integer -> (integer -> 'a -> M unit 'e) -> list 'a -> M unit 'e
let rec iter_aux i f xs = match xs with
| x :: xs -> f i x >> iter_aux (i + 1) f xs
| [] -> return ()
end
-val iteri : forall 'regs 'e 'a. (integer -> 'a -> MR unit 'e) -> list 'a -> MR unit 'e
+val iteri : forall 'a 'e. (integer -> 'a -> M unit 'e) -> list 'a -> M unit 'e
let iteri f xs = iter_aux 0 f xs
-val iter : forall 'regs 'e 'a. ('a -> MR unit 'e) -> list 'a -> MR unit 'e
+val iter : forall 'a 'e. ('a -> M unit 'e) -> list 'a -> M unit 'e
let iter f xs = iteri (fun _ x -> f x) xs
-val foreachM_inc : forall 'vars 'r. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> MR 'vars 'r) -> MR 'vars 'r
+val foreachM_inc : forall 'vars 'e. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> M 'vars 'e) -> M 'vars 'e
let rec foreachM_inc (i,stop,by) vars body =
if (by > 0 && i <= stop) || (by < 0 && stop <= i)
then
@@ -181,8 +192,8 @@ let rec foreachM_inc (i,stop,by) vars body =
else return vars
-val foreachM_dec : forall 'vars 'r. (integer * integer * integer) -> 'vars ->
- (integer -> 'vars -> MR 'vars 'r) -> MR 'vars 'r
+val foreachM_dec : forall 'vars 'e. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> M 'vars 'e) -> M 'vars 'e
let rec foreachM_dec (i,stop,by) vars body =
if (by > 0 && i >= stop) || (by < 0 && stop >= i)
then
@@ -194,21 +205,21 @@ val while_PP : forall 'vars. 'vars -> ('vars -> bool) -> ('vars -> 'vars) -> 'va
let rec while_PP vars cond body =
if cond vars then while_PP (body vars) cond body else vars
-val while_PM : forall 'vars 'r. 'vars -> ('vars -> bool) ->
- ('vars -> MR 'vars 'r) -> MR 'vars 'r
+val while_PM : forall 'vars 'e. 'vars -> ('vars -> bool) ->
+ ('vars -> M 'vars 'e) -> M 'vars 'e
let rec while_PM vars cond body =
if cond vars then
body vars >>= fun vars -> while_PM vars cond body
else return vars
-val while_MP : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) ->
- ('vars -> 'vars) -> MR 'vars 'r
+val while_MP : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) ->
+ ('vars -> 'vars) -> M 'vars 'e
let rec while_MP vars cond body =
cond vars >>= fun cond_val ->
if cond_val then while_MP (body vars) cond body else return vars
-val while_MM : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) ->
- ('vars -> MR 'vars 'r) -> MR 'vars 'r
+val while_MM : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) ->
+ ('vars -> M 'vars 'e) -> M 'vars 'e
let rec while_MM vars cond body =
cond vars >>= fun cond_val ->
if cond_val then
@@ -220,21 +231,21 @@ let rec until_PP vars cond body =
let vars = body vars in
if (cond vars) then vars else until_PP (body vars) cond body
-val until_PM : forall 'vars 'r. 'vars -> ('vars -> bool) ->
- ('vars -> MR 'vars 'r) -> MR 'vars 'r
+val until_PM : forall 'vars 'e. 'vars -> ('vars -> bool) ->
+ ('vars -> M 'vars 'e) -> M 'vars 'e
let rec until_PM vars cond body =
body vars >>= fun vars ->
if (cond vars) then return vars else until_PM vars cond body
-val until_MP : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) ->
- ('vars -> 'vars) -> MR 'vars 'r
+val until_MP : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) ->
+ ('vars -> 'vars) -> M 'vars 'e
let rec until_MP vars cond body =
let vars = body vars in
cond vars >>= fun cond_val ->
if cond_val then return vars else until_MP vars cond body
-val until_MM : forall 'vars 'r. 'vars -> ('vars -> MR bool 'r) ->
- ('vars -> MR 'vars 'r) -> MR 'vars 'r
+val until_MM : forall 'vars 'e. 'vars -> ('vars -> M bool 'e) ->
+ ('vars -> M 'vars 'e) -> M 'vars 'e
let rec until_MM vars cond body =
body vars >>= fun vars ->
cond vars >>= fun cond_val ->