diff options
| author | Jon French | 2018-06-11 15:25:02 +0100 |
|---|---|---|
| committer | Jon French | 2018-06-11 15:25:02 +0100 |
| commit | 826e94548a86a88d8fefeb1edef177c02bf5d68d (patch) | |
| tree | fc9a5484440e030cc479101c5cab345c1c77468e /lib | |
| parent | 5717bb3d0cef5932cb2b33bc66b3b2f0c0552164 (diff) | |
| parent | 4336409f923c10a8c5e4acc91fa7e6ef5551a88f (diff) | |
Merge branch 'sail2' into mappings
(involved some manual tinkering with gitignore, type_check, riscv)
Diffstat (limited to 'lib')
31 files changed, 4910 insertions, 175 deletions
@@ -1,4 +1,3 @@ true: debug - <*.m{l,li}>: package(lem), package(linksem), package(zarith) <main.native>: package(lem), package(linksem), package(zarith) diff --git a/lib/arith.sail b/lib/arith.sail index 54ecdbbc..f713805a 100644 --- a/lib/arith.sail +++ b/lib/arith.sail @@ -5,41 +5,43 @@ $include <flow.sail> // ***** Addition ***** -val add_atom = {ocaml: "add_int", lem: "integerAdd", c: "add_int"} : forall 'n 'm. +val add_atom = {ocaml: "add_int", lem: "integerAdd", c: "add_int", coq: "Z.add"} : forall 'n 'm. (atom('n), atom('m)) -> atom('n + 'm) -val add_int = {ocaml: "add_int", lem: "integerAdd", c: "add_int"} : (int, int) -> int +val add_int = {ocaml: "add_int", lem: "integerAdd", c: "add_int", coq: "Z.add"} : (int, int) -> int overload operator + = {add_atom, add_int} // ***** Subtraction ***** -val sub_atom = {ocaml: "sub_int", lem: "integerMinus", c: "sub_int"} : forall 'n 'm. +val sub_atom = {ocaml: "sub_int", lem: "integerMinus", c: "sub_int", coq: "Z.sub"} : forall 'n 'm. (atom('n), atom('m)) -> atom('n - 'm) -val sub_int = {ocaml: "sub_int", lem: "integerMinus", c: "sub_int"} : (int, int) -> int +val sub_int = {ocaml: "sub_int", lem: "integerMinus", c: "sub_int", coq: "Z.sub"} : (int, int) -> int overload operator - = {sub_atom, sub_int} // ***** Negation ***** -val negate_atom = {ocaml: "negate", lem: "integerNegate", c: "neg_int"} : forall 'n. atom('n) -> atom(- 'n) +val negate_atom = {ocaml: "negate", lem: "integerNegate", c: "neg_int", coq: "Z.opp"} : forall 'n. atom('n) -> atom(- 'n) -val negate_int = {ocaml: "negate", lem: "integerNegate", c: "neg_int"} : int -> int +val negate_int = {ocaml: "negate", lem: "integerNegate", c: "neg_int", coq: "Z.opp"} : int -> int overload negate = {negate_atom, negate_int} // ***** Multiplication ***** -val mult_atom = {ocaml: "mult", lem: "integerMult", c: "mult_int"} : forall 'n 'm. +val mult_atom = {ocaml: "mult", lem: "integerMult", c: "mult_int", coq: "Z.mul"} : forall 'n 'm. (atom('n), atom('m)) -> atom('n * 'm) -val mult_int = {ocaml: "mult", lem: "integerMult", c: "mult_int"} : (int, int) -> int +val mult_int = {ocaml: "mult", lem: "integerMult", c: "mult_int", coq: "Z.mul"} : (int, int) -> int overload operator * = {mult_atom, mult_int} val "print_int" : (string, int) -> unit +val "prerr_int" : (string, int) -> unit + // ***** Integer shifts ***** val shl_int = "shl_int" : (int, int) -> int @@ -52,7 +54,8 @@ val div_int = { smt: "div", ocaml: "quotient", lem: "integerDiv", - c: "div_int" + c: "div_int", + coq: "Z.quot" } : (int, int) -> int overload operator / = {div_int} @@ -61,7 +64,8 @@ val mod_int = { smt: "mod", ocaml: "modulus", lem: "integerMod", - c: "mod_int" + c: "mod_int", + coq: "Z.rem" } : (int, int) -> int overload operator % = {mod_int} @@ -69,7 +73,8 @@ overload operator % = {mod_int} val abs_int = { smt : "abs", ocaml: "abs_int", - lem: "abs_int" + lem: "abs_int", + coq: "Z.abs" } : (int, int) -> int $endif diff --git a/lib/coq/.gitignore b/lib/coq/.gitignore new file mode 100644 index 00000000..1aa62803 --- /dev/null +++ b/lib/coq/.gitignore @@ -0,0 +1 @@ +deps
\ No newline at end of file diff --git a/lib/coq/Makefile b/lib/coq/Makefile new file mode 100644 index 00000000..d974b692 --- /dev/null +++ b/lib/coq/Makefile @@ -0,0 +1,24 @@ +BBV_DIR=../../../bbv + +SRC=Prompt_monad.v Prompt.v Sail_impl_base.v Sail_instr_kinds.v Sail_operators_bitlists.v Sail_operators_mwords.v Sail_operators.v Sail_values.v State_monad.v State.v + +COQ_LIBS = -R . Sail -R "$(BBV_DIR)" bbv + +TARGETS=$(SRC:.v=.vo) + +.PHONY: all clean *.ide + +all: $(TARGETS) +clean: + rm -f -- $(TARGETS) $(TARGETS:.vo=.glob) $(TARGETS:%.vo=.%.aux) deps + +%.vo: %.v + coqc $(COQ_LIBS) $< + +%.ide: %.v + coqide $(COQ_LIBS) $< + +deps: $(SRC) + coqdep $(COQ_LIBS) $(SRC) > deps + +-include deps diff --git a/lib/coq/Prompt.v b/lib/coq/Prompt.v new file mode 100644 index 00000000..6c4be18e --- /dev/null +++ b/lib/coq/Prompt.v @@ -0,0 +1,72 @@ +(*Require Import Sail_impl_base*) +Require Import Sail_values. +Require Import Prompt_monad. + +Require Import List. +Import ListNotations. +(* + +val iter_aux : forall 'rv 'a 'e. integer -> (integer -> 'a -> monad 'rv unit 'e) -> list 'a -> monad 'rv unit 'e +let rec iter_aux i f xs = match xs with + | x :: xs -> f i x >> iter_aux (i + 1) f xs + | [] -> return () + end + +declare {isabelle} termination_argument iter_aux = automatic + +val iteri : forall 'rv 'a 'e. (integer -> 'a -> monad 'rv unit 'e) -> list 'a -> monad 'rv unit 'e +let iteri f xs = iter_aux 0 f xs + +val iter : forall 'rv 'a 'e. ('a -> monad 'rv unit 'e) -> list 'a -> monad 'rv unit 'e +let iter f xs = iteri (fun _ x -> f x) xs + +val foreachM : forall 'a 'rv 'vars 'e. + list 'a -> 'vars -> ('a -> 'vars -> monad 'rv 'vars 'e) -> monad 'rv 'vars 'e*) +Fixpoint foreachM {a rv Vars e} (l : list a) (vars : Vars) (body : a -> Vars -> monad rv Vars e) : monad rv Vars e := +match l with +| [] => returnm vars +| (x :: xs) => + body x vars >>= fun vars => + foreachM xs vars body +end. + +(*declare {isabelle} termination_argument foreachM = automatic + + +val whileM : forall 'rv 'vars 'e. 'vars -> ('vars -> monad 'rv bool 'e) -> + ('vars -> monad 'rv 'vars 'e) -> monad 'rv 'vars 'e +let rec whileM vars cond body = + cond vars >>= fun cond_val -> + if cond_val then + body vars >>= fun vars -> whileM vars cond body + else return vars + +val untilM : forall 'rv 'vars 'e. 'vars -> ('vars -> monad 'rv bool 'e) -> + ('vars -> monad 'rv 'vars 'e) -> monad 'rv 'vars 'e +let rec untilM vars cond body = + body vars >>= fun vars -> + cond vars >>= fun cond_val -> + if cond_val then return vars else untilM vars cond body + +(*let write_two_regs r1 r2 vec = + let is_inc = + let is_inc_r1 = is_inc_of_reg r1 in + let is_inc_r2 = is_inc_of_reg r2 in + let () = ensure (is_inc_r1 = is_inc_r2) + "write_two_regs called with vectors of different direction" in + is_inc_r1 in + + let (size_r1 : integer) = size_of_reg r1 in + let (start_vec : integer) = get_start vec in + let size_vec = length vec in + let r1_v = + if is_inc + then slice vec start_vec (size_r1 - start_vec - 1) + else slice vec start_vec (start_vec - size_r1 - 1) in + let r2_v = + if is_inc + then slice vec (size_r1 - start_vec) (size_vec - start_vec) + else slice vec (start_vec - size_r1) (start_vec - size_vec) in + write_reg r1 r1_v >> write_reg r2 r2_v*) + +*) diff --git a/lib/coq/Prompt_monad.v b/lib/coq/Prompt_monad.v new file mode 100644 index 00000000..ef399444 --- /dev/null +++ b/lib/coq/Prompt_monad.v @@ -0,0 +1,228 @@ +Require Import String. +(*Require Import Sail_impl_base*) +Require Import Sail_instr_kinds. +Require Import Sail_values. + + + +Definition register_name := string. +Definition address := list bitU. + +Inductive monad regval a e := + | Done : a -> monad regval a e + (* Read a number : bytes from memory, returned in little endian order *) + | Read_mem : read_kind -> address -> nat -> (list memory_byte -> monad regval a e) -> monad regval a e + (* Read the tag : a memory address *) + | Read_tag : address -> (bitU -> monad regval a e) -> monad regval a e + (* Tell the system a write is imminent, at address lifted, : size nat *) + | Write_ea : write_kind -> address -> nat -> monad regval a e -> monad regval a e + (* Request the result : store-exclusive *) + | Excl_res : (bool -> monad regval a e) -> monad regval a e + (* Request to write memory at last signalled address. Memory value should be 8 + times the size given in ea signal, given in little endian order *) + | Write_memv : list memory_byte -> (bool -> monad regval a e) -> monad regval a e + (* Request to write the tag at last signalled address. *) + | Write_tagv : bitU -> (bool -> monad regval a e) -> monad regval a e + (* Tell the system to dynamically recalculate dependency footprint *) + | Footprint : monad regval a e -> monad regval a e + (* Request a memory barrier *) + | Barrier : barrier_kind -> monad regval a e -> monad regval a e + (* Request to read register, will track dependency when mode.track_values *) + | Read_reg : register_name -> (regval -> monad regval a e) -> monad regval a e + (* Request to write register *) + | Write_reg : register_name -> regval -> monad regval a e -> monad regval a e + (*Result : a failed assert with possible error message to report*) + | Fail : string -> monad regval a e + | Error : string -> monad regval a e + (* Exception : type e *) + | Exception : e -> monad regval a e. + (* TODO: Reading/writing tags *) + +Arguments Done [_ _ _]. +Arguments Read_mem [_ _ _]. +Arguments Read_tag [_ _ _]. +Arguments Write_ea [_ _ _]. +Arguments Excl_res [_ _ _]. +Arguments Write_memv [_ _ _]. +Arguments Write_tagv [_ _ _]. +Arguments Footprint [_ _ _]. +Arguments Barrier [_ _ _]. +Arguments Read_reg [_ _ _]. +Arguments Write_reg [_ _ _]. +Arguments Fail [_ _ _]. +Arguments Error [_ _ _]. +Arguments Exception [_ _ _]. + +(*val return : forall rv a e. a -> monad rv a e*) +Definition returnm {rv A E} (a : A) : monad rv A E := Done a. + +(*val bind : forall rv a b e. monad rv a e -> (a -> monad rv b e) -> monad rv b e*) +Fixpoint bind {rv A B E} (m : monad rv A E) (f : A -> monad rv B E) := match m with + | Done a => f a + | Read_mem rk a sz k => Read_mem rk a sz (fun v => bind (k v) f) + | Read_tag a k => Read_tag a (fun v => bind (k v) f) + | Write_memv descr k => Write_memv descr (fun v => bind (k v) f) + | Write_tagv t k => Write_tagv t (fun v => bind (k v) f) + | Read_reg descr k => Read_reg descr (fun v => bind (k v) f) + | Excl_res k => Excl_res (fun v => bind (k v) f) + | Write_ea wk a sz k => Write_ea wk a sz (bind k f) + | Footprint k => Footprint (bind k f) + | Barrier bk k => Barrier bk (bind k f) + | Write_reg r v k => Write_reg r v (bind k f) + | Fail descr => Fail descr + | Error descr => Error descr + | Exception e => Exception e +end. + +Notation "m >>= f" := (bind m f) (at level 50, left associativity). +(*val (>>) : forall rv b e. monad rv unit e -> monad rv b e -> monad rv b e*) +Definition bind0 {rv A E} (m : monad rv unit E) (n : monad rv A E) := + m >>= fun (_ : unit) => n. +Notation "m >> n" := (bind0 m n) (at level 50, left associativity). + +(*val exit : forall rv a e. unit -> monad rv a e*) +Definition exit {rv A E} (_ : unit) : monad rv A E := Fail "exit". + +(*val assert_exp : forall rv e. bool -> string -> monad rv unit e*) +Definition assert_exp {rv E} (exp :bool) msg : monad rv unit E := + if exp then Done tt else Fail msg. + +Definition assert_exp' {rv E} (exp :bool) msg : monad rv (exp = true) E := + if exp return monad rv (exp = true) E then Done eq_refl else Fail msg. +Definition bindH {rv A P E} (m : monad rv P E) (n : monad rv A E) := + m >>= fun (H : P) => n. +Notation "m >>> n" := (bindH m n) (at level 50, left associativity). + +(*val throw : forall rv a e. e -> monad rv a e*) +Definition throw {rv A E} e : monad rv A E := Exception e. + +(*val try_catch : forall rv a e1 e2. monad rv a e1 -> (e1 -> monad rv a e2) -> monad rv a e2*) +Fixpoint try_catch {rv A E1 E2} (m : monad rv A E1) (h : E1 -> monad rv A E2) := match m with + | Done a => Done a + | Read_mem rk a sz k => Read_mem rk a sz (fun v => try_catch (k v) h) + | Read_tag a k => Read_tag a (fun v => try_catch (k v) h) + | Write_memv descr k => Write_memv descr (fun v => try_catch (k v) h) + | Write_tagv t k => Write_tagv t (fun v => try_catch (k v) h) + | Read_reg descr k => Read_reg descr (fun v => try_catch (k v) h) + | Excl_res k => Excl_res (fun v => try_catch (k v) h) + | Write_ea wk a sz k => Write_ea wk a sz (try_catch k h) + | Footprint k => Footprint (try_catch k h) + | Barrier bk k => Barrier bk (try_catch k h) + | Write_reg r v k => Write_reg r v (try_catch k h) + | Fail descr => Fail descr + | Error descr => Error descr + | Exception e => h e +end. + +(* For early return, we abuse exceptions by throwing and catching + the return value. The exception type is "either r e", where "inr e" + represents a proper exception and "inl r" an early return : value "r". *) +Definition monadR rv a r e := monad rv a (sum r e). + +(*val early_return : forall rv a r e. r -> monadR rv a r e*) +Definition early_return {rv A R E} (r : R) : monadR rv A R E := throw (inl r). + +(*val catch_early_return : forall rv a e. monadR rv a a e -> monad rv a e*) +Definition catch_early_return {rv A E} (m : monadR rv A A E) := + try_catch m + (fun r => match r with + | inl a => returnm a + | inr e => throw e + end). + +(* Lift to monad with early return by wrapping exceptions *) +(*val liftR : forall rv a r e. monad rv a e -> monadR rv a r e*) +Definition liftR {rv A R E} (m : monad rv A E) : monadR rv A R E := + try_catch m (fun e => throw (inr e)). + +(* Catch exceptions in the presence : early returns *) +(*val try_catchR : forall rv a r e1 e2. monadR rv a r e1 -> (e1 -> monadR rv a r e2) -> monadR rv a r e2*) +Definition try_catchR {rv A R E1 E2} (m : monadR rv A R E1) (h : E1 -> monadR rv A R E2) := + try_catch m + (fun r => match r with + | inl r => throw (inl r) + | inr e => h e + end). + + +(*Parameter read_mem : forall {rv n m e}, read_kind -> mword m -> Z -> monad rv (mword n) e.*) +Definition read_mem {a b e rv : Type} `{Bitvector a} `{Bitvector b} (rk : read_kind ) (addr : a) (sz : Z ) : monad rv b e:= + let k bytes : monad rv b e := Done (bits_of_mem_bytes bytes) in + Read_mem rk (bits_of addr) (Z.to_nat sz) k. + +(*val read_tag : forall rv a e. Bitvector a => a -> monad rv bitU e*) +Definition read_tag {rv a e} `{Bitvector a} (addr : a) : monad rv bitU e := + Read_tag (bits_of addr) returnm. + +(*val excl_result : forall rv e. unit -> monad rv bool e*) +Definition excl_result {rv e} (_:unit) : monad rv bool e := + let k successful := (returnm successful) in + Excl_res k. + +Definition write_mem_ea {rv a E} `{Bitvector a} wk (addr: a) sz : monad rv unit E := + Write_ea wk (bits_of addr) (Z.to_nat sz) (Done tt). + +Definition write_mem_val {rv a e} `{Bitvector a} (v : a) : monad rv bool e := match mem_bytes_of_bits v with + | Some v => Write_memv v returnm + | None => Fail "write_mem_val" +end. + +(*val write_tag_val : forall rv e. bitU -> monad rv bool e*) +Definition write_tag_val {rv e} (b : bitU) : monad rv bool e := Write_tagv b returnm. + +Definition read_reg {s rv a e} (reg : register_ref s rv a) : monad rv a e := + let k v := + match reg.(of_regval) v with + | Some v => Done v + | None => Error "read_reg: unrecognised value" + end + in + Read_reg reg.(name) k. + +(* TODO +val read_reg_range : forall s r rv a e. Bitvector a => register_ref s rv r -> integer -> integer -> monad rv a e +Definition read_reg_range reg i j := + read_reg_aux of_bits (external_reg_slice reg (natFromInteger i,natFromInteger j)) + +Definition read_reg_bit reg i := + read_reg_aux (fun v -> v) (external_reg_slice reg (natFromInteger i,natFromInteger i)) >>= fun v -> + returnm (extract_only_element v) + +Definition read_reg_field reg regfield := + read_reg_aux (external_reg_field_whole reg regfield) + +Definition read_reg_bitfield reg regfield := + read_reg_aux (external_reg_field_whole reg regfield) >>= fun v -> + returnm (extract_only_element v)*) + +Definition reg_deref {s rv a e} := @read_reg s rv a e. + +(*Parameter write_reg : forall {s rv a e}, register_ref s rv a -> a -> monad rv unit e.*) +Definition write_reg {s rv a e} (reg : register_ref s rv a) (v : a) : monad rv unit e := + Write_reg reg.(name) (reg.(regval_of) v) (Done tt). + +(* TODO +Definition write_reg reg v := + write_reg_aux (external_reg_whole reg) v +Definition write_reg_range reg i j v := + write_reg_aux (external_reg_slice reg (natFromInteger i,natFromInteger j)) v +Definition write_reg_pos reg i v := + let iN := natFromInteger i in + write_reg_aux (external_reg_slice reg (iN,iN)) [v] +Definition write_reg_bit := write_reg_pos +Definition write_reg_field reg regfield v := + write_reg_aux (external_reg_field_whole reg regfield.field_name) v +Definition write_reg_field_bit reg regfield bit := + write_reg_aux (external_reg_field_whole reg regfield.field_name) + (Vector [bit] 0 (is_inc_of_reg reg)) +Definition write_reg_field_range reg regfield i j v := + write_reg_aux (external_reg_field_slice reg regfield.field_name (natFromInteger i,natFromInteger j)) v +Definition write_reg_field_pos reg regfield i v := + write_reg_field_range reg regfield i i [v] +Definition write_reg_field_bit := write_reg_field_pos*) + +(*val barrier : forall rv e. barrier_kind -> monad rv unit e*) +Definition barrier {rv e} bk : monad rv unit e := Barrier bk (Done tt). + +(*val footprint : forall rv e. unit -> monad rv unit e*) +Definition footprint {rv e} (_ : unit) : monad rv unit e := Footprint (Done tt). diff --git a/lib/coq/Sail_impl_base.v b/lib/coq/Sail_impl_base.v new file mode 100644 index 00000000..df7854e7 --- /dev/null +++ b/lib/coq/Sail_impl_base.v @@ -0,0 +1,1103 @@ +(*========================================================================*) +(* Sail *) +(* *) +(* Copyright (c) 2013-2017 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* Alasdair Armstrong *) +(* Brian Campbell *) +(* Thomas Bauereiss *) +(* Anthony Fox *) +(* Jon French *) +(* Dominic Mulligan *) +(* Stephen Kell *) +(* Mark Wassell *) +(* *) +(* All rights reserved. *) +(* *) +(* This software was developed by the University of Cambridge Computer *) +(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) +(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(*========================================================================*) + +Require Import Sail_instr_kinds. + +(* +class ( EnumerationType 'a ) + val toNat : 'a -> nat +end + + +val enumeration_typeCompare : forall 'a. EnumerationType 'a => 'a -> 'a -> ordering +let ~{ocaml} enumeration_typeCompare e1 e2 = + compare (toNat e1) (toNat e2) +let inline {ocaml} enumeration_typeCompare = defaultCompare + + +default_instance forall 'a. EnumerationType 'a => (Ord 'a) + let compare = enumeration_typeCompare + let (<) r1 r2 = (enumeration_typeCompare r1 r2) = LT + let (<=) r1 r2 = (enumeration_typeCompare r1 r2) <> GT + let (>) r1 r2 = (enumeration_typeCompare r1 r2) = GT + let (>=) r1 r2 = (enumeration_typeCompare r1 r2) <> LT +end + + + +(* maybe isn't a member of type Ord - this should be in the Lem standard library*) +instance forall 'a. Ord 'a => (Ord (maybe 'a)) + let compare = maybeCompare compare + let (<) r1 r2 = (maybeCompare compare r1 r2) = LT + let (<=) r1 r2 = (maybeCompare compare r1 r2) <> GT + let (>) r1 r2 = (maybeCompare compare r1 r2) = GT + let (>=) r1 r2 = (maybeCompare compare r1 r2) <> LT +end + +type word8 = nat (* bounded at a byte, for when lem supports it*) + +type end_flag = + | E_big_endian + | E_little_endian + +type bit = + | Bitc_zero + | Bitc_one + +type bit_lifted = + | Bitl_zero + | Bitl_one + | Bitl_undef (* used for modelling h/w arch unspecified bits *) + | Bitl_unknown (* used for interpreter analysis exhaustive execution *) + +type direction = + | D_increasing + | D_decreasing + +let dir_of_bool is_inc = if is_inc then D_increasing else D_decreasing +let bool_of_dir = function + | D_increasing -> true + | D_decreasing -> false + end + +(* at some point this should probably not mention bit_lifted anymore *) +type register_value = <| + rv_bits: list bit_lifted (* MSB first, smallest index number *); + rv_dir: direction; + rv_start: nat ; + rv_start_internal: nat; + (*when dir is increasing, rv_start = rv_start_internal. + Otherwise, tells interpreter how to reconstruct a proper decreasing value*) + |> + +type byte_lifted = Byte_lifted of list bit_lifted (* of length 8 *) (*MSB first everywhere*) + +type instruction_field_value = list bit + +type byte = Byte of list bit (* of length 8 *) (*MSB first everywhere*) + +type address_lifted = Address_lifted of list byte_lifted (* of length 8 for 64bit machines*) * maybe integer +(* for both values of end_flag, MSBy first *) + +type memory_byte = byte_lifted (* of length 8 *) (*MSB first everywhere*) + +type memory_value = list memory_byte +(* the list is of length >=1 *) +(* the head of the list is the byte stored at the lowest address; +when calling a Sail function with a wmv effect, the least significant 8 +bits of the bit vector passed to the function will be interpreted as +the lowest address byte; similarly, when calling a Sail function with +rmem effect, the lowest address byte will be placed in the least +significant 8 bits of the bit vector returned by the function; this +behaviour is consistent with little-endian. *) + + +(* not sure which of these is more handy yet *) +type address = Address of list byte (* of length 8 *) * integer +(* type address = Address of integer *) + +type opcode = Opcode of list byte (* of length 4 *) + +(** typeclass instantiations *) + +instance (EnumerationType bit) + let toNat = function + | Bitc_zero -> 0 + | Bitc_one -> 1 + end +end + +instance (EnumerationType bit_lifted) + let toNat = function + | Bitl_zero -> 0 + | Bitl_one -> 1 + | Bitl_undef -> 2 + | Bitl_unknown -> 3 + end +end + +let ~{ocaml} byte_liftedCompare (Byte_lifted b1) (Byte_lifted b2) = compare b1 b2 +let inline {ocaml} byte_liftedCompare = defaultCompare + +let ~{ocaml} byte_liftedLess b1 b2 = byte_liftedCompare b1 b2 = LT +let ~{ocaml} byte_liftedLessEq b1 b2 = byte_liftedCompare b1 b2 <> GT +let ~{ocaml} byte_liftedGreater b1 b2 = byte_liftedCompare b1 b2 = GT +let ~{ocaml} byte_liftedGreaterEq b1 b2 = byte_liftedCompare b1 b2 <> LT + +let inline {ocaml} byte_liftedLess = defaultLess +let inline {ocaml} byte_liftedLessEq = defaultLessEq +let inline {ocaml} byte_liftedGreater = defaultGreater +let inline {ocaml} byte_liftedGreaterEq = defaultGreaterEq + +instance (Ord byte_lifted) + let compare = byte_liftedCompare + let (<) = byte_liftedLess + let (<=) = byte_liftedLessEq + let (>) = byte_liftedGreater + let (>=) = byte_liftedGreaterEq +end + +let ~{ocaml} byteCompare (Byte b1) (Byte b2) = compare b1 b2 +let inline {ocaml} byteCompare = defaultCompare + +let ~{ocaml} byteLess b1 b2 = byteCompare b1 b2 = LT +let ~{ocaml} byteLessEq b1 b2 = byteCompare b1 b2 <> GT +let ~{ocaml} byteGreater b1 b2 = byteCompare b1 b2 = GT +let ~{ocaml} byteGreaterEq b1 b2 = byteCompare b1 b2 <> LT + +let inline {ocaml} byteLess = defaultLess +let inline {ocaml} byteLessEq = defaultLessEq +let inline {ocaml} byteGreater = defaultGreater +let inline {ocaml} byteGreaterEq = defaultGreaterEq + +instance (Ord byte) + let compare = byteCompare + let (<) = byteLess + let (<=) = byteLessEq + let (>) = byteGreater + let (>=) = byteGreaterEq +end + + + + + +let ~{ocaml} opcodeCompare (Opcode o1) (Opcode o2) = + compare o1 o2 +let {ocaml} opcodeCompare = defaultCompare + +let ~{ocaml} opcodeLess b1 b2 = opcodeCompare b1 b2 = LT +let ~{ocaml} opcodeLessEq b1 b2 = opcodeCompare b1 b2 <> GT +let ~{ocaml} opcodeGreater b1 b2 = opcodeCompare b1 b2 = GT +let ~{ocaml} opcodeGreaterEq b1 b2 = opcodeCompare b1 b2 <> LT + +let inline {ocaml} opcodeLess = defaultLess +let inline {ocaml} opcodeLessEq = defaultLessEq +let inline {ocaml} opcodeGreater = defaultGreater +let inline {ocaml} opcodeGreaterEq = defaultGreaterEq + +instance (Ord opcode) + let compare = opcodeCompare + let (<) = opcodeLess + let (<=) = opcodeLessEq + let (>) = opcodeGreater + let (>=) = opcodeGreaterEq +end + +let addressCompare (Address b1 i1) (Address b2 i2) = compare i1 i2 +(* this cannot be defaultCompare for OCaml because addresses contain big ints *) + +let addressLess b1 b2 = addressCompare b1 b2 = LT +let addressLessEq b1 b2 = addressCompare b1 b2 <> GT +let addressGreater b1 b2 = addressCompare b1 b2 = GT +let addressGreaterEq b1 b2 = addressCompare b1 b2 <> LT + +instance (SetType address) + let setElemCompare = addressCompare +end + +instance (Ord address) + let compare = addressCompare + let (<) = addressLess + let (<=) = addressLessEq + let (>) = addressGreater + let (>=) = addressGreaterEq +end + +let {coq; ocaml} addressEqual a1 a2 = (addressCompare a1 a2) = EQ +let inline {hol; isabelle} addressEqual = unsafe_structural_equality + +let {coq; ocaml} addressInequal a1 a2 = not (addressEqual a1 a2) +let inline {hol; isabelle} addressInequal = unsafe_structural_inequality + +instance (Eq address) + let (=) = addressEqual + let (<>) = addressInequal +end + +let ~{ocaml} directionCompare d1 d2 = + match (d1, d2) with + | (D_decreasing, D_increasing) -> GT + | (D_increasing, D_decreasing) -> LT + | _ -> EQ + end +let inline {ocaml} directionCompare = defaultCompare + +let ~{ocaml} directionLess b1 b2 = directionCompare b1 b2 = LT +let ~{ocaml} directionLessEq b1 b2 = directionCompare b1 b2 <> GT +let ~{ocaml} directionGreater b1 b2 = directionCompare b1 b2 = GT +let ~{ocaml} directionGreaterEq b1 b2 = directionCompare b1 b2 <> LT + +let inline {ocaml} directionLess = defaultLess +let inline {ocaml} directionLessEq = defaultLessEq +let inline {ocaml} directionGreater = defaultGreater +let inline {ocaml} directionGreaterEq = defaultGreaterEq + +instance (Ord direction) + let compare = directionCompare + let (<) = directionLess + let (<=) = directionLessEq + let (>) = directionGreater + let (>=) = directionGreaterEq +end + +instance (Show direction) + let show = function D_increasing -> "D_increasing" | D_decreasing -> "D_decreasing" end +end + +let ~{ocaml} register_valueCompare rv1 rv2 = + compare (rv1.rv_bits, rv1.rv_dir, rv1.rv_start, rv1.rv_start_internal) + (rv2.rv_bits, rv2.rv_dir, rv2.rv_start, rv2.rv_start_internal) +let inline {ocaml} register_valueCompare = defaultCompare + +let ~{ocaml} register_valueLess b1 b2 = register_valueCompare b1 b2 = LT +let ~{ocaml} register_valueLessEq b1 b2 = register_valueCompare b1 b2 <> GT +let ~{ocaml} register_valueGreater b1 b2 = register_valueCompare b1 b2 = GT +let ~{ocaml} register_valueGreaterEq b1 b2 = register_valueCompare b1 b2 <> LT + +let inline {ocaml} register_valueLess = defaultLess +let inline {ocaml} register_valueLessEq = defaultLessEq +let inline {ocaml} register_valueGreater = defaultGreater +let inline {ocaml} register_valueGreaterEq = defaultGreaterEq + +instance (Ord register_value) + let compare = register_valueCompare + let (<) = register_valueLess + let (<=) = register_valueLessEq + let (>) = register_valueGreater + let (>=) = register_valueGreaterEq +end + +let address_liftedCompare (Address_lifted b1 i1) (Address_lifted b2 i2) = + compare (i1,b1) (i2,b2) +(* this cannot be defaultCompare for OCaml because address_lifteds contain big + ints *) + +let address_liftedLess b1 b2 = address_liftedCompare b1 b2 = LT +let address_liftedLessEq b1 b2 = address_liftedCompare b1 b2 <> GT +let address_liftedGreater b1 b2 = address_liftedCompare b1 b2 = GT +let address_liftedGreaterEq b1 b2 = address_liftedCompare b1 b2 <> LT + +instance (Ord address_lifted) + let compare = address_liftedCompare + let (<) = address_liftedLess + let (<=) = address_liftedLessEq + let (>) = address_liftedGreater + let (>=) = address_liftedGreaterEq +end + +(* Registers *) +type slice = (nat * nat) + +type reg_name = + (* do we really need this here if ppcmem already has this information by itself? *) +| Reg of string * nat * nat * direction +(*Name of the register, accessing the entire register, the start and size of this register, and its direction *) + +| Reg_slice of string * nat * direction * slice +(* Name of the register, accessing from the bit indexed by the first +to the bit indexed by the second integer of the slice, inclusive. For +machineDef* the first is a smaller number or equal to the second, adjusted +to reflect the correct span direction in the interpreter side. *) + +| Reg_field of string * nat * direction * string * slice +(*Name of the register, start and direction, and name of the field of the register +accessed. The slice specifies where this field is in the register*) + +| Reg_f_slice of string * nat * direction * string * slice * slice +(* The first four components are as in Reg_field; the final slice +specifies a part of the field, indexed w.r.t. the register as a whole *) + +let register_base_name : reg_name -> string = function + | Reg s _ _ _ -> s + | Reg_slice s _ _ _ -> s + | Reg_field s _ _ _ _ -> s + | Reg_f_slice s _ _ _ _ _ -> s + end + +let slice_of_reg_name : reg_name -> slice = function + | Reg _ start width D_increasing -> (start, start + width -1) + | Reg _ start width D_decreasing -> (start - width - 1, start) + | Reg_slice _ _ _ sl -> sl + | Reg_field _ _ _ _ sl -> sl + | Reg_f_slice _ _ _ _ _ sl -> sl + end + +let width_of_reg_name (r: reg_name) : nat = + let width_of_slice (i, j) = (* j - i + 1 in *) + + (integerFromNat j) - (integerFromNat i) + 1 + $> abs $> natFromInteger + in + match r with + | Reg _ _ width _ -> width + | Reg_slice _ _ _ sl -> width_of_slice sl + | Reg_field _ _ _ _ sl -> width_of_slice sl + | Reg_f_slice _ _ _ _ _ sl -> width_of_slice sl + end + +let reg_name_non_empty_intersection (r: reg_name) (r': reg_name) : bool = + register_base_name r = register_base_name r' && + let (i1, i2) = slice_of_reg_name r in + let (i1', i2') = slice_of_reg_name r' in + i1' <= i2 && i2' >= i1 + +let reg_nameCompare r1 r2 = + compare (register_base_name r1,slice_of_reg_name r1) + (register_base_name r2,slice_of_reg_name r2) + +let reg_nameLess b1 b2 = reg_nameCompare b1 b2 = LT +let reg_nameLessEq b1 b2 = reg_nameCompare b1 b2 <> GT +let reg_nameGreater b1 b2 = reg_nameCompare b1 b2 = GT +let reg_nameGreaterEq b1 b2 = reg_nameCompare b1 b2 <> LT + +instance (Ord reg_name) + let compare = reg_nameCompare + let (<) = reg_nameLess + let (<=) = reg_nameLessEq + let (>) = reg_nameGreater + let (>=) = reg_nameGreaterEq +end + +let {coq;ocaml} reg_nameEqual a1 a2 = (reg_nameCompare a1 a2) = EQ +let {hol;isabelle} reg_nameEqual = unsafe_structural_equality +let {coq;ocaml} reg_nameInequal a1 a2 = not (reg_nameEqual a1 a2) +let {hol;isabelle} reg_nameInequal = unsafe_structural_inequality + +instance (Eq reg_name) + let (=) = reg_nameEqual + let (<>) = reg_nameInequal +end + +instance (SetType reg_name) + let setElemCompare = reg_nameCompare +end + +let direction_of_reg_name r = match r with + | Reg _ _ _ d -> d + | Reg_slice _ _ d _ -> d + | Reg_field _ _ d _ _ -> d + | Reg_f_slice _ _ d _ _ _ -> d + end + +let start_of_reg_name r = match r with + | Reg _ start _ _ -> start + | Reg_slice _ start _ _ -> start + | Reg_field _ start _ _ _ -> start + | Reg_f_slice _ start _ _ _ _ -> start +end + +(* Data structures for building up instructions *) + +(* read_kind, write_kind, barrier_kind, trans_kind and instruction_kind have + been moved to sail_instr_kinds.lem. This removes the dependency of the + shallow embedding on the rest of sail_impl_base.lem, and helps avoid name + clashes between the different monad types. *) + +type event = + | E_read_mem of read_kind * address_lifted * nat * maybe (list reg_name) + | E_read_memt of read_kind * address_lifted * nat * maybe (list reg_name) + | E_write_mem of write_kind * address_lifted * nat * maybe (list reg_name) * memory_value * maybe (list reg_name) + | E_write_ea of write_kind * address_lifted * nat * maybe (list reg_name) + | E_excl_res + | E_write_memv of maybe address_lifted * memory_value * maybe (list reg_name) + | E_write_memvt of maybe address_lifted * (bit_lifted * memory_value) * maybe (list reg_name) + | E_barrier of barrier_kind + | E_footprint + | E_read_reg of reg_name + | E_write_reg of reg_name * register_value + | E_escape + | E_error of string + + +let eventCompare e1 e2 = + match (e1,e2) with + | (E_read_mem rk1 v1 i1 tr1, E_read_mem rk2 v2 i2 tr2) -> + compare (rk1, (v1,i1,tr1)) (rk2,(v2, i2, tr2)) + | (E_read_memt rk1 v1 i1 tr1, E_read_memt rk2 v2 i2 tr2) -> + compare (rk1, (v1,i1,tr1)) (rk2,(v2, i2, tr2)) + | (E_write_mem wk1 v1 i1 tr1 v1' tr1', E_write_mem wk2 v2 i2 tr2 v2' tr2') -> + compare ((wk1,v1,i1),(tr1,v1',tr1')) ((wk2,v2,i2),(tr2,v2',tr2')) + | (E_write_ea wk1 a1 i1 tr1, E_write_ea wk2 a2 i2 tr2) -> + compare (wk1, (a1, i1, tr1)) (wk2, (a2, i2, tr2)) + | (E_excl_res, E_excl_res) -> EQ + | (E_write_memv _ mv1 tr1, E_write_memv _ mv2 tr2) -> compare (mv1,tr1) (mv2,tr2) + | (E_write_memvt _ mv1 tr1, E_write_memvt _ mv2 tr2) -> compare (mv1,tr1) (mv2,tr2) + | (E_barrier bk1, E_barrier bk2) -> compare bk1 bk2 + | (E_read_reg r1, E_read_reg r2) -> compare r1 r2 + | (E_write_reg r1 v1, E_write_reg r2 v2) -> compare (r1,v1) (r2,v2) + | (E_error s1, E_error s2) -> compare s1 s2 + | (E_escape,E_escape) -> EQ + | (E_read_mem _ _ _ _, _) -> LT + | (E_write_mem _ _ _ _ _ _, _) -> LT + | (E_write_ea _ _ _ _, _) -> LT + | (E_excl_res, _) -> LT + | (E_write_memv _ _ _, _) -> LT + | (E_barrier _, _) -> LT + | (E_read_reg _, _) -> LT + | (E_write_reg _ _, _) -> LT + | _ -> GT + end + +let eventLess b1 b2 = eventCompare b1 b2 = LT +let eventLessEq b1 b2 = eventCompare b1 b2 <> GT +let eventGreater b1 b2 = eventCompare b1 b2 = GT +let eventGreaterEq b1 b2 = eventCompare b1 b2 <> LT + +instance (Ord event) + let compare = eventCompare + let (<) = eventLess + let (<=) = eventLessEq + let (>) = eventGreater + let (>=) = eventGreaterEq +end + +instance (SetType event) + let setElemCompare = compare +end + + +(* the address_lifted types should go away here and be replaced by address *) +type with_aux 'o = 'o * maybe ((unit -> (string * string)) * ((list (reg_name * register_value)) -> list event)) +type outcome 'a 'e = + (* Request to read memory, value is location to read, integer is size to read, + followed by registers that were used in computing that size *) + | Read_mem of (read_kind * address_lifted * nat) * (memory_value -> with_aux (outcome 'a 'e)) + (* Tell the system a write is imminent, at address lifted, of size nat *) + | Write_ea of (write_kind * address_lifted * nat) * (with_aux (outcome 'a 'e)) + (* Request the result of store-exclusive *) + | Excl_res of (bool -> with_aux (outcome 'a 'e)) + (* Request to write memory at last signalled address. Memory value should be 8 + times the size given in ea signal *) + | Write_memv of memory_value * (bool -> with_aux (outcome 'a 'e)) + (* Request a memory barrier *) + | Barrier of barrier_kind * with_aux (outcome 'a 'e) + (* Tell the system to dynamically recalculate dependency footprint *) + | Footprint of with_aux (outcome 'a 'e) + (* Request to read register, will track dependency when mode.track_values *) + | Read_reg of reg_name * (register_value -> with_aux (outcome 'a 'e)) + (* Request to write register *) + | Write_reg of (reg_name * register_value) * with_aux (outcome 'a 'e) + | Escape of maybe string + (*Result of a failed assert with possible error message to report*) + | Fail of maybe string + (* Exception of type 'e *) + | Exception of 'e + | Internal of (maybe string * maybe (unit -> string)) * with_aux (outcome 'a 'e) + | Done of 'a + | Error of string + +type outcome_s 'a 'e = with_aux (outcome 'a 'e) +(* first string : output of instruction_stack_to_string + second string: output of local_variables_to_string *) + +(** operations and coercions on basic values *) + +val word8_to_bitls : word8 -> list bit_lifted +val bitls_to_word8 : list bit_lifted -> word8 + +val integer_of_word8_list : list word8 -> integer +val word8_list_of_integer : integer -> integer -> list word8 + +val concretizable_bitl : bit_lifted -> bool +val concretizable_bytl : byte_lifted -> bool +val concretizable_bytls : list byte_lifted -> bool + +let concretizable_bitl = function + | Bitl_zero -> true + | Bitl_one -> true + | Bitl_undef -> false + | Bitl_unknown -> false +end + +let concretizable_bytl (Byte_lifted bs) = List.all concretizable_bitl bs +let concretizable_bytls = List.all concretizable_bytl + +(* constructing values *) + +val build_register_value : list bit_lifted -> direction -> nat -> nat -> register_value +let build_register_value bs dir width start_index = + <| rv_bits = bs; + rv_dir = dir; (* D_increasing for Power, D_decreasing for ARM *) + rv_start_internal = start_index; + rv_start = if dir = D_increasing + then start_index + else (start_index+1) - width; (* Smaller index, as in Power, for external interaction *) + |> + +val register_value : bit_lifted -> direction -> nat -> nat -> register_value +let register_value b dir width start_index = + build_register_value (List.replicate width b) dir width start_index + +val register_value_zeros : direction -> nat -> nat -> register_value +let register_value_zeros dir width start_index = + register_value Bitl_zero dir width start_index + +val register_value_ones : direction -> nat -> nat -> register_value +let register_value_ones dir width start_index = + register_value Bitl_one dir width start_index + +val register_value_for_reg : reg_name -> list bit_lifted -> register_value +let register_value_for_reg r bs : register_value = + let () = ensure (width_of_reg_name r = List.length bs) + ("register_value_for_reg (\"" ^ show (register_base_name r) ^ "\") length mismatch: " + ^ show (width_of_reg_name r) ^ " vs " ^ show (List.length bs)) + in + let (j1, j2) = slice_of_reg_name r in + let d = direction_of_reg_name r in + <| rv_bits = bs; + rv_dir = d; + rv_start_internal = if d = D_increasing then j1 else (start_of_reg_name r) - j1; + rv_start = j1; + |> + +val byte_lifted_undef : byte_lifted +let byte_lifted_undef = Byte_lifted (List.replicate 8 Bitl_undef) + +val byte_lifted_unknown : byte_lifted +let byte_lifted_unknown = Byte_lifted (List.replicate 8 Bitl_unknown) + +val memory_value_unknown : nat (*the number of bytes*) -> memory_value +let memory_value_unknown (width:nat) : memory_value = + List.replicate width byte_lifted_unknown + +val memory_value_undef : nat (*the number of bytes*) -> memory_value +let memory_value_undef (width:nat) : memory_value = + List.replicate width byte_lifted_undef + +val match_endianness : forall 'a. end_flag -> list 'a -> list 'a +let match_endianness endian l = + match endian with + | E_little_endian -> List.reverse l + | E_big_endian -> l + end + +(* lengths *) + +val memory_value_length : memory_value -> nat +let memory_value_length (mv:memory_value) = List.length mv + + +(* aux fns *) + +val maybe_all : forall 'a. list (maybe 'a) -> maybe (list 'a) +let rec maybe_all' xs acc = + match xs with + | [] -> Just (List.reverse acc) + | Nothing :: _ -> Nothing + | (Just y)::xs' -> maybe_all' xs' (y::acc) + end +let maybe_all xs = maybe_all' xs [] + +(** coercions *) + +(* bits and bytes *) + +let bit_to_bool = function (* TODO: rename bool_of_bit *) + | Bitc_zero -> false + | Bitc_one -> true +end + + +val bit_lifted_of_bit : bit -> bit_lifted +let bit_lifted_of_bit b = + match b with + | Bitc_zero -> Bitl_zero + | Bitc_one -> Bitl_one + end + +val bit_of_bit_lifted : bit_lifted -> maybe bit +let bit_of_bit_lifted bl = + match bl with + | Bitl_zero -> Just Bitc_zero + | Bitl_one -> Just Bitc_one + | Bitl_undef -> Nothing + | Bitl_unknown -> Nothing + end + + +val byte_lifted_of_byte : byte -> byte_lifted +let byte_lifted_of_byte (Byte bs) : byte_lifted = Byte_lifted (List.map bit_lifted_of_bit bs) + +val byte_of_byte_lifted : byte_lifted -> maybe byte +let byte_of_byte_lifted bl = + match bl with + | Byte_lifted bls -> + match maybe_all (List.map bit_of_bit_lifted bls) with + | Nothing -> Nothing + | Just bs -> Just (Byte bs) + end + end + + +val bytes_of_bits : list bit -> list byte (*assumes (length bits) mod 8 = 0*) +let rec bytes_of_bits bits = match bits with + | [] -> [] + | b0::b1::b2::b3::b4::b5::b6::b7::bits -> + (Byte [b0;b1;b2;b3;b4;b5;b6;b7])::(bytes_of_bits bits) + | _ -> failwith "bytes_of_bits not given bits divisible by 8" +end + +val byte_lifteds_of_bit_lifteds : list bit_lifted -> list byte_lifted (*assumes (length bits) mod 8 = 0*) +let rec byte_lifteds_of_bit_lifteds bits = match bits with + | [] -> [] + | b0::b1::b2::b3::b4::b5::b6::b7::bits -> + (Byte_lifted [b0;b1;b2;b3;b4;b5;b6;b7])::(byte_lifteds_of_bit_lifteds bits) + | _ -> failwith "byte_lifteds of bit_lifteds not given bits divisible by 8" +end + + +val byte_of_memory_byte : memory_byte -> maybe byte +let byte_of_memory_byte = byte_of_byte_lifted + +val memory_byte_of_byte : byte -> memory_byte +let memory_byte_of_byte = byte_lifted_of_byte + + +(* to and from nat *) + +(* this natFromBoolList could move to the Lem word.lem library *) +val natFromBoolList : list bool -> nat +let rec natFromBoolListAux (acc : nat) (bl : list bool) = + match bl with + | [] -> acc + | (true :: bl') -> natFromBoolListAux ((acc * 2) + 1) bl' + | (false :: bl') -> natFromBoolListAux (acc * 2) bl' + end +let natFromBoolList bl = + natFromBoolListAux 0 (List.reverse bl) + + +val nat_of_bit_list : list bit -> nat +let nat_of_bit_list b = + natFromBoolList (List.reverse (List.map bit_to_bool b)) + (* natFromBoolList takes a list with LSB first, for consistency with rest of Lem word library, so we reverse it. twice. *) + + +(* to and from integer *) + +val integer_of_bit_list : list bit -> integer +let integer_of_bit_list b = + integerFromBoolList (false,(List.reverse (List.map bit_to_bool b))) + (* integerFromBoolList takes a list with LSB first, so we reverse it *) + +val bit_list_of_integer : nat -> integer -> list bit +let bit_list_of_integer len b = + List.map (fun b -> if b then Bitc_one else Bitc_zero) + (reverse (boolListFrombitSeq len (bitSeqFromInteger Nothing b))) + +val integer_of_byte_list : list byte -> integer +let integer_of_byte_list bytes = integer_of_bit_list (List.concatMap (fun (Byte bs) -> bs) bytes) + +val byte_list_of_integer : nat -> integer -> list byte +let byte_list_of_integer (len:nat) (a:integer):list byte = + let bits = bit_list_of_integer (len * 8) a in bytes_of_bits bits + + +val integer_of_address : address -> integer +let integer_of_address (a:address):integer = + match a with + | Address bs i -> i + end + +val address_of_integer : integer -> address +let address_of_integer (i:integer):address = + Address (byte_list_of_integer 8 i) i + +(* to and from signed-integer *) + +val signed_integer_of_bit_list : list bit -> integer +let signed_integer_of_bit_list b = + match b with + | [] -> failwith "empty bit list" + | Bitc_zero :: b' -> + integerFromBoolList (false,(List.reverse (List.map bit_to_bool b))) + | Bitc_one :: b' -> + let b'_val = integerFromBoolList (false,(List.reverse (List.map bit_to_bool b'))) in + (* integerFromBoolList takes a list with LSB first, so we reverse it *) + let msb_val = integerPow 2 ((List.length b) - 1) in + b'_val - msb_val + end + + +(* regarding a list of int as a list of bytes in memory, MSB lowest-address first, convert to an integer *) +val integer_address_of_int_list : list int -> integer +let rec integerFromIntListAux (acc: integer) (is: list int) = + match is with + | [] -> acc + | (i :: is') -> integerFromIntListAux ((acc * 256) + integerFromInt i) is' + end +let integer_address_of_int_list (is: list int) = + integerFromIntListAux 0 is + +val address_of_byte_list : list byte -> address +let address_of_byte_list bs = + if List.length bs <> 8 then failwith "address_of_byte_list given list not of length 8" else + Address bs (integer_of_byte_list bs) + +let address_of_byte_lifted_list bls = + match maybe_all (List.map byte_of_byte_lifted bls) with + | Nothing -> Nothing + | Just bs -> Just (address_of_byte_list bs) + end + +(* operations on addresses *) + +val add_address_nat : address -> nat -> address +let add_address_nat (a:address) (i:nat) : address = + address_of_integer ((integer_of_address a) + (integerFromNat i)) + +val clear_low_order_bits_of_address : address -> address +let clear_low_order_bits_of_address a = + match a with + | Address [b0;b1;b2;b3;b4;b5;b6;b7] i -> + match b7 with + | Byte [bt0;bt1;bt2;bt3;bt4;bt5;bt6;bt7] -> + let b7' = Byte [bt0;bt1;bt2;bt3;bt4;bt5;Bitc_zero;Bitc_zero] in + let bytes = [b0;b1;b2;b3;b4;b5;b6;b7'] in + Address bytes (integer_of_byte_list bytes) + | _ -> failwith "Byte does not contain 8 bits" + end + | _ -> failwith "Address does not contain 8 bytes" + end + + + +val byte_list_of_memory_value : end_flag -> memory_value -> maybe (list byte) +let byte_list_of_memory_value endian mv = + match_endianness endian mv + $> List.map byte_of_memory_byte + $> maybe_all + + +val integer_of_memory_value : end_flag -> memory_value -> maybe integer +let integer_of_memory_value endian (mv:memory_value):maybe integer = + match byte_list_of_memory_value endian mv with + | Just bs -> Just (integer_of_byte_list bs) + | Nothing -> Nothing + end + +val memory_value_of_integer : end_flag -> nat -> integer -> memory_value +let memory_value_of_integer endian (len:nat) (i:integer):memory_value = + List.map byte_lifted_of_byte (byte_list_of_integer len i) + $> match_endianness endian + + +val integer_of_register_value : register_value -> maybe integer +let integer_of_register_value (rv:register_value):maybe integer = + match maybe_all (List.map bit_of_bit_lifted rv.rv_bits) with + | Nothing -> Nothing + | Just bs -> Just (integer_of_bit_list bs) + end + +(* NOTE: register_value_for_reg_of_integer might be easier to use *) +val register_value_of_integer : nat -> nat -> direction -> integer -> register_value +let register_value_of_integer (len:nat) (start:nat) (dir:direction) (i:integer):register_value = + let bs = bit_list_of_integer len i in + build_register_value (List.map bit_lifted_of_bit bs) dir len start + +val register_value_for_reg_of_integer : reg_name -> integer -> register_value +let register_value_for_reg_of_integer (r: reg_name) (i:integer) : register_value = + register_value_of_integer (width_of_reg_name r) (start_of_reg_name r) (direction_of_reg_name r) i + +(* *) + +val opcode_of_bytes : byte -> byte -> byte -> byte -> opcode +let opcode_of_bytes b0 b1 b2 b3 : opcode = Opcode [b0;b1;b2;b3] + +val register_value_of_address : address -> direction -> register_value +let register_value_of_address (Address bytes _) dir : register_value = + let bits = List.concatMap (fun (Byte bs) -> List.map bit_lifted_of_bit bs) bytes in + <| rv_bits = bits; + rv_dir = dir; + rv_start = 0; + rv_start_internal = if dir = D_increasing then 0 else (List.length bits) - 1 + |> + +val register_value_of_memory_value : memory_value -> direction -> register_value +let register_value_of_memory_value bytes dir : register_value = + let bitls = List.concatMap (fun (Byte_lifted bs) -> bs) bytes in + <| rv_bits = bitls; + rv_dir = dir; + rv_start = 0; + rv_start_internal = if dir = D_increasing then 0 else (List.length bitls) - 1 + |> + +val memory_value_of_register_value: register_value -> memory_value +let memory_value_of_register_value r = + (byte_lifteds_of_bit_lifteds r.rv_bits) + +val address_lifted_of_register_value : register_value -> maybe address_lifted +(* returning Nothing iff the register value is not 64 bits wide, but +allowing Bitl_undef and Bitl_unknown *) +let address_lifted_of_register_value (rv:register_value) : maybe address_lifted = + if List.length rv.rv_bits <> 64 then Nothing + else + Just (Address_lifted (byte_lifteds_of_bit_lifteds rv.rv_bits) + (if List.all concretizable_bitl rv.rv_bits + then match (maybe_all (List.map bit_of_bit_lifted rv.rv_bits)) with + | (Just(bits)) -> Just (integer_of_bit_list bits) + | Nothing -> Nothing end + else Nothing)) + +val address_of_address_lifted : address_lifted -> maybe address +(* returning Nothing iff the address contains any Bitl_undef or Bitl_unknown *) +let address_of_address_lifted (al:address_lifted): maybe address = + match al with + | Address_lifted bls (Just i)-> + match maybe_all ((List.map byte_of_byte_lifted) bls) with + | Nothing -> Nothing + | Just bs -> Just (Address bs i) + end + | _ -> Nothing +end + +val address_of_register_value : register_value -> maybe address +(* returning Nothing iff the register value is not 64 bits wide, or contains Bitl_undef or Bitl_unknown *) +let address_of_register_value (rv:register_value) : maybe address = + match address_lifted_of_register_value rv with + | Nothing -> Nothing + | Just al -> + match address_of_address_lifted al with + | Nothing -> Nothing + | Just a -> Just a + end + end + +let address_of_memory_value (endian: end_flag) (mv:memory_value) : maybe address = + match byte_list_of_memory_value endian mv with + | Nothing -> Nothing + | Just bs -> + if List.length bs <> 8 then Nothing else + Just (address_of_byte_list bs) + end + +val byte_of_int : int -> byte +let byte_of_int (i:int) : byte = + Byte (bit_list_of_integer 8 (integerFromInt i)) + +val memory_byte_of_int : int -> memory_byte +let memory_byte_of_int (i:int) : memory_byte = + memory_byte_of_byte (byte_of_int i) + +(* +val int_of_memory_byte : int -> maybe memory_byte +let int_of_memory_byte (mb:memory_byte) : int = + failwith "TODO" +*) + + + +val memory_value_of_address_lifted : end_flag -> address_lifted -> memory_value +let memory_value_of_address_lifted endian (Address_lifted bs _ :address_lifted) = + match_endianness endian bs + +val byte_list_of_address : address -> list byte +let byte_list_of_address (Address bs _) : list byte = bs + +val memory_value_of_address : end_flag -> address -> memory_value +let memory_value_of_address endian (Address bs _) = + match_endianness endian bs + $> List.map byte_lifted_of_byte + +val byte_list_of_opcode : opcode -> list byte +let byte_list_of_opcode (Opcode bs) : list byte = bs + +(** ****************************************** *) +(** show type class instantiations *) +(** ****************************************** *) + +(* matching printing_functions.ml *) +val stringFromReg_name : reg_name -> string +let stringFromReg_name r = + let norm_sl start dir (first,second) = (first,second) + (* match dir with + | D_increasing -> (first,second) + | D_decreasing -> (start - first, start - second) + end *) + in + match r with + | Reg s start size dir -> s + | Reg_slice s start dir sl -> + let (first,second) = norm_sl start dir sl in + s ^ "[" ^ show first ^ (if (first = second) then "" else ".." ^ (show second)) ^ "]" + | Reg_field s start dir f sl -> + let (first,second) = norm_sl start dir sl in + s ^ "." ^ f ^ " (" ^ (show start) ^ ", " ^ (show dir) ^ ", " ^ (show first) ^ ", " ^ (show second) ^ ")" + | Reg_f_slice s start dir f (first1,second1) (first,second) -> + let (first,second) = + match dir with + | D_increasing -> (first,second) + | D_decreasing -> (start - first, start - second) + end in + s ^ "." ^ f ^ "]" ^ show first ^ (if (first = second) then "" else ".." ^ (show second)) ^ "]" + end + +instance (Show reg_name) + let show = stringFromReg_name +end + + +(* hex pp of integers, adapting the Lem string_extra.lem code *) +val stringFromNaturalHexHelper : natural -> list char -> list char +let rec stringFromNaturalHexHelper n acc = + if n = 0 then + acc + else + stringFromNaturalHexHelper (n / 16) (String_extra.chr (natFromNatural (let nd = n mod 16 in if nd <=9 then nd + 48 else nd - 10 + 97)) :: acc) + +val stringFromNaturalHex : natural -> string +let (*~{ocaml;hol}*) stringFromNaturalHex n = + if n = 0 then "0" else toString (stringFromNaturalHexHelper n []) + +val stringFromIntegerHex : integer -> string +let (*~{ocaml}*) stringFromIntegerHex i = + if i < 0 then + "-" ^ stringFromNaturalHex (naturalFromInteger i) + else + stringFromNaturalHex (naturalFromInteger i) + + +let stringFromAddress (Address bs i) = + let i' = integer_of_byte_list bs in + if i=i' then +(*TODO: ideally this should be made to match the src/pp.ml pp_address; the following very roughly matches what's used in the ppcmem UI, enough to make exceptions readable *) + if i < 65535 then + show i + else + stringFromIntegerHex i + else + "stringFromAddress bytes and integer mismatch" + +instance (Show address) + let show = stringFromAddress +end + +let stringFromByte_lifted bl = + match byte_of_byte_lifted bl with + | Nothing -> "u?" + | Just (Byte bits) -> + let i = integer_of_bit_list bits in + show i + end + +instance (Show byte_lifted) + let show = stringFromByte_lifted +end + +(* possible next instruction address options *) +type nia = + | NIA_successor + | NIA_concrete_address of address + | NIA_indirect_address + +let niaCompare n1 n2 = match (n1,n2) with + | (NIA_successor, NIA_successor) -> EQ + | (NIA_successor, _) -> LT + | (_, NIA_successor) -> GT + | (NIA_concrete_address a1, NIA_concrete_address a2) -> compare a1 a2 + | (NIA_concrete_address _, _) -> LT + | (_, NIA_concrete_address _) -> GT + | (NIA_indirect_address, NIA_indirect_address) -> EQ + (* | (NIA_indirect_address, _) -> LT + | (_, NIA_indirect_address) -> GT *) + end + +instance (Ord nia) + let compare = niaCompare + let (<) n1 n2 = (niaCompare n1 n2) = LT + let (<=) n1 n2 = (niaCompare n1 n2) <> GT + let (>) n1 n2 = (niaCompare n1 n2) = GT + let (>=) n1 n2 = (niaCompare n1 n2) <> LT +end + +let stringFromNia = function + | NIA_successor -> "NIA_successor" + | NIA_concrete_address a -> "NIA_concrete_address " ^ show a + | NIA_indirect_address -> "NIA_indirect_address" +end + +instance (Show nia) + let show = stringFromNia +end + +type dia = + | DIA_none + | DIA_concrete_address of address + | DIA_register of reg_name + +let diaCompare d1 d2 = match (d1, d2) with + | (DIA_none, DIA_none) -> EQ + | (DIA_none, _) -> LT + | (DIA_concrete_address a1, DIA_none) -> GT + | (DIA_concrete_address a1, DIA_concrete_address a2) -> compare a1 a2 + | (DIA_concrete_address a1, _) -> LT + | (DIA_register r1, DIA_register r2) -> compare r1 r2 + | (DIA_register _, _) -> GT +end + +instance (Ord dia) + let compare = diaCompare + let (<) n1 n2 = (diaCompare n1 n2) = LT + let (<=) n1 n2 = (diaCompare n1 n2) <> GT + let (>) n1 n2 = (diaCompare n1 n2) = GT + let (>=) n1 n2 = (diaCompare n1 n2) <> LT +end + +let stringFromDia = function + | DIA_none -> "DIA_none" + | DIA_concrete_address a -> "DIA_concrete_address " ^ show a + | DIA_register r -> "DIA_delayed_register " ^ show r +end + +instance (Show dia) + let show = stringFromDia +end +*) diff --git a/lib/coq/Sail_instr_kinds.v b/lib/coq/Sail_instr_kinds.v new file mode 100644 index 00000000..57532e92 --- /dev/null +++ b/lib/coq/Sail_instr_kinds.v @@ -0,0 +1,298 @@ +(*========================================================================*) +(* Sail *) +(* *) +(* Copyright (c) 2013-2017 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* Alasdair Armstrong *) +(* Brian Campbell *) +(* Thomas Bauereiss *) +(* Anthony Fox *) +(* Jon French *) +(* Dominic Mulligan *) +(* Stephen Kell *) +(* Mark Wassell *) +(* *) +(* All rights reserved. *) +(* *) +(* This software was developed by the University of Cambridge Computer *) +(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) +(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(*========================================================================*) + + +(* + +class ( EnumerationType 'a ) + val toNat : 'a -> nat +end + + +val enumeration_typeCompare : forall 'a. EnumerationType 'a => 'a -> 'a -> ordering +let ~{ocaml} enumeration_typeCompare e1 e2 := + compare (toNat e1) (toNat e2) +let inline {ocaml} enumeration_typeCompare := defaultCompare + + +default_instance forall 'a. EnumerationType 'a => (Ord 'a) + let compare := enumeration_typeCompare + let (<) r1 r2 := (enumeration_typeCompare r1 r2) = LT + let (<=) r1 r2 := (enumeration_typeCompare r1 r2) <> GT + let (>) r1 r2 := (enumeration_typeCompare r1 r2) = GT + let (>=) r1 r2 := (enumeration_typeCompare r1 r2) <> LT +end +*) + +(* Data structures for building up instructions *) + +(* careful: changes in the read/write/barrier kinds have to be + reflected in deep_shallow_convert *) +Inductive read_kind := + (* common reads *) + | Read_plain + (* Power reads *) + | Read_reserve + (* AArch64 reads *) + | Read_acquire | Read_exclusive | Read_exclusive_acquire | Read_stream + (* RISC-V reads *) + | Read_RISCV_acquire | Read_RISCV_strong_acquire + | Read_RISCV_reserved | Read_RISCV_reserved_acquire + | Read_RISCV_reserved_strong_acquire + (* x86 reads *) + | Read_X86_locked (* the read part of a lock'd instruction (rmw) *) +. +(* +instance (Show read_kind) + let show := function + | Read_plain -> "Read_plain" + | Read_reserve -> "Read_reserve" + | Read_acquire -> "Read_acquire" + | Read_exclusive -> "Read_exclusive" + | Read_exclusive_acquire -> "Read_exclusive_acquire" + | Read_stream -> "Read_stream" + | Read_RISCV_acquire -> "Read_RISCV_acquire" + | Read_RISCV_strong_acquire -> "Read_RISCV_strong_acquire" + | Read_RISCV_reserved -> "Read_RISCV_reserved" + | Read_RISCV_reserved_acquire -> "Read_RISCV_reserved_acquire" + | Read_RISCV_reserved_strong_acquire -> "Read_RISCV_reserved_strong_acquire" + | Read_X86_locked -> "Read_X86_locked" + end +end +*) +Inductive write_kind := + (* common writes *) + | Write_plain + (* Power writes *) + | Write_conditional + (* AArch64 writes *) + | Write_release | Write_exclusive | Write_exclusive_release + (* RISC-V *) + | Write_RISCV_release | Write_RISCV_strong_release + | Write_RISCV_conditional | Write_RISCV_conditional_release + | Write_RISCV_conditional_strong_release + (* x86 writes *) + | Write_X86_locked (* the write part of a lock'd instruction (rmw) *) +. +(* +instance (Show write_kind) + let show := function + | Write_plain -> "Write_plain" + | Write_conditional -> "Write_conditional" + | Write_release -> "Write_release" + | Write_exclusive -> "Write_exclusive" + | Write_exclusive_release -> "Write_exclusive_release" + | Write_RISCV_release -> "Write_RISCV_release" + | Write_RISCV_strong_release -> "Write_RISCV_strong_release" + | Write_RISCV_conditional -> "Write_RISCV_conditional" + | Write_RISCV_conditional_release -> "Write_RISCV_conditional_release" + | Write_RISCV_conditional_strong_release -> "Write_RISCV_conditional_strong_release" + | Write_X86_locked -> "Write_X86_locked" + end +end +*) +Inductive barrier_kind := + (* Power barriers *) + Barrier_Sync | Barrier_LwSync | Barrier_Eieio | Barrier_Isync + (* AArch64 barriers *) + | Barrier_DMB | Barrier_DMB_ST | Barrier_DMB_LD | Barrier_DSB + | Barrier_DSB_ST | Barrier_DSB_LD | Barrier_ISB + | Barrier_TM_COMMIT + (* MIPS barriers *) + | Barrier_MIPS_SYNC + (* RISC-V barriers *) + | Barrier_RISCV_rw_rw + | Barrier_RISCV_r_rw + | Barrier_RISCV_r_r + | Barrier_RISCV_rw_w + | Barrier_RISCV_w_w + | Barrier_RISCV_i + (* X86 *) + | Barrier_x86_MFENCE. + +(* +instance (Show barrier_kind) + let show := function + | Barrier_Sync -> "Barrier_Sync" + | Barrier_LwSync -> "Barrier_LwSync" + | Barrier_Eieio -> "Barrier_Eieio" + | Barrier_Isync -> "Barrier_Isync" + | Barrier_DMB -> "Barrier_DMB" + | Barrier_DMB_ST -> "Barrier_DMB_ST" + | Barrier_DMB_LD -> "Barrier_DMB_LD" + | Barrier_DSB -> "Barrier_DSB" + | Barrier_DSB_ST -> "Barrier_DSB_ST" + | Barrier_DSB_LD -> "Barrier_DSB_LD" + | Barrier_ISB -> "Barrier_ISB" + | Barrier_TM_COMMIT -> "Barrier_TM_COMMIT" + | Barrier_MIPS_SYNC -> "Barrier_MIPS_SYNC" + | Barrier_RISCV_rw_rw -> "Barrier_RISCV_rw_rw" + | Barrier_RISCV_r_rw -> "Barrier_RISCV_r_rw" + | Barrier_RISCV_r_r -> "Barrier_RISCV_r_r" + | Barrier_RISCV_rw_w -> "Barrier_RISCV_rw_w" + | Barrier_RISCV_w_w -> "Barrier_RISCV_w_w" + | Barrier_RISCV_i -> "Barrier_RISCV_i" + | Barrier_x86_MFENCE -> "Barrier_x86_MFENCE" + end +end*) + +Inductive trans_kind := + (* AArch64 *) + | Transaction_start | Transaction_commit | Transaction_abort. +(* +instance (Show trans_kind) + let show := function + | Transaction_start -> "Transaction_start" + | Transaction_commit -> "Transaction_commit" + | Transaction_abort -> "Transaction_abort" + end +end*) + +Inductive instruction_kind := + | IK_barrier : barrier_kind -> instruction_kind + | IK_mem_read : read_kind -> instruction_kind + | IK_mem_write : write_kind -> instruction_kind + | IK_mem_rmw : (read_kind * write_kind) -> instruction_kind + | IK_branch (* this includes conditional-branch (multiple nias, none of which is NIA_indirect_address), + indirect/computed-branch (single nia of kind NIA_indirect_address) + and branch/jump (single nia of kind NIA_concrete_address) *) + | IK_trans : trans_kind -> instruction_kind + | IK_simple : instruction_kind. + +(* +instance (Show instruction_kind) + let show := function + | IK_barrier barrier_kind -> "IK_barrier " ^ (show barrier_kind) + | IK_mem_read read_kind -> "IK_mem_read " ^ (show read_kind) + | IK_mem_write write_kind -> "IK_mem_write " ^ (show write_kind) + | IK_mem_rmw (r, w) -> "IK_mem_rmw " ^ (show r) ^ " " ^ (show w) + | IK_branch -> "IK_branch" + | IK_trans trans_kind -> "IK_trans " ^ (show trans_kind) + | IK_simple -> "IK_simple" + end +end +*) + +Definition read_is_exclusive r := +match r with + | Read_plain => false + | Read_reserve => true + | Read_acquire => false + | Read_exclusive => true + | Read_exclusive_acquire => true + | Read_stream => false + | Read_RISCV_acquire => false + | Read_RISCV_strong_acquire => false + | Read_RISCV_reserved => true + | Read_RISCV_reserved_acquire => true + | Read_RISCV_reserved_strong_acquire => true + | Read_X86_locked => true +end. + + +(* +instance (EnumerationType read_kind) + let toNat := function + | Read_plain -> 0 + | Read_reserve -> 1 + | Read_acquire -> 2 + | Read_exclusive -> 3 + | Read_exclusive_acquire -> 4 + | Read_stream -> 5 + | Read_RISCV_acquire -> 6 + | Read_RISCV_strong_acquire -> 7 + | Read_RISCV_reserved -> 8 + | Read_RISCV_reserved_acquire -> 9 + | Read_RISCV_reserved_strong_acquire -> 10 + | Read_X86_locked -> 11 + end +end + +instance (EnumerationType write_kind) + let toNat := function + | Write_plain -> 0 + | Write_conditional -> 1 + | Write_release -> 2 + | Write_exclusive -> 3 + | Write_exclusive_release -> 4 + | Write_RISCV_release -> 5 + | Write_RISCV_strong_release -> 6 + | Write_RISCV_conditional -> 7 + | Write_RISCV_conditional_release -> 8 + | Write_RISCV_conditional_strong_release -> 9 + | Write_X86_locked -> 10 + end +end + +instance (EnumerationType barrier_kind) + let toNat := function + | Barrier_Sync -> 0 + | Barrier_LwSync -> 1 + | Barrier_Eieio ->2 + | Barrier_Isync -> 3 + | Barrier_DMB -> 4 + | Barrier_DMB_ST -> 5 + | Barrier_DMB_LD -> 6 + | Barrier_DSB -> 7 + | Barrier_DSB_ST -> 8 + | Barrier_DSB_LD -> 9 + | Barrier_ISB -> 10 + | Barrier_TM_COMMIT -> 11 + | Barrier_MIPS_SYNC -> 12 + | Barrier_RISCV_rw_rw -> 13 + | Barrier_RISCV_r_rw -> 14 + | Barrier_RISCV_r_r -> 15 + | Barrier_RISCV_rw_w -> 16 + | Barrier_RISCV_w_w -> 17 + | Barrier_RISCV_i -> 18 + | Barrier_x86_MFENCE -> 19 + end +end +*) diff --git a/lib/coq/Sail_operators.v b/lib/coq/Sail_operators.v new file mode 100644 index 00000000..8183daea --- /dev/null +++ b/lib/coq/Sail_operators.v @@ -0,0 +1,245 @@ +Require Import Sail_values. +Require List. + +(*** Bit vector operations *) + +Section Bitvectors. +Context {a b c} `{Bitvector a} `{Bitvector b} `{Bitvector c}. + +(*val concat_bv : forall 'a 'b 'c. Bitvector 'a, Bitvector 'b, Bitvector 'c => 'a -> 'b -> 'c*) +Definition concat_bv (l : a) (r : b) : c := of_bits (bits_of l ++ bits_of r). + +(*val cons_bv : forall 'a 'b 'c. Bitvector 'a, Bitvector 'b => bitU -> 'a -> 'b*) +Definition cons_bv b' (v : a) : b := of_bits (b' :: bits_of v). + +(*Definition bool_of_bv v := extract_only_element (bits_of v). +Definition cast_unit_bv b := of_bits [b] +Definition bv_of_bit len b := of_bits (extz_bits len [b])*) +Definition int_of_bv {a} `{Bitvector a} (sign : bool) : a -> Z := if sign then signed else unsigned. +(* +Definition most_significant v := match bits_of v with + | b :: _ -> b + | _ -> failwith "most_significant applied to empty vector" + end + +Definition get_max_representable_in sign (n : integer) : integer := + if (n = 64) then match sign with | true -> max_64 | false -> max_64u end + else if (n=32) then match sign with | true -> max_32 | false -> max_32u end + else if (n=8) then max_8 + else if (n=5) then max_5 + else match sign with | true -> integerPow 2 ((natFromInteger n) -1) + | false -> integerPow 2 (natFromInteger n) + end + +Definition get_min_representable_in _ (n : integer) : integer := + if n = 64 then min_64 + else if n = 32 then min_32 + else if n = 8 then min_8 + else if n = 5 then min_5 + else 0 - (integerPow 2 (natFromInteger n)) + +val bitwise_binop_bv : forall 'a. Bitvector 'a => + (bool -> bool -> bool) -> 'a -> 'a -> 'a*) +Definition bitwise_binop_bv {a} `{Bitvector a} op (l : a) (r : a) : a := + of_bits (binop_bits op (bits_of l) (bits_of r)). + +Definition and_bv {a} `{Bitvector a} : a -> a -> a := bitwise_binop_bv (andb). +Definition or_bv {a} `{Bitvector a} : a -> a -> a := bitwise_binop_bv orb. +Definition xor_bv {a} `{Bitvector a} : a -> a -> a := bitwise_binop_bv xorb. +Definition not_bv {a} `{Bitvector a} (v : a) : a := of_bits (not_bits (bits_of v)). + +(*val arith_op_bv : forall 'a 'b. Bitvector 'a, Bitvector 'b => + (integer -> integer -> integer) -> bool -> integer -> 'a -> 'a -> 'b*) +Definition arith_op_bv op (sign : bool) size (l : a) (r : a) : b := + let (l',r') := (int_of_bv sign l, int_of_bv sign r) in + let n := op l' r' in + of_int (size * length l) n. + + +Definition add_bv := arith_op_bv Zplus false 1. +Definition sadd_bv := arith_op_bv Zplus true 1. +Definition sub_bv := arith_op_bv Zminus false 1. +Definition mult_bv := arith_op_bv Zmult false 2. +Definition smult_bv := arith_op_bv Zmult true 2. +(* +Definition inline add_mword := Machine_word.plus +Definition inline sub_mword := Machine_word.minus +val mult_mword : forall 'a 'b. Size 'b => mword 'a -> mword 'a -> mword 'b +Definition mult_mword l r := times (zeroExtend l) (zeroExtend r) + +val arith_op_bv_int : forall 'a 'b. Bitvector 'a, Bitvector 'b => + (integer -> integer -> integer) -> bool -> integer -> 'a -> integer -> 'b*) +Definition arith_op_bv_int (op : Z -> Z -> Z) (sign : bool) (size : Z) (l : a) (r : Z) : b := + let l' := int_of_bv sign l in + let n := op l' r in + of_int (size * length l) n. + +Definition add_bv_int := arith_op_bv_int Zplus false 1. +Definition sadd_bv_int := arith_op_bv_int Zplus true 1. +Definition sub_bv_int := arith_op_bv_int Zminus false 1. +Definition mult_bv_int := arith_op_bv_int Zmult false 2. +Definition smult_bv_int := arith_op_bv_int Zmult true 2. + +(*val arith_op_int_bv : forall 'a 'b. Bitvector 'a, Bitvector 'b => + (integer -> integer -> integer) -> bool -> integer -> integer -> 'a -> 'b +Definition arith_op_int_bv op sign size l r := + let r' = int_of_bv sign r in + let n = op l r' in + of_int (size * length r) n + +Definition add_int_bv = arith_op_int_bv integerAdd false 1 +Definition sadd_int_bv = arith_op_int_bv integerAdd true 1 +Definition sub_int_bv = arith_op_int_bv integerMinus false 1 +Definition mult_int_bv = arith_op_int_bv integerMult false 2 +Definition smult_int_bv = arith_op_int_bv integerMult true 2 + +Definition arith_op_bv_bit op sign (size : integer) l r := + let l' = int_of_bv sign l in + let n = op l' (match r with | B1 -> (1 : integer) | _ -> 0 end) in + of_int (size * length l) n + +Definition add_bv_bit := arith_op_bv_bit integerAdd false 1 +Definition sadd_bv_bit := arith_op_bv_bit integerAdd true 1 +Definition sub_bv_bit := arith_op_bv_bit integerMinus true 1 + +val arith_op_overflow_bv : forall 'a 'b. Bitvector 'a, Bitvector 'b => + (integer -> integer -> integer) -> bool -> integer -> 'a -> 'a -> ('b * bitU * bitU) +Definition arith_op_overflow_bv op sign size l r := + let len := length l in + let act_size := len * size in + let (l_sign,r_sign) := (int_of_bv sign l,int_of_bv sign r) in + let (l_unsign,r_unsign) := (int_of_bv false l,int_of_bv false r) in + let n := op l_sign r_sign in + let n_unsign := op l_unsign r_unsign in + let correct_size := of_int act_size n in + let one_more_size_u := bits_of_int (act_size + 1) n_unsign in + let overflow := + if n <= get_max_representable_in sign len && + n >= get_min_representable_in sign len + then B0 else B1 in + let c_out := most_significant one_more_size_u in + (correct_size,overflow,c_out) + +Definition add_overflow_bv := arith_op_overflow_bv integerAdd false 1 +Definition add_overflow_bv_signed := arith_op_overflow_bv integerAdd true 1 +Definition sub_overflow_bv := arith_op_overflow_bv integerMinus false 1 +Definition sub_overflow_bv_signed := arith_op_overflow_bv integerMinus true 1 +Definition mult_overflow_bv := arith_op_overflow_bv integerMult false 2 +Definition mult_overflow_bv_signed := arith_op_overflow_bv integerMult true 2 + +val arith_op_overflow_bv_bit : forall 'a 'b. Bitvector 'a, Bitvector 'b => + (integer -> integer -> integer) -> bool -> integer -> 'a -> bitU -> ('b * bitU * bitU) +Definition arith_op_overflow_bv_bit op sign size l r_bit := + let act_size := length l * size in + let l' := int_of_bv sign l in + let l_u := int_of_bv false l in + let (n,nu,changed) := match r_bit with + | B1 -> (op l' 1, op l_u 1, true) + | B0 -> (l',l_u,false) + | BU -> failwith "arith_op_overflow_bv_bit applied to undefined bit" + end in + let correct_size := of_int act_size n in + let one_larger := bits_of_int (act_size + 1) nu in + let overflow := + if changed + then + if n <= get_max_representable_in sign act_size && n >= get_min_representable_in sign act_size + then B0 else B1 + else B0 in + (correct_size,overflow,most_significant one_larger) + +Definition add_overflow_bv_bit := arith_op_overflow_bv_bit integerAdd false 1 +Definition add_overflow_bv_bit_signed := arith_op_overflow_bv_bit integerAdd true 1 +Definition sub_overflow_bv_bit := arith_op_overflow_bv_bit integerMinus false 1 +Definition sub_overflow_bv_bit_signed := arith_op_overflow_bv_bit integerMinus true 1 + +type shift := LL_shift | RR_shift | RR_shift_arith | LL_rot | RR_rot + +val shift_op_bv : forall 'a. Bitvector 'a => shift -> 'a -> integer -> 'a +Definition shift_op_bv op v n := + match op with + | LL_shift -> + of_bits (get_bits true v n (length v - 1) ++ repeat [B0] n) + | RR_shift -> + of_bits (repeat [B0] n ++ get_bits true v 0 (length v - n - 1)) + | RR_shift_arith -> + of_bits (repeat [most_significant v] n ++ get_bits true v 0 (length v - n - 1)) + | LL_rot -> + of_bits (get_bits true v n (length v - 1) ++ get_bits true v 0 (n - 1)) + | RR_rot -> + of_bits (get_bits false v 0 (n - 1) ++ get_bits false v n (length v - 1)) + end + +Definition shiftl_bv := shift_op_bv LL_shift (*"<<"*) +Definition shiftr_bv := shift_op_bv RR_shift (*">>"*) +Definition arith_shiftr_bv := shift_op_bv RR_shift_arith +Definition rotl_bv := shift_op_bv LL_rot (*"<<<"*) +Definition rotr_bv := shift_op_bv LL_rot (*">>>"*) + +Definition shiftl_mword w n := Machine_word.shiftLeft w (natFromInteger n) +Definition shiftr_mword w n := Machine_word.shiftRight w (natFromInteger n) +Definition rotl_mword w n := Machine_word.rotateLeft (natFromInteger n) w +Definition rotr_mword w n := Machine_word.rotateRight (natFromInteger n) w + +Definition rec arith_op_no0 (op : integer -> integer -> integer) l r := + if r = 0 + then Nothing + else Just (op l r) + +val arith_op_bv_no0 : forall 'a 'b. Bitvector 'a, Bitvector 'b => + (integer -> integer -> integer) -> bool -> integer -> 'a -> 'a -> 'b +Definition arith_op_bv_no0 op sign size l r := + let act_size := length l * size in + let (l',r') := (int_of_bv sign l,int_of_bv sign r) in + let n := arith_op_no0 op l' r' in + let (representable,n') := + match n with + | Just n' -> + (n' <= get_max_representable_in sign act_size && + n' >= get_min_representable_in sign act_size, n') + | _ -> (false,0) + end in + if representable then (of_int act_size n') else (of_bits (repeat [BU] act_size)) + +Definition mod_bv := arith_op_bv_no0 hardware_mod false 1 +Definition quot_bv := arith_op_bv_no0 hardware_quot false 1 +Definition quot_bv_signed := arith_op_bv_no0 hardware_quot true 1 + +Definition mod_mword := Machine_word.modulo +Definition quot_mword := Machine_word.unsignedDivide +Definition quot_mword_signed := Machine_word.signedDivide + +Definition arith_op_bv_int_no0 op sign size l r := + arith_op_bv_no0 op sign size l (of_int (length l) r) + +Definition quot_bv_int := arith_op_bv_int_no0 hardware_quot false 1 +Definition mod_bv_int := arith_op_bv_int_no0 hardware_mod false 1 +*) +Definition replicate_bits_bv {a b} `{Bitvector a} `{Bitvector b} (v : a) count : b := of_bits (repeat (bits_of v) count). +Import List. +Import ListNotations. +Definition duplicate_bit_bv {a} `{Bitvector a} bit len : a := replicate_bits_bv [bit] len. + +(*val eq_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool*) +Definition eq_bv {A} `{Bitvector A} (l : A) r := (unsigned l =? unsigned r). + +(*val neq_bv : forall 'a. Bitvector 'a => 'a -> 'a -> bool*) +Definition neq_bv (l : a) (r :a) : bool := (negb (unsigned l =? unsigned r)). +(* +val ucmp_bv : forall 'a. Bitvector 'a => (integer -> integer -> bool) -> 'a -> 'a -> bool +Definition ucmp_bv cmp l r := cmp (unsigned l) (unsigned r) + +val scmp_bv : forall 'a. Bitvector 'a => (integer -> integer -> bool) -> 'a -> 'a -> bool +Definition scmp_bv cmp l r := cmp (signed l) (signed r) + +Definition ult_bv := ucmp_bv (<) +Definition slt_bv := scmp_bv (<) +Definition ugt_bv := ucmp_bv (>) +Definition sgt_bv := scmp_bv (>) +Definition ulteq_bv := ucmp_bv (<=) +Definition slteq_bv := scmp_bv (<=) +Definition ugteq_bv := ucmp_bv (>=) +Definition sgteq_bv := scmp_bv (>=) +*) + +End Bitvectors. diff --git a/lib/coq/Sail_operators_bitlists.v b/lib/coq/Sail_operators_bitlists.v new file mode 100644 index 00000000..51c3e972 --- /dev/null +++ b/lib/coq/Sail_operators_bitlists.v @@ -0,0 +1,182 @@ +Require Import Sail_values. +Require Import Sail_operators. + +(* + +(* Specialisation of operators to bit lists *) + +val access_vec_inc : list bitU -> integer -> bitU +let access_vec_inc = access_bv_inc + +val access_vec_dec : list bitU -> integer -> bitU +let access_vec_dec = access_bv_dec + +val update_vec_inc : list bitU -> integer -> bitU -> list bitU +let update_vec_inc = update_bv_inc + +val update_vec_dec : list bitU -> integer -> bitU -> list bitU +let update_vec_dec = update_bv_dec + +val subrange_vec_inc : list bitU -> integer -> integer -> list bitU +let subrange_vec_inc = subrange_bv_inc + +val subrange_vec_dec : list bitU -> integer -> integer -> list bitU +let subrange_vec_dec = subrange_bv_dec + +val update_subrange_vec_inc : list bitU -> integer -> integer -> list bitU -> list bitU +let update_subrange_vec_inc = update_subrange_bv_inc + +val update_subrange_vec_dec : list bitU -> integer -> integer -> list bitU -> list bitU +let update_subrange_vec_dec = update_subrange_bv_dec + +val extz_vec : integer -> list bitU -> list bitU +let extz_vec = extz_bv + +val exts_vec : integer -> list bitU -> list bitU +let exts_vec = exts_bv + +val concat_vec : list bitU -> list bitU -> list bitU +let concat_vec = concat_bv + +val cons_vec : bitU -> list bitU -> list bitU +let cons_vec = cons_bv + +val bool_of_vec : mword ty1 -> bitU +let bool_of_vec = bool_of_bv + +val cast_unit_vec : bitU -> mword ty1 +let cast_unit_vec = cast_unit_bv + +val vec_of_bit : integer -> bitU -> list bitU +let vec_of_bit = bv_of_bit + +val msb : list bitU -> bitU +let msb = most_significant + +val int_of_vec : bool -> list bitU -> integer +let int_of_vec = int_of_bv + +val string_of_vec : list bitU -> string +let string_of_vec = string_of_bv + +val and_vec : list bitU -> list bitU -> list bitU +val or_vec : list bitU -> list bitU -> list bitU +val xor_vec : list bitU -> list bitU -> list bitU +val not_vec : list bitU -> list bitU +let and_vec = and_bv +let or_vec = or_bv +let xor_vec = xor_bv +let not_vec = not_bv + +val add_vec : list bitU -> list bitU -> list bitU +val sadd_vec : list bitU -> list bitU -> list bitU +val sub_vec : list bitU -> list bitU -> list bitU +val mult_vec : list bitU -> list bitU -> list bitU +val smult_vec : list bitU -> list bitU -> list bitU +let add_vec = add_bv +let sadd_vec = sadd_bv +let sub_vec = sub_bv +let mult_vec = mult_bv +let smult_vec = smult_bv + +val add_vec_int : list bitU -> integer -> list bitU +val sadd_vec_int : list bitU -> integer -> list bitU +val sub_vec_int : list bitU -> integer -> list bitU +val mult_vec_int : list bitU -> integer -> list bitU +val smult_vec_int : list bitU -> integer -> list bitU +let add_vec_int = add_bv_int +let sadd_vec_int = sadd_bv_int +let sub_vec_int = sub_bv_int +let mult_vec_int = mult_bv_int +let smult_vec_int = smult_bv_int + +val add_int_vec : integer -> list bitU -> list bitU +val sadd_int_vec : integer -> list bitU -> list bitU +val sub_int_vec : integer -> list bitU -> list bitU +val mult_int_vec : integer -> list bitU -> list bitU +val smult_int_vec : integer -> list bitU -> list bitU +let add_int_vec = add_int_bv +let sadd_int_vec = sadd_int_bv +let sub_int_vec = sub_int_bv +let mult_int_vec = mult_int_bv +let smult_int_vec = smult_int_bv + +val add_vec_bit : list bitU -> bitU -> list bitU +val sadd_vec_bit : list bitU -> bitU -> list bitU +val sub_vec_bit : list bitU -> bitU -> list bitU +let add_vec_bit = add_bv_bit +let sadd_vec_bit = sadd_bv_bit +let sub_vec_bit = sub_bv_bit + +val add_overflow_vec : list bitU -> list bitU -> (list bitU * bitU * bitU) +val add_overflow_vec_signed : list bitU -> list bitU -> (list bitU * bitU * bitU) +val sub_overflow_vec : list bitU -> list bitU -> (list bitU * bitU * bitU) +val sub_overflow_vec_signed : list bitU -> list bitU -> (list bitU * bitU * bitU) +val mult_overflow_vec : list bitU -> list bitU -> (list bitU * bitU * bitU) +val mult_overflow_vec_signed : list bitU -> list bitU -> (list bitU * bitU * bitU) +let add_overflow_vec = add_overflow_bv +let add_overflow_vec_signed = add_overflow_bv_signed +let sub_overflow_vec = sub_overflow_bv +let sub_overflow_vec_signed = sub_overflow_bv_signed +let mult_overflow_vec = mult_overflow_bv +let mult_overflow_vec_signed = mult_overflow_bv_signed + +val add_overflow_vec_bit : list bitU -> bitU -> (list bitU * bitU * bitU) +val add_overflow_vec_bit_signed : list bitU -> bitU -> (list bitU * bitU * bitU) +val sub_overflow_vec_bit : list bitU -> bitU -> (list bitU * bitU * bitU) +val sub_overflow_vec_bit_signed : list bitU -> bitU -> (list bitU * bitU * bitU) +let add_overflow_vec_bit = add_overflow_bv_bit +let add_overflow_vec_bit_signed = add_overflow_bv_bit_signed +let sub_overflow_vec_bit = sub_overflow_bv_bit +let sub_overflow_vec_bit_signed = sub_overflow_bv_bit_signed + +val shiftl : list bitU -> integer -> list bitU +val shiftr : list bitU -> integer -> list bitU +val arith_shiftr : list bitU -> integer -> list bitU +val rotl : list bitU -> integer -> list bitU +val rotr : list bitU -> integer -> list bitU +let shiftl = shiftl_bv +let shiftr = shiftr_bv +let arith_shiftr = arith_shiftr_bv +let rotl = rotl_bv +let rotr = rotr_bv + +val mod_vec : list bitU -> list bitU -> list bitU +val quot_vec : list bitU -> list bitU -> list bitU +val quot_vec_signed : list bitU -> list bitU -> list bitU +let mod_vec = mod_bv +let quot_vec = quot_bv +let quot_vec_signed = quot_bv_signed + +val mod_vec_int : list bitU -> integer -> list bitU +val quot_vec_int : list bitU -> integer -> list bitU +let mod_vec_int = mod_bv_int +let quot_vec_int = quot_bv_int + +val replicate_bits : list bitU -> integer -> list bitU +let replicate_bits = replicate_bits_bv + +val duplicate : bitU -> integer -> list bitU +let duplicate = duplicate_bit_bv + +val eq_vec : list bitU -> list bitU -> bool +val neq_vec : list bitU -> list bitU -> bool +val ult_vec : list bitU -> list bitU -> bool +val slt_vec : list bitU -> list bitU -> bool +val ugt_vec : list bitU -> list bitU -> bool +val sgt_vec : list bitU -> list bitU -> bool +val ulteq_vec : list bitU -> list bitU -> bool +val slteq_vec : list bitU -> list bitU -> bool +val ugteq_vec : list bitU -> list bitU -> bool +val sgteq_vec : list bitU -> list bitU -> bool +let eq_vec = eq_bv +let neq_vec = neq_bv +let ult_vec = ult_bv +let slt_vec = slt_bv +let ugt_vec = ugt_bv +let sgt_vec = sgt_bv +let ulteq_vec = ulteq_bv +let slteq_vec = slteq_bv +let ugteq_vec = ugteq_bv +let sgteq_vec = sgteq_bv +*) diff --git a/lib/coq/Sail_operators_mwords.v b/lib/coq/Sail_operators_mwords.v new file mode 100644 index 00000000..d32a7dbf --- /dev/null +++ b/lib/coq/Sail_operators_mwords.v @@ -0,0 +1,248 @@ +Require Import Sail_values. +Require Import Sail_operators. +Require bbv.Word. +Require Import Arith. +Require Import Omega. + +Definition cast_mword {m n} (x : mword m) (eq : m = n) : mword n. +rewrite <- eq. +exact x. +Defined. + +Definition cast_word {m n} (x : Word.word m) (eq : m = n) : Word.word n. +rewrite <- eq. +exact x. +Defined. + +Definition mword_of_nat {m} (x : Word.word m) : mword (Z.of_nat m). +destruct m. +- exact x. +- simpl. rewrite SuccNat2Pos.id_succ. exact x. +Defined. + +Definition cast_to_mword {m n} (x : Word.word m) (eq : Z.of_nat m = n) : mword n. +destruct n. +- constructor. +- rewrite <- eq. exact (mword_of_nat x). +- exfalso. destruct m; simpl in *; congruence. +Defined. + +(* +(* Specialisation of operators to machine words *) + +val access_vec_inc : forall 'a. Size 'a => mword 'a -> integer -> bitU*) +Definition access_vec_inc {a} : mword a -> Z -> bitU := access_mword_inc. + +(*val access_vec_dec : forall 'a. Size 'a => mword 'a -> integer -> bitU*) +Definition access_vec_dec {a} : mword a -> Z -> bitU := access_mword_dec. + +(*val update_vec_inc : forall 'a. Size 'a => mword 'a -> integer -> bitU -> mword 'a*) +Definition update_vec_inc {a} : mword a -> Z -> bitU -> mword a := update_mword_inc. + +(*val update_vec_dec : forall 'a. Size 'a => mword 'a -> integer -> bitU -> mword 'a*) +Definition update_vec_dec {a} : mword a -> Z -> bitU -> mword a := update_mword_dec. + +(*val subrange_vec_inc : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> integer -> mword 'b*) +Definition subrange_vec_inc {a b} `{ArithFact (b >= 0)} (w: mword a) : Z -> Z -> mword b := subrange_bv_inc w. + +(*val subrange_vec_dec : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> integer -> mword 'b*) +Definition subrange_vec_dec {n m} `{ArithFact (m >= 0)} (w : mword n) : Z -> Z -> mword m := subrange_bv_dec w. + +(*val update_subrange_vec_inc : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> integer -> mword 'b -> mword 'a*) +Definition update_subrange_vec_inc {a b} (v : mword a) i j (w : mword b) : mword a := update_subrange_bv_inc v i j w. + +(*val update_subrange_vec_dec : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> integer -> mword 'b -> mword 'a*) +Definition update_subrange_vec_dec {a b} (v : mword a) i j (w : mword b) : mword a := update_subrange_bv_dec v i j w. + +Lemma mword_nonneg {a} : mword a -> a >= 0. +destruct a; +auto using Z.le_ge, Zle_0_pos with zarith. +destruct 1. +Qed. + +(*val extz_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b*) +Definition extz_vec {a b} `{ArithFact (b >= 0)} `{ArithFact (b >= a)} (n : Z) (v : mword a) : mword b. +refine (cast_to_mword (Word.zext (get_word v) (Z.to_nat (b - a))) _). +unwrap_ArithFacts. +assert (a >= 0). { apply mword_nonneg. assumption. } +rewrite <- Z2Nat.inj_add; try omega. +rewrite Zplus_minus. +apply Z2Nat.id. +auto with zarith. +Defined. + +(*val exts_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b*) +Definition exts_vec {a b} `{ArithFact (b >= 0)} (n : Z) (v : mword a) : mword b := exts_bv n v. + +Definition zero_extend {a} (v : mword a) (n : Z) `{ArithFact (n >= a)} : mword n := extz_vec n v. + +Definition sign_extend {a} (v : mword a) (n : Z) `{ArithFact (n >= a)} : mword n := exts_vec n v. + +Lemma truncate_eq {m n} : m >= 0 -> m <= n -> (Z.to_nat n = Z.to_nat m + (Z.to_nat n - Z.to_nat m))%nat. +intros. +assert ((Z.to_nat m <= Z.to_nat n)%nat). +{ apply Z2Nat.inj_le; omega. } +omega. +Qed. + +Definition vector_truncate {n} (v : mword n) (m : Z) `{ArithFact (m >= 0)} `{ArithFact (m <= n)} : mword m := + cast_to_mword (Word.split1 _ _ (cast_word (get_word v) (ltac:(unwrap_ArithFacts; apply truncate_eq; auto) : Z.to_nat n = Z.to_nat m + (Z.to_nat n - Z.to_nat m))%nat)) (ltac:(unwrap_ArithFacts; apply Z2Nat.id; omega) : Z.of_nat (Z.to_nat m) = m). + +Lemma concat_eq {a b} : a >= 0 -> b >= 0 -> Z.of_nat (Z.to_nat b + Z.to_nat a)%nat = a + b. +intros. +rewrite Nat2Z.inj_add. +rewrite Z2Nat.id; auto with zarith. +rewrite Z2Nat.id; auto with zarith. +Qed. + + +(*val concat_vec : forall 'a 'b 'c. Size 'a, Size 'b, Size 'c => mword 'a -> mword 'b -> mword 'c*) +Definition concat_vec {a b} (v : mword a) (w : mword b) : mword (a + b) := + cast_to_mword (Word.combine (get_word w) (get_word v)) (ltac:(solve [auto using concat_eq, mword_nonneg with zarith]) : Z.of_nat (Z.to_nat b + Z.to_nat a)%nat = a + b). + +(*val cons_vec : forall 'a 'b 'c. Size 'a, Size 'b => bitU -> mword 'a -> mword 'b*) +(*Definition cons_vec {a b} : bitU -> mword a -> mword b := cons_bv.*) + +(*val bool_of_vec : mword ty1 -> bitU +Definition bool_of_vec := bool_of_bv + +val cast_unit_vec : bitU -> mword ty1 +Definition cast_unit_vec := cast_unit_bv + +val vec_of_bit : forall 'a. Size 'a => integer -> bitU -> mword 'a +Definition vec_of_bit := bv_of_bit*) + +Definition vec_of_bits {a} `{ArithFact (a >= 0)} (l:list bitU) : mword a := of_bits l. +(* + +val msb : forall 'a. Size 'a => mword 'a -> bitU +Definition msb := most_significant + +val int_of_vec : forall 'a. Size 'a => bool -> mword 'a -> integer +Definition int_of_vec := int_of_bv + +val string_of_vec : forall 'a. Size 'a => mword 'a -> string +Definition string_of_vec := string_of_bv + +val and_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val or_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val xor_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val not_vec : forall 'a. Size 'a => mword 'a -> mword 'a*) +Definition and_vec {n} (w : mword n) : mword n -> mword n := and_bv w. +Definition or_vec {n} (w : mword n) : mword n -> mword n := or_bv w. +Definition xor_vec {n} (w : mword n) : mword n -> mword n := xor_bv w. +Definition not_vec {n} (w : mword n) : mword n := not_bv w. + +(*val add_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val sadd_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val sub_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val mult_vec : forall 'a 'b. Size 'a, Size 'b => mword 'a -> mword 'a -> mword 'b +val smult_vec : forall 'a 'b. Size 'a, Size 'b => mword 'a -> mword 'a -> mword 'b*) +Definition add_vec {n} (w : mword n) : mword n -> mword n := add_bv w. +Definition sadd_vec {n} (w : mword n) : mword n -> mword n := sadd_bv w. +Definition sub_vec {n} (w : mword n) : mword n -> mword n := sub_bv w. +Definition mult_vec {n} (w : mword n) : mword n -> mword n := mult_bv w. +Definition smult_vec {n} (w : mword n) : mword n -> mword n := smult_bv w. + +(*val add_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val sadd_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val sub_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val mult_vec_int : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b +val smult_vec_int : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b*) +Definition add_vec_int {a} (w : mword a) : Z -> mword a := add_bv_int w. +Definition sadd_vec_int {a} (w : mword a) : Z -> mword a := sadd_bv_int w. +Definition sub_vec_int {a} (w : mword a) : Z -> mword a := sub_bv_int w. +(*Definition mult_vec_int {a b} : mword a -> Z -> mword b := mult_bv_int. +Definition smult_vec_int {a b} : mword a -> Z -> mword b := smult_bv_int.*) + +(*val add_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a +val sadd_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a +val sub_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a +val mult_int_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b +val smult_int_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b +Definition add_int_vec := add_int_bv +Definition sadd_int_vec := sadd_int_bv +Definition sub_int_vec := sub_int_bv +Definition mult_int_vec := mult_int_bv +Definition smult_int_vec := smult_int_bv + +val add_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> mword 'a +val sadd_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> mword 'a +val sub_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> mword 'a +Definition add_vec_bit := add_bv_bit +Definition sadd_vec_bit := sadd_bv_bit +Definition sub_vec_bit := sub_bv_bit + +val add_overflow_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val add_overflow_vec_signed : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val sub_overflow_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val sub_overflow_vec_signed : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val mult_overflow_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +val mult_overflow_vec_signed : forall 'a. Size 'a => mword 'a -> mword 'a -> (mword 'a * bitU * bitU) +Definition add_overflow_vec := add_overflow_bv +Definition add_overflow_vec_signed := add_overflow_bv_signed +Definition sub_overflow_vec := sub_overflow_bv +Definition sub_overflow_vec_signed := sub_overflow_bv_signed +Definition mult_overflow_vec := mult_overflow_bv +Definition mult_overflow_vec_signed := mult_overflow_bv_signed + +val add_overflow_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> (mword 'a * bitU * bitU) +val add_overflow_vec_bit_signed : forall 'a. Size 'a => mword 'a -> bitU -> (mword 'a * bitU * bitU) +val sub_overflow_vec_bit : forall 'a. Size 'a => mword 'a -> bitU -> (mword 'a * bitU * bitU) +val sub_overflow_vec_bit_signed : forall 'a. Size 'a => mword 'a -> bitU -> (mword 'a * bitU * bitU) +Definition add_overflow_vec_bit := add_overflow_bv_bit +Definition add_overflow_vec_bit_signed := add_overflow_bv_bit_signed +Definition sub_overflow_vec_bit := sub_overflow_bv_bit +Definition sub_overflow_vec_bit_signed := sub_overflow_bv_bit_signed + +val shiftl : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val shiftr : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val arith_shiftr : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val rotl : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val rotr : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +Definition shiftl := shiftl_bv +Definition shiftr := shiftr_bv +Definition arith_shiftr := arith_shiftr_bv +Definition rotl := rotl_bv +Definition rotr := rotr_bv + +val mod_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val quot_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +val quot_vec_signed : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a +Definition mod_vec := mod_bv +Definition quot_vec := quot_bv +Definition quot_vec_signed := quot_bv_signed + +val mod_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +val quot_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a +Definition mod_vec_int := mod_bv_int +Definition quot_vec_int := quot_bv_int + +val replicate_bits : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b*) +Definition replicate_bits {a b} `{ArithFact (b >= 0)} (w : mword a) : Z -> mword b := replicate_bits_bv w. + +(*val duplicate : forall 'a. Size 'a => bitU -> integer -> mword 'a +Definition duplicate := duplicate_bit_bv + +val eq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val neq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val ult_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val slt_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val ugt_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val sgt_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val ulteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val slteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val ugteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool +val sgteq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool*) +Definition eq_vec {n} (w : mword n) : mword n -> bool := eq_bv w. +Definition neq_vec {n} (w : mword n) : mword n -> bool := neq_bv w. +(*Definition ult_vec := ult_bv. +Definition slt_vec := slt_bv. +Definition ugt_vec := ugt_bv. +Definition sgt_vec := sgt_bv. +Definition ulteq_vec := ulteq_bv. +Definition slteq_vec := slteq_bv. +Definition ugteq_vec := ugteq_bv. +Definition sgteq_vec := sgteq_bv. + +*) diff --git a/lib/coq/Sail_values.v b/lib/coq/Sail_values.v new file mode 100644 index 00000000..ee3a90bb --- /dev/null +++ b/lib/coq/Sail_values.v @@ -0,0 +1,1167 @@ +(* Version of sail_values.lem that uses Lems machine words library *) + +(*Require Import Sail_impl_base*) +Require Import ZArith. +Require Export ZArith. +Require Import String. +Require Import bbv.Word. +Require Import List. +Import ListNotations. + +Open Scope Z. + +(* Constraint solving basics. A HintDb which unfolding hints and lemmata + can be added to, and a typeclass to wrap constraint arguments in to + trigger automatic solving. *) +Create HintDb sail. +Class ArithFact (P : Prop) := { fact : P }. +Lemma use_ArithFact {P} `(ArithFact P) : P. +apply fact. +Defined. + +Definition build_ex (n:Z) {P:Z -> Prop} `{H:ArithFact (P n)} : {x : Z & ArithFact (P x)} := + existT _ n H. + + +Definition ii := Z. +Definition nn := nat. + +(* +val pow : Z -> Z -> Z +Definition pow m n := m ** (Z.to_nat n) + +Definition pow2 n := pow 2 n + +Definition inline lt := (<) +Definition inline gt := (>) +Definition inline lteq := (<=) +Definition inline gteq := (>=) + +val eq : forall a. Eq a => a -> a -> bool +Definition inline eq l r := (l = r) + +val neq : forall a. Eq a => a -> a -> bool*) +Definition neq l r := (negb (l =? r)). (* Z only *) + +(*let add_int l r := integerAdd l r +Definition add_signed l r := integerAdd l r +Definition sub_int l r := integerMinus l r +Definition mult_int l r := integerMult l r +Definition div_int l r := integerDiv l r +Definition div_nat l r := natDiv l r +Definition power_int_nat l r := integerPow l r +Definition power_int_int l r := integerPow l (Z.to_nat r) +Definition negate_int i := integerNegate i +Definition min_int l r := integerMin l r +Definition max_int l r := integerMax l r + +Definition add_real l r := realAdd l r +Definition sub_real l r := realMinus l r +Definition mult_real l r := realMult l r +Definition div_real l r := realDiv l r +Definition negate_real r := realNegate r +Definition abs_real r := realAbs r +Definition power_real b e := realPowInteger b e*) + +Definition print_int (_ : string) (_ : Z) : unit := tt. + +(* +Definition or_bool l r := (l || r) +Definition and_bool l r := (l && r) +Definition xor_bool l r := xor l r +*) +Definition append_list {A:Type} (l : list A) r := l ++ r. +Definition length_list {A:Type} (xs : list A) := Z.of_nat (List.length xs). +Definition take_list {A:Type} n (xs : list A) := firstn (Z.to_nat n) xs. +Definition drop_list {A:Type} n (xs : list A) := skipn (Z.to_nat n) xs. +(* +val repeat : forall a. list a -> Z -> list a*) +Fixpoint repeat' {a} (xs : list a) n := + match n with + | O => [] + | S n => xs ++ repeat' xs n + end. +Definition repeat {a} (xs : list a) (n : Z) := + if n <=? 0 then [] + else repeat' xs (Z.to_nat n). +(*declare {isabelle} termination_argument repeat = automatic + +Definition duplicate_to_list bit length := repeat [bit] length + +Fixpoint replace bs (n : Z) b' := match bs with + | [] => [] + | b :: bs => + if n = 0 then b' :: bs + else b :: replace bs (n - 1) b' + end +declare {isabelle} termination_argument replace = automatic + +Definition upper n := n + +(* Modulus operation corresponding to quot below -- result + has sign of dividend. *) +Definition hardware_mod (a: Z) (b:Z) : Z := + let m := (abs a) mod (abs b) in + if a < 0 then ~m else m + +(* There are different possible answers for integer divide regarding +rounding behaviour on negative operands. Positive operands always +round down so derive the one we want (trucation towards zero) from +that *) +Definition hardware_quot (a:Z) (b:Z) : Z := + let q := (abs a) / (abs b) in + if ((a<0) = (b<0)) then + q (* same sign -- result positive *) + else + ~q (* different sign -- result negative *) + +Definition max_64u := (integerPow 2 64) - 1 +Definition max_64 := (integerPow 2 63) - 1 +Definition min_64 := 0 - (integerPow 2 63) +Definition max_32u := (4294967295 : Z) +Definition max_32 := (2147483647 : Z) +Definition min_32 := (0 - 2147483648 : Z) +Definition max_8 := (127 : Z) +Definition min_8 := (0 - 128 : Z) +Definition max_5 := (31 : Z) +Definition min_5 := (0 - 32 : Z) +*) +(*** Bits *) +Inductive bitU := B0 | B1 | BU. + +Definition showBitU b := +match b with + | B0 => "O" + | B1 => "I" + | BU => "U" +end%string. + +(*instance (Show bitU) + let show := showBitU +end*) + +Class BitU (a : Type) : Type := { + to_bitU : a -> bitU; + of_bitU : bitU -> a +}. + +Instance bitU_BitU : (BitU bitU) := { + to_bitU b := b; + of_bitU b := b +}. + +(* TODO: consider alternatives *) +Parameter undefined_BU : bool. + +Definition bool_of_bitU bu := match bu with + | B0 => false + | B1 => true + | BU => undefined_BU (*failwith "bool_of_bitU applied to BU"*) + end. + +Definition bitU_of_bool (b : bool) := if b then B1 else B0. + +Instance bool_BitU : (BitU bool) := { + to_bitU := bitU_of_bool; + of_bitU := bool_of_bitU +}. + +Definition cast_bit_bool := bool_of_bitU. +(* +Definition bit_lifted_of_bitU bu := match bu with + | B0 => Bitl_zero + | B1 => Bitl_one + | BU => Bitl_undef + end. + +Definition bitU_of_bit := function + | Bitc_zero => B0 + | Bitc_one => B1 + end. + +Definition bit_of_bitU := function + | B0 => Bitc_zero + | B1 => Bitc_one + | BU => failwith "bit_of_bitU: BU" + end. + +Definition bitU_of_bit_lifted := function + | Bitl_zero => B0 + | Bitl_one => B1 + | Bitl_undef => BU + | Bitl_unknown => failwith "bitU_of_bit_lifted Bitl_unknown" + end. +*) +Definition not_bit b := +match b with + | B1 => B0 + | B0 => B1 + | BU => BU + end. + +(*val is_one : Z -> bitU*) +Definition is_one (i : Z) := + if i =? 1 then B1 else B0. + +Definition binop_bit op x y := + match (x, y) with + | (BU,_) => BU (*Do we want to do this or to respect | of I and & of B0 rules?*) + | (_,BU) => BU (*Do we want to do this or to respect | of I and & of B0 rules?*) + | (x,y) => bitU_of_bool (op (bool_of_bitU x) (bool_of_bitU y)) + end. + +(*val and_bit : bitU -> bitU -> bitU +Definition and_bit := binop_bit (&&) + +val or_bit : bitU -> bitU -> bitU +Definition or_bit := binop_bit (||) + +val xor_bit : bitU -> bitU -> bitU +Definition xor_bit := binop_bit xor + +val (&.) : bitU -> bitU -> bitU +Definition inline (&.) x y := and_bit x y + +val (|.) : bitU -> bitU -> bitU +Definition inline (|.) x y := or_bit x y + +val (+.) : bitU -> bitU -> bitU +Definition inline (+.) x y := xor_bit x y +*) + +(*** Bit lists ***) + +(*val bits_of_nat_aux : natural -> list bitU*) +Fixpoint bits_of_nat_aux n x := + match n,x with + | O,_ => [] + | _,O => [] + | S n, S _ => (if x mod 2 =? 1 then B1 else B0) :: bits_of_nat_aux n (x / 2) + end%nat. +(**declare {isabelle} termination_argument bits_of_nat_aux = automatic*) +Definition bits_of_nat n := List.rev (bits_of_nat_aux n n). + +(*val nat_of_bits_aux : natural -> list bitU -> natural*) +Fixpoint nat_of_bits_aux acc bs := match bs with + | [] => acc + | B1 :: bs => nat_of_bits_aux ((2 * acc) + 1) bs + | B0 :: bs => nat_of_bits_aux (2 * acc) bs + | BU :: bs => (*failwith "nat_of_bits_aux: bit list has undefined bits"*) + nat_of_bits_aux ((2 * acc) + if undefined_BU then 1 else 0) bs +end%nat. +(*declare {isabelle} termination_argument nat_of_bits_aux = automatic*) +Definition nat_of_bits bits := nat_of_bits_aux 0 bits. + +Definition not_bits := List.map not_bit. + +Definition binop_bits op bsl bsr := + List.fold_right (fun '(bl, br) acc => binop_bit op bl br :: acc) [] (List.combine bsl bsr). +(* +Definition and_bits := binop_bits (&&) +Definition or_bits := binop_bits (||) +Definition xor_bits := binop_bits xor + +val unsigned_of_bits : list bitU -> Z*) +Definition unsigned_of_bits bs := Z.of_nat (nat_of_bits bs). + +(*val signed_of_bits : list bitU -> Z*) +Parameter undefined_Z : Z. +Definition signed_of_bits bits := + match bits with + | B1 :: _ => 0 - (1 + (unsigned_of_bits (not_bits bits))) + | B0 :: _ => unsigned_of_bits bits + | BU :: _ => undefined_Z (*failwith "signed_of_bits applied to list with undefined bits"*) + | [] => undefined_Z (*failwith "signed_of_bits applied to empty list"*) + end. + +(*val pad_bitlist : bitU -> list bitU -> Z -> list bitU*) +Fixpoint pad_bitlist_nat (b : bitU) bits n := +match n with +| O => bits +| S n' => pad_bitlist_nat b (b :: bits) n' +end. +Definition pad_bitlist b bits n := pad_bitlist_nat b bits (Z.to_nat n). (* Negative n will come out as 0 *) +(* if n <= 0 then bits else pad_bitlist b (b :: bits) (n - 1). +declare {isabelle} termination_argument pad_bitlist = automatic*) + +Definition ext_bits pad len bits := + let longer := len - (Z.of_nat (List.length bits)) in + if longer <? 0 then skipn (Z.abs_nat longer) bits + else pad_bitlist pad bits longer. + +Definition extz_bits len bits := ext_bits B0 len bits. +Parameter undefined_list_bitU : list bitU. +Definition exts_bits len bits := + match bits with + | BU :: _ => undefined_list_bitU (*failwith "exts_bits: undefined bit"*) + | B1 :: _ => ext_bits B1 len bits + | _ => ext_bits B0 len bits + end. + +Fixpoint add_one_bit_ignore_overflow_aux bits := match bits with + | [] => [] + | B0 :: bits => B1 :: bits + | B1 :: bits => B0 :: add_one_bit_ignore_overflow_aux bits + | BU :: _ => undefined_list_bitU (*failwith "add_one_bit_ignore_overflow: undefined bit"*) +end. +(*declare {isabelle} termination_argument add_one_bit_ignore_overflow_aux = automatic*) + +Definition add_one_bit_ignore_overflow bits := + rev (add_one_bit_ignore_overflow_aux (rev bits)). + +Definition bitlist_of_int n := + let bits_abs := B0 :: bits_of_nat (Zabs_nat n) in + if n >=? 0 then bits_abs + else add_one_bit_ignore_overflow (not_bits bits_abs). + +Definition bits_of_int len n := exts_bits len (bitlist_of_int n). +(* +Definition char_of_nibble := function + | (B0, B0, B0, B0) => Some #'0' + | (B0, B0, B0, B1) => Some #'1' + | (B0, B0, B1, B0) => Some #'2' + | (B0, B0, B1, B1) => Some #'3' + | (B0, B1, B0, B0) => Some #'4' + | (B0, B1, B0, B1) => Some #'5' + | (B0, B1, B1, B0) => Some #'6' + | (B0, B1, B1, B1) => Some #'7' + | (B1, B0, B0, B0) => Some #'8' + | (B1, B0, B0, B1) => Some #'9' + | (B1, B0, B1, B0) => Some #'A' + | (B1, B0, B1, B1) => Some #'B' + | (B1, B1, B0, B0) => Some #'C' + | (B1, B1, B0, B1) => Some #'D' + | (B1, B1, B1, B0) => Some #'E' + | (B1, B1, B1, B1) => Some #'F' + | _ => None + end + +Fixpoint hexstring_of_bits bs := match bs with + | b1 :: b2 :: b3 :: b4 :: bs => + let n := char_of_nibble (b1, b2, b3, b4) in + let s := hexstring_of_bits bs in + match (n, s) with + | (Some n, Some s) => Some (n :: s) + | _ => None + end + | _ => None + end +declare {isabelle} termination_argument hexstring_of_bits = automatic + +Definition show_bitlist bs := + match hexstring_of_bits bs with + | Some s => toString (#'0' :: #x' :: s) + | None => show bs + end + +(*** List operations *) + +Definition inline (^^) := append_list + +val subrange_list_inc : forall a. list a -> Z -> Z -> list a*) +Definition subrange_list_inc {A} (xs : list A) i j := + let toJ := firstn (Z.to_nat j + 1) xs in + let fromItoJ := skipn (Z.to_nat i) toJ in + fromItoJ. + +(*val subrange_list_dec : forall a. list a -> Z -> Z -> list a*) +Definition subrange_list_dec {A} (xs : list A) i j := + let top := (length_list xs) - 1 in + subrange_list_inc xs (top - i) (top - j). + +(*val subrange_list : forall a. bool -> list a -> Z -> Z -> list a*) +Definition subrange_list {A} (is_inc : bool) (xs : list A) i j := + if is_inc then subrange_list_inc xs i j else subrange_list_dec xs i j. + +Definition splitAt {A} n (l : list A) := (firstn n l, skipn n l). + +(*val update_subrange_list_inc : forall a. list a -> Z -> Z -> list a -> list a*) +Definition update_subrange_list_inc {A} (xs : list A) i j xs' := + let (toJ,suffix) := splitAt (Z.to_nat j + 1) xs in + let (prefix,_fromItoJ) := splitAt (Z.to_nat i) toJ in + prefix ++ xs' ++ suffix. + +(*val update_subrange_list_dec : forall a. list a -> Z -> Z -> list a -> list a*) +Definition update_subrange_list_dec {A} (xs : list A) i j xs' := + let top := (length_list xs) - 1 in + update_subrange_list_inc xs (top - i) (top - j) xs'. + +(*val update_subrange_list : forall a. bool -> list a -> Z -> Z -> list a -> list a*) +Definition update_subrange_list {A} (is_inc : bool) (xs : list A) i j xs' := + if is_inc then update_subrange_list_inc xs i j xs' else update_subrange_list_dec xs i j xs'. + +Open Scope nat. +Fixpoint nth_in_range {A} (n:nat) (l:list A) : n < length l -> A. +refine + (match n, l with + | O, h::_ => fun _ => h + | S m, _::t => fun H => nth_in_range A m t _ + | _,_ => fun H => _ + end). +exfalso. inversion H. +exfalso. inversion H. +simpl in H. omega. +Defined. + +Lemma nth_in_range_is_nth : forall A n (l : list A) d (H : n < length l), + nth_in_range n l H = nth n l d. +intros until d. revert n. +induction l; intros n H. +* inversion H. +* destruct n. + + reflexivity. + + apply IHl. +Qed. + +Lemma nth_Z_nat {A} {n} {xs : list A} : + (0 <= n)%Z -> (n < length_list xs)%Z -> Z.to_nat n < length xs. +unfold length_list. +intros nonneg bounded. +rewrite Z2Nat.inj_lt in bounded; auto using Zle_0_nat. +rewrite Nat2Z.id in bounded. +assumption. +Qed. + +(* +Lemma nth_top_aux {A} {n} {xs : list A} : Z.to_nat n < length xs -> let top := ((length_list xs) - 1)%Z in Z.to_nat (top - n)%Z < length xs. +unfold length_list. +generalize (length xs). +intro n0. +rewrite <- (Nat2Z.id n0). +intro H. +apply Z2Nat.inj_lt. +* omega. +*) + +Close Scope nat. + +(*val access_list_inc : forall a. list a -> Z -> a*) +Definition access_list_inc {A} (xs : list A) n `{ArithFact (0 <= n)} `{ArithFact (n < length_list xs)} := nth_in_range (Z.to_nat n) xs (nth_Z_nat (use_ArithFact _) (use_ArithFact _)). + +(*val access_list_dec : forall a. list a -> Z -> a*) +Definition access_list_dec {A} (xs : list A) n `{ArithFact (0 <= n)} `{ArithFact (n < length_list xs)} : A. +refine ( + let top := (length_list xs) - 1 in + @access_list_inc A xs (top - n) _ _). +constructor. apply use_ArithFact in H. apply use_ArithFact in H0. omega. +constructor. apply use_ArithFact in H. apply use_ArithFact in H0. omega. +Defined. + +(*val access_list : forall a. bool -> list a -> Z -> a*) +Definition access_list {A} (is_inc : bool) (xs : list A) n `{ArithFact (0 <= n)} `{ArithFact (n < length_list xs)} := + if is_inc then access_list_inc xs n else access_list_dec xs n. + +Definition access_list_opt_inc {A} (xs : list A) n := nth_error xs (Z.to_nat n). + +(*val access_list_dec : forall a. list a -> Z -> a*) +Definition access_list_opt_dec {A} (xs : list A) n := + let top := (length_list xs) - 1 in + access_list_opt_inc xs (top - n). + +(*val access_list : forall a. bool -> list a -> Z -> a*) +Definition access_list_opt {A} (is_inc : bool) (xs : list A) n := + if is_inc then access_list_opt_inc xs n else access_list_opt_dec xs n. + +Definition list_update {A} (xs : list A) n x := firstn n xs ++ x :: skipn (S n) xs. + +(*val update_list_inc : forall a. list a -> Z -> a -> list a*) +Definition update_list_inc {A} (xs : list A) n x := list_update xs (Z.to_nat n) x. + +(*val update_list_dec : forall a. list a -> Z -> a -> list a*) +Definition update_list_dec {A} (xs : list A) n x := + let top := (length_list xs) - 1 in + update_list_inc xs (top - n) x. + +(*val update_list : forall a. bool -> list a -> Z -> a -> list a*) +Definition update_list {A} (is_inc : bool) (xs : list A) n x := + if is_inc then update_list_inc xs n x else update_list_dec xs n x. + +(*Definition extract_only_element := function + | [] => failwith "extract_only_element called for empty list" + | [e] => e + | _ => failwith "extract_only_element called for list with more elements" +end + +(* just_list takes a list of maybes and returns Some xs if all elements have + a value, and None if one of the elements is None. *) +val just_list : forall a. list (option a) -> option (list a)*) +Fixpoint just_list {A} (l : list (option A)) := match l with + | [] => Some [] + | (x :: xs) => + match (x, just_list xs) with + | (Some x, Some xs) => Some (x :: xs) + | (_, _) => None + end + end. +(*declare {isabelle} termination_argument just_list = automatic + +lemma just_list_spec: + ((forall xs. (just_list xs = None) <-> List.elem None xs) && + (forall xs es. (just_list xs = Some es) <-> (xs = List.map Some es))) + +(*** Machine words *) +*) +Definition mword (n : Z) := + match n with + | Zneg _ => False + | Z0 => word 0 + | Zpos p => word (Pos.to_nat p) + end. + +Definition get_word {n} : mword n -> word (Z.to_nat n) := + match n with + | Zneg _ => fun x => match x with end + | Z0 => fun x => x + | Zpos p => fun x => x + end. + +Definition with_word {n} {P : Type -> Type} : (word (Z.to_nat n) -> P (word (Z.to_nat n))) -> mword n -> P (mword n) := +match n with +| Zneg _ => fun f w => match w with end +| Z0 => fun f w => f w +| Zpos _ => fun f w => f w +end. + +Program Definition to_word {n} : n >= 0 -> word (Z.to_nat n) -> mword n := + match n with + | Zneg _ => fun H _ => _ + | Z0 => fun _ w => w + | Zpos _ => fun _ w => w + end. + +(*val length_mword : forall a. mword a -> Z*) +Definition length_mword {n} (w : mword n) := n. + +(*val slice_mword_dec : forall a b. mword a -> Z -> Z -> mword b*) +(*Definition slice_mword_dec w i j := word_extract (Z.to_nat i) (Z.to_nat j) w. + +val slice_mword_inc : forall a b. mword a -> Z -> Z -> mword b +Definition slice_mword_inc w i j := + let top := (length_mword w) - 1 in + slice_mword_dec w (top - i) (top - j) + +val slice_mword : forall a b. bool -> mword a -> Z -> Z -> mword b +Definition slice_mword is_inc w i j := if is_inc then slice_mword_inc w i j else slice_mword_dec w i j + +val update_slice_mword_dec : forall a b. mword a -> Z -> Z -> mword b -> mword a +Definition update_slice_mword_dec w i j w' := word_update w (Z.to_nat i) (Z.to_nat j) w' + +val update_slice_mword_inc : forall a b. mword a -> Z -> Z -> mword b -> mword a +Definition update_slice_mword_inc w i j w' := + let top := (length_mword w) - 1 in + update_slice_mword_dec w (top - i) (top - j) w' + +val update_slice_mword : forall a b. bool -> mword a -> Z -> Z -> mword b -> mword a +Definition update_slice_mword is_inc w i j w' := + if is_inc then update_slice_mword_inc w i j w' else update_slice_mword_dec w i j w' + +val access_mword_dec : forall a. mword a -> Z -> bitU*) +Parameter undefined_bit : bool. +Definition getBit {n} := +match n with +| O => fun (w : word O) i => undefined_bit +| S n => fun (w : word (S n)) i => wlsb (wrshift w i) +end. + +Definition access_mword_dec {m} (w : mword m) n := bitU_of_bool (getBit (get_word w) (Z.to_nat n)). + +(*val access_mword_inc : forall a. mword a -> Z -> bitU*) +Definition access_mword_inc {m} (w : mword m) n := + let top := (length_mword w) - 1 in + access_mword_dec w (top - n). + +(*Parameter access_mword : forall {a}, bool -> mword a -> Z -> bitU.*) +Definition access_mword {a} (is_inc : bool) (w : mword a) n := + if is_inc then access_mword_inc w n else access_mword_dec w n. + +Definition setBit {n} := +match n with +| O => fun (w : word O) i b => w +| S n => fun (w : word (S n)) i (b : bool) => + let bit : word (S n) := wlshift (natToWord _ 1) i in + let mask : word (S n) := wnot bit in + let masked := wand mask w in + if b then masked else wor masked bit +end. + +(*val update_mword_dec : forall a. mword a -> Z -> bitU -> mword a*) +Definition update_mword_dec {a} (w : mword a) n b : mword a := + with_word (P := id) (fun w => setBit w (Z.to_nat n) (bool_of_bitU b)) w. + +(*val update_mword_inc : forall a. mword a -> Z -> bitU -> mword a*) +Definition update_mword_inc {a} (w : mword a) n b := + let top := (length_mword w) - 1 in + update_mword_dec w (top - n) b. + +(*Parameter update_mword : forall {a}, bool -> mword a -> Z -> bitU -> mword a.*) +Definition update_mword {a} (is_inc : bool) (w : mword a) n b := + if is_inc then update_mword_inc w n b else update_mword_dec w n b. + +(*val mword_of_int : forall a. Size a => Z -> Z -> mword a +Definition mword_of_int len n := + let w := wordFromInteger n in + if (length_mword w = len) then w else failwith "unexpected word length" +*) +Program Definition mword_of_int len {H:len >= 0} n : mword len := +match len with +| Zneg _ => _ +| Z0 => ZToWord 0 n +| Zpos p => ZToWord (Pos.to_nat p) n +end. +(* +(* Translating between a type level number (itself n) and an integer *) + +Definition size_itself_int x := Z.of_nat (size_itself x) + +(* NB: the corresponding sail type is forall n. atom(n) -> itself(n), + the actual integer is ignored. *) + +val make_the_value : forall n. Z -> itself n +Definition inline make_the_value x := the_value +*) + +Fixpoint bitlistFromWord {n} w := +match w with +| WO => [] +| WS b w => b :: bitlistFromWord w +end. + +Fixpoint wordFromBitlist l : word (length l) := +match l with +| [] => WO +| b::t => WS b (wordFromBitlist t) +end. + +Local Open Scope nat. +Program Definition fit_bbv_word {n m} (w : word n) : word m := +match Nat.compare m n with +| Gt => extz w (m - n) +| Eq => w +| Lt => split2 (n - m) m w +end. +Next Obligation. +symmetry in Heq_anonymous. +apply nat_compare_gt in Heq_anonymous. +omega. +Defined. +Next Obligation. + +symmetry in Heq_anonymous. +apply nat_compare_eq in Heq_anonymous. +omega. +Defined. +Next Obligation. + +symmetry in Heq_anonymous. +apply nat_compare_lt in Heq_anonymous. +omega. +Defined. +Local Close Scope nat. + +(*** Bitvectors *) + +Class Bitvector (a:Type) : Type := { + bits_of : a -> list bitU; + of_bits : list bitU -> a; + (* The first parameter specifies the desired length of the bitvector *) + of_int : Z -> Z -> a; + length : a -> Z; + unsigned : a -> Z; + signed : a -> Z; + (* The first parameter specifies the indexing order (true is increasing) *) + get_bit : bool -> a -> Z -> bitU; + set_bit : bool -> a -> Z -> bitU -> a; + get_bits : bool -> a -> Z -> Z -> list bitU; + set_bits : bool -> a -> Z -> Z -> list bitU -> a +}. + +Parameter undefined_bitU : bitU. (* A missing value of type bitU, as opposed to BU, which is an undefined bit *) +Definition opt_bitU {a : Type} `{BitU a} (b : option a) := +match b with +| None => undefined_bitU +| Some c => to_bitU c +end. + +Instance bitlist_Bitvector {a : Type} `{BitU a} : (Bitvector (list a)) := { + bits_of v := List.map to_bitU v; + of_bits v := List.map of_bitU v; + of_int len n := List.map of_bitU (bits_of_int len n); + length := length_list; + unsigned v := unsigned_of_bits (List.map to_bitU v); + signed v := signed_of_bits (List.map to_bitU v); + get_bit is_inc v n := opt_bitU (access_list_opt is_inc v n); + set_bit is_inc v n b := update_list is_inc v n (of_bitU b); + get_bits is_inc v i j := List.map to_bitU (subrange_list is_inc v i j); + set_bits is_inc v i j v' := update_subrange_list is_inc v i j (List.map of_bitU v') +}. + +Class ReasonableSize (a : Z) : Prop := { + isPositive : a >= 0 +}. + +Hint Resolve -> Z.gtb_lt Z.geb_le Z.ltb_lt Z.leb_le : zbool. +Hint Resolve <- Z.ge_le_iff Z.gt_lt_iff : zbool. + +Lemma ArithFact_mword (a : Z) (w : mword a) : ArithFact (a >= 0). +constructor. +destruct a. +auto with zarith. +auto using Z.le_ge, Zle_0_pos. +destruct w. +Qed. +Ltac unwrap_ArithFacts := + repeat match goal with H:(ArithFact _) |- _ => apply use_ArithFact in H end. +Ltac unbool_comparisons := + repeat match goal with + | H:context [Z.leb _ _ = true] |- _ => rewrite Z.leb_le in H + | H:context [Z.ltb _ _ = true] |- _ => rewrite Z.ltb_lt in H + | H:context [Z.geb _ _ = true] |- _ => rewrite Z.geb_le in H + | H:context [Z.gtb _ _ = true] |- _ => rewrite Z.gtb_lt in H + | H:context [Z.eqb _ _ = true] |- _ => rewrite Z.eqb_eq in H + | H:context [Z.leb _ _ = false] |- _ => rewrite Z.leb_gt in H + | H:context [Z.ltb _ _ = false] |- _ => rewrite Z.ltb_ge in H + | H:context [Z.eqb _ _ = false] |- _ => rewrite Z.eqb_neq in H + | H:context [orb _ _ = true] |- _ => rewrite Bool.orb_true_iff in H + end. +(* Split up dependent pairs to get at proofs of properties *) +(* TODO: simpl is probably too strong here *) +Ltac extract_properties := + repeat match goal with H := (projT1 ?X) |- _ => destruct X in *; simpl in H; unfold H in * end; + repeat match goal with |- context [projT1 ?X] => destruct X in *; simpl end. +Ltac reduce_list_lengths := + repeat match goal with |- context [length_list ?X] => + let r := (eval cbn in (length_list X)) in + change (length_list X) with r + end. +Ltac solve_arithfact := + extract_properties; + repeat match goal with w:mword ?n |- _ => apply ArithFact_mword in w end; + unwrap_ArithFacts; + autounfold with sail in * |- *; (* You can add Hint Unfold ... : sail to let omega see through fns *) + unbool_comparisons; + reduce_list_lengths; + solve [apply ArithFact_mword; assumption + | constructor; omega + (* The datatypes hints give us some list handling, esp In *) + | constructor; auto with datatypes zbool zarith sail]. +Hint Extern 0 (ArithFact _) => solve_arithfact : typeclass_instances. + +Hint Unfold length_mword : sail. + +Lemma ReasonableSize_witness (a : Z) (w : mword a) : ReasonableSize a. +constructor. +destruct a. +auto with zarith. +auto using Z.le_ge, Zle_0_pos. +destruct w. +Qed. + +Goal forall x y, ArithFact (x > y) -> ArithFact (y > 0) -> x >= 0. +intros. +unwrap_ArithFacts. +omega. +Abort. + +Hint Extern 0 (ReasonableSize ?A) => (unwrap_ArithFacts; solve [apply ReasonableSize_witness; assumption | constructor; omega]) : typeclass_instances. + +Instance mword_Bitvector {a : Z} `{ReasonableSize a} : (Bitvector (mword a)) := { + bits_of v := List.map to_bitU (bitlistFromWord (get_word v)); + of_bits v := to_word isPositive (fit_bbv_word (wordFromBitlist (List.map of_bitU v))); + of_int len z := @mword_of_int a isPositive z; (* cheat a little *) + length v := a; + unsigned v := Z.of_N (wordToN (get_word v)); + signed v := wordToZ (get_word v); + get_bit := access_mword; + set_bit := update_mword; + get_bits is_inc v i j := get_bits is_inc (bitlistFromWord (get_word v)) i j; + set_bits is_inc v i j v' := to_word isPositive (fit_bbv_word (wordFromBitlist (set_bits is_inc (bitlistFromWord (get_word v)) i j v'))) +}. + +Section Bitvector_defs. +Context {a b} `{Bitvector a} `{Bitvector b}. + +Definition access_bv_inc (v : a) n := get_bit true v n. +Definition access_bv_dec (v : a) n := get_bit false v n. + +Definition update_bv_inc (v : a) n b := set_bit true v n b. +Definition update_bv_dec (v : a) n b := set_bit false v n b. + +Definition subrange_bv_inc (v : a) i j : b := of_bits (get_bits true v i j). +Definition subrange_bv_dec (v : a) i j : b := of_bits (get_bits false v i j). + +Definition update_subrange_bv_inc (v : a) i j (v' : b) := set_bits true v i j (bits_of v'). +Definition update_subrange_bv_dec (v : a) i j (v' : b) := set_bits false v i j (bits_of v'). + +(*val extz_bv : forall a b. Bitvector a, Bitvector b => Z -> a -> b*) +Definition extz_bv n (v : a) : b := of_bits (extz_bits n (bits_of v)). + +(*val exts_bv : forall a b. Bitvector a, Bitvector b => Z -> a -> b*) +Definition exts_bv n (v : a) :b := of_bits (exts_bits n (bits_of v)). + +(*val string_of_bv : forall a. Bitvector a => a -> string +Definition string_of_bv v := show_bitlist (bits_of v) +*) +End Bitvector_defs. + +(*** Bytes and addresses *) + +Definition memory_byte := list bitU. + +(*val byte_chunks : forall a. list a -> option (list (list a))*) +Fixpoint byte_chunks {a} (bs : list a) := match bs with + | [] => Some [] + | a::b::c::d::e::f::g::h::rest => + match byte_chunks rest with + | None => None + | Some rest => Some ([a;b;c;d;e;f;g;h] :: rest) + end + | _ => None +end. +(*declare {isabelle} termination_argument byte_chunks = automatic*) + +Section BytesBits. +Context {a} `{Bitvector a}. + +(*val bytes_of_bits : forall a. Bitvector a => a -> option (list memory_byte)*) +Definition bytes_of_bits (bs : a) := byte_chunks (bits_of bs). + +(*val bits_of_bytes : forall a. Bitvector a => list memory_byte -> a*) +Definition bits_of_bytes (bs : list memory_byte) : a := of_bits (List.concat (List.map bits_of bs)). + +Definition mem_bytes_of_bits (bs : a) := option_map (@rev (list bitU)) (bytes_of_bits bs). +Definition bits_of_mem_bytes (bs : list memory_byte) : a := bits_of_bytes (List.rev bs). + +End BytesBits. +(* +(*val bitv_of_byte_lifteds : list Sail_impl_base.byte_lifted -> list bitU +Definition bitv_of_byte_lifteds v := + foldl (fun x (Byte_lifted y) => x ++ (List.map bitU_of_bit_lifted y)) [] v + +val bitv_of_bytes : list Sail_impl_base.byte -> list bitU +Definition bitv_of_bytes v := + foldl (fun x (Byte y) => x ++ (List.map bitU_of_bit y)) [] v + +val byte_lifteds_of_bitv : list bitU -> list byte_lifted +Definition byte_lifteds_of_bitv bits := + let bits := List.map bit_lifted_of_bitU bits in + byte_lifteds_of_bit_lifteds bits + +val bytes_of_bitv : list bitU -> list byte +Definition bytes_of_bitv bits := + let bits := List.map bit_of_bitU bits in + bytes_of_bits bits + +val bit_lifteds_of_bitUs : list bitU -> list bit_lifted +Definition bit_lifteds_of_bitUs bits := List.map bit_lifted_of_bitU bits + +val bit_lifteds_of_bitv : list bitU -> list bit_lifted +Definition bit_lifteds_of_bitv v := bit_lifteds_of_bitUs v + + +val address_lifted_of_bitv : list bitU -> address_lifted +Definition address_lifted_of_bitv v := + let byte_lifteds := byte_lifteds_of_bitv v in + let maybe_address_integer := + match (maybe_all (List.map byte_of_byte_lifted byte_lifteds)) with + | Some bs => Some (integer_of_byte_list bs) + | _ => None + end in + Address_lifted byte_lifteds maybe_address_integer + +val bitv_of_address_lifted : address_lifted -> list bitU +Definition bitv_of_address_lifted (Address_lifted bs _) := bitv_of_byte_lifteds bs + +val address_of_bitv : list bitU -> address +Definition address_of_bitv v := + let bytes := bytes_of_bitv v in + address_of_byte_list bytes*) + +Fixpoint reverse_endianness_list bits := + if List.length bits <= 8 then bits else + reverse_endianness_list (drop_list 8 bits) ++ take_list 8 bits + +val reverse_endianness : forall a. Bitvector a => a -> a +Definition reverse_endianness v := of_bits (reverse_endianness_list (bits_of v)) +*) + +(*** Registers *) + +Definition register_field := string. +Definition register_field_index : Type := string * (Z * Z). (* name, start and end *) + +Inductive register := + | Register : string * (* name *) + Z * (* length *) + Z * (* start index *) + bool * (* is increasing *) + list register_field_index + -> register + | UndefinedRegister : Z -> register (* length *) + | RegisterPair : register * register -> register. + +Record register_ref regstate regval a := + { name : string; + (*is_inc : bool;*) + read_from : regstate -> a; + write_to : a -> regstate -> regstate; + of_regval : regval -> option a; + regval_of : a -> regval }. +Notation "{[ r 'with' 'name' := e ]}" := ({| name := e; read_from := read_from r; write_to := write_to r; of_regval := of_regval r; regval_of := regval_of r |}). +Notation "{[ r 'with' 'read_from' := e ]}" := ({| read_from := e; name := name r; write_to := write_to r; of_regval := of_regval r; regval_of := regval_of r |}). +Notation "{[ r 'with' 'write_to' := e ]}" := ({| write_to := e; name := name r; read_from := read_from r; of_regval := of_regval r; regval_of := regval_of r |}). +Notation "{[ r 'with' 'of_regval' := e ]}" := ({| of_regval := e; name := name r; read_from := read_from r; write_to := write_to r; regval_of := regval_of r |}). +Notation "{[ r 'with' 'regval_of' := e ]}" := ({| regval_of := e; name := name r; read_from := read_from r; write_to := write_to r; of_regval := of_regval r |}). +Arguments name [_ _ _]. +Arguments read_from [_ _ _]. +Arguments write_to [_ _ _]. +Arguments of_regval [_ _ _]. +Arguments regval_of [_ _ _]. + +Definition register_accessors regstate regval : Type := + ((string -> regstate -> option regval) * + (string -> regval -> regstate -> option regstate)). + +Record field_ref regtype a := + { field_name : string; + field_start : Z; + field_is_inc : bool; + get_field : regtype -> a; + set_field : regtype -> a -> regtype }. +Arguments field_name [_ _]. +Arguments field_start [_ _]. +Arguments field_is_inc [_ _]. +Arguments get_field [_ _]. +Arguments set_field [_ _]. + +(* +(*let name_of_reg := function + | Register name _ _ _ _ => name + | UndefinedRegister _ => failwith "name_of_reg UndefinedRegister" + | RegisterPair _ _ => failwith "name_of_reg RegisterPair" +end + +Definition size_of_reg := function + | Register _ size _ _ _ => size + | UndefinedRegister size => size + | RegisterPair _ _ => failwith "size_of_reg RegisterPair" +end + +Definition start_of_reg := function + | Register _ _ start _ _ => start + | UndefinedRegister _ => failwith "start_of_reg UndefinedRegister" + | RegisterPair _ _ => failwith "start_of_reg RegisterPair" +end + +Definition is_inc_of_reg := function + | Register _ _ _ is_inc _ => is_inc + | UndefinedRegister _ => failwith "is_inc_of_reg UndefinedRegister" + | RegisterPair _ _ => failwith "in_inc_of_reg RegisterPair" +end + +Definition dir_of_reg := function + | Register _ _ _ is_inc _ => dir_of_bool is_inc + | UndefinedRegister _ => failwith "dir_of_reg UndefinedRegister" + | RegisterPair _ _ => failwith "dir_of_reg RegisterPair" +end + +Definition size_of_reg_nat reg := Z.to_nat (size_of_reg reg) +Definition start_of_reg_nat reg := Z.to_nat (start_of_reg reg) + +val register_field_indices_aux : register -> register_field -> option (Z * Z) +Fixpoint register_field_indices_aux register rfield := + match register with + | Register _ _ _ _ rfields => List.lookup rfield rfields + | RegisterPair r1 r2 => + let m_indices := register_field_indices_aux r1 rfield in + if isSome m_indices then m_indices else register_field_indices_aux r2 rfield + | UndefinedRegister _ => None + end + +val register_field_indices : register -> register_field -> Z * Z +Definition register_field_indices register rfield := + match register_field_indices_aux register rfield with + | Some indices => indices + | None => failwith "Invalid register/register-field combination" + end + +Definition register_field_indices_nat reg regfield= + let (i,j) := register_field_indices reg regfield in + (Z.to_nat i,Z.to_nat j)*) + +(*let rec external_reg_value reg_name v := + let (internal_start, external_start, direction) := + match reg_name with + | Reg _ start size dir => + (start, (if dir = D_increasing then start else (start - (size +1))), dir) + | Reg_slice _ reg_start dir (slice_start, _) => + ((if dir = D_increasing then slice_start else (reg_start - slice_start)), + slice_start, dir) + | Reg_field _ reg_start dir _ (slice_start, _) => + ((if dir = D_increasing then slice_start else (reg_start - slice_start)), + slice_start, dir) + | Reg_f_slice _ reg_start dir _ _ (slice_start, _) => + ((if dir = D_increasing then slice_start else (reg_start - slice_start)), + slice_start, dir) + end in + let bits := bit_lifteds_of_bitv v in + <| rv_bits := bits; + rv_dir := direction; + rv_start := external_start; + rv_start_internal := internal_start |> + +val internal_reg_value : register_value -> list bitU +Definition internal_reg_value v := + List.map bitU_of_bit_lifted v.rv_bits + (*(Z.of_nat v.rv_start_internal) + (v.rv_dir = D_increasing)*) + + +Definition external_slice (d:direction) (start:nat) ((i,j):(nat*nat)) := + match d with + (*This is the case the thread/concurrecny model expects, so no change needed*) + | D_increasing => (i,j) + | D_decreasing => let slice_i = start - i in + let slice_j = (i - j) + slice_i in + (slice_i,slice_j) + end *) + +(* TODO +Definition external_reg_whole r := + Reg (r.name) (Z.to_nat r.start) (Z.to_nat r.size) (dir_of_bool r.is_inc) + +Definition external_reg_slice r (i,j) := + let start := Z.to_nat r.start in + let dir := dir_of_bool r.is_inc in + Reg_slice (r.name) start dir (external_slice dir start (i,j)) + +Definition external_reg_field_whole reg rfield := + let (m,n) := register_field_indices_nat reg rfield in + let start := start_of_reg_nat reg in + let dir := dir_of_reg reg in + Reg_field (name_of_reg reg) start dir rfield (external_slice dir start (m,n)) + +Definition external_reg_field_slice reg rfield (i,j) := + let (m,n) := register_field_indices_nat reg rfield in + let start := start_of_reg_nat reg in + let dir := dir_of_reg reg in + Reg_f_slice (name_of_reg reg) start dir rfield + (external_slice dir start (m,n)) + (external_slice dir start (i,j))*) + +(*val external_mem_value : list bitU -> memory_value +Definition external_mem_value v := + byte_lifteds_of_bitv v $> List.reverse + +val internal_mem_value : memory_value -> list bitU +Definition internal_mem_value bytes := + List.reverse bytes $> bitv_of_byte_lifteds*) + + +val foreach : forall a vars. + (list a) -> vars -> (a -> vars -> vars) -> vars*) +Fixpoint foreach {a Vars} (l : list a) (vars : Vars) (body : a -> Vars -> Vars) : Vars := +match l with +| [] => vars +| (x :: xs) => foreach xs (body x vars) body +end. + +(*declare {isabelle} termination_argument foreach = automatic + +val index_list : Z -> Z -> Z -> list Z*) +Fixpoint index_list' from step n := + match n with + | O => [] + | S n => from :: index_list' (from + step) step n + end. + +Definition index_list from to step := + if orb (andb (step >? 0) (from <=? to)) (andb (step <? 0) (to <=? from)) then + index_list' from step (S (Z.abs_nat (from - to))) + else []. + +(*val while : forall vars. vars -> (vars -> bool) -> (vars -> vars) -> vars +Fixpoint while vars cond body := + if cond vars then while (body vars) cond body else vars + +val until : forall vars. vars -> (vars -> bool) -> (vars -> vars) -> vars +Fixpoint until vars cond body := + let vars := body vars in + if cond vars then vars else until (body vars) cond body + + +Definition assert' b msg_opt := + let msg := match msg_opt with + | Some msg => msg + | None => "unspecified error" + end in + if b then () else failwith msg + +(* convert numbers unsafely to naturals *) + +class (ToNatural a) val toNatural : a -> natural end +(* eta-expanded for Isabelle output, otherwise it breaks *) +instance (ToNatural Z) let toNatural := (fun n => naturalFromInteger n) end +instance (ToNatural int) let toNatural := (fun n => naturalFromInt n) end +instance (ToNatural nat) let toNatural := (fun n => naturalFromNat n) end +instance (ToNatural natural) let toNatural := (fun n => n) end + +Definition toNaturalFiveTup (n1,n2,n3,n4,n5) := + (toNatural n1, + toNatural n2, + toNatural n3, + toNatural n4, + toNatural n5) + +(* Let the following types be generated by Sail per spec, using either bitlists + or machine words as bitvector representation *) +(*type regfp := + | RFull of (string) + | RSlice of (string * Z * Z) + | RSliceBit of (string * Z) + | RField of (string * string) + +type niafp := + | NIAFP_successor + | NIAFP_concrete_address of vector bitU + | NIAFP_indirect_address + +(* only for MIPS *) +type diafp := + | DIAFP_none + | DIAFP_concrete of vector bitU + | DIAFP_reg of regfp + +Definition regfp_to_reg (reg_info : string -> option string -> (nat * nat * direction * (nat * nat))) := function + | RFull name => + let (start,length,direction,_) := reg_info name None in + Reg name start length direction + | RSlice (name,i,j) => + let i = Z.to_nat i in + let j = Z.to_nat j in + let (start,length,direction,_) = reg_info name None in + let slice = external_slice direction start (i,j) in + Reg_slice name start direction slice + | RSliceBit (name,i) => + let i = Z.to_nat i in + let (start,length,direction,_) = reg_info name None in + let slice = external_slice direction start (i,i) in + Reg_slice name start direction slice + | RField (name,field_name) => + let (start,length,direction,span) = reg_info name (Some field_name) in + let slice = external_slice direction start span in + Reg_field name start direction field_name slice +end + +Definition niafp_to_nia reginfo = function + | NIAFP_successor => NIA_successor + | NIAFP_concrete_address v => NIA_concrete_address (address_of_bitv v) + | NIAFP_indirect_address => NIA_indirect_address +end + +Definition diafp_to_dia reginfo = function + | DIAFP_none => DIA_none + | DIAFP_concrete v => DIA_concrete_address (address_of_bitv v) + | DIAFP_reg r => DIA_register (regfp_to_reg reginfo r) +end +*) +*) diff --git a/lib/coq/State.v b/lib/coq/State.v new file mode 100644 index 00000000..00dd1f5b --- /dev/null +++ b/lib/coq/State.v @@ -0,0 +1,68 @@ +(*Require Import Sail_impl_base*) +Require Import Sail_values. +Require Import Prompt_monad. +Require Import Prompt. +Require Import State_monad. +(* +(* State monad wrapper around prompt monad *) + +val liftState : forall 'regval 'regs 'a 'e. register_accessors 'regs 'regval -> monad 'regval 'a 'e -> monadS 'regs 'a 'e +let rec liftState ra s = match s with + | (Done a) -> returnS a + | (Read_mem rk a sz k) -> bindS (read_mem_bytesS rk a sz) (fun v -> liftState ra (k v)) + | (Read_tag t k) -> bindS (read_tagS t) (fun v -> liftState ra (k v)) + | (Write_memv a k) -> bindS (write_mem_bytesS a) (fun v -> liftState ra (k v)) + | (Write_tagv t k) -> bindS (write_tagS t) (fun v -> liftState ra (k v)) + | (Read_reg r k) -> bindS (read_regvalS ra r) (fun v -> liftState ra (k v)) + | (Excl_res k) -> bindS (excl_resultS ()) (fun v -> liftState ra (k v)) + | (Write_ea wk a sz k) -> seqS (write_mem_eaS wk a sz) (liftState ra k) + | (Write_reg r v k) -> seqS (write_regvalS ra r v) (liftState ra k) + | (Footprint k) -> liftState ra k + | (Barrier _ k) -> liftState ra k + | (Fail descr) -> failS descr + | (Error descr) -> failS descr + | (Exception e) -> throwS e +end + + +val iterS_aux : forall 'rv 'a 'e. integer -> (integer -> 'a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e +let rec iterS_aux i f xs = match xs with + | x :: xs -> f i x >>$ iterS_aux (i + 1) f xs + | [] -> returnS () + end + +declare {isabelle} termination_argument iterS_aux = automatic + +val iteriS : forall 'rv 'a 'e. (integer -> 'a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e +let iteriS f xs = iterS_aux 0 f xs + +val iterS : forall 'rv 'a 'e. ('a -> monadS 'rv unit 'e) -> list 'a -> monadS 'rv unit 'e +let iterS f xs = iteriS (fun _ x -> f x) xs + +val foreachS : forall 'a 'rv 'vars 'e. + list 'a -> 'vars -> ('a -> 'vars -> monadS 'rv 'vars 'e) -> monadS 'rv 'vars 'e +let rec foreachS xs vars body = match xs with + | [] -> returnS vars + | x :: xs -> + body x vars >>$= fun vars -> + foreachS xs vars body +end + +declare {isabelle} termination_argument foreachS = automatic + + +val whileS : forall 'rv 'vars 'e. 'vars -> ('vars -> monadS 'rv bool 'e) -> + ('vars -> monadS 'rv 'vars 'e) -> monadS 'rv 'vars 'e +let rec whileS vars cond body s = + (cond vars >>$= (fun cond_val s' -> + if cond_val then + (body vars >>$= (fun vars s'' -> whileS vars cond body s'')) s' + else returnS vars s')) s + +val untilS : forall 'rv 'vars 'e. 'vars -> ('vars -> monadS 'rv bool 'e) -> + ('vars -> monadS 'rv 'vars 'e) -> monadS 'rv 'vars 'e +let rec untilS vars cond body s = + (body vars >>$= (fun vars s' -> + (cond vars >>$= (fun cond_val s'' -> + if cond_val then returnS vars s'' else untilS vars cond body s'')) s')) s +*) diff --git a/lib/coq/State_monad.v b/lib/coq/State_monad.v new file mode 100644 index 00000000..f93606e8 --- /dev/null +++ b/lib/coq/State_monad.v @@ -0,0 +1,253 @@ +Require Import Sail_instr_kinds. +Require Import Sail_values. +(* +(* 'a is result type *) + +type memstate = map integer memory_byte +type tagstate = map integer bitU +(* type regstate = map string (vector bitU) *) + +type sequential_state 'regs = + <| regstate : 'regs; + memstate : memstate; + tagstate : tagstate; + write_ea : maybe (write_kind * integer * integer); + last_exclusive_operation_was_load : bool|> + +val init_state : forall 'regs. 'regs -> sequential_state 'regs +let init_state regs = + <| regstate = regs; + memstate = Map.empty; + tagstate = Map.empty; + write_ea = Nothing; + last_exclusive_operation_was_load = false |> + +type ex 'e = + | Failure of string + | Throw of 'e + +type result 'a 'e = + | Value of 'a + | Ex of (ex 'e) + +(* State, nondeterminism and exception monad with result value type 'a + and exception type 'e. *) +type monadS 'regs 'a 'e = sequential_state 'regs -> list (result 'a 'e * sequential_state 'regs) + +val returnS : forall 'regs 'a 'e. 'a -> monadS 'regs 'a 'e +let returnS a s = [(Value a,s)] + +val bindS : forall 'regs 'a 'b 'e. monadS 'regs 'a 'e -> ('a -> monadS 'regs 'b 'e) -> monadS 'regs 'b 'e +let bindS m f (s : sequential_state 'regs) = + List.concatMap (function + | (Value a, s') -> f a s' + | (Ex e, s') -> [(Ex e, s')] + end) (m s) + +val seqS: forall 'regs 'b 'e. monadS 'regs unit 'e -> monadS 'regs 'b 'e -> monadS 'regs 'b 'e +let seqS m n = bindS m (fun (_ : unit) -> n) + +let inline (>>$=) = bindS +let inline (>>$) = seqS + +val chooseS : forall 'regs 'a 'e. list 'a -> monadS 'regs 'a 'e +let chooseS xs s = List.map (fun x -> (Value x, s)) xs + +val readS : forall 'regs 'a 'e. (sequential_state 'regs -> 'a) -> monadS 'regs 'a 'e +let readS f = (fun s -> returnS (f s) s) + +val updateS : forall 'regs 'e. (sequential_state 'regs -> sequential_state 'regs) -> monadS 'regs unit 'e +let updateS f = (fun s -> returnS () (f s)) + +val failS : forall 'regs 'a 'e. string -> monadS 'regs 'a 'e +let failS msg s = [(Ex (Failure msg), s)] + +val exitS : forall 'regs 'e 'a. unit -> monadS 'regs 'a 'e +let exitS () = failS "exit" + +val throwS : forall 'regs 'a 'e. 'e -> monadS 'regs 'a 'e +let throwS e s = [(Ex (Throw e), s)] + +val try_catchS : forall 'regs 'a 'e1 'e2. monadS 'regs 'a 'e1 -> ('e1 -> monadS 'regs 'a 'e2) -> monadS 'regs 'a 'e2 +let try_catchS m h s = + List.concatMap (function + | (Value a, s') -> returnS a s' + | (Ex (Throw e), s') -> h e s' + | (Ex (Failure msg), s') -> [(Ex (Failure msg), s')] + end) (m s) + +val assert_expS : forall 'regs 'e. bool -> string -> monadS 'regs unit 'e +let assert_expS exp msg = if exp then returnS () else failS msg + +(* For early return, we abuse exceptions by throwing and catching + the return value. The exception type is "either 'r 'e", where "Right e" + represents a proper exception and "Left r" an early return of value "r". *) +type monadSR 'regs 'a 'r 'e = monadS 'regs 'a (either 'r 'e) + +val early_returnS : forall 'regs 'a 'r 'e. 'r -> monadSR 'regs 'a 'r 'e +let early_returnS r = throwS (Left r) + +val catch_early_returnS : forall 'regs 'a 'e. monadSR 'regs 'a 'a 'e -> monadS 'regs 'a 'e +let catch_early_returnS m = + try_catchS m + (function + | Left a -> returnS a + | Right e -> throwS e + end) + +(* Lift to monad with early return by wrapping exceptions *) +val liftSR : forall 'a 'r 'regs 'e. monadS 'regs 'a 'e -> monadSR 'regs 'a 'r 'e +let liftSR m = try_catchS m (fun e -> throwS (Right e)) + +(* Catch exceptions in the presence of early returns *) +val try_catchSR : forall 'regs 'a 'r 'e1 'e2. monadSR 'regs 'a 'r 'e1 -> ('e1 -> monadSR 'regs 'a 'r 'e2) -> monadSR 'regs 'a 'r 'e2 +let try_catchSR m h = + try_catchS m + (function + | Left r -> throwS (Left r) + | Right e -> h e + end) + +val read_tagS : forall 'regs 'a 'e. Bitvector 'a => 'a -> monadS 'regs bitU 'e +let read_tagS addr = + readS (fun s -> fromMaybe B0 (Map.lookup (unsigned addr) s.tagstate)) + +(* Read bytes from memory and return in little endian order *) +val read_mem_bytesS : forall 'regs 'e 'a. Bitvector 'a => read_kind -> 'a -> nat -> monadS 'regs (list memory_byte) 'e +let read_mem_bytesS read_kind addr sz = + let addr = unsigned addr in + let sz = integerFromNat sz in + let addrs = index_list addr (addr+sz-1) 1 in + let read_byte s addr = Map.lookup addr s.memstate in + readS (fun s -> just_list (List.map (read_byte s) addrs)) >>$= (function + | Just mem_val -> + updateS (fun s -> + if read_is_exclusive read_kind + then <| s with last_exclusive_operation_was_load = true |> + else s) >>$ + returnS mem_val + | Nothing -> failS "read_memS" + end) + +val read_memS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monadS 'regs 'b 'e +let read_memS rk a sz = + read_mem_bytesS rk a (natFromInteger sz) >>$= (fun bytes -> + returnS (bits_of_mem_bytes bytes)) + +val excl_resultS : forall 'regs 'e. unit -> monadS 'regs bool 'e +let excl_resultS () = + readS (fun s -> s.last_exclusive_operation_was_load) >>$= (fun excl_load -> + updateS (fun s -> <| s with last_exclusive_operation_was_load = false |>) >>$ + chooseS (if excl_load then [false; true] else [false])) + +val write_mem_eaS : forall 'regs 'e 'a. Bitvector 'a => write_kind -> 'a -> nat -> monadS 'regs unit 'e +let write_mem_eaS write_kind addr sz = + let addr = unsigned addr in + let sz = integerFromNat sz in + updateS (fun s -> <| s with write_ea = Just (write_kind, addr, sz) |>) + +(* Write little-endian list of bytes to previously announced address *) +val write_mem_bytesS : forall 'regs 'e. list memory_byte -> monadS 'regs bool 'e +let write_mem_bytesS v = + readS (fun s -> s.write_ea) >>$= (function + | Nothing -> failS "write ea has not been announced yet" + | Just (_, addr, sz) -> + let addrs = index_list addr (addr+sz-1) 1 in + (*let v = external_mem_value (bits_of v) in*) + let a_v = List.zip addrs v in + let write_byte mem (addr, v) = Map.insert addr v mem in + updateS (fun s -> + <| s with memstate = List.foldl write_byte s.memstate a_v |>) >>$ + returnS true + end) + +val write_mem_valS : forall 'regs 'e 'a. Bitvector 'a => 'a -> monadS 'regs bool 'e +let write_mem_valS v = match mem_bytes_of_bits v with + | Just v -> write_mem_bytesS v + | Nothing -> failS "write_mem_val" +end + +val write_tagS : forall 'regs 'e. bitU -> monadS 'regs bool 'e +let write_tagS t = + readS (fun s -> s.write_ea) >>$= (function + | Nothing -> failS "write ea has not been announced yet" + | Just (_, addr, _) -> + (*let taddr = addr / cap_alignment in*) + updateS (fun s -> <| s with tagstate = Map.insert addr t s.tagstate |>) >>$ + returnS true + end) + +val read_regS : forall 'regs 'rv 'a 'e. register_ref 'regs 'rv 'a -> monadS 'regs 'a 'e +let read_regS reg = readS (fun s -> reg.read_from s.regstate) + +(* TODO +let read_reg_range reg i j state = + let v = slice (get_reg state (name_of_reg reg)) i j in + [(Value (vec_to_bvec v),state)] +let read_reg_bit reg i state = + let v = access (get_reg state (name_of_reg reg)) i in + [(Value v,state)] +let read_reg_field reg regfield = + let (i,j) = register_field_indices reg regfield in + read_reg_range reg i j +let read_reg_bitfield reg regfield = + let (i,_) = register_field_indices reg regfield in + read_reg_bit reg i *) + +val read_regvalS : forall 'regs 'rv 'e. + register_accessors 'regs 'rv -> string -> monadS 'regs 'rv 'e +let read_regvalS (read, _) reg = + readS (fun s -> read reg s.regstate) >>$= (function + | Just v -> returnS v + | Nothing -> failS ("read_regvalS " ^ reg) + end) + +val write_regvalS : forall 'regs 'rv 'e. + register_accessors 'regs 'rv -> string -> 'rv -> monadS 'regs unit 'e +let write_regvalS (_, write) reg v = + readS (fun s -> write reg v s.regstate) >>$= (function + | Just rs' -> updateS (fun s -> <| s with regstate = rs' |>) + | Nothing -> failS ("write_regvalS " ^ reg) + end) + +val write_regS : forall 'regs 'rv 'a 'e. register_ref 'regs 'rv 'a -> 'a -> monadS 'regs unit 'e +let write_regS reg v = + updateS (fun s -> <| s with regstate = reg.write_to v s.regstate |>) + +(* TODO +val update_reg : forall 'regs 'rv 'a 'b 'e. register_ref 'regs 'rv 'a -> ('a -> 'b -> 'a) -> 'b -> monadS 'regs unit 'e +let update_reg reg f v state = + let current_value = get_reg state reg in + let new_value = f current_value v in + [(Value (), set_reg state reg new_value)] + +let write_reg_field reg regfield = update_reg reg regfield.set_field + +val update_reg_range : forall 'regs 'rv 'a 'b. Bitvector 'a, Bitvector 'b => register_ref 'regs 'rv 'a -> integer -> integer -> 'a -> 'b -> 'a +let update_reg_range reg i j reg_val new_val = set_bits (reg.is_inc) reg_val i j (bits_of new_val) +let write_reg_range reg i j = update_reg reg (update_reg_range reg i j) + +let update_reg_pos reg i reg_val x = update_list reg.is_inc reg_val i x +let write_reg_pos reg i = update_reg reg (update_reg_pos reg i) + +let update_reg_bit reg i reg_val bit = set_bit (reg.is_inc) reg_val i (to_bitU bit) +let write_reg_bit reg i = update_reg reg (update_reg_bit reg i) + +let update_reg_field_range regfield i j reg_val new_val = + let current_field_value = regfield.get_field reg_val in + let new_field_value = set_bits (regfield.field_is_inc) current_field_value i j (bits_of new_val) in + regfield.set_field reg_val new_field_value +let write_reg_field_range reg regfield i j = update_reg reg (update_reg_field_range regfield i j) + +let update_reg_field_pos regfield i reg_val x = + let current_field_value = regfield.get_field reg_val in + let new_field_value = update_list regfield.field_is_inc current_field_value i x in + regfield.set_field reg_val new_field_value +let write_reg_field_pos reg regfield i = update_reg reg (update_reg_field_pos regfield i) + +let update_reg_field_bit regfield i reg_val bit = + let current_field_value = regfield.get_field reg_val in + let new_field_value = set_bit (regfield.field_is_inc) current_field_value i (to_bitU bit) in + regfield.set_field reg_val new_field_value +let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i)*) +*) diff --git a/lib/hol/.gitignore b/lib/hol/.gitignore new file mode 100644 index 00000000..fe652801 --- /dev/null +++ b/lib/hol/.gitignore @@ -0,0 +1,9 @@ +prompt_monadScript.sml +promptScript.sml +sail_instr_kindsScript.sml +sail_operators_bitlistsScript.sml +sail_operators_mwordsScript.sml +sail_operatorsScript.sml +sail_valuesScript.sml +state_monadScript.sml +stateScript.sml diff --git a/lib/hol/Holmakefile b/lib/hol/Holmakefile new file mode 100644 index 00000000..eac4fec8 --- /dev/null +++ b/lib/hol/Holmakefile @@ -0,0 +1,32 @@ +LEM_SCRIPTS = sail_instr_kindsScript.sml sail_valuesScript.sml sail_operatorsScript.sml \ + sail_operators_mwordsScript.sml sail_operators_bitlistsScript.sml \ + state_monadScript.sml stateScript.sml promptScript.sml prompt_monadScript.sml + +LEM_CLEANS = $(LEM_SCRIPTS) + +SCRIPTS = $(LEM_SCRIPTS) \ + sail_valuesAuxiliaryScript.sml stateAuxiliaryScript.sml + +THYS = $(patsubst %Script.sml,%Theory.uo,$(SCRIPTS)) + +LEMDIR=../../../lem/hol-lib + +INCLUDES = $(LEMDIR) + +all: $(THYS) +.PHONY: all + +EXTRA_CLEANS = $(LEM_CLEANS) + +ifdef POLY +HOLHEAP = sail-heap +EXTRA_CLEANS = $(LEM_CLEANS) $(HOLHEAP) $(HOLHEAP).o + +BASE_HEAP = $(LEMDIR)/lemheap + +$(HOLHEAP): $(BASE_HEAP) + $(protect $(HOLDIR)/bin/buildheap) -o $(HOLHEAP) -b $(BASE_HEAP) + +all: $(HOLHEAP) + +endif diff --git a/lib/hol/Makefile b/lib/hol/Makefile new file mode 100644 index 00000000..065f887a --- /dev/null +++ b/lib/hol/Makefile @@ -0,0 +1,31 @@ +LEMSRC = \ + ../../src/lem_interp/sail_instr_kinds.lem \ + ../../src/gen_lib/sail_values.lem \ + ../../src/gen_lib/sail_operators.lem \ + ../../src/gen_lib/sail_operators_mwords.lem \ + ../../src/gen_lib/sail_operators_bitlists.lem \ + ../../src/gen_lib/state_monad.lem \ + ../../src/gen_lib/state.lem \ + prompt_monad.lem \ + prompt.lem + +SCRIPTS = sail_instr_kindsScript.sml sail_valuesScript.sml sail_operatorsScript.sml \ + sail_operators_mwordsScript.sml sail_operators_bitlistsScript.sml \ + state_monadScript.sml stateScript.sml \ + prompt_monadScript.sml promptScript.sml + +THYS = $(patsubst %Script.sml,%Theory.uo,$(SCRIPTS)) + +all: sail-heap $(THYS) + +$(SCRIPTS): $(LEMSRC) + lem -hol -outdir . -auxiliary_level none -lib ../../src/lem_interp -lib ../../src/gen_lib $(LEMSRC) + +$(THYS) sail-heap: $(SCRIPTS) + Holmake + +# Holmake will also clear out the generated $(SCRIPTS) files +clean: + Holmake cleanAll + +.PHONY: all clean diff --git a/lib/hol/prompt.lem b/lib/hol/prompt.lem new file mode 100644 index 00000000..edbd3752 --- /dev/null +++ b/lib/hol/prompt.lem @@ -0,0 +1,18 @@ +open import Prompt_monad +open import State_monad +open import State + +let inline undefined_bool = undefined_boolS +let inline bool_of_bitU_oracle = bool_of_bitU_oracleS +let inline bool_of_bitU_fail = bool_of_bitU_fail +let inline bools_of_bits_oracle = bools_of_bits_oracleS +let inline of_bits_oracle = of_bits_oracleS +let inline of_bits_fail = of_bits_failS +let inline mword_oracle = mword_oracleS +let inline reg_deref = read_regS + +let inline foreachM = foreachS +let inline whileM = whileS +let inline untilM = untilS +let inline and_boolM = and_boolS +let inline or_boolM = or_boolS diff --git a/lib/hol/prompt_monad.lem b/lib/hol/prompt_monad.lem new file mode 100644 index 00000000..8fcd645a --- /dev/null +++ b/lib/hol/prompt_monad.lem @@ -0,0 +1,49 @@ +open import Pervasives_extra +open import Sail_instr_kinds +open import Sail_values +open import State_monad + +(* Fake interface of the prompt monad by redirecting to the state monad, since + the former is not currently supported by HOL4 *) + +type monad 'rv 'a 'e = monadS 'rv 'a 'e +type monadR 'rv 'a 'e 'r = monadRS 'rv 'a 'e 'r + +(* We need to use a target_rep for these because HOL doesn't handle unused + type parameters well. *) + +type base_monad 'regval 'regstate 'a 'e = monad 'regstate 'a 'e +declare hol target_rep type base_monad 'regval 'regstate 'a 'e = `monad` 'regstate 'a 'e +type base_monadR 'regval 'regstate 'a 'r 'e = monadR 'regstate 'a 'r 'e +declare hol target_rep type base_monadR 'regval 'regstate 'a 'r 'e = `monadR` 'regstate 'a 'r 'e + +let inline return = returnS +let inline bind = bindS +let inline (>>=) = (>>$=) +let inline (>>) = (>>$) + +let inline exit = exitS + +let inline throw = throwS +let inline try_catch = try_catchS + +let inline catch_early_return = catch_early_returnS +let inline early_return = early_returnS +let inline liftR = liftRS +let inline try_catchR = try_catchRS + +let inline maybe_fail = maybe_failS + +let inline read_mem_bytes = read_mem_bytesS +let inline read_reg = read_regS +let inline reg_deref = read_regS +let inline read_mem = read_memS +let inline read_tag = read_tagS +let inline excl_result = excl_resultS +let inline write_reg = write_regS +let inline write_tag = write_tagS +let inline write_mem_ea wk addr sz = write_mem_eaS wk addr (nat_of_int sz) +let inline write_mem_val = write_mem_valS +let barrier _ = return () + +let inline assert_exp = assert_expS diff --git a/lib/hol/sail_valuesAuxiliaryScript.sml b/lib/hol/sail_valuesAuxiliaryScript.sml new file mode 100644 index 00000000..aa169979 --- /dev/null +++ b/lib/hol/sail_valuesAuxiliaryScript.sml @@ -0,0 +1,139 @@ +(*Generated by Lem from ../../src/gen_lib/sail_values.lem.*) +open HolKernel Parse boolLib bossLib; +open lem_pervasives_extraTheory lem_machine_wordTheory sail_valuesTheory; +open intLib; + +val _ = numLib.prefer_num(); + + + +open lemLib; +(* val _ = lemLib.run_interactive := true; *) +val _ = new_theory "sail_valuesAuxiliary" + + +(****************************************************) +(* *) +(* Termination Proofs *) +(* *) +(****************************************************) + +(* val gst = Defn.tgoal_no_defn (shr_int_def, shr_int_ind) *) +val (shr_int_rw, shr_int_ind_rw) = + Defn.tprove_no_defn ((shr_int_def, shr_int_ind), + WF_REL_TAC`measure (Num o SND)` \\ COOPER_TAC + ) +val shr_int_rw = save_thm ("shr_int_rw", shr_int_rw); +val shr_int_ind_rw = save_thm ("shr_int_ind_rw", shr_int_ind_rw); +val () = computeLib.add_persistent_funs["shr_int_rw"]; + + +(* val gst = Defn.tgoal_no_defn (shl_int_def, shl_int_ind) *) +val (shl_int_rw, shl_int_ind_rw) = + Defn.tprove_no_defn ((shl_int_def, shl_int_ind), + WF_REL_TAC`measure (Num o SND)` \\ COOPER_TAC + ) +val shl_int_rw = save_thm ("shl_int_rw", shl_int_rw); +val shl_int_ind_rw = save_thm ("shl_int_ind_rw", shl_int_ind_rw); +val () = computeLib.add_persistent_funs["shl_int_rw"]; + + +(* val gst = Defn.tgoal_no_defn (repeat_def, repeat_ind) *) +val (repeat_rw, repeat_ind_rw) = + Defn.tprove_no_defn ((repeat_def, repeat_ind), + WF_REL_TAC`measure (Num o SND)` \\ COOPER_TAC + ) +val repeat_rw = save_thm ("repeat_rw", repeat_rw); +val repeat_ind_rw = save_thm ("repeat_ind_rw", repeat_ind_rw); +val () = computeLib.add_persistent_funs["repeat_rw"]; + + +(* val gst = Defn.tgoal_no_defn (bools_of_nat_aux_def, bools_of_nat_aux_ind) *) +val (bools_of_nat_aux_rw, bools_of_nat_aux_ind_rw) = + Defn.tprove_no_defn ((bools_of_nat_aux_def, bools_of_nat_aux_ind), + WF_REL_TAC`measure (Num o FST)` \\ COOPER_TAC + ) +val bools_of_nat_aux_rw = save_thm ("bools_of_nat_aux_rw", bools_of_nat_aux_rw); +val bools_of_nat_aux_ind_rw = save_thm ("bools_of_nat_aux_ind_rw", bools_of_nat_aux_ind_rw); +val () = computeLib.add_persistent_funs["bools_of_nat_aux_rw"]; + + +(* val gst = Defn.tgoal_no_defn (pad_list_def, pad_list_ind) *) +val (pad_list_rw, pad_list_ind_rw) = + Defn.tprove_no_defn ((pad_list_def, pad_list_ind), + WF_REL_TAC`measure (Num o SND o SND)` \\ COOPER_TAC + ) +val pad_list_rw = save_thm ("pad_list_rw", pad_list_rw); +val pad_list_ind_rw = save_thm ("pad_list_ind_rw", pad_list_ind_rw); +val () = computeLib.add_persistent_funs["pad_list_rw"]; + + +(* val gst = Defn.tgoal_no_defn (reverse_endianness_list_def, reverse_endianness_list_ind) *) +val (reverse_endianness_list_rw, reverse_endianness_list_ind_rw) = + Defn.tprove_no_defn ((reverse_endianness_list_def, reverse_endianness_list_ind), + WF_REL_TAC`measure LENGTH` \\ rw[drop_list_def,nat_of_int_def] + ) +val reverse_endianness_list_rw = save_thm ("reverse_endianness_list_rw", reverse_endianness_list_rw); +val reverse_endianness_list_ind_rw = save_thm ("reverse_endianness_list_ind_rw", reverse_endianness_list_ind_rw); +val () = computeLib.add_persistent_funs["reverse_endianness_list_rw"]; + + +(* val gst = Defn.tgoal_no_defn (index_list_def, index_list_ind) *) +val (index_list_rw, index_list_ind_rw) = + Defn.tprove_no_defn ((index_list_def, index_list_ind), + WF_REL_TAC`measure (λ(x,y,z). Num(1+(if z > 0 then int_max (-1) (y - x) else int_max (-1) (x - y))))` + \\ rw[integerTheory.INT_MAX] + \\ intLib.COOPER_TAC + ) +val index_list_rw = save_thm ("index_list_rw", index_list_rw); +val index_list_ind_rw = save_thm ("index_list_ind_rw", index_list_ind_rw); +val () = computeLib.add_persistent_funs["index_list_rw"]; + + +(* +(* val gst = Defn.tgoal_no_defn (while_def, while_ind) *) +val (while_rw, while_ind_rw) = + Defn.tprove_no_defn ((while_def, while_ind), + cheat (* the termination proof *) + ) +val while_rw = save_thm ("while_rw", while_rw); +val while_ind_rw = save_thm ("while_ind_rw", while_ind_rw); +*) + + +(* +(* val gst = Defn.tgoal_no_defn (until_def, until_ind) *) +val (until_rw, until_ind_rw) = + Defn.tprove_no_defn ((until_def, until_ind), + cheat (* the termination proof *) + ) +val until_rw = save_thm ("until_rw", until_rw); +val until_ind_rw = save_thm ("until_ind_rw", until_ind_rw); +*) + + +(****************************************************) +(* *) +(* Lemmata *) +(* *) +(****************************************************) + +val just_list_spec = store_thm("just_list_spec", +``((! xs. (just_list xs = NONE) <=> MEM NONE xs) /\ + (! xs es. (just_list xs = SOME es) <=> (xs = MAP SOME es)))``, + (* Theorem: just_list_spec*) + conj_tac + \\ ho_match_mp_tac just_list_ind + \\ Cases \\ rw[] + \\ rw[Once just_list_def] + >- ( CASE_TAC \\ fs[] \\ CASE_TAC ) + >- PROVE_TAC[] + \\ Cases_on`es` \\ fs[] + \\ CASE_TAC \\ fs[] + \\ CASE_TAC \\ fs[] +); + + + +val _ = export_theory() + diff --git a/lib/hol/stateAuxiliaryScript.sml b/lib/hol/stateAuxiliaryScript.sml new file mode 100644 index 00000000..c8269750 --- /dev/null +++ b/lib/hol/stateAuxiliaryScript.sml @@ -0,0 +1,61 @@ +(*Generated by Lem from ../../src/gen_lib/state.lem.*) +open HolKernel Parse boolLib bossLib; +open lem_pervasives_extraTheory sail_valuesTheory state_monadTheory stateTheory; + +val _ = numLib.prefer_num(); + + + +open lemLib; +(* val _ = lemLib.run_interactive := true; *) +val _ = new_theory "stateAuxiliary" + + +(****************************************************) +(* *) +(* Termination Proofs *) +(* *) +(****************************************************) + +(* val gst = Defn.tgoal_no_defn (iterS_aux_def, iterS_aux_ind) *) +val (iterS_aux_rw, iterS_aux_ind_rw) = + Defn.tprove_no_defn ((iterS_aux_def, iterS_aux_ind), + WF_REL_TAC`measure (LENGTH o SND o SND)` \\ rw[] + ) +val iterS_aux_rw = save_thm ("iterS_aux_rw", iterS_aux_rw); +val iterS_aux_ind_rw = save_thm ("iterS_aux_ind_rw", iterS_aux_ind_rw); + + +(* val gst = Defn.tgoal_no_defn (foreachS_def, foreachS_ind) *) +val (foreachS_rw, foreachS_ind_rw) = + Defn.tprove_no_defn ((foreachS_def, foreachS_ind), + WF_REL_TAC`measure (LENGTH o FST)` \\ rw[] + ) +val foreachS_rw = save_thm ("foreachS_rw", foreachS_rw); +val foreachS_ind_rw = save_thm ("foreachS_ind_rw", foreachS_ind_rw); + + +(* +These are unprovable. + +(* val gst = Defn.tgoal_no_defn (whileS_def, whileS_ind) *) +val (whileS_rw, whileS_ind_rw) = + Defn.tprove_no_defn ((whileS_def, whileS_ind), + cheat (* the termination proof *) + ) +val whileS_rw = save_thm ("whileS_rw", whileS_rw); +val whileS_ind_rw = save_thm ("whileS_ind_rw", whileS_ind_rw); + + +(* val gst = Defn.tgoal_no_defn (untilS_def, untilS_ind) *) +val (untilS_rw, untilS_ind_rw) = + Defn.tprove_no_defn ((untilS_def, untilS_ind), + cheat (* the termination proof *) + ) +val untilS_rw = save_thm ("untilS_rw", untilS_rw); +val untilS_ind_rw = save_thm ("untilS_ind_rw", untilS_ind_rw); +*) + + +val _ = export_theory() + diff --git a/lib/isabelle/Hoare.thy b/lib/isabelle/Hoare.thy index ee7a5fa6..9271e2fa 100644 --- a/lib/isabelle/Hoare.thy +++ b/lib/isabelle/Hoare.thy @@ -42,19 +42,20 @@ lemma PrePost_weaken_post: shows "PrePost A f C" using assms by (blast intro: PrePost_consequence) -named_theorems PrePost_intro +named_theorems PrePost_compositeI +named_theorems PrePost_atomI -lemma PrePost_True_post[PrePost_intro, intro, simp]: +lemma PrePost_True_post[PrePost_atomI, intro, simp]: "PrePost P m (\<lambda>_ _. True)" unfolding PrePost_def by auto lemma PrePost_any: "PrePost (\<lambda>s. \<forall>(r, s') \<in> m s. Q r s') m Q" unfolding PrePost_def by auto -lemma PrePost_returnS[intro, PrePost_intro]: "PrePost (P (Value x)) (returnS x) P" +lemma PrePost_returnS[intro, PrePost_atomI]: "PrePost (P (Value x)) (returnS x) P" unfolding PrePost_def returnS_def by auto -lemma PrePost_bindS[intro, PrePost_intro]: +lemma PrePost_bindS[intro, PrePost_compositeI]: assumes f: "\<And>s a s'. (Value a, s') \<in> m s \<Longrightarrow> PrePost (R a) (f a) Q" and m: "PrePost P m (\<lambda>r. case r of Value a \<Rightarrow> R a | Ex e \<Rightarrow> Q (Ex e))" shows "PrePost P (bindS m f) Q" @@ -89,10 +90,10 @@ lemma PrePost_bindS_unit: shows "PrePost P (bindS m f) Q" using assms by auto -lemma PrePost_readS[intro, PrePost_intro]: "PrePost (\<lambda>s. P (Value (f s)) s) (readS f) P" +lemma PrePost_readS[intro, PrePost_atomI]: "PrePost (\<lambda>s. P (Value (f s)) s) (readS f) P" unfolding PrePost_def readS_def returnS_def by auto -lemma PrePost_updateS[intro, PrePost_intro]: "PrePost (\<lambda>s. P (Value ()) (f s)) (updateS f) P" +lemma PrePost_updateS[intro, PrePost_atomI]: "PrePost (\<lambda>s. P (Value ()) (f s)) (updateS f) P" unfolding PrePost_def updateS_def returnS_def by auto lemma PrePost_if: @@ -100,7 +101,7 @@ lemma PrePost_if: shows "PrePost P (if b then f else g) Q" using assms by auto -lemma PrePost_if_branch[PrePost_intro]: +lemma PrePost_if_branch[PrePost_compositeI]: assumes "b \<Longrightarrow> PrePost Pf f Q" and "\<not>b \<Longrightarrow> PrePost Pg g Q" shows "PrePost (if b then Pf else Pg) (if b then f else g) Q" using assms by auto @@ -115,35 +116,65 @@ lemma PrePost_if_else: shows "PrePost P (if b then f else g) Q" using assms by auto -lemma PrePost_prod_cases[PrePost_intro]: +lemma PrePost_prod_cases[PrePost_compositeI]: assumes "PrePost P (f (fst x) (snd x)) Q" shows "PrePost P (case x of (a, b) \<Rightarrow> f a b) Q" using assms by (auto split: prod.splits) -lemma PrePost_option_cases[PrePost_intro]: +lemma PrePost_option_cases[PrePost_compositeI]: assumes "\<And>a. PrePost (PS a) (s a) Q" and "PrePost PN n Q" shows "PrePost (case x of Some a \<Rightarrow> PS a | None \<Rightarrow> PN) (case x of Some a \<Rightarrow> s a | None \<Rightarrow> n) Q" using assms by (auto split: option.splits) -lemma PrePost_let[intro, PrePost_intro]: +lemma PrePost_let[intro, PrePost_compositeI]: assumes "PrePost P (m y) Q" shows "PrePost P (let x = y in m x) Q" using assms by auto -lemma PrePost_assert_expS[intro, PrePost_intro]: "PrePost (if c then P (Value ()) else P (Ex (Failure m))) (assert_expS c m) P" +lemma PrePost_and_boolS[PrePost_compositeI]: + assumes r: "PrePost R r Q" + and l: "PrePost P l (\<lambda>r. case r of Value True \<Rightarrow> R | _ \<Rightarrow> Q r)" + shows "PrePost P (and_boolS l r) Q" + unfolding and_boolS_def +proof (rule PrePost_bindS) + fix s a s' + assume "(Value a, s') \<in> l s" + show "PrePost (if a then R else Q (Value False)) (if a then r else returnS False) Q" + using r by auto +next + show "PrePost P l (\<lambda>r. case r of Value a \<Rightarrow> if a then R else Q (Value False) | Ex e \<Rightarrow> Q (Ex e))" + using l by (elim PrePost_weaken_post) (auto split: result.splits) +qed + +lemma PrePost_or_boolS[PrePost_compositeI]: + assumes r: "PrePost R r Q" + and l: "PrePost P l (\<lambda>r. case r of Value False \<Rightarrow> R | _ \<Rightarrow> Q r)" + shows "PrePost P (or_boolS l r) Q" + unfolding or_boolS_def +proof (rule PrePost_bindS) + fix s a s' + assume "(Value a, s') \<in> l s" + show "PrePost (if a then Q (Value True) else R) (if a then returnS True else r) Q" + using r by auto +next + show "PrePost P l (\<lambda>r. case r of Value a \<Rightarrow> if a then Q (Value True) else R | Ex e \<Rightarrow> Q (Ex e))" + using l by (elim PrePost_weaken_post) (auto split: result.splits) +qed + +lemma PrePost_assert_expS[intro, PrePost_atomI]: "PrePost (if c then P (Value ()) else P (Ex (Failure m))) (assert_expS c m) P" unfolding PrePost_def assert_expS_def by (auto simp: returnS_def failS_def) -lemma PrePost_chooseS[intro, PrePost_intro]: "PrePost (\<lambda>s. \<forall>x \<in> xs. Q (Value x) s) (chooseS xs) Q" +lemma PrePost_chooseS[intro, PrePost_atomI]: "PrePost (\<lambda>s. \<forall>x \<in> xs. Q (Value x) s) (chooseS xs) Q" by (auto simp: PrePost_def chooseS_def) -lemma PrePost_failS[intro, PrePost_intro]: "PrePost (Q (Ex (Failure msg))) (failS msg) Q" +lemma PrePost_failS[intro, PrePost_atomI]: "PrePost (Q (Ex (Failure msg))) (failS msg) Q" by (auto simp: PrePost_def failS_def) lemma case_result_combine[simp]: "(case r of Value a \<Rightarrow> Q (Value a) | Ex e \<Rightarrow> Q (Ex e)) = Q r" by (auto split: result.splits) -lemma PrePost_foreachS_Nil[intro, simp, PrePost_intro]: +lemma PrePost_foreachS_Nil[intro, simp, PrePost_atomI]: "PrePost (Q (Value vars)) (foreachS [] vars body) Q" by auto @@ -219,20 +250,21 @@ lemma PrePostE_weaken_post: shows "PrePostE A f C E" using assms by (blast intro: PrePostE_consequence) -named_theorems PrePostE_intro +named_theorems PrePostE_compositeI +named_theorems PrePostE_atomI -lemma PrePostE_True_post[PrePost_intro, intro, simp]: +lemma PrePostE_True_post[PrePostE_atomI, intro, simp]: "PrePostE P m (\<lambda>_ _. True) (\<lambda>_ _. True)" unfolding PrePost_defs by (auto split: result.splits) lemma PrePostE_any: "PrePostE (\<lambda>s. \<forall>(r, s') \<in> m s. case r of Value a \<Rightarrow> Q a s' | Ex e \<Rightarrow> E e s') m Q E" by (intro PrePostE_I) auto -lemma PrePostE_returnS[PrePostE_intro, intro, simp]: +lemma PrePostE_returnS[PrePostE_atomI, intro, simp]: "PrePostE (P x) (returnS x) P Q" unfolding PrePostE_def by (auto intro: PrePost_strengthen_pre) -lemma PrePostE_bindS[intro, PrePostE_intro]: +lemma PrePostE_bindS[intro, PrePostE_compositeI]: assumes f: "\<And>s a s'. (Value a, s') \<in> m s \<Longrightarrow> PrePostE (R a) (f a) Q E" and m: "PrePostE P m R E" shows "PrePostE P (bindS m f) Q E" @@ -252,13 +284,13 @@ lemma PrePostE_bindS_unit: shows "PrePostE P (bindS m f) Q E" using assms by auto -lemma PrePostE_readS[PrePostE_intro, intro]: "PrePostE (\<lambda>s. Q (f s) s) (readS f) Q E" +lemma PrePostE_readS[PrePostE_atomI, intro]: "PrePostE (\<lambda>s. Q (f s) s) (readS f) Q E" unfolding PrePostE_def by (auto intro: PrePost_strengthen_pre) -lemma PrePostE_updateS[PrePostE_intro, intro]: "PrePostE (\<lambda>s. Q () (f s)) (updateS f) Q E" +lemma PrePostE_updateS[PrePostE_atomI, intro]: "PrePostE (\<lambda>s. Q () (f s)) (updateS f) Q E" unfolding PrePostE_def by (auto intro: PrePost_strengthen_pre) -lemma PrePostE_if_branch[PrePostE_intro]: +lemma PrePostE_if_branch[PrePostE_compositeI]: assumes "b \<Longrightarrow> PrePostE Pf f Q E" and "\<not>b \<Longrightarrow> PrePostE Pg g Q E" shows "PrePostE (if b then Pf else Pg) (if b then f else g) Q E" using assms by (auto) @@ -278,30 +310,44 @@ lemma PrePostE_if_else: shows "PrePostE P (if b then f else g) Q E" using assms by auto -lemma PrePostE_prod_cases[PrePostE_intro]: +lemma PrePostE_prod_cases[PrePostE_compositeI]: assumes "PrePostE P (f (fst x) (snd x)) Q E" shows "PrePostE P (case x of (a, b) \<Rightarrow> f a b) Q E" using assms by (auto split: prod.splits) -lemma PrePostE_option_cases[PrePostE_intro]: +lemma PrePostE_option_cases[PrePostE_compositeI]: assumes "\<And>a. PrePostE (PS a) (s a) Q E" and "PrePostE PN n Q E" shows "PrePostE (case x of Some a \<Rightarrow> PS a | None \<Rightarrow> PN) (case x of Some a \<Rightarrow> s a | None \<Rightarrow> n) Q E" using assms by (auto split: option.splits) -lemma PrePostE_let[PrePostE_intro]: +lemma PrePostE_let[PrePostE_compositeI]: assumes "PrePostE P (m y) Q E" shows "PrePostE P (let x = y in m x) Q E" using assms by auto -lemma PrePostE_assert_expS[PrePostE_intro, intro]: +lemma PrePostE_and_boolS[PrePostE_compositeI]: + assumes r: "PrePostE R r Q E" + and l: "PrePostE P l (\<lambda>r. if r then R else Q False) E" + shows "PrePostE P (and_boolS l r) Q E" + using assms unfolding PrePostE_def + by (intro PrePost_and_boolS) (auto elim: PrePost_weaken_post split: if_splits result.splits) + +lemma PrePostE_or_boolS[PrePostE_compositeI]: + assumes r: "PrePostE R r Q E" + and l: "PrePostE P l (\<lambda>r. if r then Q True else R) E" + shows "PrePostE P (or_boolS l r) Q E" + using assms unfolding PrePostE_def + by (intro PrePost_or_boolS) (auto elim: PrePost_weaken_post split: if_splits result.splits) + +lemma PrePostE_assert_expS[PrePostE_atomI, intro]: "PrePostE (if c then P () else Q (Failure m)) (assert_expS c m) P Q" unfolding PrePostE_def by (auto intro: PrePost_strengthen_pre) -lemma PrePostE_failS[PrePost_intro, intro]: +lemma PrePostE_failS[PrePostE_atomI, intro]: "PrePostE (E (Failure msg)) (failS msg) Q E" unfolding PrePostE_def by (auto intro: PrePost_strengthen_pre) -lemma PrePostE_chooseS[intro, PrePostE_intro]: +lemma PrePostE_chooseS[intro, PrePostE_atomI]: "PrePostE (\<lambda>s. \<forall>x \<in> xs. Q x s) (chooseS xs) Q E" unfolding PrePostE_def by (auto intro: PrePost_strengthen_pre) diff --git a/lib/isabelle/Makefile b/lib/isabelle/Makefile index f8786321..b10dde78 100644 --- a/lib/isabelle/Makefile +++ b/lib/isabelle/Makefile @@ -1,6 +1,6 @@ THYS = Sail_instr_kinds.thy Sail_values.thy Sail_operators.thy \ Sail_operators_mwords.thy Sail_operators_bitlists.thy \ - State_monad.thy State.thy Prompt_monad.thy Prompt.thy + State_monad.thy State.thy State_lifting.thy Prompt_monad.thy Prompt.thy EXTRA_THYS = State_monad_lemmas.thy State_lemmas.thy Prompt_monad_lemmas.thy \ Sail_operators_mwords_lemmas.thy Hoare.thy @@ -51,5 +51,8 @@ State_monad.thy: ../../src/gen_lib/state_monad.lem Sail_values.thy State.thy: ../../src/gen_lib/state.lem Prompt.thy State_monad.thy State_monad_lemmas.thy lem -isa -outdir . -auxiliary_level none -lib ../../src/lem_interp -lib ../../src/gen_lib $< +State_lifting.thy: ../../src/gen_lib/state_lifting.lem Prompt.thy State.thy + lem -isa -outdir . -auxiliary_level none -lib ../../src/lem_interp -lib ../../src/gen_lib $< + clean: -rm $(THYS) diff --git a/lib/isabelle/Prompt_monad_lemmas.thy b/lib/isabelle/Prompt_monad_lemmas.thy index e883c2a0..7a3a108d 100644 --- a/lib/isabelle/Prompt_monad_lemmas.thy +++ b/lib/isabelle/Prompt_monad_lemmas.thy @@ -17,6 +17,7 @@ lemmas bind_induct[case_names Done Read_mem Write_memv Read_reg Excl_res Write_e lemma bind_return[simp]: "bind (return a) f = f a" by (auto simp: return_def) +lemma bind_return_right[simp]: "bind x return = x" by (induction x) (auto simp: return_def) lemma bind_assoc[simp]: "bind (bind m f) g = bind m (\<lambda>x. bind (f x) g)" by (induction m f arbitrary: g rule: bind.induct) auto diff --git a/lib/isabelle/State_lemmas.thy b/lib/isabelle/State_lemmas.thy index 84b08e6c..cf5e4dbf 100644 --- a/lib/isabelle/State_lemmas.thy +++ b/lib/isabelle/State_lemmas.thy @@ -1,16 +1,18 @@ theory State_lemmas - imports State + imports State State_lifting begin lemma All_liftState_dom: "liftState_dom (r, m)" by (induction m) (auto intro: liftState.domintros) termination liftState using All_liftState_dom by auto -lemma liftState_bind[simp]: +named_theorems liftState_simp + +lemma liftState_bind[liftState_simp]: "liftState r (bind m f) = bindS (liftState r m) (liftState r \<circ> f)" by (induction m f rule: bind.induct) auto -lemma liftState_return[simp]: "liftState r (return a) = returnS a" by (auto simp: return_def) +lemma liftState_return[liftState_simp]: "liftState r (return a) = returnS a" by (auto simp: return_def) lemma Value_liftState_Run: assumes "(Value a, s') \<in> liftState r m s" @@ -19,45 +21,58 @@ lemma Value_liftState_Run: auto simp add: failS_def throwS_def returnS_def simp del: read_regvalS.simps; blast elim: Value_bindS_elim) -lemmas liftState_if_distrib[simp] = if_distrib[where f = "liftState ra" for ra] - -lemma liftState_throw[simp]: "liftState r (throw e) = throwS e" by (auto simp: throw_def) -lemma liftState_assert[simp]: "liftState r (assert_exp c msg) = assert_expS c msg" by (auto simp: assert_exp_def assert_expS_def) -lemma liftState_exit[simp]: "liftState r (exit0 ()) = exitS ()" by (auto simp: exit0_def exitS_def) -lemma liftState_exclResult[simp]: "liftState r (excl_result ()) = excl_resultS ()" by (auto simp: excl_result_def) -lemma liftState_barrier[simp]: "liftState r (barrier bk) = returnS ()" by (auto simp: barrier_def) -lemma liftState_footprint[simp]: "liftState r (footprint ()) = returnS ()" by (auto simp: footprint_def) -lemma liftState_undefined[simp]: "liftState r (undefined_bool ()) = undefined_boolS ()" by (auto simp: undefined_bool_def) -lemma liftState_maybe_fail[simp]: "liftState r (maybe_fail msg x) = maybe_failS msg x" - by (auto simp: maybe_fail_def maybe_failS_def split: option.splits) - -lemma liftState_try_catch[simp]: +lemmas liftState_if_distrib[liftState_simp] = if_distrib[where f = "liftState ra" for ra] + +lemma liftState_throw[liftState_simp]: "liftState r (throw e) = throwS e" + by (auto simp: throw_def) +lemma liftState_assert[liftState_simp]: "liftState r (assert_exp c msg) = assert_expS c msg" + by (auto simp: assert_exp_def assert_expS_def) +lemma liftState_exit[liftState_simp]: "liftState r (exit0 ()) = exitS ()" + by (auto simp: exit0_def exitS_def) +lemma liftState_exclResult[liftState_simp]: "liftState r (excl_result ()) = excl_resultS ()" + by (auto simp: excl_result_def liftState_simp) +lemma liftState_barrier[liftState_simp]: "liftState r (barrier bk) = returnS ()" + by (auto simp: barrier_def) +lemma liftState_footprint[liftState_simp]: "liftState r (footprint ()) = returnS ()" + by (auto simp: footprint_def) +lemma liftState_undefined[liftState_simp]: "liftState r (undefined_bool ()) = undefined_boolS ()" + by (auto simp: undefined_bool_def liftState_simp) +lemma liftState_maybe_fail[liftState_simp]: "liftState r (maybe_fail msg x) = maybe_failS msg x" + by (auto simp: maybe_fail_def maybe_failS_def liftState_simp split: option.splits) +lemma liftState_and_boolM[liftState_simp]: + "liftState r (and_boolM x y) = and_boolS (liftState r x) (liftState r y)" + by (auto simp: and_boolM_def and_boolS_def liftState_simp cong: bindS_cong if_cong) +lemma liftState_or_boolM[liftState_simp]: + "liftState r (or_boolM x y) = or_boolS (liftState r x) (liftState r y)" + by (auto simp: or_boolM_def or_boolS_def liftState_simp cong: bindS_cong if_cong) + +lemma liftState_try_catch[liftState_simp]: "liftState r (try_catch m h) = try_catchS (liftState r m) (liftState r \<circ> h)" by (induction m h rule: try_catch_induct) (auto simp: try_catchS_bindS_no_throw) -lemma liftState_early_return[simp]: +lemma liftState_early_return[liftState_simp]: "liftState r (early_return r) = early_returnS r" - by (auto simp: early_return_def early_returnS_def) + by (auto simp: early_return_def early_returnS_def liftState_simp) -lemma liftState_catch_early_return[simp]: +lemma liftState_catch_early_return[liftState_simp]: "liftState r (catch_early_return m) = catch_early_returnS (liftState r m)" - by (auto simp: catch_early_return_def catch_early_returnS_def sum.case_distrib cong: sum.case_cong) + by (auto simp: catch_early_return_def catch_early_returnS_def sum.case_distrib liftState_simp cong: sum.case_cong) -lemma liftState_liftR[simp]: - "liftState r (liftR m) = liftSR (liftState r m)" - by (auto simp: liftR_def liftSR_def) +lemma liftState_liftR[liftState_simp]: + "liftState r (liftR m) = liftRS (liftState r m)" + by (auto simp: liftR_def liftRS_def liftState_simp) -lemma liftState_try_catchR[simp]: - "liftState r (try_catchR m h) = try_catchSR (liftState r m) (liftState r \<circ> h)" - by (auto simp: try_catchR_def try_catchSR_def sum.case_distrib cong: sum.case_cong) +lemma liftState_try_catchR[liftState_simp]: + "liftState r (try_catchR m h) = try_catchRS (liftState r m) (liftState r \<circ> h)" + by (auto simp: try_catchR_def try_catchRS_def sum.case_distrib liftState_simp cong: sum.case_cong) lemma liftState_read_mem_BC: assumes "unsigned_method BC_bitU_list (bits_of_method BCa a) = unsigned_method BCa a" shows "liftState r (read_mem BCa BCb rk a sz) = read_memS BCa BCb rk a sz" using assms - by (auto simp: read_mem_def read_mem_bytes_def read_memS_def read_mem_bytesS_def maybe_failS_def split: option.splits) + by (auto simp: read_mem_def read_mem_bytes_def read_memS_def read_mem_bytesS_def maybe_failS_def liftState_simp split: option.splits) -lemma liftState_read_mem[simp]: +lemma liftState_read_mem[liftState_simp]: "\<And>a. liftState r (read_mem BC_mword BC_mword rk a sz) = read_memS BC_mword BC_mword rk a sz" "\<And>a. liftState r (read_mem BC_bitU_list BC_bitU_list rk a sz) = read_memS BC_bitU_list BC_bitU_list rk a sz" by (auto simp: liftState_read_mem_BC) @@ -67,14 +82,14 @@ lemma liftState_write_mem_ea_BC: shows "liftState r (write_mem_ea BCa rk a sz) = write_mem_eaS BCa rk a (nat sz)" using assms by (auto simp: write_mem_ea_def write_mem_eaS_def) -lemma liftState_write_mem_ea[simp]: +lemma liftState_write_mem_ea[liftState_simp]: "\<And>a. liftState r (write_mem_ea BC_mword rk a sz) = write_mem_eaS BC_mword rk a (nat sz)" "\<And>a. liftState r (write_mem_ea BC_bitU_list rk a sz) = write_mem_eaS BC_bitU_list rk a (nat sz)" by (auto simp: liftState_write_mem_ea_BC) lemma liftState_write_mem_val: "liftState r (write_mem_val BC v) = write_mem_valS BC v" - by (auto simp: write_mem_val_def write_mem_valS_def split: option.splits) + by (auto simp: write_mem_val_def write_mem_valS_def liftState_simp split: option.splits) lemma liftState_read_reg_readS: assumes "\<And>s. Option.bind (get_regval' (name reg) s) (of_regval reg) = Some (read_from reg s)" @@ -93,22 +108,23 @@ lemma liftState_write_reg_updateS: shows "liftState (get_regval', set_regval') (write_reg reg v) = updateS (regstate_update (write_to reg v))" using assms by (auto simp: write_reg_def updateS_def returnS_def bindS_readS) -lemma liftState_iter_aux[simp]: +lemma liftState_iter_aux[liftState_simp]: shows "liftState r (iter_aux i f xs) = iterS_aux i (\<lambda>i x. liftState r (f i x)) xs" - by (induction i "\<lambda>i x. liftState r (f i x)" xs rule: iterS_aux.induct) (auto cong: bindS_cong) + by (induction i "\<lambda>i x. liftState r (f i x)" xs rule: iterS_aux.induct) + (auto simp: liftState_simp cong: bindS_cong) -lemma liftState_iteri[simp]: +lemma liftState_iteri[liftState_simp]: "liftState r (iteri f xs) = iteriS (\<lambda>i x. liftState r (f i x)) xs" - by (auto simp: iteri_def iteriS_def) + by (auto simp: iteri_def iteriS_def liftState_simp) -lemma liftState_iter[simp]: +lemma liftState_iter[liftState_simp]: "liftState r (iter f xs) = iterS (liftState r \<circ> f) xs" - by (auto simp: iter_def iterS_def) + by (auto simp: iter_def iterS_def liftState_simp) -lemma liftState_foreachM[simp]: +lemma liftState_foreachM[liftState_simp]: "liftState r (foreachM xs vars body) = foreachS xs vars (\<lambda>x vars. liftState r (body x vars))" by (induction xs vars "\<lambda>x vars. liftState r (body x vars)" rule: foreachS.induct) - (auto cong: bindS_cong) + (auto simp: liftState_simp cong: bindS_cong) lemma whileS_dom_step: assumes "whileS_dom (vars, cond, body, s)" @@ -156,7 +172,7 @@ proof (use assms in \<open>induction vars "liftState r \<circ> cond" "liftState qed then show ?case using while while' that IH by auto qed auto - then show ?case by auto + then show ?case by (auto simp: liftState_simp) qed auto qed @@ -194,9 +210,51 @@ proof (use assms in \<open>induction vars "liftState r \<circ> cond" "liftState show "\<exists>t. Run (body vars) t vars'" using k by (auto elim: Value_liftState_Run) show "\<exists>t'. Run (cond vars') t' False" using until that by (auto elim: Value_liftState_Run) qed - then show ?case using k until IH by (auto simp: comp_def) + then show ?case using k until IH by (auto simp: comp_def liftState_simp) qed auto qed auto qed +(* Simplification rules for monadic Boolean connectives *) + +lemma if_return_return[simp]: "(if a then return True else return False) = return a" by auto + +lemma and_boolM_simps[simp]: + "and_boolM (return b) y = (if b then y else return False)" + "and_boolM x (return True) = x" + "and_boolM x (return False) = x \<bind> (\<lambda>_. return False)" + "\<And>x y z. and_boolM (x \<bind> y) z = (x \<bind> (\<lambda>r. and_boolM (y r) z))" + by (auto simp: and_boolM_def) + +lemmas and_boolM_if_distrib[simp] = if_distrib[where f = "\<lambda>x. and_boolM x y" for y] + +lemma or_boolM_simps[simp]: + "or_boolM (return b) y = (if b then return True else y)" + "or_boolM x (return True) = x \<bind> (\<lambda>_. return True)" + "or_boolM x (return False) = x" + "\<And>x y z. or_boolM (x \<bind> y) z = (x \<bind> (\<lambda>r. or_boolM (y r) z))" + by (auto simp: or_boolM_def) + +lemmas or_boolM_if_distrib[simp] = if_distrib[where f = "\<lambda>x. or_boolM x y" for y] + +lemma if_returnS_returnS[simp]: "(if a then returnS True else returnS False) = returnS a" by auto + +lemma and_boolS_simps[simp]: + "and_boolS (returnS b) y = (if b then y else returnS False)" + "and_boolS x (returnS True) = x" + "and_boolS x (returnS False) = bindS x (\<lambda>_. returnS False)" + "\<And>x y z. and_boolS (bindS x y) z = (bindS x (\<lambda>r. and_boolS (y r) z))" + by (auto simp: and_boolS_def) + +lemmas and_boolS_if_distrib[simp] = if_distrib[where f = "\<lambda>x. and_boolS x y" for y] + +lemma or_boolS_simps[simp]: + "or_boolS (returnS b) y = (if b then returnS True else y)" + "or_boolS x (returnS True) = bindS x (\<lambda>_. returnS True)" + "or_boolS x (returnS False) = x" + "\<And>x y z. or_boolS (bindS x y) z = (bindS x (\<lambda>r. or_boolS (y r) z))" + by (auto simp: or_boolS_def) + +lemmas or_boolS_if_distrib[simp] = if_distrib[where f = "\<lambda>x. or_boolS x y" for y] + end diff --git a/lib/isabelle/manual/Manual.thy b/lib/isabelle/manual/Manual.thy index 6cdfbfa1..0c83ebea 100644 --- a/lib/isabelle/manual/Manual.thy +++ b/lib/isabelle/manual/Manual.thy @@ -288,7 +288,7 @@ exception handler as arguments. The exception mechanism is also used to implement early returns by throwing and catching return values: A function body with one or more early returns of type @{typ 'a} (and exception type @{typ 'e}) is lifted to a monadic expression with exception type @{typ "('a + 'e)"} using -@{term liftSR}, such that an early return of the value @{term a} throws @{term "Inl a"}, and a +@{term liftRS}, such that an early return of the value @{term a} throws @{term "Inl a"}, and a regular exception @{term e} is thrown as @{term "Inr e"}. The function body is then wrapped in @{term catch_early_returnS} to lower it back to the default monad and exception type. These liftings and lowerings are automatically inserted by Sail for functions with early returns.\<^footnote>\<open>To be @@ -412,7 +412,9 @@ case, defined as @{thm [display, names_short] PrePostE_def} The theory includes standard proof rules for both of these variants, in particular rules giving weakest preconditions of the predefined primitives of the monad, collected under the names -@{attribute PrePost_intro} and @{attribute PrePostE_intro}, respectively. +@{attribute PrePost_atomI} for atoms such as @{term return} and @{attribute PrePost_compositeI} +for composites such as @{term bind} (or @{attribute PrePostE_atomI} and +@{attribute PrePostE_compositeI}, respectively, for the quadruple variant). The instruction we are considering is defined as @{thm [display] execute_ITYPE.simps[of _ rs for rs]} @@ -448,16 +450,17 @@ lemma shows "PrePostE pre (liftS instr) post (\<lambda>_ _. False)" unfolding pre_def instr_def post_def - by (simp add: rX_def wX_def cong: bindS_cong if_cong split del: if_split) - (rule PrePostE_strengthen_pre, (rule PrePostE_intro)+, auto simp: uint_0_iff) + by (simp add: rX_def wX_def liftState_simp cong: bindS_cong if_cong split del: if_split) + (rule PrePostE_strengthen_pre, (rule PrePostE_atomI PrePostE_compositeI)+, auto simp: uint_0_iff) text \<open>The proof begins with a simplification step, which not only unfolds the definitions of the auxiliary functions @{term rX} and @{term wX}, but also performs the lifting from the free monad to the state monad. We apply the rule @{thm [source] PrePostE_strengthen_pre} (in a -backward manner) to allow a weaker precondition, then use the rules in @{attribute PrePostE_intro} -to derive a weakest precondition, and then use @{method auto} to show that it is implied by -the given precondition. For more serious proofs, one will want to set up specialised proof -tactics. This example uses only basic proof methods, to make the reasoning steps more explicit.\<close> +backward manner) to allow a weaker precondition, then use the rules in +@{attribute PrePostE_compositeI} and @{attribute PrePostE_atomI} to derive a weakest precondition, +and then use @{method auto} to show that it is implied by the given precondition. For more serious +proofs, one will want to set up specialised proof tactics. This example uses only basic proof +methods, to make the reasoning steps more explicit.\<close> (*<*) end diff --git a/lib/main.ml b/lib/main.ml index 5733425f..e9dcb4e0 100644 --- a/lib/main.ml +++ b/lib/main.ml @@ -52,14 +52,37 @@ open Elf_loader;; let opt_file_arguments = ref ([] : string list) - -let options = Arg.align [] +let opt_raw_files = ref ([] : (string * Nat_big_num.num) list) +let options = Arg.align [ + ( "-raw", + Arg.String (fun s -> + let l = Util.split_on_char '@' s in + let (file, addr) = match l with + | [fname;addr] -> (fname, Nat_big_num.of_string addr) + | _ -> raise (Arg.Bad (s ^ " not of form <filename>@<addr>")) in + opt_raw_files := (file, addr) :: !opt_raw_files), + "<file@0xADDR> load a raw binary in memory at given address.")] let usage_msg = "Sail OCaml RTS options:" let () = Arg.parse options (fun s -> opt_file_arguments := !opt_file_arguments @ [s]) usage_msg +let rec load_raw_files = function + | (file, addr) :: files -> begin + let ic = open_in_bin file in + let addr' = ref addr in + try + while true do + let b = input_byte ic in + Sail_lib.wram !addr' b; + addr' := Nat_big_num.succ !addr'; + done + with End_of_file -> (); + load_raw_files files + end + | [] -> () + let () = Random.self_init (); begin @@ -67,4 +90,5 @@ let () = | f :: _ -> load_elf f | _ -> () end; + load_raw_files !opt_raw_files; (* ocaml_backend.ml will append from here *) diff --git a/lib/mono_rewrites.sail b/lib/mono_rewrites.sail new file mode 100644 index 00000000..c9164b6c --- /dev/null +++ b/lib/mono_rewrites.sail @@ -0,0 +1,157 @@ +/* Definitions for use with the -mono_rewrites option */ + +/* External definitions not in the usual asl prelude */ + +infix 6 << + +val shiftleft = "shiftl" : forall 'n ('ord : Order). + (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure + +overload operator << = {shiftleft} + +infix 6 >> + +val shiftright = "shiftr" : forall 'n ('ord : Order). + (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure + +overload operator >> = {shiftright} + +val arith_shiftright = "arith_shiftr" : forall 'n ('ord : Order). + (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure + +val "extz_vec" : forall 'n 'm. (atom('m),vector('n, dec, bit)) -> vector('m, dec, bit) effect pure +val extzv : forall 'n 'm. vector('n, dec, bit) -> vector('m, dec, bit) effect pure +function extzv(v) = extz_vec(sizeof('m),v) + +val "exts_vec" : forall 'n 'm. (atom('m),vector('n, dec, bit)) -> vector('m, dec, bit) effect pure +val extsv : forall 'n 'm. vector('n, dec, bit) -> vector('m, dec, bit) effect pure +function extsv(v) = exts_vec(sizeof('m),v) + +/* This is generated internally to deal with case splits which reveal the size + of a bitvector */ +val bitvector_cast = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect pure + +/* Definitions for the rewrites */ + +val slice_mask : forall 'n, 'n >= 0. (int, int) -> bits('n) effect pure +function slice_mask(i,l) = + let one : bits('n) = extzv(0b1) in + ((one << l) - one) << i + +val is_zero_subrange : forall 'n, 'n >= 0. + (bits('n), int, int) -> bool effect pure + +function is_zero_subrange (xs, i, j) = { + (xs & slice_mask(j, i-j)) == extzv(0b0) +} + +val is_ones_subrange : forall 'n, 'n >= 0. + (bits('n), int, int) -> bool effect pure + +function is_ones_subrange (xs, i, j) = { + let m : bits('n) = slice_mask(j,j-i) in + (xs & m) == m +} + +val slice_slice_concat : forall 'n 'm 'r, 'n >= 0 & 'm >= 0 & 'r >= 0. + (bits('n), int, int, bits('m), int, int) -> bits('r) effect pure + +function slice_slice_concat (xs, i, l, ys, i', l') = { + let xs = (xs & slice_mask(i,l)) >> i in + let ys = (ys & slice_mask(i',l')) >> i' in + extzv(xs) << l' | extzv(ys) +} + +val slice_zeros_concat : forall 'n 'p 'q 'r, 'r = 'p + 'q & 'n >= 0 & 'p >= 0 & 'q >= 0 & 'r >= 0. + (bits('n), int, atom('p), atom('q)) -> bits('r) effect pure + +function slice_zeros_concat (xs, i, l, l') = { + let xs = (xs & slice_mask(i,l)) >> i in + extzv(xs) << l' +} + +/* Assumes initial vectors are of equal size */ + +val subrange_subrange_eq : forall 'n, 'n >= 0. + (bits('n), int, int, bits('n), int, int) -> bool effect pure + +function subrange_subrange_eq (xs, i, j, ys, i', j') = { + let xs = (xs & slice_mask(j,i-j)) >> j in + let ys = (ys & slice_mask(j',i'-j')) >> j' in + xs == ys +} + +val subrange_subrange_concat : forall 'n 'o 'p 'm 'q 'r 's, 's = 'o - ('p - 1) + 'q - ('r - 1) & 'n >= 0 & 'm >= 0. + (bits('n), atom('o), atom('p), bits('m), atom('q), atom('r)) -> bits('s) effect pure + +function subrange_subrange_concat (xs, i, j, ys, i', j') = { + let xs = (xs & slice_mask(j,i-j)) >> j in + let ys = (ys & slice_mask(j',i'-j')) >> j' in + extzv(xs) << i' - (j' - 1) | extzv(ys) +} + +val place_subrange : forall 'n 'm, 'n >= 0 & 'm >= 0. + (bits('n), int, int, int) -> bits('m) effect pure + +function place_subrange(xs,i,j,shift) = { + let xs = (xs & slice_mask(j,i-j)) >> j in + extzv(xs) << shift +} + +val place_slice : forall 'n 'm, 'n >= 0 & 'm >= 0. + (bits('n), int, int, int) -> bits('m) effect pure + +function place_slice(xs,i,l,shift) = { + let xs = (xs & slice_mask(i,l)) >> i in + extzv(xs) << shift +} + +val zext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0. + (bits('n), int, int) -> bits('m) effect pure + +function zext_slice(xs,i,l) = { + let xs = (xs & slice_mask(i,l)) >> i in + extzv(xs) +} + +val sext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0. + (bits('n), int, int) -> bits('m) effect pure + +function sext_slice(xs,i,l) = { + let xs = arith_shiftright(((xs & slice_mask(i,l)) << ('n - i - l)), 'n - l) in + extsv(xs) +} + +/* This has different names in the aarch64 prelude (UInt) and the other + preludes (unsigned). To avoid variable name clashes, we redeclare it + here with a suitably awkward name. */ +val _builtin_unsigned = { + ocaml: "uint", + lem: "uint", + interpreter: "uint", + c: "sail_uint" +} : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1) + +val unsigned_slice : forall 'n, 'n >= 0. + (bits('n), int, int) -> int effect pure + +function unsigned_slice(xs,i,l) = { + let xs = (xs & slice_mask(i,l)) >> i in + _builtin_unsigned(xs) +} + +val unsigned_subrange : forall 'n, 'n >= 0. + (bits('n), int, int) -> int effect pure + +function unsigned_subrange(xs,i,j) = { + let xs = (xs & slice_mask(j,i-j)) >> i in + _builtin_unsigned(xs) +} + + +val zext_ones : forall 'n, 'n >= 0. int -> bits('n) effect pure + +function zext_ones(m) = { + let v : bits('n) = extsv(0b1) in + v >> ('n - m) +} @@ -7,6 +7,7 @@ #include<stdbool.h> #include<string.h> #include<gmp.h> +#include<time.h> typedef int unit; @@ -16,6 +17,10 @@ unit undefined_unit(const unit u) { return UNIT; } +bool eq_unit(const unit u1, const unit u2) { + return true; +} + typedef struct { mp_bitcnt_t len; mpz_t *bits; @@ -23,32 +28,60 @@ typedef struct { typedef char *sail_string; -// This function should be called whenever a pattern match failure -// occurs. Pattern match failures are always fatal. +/* Temporary mpzs for use in functions below. To avoid conflicts, only + * use in functions that do not call other functions in this file. */ +static mpz_t sail_lib_tmp1, sail_lib_tmp2; + +/* Wrapper around >> operator to avoid UB when shift amount is greater + than or equal to 64. */ +uint64_t safe_rshift(const uint64_t x, const uint64_t n) { + if (n >= 64) { + return 0ul; + } else { + return x >> n; + } +} + +/* This function should be called whenever a pattern match failure + occurs. Pattern match failures are always fatal. */ void sail_match_failure(sail_string msg) { fprintf(stderr, "Pattern match failure in %s\n", msg); exit(EXIT_FAILURE); } +/* sail_assert implements the assert construct in Sail. If any + assertion fails we immediately exit the model. */ unit sail_assert(bool b, sail_string msg) { if (b) return UNIT; fprintf(stderr, "Assertion failed: %s\n", msg); exit(EXIT_FAILURE); } +/* If the sail model calls the exit() function we print a message and + exit successfully. */ unit sail_exit(const unit u) { - fprintf(stderr, "exit\n"); + fprintf(stderr, "Sail model exit\n"); exit(EXIT_SUCCESS); } +uint64_t g_elf_entry; + void elf_entry(mpz_t *rop, const unit u) { - mpz_set_ui(*rop, 0x400130ul); + mpz_set_ui(*rop, g_elf_entry); } void elf_tohost(mpz_t *rop, const unit u) { mpz_set_ui(*rop, 0x0ul); } +/* ASL->Sail model has an EnterLowPowerState() function that calls a + sleep request builtin. If it gets called we print a message and + exit the model. */ +unit sleep_request(const unit u) { + fprintf(stderr, "Sail model going to sleep\n"); + exit(EXIT_SUCCESS); +} + // Sail bits are mapped to uint64_t where bitzero = 0ul and bitone = 1ul bool eq_bit(const uint64_t a, const uint64_t b) { return a == b; @@ -56,6 +89,10 @@ bool eq_bit(const uint64_t a, const uint64_t b) { uint64_t undefined_bit(unit u) { return 0; } +unit skip(const unit u) { + return UNIT; +} + // ***** Sail booleans ***** bool not(const bool b) { @@ -98,6 +135,16 @@ void set_sail_string(sail_string *str1, const sail_string str2) { *str1 = strcpy(*str1, str2); } +void dec_str(sail_string *str, const mpz_t n) { + free(*str); + gmp_asprintf(str, "%Zd", n); +} + +void hex_str(sail_string *str, const mpz_t n) { + free(*str); + gmp_asprintf(str, "0x%Zx", n); +} + void clear_sail_string(sail_string *str) { free(*str); } @@ -106,6 +153,16 @@ bool eq_string(const sail_string str1, const sail_string str2) { return strcmp(str1, str2) == 0; } +void concat_str(sail_string *stro, const sail_string str1, const sail_string str2) { + *stro = realloc(*stro, strlen(str1) + strlen(str2) + 1); + (*stro)[0] = '\0'; + strcat(*stro, str1); + strcat(*stro, str2); +} + +void undefined_string(sail_string *str, const unit u) { +} + unit print_endline(const sail_string str) { printf("%s\n", str); return UNIT; @@ -136,6 +193,7 @@ unit print_int64(const sail_string str, const int64_t op) { unit sail_putchar(const mpz_t op) { char c = (char) mpz_get_ui(op); putchar(c); + return UNIT; } // ***** Arbitrary precision integers ***** @@ -209,6 +267,10 @@ void shl_int(mpz_t *rop, const mpz_t op1, const mpz_t op2) { mpz_mul_2exp(*rop, op1, mpz_get_ui(op2)); } +void shr_int(mpz_t *rop, const mpz_t op1, const mpz_t op2) { + mpz_fdiv_q_2exp(*rop, op1, mpz_get_ui(op2)); +} + void undefined_int(mpz_t *rop, const unit u) { mpz_set_ui(*rop, 0ul); } @@ -274,14 +336,34 @@ void pow2(mpz_t *rop, mpz_t exp) { mpz_clear(base); } +void get_time_ns(mpz_t *rop, const unit u) { + struct timespec t; + clock_gettime(CLOCK_REALTIME, &t); + mpz_set_si(*rop, t.tv_sec); + mpz_mul_ui(*rop, *rop, 1000000000); + mpz_add_ui(*rop, *rop, t.tv_nsec); +} + // ***** Sail bitvectors ***** -unit print_bits(const sail_string str, const bv_t op) +void string_of_int(sail_string *str, mpz_t i) { + gmp_asprintf(str, "%Zd", i); +} + +void string_of_bits(sail_string *str, const bv_t op) { + if ((op.len % 4) == 0) { + gmp_asprintf(str, "0x%*0Zx", op.len / 4, *op.bits); + } else { + gmp_asprintf(str, "0b%*0Zb", op.len, *op.bits); + } +} + +unit fprint_bits(const sail_string str, const bv_t op, FILE *stream) { - fputs(str, stdout); + fputs(str, stream); if (op.len % 4 == 0) { - fputs("0x", stdout); + fputs("0x", stream); mpz_t buf; mpz_init_set(buf, *op.bits); @@ -294,19 +376,30 @@ unit print_bits(const sail_string str, const bv_t op) } for (int i = op.len / 4; i > 0; --i) { - fputc(hex[i - 1], stdout); + fputc(hex[i - 1], stream); } free(hex); mpz_clear(buf); } else { - fputs("0b", stdout); + fputs("0b", stream); for (int i = op.len; i > 0; --i) { - fputc(mpz_tstbit(*op.bits, i - 1) + 0x30, stdout); + fputc(mpz_tstbit(*op.bits, i - 1) + 0x30, stream); } } - fputs("\n", stdout); + fputs("\n", stream); + return UNIT; +} + +unit print_bits(const sail_string str, const bv_t op) +{ + return fprint_bits(str, op, stdout); +} + +unit prerr_bits(const sail_string str, const bv_t op) +{ + return fprint_bits(str, op, stderr); } void length_bv_t(mpz_t *rop, const bv_t op) { @@ -324,6 +417,14 @@ void reinit_bv_t(bv_t *rop) { mpz_set_ui(*rop->bits, 0); } +void normalise_bv_t(bv_t *rop) { + /* TODO optimisation: keep a set of masks of various sizes handy */ + mpz_set_ui(sail_lib_tmp1, 1); + mpz_mul_2exp(sail_lib_tmp1, sail_lib_tmp1, rop->len); + mpz_sub_ui(sail_lib_tmp1, sail_lib_tmp1, 1); + mpz_and(*rop->bits, *rop->bits, sail_lib_tmp1); +} + void init_bv_t_of_uint64_t(bv_t *rop, const uint64_t op, const uint64_t len, const bool direction) { rop->bits = malloc(sizeof(mpz_t)); rop->len = len; @@ -349,7 +450,7 @@ void append_64(bv_t *rop, const bv_t op, const uint64_t chunk) { void append(bv_t *rop, const bv_t op1, const bv_t op2) { rop->len = op1.len + op2.len; mpz_mul_2exp(*rop->bits, *op1.bits, op2.len); - mpz_add(*rop->bits, *rop->bits, *op2.bits); + mpz_ior(*rop->bits, *rop->bits, *op2.bits); } void replicate_bits(bv_t *rop, const bv_t op1, const mpz_t op2) { @@ -363,9 +464,9 @@ void replicate_bits(bv_t *rop, const bv_t op1, const mpz_t op2) { } uint64_t fast_replicate_bits(const uint64_t shift, const uint64_t v, const int64_t times) { - uint64_t r = 0; - for (int i = 0; i < times; ++i) { - r |= v << shift; + uint64_t r = v; + for (int i = 1; i < times; ++i) { + r |= (r << shift); } return r; } @@ -397,10 +498,16 @@ void zero_extend(bv_t *rop, const bv_t op, const mpz_t len) { mpz_set(*rop->bits, *op.bits); } -/* FIXME */ void sign_extend(bv_t *rop, const bv_t op, const mpz_t len) { rop->len = mpz_get_ui(len); - mpz_set(*rop->bits, *op.bits); + if(mpz_tstbit(*op.bits, op.len - 1)) { + mpz_set(*rop->bits, *op.bits); + for(mp_bitcnt_t i = rop->len - 1; i >= op.len; i--) { + mpz_setbit(*rop->bits, i); + } + } else { + mpz_set(*rop->bits, *op.bits); + } } void clear_bv_t(bv_t *rop) { @@ -441,7 +548,10 @@ void or_bits(bv_t *rop, const bv_t op1, const bv_t op2) { void not_bits(bv_t *rop, const bv_t op) { rop->len = op.len; - mpz_com(*rop->bits, *op.bits); + mpz_set(*rop->bits, *op.bits); + for (mp_bitcnt_t i = 0; i < op.len; i++) { + mpz_combit(*rop->bits, i); + } } void xor_bits(bv_t *rop, const bv_t op1, const bv_t op2) { @@ -449,8 +559,6 @@ void xor_bits(bv_t *rop, const bv_t op1, const bv_t op2) { mpz_xor(*rop->bits, *op1.bits, *op2.bits); } -mpz_t eq_bits_test; - bool eq_bits(const bv_t op1, const bv_t op2) { for (mp_bitcnt_t i = 0; i < op1.len; i++) { @@ -459,39 +567,25 @@ bool eq_bits(const bv_t op1, const bv_t op2) return true; } -// These aren't very efficient, but they work. Question is how best to -// do these given GMP uses a sign bit representation? void sail_uint(mpz_t *rop, const bv_t op) { - mpz_set_ui(*rop, 0ul); - for (mp_bitcnt_t i = 0; i < op.len; ++i) { - if (mpz_tstbit(*op.bits, i)) { - mpz_setbit(*rop, i); - } else { - mpz_clrbit(*rop, i); - } - } + /* Normal form of bv_t is always positive so just return the bits. */ + mpz_set(*rop, *op.bits); } void sint(mpz_t *rop, const bv_t op) { - mpz_set_ui(*rop, 0ul); - if (mpz_tstbit(*op.bits, op.len - 1)) { - for (mp_bitcnt_t i = 0; i < op.len; ++i) { - if (mpz_tstbit(*op.bits, i)) { - mpz_clrbit(*rop, i); - } else { - mpz_setbit(*rop, i); - } - }; - mpz_add_ui(*rop, *rop, 1ul); - mpz_neg(*rop, *rop); + if (op.len == 0) { + mpz_set_ui(*rop, 0); } else { - for (mp_bitcnt_t i = 0; i < op.len; ++i) { - if (mpz_tstbit(*op.bits, i)) { - mpz_setbit(*rop, i); - } else { - mpz_clrbit(*rop, i); - } + mp_bitcnt_t sign_bit = op.len - 1; + mpz_set(*rop, *op.bits); + if (mpz_tstbit(*op.bits, sign_bit) != 0) { + /* If sign bit is unset then we are done, + otherwise clear sign_bit and subtract 2**sign_bit */ + mpz_set_ui(sail_lib_tmp1, 1); + mpz_mul_2exp(sail_lib_tmp1, sail_lib_tmp1, sign_bit); /* 2**sign_bit */ + mpz_combit(*rop, sign_bit); /* clear sign_bit */ + mpz_sub(*rop, *rop, sail_lib_tmp1); } } } @@ -499,19 +593,103 @@ void sint(mpz_t *rop, const bv_t op) void add_bits(bv_t *rop, const bv_t op1, const bv_t op2) { rop->len = op1.len; mpz_add(*rop->bits, *op1.bits, *op2.bits); + normalise_bv_t(rop); +} + +void sub_bits(bv_t *rop, const bv_t op1, const bv_t op2) { + rop->len = op1.len; + mpz_sub(*rop->bits, *op1.bits, *op2.bits); + normalise_bv_t(rop); } void add_bits_int(bv_t *rop, const bv_t op1, const mpz_t op2) { rop->len = op1.len; mpz_add(*rop->bits, *op1.bits, op2); + normalise_bv_t(rop); } void sub_bits_int(bv_t *rop, const bv_t op1, const mpz_t op2) { - printf("sub_bits_int\n"); rop->len = op1.len; mpz_sub(*rop->bits, *op1.bits, op2); } +void mults_vec(bv_t *rop, const bv_t op1, const bv_t op2) { + mpz_t op1_int, op2_int; + mpz_init(op1_int); + mpz_init(op2_int); + sint(&op1_int, op1); + sint(&op2_int, op2); + rop->len = op1.len * 2; + mpz_mul(*rop->bits, op1_int, op2_int); + normalise_bv_t(rop); + mpz_clear(op1_int); + mpz_clear(op2_int); +} + +void mult_vec(bv_t *rop, const bv_t op1, const bv_t op2) { + rop->len = op1.len * 2; + mpz_mul(*rop->bits, *op1.bits, *op2.bits); + normalise_bv_t(rop); /* necessary? */ +} + +void shift_bits_left(bv_t *rop, const bv_t op1, const bv_t op2) { + rop->len = op1.len; + mpz_mul_2exp(*rop->bits, *op1.bits, mpz_get_ui(*op2.bits)); + normalise_bv_t(rop); +} + +void shift_bits_right(bv_t *rop, const bv_t op1, const bv_t op2) { + rop->len = op1.len; + mpz_tdiv_q_2exp(*rop->bits, *op1.bits, mpz_get_ui(*op2.bits)); +} + +/* FIXME */ +void shift_bits_right_arith(bv_t *rop, const bv_t op1, const bv_t op2) { + rop->len = op1.len; + mp_bitcnt_t shift_amt = mpz_get_ui(*op2.bits); + mp_bitcnt_t sign_bit = op1.len - 1; + mpz_fdiv_q_2exp(*rop->bits, *op1.bits, shift_amt); + if(mpz_tstbit(*op1.bits, sign_bit) != 0) { + /* */ + for(; shift_amt > 0; shift_amt--) { + mpz_setbit(*rop->bits, sign_bit - shift_amt + 1); + } + } +} + +void reverse_endianness(bv_t *rop, const bv_t op) { + rop->len = op.len; + if (rop->len == 64ul) { + uint64_t x = mpz_get_ui(*op.bits); + x = (x & 0xFFFFFFFF00000000) >> 32 | (x & 0x00000000FFFFFFFF) << 32; + x = (x & 0xFFFF0000FFFF0000) >> 16 | (x & 0x0000FFFF0000FFFF) << 16; + x = (x & 0xFF00FF00FF00FF00) >> 8 | (x & 0x00FF00FF00FF00FF) << 8; + mpz_set_ui(*rop->bits, x); + } else if (rop->len == 32ul) { + uint64_t x = mpz_get_ui(*op.bits); + x = (x & 0xFFFF0000FFFF0000) >> 16 | (x & 0x0000FFFF0000FFFF) << 16; + x = (x & 0xFF00FF00FF00FF00) >> 8 | (x & 0x00FF00FF00FF00FF) << 8; + mpz_set_ui(*rop->bits, x); + } else if (rop->len == 16ul) { + uint64_t x = mpz_get_ui(*op.bits); + x = (x & 0xFF00FF00FF00FF00) >> 8 | (x & 0x00FF00FF00FF00FF) << 8; + mpz_set_ui(*rop->bits, x); + } else if (rop->len == 8ul) { + mpz_set(*rop->bits, *op.bits); + } else { + /* For other numbers of bytes we reverse the bytes. + * XXX could use mpz_import/export for this. */ + mpz_set_ui(sail_lib_tmp1, 0xff); // byte mask + mpz_set_ui(*rop->bits, 0); // reset accumulator for result + for(mp_bitcnt_t byte = 0; byte < op.len; byte+=8) { + mpz_tdiv_q_2exp(sail_lib_tmp2, *op.bits, byte); // shift byte to bottom + mpz_and(sail_lib_tmp2, sail_lib_tmp2, sail_lib_tmp1); // and with mask + mpz_mul_2exp(*rop->bits, *rop->bits, 8); // shift result left 8 + mpz_ior(*rop->bits, *rop->bits, sail_lib_tmp2); // or byte into result + } + } +} + // Takes a slice of the (two's complement) binary representation of // integer n, starting at bit start, and of length len. With the // argument in the following order: @@ -767,26 +945,18 @@ uint64_t MASK = 0xFFFFul; // are used in the second argument. void write_mem(uint64_t address, uint64_t byte) { + //printf("ADDR: %lu, BYTE: %lu\n", address, byte); + uint64_t mask = address & ~MASK; uint64_t offset = address & MASK; - struct block *prev = NULL; struct block *current = sail_memory; while (current != NULL) { if (current->block_id == mask) { current->mem[offset] = (uint8_t) byte; - - /* Move the accessed block to the front of the block list */ - if (prev != NULL) { - prev->next = current->next; - } - current->next = sail_memory->next; - sail_memory = current; - return; } else { - prev = current; current = current->next; } } @@ -886,6 +1056,18 @@ void read_ram(bv_t *data, mpz_clear(byte); } +unit load_raw(uint64_t addr, const sail_string file) { + FILE *fp = fopen(file, "r"); + + uint64_t byte; + while ((byte = (uint64_t)fgetc(fp)) != EOF) { + write_mem(addr, byte); + addr++; + } + + return UNIT; +} + void load_image(char *file) { FILE *fp = fopen(file, "r"); @@ -904,7 +1086,15 @@ void load_image(char *file) { ssize_t data_len = getline(&data, &len, fp); if (data_len == -1) break; - write_mem((uint64_t) atoll(addr), (uint64_t) atoll(data)); + if (!strcmp(addr, "elf_entry\n")) { + if (sscanf(data, "%" PRIu64 "\n", &g_elf_entry) != 1) { + fprintf(stderr, "Failed to parse elf_entry\n"); + exit(EXIT_FAILURE); + }; + printf("Elf entry point: %" PRIx64 "\n", g_elf_entry); + } else { + write_mem((uint64_t) atoll(addr), (uint64_t) atoll(data)); + } } free(addr); @@ -923,10 +1113,12 @@ void load_instr(uint64_t addr, uint32_t instr) { void setup_library(void) { mpf_set_default_prec(FLOAT_PRECISION); - mpz_init(eq_bits_test); + mpz_init(sail_lib_tmp1); + mpz_init(sail_lib_tmp2); } void cleanup_library(void) { - mpz_clear(eq_bits_test); + mpz_clear(sail_lib_tmp1); + mpz_clear(sail_lib_tmp2); kill_mem(); } diff --git a/lib/vector_dec.sail b/lib/vector_dec.sail index 17603e03..d9b80b32 100644 --- a/lib/vector_dec.sail +++ b/lib/vector_dec.sail @@ -10,7 +10,8 @@ val "eq_bit" : (bit, bit) -> bool val eq_bits = { ocaml: "eq_list", lem: "eq_vec", - c: "eq_bits" + c: "eq_bits", + coq: "eq_vec" } : forall 'n. (vector('n, dec, bit), vector('n, dec, bit)) -> bool overload operator == = {eq_bit, eq_bits} @@ -20,33 +21,36 @@ val bitvector_length = {coq: "length_mword", _:"length"} : forall 'n. bits('n) - val vector_length = { ocaml: "length", lem: "length_list", - c: "length" + c: "length", + coq: "length_list" } : forall 'n ('a : Type). vector('n, dec, 'a) -> atom('n) overload length = {bitvector_length, vector_length} -val "zeros" : forall 'n. atom('n) -> bits('n) +val sail_zeros = "zeros" : forall 'n. atom('n) -> bits('n) val "print_bits" : forall 'n. (string, bits('n)) -> unit -val "sign_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom('m)) -> bits('m) +val "prerr_bits" : forall 'n. (string, bits('n)) -> unit -val "zero_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom('m)) -> bits('m) +val sail_sign_extend = "sign_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom('m)) -> bits('m) + +val sail_zero_extend = "zero_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom('m)) -> bits('m) val truncate = { ocaml: "vector_truncate", lem: "vector_truncate", coq: "vector_truncate", c: "truncate" -} : forall 'm 'n, 'm <= 'n. (vector('n, dec, bit), atom('m)) -> vector('m, dec, bit) +} : forall 'm 'n, 'm >= 0 & 'm <= 'n. (vector('n, dec, bit), atom('m)) -> vector('m, dec, bit) -val mask : forall 'len 'v, 'v >= 0. (atom('len), vector('v, dec, bit)) -> vector('len, dec, bit) +val sail_mask : forall 'len 'v, 'len >= 0 & 'v >= 0. (atom('len), vector('v, dec, bit)) -> vector('len, dec, bit) -function mask(len, v) = if len <= length(v) then truncate(v, len) else zero_extend(v, len) +function sail_mask(len, v) = if len <= length(v) then truncate(v, len) else sail_zero_extend(v, len) -overload operator ^ = {mask} +overload operator ^ = {sail_mask} -val bitvector_concat = {ocaml: "append", lem: "concat_vec", c: "append"} : forall ('n : Int) ('m : Int). +val bitvector_concat = {ocaml: "append", lem: "concat_vec", c: "append", coq: "concat_vec"} : forall ('n : Int) ('m : Int). (bits('n), bits('m)) -> bits('n + 'm) overload append = {bitvector_concat} @@ -57,25 +61,29 @@ val "append_64" : forall 'n. (bits('n), bits(64)) -> bits('n + 64) val vector_access = { ocaml: "access", lem: "access_list_dec", + coq: "access_list_dec", c: "vector_access" } : forall ('n : Int) ('m : Int) ('a : Type), 0 <= 'm < 'n. (vector('n, dec, 'a), atom('m)) -> 'a val vector_update = { ocaml: "update", lem: "update_list_dec", + coq: "update_list_dec", c: "vector_update" } : forall 'n ('a : Type). (vector('n, dec, 'a), int, 'a) -> vector('n, dec, 'a) val add_bits = { ocaml: "add_vec", lem: "add_vec", - c: "add_bits" + c: "add_bits", + coq: "add_vec" } : forall 'n. (bits('n), bits('n)) -> bits('n) val add_bits_int = { ocaml: "add_vec_int", lem: "add_vec_int", - c: "add_bits_int" + c: "add_bits_int", + coq: "add_vec_int" } : forall 'n. (bits('n), int) -> bits('n) overload operator + = {add_bits, add_bits_int} @@ -83,14 +91,16 @@ overload operator + = {add_bits, add_bits_int} val vector_subrange = { ocaml: "subrange", lem: "subrange_vec_dec", - c: "vector_subrange" + c: "vector_subrange", + coq: "subrange_vec_dec" } : forall ('n : Int) ('m : Int) ('o : Int), 'o <= 'm <= 'n. (bits('n), atom('m), atom('o)) -> bits('m - ('o - 1)) val vector_update_subrange = { ocaml: "update_subrange", lem: "update_subrange_vec_dec", - c: "vector_update_subrange" + c: "vector_update_subrange", + coq: "update_subrange_vec_dec" } : forall 'n 'm 'o. (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n) // Some ARM specific builtins @@ -111,7 +121,8 @@ val unsigned = { ocaml: "uint", lem: "uint", interpreter: "uint", - c: "sail_uint" + c: "sail_uint", + coq: "unsigned" } : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1) val signed = "sint" : forall 'n. bits('n) -> range(- (2 ^ ('n - 1)), 2 ^ ('n - 1) - 1) diff --git a/lib/vector_inc.sail b/lib/vector_inc.sail index 04c95996..b13e053c 100644 --- a/lib/vector_inc.sail +++ b/lib/vector_inc.sail @@ -10,17 +10,19 @@ val "eq_bit" : (bit, bit) -> bool val eq_bits = { ocaml: "eq_list", lem: "eq_vec", - c: "eq_bits" + c: "eq_bits", + coq: "eq_vec" } : forall 'n. (vector('n, inc, bit), vector('n, inc, bit)) -> bool overload operator == = {eq_bit, eq_bits} -val bitvector_length = "length" : forall 'n. bits('n) -> atom('n) +val bitvector_length = {coq: "length_mword", _:"length"} : forall 'n. bits('n) -> atom('n) val vector_length = { ocaml: "length", lem: "length_list", - c: "length" + c: "length", + coq: "length_list" } : forall 'n ('a : Type). vector('n, inc, 'a) -> atom('n) overload length = {bitvector_length, vector_length} @@ -36,16 +38,17 @@ val "zero_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom('m)) -> bits('m) val truncate = { ocaml: "vector_truncate", lem: "vector_truncate", + coq: "vector_truncate", c: "truncate" -} : forall 'm 'n, 'm <= 'n. (vector('n, inc, bit), atom('m)) -> vector('m, inc, bit) +} : forall 'm 'n, 'm >= 0 & 'm <= 'n. (vector('n, inc, bit), atom('m)) -> vector('m, inc, bit) -val mask : forall 'len 'v, 'v >= 0. (atom('len), vector('v, inc, bit)) -> vector('len, inc, bit) +val mask : forall 'len 'v, 'len >= 0 & 'v >= 0. (atom('len), vector('v, inc, bit)) -> vector('len, inc, bit) function mask(len, v) = if len <= length(v) then truncate(v, len) else zero_extend(v, len) overload operator ^ = {mask} -val bitvector_concat = {ocaml: "append", lem: "concat_vec", c: "append"} : forall ('n : Int) ('m : Int). +val bitvector_concat = {ocaml: "append", lem: "concat_vec", c: "append", coq: "concat_vec"} : forall ('n : Int) ('m : Int). (bits('n), bits('m)) -> bits('n + 'm) overload append = {bitvector_concat} @@ -56,12 +59,14 @@ val "append_64" : forall 'n. (bits('n), bits(64)) -> bits('n + 64) val vector_access = { ocaml: "access", lem: "access_list_inc", + coq: "access_list_inc", c: "vector_access" } : forall ('n : Int) ('m : Int) ('a : Type), 0 <= 'm < 'n. (vector('n, inc, 'a), atom('m)) -> 'a val vector_update = { ocaml: "update", lem: "update_list_inc", + coq: "update_list_inc", c: "vector_update" } : forall 'n ('a : Type). (vector('n, inc, 'a), int, 'a) -> vector('n, inc, 'a) @@ -80,14 +85,16 @@ overload operator + = {add_bits, add_bits_int} val vector_subrange = { ocaml: "subrange", lem: "subrange_vec_inc", - c: "vector_subrange" + c: "vector_subrange", + coq: "subrange_vec_inc" } : forall ('n : Int) ('m : Int) ('o : Int), 'o <= 'm <= 'n. (bits('n), atom('m), atom('o)) -> bits('m - ('o - 1)) val vector_update_subrange = { ocaml: "update_subrange", lem: "update_subrange_vec_inc", - c: "vector_update_subrange" + c: "vector_update_subrange", + coq: "update_subrange_vec_inc" } : forall 'n 'm 'o. (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n) // Some ARM specific builtins @@ -107,7 +114,8 @@ val unsigned = { ocaml: "uint", lem: "uint", interpreter: "uint", - c: "sail_uint" + c: "sail_uint", + coq: "uint" } : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1) val signed = "sint" : forall 'n. bits('n) -> range(- (2 ^ ('n - 1)), 2 ^ ('n - 1) - 1) |
