summaryrefslogtreecommitdiff
path: root/lib/arith.sail
blob: 371a4a45b99556e0b77b442dbd9ce8b219e4d4d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
$ifndef _ARITH
$define _ARITH

$include <flow.sail>

// ***** Addition *****

val add_atom = {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", c: "add_int", coq: "Z.add"} : forall 'n 'm.
  (int('n), int('m)) -> int('n + 'm)

val add_int = {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", c: "add_int", coq: "Z.add"} : (int, int) -> int

overload operator + = {add_atom, add_int}

// ***** Subtraction *****

val sub_atom = {ocaml: "sub_int", interpreter: "sub_int", lem: "integerMinus", c: "sub_int", coq: "Z.sub"} : forall 'n 'm.
  (int('n), int('m)) -> int('n - 'm)

val sub_int = {ocaml: "sub_int", interpreter: "sub_int", lem: "integerMinus", c: "sub_int", coq: "Z.sub"} : (int, int) -> int

overload operator - = {sub_atom, sub_int}

val sub_nat = {
  ocaml: "(fun (x,y) -> let n = sub_int (x,y) in if Big_int.less_equal n Big_int.zero then Big_int.zero else n)",
  lem: "integerMinus",
  _: "sub_nat"
} : (nat, nat) -> nat

// ***** Negation *****

val negate_atom = {ocaml: "negate", interpreter: "negate", lem: "integerNegate", c: "neg_int", coq: "Z.opp"} : forall 'n. int('n) -> int(- 'n)

val negate_int = {ocaml: "negate", interpreter: "negate", lem: "integerNegate", c: "neg_int", coq: "Z.opp"} : int -> int

overload negate = {negate_atom, negate_int}

// ***** Multiplication *****

val mult_atom = {ocaml: "mult", interpreter: "mult", lem: "integerMult", c: "mult_int", coq: "Z.mul"} : forall 'n 'm.
  (int('n), int('m)) -> int('n * 'm)

val mult_int = {ocaml: "mult", interpreter: "mult", lem: "integerMult", c: "mult_int", coq: "Z.mul"} : (int, int) -> int

overload operator * = {mult_atom, mult_int}

val "print_int" : (string, int) -> unit

val "prerr_int" : (string, int) -> unit

// ***** Integer shifts *****

/*!
A common idiom in asl is to take two bits of an opcode and convert in into a variable like
```
let elsize = shl_int(8, UInt(size))
```
THIS ensures that in this case the typechecker knows that the end result will be a value in the set `{8, 16, 32, 64}`

Similarly, we define shifts of 32 and 1 (i.e., powers of two).
*/
val _shl8 = {c: "shl_mach_int", coq: "shl_int_8", _: "shl_int"} :
  forall 'n, 0 <= 'n <= 3. (int(8), int('n)) -> {'m, 'm in {8, 16, 32, 64}. int('m)}

val _shl32 = {c: "shl_mach_int", coq: "shl_int_32", _: "shl_int"} :
  forall 'n, 'n in {0, 1}. (int(32), int('n)) -> {'m, 'm in {32, 64}. int('m)}

val _shl1 = {c: "shl_mach_int", coq: "shl_int_1", _: "shl_int"} :
  forall 'n, 0 <= 'n <= 3. (int(1), int('n)) -> {'m, 'm in {1, 2, 4, 8}. int('m)}

val _shl_int = "shl_int" : (int, int) -> int

overload shl_int = {_shl1, _shl8, _shl32, _shl_int}

val _shr32 = {c: "shr_mach_int", coq: "shr_int_32", _: "shr_int"} : forall 'n, 0 <= 'n <= 31. (int('n), int(1)) -> {'m, 0 <= 'm <= 15. int('m)}

val _shr_int = "shr_int" : (int, int) -> int

overload shr_int = {_shr32, _shr_int}

// ***** div and mod *****

/*! Truncating division (rounds towards zero) */
val tdiv_int = {
  ocaml: "tdiv_int",
  interpreter: "tdiv_int",
  lem: "tdiv_int",
  c: "tdiv_int",
  coq: "Z.quot"
} : (int, int) -> int

/*! Remainder for truncating division (has sign of dividend) */
val _tmod_int = {
  ocaml: "tmod_int",
  interpreter: "tmod_int",
  lem: "tmod_int",
  c: "tmod_int",
  coq: "Z.rem"
} : (int, int) -> int

/*! If we know the second argument is positive, we know the result is positive */
val _tmod_int_positive = {
  ocaml: "tmod_int",
  interpreter: "tmod_int",
  lem: "tmod_int",
  c: "tmod_int",
  coq: "Z.rem"
} : forall 'n, 'n >= 1. (int, int('n)) -> nat

overload tmod_int = {_tmod_int_positive, _tmod_int}

function fdiv_int(n: int, m: int) -> int = {
  if n < 0 & m > 0 then {
    tdiv_int(n + 1, m) - 1
  } else if n > 0 & m < 0 then {
    tdiv_int(n - 1, m) - 1
  } else {
    tdiv_int(n, m)
  }
}

function fmod_int(n: int, m: int) -> int = {
  n - (m * fdiv_int(n, m))
}

val abs_int_plain = {
  smt : "abs",
  ocaml: "abs_int",
  interpreter: "abs_int",
  lem: "integerAbs",
  c: "abs_int",
  coq: "Z.abs"
} : int -> int

overload abs_int = {abs_int_plain}

$endif