summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlasdair Armstrong2018-08-06 19:03:47 +0100
committerAlasdair Armstrong2018-08-06 19:03:47 +0100
commit6ff2e336cbf6ada9131f060bde6576b07bfe707b (patch)
tree14a4b98ce7890a3e04ad389e07c841e0f9106d8a /src
parente8213679de49e1fb14582e14ee0ec604732babef (diff)
More fixes for polymorphic data types
Diffstat (limited to 'src')
-rw-r--r--src/anf.ml28
-rw-r--r--src/anf.mli4
-rw-r--r--src/bytecode_util.ml49
-rw-r--r--src/c_backend.ml99
-rw-r--r--src/reporting_basic.ml3
5 files changed, 111 insertions, 72 deletions
diff --git a/src/anf.ml b/src/anf.ml
index 97565b2b..c686754f 100644
--- a/src/anf.ml
+++ b/src/anf.ml
@@ -119,8 +119,8 @@ and 'a apat_aux =
| AP_global of id * 'a
| AP_app of id * 'a apat
| AP_cons of 'a apat * 'a apat
- | AP_nil
- | AP_wild
+ | AP_nil of 'a
+ | AP_wild of 'a
and 'a aval =
| AV_lit of lit * 'a
@@ -133,7 +133,7 @@ and 'a aval =
| AV_C_fragment of fragment * 'a
(* Renaming variables in ANF expressions *)
-
+
let rec apat_bindings (AP_aux (apat_aux, _, _)) =
match apat_aux with
| AP_tup apats -> List.fold_left IdSet.union IdSet.empty (List.map apat_bindings apats)
@@ -141,8 +141,8 @@ let rec apat_bindings (AP_aux (apat_aux, _, _)) =
| AP_global (id, _) -> IdSet.empty
| AP_app (id, apat) -> apat_bindings apat
| AP_cons (apat1, apat2) -> IdSet.union (apat_bindings apat1) (apat_bindings apat2)
- | AP_nil -> IdSet.empty
- | AP_wild -> IdSet.empty
+ | AP_nil _ -> IdSet.empty
+ | AP_wild _ -> IdSet.empty
let rec apat_types (AP_aux (apat_aux, _, _)) =
let merge id b1 b2 =
@@ -158,8 +158,8 @@ let rec apat_types (AP_aux (apat_aux, _, _)) =
| AP_global (id, _) -> Bindings.empty
| AP_app (id, apat) -> apat_types apat
| AP_cons (apat1, apat2) -> (Bindings.merge merge) (apat_types apat1) (apat_types apat2)
- | AP_nil -> Bindings.empty
- | AP_wild -> Bindings.empty
+ | AP_nil _ -> Bindings.empty
+ | AP_wild _ -> Bindings.empty
let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) =
let apat_aux = match apat_aux with
@@ -169,8 +169,8 @@ let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) =
| AP_global (id, typ) -> AP_global (id, typ)
| AP_app (ctor, apat) -> AP_app (ctor, apat_rename from_id to_id apat)
| AP_cons (apat1, apat2) -> AP_cons (apat_rename from_id to_id apat1, apat_rename from_id to_id apat2)
- | AP_nil -> AP_nil
- | AP_wild -> AP_wild
+ | AP_nil typ -> AP_nil typ
+ | AP_wild typ -> AP_wild typ
in
AP_aux (apat_aux, env, l)
@@ -401,12 +401,12 @@ let rec pp_aexp (AE_aux (aexp, _, _)) =
and pp_apat (AP_aux (apat_aux, _, _)) =
match apat_aux with
- | AP_wild -> string "_"
+ | AP_wild _ -> string "_"
| AP_id (id, typ) -> pp_annot typ (pp_id id)
| AP_global (id, _) -> pp_id id
| AP_tup apats -> parens (separate_map (comma ^^ space) pp_apat apats)
| AP_app (id, apat) -> pp_id id ^^ parens (pp_apat apat)
- | AP_nil -> string "[||]"
+ | AP_nil _ -> string "[||]"
| AP_cons (hd_apat, tl_apat) -> pp_apat hd_apat ^^ string " :: " ^^ pp_apat tl_apat
and pp_cases cases = surround 2 0 lbrace (separate_map (comma ^^ hardline) pp_case cases) rbrace
@@ -458,19 +458,19 @@ let rec anf_pat ?global:(global=false) (P_aux (p_aux, annot) as pat) =
match p_aux with
| P_id id when global -> mk_apat (AP_global (id, pat_typ_of pat))
| P_id id -> mk_apat (AP_id (id, pat_typ_of pat))
- | P_wild -> mk_apat AP_wild
+ | P_wild -> mk_apat (AP_wild (pat_typ_of pat))
| P_tup pats -> mk_apat (AP_tup (List.map (fun pat -> anf_pat ~global:global pat) pats))
| P_app (id, [pat]) -> mk_apat (AP_app (id, anf_pat ~global:global pat))
| P_app (id, pats) -> mk_apat (AP_app (id, mk_apat (AP_tup (List.map (fun pat -> anf_pat ~global:global pat) pats))))
| P_typ (_, pat) -> anf_pat ~global:global pat
| P_var (pat, _) -> anf_pat ~global:global pat
| P_cons (hd_pat, tl_pat) -> mk_apat (AP_cons (anf_pat ~global:global hd_pat, anf_pat ~global:global tl_pat))
- | P_list pats -> List.fold_right (fun pat apat -> mk_apat (AP_cons (anf_pat ~global:global pat, apat))) pats (mk_apat AP_nil)
+ | P_list pats -> List.fold_right (fun pat apat -> mk_apat (AP_cons (anf_pat ~global:global pat, apat))) pats (mk_apat (AP_nil (pat_typ_of pat)))
| _ -> anf_error ~loc:(fst annot) ("Could not convert pattern to ANF: " ^ string_of_pat pat)
let rec apat_globals (AP_aux (aux, _, _)) =
match aux with
- | AP_nil | AP_wild | AP_id _ -> []
+ | AP_nil _ | AP_wild _ | AP_id _ -> []
| AP_global (id, typ) -> [(id, typ)]
| AP_tup apats -> List.concat (List.map apat_globals apats)
| AP_app (_, apat) -> apat_globals apat
diff --git a/src/anf.mli b/src/anf.mli
index 8ba83ccd..56c3b520 100644
--- a/src/anf.mli
+++ b/src/anf.mli
@@ -85,8 +85,8 @@ and 'a apat_aux =
| AP_global of id * 'a
| AP_app of id * 'a apat
| AP_cons of 'a apat * 'a apat
- | AP_nil
- | AP_wild
+ | AP_nil of 'a
+ | AP_wild of 'a
and 'a aval =
| AV_lit of lit * 'a
diff --git a/src/bytecode_util.ml b/src/bytecode_util.ml
index c3e61956..27086858 100644
--- a/src/bytecode_util.ml
+++ b/src/bytecode_util.ml
@@ -129,6 +129,7 @@ let rec frag_rename from_id to_id = function
| F_unary (op, f) -> F_unary (op, frag_rename from_id to_id f)
| F_field (f, field) -> F_field (frag_rename from_id to_id f, field)
| F_raw raw -> F_raw raw
+ | F_poly f -> F_poly (frag_rename from_id to_id f)
(**************************************************************************)
(* 1. Instruction pretty printer *)
@@ -192,6 +193,52 @@ and string_of_ctyp = function
| CT_ref ctyp -> "ref(" ^ string_of_ctyp ctyp ^ ")"
| CT_poly -> "*"
+let rec ctyp_equal ctyp1 ctyp2 =
+ match ctyp1, ctyp2 with
+ | CT_int, CT_int -> true
+ | CT_bits d1, CT_bits d2 -> d1 = d2
+ | CT_bits64 (m1, d1), CT_bits64 (m2, d2) -> m1 = m2 && d1 = d2
+ | CT_bit, CT_bit -> true
+ | CT_int64, CT_int64 -> true
+ | CT_unit, CT_unit -> true
+ | CT_bool, CT_bool -> true
+ | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0
+ | CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0
+ | CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0
+ | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 ->
+ List.for_all2 ctyp_equal ctyps1 ctyps2
+ | CT_string, CT_string -> true
+ | CT_real, CT_real -> true
+ | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2
+ | CT_list ctyp1, CT_list ctyp2 -> ctyp_equal ctyp1 ctyp2
+ | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_equal ctyp1 ctyp2
+ | _, _ -> false
+
+let rec ctyp_unify ctyp1 ctyp2 =
+ match ctyp1, ctyp2 with
+ | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 ->
+ List.concat (List.map2 ctyp_unify ctyps1 ctyps2)
+
+ | CT_vector (b1, ctyp1), CT_vector (b2, ctyp2) when b1 = b2 ->
+ ctyp_unify ctyp1 ctyp2
+
+ | CT_list ctyp1, CT_list ctyp2 -> ctyp_unify ctyp1 ctyp2
+
+ | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_unify ctyp1 ctyp2
+
+ | CT_poly, _ -> [ctyp2]
+
+ | _, _ when ctyp_equal ctyp1 ctyp2 -> []
+ | _, _ -> raise (Invalid_argument "ctyp_unify")
+
+let rec unpoly = function
+ | F_poly f -> unpoly f
+ | F_call (call, fs) -> F_call (call, List.map unpoly fs)
+ | F_field (f, field) -> F_field (unpoly f, field)
+ | F_op (f1, op, f2) -> F_op (unpoly f1, op, unpoly f2)
+ | F_unary (op, f) -> F_unary (op, unpoly f)
+ | f -> f
+
let rec is_polymorphic = function
| CT_int | CT_int64 | CT_bits _ | CT_bits64 _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string -> false
| CT_tup ctyps -> List.exists is_polymorphic ctyps
@@ -353,7 +400,7 @@ type dep_graph = NS.t NM.t
let rec fragment_deps = function
| F_id id | F_ref id -> NS.singleton (G_id id)
| F_lit _ -> NS.empty
- | F_field (frag, _) | F_unary (_, frag) -> fragment_deps frag
+ | F_field (frag, _) | F_unary (_, frag) | F_poly frag -> fragment_deps frag
| F_call (_, frags) -> List.fold_left NS.union NS.empty (List.map fragment_deps frags)
| F_op (frag1, _, frag2) -> NS.union (fragment_deps frag1) (fragment_deps frag2)
| F_current_exception -> NS.empty
diff --git a/src/c_backend.ml b/src/c_backend.ml
index f5c4a7fa..22527f4c 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -113,27 +113,6 @@ let initial_ctx env =
optimize_z3 = true;
}
-let rec ctyp_equal ctyp1 ctyp2 =
- match ctyp1, ctyp2 with
- | CT_int, CT_int -> true
- | CT_bits d1, CT_bits d2 -> d1 = d2
- | CT_bits64 (m1, d1), CT_bits64 (m2, d2) -> m1 = m2 && d1 = d2
- | CT_bit, CT_bit -> true
- | CT_int64, CT_int64 -> true
- | CT_unit, CT_unit -> true
- | CT_bool, CT_bool -> true
- | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0
- | CT_enum (id1, _), CT_enum (id2, _) -> Id.compare id1 id2 = 0
- | CT_variant (id1, _), CT_variant (id2, _) -> Id.compare id1 id2 = 0
- | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 ->
- List.for_all2 ctyp_equal ctyps1 ctyps2
- | CT_string, CT_string -> true
- | CT_real, CT_real -> true
- | CT_vector (d1, ctyp1), CT_vector (d2, ctyp2) -> d1 = d2 && ctyp_equal ctyp1 ctyp2
- | CT_list ctyp1, CT_list ctyp2 -> ctyp_equal ctyp1 ctyp2
- | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_equal ctyp1 ctyp2
- | _, _ -> false
-
(** Convert a sail type into a C-type **)
let rec ctyp_of_typ ctx typ =
let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.tc_env typ in
@@ -843,6 +822,13 @@ let compile_funcall l ctx id args typ =
end,
!cleanup
+let rec apat_ctyp ctx (AP_aux (apat, _, _)) =
+ match apat with
+ | AP_tup apats -> CT_tup (List.map (apat_ctyp ctx) apats)
+ | AP_global (_, typ) -> ctyp_of_typ ctx typ
+ | AP_cons (apat, _) -> CT_list (apat_ctyp ctx apat)
+ | AP_wild typ | AP_nil typ | AP_id (_, typ) -> ctyp_of_typ ctx typ
+
let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label =
let ctx = { ctx with local_env = env } in
match apat_aux, cval with
@@ -887,15 +873,23 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label =
| CT_variant (_, ctors) ->
let ctor_c_id = string_of_id ctor in
let ctor_ctyp = Bindings.find ctor (ctor_bindings ctors) in
+ let ctor_c_id =
+ if is_polymorphic ctor_ctyp then
+ let unification = ctyp_unify ctor_ctyp (apat_ctyp ctx apat) in
+ ctor_c_id ^ "_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification
+ else
+ ctor_c_id
+ in
let instrs, cleanup, ctx = compile_match ctx apat ((F_field (frag, Util.zencode_string ctor_c_id), ctor_ctyp)) case_label in
- [ijump (F_op (F_field (frag, "kind"), "!=", F_lit (V_ctor_kind ctor_c_id)), CT_bool) case_label]
- @ instrs,
+ [icomment (string_of_ctyp (apat_ctyp ctx apat)); ijump (F_op (F_field (frag, "kind"), "!=", F_lit (V_ctor_kind ctor_c_id)), CT_bool) case_label]
+ @ instrs
+ @ [icomment (string_of_ctyp ctor_ctyp)],
cleanup,
ctx
| _ -> failwith "AP_app constructor with non-variant type"
end
- | AP_wild, _ -> [], [], ctx
+ | AP_wild _, _ -> [], [], ctx
| AP_cons (hd_apat, tl_apat), (frag, CT_list ctyp) ->
let hd_setup, hd_cleanup, ctx = compile_match ctx hd_apat (F_field (F_unary ("*", frag), "hd"), ctyp) case_label in
@@ -904,7 +898,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label =
| AP_cons _, (_, _) -> c_error "Tried to pattern match cons on non list type"
- | AP_nil, (frag, _) -> [ijump (F_op (frag, "!=", F_lit V_null), CT_bool) case_label], [], ctx
+ | AP_nil _, (frag, _) -> [ijump (F_op (frag, "!=", F_lit V_null), CT_bool) case_label], [], ctx
let unit_fragment = (F_lit V_unit, CT_unit)
@@ -1792,24 +1786,10 @@ let flatten_instrs ctx =
| cdef -> [cdef]
-let rec ctyp_unify ctyp1 ctyp2 =
- match ctyp1, ctyp2 with
- | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 ->
- List.concat (List.map2 ctyp_unify ctyps1 ctyps2)
-
- | CT_vector (b1, ctyp1), CT_vector (b2, ctyp2) when b1 = b2 ->
- ctyp_unify ctyp1 ctyp2
-
- | CT_list ctyp1, CT_list ctyp2 -> ctyp_unify ctyp1 ctyp2
-
- | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_unify ctyp1 ctyp2
-
- | CT_poly, _ -> [ctyp2]
+let rec specialize_variants ctx =
- | _, _ when ctyp_equal ctyp1 ctyp2 -> []
- | _, _ -> c_error "Unification failed"
+ let unifications = ref (Bindings.empty) in
-let rec specialize_variants ctx =
let specialize_constructor ctx ctor_id ctyp =
let ctyps = match ctyp with
| CT_tup ctyps -> ctyps
@@ -1818,10 +1798,17 @@ let rec specialize_variants ctx =
function
| I_aux (I_funcall (clexp, extern, id, cvals), aux) as instr when Id.compare id ctor_id = 0 ->
assert (List.length ctyps = List.length cvals);
- List.iter2 (fun cval ctyp -> print_endline (Pretty_print_sail.to_string (pp_cval cval) ^ " -> " ^ string_of_ctyp ctyp)) cvals ctyps;
+ List.iter2 (fun cval ctyp -> prerr_endline (Pretty_print_sail.to_string (pp_cval cval) ^ " -> " ^ string_of_ctyp ctyp)) cvals ctyps;
+
+ (* Work out how each call to a constructor in instantiated and add that to unifications *)
let unification = List.concat (List.map2 (fun cval ctyp -> ctyp_unify ctyp (cval_ctyp cval)) cvals ctyps) in
- List.iter (fun ctyp -> print_endline (string_of_ctyp ctyp)) unification;
- instr
+ let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> Util.zencode_string (string_of_ctyp ctyp)) unification) in
+ unifications := Bindings.add mono_id (CT_tup (List.map cval_ctyp cvals)) !unifications;
+
+ List.iter (fun ctyp -> prerr_endline (string_of_ctyp ctyp)) unification;
+ prerr_endline (string_of_id mono_id);
+
+ I_aux (I_funcall (clexp, extern, mono_id, List.map (fun (frag, ctyp) -> (unpoly frag, ctyp)) cvals), aux)
| instr -> instr
in
@@ -1829,18 +1816,24 @@ let rec specialize_variants ctx =
| (CDEF_type (CTD_variant (var_id, ctors)) as cdef) :: cdefs ->
let polymorphic_ctors = List.filter (fun (_, ctyp) -> is_polymorphic ctyp) ctors in
List.iter (fun (id, ctyp) -> prerr_endline (Printf.sprintf "%s : %s" (string_of_id id) (string_of_ctyp ctyp))) polymorphic_ctors;
- print_endline "=== CONSTRUCTORS ===";
+ prerr_endline "=== CONSTRUCTORS ===";
+
let cdefs =
List.fold_left (fun cdefs (ctor_id, ctyp) -> List.map (cdef_map_instr (specialize_constructor ctx ctor_id ctyp)) cdefs)
cdefs
polymorphic_ctors
in
- cdef :: specialize_variants ctx cdefs
+
+ let ctx = { ctx with variants = Bindings.add var_id !unifications ctx.variants } in
+
+ let cdefs, ctx = specialize_variants ctx cdefs in
+ CDEF_type (CTD_variant (var_id, (Bindings.bindings !unifications))) :: cdefs, ctx
| cdef :: cdefs ->
- cdef :: specialize_variants ctx cdefs
+ let cdefs, ctx = specialize_variants ctx cdefs in
+ cdef :: cdefs, ctx
- | [] -> []
+ | [] -> [], ctx
(*
(* When this optimization fires we know we have bytecode of the form
@@ -1910,7 +1903,6 @@ let optimize ctx cdefs =
let nothing cdefs = cdefs in
cdefs
|> (if !optimize_hoist_allocations then concatMap (hoist_allocations ctx) else nothing)
- |> specialize_variants ctx
(* |> (if !optimize_struct_updates then concatMap (fix_struct_updates ctx) else nothing) *)
(**************************************************************************)
@@ -2739,9 +2731,9 @@ let codegen_def ctx def =
let vectors = List.filter is_ct_vector (cdef_ctyps ctx def) in
let vectors = List.map (fun ctyp -> codegen_vector ctx (unvector ctyp)) vectors in
(* prerr_endline (Pretty_print_sail.to_string (pp_cdef def)); *)
- concat tups
+ concat vectors
^^ concat lists
- ^^ concat vectors
+ ^^ concat tups
^^ codegen_def' ctx def
let is_cdef_startup = function
@@ -2866,10 +2858,9 @@ let compile_ast ctx (Defs defs) =
let ctx = { ctx with tc_env = snd (Type_error.check ctx.tc_env (Defs [assert_vs; exit_vs])) } in
let chunks, ctx = List.fold_left (fun (chunks, ctx) def -> let defs, ctx = compile_def ctx def in defs :: chunks, ctx) ([], ctx) defs in
let cdefs = List.concat (List.rev chunks) in
-
- print_endline (Pretty_print_sail.to_string (separate_map (hardline ^^ hardline) pp_cdef cdefs));
-
+ let cdefs, ctx = specialize_variants ctx cdefs in
let cdefs = optimize ctx cdefs in
+ prerr_endline (Pretty_print_sail.to_string (separate_map (hardline ^^ hardline) pp_cdef cdefs));
(*
let cdefs = if !opt_trace then List.map (instrument_tracing ctx) cdefs else cdefs in
*)
diff --git a/src/reporting_basic.ml b/src/reporting_basic.ml
index 65acd4ac..985136c4 100644
--- a/src/reporting_basic.ml
+++ b/src/reporting_basic.ml
@@ -120,7 +120,8 @@ let print_code1 ff fname lnum1 cnum1 cnum2 =
Util.(Str.string_before (Str.string_after line cnum1) (cnum2 - cnum1) |> red_bg |> clear)
(Str.string_after line cnum2);
close_in in_chan
- with e -> (close_in_noerr in_chan; print_endline (Printexc.to_string e))
+ with e -> (close_in_noerr in_chan;
+ prerr_endline (Printf.sprintf "print_code1: %s %d %d %d %s" fname lnum1 cnum1 cnum2 (Printexc.to_string e)))
end
with _ -> ()