diff options
| author | Jon French | 2019-04-15 16:18:18 +0100 |
|---|---|---|
| committer | Jon French | 2019-04-15 16:18:18 +0100 |
| commit | a9f0b829507e9882efdb59cce4d83ea7e87f5f71 (patch) | |
| tree | 11cde6c1918bc15f4dda9a8e40afd4a1fe912a0a /lib/mono_rewrites.sail | |
| parent | 0f6fd188ca232cb539592801fcbb873d59611d81 (diff) | |
| parent | 57443173923e87f33713c99dbab9eba7e3db0660 (diff) | |
Merge branch 'sail2' into rmem_interpreter
Diffstat (limited to 'lib/mono_rewrites.sail')
| -rw-r--r-- | lib/mono_rewrites.sail | 108 |
1 files changed, 55 insertions, 53 deletions
diff --git a/lib/mono_rewrites.sail b/lib/mono_rewrites.sail index 9e4010a0..5e20fc71 100644 --- a/lib/mono_rewrites.sail +++ b/lib/mono_rewrites.sail @@ -1,23 +1,12 @@ -/* Definitions for use with the -mono_rewrites option */ - -/* External definitions not in the usual asl prelude */ - -infix 6 << - -val shiftleft = "shiftl" : forall 'n ('ord : Order). - (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure - -overload operator << = {shiftleft} - -infix 6 >> +$ifndef _MONO_REWRITES +$define _MONO_REWRITES -val shiftright = "shiftr" : forall 'n ('ord : Order). - (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure +/* Definitions for use with the -mono_rewrites option */ -overload operator >> = {shiftright} +$include <arith.sail> +$include <vector_dec.sail> -val arith_shiftright = "arith_shiftr" : forall 'n ('ord : Order). - (vector('n, 'ord, bit), int) -> vector('n, 'ord, bit) effect pure +/* External definitions not in the usual asl prelude */ val extzv = "extz_vec" : forall 'n 'm. (implicit('m), vector('n, dec, bit)) -> vector('m, dec, bit) effect pure @@ -30,23 +19,18 @@ val bitvector_cast_out = "zeroExtend" : forall 'n. bits('n) -> bits('n) effect p /* Definitions for the rewrites */ -val slice_mask : forall 'n, 'n >= 0. (implicit('n), int, int) -> bits('n) effect pure -function slice_mask(n,i,l) = - let one : bits('n) = extzv(n, 0b1) in - ((one << l) - one) << i - val is_zero_subrange : forall 'n, 'n >= 0. (bits('n), int, int) -> bool effect pure function is_zero_subrange (xs, i, j) = { - (xs & slice_mask(j, i-j+1)) == extzv(0b0) + (xs & slice_mask(j, i-j+1)) == extzv([bitzero] : bits(1)) } val is_zeros_slice : forall 'n, 'n >= 0. (bits('n), int, int) -> bool effect pure function is_zeros_slice (xs, i, l) = { - (xs & slice_mask(i, l)) == extzv(0b0) + (xs & slice_mask(i, l)) == extzv([bitzero] : bits(1)) } val is_ones_subrange : forall 'n, 'n >= 0. @@ -69,17 +53,17 @@ val slice_slice_concat : forall 'n 'm 'r, 'n >= 0 & 'm >= 0 & 'r >= 0. (implicit('r), bits('n), int, int, bits('m), int, int) -> bits('r) effect pure function slice_slice_concat (r, xs, i, l, ys, i', l') = { - let xs = (xs & slice_mask(i,l)) >> i in - let ys = (ys & slice_mask(i',l')) >> i' in - extzv(r, xs) << l' | extzv(r, ys) + let xs = sail_shiftright(xs & slice_mask(i,l), i) in + let ys = sail_shiftright(ys & slice_mask(i',l'), i') in + sail_shiftleft(extzv(r, xs), l') | extzv(r, ys) } val slice_zeros_concat : forall 'n 'p 'q, 'n >= 0 & 'p + 'q >= 0. (bits('n), int, atom('p), atom('q)) -> bits('p + 'q) effect pure function slice_zeros_concat (xs, i, l, l') = { - let xs = (xs & slice_mask(i,l)) >> i in - extzv(l + l', xs) << l' + let xs = sail_shiftright(xs & slice_mask(i,l), i) in + sail_shiftleft(extzv(l + l', xs), l') } /* Assumes initial vectors are of equal size */ @@ -88,8 +72,8 @@ val subrange_subrange_eq : forall 'n, 'n >= 0. (bits('n), int, int, bits('n), int, int) -> bool effect pure function subrange_subrange_eq (xs, i, j, ys, i', j') = { - let xs = (xs & slice_mask(j,i-j+1)) >> j in - let ys = (ys & slice_mask(j',i'-j'+1)) >> j' in + let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in + let ys = sail_shiftright(ys & slice_mask(j',i'-j'+1), j') in xs == ys } @@ -97,25 +81,25 @@ val subrange_subrange_concat : forall 'n 'o 'p 'm 'q 'r 's, 's >= 0 & 'n >= 0 & (implicit('s), bits('n), atom('o), atom('p), bits('m), atom('q), atom('r)) -> bits('s) effect pure function subrange_subrange_concat (s, xs, i, j, ys, i', j') = { - let xs = (xs & slice_mask(j,i-j+1)) >> j in - let ys = (ys & slice_mask(j',i'-j'+1)) >> j' in - extzv(s, xs) << (i' - j' + 1) | extzv(s, ys) + let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in + let ys = sail_shiftright(ys & slice_mask(j',i'-j'+1), j) in + sail_shiftleft(extzv(s, xs), i' - j' + 1) | extzv(s, ys) } val place_subrange : forall 'n 'm, 'n >= 0 & 'm >= 0. (implicit('m), bits('n), int, int, int) -> bits('m) effect pure function place_subrange(m,xs,i,j,shift) = { - let xs = (xs & slice_mask(j,i-j+1)) >> j in - extzv(m, xs) << shift + let xs = sail_shiftright(xs & slice_mask(j,i-j+1), j) in + sail_shiftleft(extzv(m, xs), shift) } val place_slice : forall 'n 'm, 'n >= 0 & 'm >= 0. (implicit('m), bits('n), int, int, int) -> bits('m) effect pure function place_slice(m,xs,i,l,shift) = { - let xs = (xs & slice_mask(i,l)) >> i in - extzv(m, xs) << shift + let xs = sail_shiftright(xs & slice_mask(i,l), i) in + sail_shiftleft(extzv(m, xs), shift) } val set_slice_zeros : forall 'n, 'n >= 0. @@ -123,14 +107,14 @@ val set_slice_zeros : forall 'n, 'n >= 0. function set_slice_zeros(n, xs, i, l) = { let ys : bits('n) = slice_mask(n, i, l) in - xs & ~(ys) + xs & not_vec(ys) } val zext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0. (implicit('m), bits('n), int, int) -> bits('m) effect pure function zext_slice(m,xs,i,l) = { - let xs = (xs & slice_mask(i,l)) >> i in + let xs = sail_shiftright(xs & slice_mask(i,l), i) in extzv(m, xs) } @@ -138,7 +122,7 @@ val sext_slice : forall 'n 'm, 'n >= 0 & 'm >= 0. (implicit('m), bits('n), int, int) -> bits('m) effect pure function sext_slice(m,xs,i,l) = { - let xs = arith_shiftright(((xs & slice_mask(i,l)) << ('n - i - l)), 'n - l) in + let xs = sail_arith_shiftright(sail_shiftleft((xs & slice_mask(i,l)), ('n - i - l)), 'n - l) in extsv(m, xs) } @@ -146,7 +130,7 @@ val place_slice_signed : forall 'n 'm, 'n >= 0 & 'm >= 0. (implicit('m), bits('n), int, int, int) -> bits('m) effect pure function place_slice_signed(m,xs,i,l,shift) = { - sext_slice(m, xs, i, l) << shift + sail_shiftleft(sext_slice(m, xs, i, l), shift) } /* This has different names in the aarch64 prelude (UInt) and the other @@ -157,28 +141,46 @@ val _builtin_unsigned = { lem: "uint", interpreter: "uint", c: "sail_uint" -} : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1) +} : forall 'n. bits('n) -> {'m, 0 <= 'm < 2 ^ 'n. int('m)} + +/* There are different implementation choices for division and remainder, but + they agree on positive values. We use this here to give more precise return + types for unsigned_slice and unsigned_subrange. */ -val unsigned_slice : forall 'n, 'n >= 0. - (bits('n), int, int) -> int effect pure +val _builtin_mod_nat = { + smt: "mod", + ocaml: "modulus", + lem: "integerMod", + c: "tmod_int", + coq: "Z.rem" +} : forall 'n 'm, 'n >= 0 & 'm >= 0. (int('n), int('m)) -> {'r, 0 <= 'r < 'm. int('r)} + +/* Below we need the fact that 2 ^ 'n >= 0, so we axiomatise it in the return + type of pow2, as SMT solvers tend to have problems with exponentiation. */ +val _builtin_pow2 = "pow2" : forall 'n, 'n >= 0. int('n) -> {'m, 'm == 2 ^ 'n & 'm >= 0. int('m)} + +val unsigned_slice : forall 'n 'l, 'n >= 0 & 'l >= 0. + (bits('n), int, int('l)) -> {'m, 0 <= 'm < 2 ^ 'l. int('m)} effect pure function unsigned_slice(xs,i,l) = { - let xs = (xs & slice_mask(i,l)) >> i in - _builtin_unsigned(xs) + let xs = sail_shiftright(xs & slice_mask(i,l), i) in + _builtin_mod_nat(_builtin_unsigned(xs), _builtin_pow2(l)) } -val unsigned_subrange : forall 'n, 'n >= 0. - (bits('n), int, int) -> int effect pure +val unsigned_subrange : forall 'n 'i 'j, 'n >= 0 & ('i - 'j) >= 0. + (bits('n), int('i), int('j)) -> {'m, 0 <= 'm < 2 ^ ('i - 'j + 1). int('m)} effect pure function unsigned_subrange(xs,i,j) = { - let xs = (xs & slice_mask(j,i-j+1)) >> i in - _builtin_unsigned(xs) + let xs = sail_shiftright(xs & slice_mask(j,i-j+1), i) in + _builtin_mod_nat(_builtin_unsigned(xs), _builtin_pow2(i - j + 1)) } val zext_ones : forall 'n, 'n >= 0. (implicit('n), int) -> bits('n) effect pure function zext_ones(n, m) = { - let v : bits('n) = extsv(0b1) in - v >> (n - m) + let v : bits('n) = extsv([bitone] : bits(1)) in + sail_shiftright(v, n - m) } + +$endif |
