summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/c_backend.ml4
-rw-r--r--src/rewrites.ml2
-rw-r--r--src/specialize.ml47
3 files changed, 38 insertions, 15 deletions
diff --git a/src/c_backend.ml b/src/c_backend.ml
index d7d9b27f..6ab45574 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -837,7 +837,7 @@ let rec is_stack_ctyp ctyp = match ctyp with
| CT_uint64 _ | CT_int64 | CT_bit | CT_unit | CT_bool | CT_enum _ -> true
| CT_bv _ | CT_mpz | CT_real | CT_string | CT_list _ | CT_vector _ -> false
| CT_struct (_, fields) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) fields
- | CT_variant (_, ctors) -> List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors
+ | CT_variant (_, ctors) -> false (* List.for_all (fun (_, ctyp) -> is_stack_ctyp ctyp) ctors *) (*FIXME*)
| CT_tup ctyps -> List.for_all is_stack_ctyp ctyps
| CT_ref ctyp -> is_stack_ctyp ctyp
@@ -1498,7 +1498,7 @@ let compile_funcall ctx id args typ =
let rec compile_match ctx apat cval case_label =
match apat, cval with
- | AP_id pid, (frag, ctyp) when is_ct_variant ctyp ->
+ | AP_id pid, (frag, ctyp) when Env.is_union_constructor pid ctx.tc_env ->
[ijump (F_op (F_field (frag, "kind"), "!=", F_lit (V_ctor_kind (string_of_id pid))), CT_bool) case_label],
[]
| AP_global (pid, _), _ -> [icopy (CL_id pid) cval], []
diff --git a/src/rewrites.ml b/src/rewrites.ml
index d6144755..af621b47 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -3036,6 +3036,8 @@ let rewrite_defs_c = [
("constraint", rewrite_constraint);
("trivial_sizeof", rewrite_trivial_sizeof);
("sizeof", rewrite_sizeof);
+ ("merge function clauses", merge_funcls);
+ ("recheck_defs", recheck_defs)
]
let rewrite_defs_interpreter = [
diff --git a/src/specialize.ml b/src/specialize.ml
index 0a4e0fbb..1b978fbc 100644
--- a/src/specialize.ml
+++ b/src/specialize.ml
@@ -150,8 +150,16 @@ let id_of_instantiation id instantiation =
let str = string_of_instantiation instantiation in
prepend_id (str ^ "#") id
+let rec variant_generic_typ id (Defs defs) =
+ match defs with
+ | DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ when Id.compare id id' = 0 ->
+ mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
+ | _ :: defs -> variant_generic_typ id (Defs defs)
+ | [] -> failwith ("No variant with id " ^ string_of_id id)
+
(* Returns a list of all the instantiations of a function id in an
- ast. *)
+ ast. Also works with union constructors, and searches for them in
+ patterns. *)
let rec instantiations_of id ast =
let instantiations = ref [] in
@@ -163,8 +171,30 @@ let rec instantiations_of id ast =
| exp -> exp
in
- let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in
- let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in
+ (* We need to to check patterns in case id is a union constructor
+ that is never called like a function. *)
+ let inspect_pat = function
+ | P_aux (P_app (id', _), annot) as pat when Id.compare id id' = 0 ->
+ begin match Type_check.typ_of_annot annot with
+ | Typ_aux (Typ_app (variant_id, _), _) as typ ->
+ let open Type_check in
+ let instantiation, _, _ = unify (fst annot) (env_of_annot annot)
+ (variant_generic_typ variant_id ast)
+ typ
+ in
+ instantiations := fix_instantiation instantiation :: !instantiations;
+ pat
+ | Typ_aux (Typ_id variant_id, _) -> pat
+ | _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type")
+ end
+ | pat -> pat
+ in
+
+ let rewrite_pat = { id_pat_alg with p_aux = (fun (pat, annot) -> inspect_pat (P_aux (pat, annot))) } in
+ let rewrite_exp = { id_exp_alg with pat_alg = rewrite_pat;
+ e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in
+ let _ = rewrite_defs_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp);
+ rewrite_pat = (fun _ -> fold_pat rewrite_pat)} ast in
!instantiations
@@ -329,7 +359,6 @@ let remove_unused_valspecs env ast =
match defs with
| def :: defs when is_fundef id def -> remove_unused (Defs defs) id
| def :: defs when is_valspec id def ->
- prerr_endline ("Removing: " ^ string_of_id id);
remove_unused (Defs defs) id
| DEF_overload (overload_id, overloads) :: defs ->
begin
@@ -379,13 +408,6 @@ let specialize_ids ids ast =
(***** Specialising polymorphic variant types, e.g. option *****)
-let rec variant_generic_typ id (Defs defs) =
- match defs with
- | DEF_type (TD_aux (TD_variant (id', _, typq, _, _), _)) :: _ ->
- mk_typ (Typ_app (id', List.map (fun kopt -> mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
- | _ :: defs -> variant_generic_typ id (Defs defs)
- | [] -> failwith ("No variant with id " ^ string_of_id id)
-
let rewrite_polymorphic_constructors id ast =
let rewrite_e_aux = function
| E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 ->
@@ -461,6 +483,5 @@ let rec specialize ast env =
if IdSet.is_empty ids then
specialize_variants ast env
else
- (prerr_endline (Util.string_of_list ", " string_of_id (IdSet.elements ids));
let ast, env = specialize_ids ids ast in
- specialize ast env)
+ specialize ast env