summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/coq/Sail2_values.v27
-rw-r--r--lib/smt.sail9
-rw-r--r--mips/prelude.sail39
3 files changed, 50 insertions, 25 deletions
diff --git a/lib/coq/Sail2_values.v b/lib/coq/Sail2_values.v
index 55d85b3a..229a9c09 100644
--- a/lib/coq/Sail2_values.v
+++ b/lib/coq/Sail2_values.v
@@ -852,7 +852,7 @@ Ltac solve_arithfact :=
reduce_list_lengths;
reduce_pow;
solve [apply ArithFact_mword; assumption
- | constructor; omega
+ | constructor; omega with Z
(* The datatypes hints give us some list handling, esp In *)
| constructor; auto with datatypes zbool zarith sail].
Hint Extern 0 (ArithFact _) => solve_arithfact : typeclass_instances.
@@ -1281,3 +1281,28 @@ Definition diafp_to_dia reginfo = function
end
*)
*)
+
+(* Arithmetic functions which return proofs that match the expected Sail
+ types in smt.sail. *)
+
+Definition div_with_eq n m : {o : Z & ArithFact (o = Z.quot n m)} := build_ex (Z.quot n m).
+Definition mod_with_eq n m : {o : Z & ArithFact (o = Z.rem n m)} := build_ex (Z.rem n m).
+Definition abs_with_eq n : {o : Z & ArithFact (o = Z.abs n)} := build_ex (Z.abs n).
+
+(* Similarly, for ranges (currently in MIPS) *)
+
+Definition add_range {n m o p} (l : {l & ArithFact (n <= l <= m)}) (r : {r & ArithFact (o <= r <= p)})
+ : {x & ArithFact (n+o <= x <= m+p)} :=
+ build_ex ((projT1 l) + (projT1 r)).
+Definition sub_range {n m o p} (l : {l & ArithFact (n <= l <= m)}) (r : {r & ArithFact (o <= r <= p)})
+ : {x & ArithFact (n-p <= x <= m-o)} :=
+ build_ex ((projT1 l) - (projT1 r)).
+Definition negate_range {n m} (l : {l : Z & ArithFact (n <= l <= m)})
+ : {x : Z & ArithFact ((- m) <= x <= (- n))} :=
+ build_ex (- (projT1 l)).
+
+Definition min_atom (a : Z) (b : Z) : {c : Z & ArithFact (c = a \/ c = b /\ c <= a /\ c <= b)} :=
+ build_ex (Z.min a b).
+Definition max_atom (a : Z) (b : Z) : {c : Z & ArithFact (c = a \/ c = b /\ c >= a /\ c >= b)} :=
+ build_ex (Z.max a b).
+
diff --git a/lib/smt.sail b/lib/smt.sail
index c9312819..efcbe48c 100644
--- a/lib/smt.sail
+++ b/lib/smt.sail
@@ -7,7 +7,8 @@ val div = {
smt: "div",
ocaml: "quotient",
lem: "integerDiv",
- c: "tdiv_int"
+ c: "tdiv_int",
+ coq: "div_with_eq"
} : forall 'n 'm. (atom('n), atom('m)) -> {'o, 'o = div('n, 'm). atom('o)}
overload operator / = {div}
@@ -16,7 +17,8 @@ val mod = {
smt: "mod",
ocaml: "modulus",
lem: "integerMod",
- c: "tmod_int"
+ c: "tmod_int",
+ coq: "mod_with_eq"
} : forall 'n 'm. (atom('n), atom('m)) -> {'o, 'o = mod('n, 'm). atom('o)}
overload operator % = {mod}
@@ -25,7 +27,8 @@ val abs_atom = {
smt : "abs",
ocaml: "abs_int",
lem: "abs_int",
- c: "abs_int"
+ c: "abs_int",
+ coq: "abs_with_eq"
} : forall 'n. atom('n) -> {'o, 'o = abs_atom('n). atom('o)}
$ifdef TEST
diff --git a/mips/prelude.sail b/mips/prelude.sail
index e2f1e0d4..2d164a79 100644
--- a/mips/prelude.sail
+++ b/mips/prelude.sail
@@ -22,7 +22,7 @@ val not_vec = {c:"not_bits", _:"not_vec"} : forall 'n. bits('n) -> bits('n)
overload ~ = {not_bool, not_vec}
-val not = "not" : bool -> bool
+val not = {coq:"negb", _:"not"} : bool -> bool
val neq_vec = {lem: "neq"} : forall 'n. (bits('n), bits('n)) -> bool
function neq_vec (x, y) = not_bool(eq_bits(x, y))
@@ -57,20 +57,20 @@ val putchar = {c:"sail_putchar", _:"putchar"} : forall ('a : Type). 'a -> unit
val concat_str = {lem: "stringAppend", _: "concat_str"} : (string, string) -> string
val string_of_int = "string_of_int" : int -> string
-
+/* Unused?
val DecStr : int -> string
val HexStr : int -> string
-
+*/
val BitStr = "string_of_bits" : forall 'n. bits('n) -> string
val xor_vec = {c: "xor_bits" , _: "xor_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)
-val int_power = {ocaml: "int_power", lem: "pow"} : (int, int) -> int
+val int_power = {ocaml: "int_power", lem: "pow", coq: "Z.pow"} : (int, int) -> int
overload operator ^ = {xor_vec, int_power}
-val add_range = {ocaml: "add_int", lem: "integerAdd"} : forall 'n 'm 'o 'p.
+val add_range = {ocaml: "add_int", lem: "integerAdd", coq: "add_range"} : forall 'n 'm 'o 'p.
(range('n, 'm), range('o, 'p)) -> range('n + 'o, 'm + 'p)
val add_vec = "add_vec" : forall 'n. (bits('n), bits('n)) -> bits('n)
@@ -79,48 +79,45 @@ val add_vec_int = "add_vec_int" : forall 'n. (bits('n), int) -> bits('n)
overload operator + = {add_range, add_int, add_vec, add_vec_int}
-val sub_range = {ocaml: "sub_int", lem: "integerMinus"} : forall 'n 'm 'o 'p.
+val sub_range = {ocaml: "sub_int", lem: "integerMinus", coq: "sub_range"} : forall 'n 'm 'o 'p.
(range('n, 'm), range('o, 'p)) -> range('n - 'p, 'm - 'o)
val sub_vec = {c : "sub_bits", _:"sub_vec"} : forall 'n. (bits('n), bits('n)) -> bits('n)
val sub_vec_int = {c:"sub_bits_int", _: "sub_vec_int"} : forall 'n. (bits('n), int) -> bits('n)
-val negate_range = {ocaml: "minus_big_int", lem: "integerNegate"} : forall 'n 'm. range('n, 'm) -> range(- 'm, - 'n)
+val negate_range = {ocaml: "minus_big_int", lem: "integerNegate", coq: "negate_range"} : forall 'n 'm. range('n, 'm) -> range(- 'm, - 'n)
overload operator - = {sub_range, sub_int, sub_vec, sub_vec_int}
overload negate = {negate_range, negate_int}
-val mult_range = {ocaml: "mult", lem: "integerMult"} : forall 'n 'm 'o 'p.
- (range('n, 'm), range('o, 'p)) -> range('n * 'o, 'm * 'p)
-
-overload operator * = {mult_range, mult_int}
+overload operator * = {mult_int}
-val quotient_nat = {ocaml: "quotient", lem: "integerDiv"} : (nat, nat) -> nat
+val quotient_nat = {ocaml: "quotient", lem: "integerDiv", coq: "Z.div"} : (nat, nat) -> nat
-val quotient = {ocaml: "quotient", lem: "integerDiv"} : (int, int) -> int
+val quotient = {ocaml: "quotient", lem: "integerDiv", coq: "Z.mod"} : (int, int) -> int
overload operator / = {quotient_nat, quotient}
-val quot_round_zero = {ocaml: "quot_round_zero", lem: "hardware_quot", _ : "tdiv_int"} : (int, int) -> int
-val rem_round_zero = {ocaml: "rem_round_zero", lem: "hardware_mod", _ : "tmod_int"} : (int, int) -> int
+val quot_round_zero = {ocaml: "quot_round_zero", lem: "hardware_quot", coq: "Z.quot", _ : "tdiv_int"} : (int, int) -> int
+val rem_round_zero = {ocaml: "rem_round_zero", lem: "hardware_mod", coq: "Z.rem", _ : "tmod_int"} : (int, int) -> int
val modulus = {ocaml: "modulus", lem: "hardware_mod", _ : "tmod_int"} : forall 'n, 'n > 0 . (int, atom('n)) -> range(0, 'n - 1)
overload operator % = {modulus}
-val min_nat = {ocaml: "min_int", lem: "min"} : (nat, nat) -> nat
+val min_nat = {ocaml: "min_int", lem: "min", coq: "Z.min"} : (nat, nat) -> nat
-val min_int = {ocaml: "min_int", lem: "min"} : (int, int) -> int
+val min_int = {ocaml: "min_int", lem: "min", coq: "Z.min"} : (int, int) -> int
-val max_nat = {ocaml: "max_int", lem: "max"} : (nat, nat) -> nat
+val max_nat = {ocaml: "max_int", lem: "max", coq: "Z.max"} : (nat, nat) -> nat
-val max_int = {ocaml: "max_int", lem: "max"} : (int, int) -> int
+val max_int = {ocaml: "max_int", lem: "max", coq: "Z.max"} : (int, int) -> int
-val min_atom = {ocaml: "min_int", lem: "min"} : forall 'a 'b . (atom('a), atom('b)) -> {'c, ('c = 'a | 'c = 'b) & 'c <= 'a & 'c <= 'b . atom('c)}
+val min_atom = {ocaml: "min_int", lem: "min", coq: "min_atom"} : forall 'a 'b . (atom('a), atom('b)) -> {'c, ('c = 'a | 'c = 'b) & 'c <= 'a & 'c <= 'b . atom('c)}
-val max_atom = {ocaml: "max_int", lem: "max"} : forall 'a 'b . (atom('a), atom('b)) -> {'c, ('c = 'a | 'c = 'b) & 'c >= 'a & 'c >= 'b . atom('c)}
+val max_atom = {ocaml: "max_int", lem: "max", coq: "max_atom"} : forall 'a 'b . (atom('a), atom('b)) -> {'c, ('c = 'a | 'c = 'b) & 'c >= 'a & 'c >= 'b . atom('c)}
overload min = {min_atom, min_nat, min_int}