summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlasdair2021-01-05 14:36:21 +0000
committerAlasdair2021-01-05 14:36:21 +0000
commit1ac7d1b3ddb0cc1aeff4964559dbf92e0addf057 (patch)
tree936ab9d1a1ef775f713d94742feefac6d04769b2
parent8b2a3fa0eae0f49b78c0c5f845d3824d21f98df3 (diff)
Fix some cases when monomorphising vectors containing variable-length bitvectors
-rw-r--r--lib/generic_equality.sail16
-rw-r--r--lib/sail.c10
-rw-r--r--lib/sail.h2
-rw-r--r--src/jib/c_backend.ml50
-rw-r--r--src/specialize.ml18
-rwxr-xr-xtest/c/run_tests.py2
-rw-r--r--test/c/split.expect1
-rw-r--r--test/c/split.sail30
8 files changed, 125 insertions, 4 deletions
diff --git a/lib/generic_equality.sail b/lib/generic_equality.sail
new file mode 100644
index 00000000..22bd778e
--- /dev/null
+++ b/lib/generic_equality.sail
@@ -0,0 +1,16 @@
+$ifndef _GENERIC_EQUALITY
+$define _GENERIC_EQUALITY
+
+$include <flow.sail>
+
+val eq_anything = {ocaml: "(fun (x, y) -> x = y)", interpreter: "eq_anything", lem: "eq", coq: "generic_eq", c: "eq_anything"} : forall ('a : Type). ('a, 'a) -> bool
+
+overload operator == = {eq_anything}
+
+val neq_anything : forall ('a : Type). ('a, 'a) -> bool
+
+function neq_anything(x, y) = not_bool(eq_anything(x, y))
+
+overload operator != = {neq_anything}
+
+$endif
diff --git a/lib/sail.c b/lib/sail.c
index 94065f0a..11a6c2d8 100644
--- a/lib/sail.c
+++ b/lib/sail.c
@@ -461,6 +461,11 @@ bool EQUAL(fbits)(const fbits op1, const fbits op2)
return op1 == op2;
}
+bool EQUAL(ref_fbits)(const fbits *op1, const fbits *op2)
+{
+ return *op1 == *op2;
+}
+
void CREATE(lbits)(lbits *rop)
{
rop->bits = sail_malloc(sizeof(mpz_t));
@@ -791,6 +796,11 @@ bool EQUAL(lbits)(const lbits op1, const lbits op2)
return eq_bits(op1, op2);
}
+bool EQUAL(ref_lbits)(const lbits *op1, const lbits *op2)
+{
+ return eq_bits(*op1, *op2);
+}
+
bool neq_bits(const lbits op1, const lbits op2)
{
assert(op1.len == op2.len);
diff --git a/lib/sail.h b/lib/sail.h
index 0c9b18b5..6de37b0a 100644
--- a/lib/sail.h
+++ b/lib/sail.h
@@ -184,6 +184,7 @@ static inline bool bit_to_bool(const fbits a)
}
bool EQUAL(fbits)(const fbits, const fbits);
+bool EQUAL(ref_fbits)(const fbits*, const fbits*);
typedef struct {
uint64_t len;
@@ -277,6 +278,7 @@ void count_leading_zeros(sail_int *rop, const lbits op);
bool eq_bits(const lbits op1, const lbits op2);
bool EQUAL(lbits)(const lbits op1, const lbits op2);
+bool EQUAL(ref_lbits)(const lbits *op1, const lbits *op2);
bool neq_bits(const lbits op1, const lbits op2);
void vector_subrange_lbits(lbits *rop,
diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml
index 88b01d87..a875405d 100644
--- a/src/jib/c_backend.ml
+++ b/src/jib/c_backend.ml
@@ -1374,6 +1374,44 @@ let rec codegen_conversion l clexp cval =
| CT_ref ctyp_to, ctyp_from ->
codegen_conversion l (CL_addr clexp) cval
+ | CT_vector (_, ctyp_elem_to), CT_vector (_, ctyp_elem_from) ->
+ let i = ngensym () in
+ let from = ngensym () in
+ let into = ngensym () in
+ ksprintf string " KILL(%s)(%s);" (sgen_ctyp_name ctyp_to) (sgen_clexp clexp) ^^ hardline
+ ^^ ksprintf string " internal_vector_init_%s(%s, %s.len);" (sgen_ctyp_name ctyp_to) (sgen_clexp clexp) (sgen_cval cval) ^^ hardline
+ ^^ ksprintf string " for (int %s = 0; %s < %s.len; %s++) {" (sgen_name i) (sgen_name i) (sgen_cval cval) (sgen_name i) ^^ hardline
+ ^^ (if is_stack_ctyp ctyp_elem_from then
+ ksprintf string " %s %s = %s.data[%s];" (sgen_ctyp ctyp_elem_from) (sgen_name from) (sgen_cval cval) (sgen_name i)
+ else
+ ksprintf string " %s %s;" (sgen_ctyp ctyp_elem_from) (sgen_name from) ^^ hardline
+ ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) ^^ hardline
+ ^^ ksprintf string " COPY(%s)(&%s, %s.data[%s]);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from) (sgen_cval cval) (sgen_name i)
+ )
+ ^^ hardline
+ ^^ ksprintf string " %s %s;" (sgen_ctyp ctyp_elem_to) (sgen_name into)
+ ^^ (if is_stack_ctyp ctyp_elem_to then
+ empty
+ else
+ hardline ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp_elem_to) (sgen_name into)
+ )
+ ^^ nest 2 (hardline
+ ^^ codegen_conversion l (CL_id (into, ctyp_elem_to)) (V_id (from, ctyp_elem_from)))
+ ^^ hardline
+ ^^ (if is_stack_ctyp ctyp_elem_to then
+ ksprintf string " %s.data[%s] = %s;" (sgen_clexp_pure clexp) (sgen_name i) (sgen_name into)
+ else
+ ksprintf string " COPY(%s)(&((%s)->data[%s]), %s);" (sgen_ctyp_name ctyp_elem_to) (sgen_clexp clexp) (sgen_name i) (sgen_name into)
+ ^^ hardline ^^ ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp_elem_to) (sgen_name into)
+ )
+ ^^ (if is_stack_ctyp ctyp_elem_from then
+ empty
+ else
+ hardline ^^ ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp_elem_from) (sgen_name from)
+ )
+ ^^ hardline
+ ^^ string " }"
+
(* If we have to convert between tuple types, convert the fields individually. *)
| CT_tup ctyps_to, CT_tup ctyps_from when List.length ctyps_to = List.length ctyps_from ->
let len = List.length ctyps_to in
@@ -2019,6 +2057,17 @@ let codegen_vector ctx (direction, ctyp) =
^^ string " }\n"
^^ string "}"
in
+ let vector_equal =
+ let open Printf in
+ ksprintf string "static bool EQUAL(%s)(const %s op1, const %s op2) {\n" (sgen_id id) (sgen_id id) (sgen_id id)
+ ^^ string " if (op1.len != op2.len) return false;\n"
+ ^^ string " bool result = true;"
+ ^^ string " for (int i = 0; i < op1.len; i++) {\n"
+ ^^ ksprintf string " result &= EQUAL(%s)(op1.data[i], op2.data[i]);" (sgen_ctyp_name ctyp)
+ ^^ string " }\n"
+ ^^ ksprintf string " return result;\n"
+ ^^ string "}"
+ in
begin
generated := IdSet.add id !generated;
vector_typedef ^^ twice hardline
@@ -2028,6 +2077,7 @@ let codegen_vector ctx (direction, ctyp) =
^^ vector_access ^^ twice hardline
^^ vector_set ^^ twice hardline
^^ vector_update ^^ twice hardline
+ ^^ vector_equal ^^ twice hardline
^^ internal_vector_update ^^ twice hardline
^^ internal_vector_init ^^ twice hardline
end
diff --git a/src/specialize.ml b/src/specialize.ml
index bbf74f46..cfd80cce 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -381,15 +381,27 @@ let specialize_id_valspec spec instantiations id ast =
let spec_ids = ref IdSet.empty in
let specialize_instance instantiation =
- let safe_instantiation, reverse = safe_instantiation instantiation in
- (* Replace the polymorphic type variables in the type with their concrete instantiation. *)
- let typ = remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ)) in
+ let uninstantiated = quant_kopts typq |> List.map kopt_kid |> List.filter (fun v -> not (KBindings.mem v instantiation)) |> KidSet.of_list in
(* Collect any new type variables introduced by the instantiation *)
let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in
let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_frees |> collect_kids in
let int_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_int_frees |> collect_kids in
+ let typq, typ =
+ List.fold_left (fun (typq, typ) free ->
+ if KidSet.mem free uninstantiated then
+ let fresh_v = prepend_kid "o#" free in
+ typquant_subst_kid free fresh_v typq, subst_kid typ_subst free fresh_v typ
+ else
+ typq, typ
+ ) (typq, typ) (typ_frees @ int_frees)
+ in
+
+ let safe_instantiation, reverse = safe_instantiation instantiation in
+ (* Replace the polymorphic type variables in the type with their concrete instantiation. *)
+ let typ = remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ)) in
+
(* Remove type variables from the type quantifier. *)
let kopts, constraints = quant_split typq in
let constraints = instantiate_constraints safe_instantiation constraints in
diff --git a/test/c/run_tests.py b/test/c/run_tests.py
index 28a4d28d..a88de411 100755
--- a/test/c/run_tests.py
+++ b/test/c/run_tests.py
@@ -115,7 +115,7 @@ def test_lem(name):
xml = '<testsuites>\n'
-xml += test_c2('unoptimized C', '', '', True)
+#xml += test_c2('unoptimized C', '', '', True)
xml += test_c('unoptimized C', '', '', True)
xml += test_c('optimized C', '-O2', '-O', True)
xml += test_c('constant folding', '', '-Oconstant_fold', True)
diff --git a/test/c/split.expect b/test/c/split.expect
new file mode 100644
index 00000000..9766475a
--- /dev/null
+++ b/test/c/split.expect
@@ -0,0 +1 @@
+ok
diff --git a/test/c/split.sail b/test/c/split.sail
new file mode 100644
index 00000000..8c994e80
--- /dev/null
+++ b/test/c/split.sail
@@ -0,0 +1,30 @@
+default Order dec
+
+$include <prelude.sail>
+$include <generic_equality.sail>
+$include <string.sail>
+
+val split : forall 'n 'm, 'n * 'm == 64 & 'n in {1, 2, 4, 8}. (int('n), int('m), bits(64)) -> vector('n, dec, bits('m)) effect {undef}
+
+function split(n, m, bv) = {
+ var result: vector('n, dec, bits('m)) = undefined;
+
+ foreach (i from 0 to (n - 1)) {
+ result[i] = sail_shiftright(bv, i * m)[m - 1 .. 0]
+ };
+
+ result
+}
+
+val main : unit -> unit effect {escape, undef}
+
+function main() = {
+ assert(split(8, 8, 0xAAAABBBBCCCCDDDD) == [0xAA, 0xAA, 0xBB, 0xBB, 0xCC, 0xCC, 0xDD, 0xDD]);
+ assert(split(4, 16, 0xAAAABBBBCCCCDDDD) == [0xAAAA, 0xBBBB, 0xCCCC, 0xDDDD]);
+ assert(split(2, 32, 0xAAAABBBBCCCCDDDD) == [0xAAAABBBB, 0xCCCCDDDD]);
+ assert(split(1, 64, 0xAAAABBBBCCCCDDDD) == [0xAAAABBBBCCCCDDDD]);
+
+ assert(split(4, 16, 0xAAAABBBBCCCCDDDD) != [0xDDDD, 0xCCCC, 0xBBBB, 0xAAAA]);
+
+ print_endline("ok");
+}