summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlasdair Armstrong2018-08-06 20:27:01 +0100
committerAlasdair Armstrong2018-08-06 20:34:11 +0100
commit0cb1e506866873f8886baf7631878ed956f1e8f5 (patch)
tree0e0d76c627c318ccbef100e65001bd60c38f62fe
parentd334535562953959c965ccace6392b0d87d1fb89 (diff)
Cast each argument to a polymorphic constructor into it's most general type
-rw-r--r--lib/sail.c4
-rw-r--r--lib/sail.h4
-rw-r--r--src/bytecode_util.ml19
-rw-r--r--src/c_backend.ml42
4 files changed, 57 insertions, 12 deletions
diff --git a/lib/sail.c b/lib/sail.c
index 38c8c273..3223dc14 100644
--- a/lib/sail.c
+++ b/lib/sail.c
@@ -390,12 +390,12 @@ void RECREATE_OF(sail_bits, mach_bits)(sail_bits *rop, const uint64_t op, const
mpz_set_ui(*rop->bits, op);
}
-mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits op)
+mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits op, const bool direction)
{
return mpz_get_ui(*op.bits);
}
-void CONVERT_OF(sail_bits, mach_bits)(sail_bits *rop, const mach_bits op, const uint64_t len)
+void CONVERT_OF(sail_bits, mach_bits)(sail_bits *rop, const mach_bits op, const uint64_t len, const bool direction)
{
rop->len = len;
// use safe_rshift to correctly handle the case when we have a 0-length vector.
diff --git a/lib/sail.h b/lib/sail.h
index 9ce3ec6b..afff2c65 100644
--- a/lib/sail.h
+++ b/lib/sail.h
@@ -186,8 +186,8 @@ void RECREATE_OF(sail_bits, mach_bits)(sail_bits *,
const mach_bits len,
const bool direction);
-mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits);
-void CONVERT_OF(sail_bits, mach_bits)(sail_bits *, const mach_bits, const uint64_t);
+mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits, const bool);
+void CONVERT_OF(sail_bits, mach_bits)(sail_bits *, const mach_bits, const uint64_t, const bool);
void UNDEFINED(sail_bits)(sail_bits *, const sail_int len, const mach_bits bit);
mach_bits UNDEFINED(mach_bits)(const unit);
diff --git a/src/bytecode_util.ml b/src/bytecode_util.ml
index 27086858..ed042c51 100644
--- a/src/bytecode_util.ml
+++ b/src/bytecode_util.ml
@@ -231,6 +231,25 @@ let rec ctyp_unify ctyp1 ctyp2 =
| _, _ when ctyp_equal ctyp1 ctyp2 -> []
| _, _ -> raise (Invalid_argument "ctyp_unify")
+let rec ctyp_suprema = function
+ | CT_int -> CT_int
+ | CT_bits d -> CT_bits d
+ | CT_bits64 (_, d) -> CT_bits d
+ | CT_int64 -> CT_int
+ | CT_unit -> CT_unit
+ | CT_bool -> CT_bool
+ | CT_real -> CT_real
+ | CT_bit -> CT_bit
+ | CT_tup ctyps -> CT_tup (List.map ctyp_suprema ctyps)
+ | CT_string -> CT_string
+ | CT_enum (id, ids) -> CT_enum (id, ids)
+ | CT_struct (id, ctors) -> CT_struct (id, List.map (fun (id, ctyp) -> (id, ctyp_suprema ctyp)) ctors)
+ | CT_variant (id, ctors) -> CT_variant (id, List.map (fun (id, ctyp) -> (id, ctyp_suprema ctyp)) ctors)
+ | CT_vector (d, ctyp) -> CT_vector (d, ctyp_suprema ctyp)
+ | CT_list ctyp -> CT_list (ctyp_suprema ctyp)
+ | CT_ref ctyp -> CT_ref (ctyp_suprema ctyp)
+ | CT_poly -> CT_poly
+
let rec unpoly = function
| F_poly f -> unpoly f
| F_call (call, fs) -> F_call (call, List.map unpoly fs)
diff --git a/src/c_backend.ml b/src/c_backend.ml
index e1fde4dc..adbc7cf6 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -875,7 +875,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label =
let ctor_ctyp = Bindings.find ctor (ctor_bindings ctors) in
let ctor_c_id =
if is_polymorphic ctor_ctyp then
- let unification = ctyp_unify ctor_ctyp (apat_ctyp ctx apat) in
+ let unification = List.map ctyp_suprema (ctyp_unify ctor_ctyp (apat_ctyp ctx apat)) in
ctor_c_id ^ "_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification
else
ctor_c_id
@@ -1441,7 +1441,7 @@ let fix_exception_block ?return:(return=None) ctx instrs =
rewrite_exception [] instrs @ [ilabel end_block_label]
| Some ctyp ->
rewrite_exception [] instrs @ [ilabel end_block_label; iundefined ctyp]
-
+
let rec map_try_block f (I_aux (instr, aux)) =
let instr = match instr with
| I_decl _ | I_reset _ | I_init _ | I_reinit _ -> instr
@@ -1805,14 +1805,40 @@ let rec specialize_variants ctx =
List.iter2 (fun cval ctyp -> prerr_endline (Pretty_print_sail.to_string (pp_cval cval) ^ " -> " ^ string_of_ctyp ctyp)) cvals ctyps;
(* Work out how each call to a constructor in instantiated and add that to unifications *)
- let unification = List.concat (List.map2 (fun cval ctyp -> ctyp_unify ctyp (cval_ctyp cval)) cvals ctyps) in
+ let unification = List.concat (List.map2 (fun cval ctyp -> List.map ctyp_suprema (ctyp_unify ctyp (cval_ctyp cval))) cvals ctyps) in
let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification) in
- unifications := Bindings.add mono_id (mk_tuple (List.map cval_ctyp cvals)) !unifications;
+ unifications := Bindings.add mono_id (ctyp_suprema (mk_tuple (List.map cval_ctyp cvals))) !unifications;
List.iter (fun ctyp -> prerr_endline (string_of_ctyp ctyp)) unification;
prerr_endline (string_of_id mono_id);
- I_aux (I_funcall (clexp, extern, mono_id, List.map (fun (frag, ctyp) -> (unpoly frag, ctyp)) cvals), aux)
+ (* We need to case each cval to it's ctyp_suprema in order to put it in the most general constructor *)
+ let casts =
+ let cast_to_suprema (frag, ctyp) =
+ let suprema = ctyp_suprema ctyp in
+ if ctyp_equal ctyp suprema then
+ [], (unpoly frag, ctyp), []
+ else
+ let gs = gensym () in
+ [idecl suprema gs;
+ icopy (CL_id (gs, suprema)) (unpoly frag, ctyp)],
+ (F_id gs, suprema),
+ [iclear suprema gs]
+ in
+ List.map cast_to_suprema cvals
+ in
+ let setup = List.concat (List.map (fun (setup, _, _) -> setup) casts) in
+ let cvals = List.map (fun (_, cval, _) -> cval) casts in
+ let cleanup = List.concat (List.map (fun (_, _, cleanup) -> cleanup) casts) in
+
+ let mk_funcall instr =
+ if List.length setup = 0 then
+ instr
+ else
+ iblock (setup @ [instr] @ cleanup)
+ in
+
+ mk_funcall (I_aux (I_funcall (clexp, extern, mono_id, cvals), aux))
| instr -> instr
in
@@ -1841,7 +1867,7 @@ let rec specialize_variants ctx =
let remove_poly (I_aux (instr, aux)) =
match instr with
| I_copy (clexp, (frag, ctyp)) when is_polymorphic ctyp ->
- I_aux (I_copy (clexp, (frag, clexp_ctyp clexp)), aux)
+ I_aux (I_copy (clexp, (frag, ctyp_suprema (clexp_ctyp clexp))), aux)
| instr -> I_aux (instr, aux)
in
let cdef = cdef_map_instr remove_poly cdef in
@@ -2023,10 +2049,10 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) =
else
if is_stack_ctyp lctyp then
string (Printf.sprintf " %s = CONVERT_OF(%s, %s)(%s);"
- (sgen_clexp_pure clexp) (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_cval cval))
+ (sgen_clexp_pure clexp) (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_cval_param cval))
else
string (Printf.sprintf " CONVERT_OF(%s, %s)(%s, %s);"
- (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_clexp clexp) (sgen_cval cval))
+ (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_clexp clexp) (sgen_cval_param cval))
| I_jump (cval, label) ->
string (Printf.sprintf " if (%s) goto %s;" (sgen_cval cval) label)