summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--riscv/prelude.sail299
-rw-r--r--riscv/riscv_types.sail174
2 files changed, 473 insertions, 0 deletions
diff --git a/riscv/prelude.sail b/riscv/prelude.sail
new file mode 100644
index 00000000..b7c1c6a1
--- /dev/null
+++ b/riscv/prelude.sail
@@ -0,0 +1,299 @@
+default Order dec
+
+/* type bits ('n : Int) = vector('n - 1, 'n, dec, bit) */
+type bits ('n : Int) = vector('n, dec, bit)
+
+infix 4 ==
+
+val eq_atom = "eq_int" : forall 'n 'm. (atom('n), atom('m)) -> bool
+val lteq_atom = "lteq" : forall 'n 'm. (atom('n), atom('m)) -> bool
+val gteq_atom = "gteq" : forall 'n 'm. (atom('n), atom('m)) -> bool
+val lt_atom = "lt" : forall 'n 'm. (atom('n), atom('m)) -> bool
+val gt_atom = "gt" : forall 'n 'm. (atom('n), atom('m)) -> bool
+
+val eq_int = "eq_int" : (int, int) -> bool
+
+val eq_vec = "eq_list" : forall 'n. (bits('n), bits('n)) -> bool
+
+val eq_string = "eq_string" : (string, string) -> bool
+
+val eq_real = "eq_real" : (real, real) -> bool
+
+val eq_anything = "(fun (x, y) -> x = y)" : forall ('a : Type). ('a, 'a) -> bool
+
+val length = "length" : forall 'n ('a : Type). vector('n, dec, 'a) -> atom('n)
+
+val "reg_deref" : forall ('a : Type). register('a) -> 'a effect {rreg}
+
+overload operator == = {eq_atom, eq_int, eq_vec, eq_string, eq_real, eq_anything}
+
+val vector_subrange = "subrange" : forall ('n : Int) ('m : Int) ('o : Int), 'o <= 'm <= 'n.
+ (bits('n), atom('m), atom('o)) -> bits('m - ('o - 1))
+
+val vector_access = "access" : forall ('n : Int) ('m : Int) ('a : Type), 0 <= 'm < 'n.
+ (vector('n, dec, 'a), atom('m)) -> 'a
+
+val vector_update = "update" : forall 'n ('a : Type).
+ (vector('n, dec, 'a), int, 'a) -> vector('n, dec, 'a)
+
+val vector_update_subrange = "update_subrange" : forall 'n 'm 'o.
+ (bits('n), atom('m), atom('o), bits('m - ('o - 1))) -> bits('n)
+
+val vcons : forall ('n : Int) ('a : Type).
+ ('a, vector('n, dec, 'a)) -> vector('n + 1, dec, 'a)
+
+val append = "append" : forall ('n : Int) ('m : Int) ('a : Type).
+ (vector('n, dec, 'a), vector('m, dec, 'a)) -> vector('n + 'm, dec, 'a)
+
+val not_bool = "not" : bool -> bool
+
+val not_vec = "not_vec" : forall 'n. bits('n) -> bits('n)
+
+overload ~ = {not_bool, not_vec}
+
+val neq_atom : forall 'n 'm. (atom('n), atom('m)) -> bool
+
+function neq_atom (x, y) = not_bool(eq_atom(x, y))
+
+val neq_int : (int, int) -> bool
+
+function neq_int (x, y) = not_bool(eq_int(x, y))
+
+val neq_vec : forall 'n. (bits('n), bits('n)) -> bool
+
+function neq_vec (x, y) = not_bool(eq_vec(x, y))
+
+val neq_anything : forall ('a : Type). ('a, 'a) -> bool
+
+function neq_anything (x, y) = not_bool(x == y)
+
+overload operator != = {neq_atom, neq_int, neq_vec, neq_anything}
+
+val and_bool = "and_bool" : (bool, bool) -> bool
+
+val builtin_and_vec = "and_vec" : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+val and_vec : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+function and_vec (xs, ys) = builtin_and_vec(xs, ys)
+
+overload operator & = {and_bool, and_vec}
+
+val or_bool = "or_bool" : (bool, bool) -> bool
+
+val builtin_or_vec = "or_vec" : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+val or_vec : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+function or_vec (xs, ys) = builtin_or_vec(xs, ys)
+
+overload operator | = {or_bool, or_vec}
+
+val unsigned = "uint" : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1)
+
+val signed = "sint" : forall 'n. bits('n) -> range(- (2 ^ ('n - 1)), 2 ^ ('n - 1) - 1)
+
+val hex_slice = "hex_slice" : forall 'n 'm. (string, atom('n), atom('m)) -> bits('n - 'm)
+
+val __SetSlice_bits = "set_slice" : forall 'n 'm.
+ (atom('n), atom('m), bits('n), int, bits('m)) -> bits('n)
+
+val __SetSlice_int = "set_slice_int" : forall 'w. (atom('w), int, int, bits('w)) -> int
+
+val __raw_SetSlice_int : forall 'w. (atom('w), int, int, bits('w)) -> int
+
+val __raw_GetSlice_int = "get_slice_int" : forall 'w. (atom('w), int, int) -> bits('w)
+
+val __GetSlice_int : forall 'n. (atom('n), int, int) -> bits('n)
+
+function __GetSlice_int (n, m, o) = __raw_GetSlice_int(n, m, o)
+
+val __raw_SetSlice_bits : forall 'n 'w.
+ (atom('n), atom('w), bits('n), int, bits('w)) -> bits('n)
+
+val __raw_GetSlice_bits : forall 'n 'w.
+ (atom('n), atom('w), bits('n), int) -> bits('w)
+
+val __ShiftLeft : forall 'm. (bits('m), int) -> bits('m)
+
+val __SignExtendSlice : forall 'm. (bits('m), int, int) -> bits('m)
+
+val __ZeroExtendSlice : forall 'm. (bits('m), int, int) -> bits('m)
+
+val cast cast_unit_vec : bit -> bits(1)
+
+val print = "prerr_endline" : string -> unit
+
+val putchar = "putchar" : forall ('a : Type). 'a -> unit
+
+val concat_str = "concat_str" : (string, string) -> string
+
+val DecStr : int -> string
+
+val HexStr : int -> string
+
+val BitStr = "string_of_bits" : forall 'n. bits('n) -> string
+
+val xor_vec = "xor_vec" : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+val int_power : (int, int) -> int
+
+val real_power = "real_power" : (real, int) -> real
+
+overload operator ^ = {xor_vec, int_power, real_power}
+
+val add_range = "add" : forall 'n 'm 'o 'p.
+ (range('n, 'm), range('o, 'p)) -> range('n + 'o, 'm + 'p)
+
+val add_int = "add" : (int, int) -> int
+
+val add_vec = "add_vec" : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+val add_vec_int = "add_vec_int" : forall 'n. (bits('n), int) -> bits('n)
+
+val add_real = "add_real" : (real, real) -> real
+
+overload operator + = {add_range, add_int, add_vec, add_vec_int, add_real}
+
+val sub_range = "sub" : forall 'n 'm 'o 'p.
+ (range('n, 'm), range('o, 'p)) -> range('n - 'p, 'm - 'o)
+
+val sub_int = "sub" : (int, int) -> int
+
+val "sub_vec" : forall 'n. (bits('n), bits('n)) -> bits('n)
+
+val "sub_vec_int" : forall 'n. (bits('n), int) -> bits('n)
+
+val "sub_real" : (real, real) -> real
+
+val negate_range = "minus_big_int" : forall 'n 'm. range('n, 'm) -> range(- 'm, - 'n)
+
+val negate_int = "minus_big_int" : int -> int
+
+val negate_real = "Num.minus_num" : real -> real
+
+overload operator - = {sub_range, sub_int, sub_vec, sub_vec_int, sub_real}
+
+overload negate = {negate_range, negate_int, negate_real}
+
+val mult_range = "mult" : forall 'n 'm 'o 'p.
+ (range('n, 'm), range('o, 'p)) -> range('n * 'o, 'm * 'p)
+
+val mult_int = "mult" : (int, int) -> int
+
+val mult_real = "mult_real" : (real, real) -> real
+
+overload operator * = {mult_range, mult_int, mult_real}
+
+val Sqrt = "sqrt_real" : real -> real
+
+val gteq_int = "gteq" : (int, int) -> bool
+
+val gteq_real = "gteq_real" : (real, real) -> bool
+
+overload operator >= = {gteq_atom, gteq_int, gteq_real}
+
+val lteq_int = "lteq" : (int, int) -> bool
+
+val lteq_real = "lteq_real" : (real, real) -> bool
+
+overload operator <= = {lteq_atom, lteq_int, lteq_real}
+
+val gt_int = "gt" : (int, int) -> bool
+
+val gt_real = "gt_real" : (real, real) -> bool
+
+overload operator > = {gt_atom, gt_int, gt_real}
+
+val lt_int = "lt" : (int, int) -> bool
+
+val lt_real = "lt_real" : (real, real) -> bool
+
+overload operator < = {lt_atom, lt_int, lt_real}
+
+val RoundDown = "round_down" : real -> int
+
+val RoundUp = "round_up" : real -> int
+
+val abs_int = "abs_int" : int -> int
+
+val abs_real = "abs_real" : real -> real
+
+overload abs = {abs_int, abs_real}
+
+val quotient_nat = "quotient" : (nat, nat) -> nat
+
+val quotient_real = "quotient_real" : (real, real) -> real
+
+val quotient = "quotient" : (int, int) -> int
+
+infixl 7 /
+
+overload operator / = {quotient_nat, quotient, quotient_real}
+
+val modulus = "modulus" : (int, int) -> int
+
+infixl 7 %
+
+overload operator % = {modulus}
+
+val Real = "Num.num_of_big_int" : int -> real
+
+val shl_int = "shl_int" : (int, int) -> int
+
+val shr_int = "shr_int" : (int, int) -> int
+
+val min_nat = "min_int" : (nat, nat) -> nat
+
+val min_int = "min_int" : (int, int) -> int
+
+val max_nat = "max_int" : (nat, nat) -> nat
+
+val max_int = "max_int" : (int, int) -> int
+
+overload min = {min_nat, min_int}
+
+overload max = {max_nat, max_int}
+
+val __WriteRAM = "write_ram" : forall 'n 'm.
+ (atom('m), atom('n), bits('m), bits('m), bits(8 * 'n)) -> unit
+
+val __TraceMemoryWrite : forall 'n 'm.
+ (atom('n), bits('m), bits(8 * 'n)) -> unit
+
+val __ReadRAM = "read_ram" : forall 'n 'm.
+ (atom('m), atom('n), bits('m), bits('m)) -> bits(8 * 'n)
+
+val __TraceMemoryRead : forall 'n 'm. (atom('n), bits('m), bits(8 * 'n)) -> unit
+
+val replicate_bits = "replicate_bits" : forall 'n 'm. (bits('n), atom('m)) -> bits('n * 'm)
+
+val cast ex_nat : nat -> {'n, 'n >= 0. atom('n)}
+
+function ex_nat 'n = n
+
+val cast ex_int : int -> {'n, true. atom('n)}
+
+function ex_int 'n = n
+
+val ex_range : forall 'n 'm.
+ range('n, 'm) -> {'o, 'n <= 'o & 'o <= 'm. atom('o)}
+
+val coerce_int_nat : int -> nat effect {escape}
+
+function coerce_int_nat 'x = {
+ assert(constraint('x >= 0));
+ x
+}
+
+val slice = "slice" : forall ('n : Int) ('m : Int), 'm >= 0 & 'n >= 0.
+ (bits('m), int, atom('n)) -> bits('n)
+
+val pow2 = "pow2" : forall 'n. atom('n) -> atom(2 ^ 'n)
+
+val print_int = "print_int" : (string, int) -> unit
+
+union exception = {
+ Error_not_implemented : string,
+ Error_misaligned_access,
+} \ No newline at end of file
diff --git a/riscv/riscv_types.sail b/riscv/riscv_types.sail
new file mode 100644
index 00000000..a0120e06
--- /dev/null
+++ b/riscv/riscv_types.sail
@@ -0,0 +1,174 @@
+val not_implemented : forall ('a : Type). string -> 'a effect {escape}
+
+function not_implemented message = throw(Error_not_implemented(message))
+
+type regval = bits(64)
+
+type regno ('n : Int), 0 <= 'n < 32 = atom('n)
+
+/* register x0 : regval is hard-wired zero */
+register x1 : regval
+register x2 : regval
+register x3 : regval
+register x4 : regval
+register x5 : regval
+register x6 : regval
+register x7 : regval
+register x8 : regval
+register x9 : regval
+register x10 : regval
+register x11 : regval
+register x12 : regval
+register x13 : regval
+register x14 : regval
+register x15 : regval
+register x16 : regval
+register x17 : regval
+register x18 : regval
+register x19 : regval
+register x20 : regval
+register x21 : regval
+register x22 : regval
+register x23 : regval
+register x24 : regval
+register x25 : regval
+register x26 : regval
+register x27 : regval
+register x28 : regval
+register x29 : regval
+register x30 : regval
+register x31 : regval
+
+register PC : bits(64)
+register nextPC : bits(64)
+
+let GPRs : vector(31, dec, register(regval)) =
+ [ ref x31, ref x30, ref x29, ref x28,
+ ref x27, ref x26, ref x25, ref x24,
+ ref x23, ref x22, ref x21, ref x20,
+ ref x19, ref x18, ref x17, ref x16,
+ ref x15, ref x14, ref x13, ref x12,
+ ref x11, ref x10, ref x9, ref x8,
+ ref x7, ref x6, ref x5, ref x4,
+ ref x3, ref x2, ref x1 /* ref x0 */
+ ]
+
+/* Getters and setters for registers */
+val rGPR : forall 'n, 0 <= 'n < 32. regno('n) -> regval effect {rreg}
+
+function rGPR 0 = 0x0000000000000000
+and rGPR (r if r > 0) = reg_deref(GPRs[r - 1])
+
+val wGPR : forall 'n, 1 <= 'n < 32. (regno('n), regval) -> unit effect {wreg}
+
+function wGPR (r, v) =
+ if (r != 0) then (*GPRs[r - 1]) = v else ()
+
+function check_alignment (addr : bits(64), width : atom('n)) -> forall 'n. unit =
+ if unsigned(addr) % width != 0 then throw(Error_misaligned_access) else ()
+
+val "MEMr" : forall 'n. (bits(64), atom('n)) -> bits(8 * 'n) effect {rmem}
+val "MEMr_acquire" : forall 'n. (bits(64), atom('n)) -> bits(8 * 'n) effect {rmem}
+val "MEMr_strong_acquire" : forall 'n. (bits(64), atom('n)) -> bits(8 * 'n) effect {rmem}
+val "MEMr_reserved" : forall 'n. (bits(64), atom('n)) -> bits(8 * 'n) effect {rmem}
+val "MEMr_reserved_acquire" : forall 'n. (bits(64), atom('n)) -> bits(8 * 'n) effect {rmem}
+val "MEMr_reserved_strong_acquire" : forall 'n. (bits(64), atom('n)) -> bits(8 * 'n) effect {rmem}
+
+val mem_read : forall 'n. (bits(64), atom('n), bool, bool, bool) -> bits(8 * 'n) effect {rmem, escape}
+
+function mem_read (addr, width, aq, rl, res) = {
+ if aq | res then check_alignment(addr, width);
+
+ match (aq, rl, res) {
+ (false, false, false) => MEMr(addr, width),
+ (true, false, false) => MEMr_acquire(addr, width),
+ (false, false, true) => MEMr_reserved(addr, width),
+ (true, false, true) => MEMr_reserved_acquire(addr, width),
+ (false, true, false) => throw(Error_not_implemented("load.rl")),
+ (true, true, false) => MEMr_strong_acquire(addr, width),
+ (false, true, true) => throw(Error_not_implemented("lr.rl")),
+ (true, true, true) => MEMr_reserved_strong_acquire(addr, width)
+ }
+}
+
+val "MEMea" : forall 'n. (bits(64), atom('n)) -> unit effect {eamem}
+val "MEMea_release" : forall 'n. (bits(64), atom('n)) -> unit effect {eamem}
+val "MEMea_strong_release" : forall 'n. (bits(64), atom('n)) -> unit effect {eamem}
+val "MEMea_conditional" : forall 'n. (bits(64), atom('n)) -> unit effect {eamem}
+val "MEMea_conditional_release" : forall 'n. (bits(64), atom('n)) -> unit effect {eamem}
+val "MEMea_conditional_strong_release" : forall 'n. (bits(64), atom('n)) -> unit effect {eamem}
+
+val mem_write_ea : forall 'n. (bits(64), atom('n), bool, bool, bool) -> unit effect {eamem, escape}
+
+function mem_write_ea (addr, width, aq, rl, con) = {
+ if rl | con then check_alignment(addr, width);
+
+ match (aq, rl, con) {
+ (false, false, false) => MEMea(addr, width),
+ (false, true, false) => MEMea_release(addr, width),
+ (false, false, true) => MEMea_conditional(addr, width),
+ (false, true , true) => MEMea_conditional_release(addr, width),
+ (true, false, false) => throw(Error_not_implemented("store.aq")),
+ (true, true, false) => MEMea_strong_release(addr, width),
+ (true, false, true) => throw(Error_not_implemented("sc.aq")),
+ (true, true , true) => MEMea_conditional_strong_release(addr, width)
+ }
+}
+
+val "MEMval" : forall 'n. (bits(64), atom('n), bits(8 * 'n)) -> unit effect {wmv}
+val "MEMval_release" : forall 'n. (bits(64), atom('n), bits(8 * 'n)) -> unit effect {wmv}
+val "MEMval_strong_release" : forall 'n. (bits(64), atom('n), bits(8 * 'n)) -> unit effect {wmv}
+val "MEMval_conditional" : forall 'n. (bits(64), atom('n), bits(8 * 'n)) -> unit effect {wmv}
+val "MEMval_conditional_release" : forall 'n. (bits(64), atom('n), bits(8 * 'n)) -> unit effect {wmv}
+val "MEMval_conditional_strong_release" : forall 'n. (bits(64), atom('n), bits(8 * 'n)) -> unit effect {wmv}
+
+val mem_write_value : forall 'n. (bits(64), atom('n), bits(8 * 'n), bool, bool, bool) -> unit effect {wmv, escape}
+
+function mem_write_value (addr, width, value, aq, rl, con) = {
+ if rl | con then check_alignment(addr, width);
+
+ match (aq, rl, con) {
+ (false, false, false) => MEMval(addr, width, value),
+ (false, true, false) => MEMval_release(addr, width, value),
+ (false, false, true) => MEMval_conditional(addr, width, value),
+ (false, true, true) => MEMval_conditional_release(addr, width, value),
+ (true, false, false) => throw(Error_not_implemented("store.aq")),
+ (true, true, false) => MEMval_strong_release(addr, width, value),
+ (true, false, true) => throw(Error_not_implemented("sc.aq")),
+ (true, true, true) => MEMval_conditional_strong_release(addr, width, value)
+ }
+}
+
+val "speculate_conditional_success" : unit -> bool effect {exmem}
+
+val "MEM_fence_rw_rw" : unit -> unit effect {barr}
+val "MEM_fence_r_rw" : unit -> unit effect {barr}
+val "MEM_fence_r_r" : unit -> unit effect {barr}
+val "MEM_fence_rw_w" : unit -> unit effect {barr}
+val "MEM_fence_w_w" : unit -> unit effect {barr}
+val "MEM_fence_i" : unit -> unit effect {barr}
+
+enum uop = {RISCV_LUI, RISCV_AUIPC} /* upper immediate ops */
+enum bop = {RISCV_BEQ, RISCV_BNE, RISCV_BLT, RISCV_BGE, RISCV_BLTU, RISCV_BGEU} /* branch ops */
+enum iop = {RISCV_ADDI, RISCV_SLTI, RISCV_SLTIU, RISCV_XORI, RISCV_ORI, RISCV_ANDI} /* immediate ops */
+enum sop = {RISCV_SLLI, RISCV_SRLI, RISCV_SRAI} /* shift ops */
+enum rop = {RISCV_ADD, RISCV_SUB, RISCV_SLL, RISCV_SLT, RISCV_SLTU, RISCV_XOR, RISCV_SRL, RISCV_SRA, RISCV_OR, RISCV_AND} /* reg-reg ops */
+enum ropw = {RISCV_ADDW, RISCV_SUBW, RISCV_SLLW, RISCV_SRLW, RISCV_SRAW} /* reg-reg 32-bit ops */
+enum amoop = {AMOSWAP, AMOADD, AMOXOR, AMOAND, AMOOR, AMOMIN, AMOMAX, AMOMINU, AMOMAXU} /* AMO ops */
+
+enum word_width = {BYTE, HALF, WORD, DOUBLE}
+
+/********************************************************************/
+
+/* Ideally these would be sail builtin */
+
+/*
+function (bit[64]) shift_right_arith64 ((bit[64]) v, (bit[6]) shift) =
+ let (bit[128]) v128 = EXTS(v) in
+ (v128 >> shift)[63..0]
+
+function (bit[32]) shift_right_arith32 ((bit[32]) v, (bit[5]) shift) =
+ let (bit[64]) v64 = EXTS(v) in
+ (v64 >> shift)[31..0]
+
+*/