summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorThomas Bauereiss2019-03-15 14:44:38 +0000
committerThomas Bauereiss2019-03-15 18:47:30 +0000
commit541c1880d31a47302fea48725bd7247d374828d6 (patch)
tree24d72fe0ae7d79ce361c93ae101fd83d4a6a3b5a /lib
parent7da62ecee7d9eb2d16d42f9f8c5a5910b0950849 (diff)
Make mono_rewrites less dependant on ASL prelude
... so that it can be more easily used for other specs. Also add some functions to vector_dec.sail to support this.
Diffstat (limited to 'lib')
-rw-r--r--lib/mono_rewrites.sail108
-rw-r--r--lib/vector_dec.sail60
2 files changed, 108 insertions, 60 deletions
diff --git a/lib/mono_rewrites.sail b/lib/mono_rewrites.sail
index 9e4010a0..5e20fc71 100644
--- a/lib/mono_rewrites.sail
+++ b/lib/mono_rewrites.sail
@@ -1,23 +1,12 @@
-/* Definitions for use with the -mono_rewrites option */
-
-/* External definitions not in the usual asl prelude */
-
-infix 6 <<
-
-val shiftleft = "shiftl" : forall 'n ('ord : Order).
- (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
-
-overload operator << = {shiftleft}
-
-infix 6 >>
+$ifndef _MONO_REWRITES
+$define _MONO_REWRITES
-val shiftright = "shiftr" : forall 'n ('ord : Order).
- (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+/* Definitions for use with the -mono_rewrites option */
-overload operator >> = {shiftright}
+$include <arith.sail>
+$include <vector_dec.sail>
-val arith_shiftright = "arith_shiftr" : forall 'n ('ord : Order).
- (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+/* External definitions not in the usual asl prelude */
val extzv = "extz_vec" : forall 'n 'm. (implicit('m), vector('n, dec, bit)) -> vector('m, dec, bit) effect pure
@@ -30,23 +19,18 @@ val bitvector_cast_out = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect p
/* Definitions for the rewrites */
-val slice_mask : forall 'n, 'n >= 0. (implicit('n), int, int) -> bits('n) effect pure
-function slice_mask(n,i,l) =
- let one : bits('n) = extzv(n, 0b1) in
- ((one << l) - one) << i
-
val is_zero_subrange : forall 'n, 'n >= 0.
(bits('n), int, int) -> bool effect pure
function is_zero_subrange (xs, i, j) = {
- (xs & slice_mask(j, i-j+1)) == extzv(0b0)
+ (xs & slice_mask(j, i-j+1)) == extzv([bitzero] : bits(1))
}
val is_zeros_slice : forall 'n, 'n >= 0.
(bits('n), int, int) -> bool effect pure
function is_zeros_slice (xs, i, l) = {
- (xs & slice_mask(i, l)) == extzv(0b0)
+ (xs & slice_mask(i, l)) == extzv([bitzero] : bits(1))
}
val is_ones_subrange : forall 'n, 'n >= 0.
@@ -69,17 +53,17 @@ val slice_slice_concat : forall 'n 'm 'r, 'n >= 0 & 'm >= 0 & 'r >= 0.
(implicit('r), bits('n), int, int, bits('m), int, int) -> bits('r) effect pure
function slice_slice_concat (r, xs, i, l, ys, i', l') = {
- let xs = (xs & slice_mask(i,l)) >> i in
- let ys = (ys & slice_mask(i',l')) >> i' in
- extzv(r, xs) << l' | extzv(r, ys)
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
+ let ys = sail_shiftright(ys & slice_mask(i',l'), i') in
+ sail_shiftleft(extzv(r, xs), l') | extzv(r, ys)
}
val slice_zeros_concat : forall 'n 'p 'q, 'n >= 0 & 'p + 'q >= 0.
(bits('n), int, atom('p), atom('q)) -> bits('p + 'q) effect pure
function slice_zeros_concat (xs, i, l, l') = {
- let xs = (xs & slice_mask(i,l)) >> i in
- extzv(l + l', xs) << l'
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
+ sail_shiftleft(extzv(l + l', xs), l')
}
/* Assumes initial vectors are of equal size */
@@ -88,8 +72,8 @@ val subrange_subrange_eq : forall 'n, 'n >= 0.
(bits('n), int, int, bits('n), int, int) -> bool effect pure
function subrange_subrange_eq (xs, i, j, ys, i', j') = {
- let xs = (xs & slice_mask(j,i-j+1)) >> j in
- let ys = (ys & slice_mask(j',i'-j'+1)) >> j' in
+ let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in
+ let ys = sail_shiftright(ys & slice_mask(j',i'-j'+1), j') in
xs == ys
}
@@ -97,25 +81,25 @@ val subrange_subrange_concat : forall 'n 'o 'p 'm 'q 'r 's, 's >= 0 & 'n >= 0 &
(implicit('s), bits('n), atom('o), atom('p), bits('m), atom('q), atom('r)) -> bits('s) effect pure
function subrange_subrange_concat (s, xs, i, j, ys, i', j') = {
- let xs = (xs & slice_mask(j,i-j+1)) >> j in
- let ys = (ys & slice_mask(j',i'-j'+1)) >> j' in
- extzv(s, xs) << (i' - j' + 1) | extzv(s, ys)
+ let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in
+ let ys = sail_shiftright(ys & slice_mask(j',i'-j'+1), j) in
+ sail_shiftleft(extzv(s, xs), i' - j' + 1) | extzv(s, ys)
}
val place_subrange : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
function place_subrange(m,xs,i,j,shift) = {
- let xs = (xs & slice_mask(j,i-j+1)) >> j in
- extzv(m, xs) << shift
+ let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in
+ sail_shiftleft(extzv(m, xs), shift)
}
val place_slice : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
function place_slice(m,xs,i,l,shift) = {
- let xs = (xs & slice_mask(i,l)) >> i in
- extzv(m, xs) << shift
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
+ sail_shiftleft(extzv(m, xs), shift)
}
val set_slice_zeros : forall 'n, 'n >= 0.
@@ -123,14 +107,14 @@ val set_slice_zeros : forall 'n, 'n >= 0.
function set_slice_zeros(n, xs, i, l) = {
let ys : bits('n) = slice_mask(n, i, l) in
- xs & ~(ys)
+ xs & not_vec(ys)
}
val zext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int) -> bits('m) effect pure
function zext_slice(m,xs,i,l) = {
- let xs = (xs & slice_mask(i,l)) >> i in
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
extzv(m, xs)
}
@@ -138,7 +122,7 @@ val sext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int) -> bits('m) effect pure
function sext_slice(m,xs,i,l) = {
- let xs = arith_shiftright(((xs & slice_mask(i,l)) << ('n - i - l)), 'n - l) in
+ let xs = sail_arith_shiftright(sail_shiftleft((xs & slice_mask(i,l)), ('n - i - l)), 'n - l) in
extsv(m, xs)
}
@@ -146,7 +130,7 @@ val place_slice_signed : forall 'n 'm, 'n >= 0 & 'm >= 0.
(implicit('m), bits('n), int, int, int) -> bits('m) effect pure
function place_slice_signed(m,xs,i,l,shift) = {
- sext_slice(m, xs, i, l) << shift
+ sail_shiftleft(sext_slice(m, xs, i, l), shift)
}
/* This has different names in the aarch64 prelude (UInt) and the other
@@ -157,28 +141,46 @@ val _builtin_unsigned = {
lem: "uint",
interpreter: "uint",
c: "sail_uint"
-} : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1)
+} : forall 'n. bits('n) -> {'m, 0 <= 'm < 2 ^ 'n. int('m)}
+
+/* There are different implementation choices for division and remainder, but
+ they agree on positive values. We use this here to give more precise return
+ types for unsigned_slice and unsigned_subrange. */
-val unsigned_slice : forall 'n, 'n >= 0.
- (bits('n), int, int) -> int effect pure
+val _builtin_mod_nat = {
+ smt: "mod",
+ ocaml: "modulus",
+ lem: "integerMod",
+ c: "tmod_int",
+ coq: "Z.rem"
+} : forall 'n 'm, 'n >= 0 & 'm >= 0. (int('n), int('m)) -> {'r, 0 <= 'r < 'm. int('r)}
+
+/* Below we need the fact that 2 ^ 'n >= 0, so we axiomatise it in the return
+ type of pow2, as SMT solvers tend to have problems with exponentiation. */
+val _builtin_pow2 = "pow2" : forall 'n, 'n >= 0. int('n) -> {'m, 'm == 2 ^ 'n & 'm >= 0. int('m)}
+
+val unsigned_slice : forall 'n 'l, 'n >= 0 & 'l >= 0.
+ (bits('n), int, int('l)) -> {'m, 0 <= 'm < 2 ^ 'l. int('m)} effect pure
function unsigned_slice(xs,i,l) = {
- let xs = (xs & slice_mask(i,l)) >> i in
- _builtin_unsigned(xs)
+ let xs = sail_shiftright(xs & slice_mask(i,l), i) in
+ _builtin_mod_nat(_builtin_unsigned(xs), _builtin_pow2(l))
}
-val unsigned_subrange : forall 'n, 'n >= 0.
- (bits('n), int, int) -> int effect pure
+val unsigned_subrange : forall 'n 'i 'j, 'n >= 0 & ('i - 'j) >= 0.
+ (bits('n), int('i), int('j)) -> {'m, 0 <= 'm < 2 ^ ('i - 'j + 1). int('m)} effect pure
function unsigned_subrange(xs,i,j) = {
- let xs = (xs & slice_mask(j,i-j+1)) >> i in
- _builtin_unsigned(xs)
+ let xs = sail_shiftright(xs & slice_mask(j,i-j+1), i) in
+ _builtin_mod_nat(_builtin_unsigned(xs), _builtin_pow2(i - j + 1))
}
val zext_ones : forall 'n, 'n >= 0. (implicit('n), int) -> bits('n) effect pure
function zext_ones(n, m) = {
- let v : bits('n) = extsv(0b1) in
- v >> (n - m)
+ let v : bits('n) = extsv([bitone] : bits(1)) in
+ sail_shiftright(v, n - m)
}
+
+$endif
diff --git a/lib/vector_dec.sail b/lib/vector_dec.sail
index 8c6426d4..9eea3112 100644
--- a/lib/vector_dec.sail
+++ b/lib/vector_dec.sail
@@ -14,6 +14,16 @@ val eq_bits = {
overload operator == = {eq_bit, eq_bits}
+val neq_bits = {
+ lem: "neq_vec",
+ c: "neq_bits",
+ coq: "neq_vec"
+} : forall 'n. (vector('n, dec, bit), vector('n, dec, bit)) -> bool
+
+function neq_bits(x, y) = not_bool(eq_bits(x, y))
+
+overload operator != = {neq_bits}
+
val bitvector_length = {coq: "length_mword", _:"length"} : forall 'n. bits('n) -> atom('n)
val vector_length = {
@@ -25,8 +35,6 @@ val vector_length = {
overload length = {bitvector_length, vector_length}
-val sail_zeros = "zeros" : forall 'n. atom('n) -> bits('n)
-
val "print_bits" : forall 'n. (string, bits('n)) -> unit
val "prerr_bits" : forall 'n. (string, bits('n)) -> unit
@@ -117,6 +125,23 @@ val add_bits_int = {
overload operator + = {add_bits, add_bits_int}
+val sub_bits = {
+ ocaml: "sub_vec",
+ lem: "sub_vec",
+ c: "sub_bits",
+ coq: "sub_vec"
+} : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+val not_vec = {c: "not_bits", _: "not_vec"} : forall 'n. bits('n) -> bits('n)
+
+val and_vec = {lem: "and_vec", c: "and_bits", coq: "and_vec", ocaml: "and_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+overload operator & = {and_vec}
+
+val or_vec = {lem: "or_vec", c: "or_bits", coq: "or_vec", ocaml: "or_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+overload operator | = {or_vec}
+
val vector_subrange = {
ocaml: "subrange",
lem: "subrange_vec_dec",
@@ -132,8 +157,34 @@ val vector_update_subrange = {
coq: "update_subrange_vec_dec"
} : forall 'n 'm 'o, 0 <= 'o <= 'm < 'n. (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n)
+val sail_shiftleft = "shiftl" : forall 'n ('ord : Order).
+ (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+
+val sail_shiftright = "shiftr" : forall 'n ('ord : Order).
+ (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+
+val sail_arith_shiftright = "arith_shiftr" : forall 'n ('ord : Order).
+ (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure
+
+val sail_zeros = "zeros" : forall 'n. atom('n) -> bits('n)
+
+val sail_ones : forall 'n. atom('n) -> bits('n)
+
+function sail_ones(n) = not_vec(sail_zeros(n))
+
// Some ARM specific builtins
+val replicate_bits = "replicate_bits" : forall 'n 'm. (bits('n), atom('m)) -> bits('n * 'm)
+
+val slice_mask : forall 'n, 'n >= 0. (implicit('n), int, int) -> bits('n) effect pure
+function slice_mask(n,i,l) =
+ if l >= n then {
+ sail_ones(n)
+ } else {
+ let one : bits('n) = sail_mask(n, [bitone] : bits(1)) in
+ sail_shiftleft(sub_bits(sail_shiftleft(one, l), one), i)
+ }
+
val get_slice_int = "get_slice_int" : forall 'w. (atom('w), int, int) -> bits('w)
val set_slice_int = "set_slice_int" : forall 'w. (atom('w), int, int, bits('w)) -> int
@@ -141,11 +192,6 @@ val set_slice_int = "set_slice_int" : forall 'w. (atom('w), int, int, bits('w))
val set_slice_bits = "set_slice" : forall 'n 'm.
(atom('n), atom('m), bits('n), int, bits('m)) -> bits('n)
-val slice = "slice" : forall 'n 'm 'o, 0 <= 'o < 'm & 'o + 'n <= 'm & 0 <= 'n.
- (bits('m), atom('o), atom('n)) -> bits('n)
-
-val replicate_bits = "replicate_bits" : forall 'n 'm. (bits('n), atom('m)) -> bits('n * 'm)
-
/*!
converts a bit vector of length $n$ to an integer in the range $0$ to $2^n - 1$.
*/