diff options
| author | Alasdair Armstrong | 2018-08-06 19:03:47 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2018-08-06 19:03:47 +0100 |
| commit | 6ff2e336cbf6ada9131f060bde6576b07bfe707b (patch) | |
| tree | 14a4b98ce7890a3e04ad389e07c841e0f9106d8a /src | |
| parent | e8213679de49e1fb14582e14ee0ec604732babef (diff) | |
More fixes for polymorphic data types
Diffstat (limited to 'src')
| -rw-r--r-- | src/anf.ml | 28 | ||||
| -rw-r--r-- | src/anf.mli | 4 | ||||
| -rw-r--r-- | src/bytecode_util.ml | 49 | ||||
| -rw-r--r-- | src/c_backend.ml | 99 | ||||
| -rw-r--r-- | src/reporting_basic.ml | 3 |
5 files changed, 111 insertions, 72 deletions
@@ -119,8 +119,8 @@ and 'a apat_aux = | AP_global of id * 'a | AP_app of id * 'a apat | AP_cons of 'a apat * 'a apat - | AP_nil - | AP_wild + | AP_nil of 'a + | AP_wild of 'a and 'a aval = | AV_lit of lit * 'a @@ -133,7 +133,7 @@ and 'a aval = | AV_C_fragment of fragment * 'a (* Renaming variables in ANF expressions *) - + let rec apat_bindings (AP_aux (apat_aux, _, _)) = match apat_aux with | AP_tup apats -> List.fold_left IdSet.union IdSet.empty (List.map apat_bindings apats) @@ -141,8 +141,8 @@ let rec apat_bindings (AP_aux (apat_aux, _, _)) = | AP_global (id, _) -> IdSet.empty | AP_app (id, apat) -> apat_bindings apat | AP_cons (apat1, apat2) -> IdSet.union (apat_bindings apat1) (apat_bindings apat2) - | AP_nil -> IdSet.empty - | AP_wild -> IdSet.empty + | AP_nil _ -> IdSet.empty + | AP_wild _ -> IdSet.empty let rec apat_types (AP_aux (apat_aux, _, _)) = let merge id b1 b2 = @@ -158,8 +158,8 @@ let rec apat_types (AP_aux (apat_aux, _, _)) = | AP_global (id, _) -> Bindings.empty | AP_app (id, apat) -> apat_types apat | AP_cons (apat1, apat2) -> (Bindings.merge merge) (apat_types apat1) (apat_types apat2) - | AP_nil -> Bindings.empty - | AP_wild -> Bindings.empty + | AP_nil _ -> Bindings.empty + | AP_wild _ -> Bindings.empty let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) = let apat_aux = match apat_aux with @@ -169,8 +169,8 @@ let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) = | AP_global (id, typ) -> AP_global (id, typ) | AP_app (ctor, apat) -> AP_app (ctor, apat_rename from_id to_id apat) | AP_cons (apat1, apat2) -> AP_cons (apat_rename from_id to_id apat1, apat_rename from_id to_id apat2) - | AP_nil -> AP_nil - | AP_wild -> AP_wild + | AP_nil typ -> AP_nil typ + | AP_wild typ -> AP_wild typ in AP_aux (apat_aux, env, l) @@ -401,12 +401,12 @@ let rec pp_aexp (AE_aux (aexp, _, _)) = and pp_apat (AP_aux (apat_aux, _, _)) = match apat_aux with - | AP_wild -> string "_" + | AP_wild _ -> string "_" | AP_id (id, typ) -> pp_annot typ (pp_id id) | AP_global (id, _) -> pp_id id | AP_tup apats -> parens (separate_map (comma ^^ space) pp_apat apats) | AP_app (id, apat) -> pp_id id ^^ parens (pp_apat apat) - | AP_nil -> string "[||]" + | AP_nil _ -> string "[||]" | AP_cons (hd_apat, tl_apat) -> pp_apat hd_apat ^^ string " :: " ^^ pp_apat tl_apat and pp_cases cases = surround 2 0 lbrace (separate_map (comma ^^ hardline) pp_case cases) rbrace @@ -458,19 +458,19 @@ let rec anf_pat ?global:(global=false) (P_aux (p_aux, annot) as pat) = match p_aux with | P_id id when global -> mk_apat (AP_global (id, pat_typ_of pat)) | P_id id -> mk_apat (AP_id (id, pat_typ_of pat)) - | P_wild -> mk_apat AP_wild + | P_wild -> mk_apat (AP_wild (pat_typ_of pat)) | P_tup pats -> mk_apat (AP_tup (List.map (fun pat -> anf_pat ~global:global pat) pats)) | P_app (id, [pat]) -> mk_apat (AP_app (id, anf_pat ~global:global pat)) | P_app (id, pats) -> mk_apat (AP_app (id, mk_apat (AP_tup (List.map (fun pat -> anf_pat ~global:global pat) pats)))) | P_typ (_, pat) -> anf_pat ~global:global pat | P_var (pat, _) -> anf_pat ~global:global pat | P_cons (hd_pat, tl_pat) -> mk_apat (AP_cons (anf_pat ~global:global hd_pat, anf_pat ~global:global tl_pat)) - | P_list pats -> List.fold_right (fun pat apat -> mk_apat (AP_cons (anf_pat ~global:global pat, apat))) pats (mk_apat AP_nil) + | P_list pats -> List.fold_right (fun pat apat -> mk_apat (AP_cons (anf_pat ~global:global pat, apat))) pats (mk_apat (AP_nil (pat_typ_of pat))) | _ -> anf_error ~loc:(fst annot) ("Could not convert pattern to ANF: " ^ string_of_pat pat) let rec apat_globals (AP_aux (aux, _, _)) = match aux with - | AP_nil | AP_wild | AP_id _ -> [] + | AP_nil _ | AP_wild _ | AP_id _ -> [] | AP_global (id, typ) -> [(id, typ)] | AP_tup apats -> List.concat (List.map apat_globals apats) | AP_app (_, apat) -> apat_globals apat diff --git a/src/anf.mli b/src/anf.mli index 8ba83ccd..56c3b520 100644 --- a/src/anf.mli +++ b/src/anf.mli @@ -85,8 +85,8 @@ and 'a apat_aux = | AP_global of id * 'a | AP_app of id * 'a apat | AP_cons of 'a apat * 'a apat - | AP_nil - | AP_wild + | AP_nil of 'a + | AP_wild of 'a and 'a aval = | AV_lit of lit * 'a diff --git a/src/bytecode_util.ml b/src/bytecode_util.ml index c3e61956..27086858 100644 --- a/src/bytecode_util.ml +++ b/src/bytecode_util.ml @@ -129,6 +129,7 @@ let rec frag_rename from_id to_id = function | F_unary (op, f) -> F_unary (op, frag_rename from_id to_id f) | F_field (f, field) -> F_field (frag_rename from_id to_id f, field) | F_raw raw -> F_raw raw + | F_poly f -> F_poly (frag_rename from_id to_id f) (**************************************************************************) (* 1. Instruction pretty printer *) @@ -192,6 +193,52 @@ and string_of_ctyp = function | CT_ref ctyp -> "ref(" ^ string_of_ctyp ctyp ^ ")" | CT_poly -> "*" +let rec ctyp_equal ctyp1 ctyp2 = + match ctyp1, ctyp2 with + | CT_int, CT_int -> true + | CT_bits d1, CT_bits d2 -> d1 = d2 + | CT_bits64 (m1, d1), CT_bits64 (m2, d2) -> m1 = m2 && d1 = d2 + | CT_bit, CT_bit -> true + | CT_int64, CT_int64 -> true + | CT_unit, CT_unit -> true + | CT_bool, CT_bool -> true + | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0 + | CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0 + | CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0 + | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> + List.for_all2 ctyp_equal ctyps1 ctyps2 + | CT_string, CT_string -> true + | CT_real, CT_real -> true + | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2 + | CT_list ctyp1, CT_list ctyp2 -> ctyp_equal ctyp1 ctyp2 + | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_equal ctyp1 ctyp2 + | _, _ -> false + +let rec ctyp_unify ctyp1 ctyp2 = + match ctyp1, ctyp2 with + | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> + List.concat (List.map2 ctyp_unify ctyps1 ctyps2) + + | CT_vector (b1, ctyp1), CT_vector (b2, ctyp2) when b1 = b2 -> + ctyp_unify ctyp1 ctyp2 + + | CT_list ctyp1, CT_list ctyp2 -> ctyp_unify ctyp1 ctyp2 + + | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_unify ctyp1 ctyp2 + + | CT_poly, _ -> [ctyp2] + + | _, _ when ctyp_equal ctyp1 ctyp2 -> [] + | _, _ -> raise (Invalid_argument "ctyp_unify") + +let rec unpoly = function + | F_poly f -> unpoly f + | F_call (call, fs) -> F_call (call, List.map unpoly fs) + | F_field (f, field) -> F_field (unpoly f, field) + | F_op (f1, op, f2) -> F_op (unpoly f1, op, unpoly f2) + | F_unary (op, f) -> F_unary (op, unpoly f) + | f -> f + let rec is_polymorphic = function | CT_int | CT_int64 | CT_bits _ | CT_bits64 _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string -> false | CT_tup ctyps -> List.exists is_polymorphic ctyps @@ -353,7 +400,7 @@ type dep_graph = NS.t NM.t let rec fragment_deps = function | F_id id | F_ref id -> NS.singleton (G_id id) | F_lit _ -> NS.empty - | F_field (frag, _) | F_unary (_, frag) -> fragment_deps frag + | F_field (frag, _) | F_unary (_, frag) | F_poly frag -> fragment_deps frag | F_call (_, frags) -> List.fold_left NS.union NS.empty (List.map fragment_deps frags) | F_op (frag1, _, frag2) -> NS.union (fragment_deps frag1) (fragment_deps frag2) | F_current_exception -> NS.empty diff --git a/src/c_backend.ml b/src/c_backend.ml index f5c4a7fa..22527f4c 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -113,27 +113,6 @@ let initial_ctx env = optimize_z3 = true; } -let rec ctyp_equal ctyp1 ctyp2 = - match ctyp1, ctyp2 with - | CT_int, CT_int -> true - | CT_bits d1, CT_bits d2 -> d1 = d2 - | CT_bits64 (m1, d1), CT_bits64 (m2, d2) -> m1 = m2 && d1 = d2 - | CT_bit, CT_bit -> true - | CT_int64, CT_int64 -> true - | CT_unit, CT_unit -> true - | CT_bool, CT_bool -> true - | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0 - | CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0 - | CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0 - | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> - List.for_all2 ctyp_equal ctyps1 ctyps2 - | CT_string, CT_string -> true - | CT_real, CT_real -> true - | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2 - | CT_list ctyp1, CT_list ctyp2 -> ctyp_equal ctyp1 ctyp2 - | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_equal ctyp1 ctyp2 - | _, _ -> false - (** Convert a sail type into a C-type **) let rec ctyp_of_typ ctx typ = let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.tc_env typ in @@ -843,6 +822,13 @@ let compile_funcall l ctx id args typ = end, !cleanup +let rec apat_ctyp ctx (AP_aux (apat, _, _)) = + match apat with + | AP_tup apats -> CT_tup (List.map (apat_ctyp ctx) apats) + | AP_global (_, typ) -> ctyp_of_typ ctx typ + | AP_cons (apat, _) -> CT_list (apat_ctyp ctx apat) + | AP_wild typ | AP_nil typ | AP_id (_, typ) -> ctyp_of_typ ctx typ + let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = let ctx = { ctx with local_env = env } in match apat_aux, cval with @@ -887,15 +873,23 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = | CT_variant (_, ctors) -> let ctor_c_id = string_of_id ctor in let ctor_ctyp = Bindings.find ctor (ctor_bindings ctors) in + let ctor_c_id = + if is_polymorphic ctor_ctyp then + let unification = ctyp_unify ctor_ctyp (apat_ctyp ctx apat) in + ctor_c_id ^ "_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification + else + ctor_c_id + in let instrs, cleanup, ctx = compile_match ctx apat ((F_field (frag, Util.zencode_string ctor_c_id), ctor_ctyp)) case_label in - [ijump (F_op (F_field (frag, "kind"), "!=", F_lit (V_ctor_kind ctor_c_id)), CT_bool) case_label] - @ instrs, + [icomment (string_of_ctyp (apat_ctyp ctx apat)); ijump (F_op (F_field (frag, "kind"), "!=", F_lit (V_ctor_kind ctor_c_id)), CT_bool) case_label] + @ instrs + @ [icomment (string_of_ctyp ctor_ctyp)], cleanup, ctx | _ -> failwith "AP_app constructor with non-variant type" end - | AP_wild, _ -> [], [], ctx + | AP_wild _, _ -> [], [], ctx | AP_cons (hd_apat, tl_apat), (frag, CT_list ctyp) -> let hd_setup, hd_cleanup, ctx = compile_match ctx hd_apat (F_field (F_unary ("*", frag), "hd"), ctyp) case_label in @@ -904,7 +898,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = | AP_cons _, (_, _) -> c_error "Tried to pattern match cons on non list type" - | AP_nil, (frag, _) -> [ijump (F_op (frag, "!=", F_lit V_null), CT_bool) case_label], [], ctx + | AP_nil _, (frag, _) -> [ijump (F_op (frag, "!=", F_lit V_null), CT_bool) case_label], [], ctx let unit_fragment = (F_lit V_unit, CT_unit) @@ -1792,24 +1786,10 @@ let flatten_instrs ctx = | cdef -> [cdef] -let rec ctyp_unify ctyp1 ctyp2 = - match ctyp1, ctyp2 with - | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> - List.concat (List.map2 ctyp_unify ctyps1 ctyps2) - - | CT_vector (b1, ctyp1), CT_vector (b2, ctyp2) when b1 = b2 -> - ctyp_unify ctyp1 ctyp2 - - | CT_list ctyp1, CT_list ctyp2 -> ctyp_unify ctyp1 ctyp2 - - | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_unify ctyp1 ctyp2 - - | CT_poly, _ -> [ctyp2] +let rec specialize_variants ctx = - | _, _ when ctyp_equal ctyp1 ctyp2 -> [] - | _, _ -> c_error "Unification failed" + let unifications = ref (Bindings.empty) in -let rec specialize_variants ctx = let specialize_constructor ctx ctor_id ctyp = let ctyps = match ctyp with | CT_tup ctyps -> ctyps @@ -1818,10 +1798,17 @@ let rec specialize_variants ctx = function | I_aux (I_funcall (clexp, extern, id, cvals), aux) as instr when Id.compare id ctor_id = 0 -> assert (List.length ctyps = List.length cvals); - List.iter2 (fun cval ctyp -> print_endline (Pretty_print_sail.to_string (pp_cval cval) ^ " -> " ^ string_of_ctyp ctyp)) cvals ctyps; + List.iter2 (fun cval ctyp -> prerr_endline (Pretty_print_sail.to_string (pp_cval cval) ^ " -> " ^ string_of_ctyp ctyp)) cvals ctyps; + + (* Work out how each call to a constructor in instantiated and add that to unifications *) let unification = List.concat (List.map2 (fun cval ctyp -> ctyp_unify ctyp (cval_ctyp cval)) cvals ctyps) in - List.iter (fun ctyp -> print_endline (string_of_ctyp ctyp)) unification; - instr + let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification) in + unifications := Bindings.add mono_id (CT_tup (List.map cval_ctyp cvals)) !unifications; + + List.iter (fun ctyp -> prerr_endline (string_of_ctyp ctyp)) unification; + prerr_endline (string_of_id mono_id); + + I_aux (I_funcall (clexp, extern, mono_id, List.map (fun (frag, ctyp) -> (unpoly frag, ctyp)) cvals), aux) | instr -> instr in @@ -1829,18 +1816,24 @@ let rec specialize_variants ctx = | (CDEF_type (CTD_variant (var_id, ctors)) as cdef) :: cdefs -> let polymorphic_ctors = List.filter (fun (_, ctyp) -> is_polymorphic ctyp) ctors in List.iter (fun (id, ctyp) -> prerr_endline (Printf.sprintf "%s : %s" (string_of_id id) (string_of_ctyp ctyp))) polymorphic_ctors; - print_endline "=== CONSTRUCTORS ==="; + prerr_endline "=== CONSTRUCTORS ==="; + let cdefs = List.fold_left (fun cdefs (ctor_id, ctyp) -> List.map (cdef_map_instr (specialize_constructor ctx ctor_id ctyp)) cdefs) cdefs polymorphic_ctors in - cdef :: specialize_variants ctx cdefs + + let ctx = { ctx with variants = Bindings.add var_id !unifications ctx.variants } in + + let cdefs, ctx = specialize_variants ctx cdefs in + CDEF_type (CTD_variant (var_id, (Bindings.bindings !unifications))) :: cdefs, ctx | cdef :: cdefs -> - cdef :: specialize_variants ctx cdefs + let cdefs, ctx = specialize_variants ctx cdefs in + cdef :: cdefs, ctx - | [] -> [] + | [] -> [], ctx (* (* When this optimization fires we know we have bytecode of the form @@ -1910,7 +1903,6 @@ let optimize ctx cdefs = let nothing cdefs = cdefs in cdefs |> (if !optimize_hoist_allocations then concatMap (hoist_allocations ctx) else nothing) - |> specialize_variants ctx (* |> (if !optimize_struct_updates then concatMap (fix_struct_updates ctx) else nothing) *) (**************************************************************************) @@ -2739,9 +2731,9 @@ let codegen_def ctx def = let vectors = List.filter is_ct_vector (cdef_ctyps ctx def) in let vectors = List.map (fun ctyp -> codegen_vector ctx (unvector ctyp)) vectors in (* prerr_endline (Pretty_print_sail.to_string (pp_cdef def)); *) - concat tups + concat vectors ^^ concat lists - ^^ concat vectors + ^^ concat tups ^^ codegen_def' ctx def let is_cdef_startup = function @@ -2866,10 +2858,9 @@ let compile_ast ctx (Defs defs) = let ctx = { ctx with tc_env = snd (Type_error.check ctx.tc_env (Defs [assert_vs; exit_vs])) } in let chunks, ctx = List.fold_left (fun (chunks, ctx) def -> let defs, ctx = compile_def ctx def in defs :: chunks, ctx) ([], ctx) defs in let cdefs = List.concat (List.rev chunks) in - - print_endline (Pretty_print_sail.to_string (separate_map (hardline ^^ hardline) pp_cdef cdefs)); - + let cdefs, ctx = specialize_variants ctx cdefs in let cdefs = optimize ctx cdefs in + prerr_endline (Pretty_print_sail.to_string (separate_map (hardline ^^ hardline) pp_cdef cdefs)); (* let cdefs = if !opt_trace then List.map (instrument_tracing ctx) cdefs else cdefs in *) diff --git a/src/reporting_basic.ml b/src/reporting_basic.ml index 65acd4ac..985136c4 100644 --- a/src/reporting_basic.ml +++ b/src/reporting_basic.ml @@ -120,7 +120,8 @@ let print_code1 ff fname lnum1 cnum1 cnum2 = Util.(Str.string_before (Str.string_after line cnum1) (cnum2 - cnum1) |> red_bg |> clear) (Str.string_after line cnum2); close_in in_chan - with e -> (close_in_noerr in_chan; print_endline (Printexc.to_string e)) + with e -> (close_in_noerr in_chan; + prerr_endline (Printf.sprintf "print_code1: %s %d %d %d %s" fname lnum1 cnum1 cnum2 (Printexc.to_string e))) end with _ -> () |
