diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/gen_lib/prompt_monad.lem | 7 | ||||
| -rw-r--r-- | src/gen_lib/sail_values.lem | 16 | ||||
| -rw-r--r-- | src/gen_lib/state.lem | 65 | ||||
| -rw-r--r-- | src/gen_lib/state_lifting.lem | 27 | ||||
| -rw-r--r-- | src/gen_lib/state_monad.lem | 14 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 25 | ||||
| -rw-r--r-- | src/process_file.ml | 8 | ||||
| -rw-r--r-- | src/process_file.mli | 2 |
8 files changed, 99 insertions, 65 deletions
diff --git a/src/gen_lib/prompt_monad.lem b/src/gen_lib/prompt_monad.lem index 9eb59319..87c9af39 100644 --- a/src/gen_lib/prompt_monad.lem +++ b/src/gen_lib/prompt_monad.lem @@ -212,3 +212,10 @@ let barrier bk = Barrier bk (Done ()) val footprint : forall 'rv 'e. unit -> monad 'rv unit 'e let footprint _ = Footprint (Done ()) + +(* Define a type synonym that also takes the register state as a type parameter, + in order to make switching to the state monad without changing generated + definitions easier, see also lib/hol/prompt_monad.lem. *) + +type base_monad 'regval 'regstate 'a 'e = monad 'regval 'a 'e +type base_monadR 'regval 'regstate 'a 'r 'e = monadR 'regval 'a 'r 'e diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index 5c6dc593..a709a8c1 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -85,7 +85,7 @@ let rec replace bs (n : integer) b' = match bs with if n = 0 then b' :: bs else b :: replace bs (n - 1) b' end -declare {isabelle} termination_argument replace = automatic +declare {isabelle; hol} termination_argument replace = automatic let upper n = n @@ -128,7 +128,7 @@ let rec just_list l = match l with | (_, _) -> Nothing end end -declare {isabelle} termination_argument just_list = automatic +declare {isabelle; hol} termination_argument just_list = automatic lemma just_list_spec: ((forall xs. (just_list xs = Nothing) <-> List.elem Nothing xs) && @@ -267,7 +267,7 @@ let rec nat_of_bools_aux acc bs = match bs with | true :: bs -> nat_of_bools_aux ((2 * acc) + 1) bs | false :: bs -> nat_of_bools_aux (2 * acc) bs end -declare {isabelle} termination_argument nat_of_bools_aux = automatic +declare {isabelle; hol} termination_argument nat_of_bools_aux = automatic let nat_of_bools bs = nat_of_bools_aux 0 bs val unsigned_of_bools : list bool -> integer @@ -306,7 +306,7 @@ let rec add_one_bool_ignore_overflow_aux bits = match bits with | false :: bits -> true :: bits | true :: bits -> false :: add_one_bool_ignore_overflow_aux bits end -declare {isabelle} termination_argument add_one_bool_ignore_overflow_aux = automatic +declare {isabelle; hol} termination_argument add_one_bool_ignore_overflow_aux = automatic let add_one_bool_ignore_overflow bits = List.reverse (add_one_bool_ignore_overflow_aux (List.reverse bits)) @@ -369,7 +369,7 @@ let rec add_one_bit_ignore_overflow_aux bits = match bits with | B1 :: bits -> B0 :: add_one_bit_ignore_overflow_aux bits | BU :: bits -> BU :: List.map (fun _ -> BU) bits end -declare {isabelle} termination_argument add_one_bit_ignore_overflow_aux = automatic +declare {isabelle; hol} termination_argument add_one_bit_ignore_overflow_aux = automatic let add_one_bit_ignore_overflow bits = List.reverse (add_one_bit_ignore_overflow_aux (List.reverse bits)) @@ -417,7 +417,7 @@ let rec hexstring_of_bits bs = match bs with | [] -> Just [] | _ -> Nothing end -declare {isabelle} termination_argument hexstring_of_bits = automatic +declare {isabelle; hol} termination_argument hexstring_of_bits = automatic let show_bitlist bs = match hexstring_of_bits bs with @@ -630,7 +630,7 @@ let rec byte_chunks bs = match bs with Maybe.bind (byte_chunks rest) (fun rest -> Just ([a;b;c;d;e;f;g;h] :: rest)) | _ -> Nothing end -declare {isabelle} termination_argument byte_chunks = automatic +declare {isabelle; hol} termination_argument byte_chunks = automatic val bytes_of_bits : forall 'a. Bitvector 'a => 'a -> maybe (list memory_byte) let bytes_of_bits bs = byte_chunks (bits_of bs) @@ -854,7 +854,7 @@ let rec foreach l vars body = | (x :: xs) -> foreach xs (body x vars) body end -declare {isabelle} termination_argument foreach = automatic +declare {isabelle; hol} termination_argument foreach = automatic val index_list : integer -> integer -> integer -> list integer let rec index_list from to step = diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index 61477258..f69f59c1 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -1,33 +1,8 @@ open import Pervasives_extra -(*open import Sail_impl_base*) open import Sail_values -open import Prompt_monad -open import Prompt open import State_monad open import {isabelle} `State_monad_lemmas` -(* State monad wrapper around prompt monad *) - -val liftState : forall 'regval 'regs 'a 'e. register_accessors 'regs 'regval -> monad 'regval 'a 'e -> monadS 'regs 'a 'e -let rec liftState ra s = match s with - | (Done a) -> returnS a - | (Read_mem rk a sz k) -> bindS (read_mem_bytesS rk a sz) (fun v -> liftState ra (k v)) - | (Read_tag t k) -> bindS (read_tagS t) (fun v -> liftState ra (k v)) - | (Write_memv a k) -> bindS (write_mem_bytesS a) (fun v -> liftState ra (k v)) - | (Write_tag a t k) -> bindS (write_tagS a t) (fun v -> liftState ra (k v)) - | (Read_reg r k) -> bindS (read_regvalS ra r) (fun v -> liftState ra (k v)) - | (Excl_res k) -> bindS (excl_resultS ()) (fun v -> liftState ra (k v)) - | (Undefined k) -> bindS (undefined_boolS ()) (fun v -> liftState ra (k v)) - | (Write_ea wk a sz k) -> seqS (write_mem_eaS wk a sz) (liftState ra k) - | (Write_reg r v k) -> seqS (write_regvalS ra r v) (liftState ra k) - | (Footprint k) -> liftState ra k - | (Barrier _ k) -> liftState ra k - | (Print _ k) -> liftState ra k (* TODO *) - | (Fail descr) -> failS descr - | (Exception e) -> throwS e -end - - val iterS_aux : forall 'rv 'a 'e. integer -> (integer -> 'a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e let rec iterS_aux i f xs = match xs with | x :: xs -> f i x >>$ iterS_aux (i + 1) f xs @@ -53,6 +28,46 @@ end declare {isabelle} termination_argument foreachS = automatic +val and_boolS : forall 'rv 'e. monadS 'rv bool 'e -> monadS 'rv bool 'e -> monadS 'rv bool 'e +let and_boolS l r = l >>$= (fun l -> if l then r else returnS false) + +val or_boolS : forall 'rv 'e. monadS 'rv bool 'e -> monadS 'rv bool 'e -> monadS 'rv bool 'e +let or_boolS l r = l >>$= (fun l -> if l then returnS true else r) + +val bool_of_bitU_fail : forall 'rv 'e. bitU -> monadS 'rv bool 'e +let bool_of_bitU_fail = function + | B0 -> returnS false + | B1 -> returnS true + | BU -> failS "bool_of_bitU" +end + +val bool_of_bitU_oracleS : forall 'rv 'e. bitU -> monadS 'rv bool 'e +let bool_of_bitU_oracleS = function + | B0 -> returnS false + | B1 -> returnS true + | BU -> undefined_boolS () +end + +val bools_of_bits_oracleS : forall 'rv 'e. list bitU -> monadS 'rv (list bool) 'e +let bools_of_bits_oracleS bits = + foreachS bits [] + (fun b bools -> + bool_of_bitU_oracleS b >>$= (fun b -> + returnS (bools ++ [b]))) + +val of_bits_oracleS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e +let of_bits_oracleS bits = + bools_of_bits_oracleS bits >>$= (fun bs -> + returnS (of_bools bs)) + +val of_bits_failS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e +let of_bits_failS bits = maybe_failS "of_bits" (of_bits bits) + +val mword_oracleS : forall 'rv 'a 'e. Size 'a => unit -> monadS 'rv (mword 'a) 'e +let mword_oracleS () = + bools_of_bits_oracleS (repeat [BU] (integerFromNat size)) >>$= (fun bs -> + returnS (wordFromBitlist bs)) + val whileS : forall 'rv 'vars 'e. 'vars -> ('vars -> monadS 'rv bool 'e) -> ('vars -> monadS 'rv 'vars 'e) -> monadS 'rv 'vars 'e diff --git a/src/gen_lib/state_lifting.lem b/src/gen_lib/state_lifting.lem new file mode 100644 index 00000000..7e569a7e --- /dev/null +++ b/src/gen_lib/state_lifting.lem @@ -0,0 +1,27 @@ +open import Pervasives_extra +open import Sail_values +open import Prompt_monad +open import Prompt +open import State_monad +open import {isabelle} `State_monad_lemmas` + +(* State monad wrapper around prompt monad *) + +val liftState : forall 'regval 'regs 'a 'e. register_accessors 'regs 'regval -> monad 'regval 'a 'e -> monadS 'regs 'a 'e +let rec liftState ra s = match s with + | (Done a) -> returnS a + | (Read_mem rk a sz k) -> bindS (read_mem_bytesS rk a sz) (fun v -> liftState ra (k v)) + | (Read_tag t k) -> bindS (read_tagS t) (fun v -> liftState ra (k v)) + | (Write_memv a k) -> bindS (write_mem_bytesS a) (fun v -> liftState ra (k v)) + | (Write_tag a t k) -> bindS (write_tagS a t) (fun v -> liftState ra (k v)) + | (Read_reg r k) -> bindS (read_regvalS ra r) (fun v -> liftState ra (k v)) + | (Excl_res k) -> bindS (excl_resultS ()) (fun v -> liftState ra (k v)) + | (Undefined k) -> bindS (undefined_boolS ()) (fun v -> liftState ra (k v)) + | (Write_ea wk a sz k) -> seqS (write_mem_eaS wk a sz) (liftState ra k) + | (Write_reg r v k) -> seqS (write_regvalS ra r v) (liftState ra k) + | (Footprint k) -> liftState ra k + | (Barrier _ k) -> liftState ra k + | (Print _ k) -> liftState ra k (* TODO *) + | (Fail descr) -> failS descr + | (Exception e) -> throwS e +end diff --git a/src/gen_lib/state_monad.lem b/src/gen_lib/state_monad.lem index a2919762..89021890 100644 --- a/src/gen_lib/state_monad.lem +++ b/src/gen_lib/state_monad.lem @@ -94,12 +94,12 @@ let assert_expS exp msg = if exp then returnS () else failS msg (* For early return, we abuse exceptions by throwing and catching the return value. The exception type is "either 'r 'e", where "Right e" represents a proper exception and "Left r" an early return of value "r". *) -type monadSR 'regs 'a 'r 'e = monadS 'regs 'a (either 'r 'e) +type monadRS 'regs 'a 'r 'e = monadS 'regs 'a (either 'r 'e) -val early_returnS : forall 'regs 'a 'r 'e. 'r -> monadSR 'regs 'a 'r 'e +val early_returnS : forall 'regs 'a 'r 'e. 'r -> monadRS 'regs 'a 'r 'e let early_returnS r = throwS (Left r) -val catch_early_returnS : forall 'regs 'a 'e. monadSR 'regs 'a 'a 'e -> monadS 'regs 'a 'e +val catch_early_returnS : forall 'regs 'a 'e. monadRS 'regs 'a 'a 'e -> monadS 'regs 'a 'e let catch_early_returnS m = try_catchS m (function @@ -108,12 +108,12 @@ let catch_early_returnS m = end) (* Lift to monad with early return by wrapping exceptions *) -val liftSR : forall 'a 'r 'regs 'e. monadS 'regs 'a 'e -> monadSR 'regs 'a 'r 'e -let liftSR m = try_catchS m (fun e -> throwS (Right e)) +val liftRS : forall 'a 'r 'regs 'e. monadS 'regs 'a 'e -> monadRS 'regs 'a 'r 'e +let liftRS m = try_catchS m (fun e -> throwS (Right e)) (* Catch exceptions in the presence of early returns *) -val try_catchSR : forall 'regs 'a 'r 'e1 'e2. monadSR 'regs 'a 'r 'e1 -> ('e1 -> monadSR 'regs 'a 'r 'e2) -> monadSR 'regs 'a 'r 'e2 -let try_catchSR m h = +val try_catchRS : forall 'regs 'a 'r 'e1 'e2. monadRS 'regs 'a 'r 'e1 -> ('e1 -> monadRS 'regs 'a 'r 'e2) -> monadRS 'regs 'a 'r 'e2 +let try_catchRS m h = try_catchS m (function | Left r -> throwS (Left r) diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 58bbfc4b..c872d420 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -883,8 +883,7 @@ let doc_exp_lem, doc_let_lem = let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in let middle = match fst (untyp_pat pat) with - | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> - string ">>" + | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> string ">>" | P_aux (P_tup _, _) when not (IdSet.mem (mk_id "varstup") (find_e_ids e2)) -> (* Work around indentation issues in Lem when translating @@ -894,7 +893,8 @@ let doc_exp_lem, doc_let_lem = doc_pat_lem ctxt true pat; string "= varstup in"] | _ -> - separate space [string ">>= fun"; doc_pat_lem ctxt true pat; arrow] + separate space [string ">>= fun"; + doc_pat_lem ctxt true pat; arrow] in infix 0 1 middle (expV b e1) (expN e2) in @@ -908,12 +908,11 @@ let doc_exp_lem, doc_let_lem = raise (Reporting_basic.err_unreachable l "pretty-printing non-constant sizeof expressions to Lem not supported")) | E_return r -> - let ret_monad = if !opt_sequential then " : MR regstate" else " : MR" in let ta = if contains_t_pp_var ctxt (typ_of full_exp) || contains_t_pp_var ctxt (typ_of r) then empty else separate space - [string ret_monad; + [string ": MR"; parens (doc_typ_lem (typ_of full_exp)); parens (doc_typ_lem (typ_of r))] in align (parens (string "early_return" ^//^ expV true r ^//^ ta)) @@ -1438,17 +1437,11 @@ let pp_defs_lem (types_file,types_modules) (defs_file,defs_modules) (Defs defs) separate empty (List.map doc_def_lem statedefs); hardline; hardline; register_refs; hardline; - (* if !opt_sequential then - concat [ - string ("type MR 'a 'r = State_monad.MR regstate 'a 'r " ^ exc_typ); hardline; - string ("type M 'a = State_monad.M regstate 'a " ^ exc_typ); hardline; - ] - else *) - concat [ - string ("type MR 'a 'r = monadR register_value 'a 'r " ^ exc_typ); hardline; - string ("type M 'a = monad register_value 'a " ^ exc_typ); hardline - ] - ]); + concat [ + string ("type MR 'a 'r = base_monadR register_value regstate 'a 'r " ^ exc_typ); hardline; + string ("type M 'a = base_monad register_value regstate 'a " ^ exc_typ); hardline + ] + ]); (print defs_file) (concat [string "(*" ^^ (string top_line) ^^ string "*)";hardline; diff --git a/src/process_file.ml b/src/process_file.ml index 4856c20a..34c1a255 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -51,9 +51,6 @@ open PPrint open Pretty_print_common -let opt_lem_sequential = ref false -let opt_lem_mwords = ref false - type out_type = | Lem_ast_out | Lem_out of string list @@ -246,10 +243,7 @@ let output_lem filename libs defs = let generated_line = generated_line filename in (* let seq_suffix = if !Pretty_print_lem.opt_sequential then "_sequential" else "" in *) let types_module = (filename ^ "_types") in - let monad_modules = ["Prompt_monad"; "Prompt"; "State"] in - (* if !Pretty_print_lem.opt_sequential - then ["State_monad"; "State"] - else ["Prompt_monad"; "Prompt"] in *) + let monad_modules = ["Prompt_monad"; "Prompt"] in let operators_module = if !Pretty_print_lem.opt_mwords then "Sail_operators_mwords" diff --git a/src/process_file.mli b/src/process_file.mli index b38b4259..cc94888e 100644 --- a/src/process_file.mli +++ b/src/process_file.mli @@ -65,8 +65,6 @@ val rewrite_ast_check : Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs val load_file_no_check : Ast.order -> string -> unit Ast.defs val load_file : Ast.order -> Type_check.Env.t -> string -> Type_check.tannot Ast.defs * Type_check.Env.t -val opt_lem_sequential : bool ref -val opt_lem_mwords : bool ref val opt_just_check : bool ref val opt_ddump_tc_ast : bool ref val opt_ddump_rewrite_ast : ((string * int) option) ref |
