diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/arith.sail | 20 | ||||
| -rw-r--r-- | lib/coq/Sail2_operators_mwords.v | 29 | ||||
| -rw-r--r-- | lib/coq/Sail2_prompt.v | 10 | ||||
| -rw-r--r-- | lib/coq/Sail2_string.v | 8 | ||||
| -rw-r--r-- | lib/coq/Sail2_values.v | 57 | ||||
| -rw-r--r-- | lib/flow.sail | 53 | ||||
| -rw-r--r-- | lib/mono_rewrites.sail | 4 | ||||
| -rw-r--r-- | lib/option.sail | 2 | ||||
| -rw-r--r-- | lib/regfp.sail | 1 | ||||
| -rw-r--r-- | lib/rts.c | 24 | ||||
| -rw-r--r-- | lib/rts.h | 24 | ||||
| -rw-r--r-- | lib/sail.c | 353 | ||||
| -rw-r--r-- | lib/sail.h | 204 | ||||
| -rw-r--r-- | lib/smt.sail | 6 | ||||
| -rw-r--r-- | lib/vector_dec.sail | 29 | ||||
| -rw-r--r-- | lib/vector_inc.sail | 8 |
16 files changed, 573 insertions, 259 deletions
diff --git a/lib/arith.sail b/lib/arith.sail index 3a1b0927..eed257fb 100644 --- a/lib/arith.sail +++ b/lib/arith.sail @@ -50,7 +50,25 @@ val "prerr_int" : (string, int) -> unit // ***** Integer shifts ***** -val shl_int = "shl_int" : (int, int) -> int +/*! +A common idiom in asl is to take two bits of an opcode and convert in into a variable like +``` +let elsize = shl_int(8, UInt(size)) +``` +THIS ensures that in this case the typechecker knows that the end result will be a value in the set `{8, 16, 32, 64}` +*/ +val _shl8 = {c: "shl_mach_int", _: "shl_int"} : + forall 'n, 0 <= 'n <= 3. (int(8), int('n)) -> {'m, 'm in {8, 16, 32, 64}. int('m)} + +/*! +Similarly, we can shift 32 by either 0 or 1 to get a value in `{32, 64}` +*/ +val _shl32 = {c: "shl_mach_int", _: "shl_int"} : + forall 'n, 'n in {0, 1}. (int(32), int('n)) -> {'m, 'm in {32, 64}. int('m)} + +val _shl_int = "shl_int" : (int, int) -> int + +overload shl_int = {_shl8, _shl32, _shl_int} val shr_int = "shr_int" : (int, int) -> int diff --git a/lib/coq/Sail2_operators_mwords.v b/lib/coq/Sail2_operators_mwords.v index 497b4a46..809f9d89 100644 --- a/lib/coq/Sail2_operators_mwords.v +++ b/lib/coq/Sail2_operators_mwords.v @@ -9,12 +9,6 @@ Require Import ZArith. Require Import Omega. Require Import Eqdep_dec. -Module Z_eq_dec. -Definition U := Z. -Definition eq_dec := Z.eq_dec. -End Z_eq_dec. -Module ZEqdep := DecidableEqDep (Z_eq_dec). - Fixpoint cast_positive (T : positive -> Type) (p q : positive) : T p -> p = q -> T q. refine ( match p, q with @@ -178,6 +172,13 @@ Definition zero_extend {a} (v : mword a) (n : Z) `{ArithFact (n >= a)} : mword n Definition sign_extend {a} (v : mword a) (n : Z) `{ArithFact (n >= a)} : mword n := exts_vec n v. +Definition zeros (n : Z) `{ArithFact (n >= 0)} : mword n. +refine (cast_to_mword (Word.wzero (Z.to_nat n)) _). +unwrap_ArithFacts. +apply Z2Nat.id. +auto with zarith. +Defined. + Lemma truncate_eq {m n} : m >= 0 -> m <= n -> (Z.to_nat n = Z.to_nat m + (Z.to_nat n - Z.to_nat m))%nat. intros. assert ((Z.to_nat m <= Z.to_nat n)%nat). @@ -444,6 +445,20 @@ Definition sgteq_vec := sgteq_bv. *) +Definition eq_vec_dec {n} : forall (x y : mword n), {x = y} + {x <> y}. +refine (match n with +| Z0 => _ +| Zpos m => _ +| Zneg m => _ +end). +* simpl. apply Word.weq. +* simpl. apply Word.weq. +* simpl. destruct x. +Defined. + +Instance Decidable_eq_mword {n} : forall (x y : mword n), Decidable (x = y) := + Decidable_eq_from_dec eq_vec_dec. + Program Fixpoint reverse_endianness_word {n} (bits : word n) : word n := match n with | S (S (S (S (S (S (S (S m))))))) => @@ -471,3 +486,5 @@ Definition set_slice_int len n lo (v : mword len) : Z := let bs : mword (hi + 1) := mword_of_int n in (int_of_mword true (update_subrange_vec_dec bs hi lo v)) else n. + +Definition prerr_bits {a} (s : string) (bs : mword a) : unit := tt. diff --git a/lib/coq/Sail2_prompt.v b/lib/coq/Sail2_prompt.v index 1b12f360..85ca95f6 100644 --- a/lib/coq/Sail2_prompt.v +++ b/lib/coq/Sail2_prompt.v @@ -1,7 +1,7 @@ (*Require Import Sail_impl_base*) Require Import Sail2_values. Require Import Sail2_prompt_monad. - +Require Export ZArith.Zwf. Require Import List. Import ListNotations. (* @@ -77,6 +77,14 @@ match b with | BU => undefined_bool tt end. +(* For termination of recursive functions. We don't name assertions, so use + the type class mechanism to find it. *) +Definition _limit_reduces {_limit} (_acc:Acc (Zwf 0) _limit) `{ArithFact (_limit >= 0)} : Acc (Zwf 0) (_limit - 1). +refine (Acc_inv _acc _). +destruct H. +red. +omega. +Defined. (*val whileM : forall 'rv 'vars 'e. 'vars -> ('vars -> monad 'rv bool 'e) -> ('vars -> monad 'rv 'vars 'e) -> monad 'rv 'vars 'e diff --git a/lib/coq/Sail2_string.v b/lib/coq/Sail2_string.v index a02556b2..0a00f8d7 100644 --- a/lib/coq/Sail2_string.v +++ b/lib/coq/Sail2_string.v @@ -7,12 +7,12 @@ Definition string_startswith s expected := let prefix := String.substring 0 (String.length expected) s in generic_eq prefix expected. -Definition string_drop s (n : {n : Z & ArithFact (n >= 0)}) := - let n := Z.to_nat (projT1 n) in +Definition string_drop s (n : Z) `{ArithFact (n >= 0)} := + let n := Z.to_nat n in String.substring n (String.length s - n) s. -Definition string_take s (n : {n : Z & ArithFact (n >= 0)}) := - let n := Z.to_nat (projT1 n) in +Definition string_take s (n : Z) `{ArithFact (n >= 0)} := + let n := Z.to_nat n in String.substring 0 n s. Definition string_length s : {n : Z & ArithFact (n >= 0)} := diff --git a/lib/coq/Sail2_values.v b/lib/coq/Sail2_values.v index 83fe1dc7..e3e039c2 100644 --- a/lib/coq/Sail2_values.v +++ b/lib/coq/Sail2_values.v @@ -8,10 +8,18 @@ Require Import bbv.Word. Require Export List. Require Export Sumbool. Require Export DecidableClass. +Require Import Eqdep_dec. Import ListNotations. Open Scope Z. +Module Z_eq_dec. +Definition U := Z. +Definition eq_dec := Z.eq_dec. +End Z_eq_dec. +Module ZEqdep := DecidableEqDep (Z_eq_dec). + + (* Constraint solving basics. A HintDb which unfolding hints and lemmata can be added to, and a typeclass to wrap constraint arguments in to trigger automatic solving. *) @@ -93,6 +101,22 @@ split. tauto. Qed. +Definition generic_dec {T:Type} (x y:T) `{Decidable (x = y)} : {x = y} + {x <> y}. +refine ((if Decidable_witness as b return (b = true <-> x = y -> _) then fun H' => _ else fun H' => _) Decidable_spec). +* left. tauto. +* right. intuition. +Defined. + +(* Used by generated code that builds Decidable equality instances for records. *) +Ltac cmp_record_field x y := + let H := fresh "H" in + case (generic_dec x y); + intro H; [ | + refine (Build_Decidable _ false _); + split; [congruence | intros Z; destruct H; injection Z; auto] + ]. + + (* Project away range constraints in comparisons *) Definition ltb_range_l {lo hi} (l : {x & ArithFact (lo <= x /\ x <= hi)}) r := Z.ltb (projT1 l) r. @@ -1101,6 +1125,10 @@ repeat end. *) +(* The linear solver doesn't like existentials. *) +Ltac destruct_exists := + repeat match goal with H:@ex Z _ |- _ => destruct H end. + Ltac prepare_for_solver := (*dump_context;*) clear_irrelevant_defns; @@ -1110,6 +1138,7 @@ Ltac prepare_for_solver := extract_properties; repeat match goal with w:mword ?n |- _ => apply ArithFact_mword in w end; unwrap_ArithFacts; + destruct_exists; unbool_comparisons; unfold_In; (* after unbool_comparisons to deal with && and || *) reduce_list_lengths; @@ -1151,6 +1180,8 @@ prepare_for_solver; [ match goal with |- ArithFact (?x _) => is_evar x; idtac "Warning: unknown constraint"; constructor; exact (I : (fun _ => True) _) end | apply ArithFact_mword; assumption | constructor; omega with Z + (* Try sail hints before dropping the existential *) + | constructor; eauto 3 with zarith sail (* The datatypes hints give us some list handling, esp In *) | constructor; drop_exists; eauto 3 with datatypes zarith sail | constructor; idtac "Unable to solve constraint"; dump_context; fail @@ -1736,6 +1767,20 @@ Qed. Definition list_of_vec {A n} (v : vec A n) : list A := projT1 v. +Definition vec_eq_dec {T n} (D : forall x y : T, {x = y} + {x <> y}) (x y : vec T n) : + {x = y} + {x <> y}. +refine (if List.list_eq_dec D (projT1 x) (projT1 y) then left _ else right _). +* apply eq_sigT_hprop; auto using ZEqdep.UIP. +* contradict n0. rewrite n0. reflexivity. +Defined. + +Instance Decidable_eq_vec {T : Type} {n} `(DT : forall x y : T, Decidable (x = y)) : + forall x y : vec T n, Decidable (x = y) := { + Decidable_witness := proj1_sig (bool_of_sumbool (vec_eq_dec (fun x y => generic_dec x y) x y)) +}. +destruct (vec_eq_dec _ x y); simpl; split; congruence. +Defined. + Program Definition vec_of_list {A} n (l : list A) : option (vec A n) := if sumbool_of_bool (n =? length_list l) then Some (existT _ l _) else None. Next Obligation. @@ -1752,7 +1797,15 @@ match a with | None => None end. -Definition sub_nat (x : {x : Z & ArithFact (x >= 0)}) (y : {y : Z & ArithFact (y >= 0)}) : +Definition sub_nat (x : Z) `{ArithFact (x >= 0)} (y : Z) `{ArithFact (y >= 0)} : {z : Z & ArithFact (z >= 0)} := - let z := projT1 x - projT1 y in + let z := x - y in if sumbool_of_bool (z >=? 0) then build_ex z else build_ex 0. + +Definition min_nat (x : Z) `{ArithFact (x >= 0)} (y : Z) `{ArithFact (y >= 0)} : + {z : Z & ArithFact (z >= 0)} := + build_ex (Z.min x y). + +Definition max_nat (x : Z) `{ArithFact (x >= 0)} (y : Z) `{ArithFact (y >= 0)} : + {z : Z & ArithFact (z >= 0)} := + build_ex (Z.max x y). diff --git a/lib/flow.sail b/lib/flow.sail index 7c6f1ebb..eb7b8038 100644 --- a/lib/flow.sail +++ b/lib/flow.sail @@ -11,7 +11,7 @@ therefore be included in just about every Sail specification. val eq_unit : (unit, unit) -> bool -val "eq_bit" : (bit, bit) -> bool +val eq_bit = { lem : "eq", _ : "eq_bit" } : (bit, bit) -> bool function eq_unit(_, _) = true @@ -20,34 +20,9 @@ val not_bool = {coq: "negb", _: "not"} : bool -> bool or_bool that are not shown here. */ val and_bool = {coq: "andb", _: "and_bool"} : (bool, bool) -> bool val or_bool = {coq: "orb", _: "or_bool"} : (bool, bool) -> bool - -val eq_atom = {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", c: "eq_int", coq: "Z.eqb"} : forall 'n 'm. (atom('n), atom('m)) -> bool - -val neq_atom = {lem: "neq", coq: "neq_atom"} : forall 'n 'm. (atom('n), atom('m)) -> bool - -function neq_atom (x, y) = not_bool(eq_atom(x, y)) - -val lteq_atom = {coq: "Z.leb", _: "lteq"} : forall 'n 'm. (atom('n), atom('m)) -> bool -val gteq_atom = {coq: "Z.geb", _: "gteq"} : forall 'n 'm. (atom('n), atom('m)) -> bool -val lt_atom = {coq: "Z.ltb", _: "lt"} : forall 'n 'm. (atom('n), atom('m)) -> bool -val gt_atom = {coq: "Z.gtb", _: "gt"} : forall 'n 'm. (atom('n), atom('m)) -> bool - -val lt_range_atom = {coq: "ltb_range_l", _: "lt"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool -val lteq_range_atom = {coq: "leb_range_l", _: "lteq"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool -val gt_range_atom = {coq: "gtb_range_l", _: "gt"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool -val gteq_range_atom = {coq: "geb_range_l", _: "gteq"} : forall 'n 'm 'o. (range('n, 'm), atom('o)) -> bool -val lt_atom_range = {coq: "ltb_range_r", _: "lt"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool -val lteq_atom_range = {coq: "leb_range_r", _: "lteq"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool -val gt_atom_range = {coq: "gtb_range_r", _: "gt"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool -val gteq_atom_range = {coq: "geb_range_r", _: "gteq"} : forall 'n 'm 'o. (atom('n), range('m, 'o)) -> bool - -val eq_range = {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", c: "eq_int", coq: "eq_range"} : forall 'n 'm 'o 'p. (range('n, 'm), range('o, 'p)) -> bool val eq_int = {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", c: "eq_int", coq: "Z.eqb"} : (int, int) -> bool val eq_bool = {ocaml: "eq_bool", interpreter: "eq_bool", lem: "eq", c: "eq_bool", coq: "Bool.eqb"} : (bool, bool) -> bool -val neq_range = {lem: "neq"} : forall 'n 'm 'o 'p. (range('n, 'm), range('o, 'p)) -> bool -function neq_range (x, y) = not_bool(eq_range(x, y)) - val neq_int = {lem: "neq"} : (int, int) -> bool function neq_int (x, y) = not_bool(eq_int(x, y)) @@ -59,28 +34,14 @@ val gteq_int = {coq: "Z.geb", _:"gteq"} : (int, int) -> bool val lt_int = {coq: "Z.ltb", _:"lt"} : (int, int) -> bool val gt_int = {coq: "Z.gtb", _:"gt"} : (int, int) -> bool -overload operator == = {eq_atom, eq_range, eq_int, eq_bit, eq_bool, eq_unit} -overload operator != = {neq_atom, neq_range, neq_int, neq_bool} +overload operator == = {eq_int, eq_bit, eq_bool, eq_unit} +overload operator != = {neq_int, neq_bool} overload operator | = {or_bool} overload operator & = {and_bool} -overload operator <= = {lteq_atom, lteq_range_atom, lteq_atom_range, lteq_int} -overload operator < = {lt_atom, lt_range_atom, lt_atom_range, lt_int} -overload operator >= = {gteq_atom, gteq_range_atom, gteq_atom_range, gteq_int} -overload operator > = {gt_atom, gt_range_atom, gt_atom_range, gt_int} - -$ifdef TEST - -val __flow_test : forall 'n 'm. (atom('n), atom('m)) -> unit - -function __flow_test (x, y) = { - if lteq_atom(x, y) then { - _prove(constraint('n <= 'm)) - } else { - _prove(constraint('n > 'm)) - } -} - -$endif +overload operator <= = {lteq_int} +overload operator < = {lt_int} +overload operator >= = {gteq_int} +overload operator > = {gt_int} $endif diff --git a/lib/mono_rewrites.sail b/lib/mono_rewrites.sail index 9e837e10..93ad3db5 100644 --- a/lib/mono_rewrites.sail +++ b/lib/mono_rewrites.sail @@ -63,7 +63,7 @@ function slice_slice_concat (xs, i, l, ys, i', l') = { extzv(xs) << l' | extzv(ys) } -val slice_zeros_concat : forall 'n 'p 'q 'r, 'r = 'p + 'q & 'n >= 0 & 'p >= 0 & 'q >= 0 & 'r >= 0. +val slice_zeros_concat : forall 'n 'p 'q 'r, 'r == 'p + 'q & 'n >= 0 & 'p >= 0 & 'q >= 0 & 'r >= 0. (bits('n), int, atom('p), atom('q)) -> bits('r) effect pure function slice_zeros_concat (xs, i, l, l') = { @@ -82,7 +82,7 @@ function subrange_subrange_eq (xs, i, j, ys, i', j') = { xs == ys } -val subrange_subrange_concat : forall 'n 'o 'p 'm 'q 'r 's, 's = 'o - ('p - 1) + 'q - ('r - 1) & 'n >= 0 & 'm >= 0. +val subrange_subrange_concat : forall 'n 'o 'p 'm 'q 'r 's, 's >= 0 & 'n >= 0 & 'm >= 0. (bits('n), atom('o), atom('p), bits('m), atom('q), atom('r)) -> bits('s) effect pure function subrange_subrange_concat (xs, i, j, ys, i', j') = { diff --git a/lib/option.sail b/lib/option.sail index 3869167b..514cf7ba 100644 --- a/lib/option.sail +++ b/lib/option.sail @@ -6,7 +6,7 @@ $define _OPTION // this won't work - also no other type should be created with // constructors named Some or None. -union option ('a : Type) = { +union option('a: Type) = { Some : 'a, None : unit } diff --git a/lib/regfp.sail b/lib/regfp.sail index d69728df..731f1a8d 100644 --- a/lib/regfp.sail +++ b/lib/regfp.sail @@ -82,6 +82,7 @@ enum barrier_kind = { Barrier_RISCV_rw_r, Barrier_RISCV_r_w, Barrier_RISCV_w_r, + Barrier_RISCV_tso, Barrier_RISCV_i, Barrier_x86_MFENCE } @@ -32,7 +32,7 @@ unit sail_exit(unit u) static uint64_t g_verbosity = 0; -mach_bits sail_get_verbosity(const unit u) +fbits sail_get_verbosity(const unit u) { return g_verbosity; } @@ -94,6 +94,10 @@ void write_mem(uint64_t address, uint64_t byte) uint64_t mask = address & ~MASK; uint64_t offset = address & MASK; + //if ((byte >= 97 && byte <= 122) || (byte >= 64 && byte <= 90) || (byte >= 48 && byte <= 57) || byte == 10 || byte == 32) { + // fprintf(stderr, "%c", (char) byte); + //} + struct block *current = sail_memory; while (current != NULL) { @@ -210,9 +214,9 @@ void kill_mem() bool write_ram(const mpz_t addr_size, // Either 32 or 64 const mpz_t data_size_mpz, // Number of bytes - const sail_bits hex_ram, // Currently unused - const sail_bits addr_bv, - const sail_bits data) + const lbits hex_ram, // Currently unused + const lbits addr_bv, + const lbits data) { uint64_t addr = mpz_get_ui(*addr_bv.bits); uint64_t data_size = mpz_get_ui(data_size_mpz); @@ -234,11 +238,11 @@ bool write_ram(const mpz_t addr_size, // Either 32 or 64 return true; } -void read_ram(sail_bits *data, +void read_ram(lbits *data, const mpz_t addr_size, const mpz_t data_size_mpz, - const sail_bits hex_ram, - const sail_bits addr_bv) + const lbits hex_ram, + const lbits addr_bv) { uint64_t addr = mpz_get_ui(*addr_bv.bits); uint64_t data_size = mpz_get_ui(data_size_mpz); @@ -257,7 +261,7 @@ void read_ram(sail_bits *data, mpz_clear(byte); } -unit load_raw(mach_bits addr, const sail_string file) +unit load_raw(fbits addr, const sail_string file) { FILE *fp = fopen(file, "r"); @@ -335,7 +339,7 @@ bool is_tracing(const unit u) return g_trace_enabled; } -void trace_mach_bits(const mach_bits x) { +void trace_fbits(const fbits x) { if (g_trace_enabled) fprintf(stderr, "0x%" PRIx64, x); } @@ -351,7 +355,7 @@ void trace_sail_int(const sail_int op) { if (g_trace_enabled) mpz_out_str(stderr, 10, op); } -void trace_sail_bits(const sail_bits op) { +void trace_lbits(const lbits op) { if (g_trace_enabled) fprint_bits("", op, "", stderr); } @@ -26,7 +26,7 @@ unit sail_exit(unit); * The intention is that you can use individual bits to turn on/off different * pieces of debugging output. */ -mach_bits sail_get_verbosity(const unit u); +fbits sail_get_verbosity(const unit u); /* * Put processor to sleep until an external device calls wakeup_request(). @@ -55,20 +55,20 @@ uint64_t read_mem(uint64_t); bool write_ram(const mpz_t addr_size, // Either 32 or 64 const mpz_t data_size_mpz, // Number of bytes - const sail_bits hex_ram, // Currently unused - const sail_bits addr_bv, - const sail_bits data); + const lbits hex_ram, // Currently unused + const lbits addr_bv, + const lbits data); -void read_ram(sail_bits *data, +void read_ram(lbits *data, const mpz_t addr_size, const mpz_t data_size_mpz, - const sail_bits hex_ram, - const sail_bits addr_bv); + const lbits hex_ram, + const lbits addr_bv); -unit write_tag_bool(const mach_bits, const bool); -bool read_tag_bool(const mach_bits); +unit write_tag_bool(const fbits, const bool); +bool read_tag_bool(const fbits); -unit load_raw(mach_bits addr, const sail_string file); +unit load_raw(fbits addr, const sail_string file); void load_image(char *); @@ -106,8 +106,8 @@ void trace_sail_int(const sail_int); void trace_bool(const bool); void trace_unit(const unit); void trace_sail_string(const sail_string); -void trace_mach_bits(const mach_bits); -void trace_sail_bits(const sail_bits); +void trace_fbits(const fbits); +void trace_lbits(const lbits); void trace_unknown(void); void trace_argsep(void); @@ -7,6 +7,8 @@ #include<string.h> #include<time.h> +#include <x86intrin.h> + #include"sail.h" /* @@ -53,7 +55,7 @@ unit skip(const unit u) /* ***** Sail bit type ***** */ -bool eq_bit(const mach_bits a, const mach_bits b) +bool eq_bit(const fbits a, const fbits b) { return a == b; } @@ -231,6 +233,12 @@ void CREATE_OF(sail_int, sail_string)(sail_int *rop, sail_string str) } inline +void CONVERT_OF(sail_int, sail_string)(sail_int *rop, sail_string str) +{ + mpz_set_str(*rop, str, 10); +} + +inline void RECREATE_OF(sail_int, sail_string)(mpz_t *rop, sail_string str) { mpz_set_str(*rop, str, 10); @@ -291,6 +299,12 @@ void shl_int(sail_int *rop, const sail_int op1, const sail_int op2) } inline +mach_int shl_mach_int(const mach_int op1, const mach_int op2) +{ + return op1 << op2; +} + +inline void shr_int(sail_int *rop, const sail_int op1, const sail_int op2) { mpz_fdiv_q_2exp(*rop, op1, mpz_get_ui(op2)); @@ -409,69 +423,144 @@ void pow2(sail_int *rop, const sail_int exp) /* ***** Sail bitvectors ***** */ -bool EQUAL(mach_bits)(const mach_bits op1, const mach_bits op2) +bool EQUAL(fbits)(const fbits op1, const fbits op2) { return op1 == op2; } -void CREATE(sail_bits)(sail_bits *rop) +void CREATE(lbits)(lbits *rop) { rop->bits = malloc(sizeof(mpz_t)); rop->len = 0; mpz_init(*rop->bits); } -void RECREATE(sail_bits)(sail_bits *rop) +void RECREATE(lbits)(lbits *rop) { rop->len = 0; mpz_set_ui(*rop->bits, 0); } -void COPY(sail_bits)(sail_bits *rop, const sail_bits op) +void COPY(lbits)(lbits *rop, const lbits op) { rop->len = op.len; mpz_set(*rop->bits, *op.bits); } -void KILL(sail_bits)(sail_bits *rop) +void KILL(lbits)(lbits *rop) { mpz_clear(*rop->bits); free(rop->bits); } -void CREATE_OF(sail_bits, mach_bits)(sail_bits *rop, const uint64_t op, const uint64_t len, const bool direction) +void CREATE_OF(lbits, fbits)(lbits *rop, const uint64_t op, const uint64_t len, const bool direction) { rop->bits = malloc(sizeof(mpz_t)); rop->len = len; mpz_init_set_ui(*rop->bits, op); } -void RECREATE_OF(sail_bits, mach_bits)(sail_bits *rop, const uint64_t op, const uint64_t len, const bool direction) +fbits CREATE_OF(fbits, lbits)(const lbits op, const bool direction) +{ + return mpz_get_ui(*op.bits); +} + +sbits CREATE_OF(sbits, lbits)(const lbits op, const bool direction) +{ + sbits rop; + rop.bits = mpz_get_ui(*op.bits); + rop.len = op.len; + return rop; +} + +sbits CREATE_OF(sbits, fbits)(const fbits op, const uint64_t len, const bool direction) +{ + sbits rop; + rop.bits = op; + rop.len = len; + return rop; +} + +void RECREATE_OF(lbits, fbits)(lbits *rop, const uint64_t op, const uint64_t len, const bool direction) { rop->len = len; mpz_set_ui(*rop->bits, op); } -mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits op, const bool direction) +void CREATE_OF(lbits, sbits)(lbits *rop, const sbits op, const bool direction) +{ + rop->bits = malloc(sizeof(mpz_t)); + rop->len = op.len; + mpz_init_set_ui(*rop->bits, op.bits); +} + +void RECREATE_OF(lbits, sbits)(lbits *rop, const sbits op, const bool direction) +{ + rop->len = op.len; + mpz_set_ui(*rop->bits, op.bits); +} + +// Bitvector conversions + +inline +fbits CONVERT_OF(fbits, lbits)(const lbits op, const bool direction) { return mpz_get_ui(*op.bits); } -void CONVERT_OF(sail_bits, mach_bits)(sail_bits *rop, const mach_bits op, const uint64_t len, const bool direction) +inline +fbits CONVERT_OF(fbits, sbits)(const sbits op, const bool direction) +{ + return op.bits; +} + +void CONVERT_OF(lbits, fbits)(lbits *rop, const fbits op, const uint64_t len, const bool direction) { rop->len = len; // use safe_rshift to correctly handle the case when we have a 0-length vector. mpz_set_ui(*rop->bits, op & safe_rshift(UINT64_MAX, 64 - len)); } -void UNDEFINED(sail_bits)(sail_bits *rop, const sail_int len, const mach_bits bit) +void CONVERT_OF(lbits, sbits)(lbits *rop, const sbits op, const bool direction) +{ + rop->len = op.len; + mpz_set_ui(*rop->bits, op.bits & safe_rshift(UINT64_MAX, 64 - op.len)); +} + +inline +sbits CONVERT_OF(sbits, fbits)(const fbits op, const uint64_t len, const bool direction) +{ + sbits rop; + rop.len = len; + rop.bits = op; + return rop; +} + +inline +sbits CONVERT_OF(sbits, lbits)(const lbits op, const bool direction) +{ + sbits rop; + rop.len = op.len; + rop.bits = mpz_get_ui(*op.bits); + return rop; +} + +void UNDEFINED(lbits)(lbits *rop, const sail_int len, const fbits bit) { zeros(rop, len); } -mach_bits UNDEFINED(mach_bits)(const unit u) { return 0; } +fbits UNDEFINED(fbits)(const unit u) { return 0; } -mach_bits safe_rshift(const mach_bits x, const mach_bits n) +sbits undefined_sbits(void) +{ + sbits rop; + rop.bits = UINT64_C(0); + rop.len = UINT64_C(0); + return rop; +} + +fbits safe_rshift(const fbits x, const fbits n) { if (n >= 64) { return 0ul; @@ -480,7 +569,7 @@ mach_bits safe_rshift(const mach_bits x, const mach_bits n) } } -void normalize_sail_bits(sail_bits *rop) { +void normalize_lbits(lbits *rop) { /* TODO optimisation: keep a set of masks of various sizes handy */ mpz_set_ui(sail_lib_tmp1, 1); mpz_mul_2exp(sail_lib_tmp1, sail_lib_tmp1, rop->len); @@ -488,64 +577,64 @@ void normalize_sail_bits(sail_bits *rop) { mpz_and(*rop->bits, *rop->bits, sail_lib_tmp1); } -void append_64(sail_bits *rop, const sail_bits op, const mach_bits chunk) +void append_64(lbits *rop, const lbits op, const fbits chunk) { rop->len = rop->len + 64ul; mpz_mul_2exp(*rop->bits, *op.bits, 64ul); mpz_add_ui(*rop->bits, *rop->bits, chunk); } -void add_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void add_bits(lbits *rop, const lbits op1, const lbits op2) { rop->len = op1.len; mpz_add(*rop->bits, *op1.bits, *op2.bits); - normalize_sail_bits(rop); + normalize_lbits(rop); } -void sub_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void sub_bits(lbits *rop, const lbits op1, const lbits op2) { assert(op1.len == op2.len); rop->len = op1.len; mpz_sub(*rop->bits, *op1.bits, *op2.bits); - normalize_sail_bits(rop); + normalize_lbits(rop); } -void add_bits_int(sail_bits *rop, const sail_bits op1, const mpz_t op2) +void add_bits_int(lbits *rop, const lbits op1, const mpz_t op2) { rop->len = op1.len; mpz_add(*rop->bits, *op1.bits, op2); - normalize_sail_bits(rop); + normalize_lbits(rop); } -void sub_bits_int(sail_bits *rop, const sail_bits op1, const mpz_t op2) +void sub_bits_int(lbits *rop, const lbits op1, const mpz_t op2) { rop->len = op1.len; mpz_sub(*rop->bits, *op1.bits, op2); - normalize_sail_bits(rop); + normalize_lbits(rop); } -void and_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void and_bits(lbits *rop, const lbits op1, const lbits op2) { assert(op1.len == op2.len); rop->len = op1.len; mpz_and(*rop->bits, *op1.bits, *op2.bits); } -void or_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void or_bits(lbits *rop, const lbits op1, const lbits op2) { assert(op1.len == op2.len); rop->len = op1.len; mpz_ior(*rop->bits, *op1.bits, *op2.bits); } -void xor_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void xor_bits(lbits *rop, const lbits op1, const lbits op2) { assert(op1.len == op2.len); rop->len = op1.len; mpz_xor(*rop->bits, *op1.bits, *op2.bits); } -void not_bits(sail_bits *rop, const sail_bits op) +void not_bits(lbits *rop, const lbits op) { rop->len = op.len; mpz_set(*rop->bits, *op.bits); @@ -554,7 +643,7 @@ void not_bits(sail_bits *rop, const sail_bits op) } } -void mults_vec(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void mults_vec(lbits *rop, const lbits op1, const lbits op2) { mpz_t op1_int, op2_int; mpz_init(op1_int); @@ -563,33 +652,33 @@ void mults_vec(sail_bits *rop, const sail_bits op1, const sail_bits op2) sail_signed(&op2_int, op2); rop->len = op1.len * 2; mpz_mul(*rop->bits, op1_int, op2_int); - normalize_sail_bits(rop); + normalize_lbits(rop); mpz_clear(op1_int); mpz_clear(op2_int); } -void mult_vec(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void mult_vec(lbits *rop, const lbits op1, const lbits op2) { rop->len = op1.len * 2; mpz_mul(*rop->bits, *op1.bits, *op2.bits); - normalize_sail_bits(rop); /* necessary? */ + normalize_lbits(rop); /* necessary? */ } -void zeros(sail_bits *rop, const sail_int op) +void zeros(lbits *rop, const sail_int op) { rop->len = mpz_get_ui(op); mpz_set_ui(*rop->bits, 0); } -void zero_extend(sail_bits *rop, const sail_bits op, const sail_int len) +void zero_extend(lbits *rop, const lbits op, const sail_int len) { assert(op.len <= mpz_get_ui(len)); rop->len = mpz_get_ui(len); mpz_set(*rop->bits, *op.bits); } -void sign_extend(sail_bits *rop, const sail_bits op, const sail_int len) +void sign_extend(lbits *rop, const lbits op, const sail_int len) { assert(op.len <= mpz_get_ui(len)); rop->len = mpz_get_ui(len); @@ -603,12 +692,12 @@ void sign_extend(sail_bits *rop, const sail_bits op, const sail_int len) } } -void length_sail_bits(sail_int *rop, const sail_bits op) +void length_lbits(sail_int *rop, const lbits op) { mpz_set_ui(*rop, op.len); } -bool eq_bits(const sail_bits op1, const sail_bits op2) +bool eq_bits(const lbits op1, const lbits op2) { assert(op1.len == op2.len); for (mp_bitcnt_t i = 0; i < op1.len; i++) { @@ -617,12 +706,12 @@ bool eq_bits(const sail_bits op1, const sail_bits op2) return true; } -bool EQUAL(sail_bits)(const sail_bits op1, const sail_bits op2) +bool EQUAL(lbits)(const lbits op1, const lbits op2) { return eq_bits(op1, op2); } -bool neq_bits(const sail_bits op1, const sail_bits op2) +bool neq_bits(const lbits op1, const lbits op2) { assert(op1.len == op2.len); for (mp_bitcnt_t i = 0; i < op1.len; i++) { @@ -631,8 +720,8 @@ bool neq_bits(const sail_bits op1, const sail_bits op2) return false; } -void vector_subrange_sail_bits(sail_bits *rop, - const sail_bits op, +void vector_subrange_lbits(lbits *rop, + const lbits op, const sail_int n_mpz, const sail_int m_mpz) { @@ -641,30 +730,40 @@ void vector_subrange_sail_bits(sail_bits *rop, rop->len = n - (m - 1ul); mpz_fdiv_q_2exp(*rop->bits, *op.bits, m); - normalize_sail_bits(rop); + normalize_lbits(rop); } -void sail_truncate(sail_bits *rop, const sail_bits op, const sail_int len) +void sail_truncate(lbits *rop, const lbits op, const sail_int len) { assert(op.len >= mpz_get_ui(len)); rop->len = mpz_get_ui(len); mpz_set(*rop->bits, *op.bits); - normalize_sail_bits(rop); + normalize_lbits(rop); +} + +void sail_truncateLSB(lbits *rop, const lbits op, const sail_int len) +{ + uint64_t rlen = mpz_get_ui(len); + assert(op.len >= rlen); + rop->len = rlen; + // similar to vector_subrange_lbits above -- right shift LSBs away + mpz_fdiv_q_2exp(*rop->bits, *op.bits, op.len - rlen); + normalize_lbits(rop); } -mach_bits bitvector_access(const sail_bits op, const sail_int n_mpz) +fbits bitvector_access(const lbits op, const sail_int n_mpz) { uint64_t n = mpz_get_ui(n_mpz); - return (mach_bits) mpz_tstbit(*op.bits, n); + return (fbits) mpz_tstbit(*op.bits, n); } -void sail_unsigned(sail_int *rop, const sail_bits op) +void sail_unsigned(sail_int *rop, const lbits op) { /* Normal form of bv_t is always positive so just return the bits. */ mpz_set(*rop, *op.bits); } -void sail_signed(sail_int *rop, const sail_bits op) +void sail_signed(sail_int *rop, const lbits op) { if (op.len == 0) { mpz_set_ui(*rop, 0); @@ -682,14 +781,44 @@ void sail_signed(sail_int *rop, const sail_bits op) } } -void append(sail_bits *rop, const sail_bits op1, const sail_bits op2) +inline +mach_int fast_unsigned(const fbits op) +{ + return (mach_int) op; +} + +void append(lbits *rop, const lbits op1, const lbits op2) { rop->len = op1.len + op2.len; mpz_mul_2exp(*rop->bits, *op1.bits, op2.len); mpz_ior(*rop->bits, *rop->bits, *op2.bits); } -void replicate_bits(sail_bits *rop, const sail_bits op1, const mpz_t op2) +sbits append_sf(const sbits op1, const fbits op2, const uint64_t len) +{ + sbits rop; + rop.bits = (op1.bits << len) | op2; + rop.len = op1.len + len; + return rop; +} + +sbits append_fs(const fbits op1, const uint64_t len, const sbits op2) +{ + sbits rop; + rop.bits = (op1 << op2.len) | op2.bits; + rop.len = len + op2.len; + return rop; +} + +sbits append_ss(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = (op1.bits << op2.len) | op2.bits; + rop.len = op1.len + op2.len; + return rop; +} + +void replicate_bits(lbits *rop, const lbits op1, const mpz_t op2) { uint64_t op2_ui = mpz_get_ui(op2); rop->len = op1.len * op2_ui; @@ -725,7 +854,7 @@ uint64_t fast_replicate_bits(const uint64_t shift, const uint64_t v, const int64 // <-------^ // (8 bit) 4 // -void get_slice_int(sail_bits *rop, const sail_int len_mpz, const sail_int n, const sail_int start_mpz) +void get_slice_int(lbits *rop, const sail_int len_mpz, const sail_int n, const sail_int start_mpz) { uint64_t start = mpz_get_ui(start_mpz); uint64_t len = mpz_get_ui(len_mpz); @@ -744,7 +873,7 @@ void set_slice_int(sail_int *rop, const sail_int len_mpz, const sail_int n, const sail_int start_mpz, - const sail_bits slice) + const lbits slice) { uint64_t start = mpz_get_ui(start_mpz); @@ -759,11 +888,11 @@ void set_slice_int(sail_int *rop, } } -void vector_update_subrange_sail_bits(sail_bits *rop, - const sail_bits op, +void vector_update_subrange_lbits(lbits *rop, + const lbits op, const sail_int n_mpz, const sail_int m_mpz, - const sail_bits slice) + const lbits slice) { uint64_t n = mpz_get_ui(n_mpz); uint64_t m = mpz_get_ui(m_mpz); @@ -780,7 +909,7 @@ void vector_update_subrange_sail_bits(sail_bits *rop, } } -void slice(sail_bits *rop, const sail_bits op, const sail_int start_mpz, const sail_int len_mpz) +void slice(lbits *rop, const lbits op, const sail_int start_mpz, const sail_int len_mpz) { assert(mpz_get_ui(start_mpz) + mpz_get_ui(len_mpz) <= op.len); uint64_t start = mpz_get_ui(start_mpz); @@ -794,12 +923,25 @@ void slice(sail_bits *rop, const sail_bits op, const sail_int start_mpz, const s } } -void set_slice(sail_bits *rop, +inline +sbits sslice(const fbits op, const mach_int start, const mach_int len) +{ + sbits rop; +#ifdef INTRINSICS + rop.bits = _bzhi_u64(op >> start, len); +#else + rop.bits = (op >> start) & safe_rshift(UINT64_MAX, 64 - len); +#endif + rop.len = len; + return rop; +} + +void set_slice(lbits *rop, const sail_int len_mpz, const sail_int slen_mpz, - const sail_bits op, + const lbits op, const sail_int start_mpz, - const sail_bits slice) + const lbits slice) { uint64_t start = mpz_get_ui(start_mpz); @@ -815,21 +957,21 @@ void set_slice(sail_bits *rop, } } -void shift_bits_left(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void shift_bits_left(lbits *rop, const lbits op1, const lbits op2) { rop->len = op1.len; mpz_mul_2exp(*rop->bits, *op1.bits, mpz_get_ui(*op2.bits)); - normalize_sail_bits(rop); + normalize_lbits(rop); } -void shift_bits_right(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void shift_bits_right(lbits *rop, const lbits op1, const lbits op2) { rop->len = op1.len; mpz_tdiv_q_2exp(*rop->bits, *op1.bits, mpz_get_ui(*op2.bits)); } /* FIXME */ -void shift_bits_right_arith(sail_bits *rop, const sail_bits op1, const sail_bits op2) +void shift_bits_right_arith(lbits *rop, const lbits op1, const lbits op2) { rop->len = op1.len; mp_bitcnt_t shift_amt = mpz_get_ui(*op2.bits); @@ -843,20 +985,20 @@ void shift_bits_right_arith(sail_bits *rop, const sail_bits op1, const sail_bits } } -void shiftl(sail_bits *rop, const sail_bits op1, const sail_int op2) +void shiftl(lbits *rop, const lbits op1, const sail_int op2) { rop->len = op1.len; mpz_mul_2exp(*rop->bits, *op1.bits, mpz_get_ui(op2)); - normalize_sail_bits(rop); + normalize_lbits(rop); } -void shiftr(sail_bits *rop, const sail_bits op1, const sail_int op2) +void shiftr(lbits *rop, const lbits op1, const sail_int op2) { rop->len = op1.len; mpz_tdiv_q_2exp(*rop->bits, *op1.bits, mpz_get_ui(op2)); } -void reverse_endianness(sail_bits *rop, const sail_bits op) +void reverse_endianness(lbits *rop, const lbits op) { rop->len = op.len; if (rop->len == 64ul) { @@ -890,6 +1032,26 @@ void reverse_endianness(sail_bits *rop, const sail_bits op) } } +inline +bool eq_sbits(const sbits op1, const sbits op2) +{ + return op1.bits == op2.bits; +} + +inline +bool neq_sbits(const sbits op1, const sbits op2) +{ + return op1.bits != op2.bits; +} + +sbits xor_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = op1.bits ^ op2.bits; + rop.len = op1.len; + return rop; +} + /* ***** Sail Reals ***** */ void CREATE(real)(real *rop) @@ -1100,6 +1262,23 @@ void CREATE_OF(real, sail_string)(real *rop, const sail_string op) mpq_add(*rop, *rop, sail_lib_tmp_real); } +void CONVERT_OF(real, sail_string)(real *rop, const sail_string op) +{ + int decimal; + int total; + + gmp_sscanf(op, "%Zd.%n%Zd%n", sail_lib_tmp1, &decimal, sail_lib_tmp2, &total); + + int len = total - decimal; + mpz_ui_pow_ui(sail_lib_tmp3, 10, len); + mpz_set(mpq_numref(*rop), sail_lib_tmp2); + mpz_set(mpq_denref(*rop), sail_lib_tmp3); + mpq_canonicalize(*rop); + mpz_set(mpq_numref(sail_lib_tmp_real), sail_lib_tmp1); + mpz_set_ui(mpq_denref(sail_lib_tmp_real), 1); + mpq_add(*rop, *rop, sail_lib_tmp_real); +} + unit print_real(const sail_string str, const real op) { gmp_printf("%s%Qd\n", str, op); @@ -1127,42 +1306,53 @@ void random_real(real *rop, const unit u) void string_of_int(sail_string *str, const sail_int i) { + free(*str); gmp_asprintf(str, "%Zd", i); } -/* asprinf is a GNU extension, but it should exist on BSD */ -void string_of_mach_bits(sail_string *str, const mach_bits op) +/* asprintf is a GNU extension, but it should exist on BSD */ +void string_of_fbits(sail_string *str, const fbits op) { + free(*str); int bytes = asprintf(str, "0x%" PRIx64, op); if (bytes == -1) { fprintf(stderr, "Could not print bits 0x%" PRIx64 "\n", op); } } -void string_of_sail_bits(sail_string *str, const sail_bits op) +void string_of_lbits(sail_string *str, const lbits op) { + free(*str); if ((op.len % 4) == 0) { - gmp_asprintf(str, "0x%*0Zx", op.len / 4, *op.bits); + gmp_asprintf(str, "0x%*0ZX", op.len / 4, *op.bits); } else { - gmp_asprintf(str, "0b%*0Zb", op.len, *op.bits); + *str = (char *) malloc((op.len + 3) * sizeof(char)); + (*str)[0] = '0'; + (*str)[1] = 'b'; + for (int i = 1; i <= op.len; ++i) { + (*str)[i + 1] = mpz_tstbit(*op.bits, op.len - i) + 0x30; + } + (*str)[op.len + 2] = '\0'; } } -void decimal_string_of_mach_bits(sail_string *str, const mach_bits op) +void decimal_string_of_fbits(sail_string *str, const fbits op) { + free(*str); int bytes = asprintf(str, "%" PRId64, op); if (bytes == -1) { fprintf(stderr, "Could not print bits %" PRId64 "\n", op); } } -void decimal_string_of_sail_bits(sail_string *str, const sail_bits op) +void decimal_string_of_lbits(sail_string *str, const lbits op) { + free(*str); gmp_asprintf(str, "%Z", *op.bits); } void fprint_bits(const sail_string pre, - const sail_bits op, + const lbits op, const sail_string post, FILE *stream) { @@ -1197,13 +1387,13 @@ void fprint_bits(const sail_string pre, fputs(post, stream); } -unit print_bits(const sail_string str, const sail_bits op) +unit print_bits(const sail_string str, const lbits op) { fprint_bits(str, op, "\n", stdout); return UNIT; } -unit prerr_bits(const sail_string str, const sail_bits op) +unit prerr_bits(const sail_string str, const lbits op) { fprint_bits(str, op, "\n", stderr); return UNIT; @@ -1265,3 +1455,14 @@ void get_time_ns(sail_int *rop, const unit u) mpz_mul_ui(*rop, *rop, 1000000000); mpz_add_ui(*rop, *rop, t.tv_nsec); } + +// ARM specific optimisations + +void arm_align(lbits *rop, const lbits x_bv, const sail_int y_mpz) { + uint64_t x = mpz_get_ui(*x_bv.bits); + uint64_t y = mpz_get_ui(y_mpz); + uint64_t z = y * (x / y); + mp_bitcnt_t n = x_bv.len; + mpz_set_ui(*rop->bits, safe_rshift(UINT64_MAX, 64l - (n - 1)) & z); + rop->len = n; +} @@ -4,10 +4,7 @@ #include<stdlib.h> #include<stdio.h> #include<stdbool.h> - -#ifndef USE_INT128 #include<gmp.h> -#endif #include<time.h> @@ -85,14 +82,6 @@ typedef int64_t mach_int; bool EQUAL(mach_int)(const mach_int, const mach_int); -/* - * Integers can be either stack-allocated as 128-bit integers if - * __int128 is available, or use GMP arbitrary precision - * integers. This affects the function signatures, so use a macro to - * paper over the differences. - */ -#ifndef USE_INT128 - typedef mpz_t sail_int; #define SAIL_INT_FUNCTION(fname, rtype, ...) void fname(rtype*, __VA_ARGS__) @@ -107,16 +96,11 @@ mach_int CREATE_OF(mach_int, sail_int)(const sail_int); void CREATE_OF(sail_int, sail_string)(sail_int *, const sail_string); void RECREATE_OF(sail_int, sail_string)(mpz_t *, const sail_string); +void CONVERT_OF(sail_int, sail_string)(sail_int *, const sail_string); + mach_int CONVERT_OF(mach_int, sail_int)(const sail_int); void CONVERT_OF(sail_int, mach_int)(sail_int *, const mach_int); -#else - -typedef __int128 sail_int; -#define SAIL_INT_FUNCTION(fname, rtype, ...) rtype fname(__VA_ARGS__) - -#endif - /* * Comparison operators for integers */ @@ -131,6 +115,7 @@ bool gteq(const sail_int, const sail_int); /* * Left and right shift for integers */ +mach_int shl_mach_int(const mach_int, const mach_int); SAIL_INT_FUNCTION(shl_int, sail_int, const sail_int, const sail_int); SAIL_INT_FUNCTION(shr_int, sail_int, const sail_int, const sail_int); @@ -169,120 +154,164 @@ SAIL_INT_FUNCTION(pow2, sail_int, const sail_int); /* ***** Sail bitvectors ***** */ -typedef uint64_t mach_bits; +typedef uint64_t fbits; + +bool eq_bit(const fbits a, const fbits b); -bool eq_bit(const mach_bits a, const mach_bits b); +bool EQUAL(fbits)(const fbits, const fbits); -bool EQUAL(mach_bits)(const mach_bits, const mach_bits); +typedef struct { + uint64_t len; + uint64_t bits; +} sbits; typedef struct { mp_bitcnt_t len; mpz_t *bits; -} sail_bits; +} lbits; + +// For backwards compatability +typedef uint64_t mach_bits; +typedef lbits sail_bits; -SAIL_BUILTIN_TYPE(sail_bits); +SAIL_BUILTIN_TYPE(lbits); -void CREATE_OF(sail_bits, mach_bits)(sail_bits *, - const mach_bits op, - const mach_bits len, - const bool direction); +void CREATE_OF(lbits, fbits)(lbits *, + const fbits op, + const uint64_t len, + const bool direction); -void RECREATE_OF(sail_bits, mach_bits)(sail_bits *, - const mach_bits op, - const mach_bits len, - const bool direction); +void RECREATE_OF(lbits, fbits)(lbits *, + const fbits op, + const uint64_t len, + const bool direction); -mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits, const bool); -void CONVERT_OF(sail_bits, mach_bits)(sail_bits *, const mach_bits, const uint64_t, const bool); +void CREATE_OF(lbits, sbits)(lbits *, + const sbits op, + const bool direction); -void UNDEFINED(sail_bits)(sail_bits *, const sail_int len, const mach_bits bit); -mach_bits UNDEFINED(mach_bits)(const unit); +void RECREATE_OF(lbits, sbits)(lbits *, + const sbits op, + const bool direction); + +sbits CREATE_OF(sbits, lbits)(const lbits op, const bool direction); +fbits CREATE_OF(fbits, lbits)(const lbits op, const bool direction); +sbits CREATE_OF(sbits, fbits)(const fbits op, const uint64_t len, const bool direction); + +/* Bitvector conversions */ + +fbits CONVERT_OF(fbits, lbits)(const lbits, const bool); +fbits CONVERT_OF(fbits, sbits)(const sbits, const bool); + +void CONVERT_OF(lbits, fbits)(lbits *, const fbits, const uint64_t, const bool); +void CONVERT_OF(lbits, sbits)(lbits *, const sbits, const bool); + +sbits CONVERT_OF(sbits, fbits)(const fbits, const uint64_t, const bool); +sbits CONVERT_OF(sbits, lbits)(const lbits, const bool); + +void UNDEFINED(lbits)(lbits *, const sail_int len, const fbits bit); +fbits UNDEFINED(fbits)(const unit); + +sbits undefined_sbits(void); /* * Wrapper around >> operator to avoid UB when shift amount is greater * than or equal to 64. */ -mach_bits safe_rshift(const mach_bits, const mach_bits); +fbits safe_rshift(const fbits, const fbits); /* * Used internally to construct large bitvector literals. */ -void append_64(sail_bits *rop, const sail_bits op, const mach_bits chunk); +void append_64(lbits *rop, const lbits op, const fbits chunk); -void add_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2); -void sub_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2); +void add_bits(lbits *rop, const lbits op1, const lbits op2); +void sub_bits(lbits *rop, const lbits op1, const lbits op2); -void add_bits_int(sail_bits *rop, const sail_bits op1, const mpz_t op2); -void sub_bits_int(sail_bits *rop, const sail_bits op1, const mpz_t op2); +void add_bits_int(lbits *rop, const lbits op1, const mpz_t op2); +void sub_bits_int(lbits *rop, const lbits op1, const mpz_t op2); -void and_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2); -void or_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2); -void xor_bits(sail_bits *rop, const sail_bits op1, const sail_bits op2); -void not_bits(sail_bits *rop, const sail_bits op); +void and_bits(lbits *rop, const lbits op1, const lbits op2); +void or_bits(lbits *rop, const lbits op1, const lbits op2); +void xor_bits(lbits *rop, const lbits op1, const lbits op2); +void not_bits(lbits *rop, const lbits op); -void mults_vec(sail_bits *rop, const sail_bits op1, const sail_bits op2); -void mult_vec(sail_bits *rop, const sail_bits op1, const sail_bits op2); +void mults_vec(lbits *rop, const lbits op1, const lbits op2); +void mult_vec(lbits *rop, const lbits op1, const lbits op2); -void zeros(sail_bits *rop, const sail_int op); +void zeros(lbits *rop, const sail_int op); -void zero_extend(sail_bits *rop, const sail_bits op, const sail_int len); -void sign_extend(sail_bits *rop, const sail_bits op, const sail_int len); +void zero_extend(lbits *rop, const lbits op, const sail_int len); +void sign_extend(lbits *rop, const lbits op, const sail_int len); -void length_sail_bits(sail_int *rop, const sail_bits op); +void length_lbits(sail_int *rop, const lbits op); -bool eq_bits(const sail_bits op1, const sail_bits op2); -bool EQUAL(sail_bits)(const sail_bits op1, const sail_bits op2); -bool neq_bits(const sail_bits op1, const sail_bits op2); +bool eq_bits(const lbits op1, const lbits op2); +bool EQUAL(lbits)(const lbits op1, const lbits op2); +bool neq_bits(const lbits op1, const lbits op2); -void vector_subrange_sail_bits(sail_bits *rop, - const sail_bits op, +void vector_subrange_lbits(lbits *rop, + const lbits op, const sail_int n_mpz, const sail_int m_mpz); -void sail_truncate(sail_bits *rop, const sail_bits op, const sail_int len); +void sail_truncate(lbits *rop, const lbits op, const sail_int len); +void sail_truncateLSB(lbits *rop, const lbits op, const sail_int len); -mach_bits bitvector_access(const sail_bits op, const sail_int n_mpz); +fbits bitvector_access(const lbits op, const sail_int n_mpz); -void sail_unsigned(sail_int *rop, const sail_bits op); -void sail_signed(sail_int *rop, const sail_bits op); +void sail_unsigned(sail_int *rop, const lbits op); +void sail_signed(sail_int *rop, const lbits op); -void append(sail_bits *rop, const sail_bits op1, const sail_bits op2); +mach_int fast_unsigned(const fbits); -void replicate_bits(sail_bits *rop, const sail_bits op1, const sail_int op2); -mach_bits fast_replicate_bits(const mach_bits shift, const mach_bits v, const mach_int times); +void append(lbits *rop, const lbits op1, const lbits op2); -void get_slice_int(sail_bits *rop, const sail_int len_mpz, const sail_int n, const sail_int start_mpz); +sbits append_sf(const sbits, const fbits, const uint64_t); +sbits append_fs(const fbits, const uint64_t, const sbits); +sbits append_ss(const sbits, const sbits); + +void replicate_bits(lbits *rop, const lbits op1, const sail_int op2); +fbits fast_replicate_bits(const fbits shift, const fbits v, const mach_int times); + +void get_slice_int(lbits *rop, const sail_int len_mpz, const sail_int n, const sail_int start_mpz); void set_slice_int(sail_int *rop, const sail_int len_mpz, const sail_int n, const sail_int start_mpz, - const sail_bits slice); + const lbits slice); -void vector_update_subrange_sail_bits(sail_bits *rop, - const sail_bits op, +void vector_update_subrange_lbits(lbits *rop, + const lbits op, const sail_int n_mpz, const sail_int m_mpz, - const sail_bits slice); + const lbits slice); + +void slice(lbits *rop, const lbits op, const sail_int start_mpz, const sail_int len_mpz); -void slice(sail_bits *rop, const sail_bits op, const sail_int start_mpz, const sail_int len_mpz); +sbits sslice(const fbits op, const mach_int start, const mach_int len); -void set_slice(sail_bits *rop, +void set_slice(lbits *rop, const sail_int len_mpz, const sail_int slen_mpz, - const sail_bits op, + const lbits op, const sail_int start_mpz, - const sail_bits slice); + const lbits slice); -void shift_bits_left(sail_bits *rop, const sail_bits op1, const sail_bits op2); -void shift_bits_right(sail_bits *rop, const sail_bits op1, const sail_bits op2); -void shift_bits_right_arith(sail_bits *rop, const sail_bits op1, const sail_bits op2); +void shift_bits_left(lbits *rop, const lbits op1, const lbits op2); +void shift_bits_right(lbits *rop, const lbits op1, const lbits op2); +void shift_bits_right_arith(lbits *rop, const lbits op1, const lbits op2); -void shiftl(sail_bits *rop, const sail_bits op1, const sail_int op2); -void shiftr(sail_bits *rop, const sail_bits op1, const sail_int op2); +void shiftl(lbits *rop, const lbits op1, const sail_int op2); +void shiftr(lbits *rop, const lbits op1, const sail_int op2); -void reverse_endianness(sail_bits*, sail_bits); +void reverse_endianness(lbits*, lbits); + +bool eq_sbits(const sbits op1, const sbits op2); +bool neq_sbits(const sbits op1, const sbits op2); +sbits xor_sbits(const sbits op1, const sbits op2); /* ***** Sail reals ***** */ @@ -291,6 +320,7 @@ typedef mpq_t real; SAIL_BUILTIN_TYPE(real); void CREATE_OF(real, sail_string)(real *rop, const sail_string op); +void CONVERT_OF(real, sail_string)(real *rop, const sail_string op); void UNDEFINED(real)(real *rop, unit u); @@ -334,21 +364,21 @@ void opt_spc_matches_prefix(zoption *dst, sail_string s); /* ***** Printing ***** */ void string_of_int(sail_string *str, const sail_int i); -void string_of_sail_bits(sail_string *str, const sail_bits op); -void string_of_mach_bits(sail_string *str, const mach_bits op); -void decimal_string_of_sail_bits(sail_string *str, const sail_bits op); -void decimal_string_of_mach_bits(sail_string *str, const mach_bits op); +void string_of_lbits(sail_string *str, const lbits op); +void string_of_fbits(sail_string *str, const fbits op); +void decimal_string_of_lbits(sail_string *str, const lbits op); +void decimal_string_of_fbits(sail_string *str, const fbits op); /* * Utility function not callable from Sail! */ void fprint_bits(const sail_string pre, - const sail_bits op, + const lbits op, const sail_string post, FILE *stream); -unit print_bits(const sail_string str, const sail_bits op); -unit prerr_bits(const sail_string str, const sail_bits op); +unit print_bits(const sail_string str, const lbits op); +unit prerr_bits(const sail_string str, const lbits op); unit print(const sail_string str); unit print_endline(const sail_string str); @@ -364,3 +394,7 @@ unit sail_putchar(const sail_int op); /* ***** Misc ***** */ void get_time_ns(sail_int*, const unit); + +/* ***** ARM optimisations ***** */ + +void arm_align(lbits *, const lbits, const sail_int); diff --git a/lib/smt.sail b/lib/smt.sail index efcbe48c..c57f7bd1 100644 --- a/lib/smt.sail +++ b/lib/smt.sail @@ -9,7 +9,7 @@ val div = { lem: "integerDiv", c: "tdiv_int", coq: "div_with_eq" -} : forall 'n 'm. (atom('n), atom('m)) -> {'o, 'o = div('n, 'm). atom('o)} +} : forall 'n 'm. (atom('n), atom('m)) -> {'o, 'o == div('n, 'm). atom('o)} overload operator / = {div} @@ -19,7 +19,7 @@ val mod = { lem: "integerMod", c: "tmod_int", coq: "mod_with_eq" -} : forall 'n 'm. (atom('n), atom('m)) -> {'o, 'o = mod('n, 'm). atom('o)} +} : forall 'n 'm. (atom('n), atom('m)) -> {'o, 'o == mod('n, 'm). atom('o)} overload operator % = {mod} @@ -29,7 +29,7 @@ val abs_atom = { lem: "abs_int", c: "abs_int", coq: "abs_with_eq" -} : forall 'n. atom('n) -> {'o, 'o = abs_atom('n). atom('o)} +} : forall 'n. atom('n) -> {'o, 'o == abs_atom('n). atom('o)} $ifdef TEST diff --git a/lib/vector_dec.sail b/lib/vector_dec.sail index 9ac4d1a5..0e97c237 100644 --- a/lib/vector_dec.sail +++ b/lib/vector_dec.sail @@ -5,8 +5,6 @@ $include <flow.sail> type bits ('n : Int) = vector('n, dec, bit) -val eq_bit = { lem : "eq", _ : "eq_bit" } : (bit, bit) -> bool - val eq_bits = { ocaml: "eq_list", interpreter: "eq_list", @@ -39,6 +37,9 @@ val sail_sign_extend = "sign_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom(' val sail_zero_extend = "zero_extend" : forall 'n 'm, 'm >= 'n. (bits('n), atom('m)) -> bits('m) +/*! +THIS`(v, n)` truncates `v`, keeping only the _least_ significant `n` bits. + */ val truncate = { ocaml: "vector_truncate", interpreter: "vector_truncate", @@ -47,6 +48,16 @@ val truncate = { c: "sail_truncate" } : forall 'm 'n, 'm >= 0 & 'm <= 'n. (vector('n, dec, bit), atom('m)) -> vector('m, dec, bit) +/*! +THIS`(v, n)` truncates `v`, keeping only the _most_ significant `n` bits. + */ +val truncateLSB = { + ocaml: "vector_truncateLSB", + lem: "vector_truncateLSB", + coq: "vector_truncateLSB", + c: "sail_truncateLSB" +} : forall 'm 'n, 'm >= 0 & 'm <= 'n. (vector('n, dec, bit), atom('m)) -> vector('m, dec, bit) + val sail_mask : forall 'len 'v, 'len >= 0 & 'v >= 0. (atom('len), vector('v, dec, bit)) -> vector('len, dec, bit) function sail_mask(len, v) = if len <= length(v) then truncate(v, len) else sail_zero_extend(v, len) @@ -67,7 +78,7 @@ val bitvector_access = { lem: "access_vec_dec", coq: "access_vec_dec", c: "vector_access" -} : forall ('n : Int), 'n >= 0. (bits('n), int) -> bit +} : forall ('n : Int) ('m : Int), 0 <= 'm < 'n . (bits('n), int('m)) -> bit val plain_vector_access = { ocaml: "access", @@ -85,7 +96,7 @@ val bitvector_update = { lem: "update_vec_dec", coq: "update_vec_dec", c: "vector_update" -} : forall 'n, 'n >= 0. (bits('n), int, bit) -> bits('n) +} : forall 'n 'm, 0 <= 'm < 'n. (bits('n), atom('m), bit) -> bits('n) val plain_vector_update = { ocaml: "update", @@ -93,7 +104,7 @@ val plain_vector_update = { lem: "update_list_dec", coq: "vec_update_dec", c: "vector_update" -} : forall 'n ('a : Type). (vector('n, dec, 'a), int, 'a) -> vector('n, dec, 'a) +} : forall 'n 'm ('a : Type), 0 <= 'm < 'n. (vector('n, dec, 'a), atom('m), 'a) -> vector('n, dec, 'a) overload vector_update = {bitvector_update, plain_vector_update} @@ -130,7 +141,7 @@ val vector_update_subrange = { lem: "update_subrange_vec_dec", c: "vector_update_subrange", coq: "update_subrange_vec_dec" -} : forall 'n 'm 'o. (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n) +} : forall 'n 'm 'o, 0 <= 'o <= 'm < 'n. (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n) // Some ARM specific builtins @@ -146,6 +157,9 @@ val slice = "slice" : forall 'n 'm 'o, 0 <= 'o < 'm & 'o + 'n <= 'm & 0 <= 'n. val replicate_bits = "replicate_bits" : forall 'n 'm. (bits('n), atom('m)) -> bits('n * 'm) +/*! +converts a bit vector of length $n$ to an integer in the range $0$ to $2^n - 1$. + */ val unsigned = { ocaml: "uint", lem: "uint", @@ -155,6 +169,9 @@ val unsigned = { } : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1) /* We need a non-empty vector so that the range makes sense */ +/*! +converts a bit vector of length $n$ to an integer in the range $-2^{n-1}$ to $2^{n-1} - 1$ using twos-complement. + */ val signed = { c: "sail_signed", _: "sint" diff --git a/lib/vector_inc.sail b/lib/vector_inc.sail index b8e1b91f..daba99be 100644 --- a/lib/vector_inc.sail +++ b/lib/vector_inc.sail @@ -65,7 +65,7 @@ val bitvector_access = { lem: "access_vec_inc", coq: "access_vec_inc", c: "vector_access" -} : forall ('n : Int), 'n >= 0. (bits('n), int) -> bit +} : forall ('n : Int) ('m : Int), 0 <= 'm < 'n . (bits('n), int('m)) -> bit val plain_vector_access = { ocaml: "access", @@ -83,7 +83,7 @@ val bitvector_update = { lem: "update_vec_inc", coq: "update_vec_inc", c: "vector_update" -} : forall 'n, 'n >= 0. (bits('n), int, bit) -> bits('n) +} : forall 'n 'm, 0 <= 'm < 'n. (bits('n), atom('m), bit) -> bits('n) val plain_vector_update = { ocaml: "update", @@ -91,7 +91,7 @@ val plain_vector_update = { lem: "update_list_inc", coq: "update_list_inc", c: "vector_update" -} : forall 'n ('a : Type). (vector('n, inc, 'a), int, 'a) -> vector('n, inc, 'a) +} : forall 'n 'm ('a : Type), 0 <= 'm < 'n. (vector('n, inc, 'a), atom('m), 'a) -> vector('n, inc, 'a) overload vector_update = {bitvector_update, plain_vector_update} @@ -123,7 +123,7 @@ val vector_update_subrange = { lem: "update_subrange_vec_inc", c: "vector_update_subrange", coq: "update_subrange_vec_inc" -} : forall 'n 'm 'o. (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n) +} : forall 'n 'm 'o, 0 <= 'm <= 'o < 'n. (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n) // Some ARM specific builtins |
