diff options
| -rw-r--r-- | lib/int128/rts.c | 529 | ||||
| -rw-r--r-- | lib/int128/rts.h | 108 | ||||
| -rw-r--r-- | lib/int128/sail.c | 1558 | ||||
| -rw-r--r-- | lib/int128/sail.h | 404 | ||||
| -rw-r--r-- | lib/sail.c | 5 | ||||
| -rw-r--r-- | lib/sail.h | 3 | ||||
| -rw-r--r-- | src/jib/c_backend.ml | 22 | ||||
| -rw-r--r-- | src/jib/c_backend.mli | 1 | ||||
| -rw-r--r-- | src/sail.ml | 3 |
9 files changed, 2623 insertions, 10 deletions
diff --git a/lib/int128/rts.c b/lib/int128/rts.c new file mode 100644 index 00000000..86c305c9 --- /dev/null +++ b/lib/int128/rts.c @@ -0,0 +1,529 @@ +#include<string.h> +#include<getopt.h> +#include<inttypes.h> + +#include"sail.h" +#include"rts.h" +#include"elf.h" + +static uint64_t g_elf_entry; +uint64_t g_cycle_count = 0; +static uint64_t g_cycle_limit; + +void sail_match_failure(sail_string msg) +{ + fprintf(stderr, "Pattern match failure in %s\n", msg); + exit(EXIT_FAILURE); +} + +unit sail_assert(bool b, sail_string msg) +{ + if (b) return UNIT; + fprintf(stderr, "Assertion failed: %s\n", msg); + exit(EXIT_FAILURE); +} + +unit sail_exit(unit u) +{ + fprintf(stderr, "[Sail] Exiting after %" PRIu64 " cycles\n", g_cycle_count); + exit(EXIT_SUCCESS); + return UNIT; +} + +static uint64_t g_verbosity = 0; + +fbits sail_get_verbosity(const unit u) +{ + return g_verbosity; +} + +bool g_sleeping = false; + +unit sleep_request(const unit u) +{ + fprintf(stderr, "Sail CPU model going to sleep\n"); + g_sleeping = true; + return UNIT; +} + +unit wakeup_request(const unit u) +{ + fprintf(stderr, "Sail CPU model waking up\n"); + g_sleeping = false; + return UNIT; +} + +bool sleeping(const unit u) +{ + return g_sleeping; +} + +/* ***** Sail memory builtins ***** */ + +/* + * We organise memory available to the sail model into a linked list + * of dynamically allocated MASK + 1 size blocks. + */ +struct block { + uint64_t block_id; + uint8_t *mem; + struct block *next; +}; + +struct block *sail_memory = NULL; + +struct tag_block { + uint64_t block_id; + bool *mem; + struct tag_block *next; +}; + +struct tag_block *sail_tags = NULL; + +/* + * Must be one less than a power of two. + */ +uint64_t MASK = 0xFFFFFFul; + +/* + * All sail vectors are at least 64-bits, but only the bottom 8 bits + * are used in the second argument. + */ +void write_mem(uint64_t address, uint64_t byte) +{ + uint64_t mask = address & ~MASK; + uint64_t offset = address & MASK; + + //if ((byte >= 97 && byte <= 122) || (byte >= 64 && byte <= 90) || (byte >= 48 && byte <= 57) || byte == 10 || byte == 32) { + // fprintf(stderr, "%c", (char) byte); + //} + + struct block *current = sail_memory; + + while (current != NULL) { + if (current->block_id == mask) { + current->mem[offset] = (uint8_t) byte; + return; + } else { + current = current->next; + } + } + + /* + * If we couldn't find a block matching the mask, allocate a new + * one, write the byte, and put it at the front of the block list. + */ + fprintf(stderr, "[Sail] Allocating new block 0x%" PRIx64 "\n", mask); + struct block *new_block = malloc(sizeof(struct block)); + new_block->block_id = mask; + new_block->mem = calloc(MASK + 1, sizeof(uint8_t)); + new_block->mem[offset] = (uint8_t) byte; + new_block->next = sail_memory; + sail_memory = new_block; +} + +uint64_t read_mem(uint64_t address) +{ + uint64_t mask = address & ~MASK; + uint64_t offset = address & MASK; + + struct block *current = sail_memory; + + while (current != NULL) { + if (current->block_id == mask) { + return (uint64_t) current->mem[offset]; + } else { + current = current->next; + } + } + + return 0x00; +} + +unit write_tag_bool(const uint64_t address, const bool tag) +{ + uint64_t mask = address & ~MASK; + uint64_t offset = address & MASK; + + struct tag_block *current = sail_tags; + + while (current != NULL) { + if (current->block_id == mask) { + current->mem[offset] = tag; + return UNIT; + } else { + current = current->next; + } + } + + /* + * If we couldn't find a block matching the mask, allocate a new + * one, write the byte, and put it at the front of the block list. + */ + fprintf(stderr, "[Sail] Allocating new tag block 0x%" PRIx64 "\n", mask); + struct tag_block *new_block = malloc(sizeof(struct tag_block)); + new_block->block_id = mask; + new_block->mem = calloc(MASK + 1, sizeof(bool)); + new_block->mem[offset] = tag; + new_block->next = sail_tags; + sail_tags = new_block; + + return UNIT; +} + +bool read_tag_bool(const uint64_t address) +{ + uint64_t mask = address & ~MASK; + uint64_t offset = address & MASK; + + struct tag_block *current = sail_tags; + + while (current != NULL) { + if (current->block_id == mask) { + return current->mem[offset]; + } else { + current = current->next; + } + } + + return false; +} + +void kill_mem() +{ + while (sail_memory != NULL) { + struct block *next = sail_memory->next; + + free(sail_memory->mem); + free(sail_memory); + + sail_memory = next; + } + + while (sail_tags != NULL) { + struct tag_block *next = sail_tags->next; + + free(sail_tags->mem); + free(sail_tags); + + sail_tags = next; + } +} + +// ***** Memory builtins ***** + +mpz_t write_buf; + +bool write_ram(const sail_int addr_size, // Either 32 or 64 + const sail_int data_size_mpz, // Number of bytes + const lbits hex_ram, // Currently unused + const lbits addr_bv, + const lbits data) +{ + uint64_t addr = mpz_get_ui(*addr_bv.bits); + uint64_t data_size = (uint64_t) data_size_mpz; + + if (data_size <= 8) { + uint64_t bytes = mpz_get_ui(*data.bits); + + for(uint64_t i = 0; i < data_size; ++i) { + write_mem(addr + i, bytes & 0xFF); + bytes >>= 8; + } + + return true; + } else { + mpz_set(write_buf, *data.bits); + + uint64_t byte; + for(uint64_t i = 0; i < data_size; ++i) { + // Take the 8 low bits of write_buf and write to addr. + byte = mpz_get_ui(write_buf) & 0xFF; + write_mem(addr + i, byte); + + // Then shift buf 8 bits right. + mpz_fdiv_q_2exp(write_buf, write_buf, 8); + } + + return true; + } +} + +sbits fast_read_ram(const int64_t data_size, + const uint64_t addr) +{ + uint64_t r = 0; + + uint64_t byte; + for(uint64_t i = (uint64_t) data_size; i > 0; --i) { + byte = read_mem(addr + (i - 1)); + r = r << 8; + r = r + byte; + } + sbits res = {.len = data_size * 8, .bits = r }; + return res; +} + +mpz_t read_buf; + +void read_ram(lbits *data, + const sail_int addr_size, + const sail_int data_size_mpz, + const lbits hex_ram, + const lbits addr_bv) +{ + uint64_t addr = mpz_get_ui(*addr_bv.bits); + uint64_t data_size = (uint64_t) data_size_mpz; + + if (data_size <= 8) { + uint64_t byte = 0; + + for(uint64_t i = data_size; i > 0; --i) { + byte = byte << 8; + byte += read_mem(addr + (i - 1)); + } + + mpz_set_ui(*data->bits, byte); + data->len = data_size * 8; + } else { + mpz_set_ui(*data->bits, 0); + data->len = data_size * 8; + + for(uint64_t i = data_size; i > 0; --i) { + mpz_set_ui(read_buf, read_mem(addr + (i - 1))); + mpz_mul_2exp(*data->bits, *data->bits, 8); + mpz_add(*data->bits, *data->bits, read_buf); + } + } +} + +unit load_raw(fbits addr, const sail_string file) +{ + FILE *fp = fopen(file, "r"); + + if (!fp) { + fprintf(stderr, "[Sail] Raw file %s could not be loaded\n", file); + exit(EXIT_FAILURE); + } + + uint64_t byte; + while ((byte = (uint64_t)fgetc(fp)) != EOF) { + write_mem(addr, byte); + addr++; + } + + return UNIT; +} + +void load_image(char *file) +{ + FILE *fp = fopen(file, "r"); + + if (!fp) { + fprintf(stderr, "[Sail] Image file %s could not be loaded\n", file); + exit(EXIT_FAILURE); + } + + char *addr = NULL; + char *data = NULL; + size_t len = 0; + + while (true) { + ssize_t addr_len = getline(&addr, &len, fp); + if (addr_len == -1) break; + ssize_t data_len = getline(&data, &len, fp); + if (data_len == -1) break; + + if (!strcmp(addr, "elf_entry\n")) { + if (sscanf(data, "%" PRIu64 "\n", &g_elf_entry) != 1) { + fprintf(stderr, "[Sail] Failed to parse elf_entry\n"); + exit(EXIT_FAILURE); + }; + fprintf(stderr, "[Sail] Elf entry point: %" PRIx64 "\n", g_elf_entry); + } else { + write_mem((uint64_t) atoll(addr), (uint64_t) atoll(data)); + } + } + + free(addr); + free(data); + fclose(fp); +} + +/* ***** ELF functions ***** */ + +sail_int elf_entry(const unit u) +{ + return (__int128) g_elf_entry; +} + +sail_int elf_tohost(const unit u) +{ + return (__int128) 0; +} + +/* ***** Cycle limit ***** */ + +/* NB Also increments cycle_count */ +bool cycle_limit_reached(const unit u) +{ + return ++g_cycle_count >= g_cycle_limit && g_cycle_limit != 0; +} + +unit cycle_count(const unit u) +{ + if (cycle_limit_reached(UNIT)) { + printf("\n[Sail] TIMEOUT: exceeded %" PRId64 " cycles\n", g_cycle_limit); + exit(EXIT_SUCCESS); + } + return UNIT; +} + +sail_int get_cycle_count(const unit u) +{ + return (__int128) g_cycle_count; +} + +/* ***** Argument Parsing ***** */ + +static struct option options[] = { + {"binary", required_argument, 0, 'b'}, + {"cyclelimit", required_argument, 0, 'l'}, + {"config", required_argument, 0, 'C'}, + {"elf", required_argument, 0, 'e'}, + {"entry", required_argument, 0, 'n'}, + {"image", required_argument, 0, 'i'}, + {"verbosity", required_argument, 0, 'v'}, + {"help", no_argument, 0, 'h'}, + {0, 0, 0, 0} +}; + +static void print_usage() +{ + struct option *opt = options; + while (opt->name) { + printf("\t -%c\t %s\n", (char)opt->val, opt->name); + opt++; + } + exit(EXIT_SUCCESS); +} + +int process_arguments(int argc, char *argv[]) +{ + int c; + bool elf_entry_set = false; + uint64_t elf_entry; + + while (true) { + int option_index = 0; + c = getopt_long(argc, argv, "e:n:i:b:l:C:h", options, &option_index); + + if (c == -1) break; + + switch (c) { + case 'C': { + char arg[100]; + uint64_t value; + if (sscanf(optarg, "%99[a-zA-Z0-9_-.]=0x%" PRIx64, arg, &value) == 2) { + // fprintf(stderr, "Got hex flag %s %" PRIx64 "\n", arg, value); + // do nothing + } else if (sscanf(optarg, "%99[a-zA-Z0-9_-.]=%" PRId64, arg, &value) == 2) { + // fprintf(stderr, "Got decimal flag %s %" PRIx64 "\n", arg, value); + // do nothing + } else { + fprintf(stderr, "Could not parse argument %s\n", optarg); +#ifdef HAVE_SETCONFIG + z__ListConfig(UNIT); +#endif + return -1; + }; +#ifdef HAVE_SETCONFIG + mpz_t s_value; + mpz_init_set_ui(s_value, value); + z__SetConfig(arg, s_value); + mpz_clear(s_value); +#else + fprintf(stderr, "Ignoring flag -C %s", optarg); +#endif + } + break; + + case 'b': ; + uint64_t addr; + char *file; + + if (!sscanf(optarg, "0x%" PRIx64 ",%ms", &addr, &file)) { + fprintf(stderr, "Could not parse argument %s\n", optarg); + return -1; + }; + + load_raw(addr, file); + free(file); + break; + + case 'i': + load_image(optarg); + break; + + case 'e': + load_elf(optarg, NULL, NULL); + break; + + case 'n': + if (!sscanf(optarg, "0x%" PRIx64, &elf_entry)) { + fprintf(stderr, "Could not parse address %s\n", optarg); + return -1; + } + elf_entry_set = true; + break; + + case 'l': + if (!sscanf(optarg, "%" PRId64, &g_cycle_limit)) { + fprintf(stderr, "Could not parse cycle limit %s\n", optarg); + return -1; + } + break; + + case 'v': + if (!sscanf(optarg, "0x%" PRIx64, &g_verbosity)) { + fprintf(stderr, "Could not parse verbosity flags %s\n", optarg); + return -1; + } + break; + + case 'h': + print_usage(); + break; + + default: + fprintf(stderr, "Unrecognized option %s\n", optarg); + print_usage(); + return -1; + } + } + + // assignment to g_elf_entry is deferred until the end of file so that an + // explicit command line flag will override the address read from the ELF + // file. + if (elf_entry_set) { + g_elf_entry = elf_entry; + } + + return 0; +} + +/* ***** Setup and cleanup functions for RTS ***** */ + +void setup_rts(void) +{ + mpz_init(write_buf); + mpz_init(read_buf); + return; +} + +void cleanup_rts(void) +{ + mpz_clear(write_buf); + mpz_clear(read_buf); + kill_mem(); +} diff --git a/lib/int128/rts.h b/lib/int128/rts.h new file mode 100644 index 00000000..4c5375ae --- /dev/null +++ b/lib/int128/rts.h @@ -0,0 +1,108 @@ +#pragma once + +#include<inttypes.h> +#include<stdlib.h> +#include<stdio.h> + +#include"sail.h" + +/* + * This function should be called whenever a pattern match failure + * occurs. Pattern match failures are always fatal. + */ +void sail_match_failure(sail_string msg); + +/* + * sail_assert implements the assert construct in Sail. If any + * assertion fails we immediately exit the model. + */ +unit sail_assert(bool b, sail_string msg); + +unit sail_exit(unit); + +/* + * sail_get_verbosity reads a 64-bit value that the C runtime allows you to set + * on the command line. + * The intention is that you can use individual bits to turn on/off different + * pieces of debugging output. + */ +fbits sail_get_verbosity(const unit u); + +/* + * Put processor to sleep until an external device calls wakeup_request(). + */ +unit sleep_request(const unit u); + +/* + * Stop processor sleeping. + * (Typically called when a device generates an interrupt.) + */ +unit wakeup_request(const unit u); + +/* + * Test whether processor is sleeping. + * (Typically used to disable execution of instructions.) + */ +bool sleeping(const unit u); + +/* ***** Memory builtins ***** */ + +void write_mem(uint64_t, uint64_t); +uint64_t read_mem(uint64_t); + +// These memory builtins are intended to match the semantics for the +// __ReadRAM and __WriteRAM functions in ASL. + +bool write_ram(const sail_int addr_size, // Either 32 or 64 + const sail_int data_size_mpz, // Number of bytes + const lbits hex_ram, // Currently unused + const lbits addr_bv, + const lbits data); + +void read_ram(lbits *data, + const sail_int addr_size, + const sail_int data_size_mpz, + const lbits hex_ram, + const lbits addr_bv); + +sbits fast_read_ram(const int64_t data_size, + const uint64_t addr_bv); + +unit write_tag_bool(const fbits, const bool); +bool read_tag_bool(const fbits); + +unit load_raw(fbits addr, const sail_string file); + +void load_image(char *); + +/* + * Functions for counting and limiting cycles + */ + +// increment cycle count and test if over limit +bool cycle_limit_reached(const unit); + +// increment cycle count and abort if over +unit cycle_count(const unit); + +// read cycle count +sail_int get_cycle_count(const unit); + +/* + * Functions to get info from ELF files. + */ + +sail_int elf_entry(const unit u); +sail_int elf_tohost(const unit u); + +int process_arguments(int, char**); + +/* + * setup_rts and cleanup_rts are responsible for calling setup_library + * and cleanup_library in sail.h. + */ +void setup_rts(void); +void cleanup_rts(void); + +unit z__SetConfig(sail_string, sail_int); +unit z__ListConfig(const unit u); diff --git a/lib/int128/sail.c b/lib/int128/sail.c new file mode 100644 index 00000000..6113c7f2 --- /dev/null +++ b/lib/int128/sail.c @@ -0,0 +1,1558 @@ +#define _GNU_SOURCE +#include<assert.h> +#include<inttypes.h> +#include<stdbool.h> +#include<stdio.h> +#include<stdlib.h> +#include<string.h> +#include<time.h> + +#include <x86intrin.h> + +#include"sail.h" + +void mpz_set_si128(mpz_t rop, __int128 op) +{ + mpz_set_si(rop, (int64_t) (op >> (__int128) 64)); + mpz_mul_2exp(rop, rop, 64); + mpz_add_ui(rop, rop, (uint64_t) op); +} + +void mpz_init_set_si128(mpz_t rop, __int128 op) +{ + mpz_init(rop); + mpz_set_si128(rop, op); +} + +bool EQUAL(unit)(const unit a, const unit b) +{ + return true; +} + +unit UNDEFINED(unit)(const unit u) +{ + return UNIT; +} + +unit skip(const unit u) +{ + return UNIT; +} + +/* ***** Sail bit type ***** */ + +bool eq_bit(const fbits a, const fbits b) +{ + return a == b; +} + +/* ***** Sail booleans ***** */ + +bool not(const bool b) { + return !b; +} + +bool EQUAL(bool)(const bool a, const bool b) { + return a == b; +} + +bool UNDEFINED(bool)(const unit u) { + return false; +} + +/* ***** Sail strings ***** */ + +void CREATE(sail_string)(sail_string *str) +{ + char *istr = (char *) malloc(1 * sizeof(char)); + istr[0] = '\0'; + *str = istr; +} + +void RECREATE(sail_string)(sail_string *str) +{ + free(*str); + char *istr = (char *) malloc(1 * sizeof(char)); + istr[0] = '\0'; + *str = istr; +} + +void COPY(sail_string)(sail_string *str1, const sail_string str2) +{ + size_t len = strlen(str2); + *str1 = realloc(*str1, len + 1); + *str1 = strcpy(*str1, str2); +} + +void KILL(sail_string)(sail_string *str) +{ + free(*str); +} + +void dec_str(sail_string *str, const sail_int n) +{ + if (INT64_MIN <= n && n <= INT64_MAX) { + int ret = asprintf(str, "%" PRId64, (int64_t) n); + if (ret == -1) { + printf("dec_str failed"); + exit(1); + } + } else { + printf("dec_str"); + exit(1); + } +} + +void hex_str(sail_string *str, const sail_int n) +{ + //free(*str); + //gmp_asprintf(str, "0x%Zx", n); +} + +bool eq_string(const sail_string str1, const sail_string str2) +{ + return strcmp(str1, str2) == 0; +} + +bool EQUAL(sail_string)(const sail_string str1, const sail_string str2) +{ + return strcmp(str1, str2) == 0; +} + +void undefined_string(sail_string *str, const unit u) {} + +void concat_str(sail_string *stro, const sail_string str1, const sail_string str2) +{ + *stro = realloc(*stro, strlen(str1) + strlen(str2) + 1); + (*stro)[0] = '\0'; + strcat(*stro, str1); + strcat(*stro, str2); +} + +bool string_startswith(sail_string s, sail_string prefix) +{ + return strstr(s, prefix) == s; +} + +sail_int string_length(sail_string s) +{ + return (sail_int) strlen(s); +} + +void string_drop(sail_string *dst, sail_string s, sail_int ns) +{ + size_t len = strlen(s); + mach_int n = CREATE_OF(mach_int, sail_int)(ns); + if (len >= n) { + *dst = realloc(*dst, (len - n) + 1); + memcpy(*dst, s + n, len - n); + (*dst)[len - n] = '\0'; + } else { + *dst = realloc(*dst, 1); + **dst = '\0'; + } +} + +void string_take(sail_string *dst, sail_string s, sail_int ns) +{ + size_t len = strlen(s); + mach_int n = CREATE_OF(mach_int, sail_int)(ns); + mach_int to_copy; + if (len <= n) { + to_copy = len; + } else { + to_copy = n; + } + *dst = realloc(*dst, to_copy + 1); + memcpy(*dst, s, to_copy); + *dst[to_copy] = '\0'; +} + +/* ***** Sail integers ***** */ + +uint64_t sail_int_get_ui(const sail_int op) +{ + return (uint64_t) op; +} + +bool EQUAL(mach_int)(const mach_int op1, const mach_int op2) +{ + return op1 == op2; +} + +sail_int CREATE_OF(sail_int, mach_int)(const mach_int op) +{ + return (sail_int) op; +} + +mach_int CREATE_OF(mach_int, sail_int)(const sail_int op) +{ + return (mach_int) op; +} + +mach_int CONVERT_OF(mach_int, sail_int)(const sail_int op) +{ + return (mach_int) op; +} + +sail_int CONVERT_OF(sail_int, mach_int)(const mach_int op) +{ + return (sail_int) op; +} + +sail_int CONVERT_OF(sail_int, sail_string)(const sail_string str) +{ + mpz_t tmp; + mpz_init(tmp); + mpz_set_str(tmp, str, 10); + uint64_t lo = mpz_get_ui(tmp); + mpz_div_2exp(tmp, tmp, 64); + uint64_t hi = mpz_get_ui(tmp); + mpz_clear(tmp); + + unsigned __int128 r = (((unsigned __int128) hi) << 64) + ((unsigned __int128) lo); + return (__int128) r; +} + +bool eq_int(const sail_int op1, const sail_int op2) +{ + return op1 == op2; +} + +bool EQUAL(sail_int)(const sail_int op1, const sail_int op2) +{ + return op1 == op2; +} + +bool lt(const sail_int op1, const sail_int op2) +{ + return op1 < op2; +} + +bool gt(const sail_int op1, const sail_int op2) +{ + return op1 > op2; +} + +bool lteq(const sail_int op1, const sail_int op2) +{ + return op1 <= op2; +} + +bool gteq(const sail_int op1, const sail_int op2) +{ + return op1 >= op2; +} + +sail_int shl_int(const sail_int op1, const sail_int op2) +{ + return op1 << op2; +} + +mach_int shl_mach_int(const mach_int op1, const mach_int op2) +{ + return op1 << op2; +} + +sail_int shr_int(const sail_int op1, const sail_int op2) +{ + return op1 >> op2; +} + +mach_int shr_mach_int(const mach_int op1, const mach_int op2) +{ + return op1 >> op2; +} + +sail_int undefined_int(const int n) +{ + return (__int128) n; +} + +sail_int undefined_range(const sail_int l, const sail_int u) +{ + return l; +} + +sail_int add_int(const sail_int op1, const sail_int op2) +{ + return op1 + op2; +} + +sail_int sub_int(const sail_int op1, const sail_int op2) +{ + return op1 - op2; +} + +sail_int sub_nat(const sail_int op1, const sail_int op2) +{ + sail_int rop = op1 - op2; + if (rop < 0) return (sail_int) 0; + return rop; +} + +sail_int mult_int(const sail_int op1, const sail_int op2) +{ + return op1 * op2; +} + +// FIXME: Make sure all division operators do the right thing +sail_int ediv_int(const sail_int op1, const sail_int op2) +{ + return op1 / op2; +} + +sail_int emod_int(const sail_int op1, const sail_int op2) +{ + return op1 % op2; +} + +sail_int tdiv_int(const sail_int op1, const sail_int op2) +{ + return op1 / op2; +} + +sail_int tmod_int(const sail_int op1, const sail_int op2) +{ + return op1 % op2; +} + +sail_int max_int(const sail_int op1, const sail_int op2) +{ + if (op1 < op2) { + return op2; + } else { + return op1; + } +} + +sail_int min_int(const sail_int op1, const sail_int op2) +{ + if (op1 > op2) { + return op2; + } else { + return op1; + } +} + +sail_int neg_int(const sail_int op) +{ + return -op; +} + +sail_int abs_int(const sail_int op) +{ + if (op < 0) { + return -op; + } else { + return op; + } +} + +sail_int pow_int(sail_int base, sail_int exp) +{ + sail_int result = 1; + while (true) + { + if (exp & 1) { + result *= base; + } + exp >>= 1; + if (!exp) { + break; + } + base *= base; + } + return result; +} + +sail_int pow2(const sail_int exp) +{ + return pow_int(2, exp); +} + +/* ***** Sail bitvectors ***** */ + +bool EQUAL(fbits)(const fbits op1, const fbits op2) +{ + return op1 == op2; +} + +void CREATE(lbits)(lbits *rop) +{ + rop->bits = malloc(sizeof(mpz_t)); + rop->len = 0; + mpz_init(*rop->bits); +} + +void RECREATE(lbits)(lbits *rop) +{ + rop->len = 0; + mpz_set_ui(*rop->bits, 0); +} + +void COPY(lbits)(lbits *rop, const lbits op) +{ + rop->len = op.len; + mpz_set(*rop->bits, *op.bits); +} + +void KILL(lbits)(lbits *rop) +{ + mpz_clear(*rop->bits); + free(rop->bits); +} + +void CREATE_OF(lbits, fbits)(lbits *rop, const uint64_t op, const uint64_t len, const bool direction) +{ + rop->bits = malloc(sizeof(mpz_t)); + rop->len = len; + mpz_init_set_ui(*rop->bits, op); +} + +fbits CREATE_OF(fbits, lbits)(const lbits op, const bool direction) +{ + return mpz_get_ui(*op.bits); +} + +sbits CREATE_OF(sbits, lbits)(const lbits op, const bool direction) +{ + sbits rop; + rop.bits = mpz_get_ui(*op.bits); + rop.len = op.len; + return rop; +} + +sbits CREATE_OF(sbits, fbits)(const fbits op, const uint64_t len, const bool direction) +{ + sbits rop; + rop.bits = op; + rop.len = len; + return rop; +} + +void RECREATE_OF(lbits, fbits)(lbits *rop, const uint64_t op, const uint64_t len, const bool direction) +{ + rop->len = len; + mpz_set_ui(*rop->bits, op); +} + +void CREATE_OF(lbits, sbits)(lbits *rop, const sbits op, const bool direction) +{ + rop->bits = malloc(sizeof(mpz_t)); + rop->len = op.len; + mpz_init_set_ui(*rop->bits, op.bits); +} + +void RECREATE_OF(lbits, sbits)(lbits *rop, const sbits op, const bool direction) +{ + rop->len = op.len; + mpz_set_ui(*rop->bits, op.bits); +} + +// Bitvector conversions + +fbits CONVERT_OF(fbits, lbits)(const lbits op, const bool direction) +{ + return mpz_get_ui(*op.bits); +} + +fbits CONVERT_OF(fbits, sbits)(const sbits op, const bool direction) +{ + return op.bits; +} + +void CONVERT_OF(lbits, fbits)(lbits *rop, const fbits op, const uint64_t len, const bool direction) +{ + rop->len = len; + // use safe_rshift to correctly handle the case when we have a 0-length vector. + mpz_set_ui(*rop->bits, op & safe_rshift(UINT64_MAX, 64 - len)); +} + +void CONVERT_OF(lbits, sbits)(lbits *rop, const sbits op, const bool direction) +{ + rop->len = op.len; + mpz_set_ui(*rop->bits, op.bits & safe_rshift(UINT64_MAX, 64 - op.len)); +} + +inline +sbits CONVERT_OF(sbits, fbits)(const fbits op, const uint64_t len, const bool direction) +{ + sbits rop; + rop.len = len; + rop.bits = op; + return rop; +} + +inline +sbits CONVERT_OF(sbits, lbits)(const lbits op, const bool direction) +{ + sbits rop; + rop.len = op.len; + rop.bits = mpz_get_ui(*op.bits); + return rop; +} + +void UNDEFINED(lbits)(lbits *rop, const sail_int len, const fbits bit) +{ + zeros(rop, len); +} + +fbits UNDEFINED(fbits)(const unit u) { return 0; } + +sbits undefined_sbits(void) +{ + sbits rop; + rop.bits = UINT64_C(0); + rop.len = UINT64_C(0); + return rop; +} + +fbits safe_rshift(const fbits x, const fbits n) +{ + if (n >= 64) { + return 0ul; + } else { + return x >> n; + } +} + +void normalize_lbits(lbits *rop) +{ + mpz_t tmp; + mpz_init(tmp); + + mpz_set_ui(tmp, 1); + mpz_mul_2exp(tmp, tmp, rop->len); + mpz_sub_ui(tmp, tmp, 1); + mpz_and(*rop->bits, *rop->bits, tmp); + + mpz_clear(tmp); +} + +void append_64(lbits *rop, const lbits op, const fbits chunk) +{ + rop->len = rop->len + 64ul; + mpz_mul_2exp(*rop->bits, *op.bits, 64ul); + mpz_add_ui(*rop->bits, *rop->bits, chunk); +} + +void add_bits(lbits *rop, const lbits op1, const lbits op2) +{ + rop->len = op1.len; + mpz_add(*rop->bits, *op1.bits, *op2.bits); + normalize_lbits(rop); +} + +void sub_bits(lbits *rop, const lbits op1, const lbits op2) +{ + assert(op1.len == op2.len); + rop->len = op1.len; + mpz_sub(*rop->bits, *op1.bits, *op2.bits); + normalize_lbits(rop); +} + +void add_bits_int(lbits *rop, const lbits op1, const sail_int op2) +{ + assert(op2 >= 0); + rop->len = op1.len; + mpz_add_ui(*rop->bits, *op1.bits, (uint64_t) op2); + normalize_lbits(rop); +} + +void sub_bits_int(lbits *rop, const lbits op1, const sail_int op2) +{ + assert(op2 >= 0); + rop->len = op1.len; + mpz_sub_ui(*rop->bits, *op1.bits, (uint64_t) op2); + normalize_lbits(rop); +} + +void and_bits(lbits *rop, const lbits op1, const lbits op2) +{ + assert(op1.len == op2.len); + rop->len = op1.len; + mpz_and(*rop->bits, *op1.bits, *op2.bits); +} + +void or_bits(lbits *rop, const lbits op1, const lbits op2) +{ + assert(op1.len == op2.len); + rop->len = op1.len; + mpz_ior(*rop->bits, *op1.bits, *op2.bits); +} + +void xor_bits(lbits *rop, const lbits op1, const lbits op2) +{ + assert(op1.len == op2.len); + rop->len = op1.len; + mpz_xor(*rop->bits, *op1.bits, *op2.bits); +} + +void not_bits(lbits *rop, const lbits op) +{ + rop->len = op.len; + mpz_set(*rop->bits, *op.bits); + for (mp_bitcnt_t i = 0; i < op.len; i++) { + mpz_combit(*rop->bits, i); + } +} + +void mults_vec(lbits *rop, const lbits op1, const lbits op2) +{ + return; +} + +void mult_vec(lbits *rop, const lbits op1, const lbits op2) +{ + rop->len = op1.len * 2; + mpz_mul(*rop->bits, *op1.bits, *op2.bits); + normalize_lbits(rop); /* necessary? */ +} + + +void zeros(lbits *rop, const sail_int op) +{ + rop->len = (mp_bitcnt_t) op; + mpz_set_ui(*rop->bits, 0); +} + +void zero_extend(lbits *rop, const lbits op, const sail_int len) +{ + assert(op.len <= (uint64_t) len); + rop->len = (uint64_t) len; + mpz_set(*rop->bits, *op.bits); +} + +fbits fast_zero_extend(const sbits op, const uint64_t n) +{ + return op.bits; +} + +void sign_extend(lbits *rop, const lbits op, const sail_int len) +{ + assert(op.len <= (uint64_t) len); + rop->len = (uint64_t) len; + if(mpz_tstbit(*op.bits, op.len - 1)) { + mpz_set(*rop->bits, *op.bits); + for(mp_bitcnt_t i = rop->len - 1; i >= op.len; i--) { + mpz_setbit(*rop->bits, i); + } + } else { + mpz_set(*rop->bits, *op.bits); + } +} + +fbits fast_sign_extend(const fbits op, const uint64_t n, const uint64_t m) +{ + uint64_t rop = op; + if (op & (UINT64_C(1) << (n - 1))) { + for (uint64_t i = m - 1; i >= n; i--) { + rop = rop | (UINT64_C(1) << i); + } + return rop; + } else { + return rop; + } +} + +fbits fast_sign_extend2(const sbits op, const uint64_t m) +{ + uint64_t rop = op.bits; + if (op.bits & (UINT64_C(1) << (op.len - 1))) { + for (uint64_t i = m - 1; i >= op.len; i--) { + rop = rop | (UINT64_C(1) << i); + } + return rop; + } else { + return rop; + } +} + +sail_int length_lbits(const lbits op) +{ + return (sail_int) op.len; +} + +bool eq_bits(const lbits op1, const lbits op2) +{ + assert(op1.len == op2.len); + for (mp_bitcnt_t i = 0; i < op1.len; i++) { + if (mpz_tstbit(*op1.bits, i) != mpz_tstbit(*op2.bits, i)) return false; + } + return true; +} + +bool EQUAL(lbits)(const lbits op1, const lbits op2) +{ + return eq_bits(op1, op2); +} + +bool neq_bits(const lbits op1, const lbits op2) +{ + assert(op1.len == op2.len); + for (mp_bitcnt_t i = 0; i < op1.len; i++) { + if (mpz_tstbit(*op1.bits, i) != mpz_tstbit(*op2.bits, i)) return true; + } + return false; +} + +void vector_subrange_lbits(lbits *rop, + const lbits op, + const sail_int n_mpz, + const sail_int m_mpz) +{ + uint64_t n = (uint64_t) n_mpz; + uint64_t m = (uint64_t) m_mpz; + + rop->len = n - (m - 1ul); + mpz_fdiv_q_2exp(*rop->bits, *op.bits, m); + normalize_lbits(rop); +} + +void sail_truncate(lbits *rop, const lbits op, const sail_int len) +{ + rop->len = (mp_bitcnt_t) len; + mpz_set(*rop->bits, *op.bits); + normalize_lbits(rop); +} + +void sail_truncateLSB(lbits *rop, const lbits op, const sail_int len) +{ + uint64_t rlen = (uint64_t) len; + assert(op.len >= rlen); + rop->len = rlen; + // similar to vector_subrange_lbits above -- right shift LSBs away + mpz_fdiv_q_2exp(*rop->bits, *op.bits, op.len - rlen); + normalize_lbits(rop); +} + +fbits bitvector_access(const lbits op, const sail_int n) +{ + return (fbits) mpz_tstbit(*op.bits, (uint64_t) n); +} + +sail_int sail_unsigned(const lbits op) +{ + return (sail_int) mpz_get_ui(*op.bits); +} + +sail_int sail_signed(const lbits op) +{ + if (op.len <= 64) { + uint64_t b = mpz_get_ui(*op.bits); + uint64_t sign_bit = UINT64_C(1) << (op.len - UINT64_C(1)); + if ((b & sign_bit) > 0) { + return ((sail_int) (b & ~sign_bit)) - ((sail_int) sign_bit); + } else { + return (sail_int) b; + } + } else if (op.len <= 128) { + uint64_t b_lo = mpz_get_ui(*op.bits); + mpz_t tmp; + mpz_init(tmp); + mpz_tdiv_q_2exp(tmp, *op.bits, 64); + uint64_t b_hi = mpz_get_ui(tmp); + mpz_clear(tmp); + uint64_t sign_bit = UINT64_C(1) << (op.len - UINT64_C(65)); + if (b_hi & sign_bit) { + unsigned __int128 b = b_hi & ~sign_bit; + b <<= 64; + b |= (unsigned __int128) b_lo; + unsigned __int128 sb = (unsigned __int128) sign_bit << (unsigned __int128) 64; + return (sail_int) b + (sail_int) (~sb + 1); + } else { + unsigned __int128 b = b_hi; + b <<= 64; + b |= (unsigned __int128) b_lo; + return (__int128) b; + } + } else { + printf("sail_signed >128\n"); + exit(1); + } +} + +mach_int fast_unsigned(const fbits op) +{ + return (mach_int) op; +} + +mach_int fast_signed(const fbits op, const uint64_t n) +{ + if (op & (UINT64_C(1) << (n - 1))) { + uint64_t rop = op & ~(UINT64_C(1) << (n - 1)); + return (mach_int) (rop - (UINT64_C(1) << (n - 1))); + } else { + return (mach_int) op; + } +} + +void append(lbits *rop, const lbits op1, const lbits op2) +{ + rop->len = op1.len + op2.len; + mpz_mul_2exp(*rop->bits, *op1.bits, op2.len); + mpz_ior(*rop->bits, *rop->bits, *op2.bits); +} + +sbits append_sf(const sbits op1, const fbits op2, const uint64_t len) +{ + sbits rop; + rop.bits = (op1.bits << len) | op2; + rop.len = op1.len + len; + return rop; +} + +sbits append_fs(const fbits op1, const uint64_t len, const sbits op2) +{ + sbits rop; + rop.bits = (op1 << op2.len) | op2.bits; + rop.len = len + op2.len; + return rop; +} + +sbits append_ss(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = (op1.bits << op2.len) | op2.bits; + rop.len = op1.len + op2.len; + return rop; +} + +void replicate_bits(lbits *rop, const lbits op1, const sail_int op2) +{ + uint64_t op2_ui = (uint64_t) op2; + rop->len = op1.len * op2_ui; + mpz_set_ui(*rop->bits, 0); + for (int i = 0; i < op2_ui; i++) { + mpz_mul_2exp(*rop->bits, *rop->bits, op1.len); + mpz_ior(*rop->bits, *rop->bits, *op1.bits); + } +} + +uint64_t fast_replicate_bits(const uint64_t shift, const uint64_t v, const int64_t times) +{ + uint64_t r = v; + for (int i = 1; i < times; ++i) { + r |= (r << shift); + } + return r; +} + +// Takes a slice of the (two's complement) binary representation of +// integer n, starting at bit start, and of length len. With the +// argument in the following order: +// +// get_slice_int(len, n, start) +// +// For example: +// +// get_slice_int(8, 1680, 4) = +// +// 11 0 +// V V +// get_slice_int(8, 0b0110_1001_0000, 4) = 0b0110_1001 +// <-------^ +// (8 bit) 4 +// +__attribute__((target ("bmi2"))) +void get_slice_int(lbits *rop, const sail_int len, const sail_int n, const sail_int start) +{ + assert(len <= 128); + + unsigned __int128 nbits = (unsigned __int128) (n >> start); + + if (len <= 64) { + mpz_set_ui(*rop->bits, _bzhi_u64((uint64_t) nbits, (uint64_t) len)); + rop->len = (uint64_t) len; + } else { + print("get_slice_int"); + exit(1); + } +} + +sail_int set_slice_int(const sail_int len, const sail_int n, const sail_int start, const lbits slice) +{ + printf("set_slice_int"); + exit(1); + return 0; +} + +void update_lbits(lbits *rop, const lbits op, const sail_int n_mpz, const uint64_t bit) +{ + uint64_t n = (uint64_t) n_mpz; + + mpz_set(*rop->bits, *op.bits); + rop->len = op.len; + + if (bit == UINT64_C(0)) { + mpz_clrbit(*rop->bits, n); + } else { + mpz_setbit(*rop->bits, n); + } +} + +void vector_update_subrange_lbits(lbits *rop, + const lbits op, + const sail_int n_mpz, + const sail_int m_mpz, + const lbits slice) +{ + uint64_t n = (uint64_t) n_mpz; + uint64_t m = (uint64_t) m_mpz; + + mpz_set(*rop->bits, *op.bits); + rop->len = op.len; + + for (uint64_t i = 0; i < n - (m - 1ul); i++) { + if (mpz_tstbit(*slice.bits, i)) { + mpz_setbit(*rop->bits, i + m); + } else { + mpz_clrbit(*rop->bits, i + m); + } + } +} + +fbits fast_update_subrange(const fbits op, + const mach_int n, + const mach_int m, + const fbits slice) +{ + fbits rop = op; + for (mach_int i = 0; i < n - (m - UINT64_C(1)); i++) { + uint64_t bit = UINT64_C(1) << ((uint64_t) i); + if (slice & bit) { + rop |= (bit << m); + } else { + rop &= ~(bit << m); + } + } + return rop; +} + +__attribute__((target ("bmi2"))) +void slice(lbits *rop, const lbits op, const sail_int start_big, const sail_int len_big) +{ + uint64_t start = (uint64_t) start_big; + uint64_t len = (uint64_t) len_big; + + if (len + start <= 64) { + mpz_set_ui(*rop->bits, _bzhi_u64(mpz_get_ui(*op.bits) >> start, len)); + rop->len = len; + } else { + mpz_set_ui(*rop->bits, 0); + rop->len = len; + + for (uint64_t i = 0; i < len; i++) { + if (mpz_tstbit(*op.bits, i + start)) mpz_setbit(*rop->bits, i); + } + } +} + +__attribute__((target ("bmi2"))) +sbits sslice(const fbits op, const mach_int start, const mach_int len) +{ + sbits rop; + rop.bits = _bzhi_u64(op >> start, (uint64_t) len); + rop.len = (uint64_t) len; + return rop; +} + +void set_slice(lbits *rop, + const sail_int len_mpz, + const sail_int slen_mpz, + const lbits op, + const sail_int start_mpz, + const lbits slice) +{ + uint64_t start = (uint64_t) start_mpz; + + mpz_set(*rop->bits, *op.bits); + rop->len = op.len; + + for (uint64_t i = 0; i < slice.len; i++) { + if (mpz_tstbit(*slice.bits, i)) { + mpz_setbit(*rop->bits, i + start); + } else { + mpz_clrbit(*rop->bits, i + start); + } + } +} + +void shift_bits_left(lbits *rop, const lbits op1, const lbits op2) +{ + rop->len = op1.len; + mpz_mul_2exp(*rop->bits, *op1.bits, mpz_get_ui(*op2.bits)); + normalize_lbits(rop); +} + +void shift_bits_right(lbits *rop, const lbits op1, const lbits op2) +{ + rop->len = op1.len; + mpz_tdiv_q_2exp(*rop->bits, *op1.bits, mpz_get_ui(*op2.bits)); +} + +/* FIXME */ +void shift_bits_right_arith(lbits *rop, const lbits op1, const lbits op2) +{ + rop->len = op1.len; + mp_bitcnt_t shift_amt = mpz_get_ui(*op2.bits); + mp_bitcnt_t sign_bit = op1.len - 1; + mpz_fdiv_q_2exp(*rop->bits, *op1.bits, shift_amt); + if(mpz_tstbit(*op1.bits, sign_bit) != 0) { + /* */ + for(; shift_amt > 0; shift_amt--) { + mpz_setbit(*rop->bits, sign_bit - shift_amt + 1); + } + } +} + +void shiftl(lbits *rop, const lbits op1, const sail_int op2) +{ + rop->len = op1.len; + mpz_mul_2exp(*rop->bits, *op1.bits, (uint64_t) op2); + normalize_lbits(rop); +} + +void shiftr(lbits *rop, const lbits op1, const sail_int op2) +{ + rop->len = op1.len; + mpz_tdiv_q_2exp(*rop->bits, *op1.bits, (uint64_t) op2); +} + +void reverse_endianness(lbits *rop, const lbits op) +{ + rop->len = op.len; + if (rop->len == 64ul) { + uint64_t x = mpz_get_ui(*op.bits); + x = (x & 0xFFFFFFFF00000000) >> 32 | (x & 0x00000000FFFFFFFF) << 32; + x = (x & 0xFFFF0000FFFF0000) >> 16 | (x & 0x0000FFFF0000FFFF) << 16; + x = (x & 0xFF00FF00FF00FF00) >> 8 | (x & 0x00FF00FF00FF00FF) << 8; + mpz_set_ui(*rop->bits, x); + } else if (rop->len == 32ul) { + uint64_t x = mpz_get_ui(*op.bits); + x = (x & 0xFFFF0000FFFF0000) >> 16 | (x & 0x0000FFFF0000FFFF) << 16; + x = (x & 0xFF00FF00FF00FF00) >> 8 | (x & 0x00FF00FF00FF00FF) << 8; + mpz_set_ui(*rop->bits, x); + } else if (rop->len == 16ul) { + uint64_t x = mpz_get_ui(*op.bits); + x = (x & 0xFF00FF00FF00FF00) >> 8 | (x & 0x00FF00FF00FF00FF) << 8; + mpz_set_ui(*rop->bits, x); + } else if (rop->len == 8ul) { + mpz_set(*rop->bits, *op.bits); + } else { + mpz_t tmp1; + mpz_t tmp2; + mpz_init(tmp1); + mpz_init(tmp2); + + /* For other numbers of bytes we reverse the bytes. + * XXX could use mpz_import/export for this. */ + mpz_set_ui(tmp1, 0xff); // byte mask + mpz_set_ui(*rop->bits, 0); // reset accumulator for result + for(mp_bitcnt_t byte = 0; byte < op.len; byte+=8) { + mpz_tdiv_q_2exp(tmp2, *op.bits, byte); // shift byte to bottom + mpz_and(tmp2, tmp2, tmp1); // and with mask + mpz_mul_2exp(*rop->bits, *rop->bits, 8); // shift result left 8 + mpz_ior(*rop->bits, *rop->bits, tmp2); // or byte into result + } + } +} + +bool eq_sbits(const sbits op1, const sbits op2) +{ + return op1.bits == op2.bits; +} + +bool neq_sbits(const sbits op1, const sbits op2) +{ + return op1.bits != op2.bits; +} + +__attribute__((target ("bmi2"))) +sbits not_sbits(const sbits op) +{ + sbits rop; + rop.bits = (~op.bits) & _bzhi_u64(UINT64_MAX, op.len); + rop.len = op.len; + return rop; +} + +sbits xor_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = op1.bits ^ op2.bits; + rop.len = op1.len; + return rop; +} + +sbits or_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = op1.bits | op2.bits; + rop.len = op1.len; + return rop; +} + +sbits and_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = op1.bits & op2.bits; + rop.len = op1.len; + return rop; +} + +__attribute__((target ("bmi2"))) +sbits add_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = (op1.bits + op2.bits) & _bzhi_u64(UINT64_MAX, op1.len); + rop.len = op1.len; + return rop; +} + +__attribute__((target ("bmi2"))) +sbits sub_sbits(const sbits op1, const sbits op2) +{ + sbits rop; + rop.bits = (op1.bits - op2.bits) & _bzhi_u64(UINT64_MAX, op1.len); + rop.len = op1.len; + return rop; +} + +/* ***** Sail Reals ***** */ + +void CREATE(real)(real *rop) +{ + mpq_init(*rop); +} + +void RECREATE(real)(real *rop) +{ + mpq_set_ui(*rop, 0, 1); +} + +void KILL(real)(real *rop) +{ + mpq_clear(*rop); +} + +void COPY(real)(real *rop, const real op) +{ + mpq_set(*rop, op); +} + +void UNDEFINED(real)(real *rop, unit u) +{ + mpq_set_ui(*rop, 0, 1); +} + +void neg_real(real *rop, const real op) +{ + mpq_neg(*rop, op); +} + +void mult_real(real *rop, const real op1, const real op2) { + mpq_mul(*rop, op1, op2); +} + +void sub_real(real *rop, const real op1, const real op2) +{ + mpq_sub(*rop, op1, op2); +} + +void add_real(real *rop, const real op1, const real op2) +{ + mpq_add(*rop, op1, op2); +} + +void div_real(real *rop, const real op1, const real op2) +{ + mpq_div(*rop, op1, op2); +} + +#define SQRT_PRECISION 30 + +/* + * sqrt_real first checks if op has the form n/1 - in which case, if n + * is a perfect square (i.e. it's square root is an integer), then it + * will return the exact square root using mpz_sqrt. If that's not the + * case we use the Babylonian method to calculate the square root to + * SQRT_PRECISION decimal places. + */ +void sqrt_real(real *rop, const real op) +{ + /* First check if op is a perfect square and use mpz_sqrt if so */ + if (mpz_cmp_ui(mpq_denref(op), 1) == 0 && mpz_perfect_square_p(mpq_numref(op))) { + mpz_sqrt(mpq_numref(*rop), mpq_numref(op)); + mpz_set_ui(mpq_denref(*rop), 1); + return; + } + + mpq_t tmp; + mpz_t tmp_z; + mpq_t p; /* previous estimate, p */ + mpq_t n; /* next estimate, n */ + /* convergence is the precision (in decimal places) we want to reach as a fraction 1/(10^precision) */ + mpq_t convergence; + + mpq_init(tmp); + mpz_init(tmp_z); + mpq_init(p); + mpq_init(n); + mpq_init(convergence); + + /* calculate an initial guess using mpz_sqrt */ + mpz_cdiv_q(tmp_z, mpq_numref(op), mpq_denref(op)); + mpz_sqrt(tmp_z, tmp_z); + mpq_set_z(p, tmp_z); + + /* initialise convergence based on SQRT_PRECISION */ + mpz_set_ui(tmp_z, 10); + mpz_pow_ui(tmp_z, tmp_z, SQRT_PRECISION); + mpz_set_ui(mpq_numref(convergence), 1); + mpq_set_den(convergence, tmp_z); + + while (true) { + // n = (p + op / p) / 2 + mpq_div(tmp, op, p); + mpq_add(tmp, tmp, p); + mpq_div_2exp(n, tmp, 1); + + /* calculate the difference between n and p */ + mpq_sub(tmp, p, n); + mpq_abs(tmp, tmp); + + /* if the difference is small enough, return */ + if (mpq_cmp(tmp, convergence) < 0) { + mpq_set(*rop, n); + break; + } + + mpq_swap(n, p); + } + + mpq_clear(tmp); + mpz_clear(tmp_z); + mpq_clear(p); + mpq_clear(n); + mpq_clear(convergence); +} + +void abs_real(real *rop, const real op) +{ + mpq_abs(*rop, op); +} + +sail_int round_up(const real op) +{ + mpz_t rop; + mpz_init(rop); + mpz_cdiv_q(rop, mpq_numref(op), mpq_denref(op)); + sail_int r = mpz_get_si(rop); + mpz_clear(rop); + return r; +} + +sail_int round_down(const real op) +{ + mpz_t rop; + mpz_init(rop); + mpz_fdiv_q(rop, mpq_numref(op), mpq_denref(op)); + sail_int r = mpz_get_si(rop); + mpz_clear(rop); + return r; +} + +void to_real(real *rop, const sail_int op) +{ + mpz_t op_mpz; + mpz_init_set_si128(op_mpz, op); + + mpq_set_z(*rop, op_mpz); + mpq_canonicalize(*rop); + + mpz_clear(op_mpz); +} + +bool EQUAL(real)(const real op1, const real op2) +{ + return mpq_cmp(op1, op2) == 0; +} + +bool lt_real(const real op1, const real op2) +{ + return mpq_cmp(op1, op2) < 0; +} + +bool gt_real(const real op1, const real op2) +{ + return mpq_cmp(op1, op2) > 0; +} + +bool lteq_real(const real op1, const real op2) +{ + return mpq_cmp(op1, op2) <= 0; +} + +bool gteq_real(const real op1, const real op2) +{ + return mpq_cmp(op1, op2) >= 0; +} + +void real_power(real *rop, const real base, const sail_int exp) +{ + int64_t exp_si = (int64_t) exp; + + mpz_set_ui(mpq_numref(*rop), 1); + mpz_set_ui(mpq_denref(*rop), 1); + + real b; + mpq_init(b); + mpq_set(b, base); + int64_t pexp = llabs(exp_si); + while (pexp != 0) { + // invariant: rop * b^pexp == base^abs(exp) + if (pexp & 1) { // b^(e+1) = b * b^e + mpq_mul(*rop, *rop, b); + pexp -= 1; + } else { // b^(2e) = (b*b)^e + mpq_mul(b, b, b); + pexp >>= 1; + } + } + if (exp_si < 0) { + mpq_inv(*rop, *rop); + } + mpq_clear(b); +} + +void CREATE_OF(real, sail_string)(real *rop, const sail_string op) +{ + mpq_init(*rop); + CONVERT_OF(real, sail_string)(rop, op); +} + +void CONVERT_OF(real, sail_string)(real *rop, const sail_string op) +{ + int decimal; + int total; + mpz_t tmp1; + mpz_t tmp2; + mpz_t tmp3; + mpq_t tmp_real; + mpz_init(tmp1); + mpz_init(tmp2); + mpz_init(tmp3); + mpq_init(tmp_real); + + gmp_sscanf(op, "%Zd.%n%Zd%n", tmp1, &decimal, tmp2, &total); + + int len = total - decimal; + mpz_ui_pow_ui(tmp3, 10, len); + mpz_set(mpq_numref(*rop), tmp2); + mpz_set(mpq_denref(*rop), tmp3); + mpq_canonicalize(*rop); + mpz_set(mpq_numref(tmp_real), tmp1); + mpz_set_ui(mpq_denref(tmp_real), 1); + mpq_add(*rop, *rop, tmp_real); + + mpz_clear(tmp1); + mpz_clear(tmp2); + mpz_clear(tmp3); + mpq_clear(tmp_real); +} + +unit print_real(const sail_string str, const real op) +{ + gmp_printf("%s%Qd\n", str, op); + return UNIT; +} + +unit prerr_real(const sail_string str, const real op) +{ + gmp_fprintf(stderr, "%s%Qd\n", str, op); + return UNIT; +} + +void random_real(real *rop, const unit u) +{ + if (rand() & 1) { + mpz_set_si(mpq_numref(*rop), rand()); + } else { + mpz_set_si(mpq_numref(*rop), -rand()); + } + mpz_set_si(mpq_denref(*rop), rand()); + mpq_canonicalize(*rop); +} + +/* ***** Printing functions ***** */ + +void string_of_int(sail_string *str, const sail_int i) +{ + free(*str); + //gmp_asprintf(str, "%Zd", i); +} + +/* asprintf is a GNU extension, but it should exist on BSD */ +void string_of_fbits(sail_string *str, const fbits op) +{ + free(*str); + int bytes = asprintf(str, "0x%" PRIx64, op); + if (bytes == -1) { + fprintf(stderr, "Could not print bits 0x%" PRIx64 "\n", op); + } +} + +void string_of_lbits(sail_string *str, const lbits op) +{ + free(*str); + if ((op.len % 4) == 0) { + gmp_asprintf(str, "0x%*0ZX", op.len / 4, *op.bits); + } else { + *str = (char *) malloc((op.len + 3) * sizeof(char)); + (*str)[0] = '0'; + (*str)[1] = 'b'; + for (int i = 1; i <= op.len; ++i) { + (*str)[i + 1] = mpz_tstbit(*op.bits, op.len - i) + 0x30; + } + (*str)[op.len + 2] = '\0'; + } +} + +void decimal_string_of_fbits(sail_string *str, const fbits op) +{ + free(*str); + int bytes = asprintf(str, "%" PRId64, op); + if (bytes == -1) { + fprintf(stderr, "Could not print bits %" PRId64 "\n", op); + } +} + +void decimal_string_of_lbits(sail_string *str, const lbits op) +{ + free(*str); + gmp_asprintf(str, "%Z", *op.bits); +} + +void fprint_bits(const sail_string pre, + const lbits op, + const sail_string post, + FILE *stream) +{ + fputs(pre, stream); + + if (op.len % 4 == 0) { + fputs("0x", stream); + mpz_t buf; + mpz_init_set(buf, *op.bits); + + char *hex = malloc((op.len / 4) * sizeof(char)); + + for (int i = 0; i < op.len / 4; ++i) { + char c = (char) ((0xF & mpz_get_ui(buf)) + 0x30); + hex[i] = (c < 0x3A) ? c : c + 0x7; + mpz_fdiv_q_2exp(buf, buf, 4); + } + + for (int i = op.len / 4; i > 0; --i) { + fputc(hex[i - 1], stream); + } + + free(hex); + mpz_clear(buf); + } else { + fputs("0b", stream); + for (int i = op.len; i > 0; --i) { + fputc(mpz_tstbit(*op.bits, i - 1) + 0x30, stream); + } + } + + fputs(post, stream); +} + +unit print_bits(const sail_string str, const lbits op) +{ + fprint_bits(str, op, "\n", stdout); + return UNIT; +} + +unit prerr_bits(const sail_string str, const lbits op) +{ + fprint_bits(str, op, "\n", stderr); + return UNIT; +} + +unit print(const sail_string str) +{ + printf("%s", str); + return UNIT; +} + +unit print_endline(const sail_string str) +{ + printf("%s\n", str); + return UNIT; +} + +unit prerr(const sail_string str) +{ + fprintf(stderr, "%s", str); + return UNIT; +} + +unit prerr_endline(const sail_string str) +{ + fprintf(stderr, "%s\n", str); + return UNIT; +} + +unit print_int(const sail_string str, const sail_int op) +{ + mpz_t op_mpz; + mpz_init_set_si128(op_mpz, op); + + fputs(str, stdout); + mpz_out_str(stdout, 10, op_mpz); + putchar('\n'); + + mpz_clear(op_mpz); + return UNIT; +} + +unit prerr_int(const sail_string str, const sail_int op) +{ + fputs(str, stderr); + //mpz_out_str(stderr, 10, op); + fputs("\n", stderr); + return UNIT; +} + +unit sail_putchar(const sail_int op) +{ + char c = (char) op; + putchar(c); + fflush(stdout); + return UNIT; +} + +sail_int get_time_ns(const unit u) +{ + struct timespec t; + clock_gettime(CLOCK_REALTIME, &t); + __int128 rop = (__int128) t.tv_sec; + rop *= 1000000000; + rop += (__int128) t.tv_nsec; + return rop; +} + +// Monomorphisation +sail_int make_the_value(const sail_int op) +{ + return op; +} + +sail_int size_itself_int(const sail_int op) +{ + return op; +} diff --git a/lib/int128/sail.h b/lib/int128/sail.h new file mode 100644 index 00000000..25db5926 --- /dev/null +++ b/lib/int128/sail.h @@ -0,0 +1,404 @@ +#pragma once + +#include<inttypes.h> +#include<stdlib.h> +#include<stdio.h> +#include<stdbool.h> +#include<gmp.h> + +#include<time.h> + +/* + * Called by the RTS to initialise and clear any library state. + */ +void setup_library(void); +void cleanup_library(void); + +/* + * The Sail compiler expects functions to follow a specific naming + * convention for allocation, deallocation, and (deep)-copying. These + * macros implement this naming convention. + */ +#define CREATE(type) create_ ## type +#define RECREATE(type) recreate_ ## type +#define CREATE_OF(type1, type2) create_ ## type1 ## _of_ ## type2 +#define RECREATE_OF(type1, type2) recreate_ ## type1 ## _of_ ## type2 +#define CONVERT_OF(type1, type2) convert_ ## type1 ## _of_ ## type2 +#define COPY(type) copy_ ## type +#define KILL(type) kill_ ## type +#define UNDEFINED(type) undefined_ ## type +#define EQUAL(type) eq_ ## type + +#define SAIL_BUILTIN_TYPE(type)\ + void create_ ## type(type *);\ + void recreate_ ## type(type *);\ + void copy_ ## type(type *, const type);\ + void kill_ ## type(type *); + +/* ***** Sail unit type ***** */ + +typedef int unit; + +#define UNIT 0 + +unit UNDEFINED(unit)(const unit); +bool EQUAL(unit)(const unit, const unit); + +unit skip(const unit); + +/* ***** Sail booleans ***** */ + +/* + * and_bool and or_bool are special-cased by the compiler to ensure + * short-circuiting evaluation. + */ +bool not(const bool); +bool EQUAL(bool)(const bool, const bool); +bool UNDEFINED(bool)(const unit); + +/* ***** Sail strings ***** */ + +/* + * Sail strings are just C strings. + */ +typedef char *sail_string; + +SAIL_BUILTIN_TYPE(sail_string); + +void undefined_string(sail_string *str, const unit u); + +bool eq_string(const sail_string, const sail_string); +bool EQUAL(sail_string)(const sail_string, const sail_string); + +void concat_str(sail_string *stro, const sail_string str1, const sail_string str2); +bool string_startswith(sail_string s, sail_string prefix); + +/* ***** Sail integers ***** */ + +typedef int64_t mach_int; + +bool EQUAL(mach_int)(const mach_int, const mach_int); + +typedef __int128 sail_int; + +uint64_t sail_int_get_ui(const sail_int); + +void dec_str(sail_string *str, const sail_int n); +void hex_str(sail_string *str, const sail_int n); + +#define SAIL_INT_FUNCTION(fname, rtype, ...) rtype fname(__VA_ARGS__) + +// SAIL_BUILTIN_TYPE(sail_int); + +sail_int CREATE_OF(sail_int, mach_int)(const mach_int); +mach_int CREATE_OF(mach_int, sail_int)(const sail_int); + +mach_int CONVERT_OF(mach_int, sail_int)(const sail_int); +sail_int CONVERT_OF(sail_int, mach_int)(const mach_int); +sail_int CONVERT_OF(sail_int, sail_string)(const sail_string); + +/* + * Comparison operators for integers + */ +bool eq_int(const sail_int, const sail_int); +bool EQUAL(sail_int)(const sail_int, const sail_int); + +bool lt(const sail_int, const sail_int); +bool gt(const sail_int, const sail_int); +bool lteq(const sail_int, const sail_int); +bool gteq(const sail_int, const sail_int); + +/* + * Left and right shift for integers + */ +mach_int shl_mach_int(const mach_int, const mach_int); +mach_int shr_mach_int(const mach_int, const mach_int); +SAIL_INT_FUNCTION(shl_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(shr_int, sail_int, const sail_int, const sail_int); + +/* + * undefined_int and undefined_range can't use the UNDEFINED(TYPE) + * macro, because they're slightly magical. They take extra parameters + * to ensure that no undefined int can violate any type-guaranteed + * constraints. + */ +SAIL_INT_FUNCTION(undefined_int, sail_int, const int); +SAIL_INT_FUNCTION(undefined_range, sail_int, const sail_int, const sail_int); + +/* + * Arithmetic operations in integers. We include functions for both + * truncating towards zero, and rounding towards -infinity (floor) as + * fdiv/fmod and tdiv/tmod respectively. + */ +SAIL_INT_FUNCTION(add_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(sub_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(sub_nat, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(mult_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(ediv_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(emod_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(tdiv_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(tmod_int, sail_int, const sail_int, const sail_int); +//SAIL_INT_FUNCTION(fdiv_int, sail_int, const sail_int, const sail_int); +//SAIL_INT_FUNCTION(fmod_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(max_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(min_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(neg_int, sail_int, const sail_int); +SAIL_INT_FUNCTION(abs_int, sail_int, const sail_int); +SAIL_INT_FUNCTION(pow_int, sail_int, const sail_int, const sail_int); +SAIL_INT_FUNCTION(pow2, sail_int, const sail_int); + +SAIL_INT_FUNCTION(make_the_value, sail_int, const sail_int); +SAIL_INT_FUNCTION(size_itself_int, sail_int, const sail_int); + +/* ***** Sail bitvectors ***** */ + +typedef uint64_t fbits; + +bool eq_bit(const fbits a, const fbits b); + +bool EQUAL(fbits)(const fbits, const fbits); + +typedef struct { + uint64_t len; + uint64_t bits; +} sbits; + +typedef struct { + mp_bitcnt_t len; + mpz_t *bits; +} lbits; + +// For backwards compatability +typedef uint64_t mach_bits; +typedef lbits sail_bits; + +SAIL_BUILTIN_TYPE(lbits); + +void CREATE_OF(lbits, fbits)(lbits *, + const fbits op, + const uint64_t len, + const bool direction); + +void RECREATE_OF(lbits, fbits)(lbits *, + const fbits op, + const uint64_t len, + const bool direction); + +void CREATE_OF(lbits, sbits)(lbits *, + const sbits op, + const bool direction); + +void RECREATE_OF(lbits, sbits)(lbits *, + const sbits op, + const bool direction); + +sbits CREATE_OF(sbits, lbits)(const lbits op, const bool direction); +fbits CREATE_OF(fbits, lbits)(const lbits op, const bool direction); +sbits CREATE_OF(sbits, fbits)(const fbits op, const uint64_t len, const bool direction); + +/* Bitvector conversions */ + +fbits CONVERT_OF(fbits, lbits)(const lbits, const bool); +fbits CONVERT_OF(fbits, sbits)(const sbits, const bool); + +void CONVERT_OF(lbits, fbits)(lbits *, const fbits, const uint64_t, const bool); +void CONVERT_OF(lbits, sbits)(lbits *, const sbits, const bool); + +sbits CONVERT_OF(sbits, fbits)(const fbits, const uint64_t, const bool); +sbits CONVERT_OF(sbits, lbits)(const lbits, const bool); + +void UNDEFINED(lbits)(lbits *, const sail_int len, const fbits bit); +fbits UNDEFINED(fbits)(const unit); + +sbits undefined_sbits(void); + +/* + * Wrapper around >> operator to avoid UB when shift amount is greater + * than or equal to 64. + */ +fbits safe_rshift(const fbits, const fbits); + +/* + * Used internally to construct large bitvector literals. + */ +void append_64(lbits *rop, const lbits op, const fbits chunk); + +void add_bits(lbits *rop, const lbits op1, const lbits op2); +void sub_bits(lbits *rop, const lbits op1, const lbits op2); + +void add_bits_int(lbits *rop, const lbits op1, const sail_int op2); +void sub_bits_int(lbits *rop, const lbits op1, const sail_int op2); + +void and_bits(lbits *rop, const lbits op1, const lbits op2); +void or_bits(lbits *rop, const lbits op1, const lbits op2); +void xor_bits(lbits *rop, const lbits op1, const lbits op2); +void not_bits(lbits *rop, const lbits op); + +void mults_vec(lbits *rop, const lbits op1, const lbits op2); +void mult_vec(lbits *rop, const lbits op1, const lbits op2); + +void zeros(lbits *rop, const sail_int op); + +void zero_extend(lbits *rop, const lbits op, const sail_int len); +fbits fast_zero_extend(const sbits op, const uint64_t n); +void sign_extend(lbits *rop, const lbits op, const sail_int len); +fbits fast_sign_extend(const fbits op, const uint64_t n, const uint64_t m); +fbits fast_sign_extend2(const sbits op, const uint64_t m); + +sail_int length_lbits(const lbits op); + +bool eq_bits(const lbits op1, const lbits op2); +bool EQUAL(lbits)(const lbits op1, const lbits op2); +bool neq_bits(const lbits op1, const lbits op2); + +void vector_subrange_lbits(lbits *rop, + const lbits op, + const sail_int n_mpz, + const sail_int m_mpz); + +void sail_truncate(lbits *rop, const lbits op, const sail_int len); +void sail_truncateLSB(lbits *rop, const lbits op, const sail_int len); + +fbits bitvector_access(const lbits op, const sail_int n_mpz); + +sail_int sail_unsigned(const lbits op); +sail_int sail_signed(const lbits op); + +mach_int fast_signed(const fbits, const uint64_t); +mach_int fast_unsigned(const fbits); + +void append(lbits *rop, const lbits op1, const lbits op2); + +sbits append_sf(const sbits, const fbits, const uint64_t); +sbits append_fs(const fbits, const uint64_t, const sbits); +sbits append_ss(const sbits, const sbits); + +void replicate_bits(lbits *rop, const lbits op1, const sail_int op2); +fbits fast_replicate_bits(const fbits shift, const fbits v, const mach_int times); + +void get_slice_int(lbits *rop, const sail_int len_mpz, const sail_int n, const sail_int start_mpz); + +sail_int set_slice_int(const sail_int, const sail_int, const sail_int, const lbits); + +void update_lbits(lbits *rop, const lbits op, const sail_int n_mpz, const uint64_t bit); + +void vector_update_subrange_lbits(lbits *rop, + const lbits op, + const sail_int n_mpz, + const sail_int m_mpz, + const lbits slice); + +fbits fast_update_subrange(const fbits op, + const mach_int n, + const mach_int m, + const fbits slice); + +void slice(lbits *rop, const lbits op, const sail_int start_mpz, const sail_int len_mpz); + +sbits sslice(const fbits op, const mach_int start, const mach_int len); + +void set_slice(lbits *rop, + const sail_int len_mpz, + const sail_int slen_mpz, + const lbits op, + const sail_int start_mpz, + const lbits slice); + + +void shift_bits_left(lbits *rop, const lbits op1, const lbits op2); +void shift_bits_right(lbits *rop, const lbits op1, const lbits op2); +void shift_bits_right_arith(lbits *rop, const lbits op1, const lbits op2); + +void shiftl(lbits *rop, const lbits op1, const sail_int op2); +void shiftr(lbits *rop, const lbits op1, const sail_int op2); + +void reverse_endianness(lbits*, lbits); + +bool eq_sbits(const sbits op1, const sbits op2); +bool neq_sbits(const sbits op1, const sbits op2); +sbits not_sbits(const sbits op); +sbits xor_sbits(const sbits op1, const sbits op2); +sbits or_sbits(const sbits op1, const sbits op2); +sbits and_sbits(const sbits op1, const sbits op2); +sbits add_sbits(const sbits op1, const sbits op2); +sbits sub_sbits(const sbits op1, const sbits op2); + +/* ***** Sail reals ***** */ + +typedef mpq_t real; + +SAIL_BUILTIN_TYPE(real); + +void CREATE_OF(real, sail_string)(real *rop, const sail_string op); +void CONVERT_OF(real, sail_string)(real *rop, const sail_string op); + +void UNDEFINED(real)(real *rop, unit u); + +void neg_real(real *rop, const real op); + +void mult_real(real *rop, const real op1, const real op2); +void sub_real(real *rop, const real op1, const real op2); +void add_real(real *rop, const real op1, const real op2); +void div_real(real *rop, const real op1, const real op2); + +void sqrt_real(real *rop, const real op); +void abs_real(real *rop, const real op); + +SAIL_INT_FUNCTION(round_up, sail_int, const real); +SAIL_INT_FUNCTION(round_down, sail_int, const real); + +void to_real(real *rop, const sail_int op); + +bool EQUAL(real)(const real op1, const real op2); + +bool lt_real(const real op1, const real op2); +bool gt_real(const real op1, const real op2); +bool lteq_real(const real op1, const real op2); +bool gteq_real(const real op1, const real op2); + +void real_power(real *rop, const real base, const sail_int exp); + +unit print_real(const sail_string, const real); +unit prerr_real(const sail_string, const real); + +void random_real(real *rop, unit); + +/* ***** String utilities ***** */ + +SAIL_INT_FUNCTION(string_length, sail_int, sail_string); +void string_drop(sail_string *dst, sail_string s, sail_int len); +void string_take(sail_string *dst, sail_string s, sail_int len); + +/* ***** Printing ***** */ + +void string_of_int(sail_string *str, const sail_int i); +void string_of_lbits(sail_string *str, const lbits op); +void string_of_fbits(sail_string *str, const fbits op); +void decimal_string_of_lbits(sail_string *str, const lbits op); +void decimal_string_of_fbits(sail_string *str, const fbits op); + +/* + * Utility function not callable from Sail! + */ +void fprint_bits(const sail_string pre, + const lbits op, + const sail_string post, + FILE *stream); + +unit print_bits(const sail_string str, const lbits op); +unit prerr_bits(const sail_string str, const lbits op); + +unit print(const sail_string str); +unit print_endline(const sail_string str); + +unit prerr(const sail_string str); +unit prerr_endline(const sail_string str); + +unit print_int(const sail_string str, const sail_int op); +unit prerr_int(const sail_string str, const sail_int op); + +unit sail_putchar(const sail_int op); + +/* ***** Misc ***** */ + +sail_int get_time_ns(const unit); @@ -176,6 +176,11 @@ void string_take(sail_string *dst, sail_string s, sail_int ns) /* ***** Sail integers ***** */ +uint64_t sail_int_get_ui(const mpz_t op) +{ + return mpz_get_ui(op); +} + inline bool EQUAL(mach_int)(const mach_int op1, const mach_int op2) { @@ -84,6 +84,8 @@ bool EQUAL(mach_int)(const mach_int, const mach_int); typedef mpz_t sail_int; +uint64_t sail_int_get_ui(const sail_int); + #define SAIL_INT_FUNCTION(fname, rtype, ...) void fname(rtype*, __VA_ARGS__) SAIL_BUILTIN_TYPE(sail_int); @@ -321,7 +323,6 @@ void set_slice(lbits *rop, const sail_int start_mpz, const lbits slice); - void shift_bits_left(lbits *rop, const lbits op1, const lbits op2); void shift_bits_right(lbits *rop, const lbits op1, const lbits op2); void shift_bits_right_arith(lbits *rop, const lbits op1, const lbits op2); diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index cb02d17d..d21d219c 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -86,6 +86,7 @@ let optimize_primops = ref false let optimize_hoist_allocations = ref false let optimize_struct_updates = ref false let optimize_alias = ref false +let optimize_int128 = ref false let c_debug str = if !c_verbosity > 0 then prerr_endline (Lazy.force str) else () @@ -192,7 +193,9 @@ let rec ctyp_of_typ ctx typ = let rec is_stack_ctyp ctyp = match ctyp with | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_enum _ -> true | CT_fint n -> n <= 64 - | CT_lbits _ | CT_lint | CT_real | CT_string | CT_list _ | CT_vector _ -> false + | CT_lint when !optimize_int128 -> true + | CT_lint -> false + | CT_lbits _ | CT_real | CT_string | CT_list _ | CT_vector _ -> false | CT_struct (_, fields) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) fields | CT_variant (_, ctors) -> false (* List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors *) (* FIXME *) | CT_tup ctyps -> List.for_all is_stack_ctyp ctyps @@ -1493,6 +1496,7 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = | CT_unit -> "UNIT", [] | CT_bit -> "UINT64_C(0)", [] | CT_fint _ -> "INT64_C(0xdeadc0de)", [] + | CT_lint when !optimize_int128 -> "((sail_int) 0xdeadc0de)", [] | CT_fbits _ -> "UINT64_C(0xdeadc0de)", [] | CT_sbits _ -> "undefined_sbits()", [] | CT_bool -> "false", [] @@ -1898,8 +1902,8 @@ let codegen_vector ctx (direction, ctyp) = ^^ string "}" in let vector_update = - string (Printf.sprintf "static void vector_update_%s(%s *rop, %s op, mpz_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) - ^^ string " int m = mpz_get_ui(n);\n" + string (Printf.sprintf "static void vector_update_%s(%s *rop, %s op, sail_int n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + ^^ string " int m = sail_int_get_ui(n);\n" ^^ string " if (rop->data == op.data) {\n" ^^ string (if is_stack_ctyp ctyp then " rop->data[m] = elem;\n" @@ -1924,13 +1928,13 @@ let codegen_vector ctx (direction, ctyp) = in let vector_access = if is_stack_ctyp ctyp then - string (Printf.sprintf "static %s vector_access_%s(%s op, mpz_t n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) - ^^ string " int m = mpz_get_ui(n);\n" + string (Printf.sprintf "static %s vector_access_%s(%s op, sail_int n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) + ^^ string " int m = sail_int_get_ui(n);\n" ^^ string " return op.data[m];\n" ^^ string "}" else - string (Printf.sprintf "static void vector_access_%s(%s *rop, %s op, mpz_t n) {\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) - ^^ string " int m = mpz_get_ui(n);\n" + string (Printf.sprintf "static void vector_access_%s(%s *rop, %s op, sail_int n) {\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + ^^ string " int m = sail_int_get_ui(n);\n" ^^ string (Printf.sprintf " COPY(%s)(rop, op.data[m]);\n" (sgen_ctyp_name ctyp)) ^^ string "}" in @@ -1946,8 +1950,8 @@ let codegen_vector ctx (direction, ctyp) = ^^ string "}" in let vector_undefined = - string (Printf.sprintf "static void undefined_vector_%s(%s *rop, mpz_t len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) - ^^ string (Printf.sprintf " rop->len = mpz_get_ui(len);\n") + string (Printf.sprintf "static void undefined_vector_%s(%s *rop, sail_int len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + ^^ string (Printf.sprintf " rop->len = sail_int_get_ui(len);\n") ^^ string (Printf.sprintf " rop->data = malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp)) ^^ string " for (int i = 0; i < (rop->len); i++) {\n" ^^ string (if is_stack_ctyp ctyp then diff --git a/src/jib/c_backend.mli b/src/jib/c_backend.mli index 3e8c426b..4628691d 100644 --- a/src/jib/c_backend.mli +++ b/src/jib/c_backend.mli @@ -100,6 +100,7 @@ val optimize_primops : bool ref val optimize_hoist_allocations : bool ref val optimize_struct_updates : bool ref val optimize_alias : bool ref +val optimize_int128 : bool ref (** Convert a typ to a IR ctyp *) val ctyp_of_typ : Jib_compile.ctx -> Ast.typ -> ctyp diff --git a/src/sail.ml b/src/sail.ml index 3c277fab..c2b2ed65 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -196,6 +196,9 @@ let options = Arg.align ([ ( "-Oconstant_fold", Arg.Set Constant_fold.optimize_constant_fold, " apply constant folding optimizations"); + ( "-Oint128", + Arg.Set C_backend.optimize_int128, + " use 128-bit integers rather than GMP arbitrary precision integers"); ( "-Oaarch64_fast", Arg.Set Jib_compile.optimize_aarch64_fast_struct, " apply ARMv8.5 specific optimizations (potentially unsound in general)"); |
