diff options
| -rw-r--r-- | lib/generic_equality.sail | 16 | ||||
| -rw-r--r-- | lib/sail.c | 10 | ||||
| -rw-r--r-- | lib/sail.h | 2 | ||||
| -rw-r--r-- | src/jib/c_backend.ml | 50 | ||||
| -rw-r--r-- | src/specialize.ml | 18 | ||||
| -rwxr-xr-x | test/c/run_tests.py | 2 | ||||
| -rw-r--r-- | test/c/split.expect | 1 | ||||
| -rw-r--r-- | test/c/split.sail | 30 |
8 files changed, 125 insertions, 4 deletions
diff --git a/lib/generic_equality.sail b/lib/generic_equality.sail new file mode 100644 index 00000000..22bd778e --- /dev/null +++ b/lib/generic_equality.sail @@ -0,0 +1,16 @@ +$ifndef _GENERIC_EQUALITY +$define _GENERIC_EQUALITY + +$include <flow.sail> + +val eq_anything = {ocaml: "(fun (x, y) -> x = y)", interpreter: "eq_anything", lem: "eq", coq: "generic_eq", c: "eq_anything"} : forall ('a : Type). ('a, 'a) -> bool + +overload operator == = {eq_anything} + +val neq_anything : forall ('a : Type). ('a, 'a) -> bool + +function neq_anything(x, y) = not_bool(eq_anything(x, y)) + +overload operator != = {neq_anything} + +$endif @@ -461,6 +461,11 @@ bool EQUAL(fbits)(const fbits op1, const fbits op2) return op1 == op2; } +bool EQUAL(ref_fbits)(const fbits *op1, const fbits *op2) +{ + return *op1 == *op2; +} + void CREATE(lbits)(lbits *rop) { rop->bits = sail_malloc(sizeof(mpz_t)); @@ -791,6 +796,11 @@ bool EQUAL(lbits)(const lbits op1, const lbits op2) return eq_bits(op1, op2); } +bool EQUAL(ref_lbits)(const lbits *op1, const lbits *op2) +{ + return eq_bits(*op1, *op2); +} + bool neq_bits(const lbits op1, const lbits op2) { assert(op1.len == op2.len); @@ -184,6 +184,7 @@ static inline bool bit_to_bool(const fbits a) } bool EQUAL(fbits)(const fbits, const fbits); +bool EQUAL(ref_fbits)(const fbits*, const fbits*); typedef struct { uint64_t len; @@ -277,6 +278,7 @@ void count_leading_zeros(sail_int *rop, const lbits op); bool eq_bits(const lbits op1, const lbits op2); bool EQUAL(lbits)(const lbits op1, const lbits op2); +bool EQUAL(ref_lbits)(const lbits *op1, const lbits *op2); bool neq_bits(const lbits op1, const lbits op2); void vector_subrange_lbits(lbits *rop, diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index 88b01d87..a875405d 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -1374,6 +1374,44 @@ let rec codegen_conversion l clexp cval = | CT_ref ctyp_to, ctyp_from -> codegen_conversion l (CL_addr clexp) cval + | CT_vector (_, ctyp_elem_to), CT_vector (_, ctyp_elem_from) -> + let i = ngensym () in + let from = ngensym () in + let into = ngensym () in + ksprintf string " KILL(%s)(%s);" (sgen_ctyp_name ctyp_to) (sgen_clexp clexp) ^^ hardline + ^^ ksprintf string " internal_vector_init_%s(%s, %s.len);" (sgen_ctyp_name ctyp_to) (sgen_clexp clexp) (sgen_cval cval) ^^ hardline + ^^ ksprintf string " for (int %s = 0; %s < %s.len; %s++) {" (sgen_name i) (sgen_name i) (sgen_cval cval) (sgen_name i) ^^ hardline + ^^ (if is_stack_ctyp ctyp_elem_from then + ksprintf string " %s %s = %s.data[%s];" (sgen_ctyp ctyp_elem_from) (sgen_name from) (sgen_cval cval) (sgen_name i) + else + ksprintf string " %s %s;" (sgen_ctyp ctyp_elem_from) (sgen_name from) ^^ hardline + ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) ^^ hardline + ^^ ksprintf string " COPY(%s)(&%s, %s.data[%s]);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) (sgen_cval cval) (sgen_name i) + ) + ^^ hardline + ^^ ksprintf string " %s %s;" (sgen_ctyp ctyp_elem_to) (sgen_name into) + ^^ (if is_stack_ctyp ctyp_elem_to then + empty + else + hardline ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp_elem_to) (sgen_name into) + ) + ^^ nest 2 (hardline + ^^ codegen_conversion l (CL_id (into, ctyp_elem_to)) (V_id (from, ctyp_elem_from))) + ^^ hardline + ^^ (if is_stack_ctyp ctyp_elem_to then + ksprintf string " %s.data[%s] = %s;" (sgen_clexp_pure clexp) (sgen_name i) (sgen_name into) + else + ksprintf string " COPY(%s)(&((%s)->data[%s]), %s);" (sgen_ctyp_name ctyp_elem_to) (sgen_clexp clexp) (sgen_name i) (sgen_name into) + ^^ hardline ^^ ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp_elem_to) (sgen_name into) + ) + ^^ (if is_stack_ctyp ctyp_elem_from then + empty + else + hardline ^^ ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) + ) + ^^ hardline + ^^ string " }" + (* If we have to convert between tuple types, convert the fields individually. *) | CT_tup ctyps_to, CT_tup ctyps_from when List.length ctyps_to = List.length ctyps_from -> let len = List.length ctyps_to in @@ -2019,6 +2057,17 @@ let codegen_vector ctx (direction, ctyp) = ^^ string " }\n" ^^ string "}" in + let vector_equal = + let open Printf in + ksprintf string "static bool EQUAL(%s)(const %s op1, const %s op2) {\n" (sgen_id id) (sgen_id id) (sgen_id id) + ^^ string " if (op1.len != op2.len) return false;\n" + ^^ string " bool result = true;" + ^^ string " for (int i = 0; i < op1.len; i++) {\n" + ^^ ksprintf string " result &= EQUAL(%s)(op1.data[i], op2.data[i]);" (sgen_ctyp_name ctyp) + ^^ string " }\n" + ^^ ksprintf string " return result;\n" + ^^ string "}" + in begin generated := IdSet.add id !generated; vector_typedef ^^ twice hardline @@ -2028,6 +2077,7 @@ let codegen_vector ctx (direction, ctyp) = ^^ vector_access ^^ twice hardline ^^ vector_set ^^ twice hardline ^^ vector_update ^^ twice hardline + ^^ vector_equal ^^ twice hardline ^^ internal_vector_update ^^ twice hardline ^^ internal_vector_init ^^ twice hardline end diff --git a/src/specialize.ml b/src/specialize.ml index bbf74f46..cfd80cce 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -381,15 +381,27 @@ let specialize_id_valspec spec instantiations id ast = let spec_ids = ref IdSet.empty in let specialize_instance instantiation = - let safe_instantiation, reverse = safe_instantiation instantiation in - (* Replace the polymorphic type variables in the type with their concrete instantiation. *) - let typ = remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ)) in + let uninstantiated = quant_kopts typq |> List.map kopt_kid |> List.filter (fun v -> not (KBindings.mem v instantiation)) |> KidSet.of_list in (* Collect any new type variables introduced by the instantiation *) let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_frees |> collect_kids in let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_int_frees |> collect_kids in + let typq, typ = + List.fold_left (fun (typq, typ) free -> + if KidSet.mem free uninstantiated then + let fresh_v = prepend_kid "o#" free in + typquant_subst_kid free fresh_v typq, subst_kid typ_subst free fresh_v typ + else + typq, typ + ) (typq, typ) (typ_frees @ int_frees) + in + + let safe_instantiation, reverse = safe_instantiation instantiation in + (* Replace the polymorphic type variables in the type with their concrete instantiation. *) + let typ = remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ)) in + (* Remove type variables from the type quantifier. *) let kopts, constraints = quant_split typq in let constraints = instantiate_constraints safe_instantiation constraints in diff --git a/test/c/run_tests.py b/test/c/run_tests.py index 28a4d28d..a88de411 100755 --- a/test/c/run_tests.py +++ b/test/c/run_tests.py @@ -115,7 +115,7 @@ def test_lem(name): xml = '<testsuites>\n' -xml += test_c2('unoptimized C', '', '', True) +#xml += test_c2('unoptimized C', '', '', True) xml += test_c('unoptimized C', '', '', True) xml += test_c('optimized C', '-O2', '-O', True) xml += test_c('constant folding', '', '-Oconstant_fold', True) diff --git a/test/c/split.expect b/test/c/split.expect new file mode 100644 index 00000000..9766475a --- /dev/null +++ b/test/c/split.expect @@ -0,0 +1 @@ +ok diff --git a/test/c/split.sail b/test/c/split.sail new file mode 100644 index 00000000..8c994e80 --- /dev/null +++ b/test/c/split.sail @@ -0,0 +1,30 @@ +default Order dec + +$include <prelude.sail> +$include <generic_equality.sail> +$include <string.sail> + +val split : forall 'n 'm, 'n * 'm == 64 & 'n in {1, 2, 4, 8}. (int('n), int('m), bits(64)) -> vector('n, dec, bits('m)) effect {undef} + +function split(n, m, bv) = { + var result: vector('n, dec, bits('m)) = undefined; + + foreach (i from 0 to (n - 1)) { + result[i] = sail_shiftright(bv, i * m)[m - 1 .. 0] + }; + + result +} + +val main : unit -> unit effect {escape, undef} + +function main() = { + assert(split(8, 8, 0xAAAABBBBCCCCDDDD) == [0xAA, 0xAA, 0xBB, 0xBB, 0xCC, 0xCC, 0xDD, 0xDD]); + assert(split(4, 16, 0xAAAABBBBCCCCDDDD) == [0xAAAA, 0xBBBB, 0xCCCC, 0xDDDD]); + assert(split(2, 32, 0xAAAABBBBCCCCDDDD) == [0xAAAABBBB, 0xCCCCDDDD]); + assert(split(1, 64, 0xAAAABBBBCCCCDDDD) == [0xAAAABBBBCCCCDDDD]); + + assert(split(4, 16, 0xAAAABBBBCCCCDDDD) != [0xDDDD, 0xCCCC, 0xBBBB, 0xAAAA]); + + print_endline("ok"); +} |
