diff options
| author | Alasdair Armstrong | 2018-08-06 20:27:01 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2018-08-06 20:34:11 +0100 |
| commit | 0cb1e506866873f8886baf7631878ed956f1e8f5 (patch) | |
| tree | 0e0d76c627c318ccbef100e65001bd60c38f62fe | |
| parent | d334535562953959c965ccace6392b0d87d1fb89 (diff) | |
Cast each argument to a polymorphic constructor into it's most general type
| -rw-r--r-- | lib/sail.c | 4 | ||||
| -rw-r--r-- | lib/sail.h | 4 | ||||
| -rw-r--r-- | src/bytecode_util.ml | 19 | ||||
| -rw-r--r-- | src/c_backend.ml | 42 |
4 files changed, 57 insertions, 12 deletions
@@ -390,12 +390,12 @@ void RECREATE_OF(sail_bits, mach_bits)(sail_bits *rop, const uint64_t op, const mpz_set_ui(*rop->bits, op); } -mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits op) +mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits op, const bool direction) { return mpz_get_ui(*op.bits); } -void CONVERT_OF(sail_bits, mach_bits)(sail_bits *rop, const mach_bits op, const uint64_t len) +void CONVERT_OF(sail_bits, mach_bits)(sail_bits *rop, const mach_bits op, const uint64_t len, const bool direction) { rop->len = len; // use safe_rshift to correctly handle the case when we have a 0-length vector. @@ -186,8 +186,8 @@ void RECREATE_OF(sail_bits, mach_bits)(sail_bits *, const mach_bits len, const bool direction); -mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits); -void CONVERT_OF(sail_bits, mach_bits)(sail_bits *, const mach_bits, const uint64_t); +mach_bits CONVERT_OF(mach_bits, sail_bits)(const sail_bits, const bool); +void CONVERT_OF(sail_bits, mach_bits)(sail_bits *, const mach_bits, const uint64_t, const bool); void UNDEFINED(sail_bits)(sail_bits *, const sail_int len, const mach_bits bit); mach_bits UNDEFINED(mach_bits)(const unit); diff --git a/src/bytecode_util.ml b/src/bytecode_util.ml index 27086858..ed042c51 100644 --- a/src/bytecode_util.ml +++ b/src/bytecode_util.ml @@ -231,6 +231,25 @@ let rec ctyp_unify ctyp1 ctyp2 = | _, _ when ctyp_equal ctyp1 ctyp2 -> [] | _, _ -> raise (Invalid_argument "ctyp_unify") +let rec ctyp_suprema = function + | CT_int -> CT_int + | CT_bits d -> CT_bits d + | CT_bits64 (_, d) -> CT_bits d + | CT_int64 -> CT_int + | CT_unit -> CT_unit + | CT_bool -> CT_bool + | CT_real -> CT_real + | CT_bit -> CT_bit + | CT_tup ctyps -> CT_tup (List.map ctyp_suprema ctyps) + | CT_string -> CT_string + | CT_enum (id, ids) -> CT_enum (id, ids) + | CT_struct (id, ctors) -> CT_struct (id, List.map (fun (id, ctyp) -> (id, ctyp_suprema ctyp)) ctors) + | CT_variant (id, ctors) -> CT_variant (id, List.map (fun (id, ctyp) -> (id, ctyp_suprema ctyp)) ctors) + | CT_vector (d, ctyp) -> CT_vector (d, ctyp_suprema ctyp) + | CT_list ctyp -> CT_list (ctyp_suprema ctyp) + | CT_ref ctyp -> CT_ref (ctyp_suprema ctyp) + | CT_poly -> CT_poly + let rec unpoly = function | F_poly f -> unpoly f | F_call (call, fs) -> F_call (call, List.map unpoly fs) diff --git a/src/c_backend.ml b/src/c_backend.ml index e1fde4dc..adbc7cf6 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -875,7 +875,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = let ctor_ctyp = Bindings.find ctor (ctor_bindings ctors) in let ctor_c_id = if is_polymorphic ctor_ctyp then - let unification = ctyp_unify ctor_ctyp (apat_ctyp ctx apat) in + let unification = List.map ctyp_suprema (ctyp_unify ctor_ctyp (apat_ctyp ctx apat)) in ctor_c_id ^ "_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification else ctor_c_id @@ -1441,7 +1441,7 @@ let fix_exception_block ?return:(return=None) ctx instrs = rewrite_exception [] instrs @ [ilabel end_block_label] | Some ctyp -> rewrite_exception [] instrs @ [ilabel end_block_label; iundefined ctyp] - + let rec map_try_block f (I_aux (instr, aux)) = let instr = match instr with | I_decl _ | I_reset _ | I_init _ | I_reinit _ -> instr @@ -1805,14 +1805,40 @@ let rec specialize_variants ctx = List.iter2 (fun cval ctyp -> prerr_endline (Pretty_print_sail.to_string (pp_cval cval) ^ " -> " ^ string_of_ctyp ctyp)) cvals ctyps; (* Work out how each call to a constructor in instantiated and add that to unifications *) - let unification = List.concat (List.map2 (fun cval ctyp -> ctyp_unify ctyp (cval_ctyp cval)) cvals ctyps) in + let unification = List.concat (List.map2 (fun cval ctyp -> List.map ctyp_suprema (ctyp_unify ctyp (cval_ctyp cval))) cvals ctyps) in let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification) in - unifications := Bindings.add mono_id (mk_tuple (List.map cval_ctyp cvals)) !unifications; + unifications := Bindings.add mono_id (ctyp_suprema (mk_tuple (List.map cval_ctyp cvals))) !unifications; List.iter (fun ctyp -> prerr_endline (string_of_ctyp ctyp)) unification; prerr_endline (string_of_id mono_id); - I_aux (I_funcall (clexp, extern, mono_id, List.map (fun (frag, ctyp) -> (unpoly frag, ctyp)) cvals), aux) + (* We need to case each cval to it's ctyp_suprema in order to put it in the most general constructor *) + let casts = + let cast_to_suprema (frag, ctyp) = + let suprema = ctyp_suprema ctyp in + if ctyp_equal ctyp suprema then + [], (unpoly frag, ctyp), [] + else + let gs = gensym () in + [idecl suprema gs; + icopy (CL_id (gs, suprema)) (unpoly frag, ctyp)], + (F_id gs, suprema), + [iclear suprema gs] + in + List.map cast_to_suprema cvals + in + let setup = List.concat (List.map (fun (setup, _, _) -> setup) casts) in + let cvals = List.map (fun (_, cval, _) -> cval) casts in + let cleanup = List.concat (List.map (fun (_, _, cleanup) -> cleanup) casts) in + + let mk_funcall instr = + if List.length setup = 0 then + instr + else + iblock (setup @ [instr] @ cleanup) + in + + mk_funcall (I_aux (I_funcall (clexp, extern, mono_id, cvals), aux)) | instr -> instr in @@ -1841,7 +1867,7 @@ let rec specialize_variants ctx = let remove_poly (I_aux (instr, aux)) = match instr with | I_copy (clexp, (frag, ctyp)) when is_polymorphic ctyp -> - I_aux (I_copy (clexp, (frag, clexp_ctyp clexp)), aux) + I_aux (I_copy (clexp, (frag, ctyp_suprema (clexp_ctyp clexp))), aux) | instr -> I_aux (instr, aux) in let cdef = cdef_map_instr remove_poly cdef in @@ -2023,10 +2049,10 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = else if is_stack_ctyp lctyp then string (Printf.sprintf " %s = CONVERT_OF(%s, %s)(%s);" - (sgen_clexp_pure clexp) (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_cval cval)) + (sgen_clexp_pure clexp) (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_cval_param cval)) else string (Printf.sprintf " CONVERT_OF(%s, %s)(%s, %s);" - (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_clexp clexp) (sgen_cval cval)) + (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_clexp clexp) (sgen_cval_param cval)) | I_jump (cval, label) -> string (Printf.sprintf " if (%s) goto %s;" (sgen_cval cval) label) |
