summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml128
1 files changed, 62 insertions, 66 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index d8cb5a5d..bc9792ef 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -2043,60 +2043,74 @@ let rewrite_split_fun_constr_pats fun_name env (Defs defs) =
let clauses, aux_funs =
List.fold_left
(fun (clauses, aux_funs) (FCL_aux (FCL_Funcl (id, pexp), fannot) as clause) ->
- let pat, guard, exp, annot = destruct_pexp pexp in
- match pat with
- | P_aux (P_app (constr_id, args), pannot) ->
- let argstup_typ = tuple_typ (List.map typ_of_pat args) in
- let pannot' = swaptyp argstup_typ pannot in
- let pat' =
- match args with
- | [arg] -> arg
- | _ -> P_aux (P_tup args, pannot')
- in
- let pexp' = construct_pexp (pat', guard, exp, annot) in
- let aux_fun_id = prepend_id (fun_name ^ "_") constr_id in
- let aux_funcl = FCL_aux (FCL_Funcl (aux_fun_id, pexp'), pannot') in
- begin
- try
- let aux_clauses = Bindings.find aux_fun_id aux_funs in
- clauses,
- Bindings.add aux_fun_id (aux_clauses @ [aux_funcl]) aux_funs
- with Not_found ->
- let argpats, argexps = List.split (List.mapi
- (fun idx (P_aux (_,a) as pat) ->
- let id = match pat_var pat with
- | Some id -> id
- | None -> mk_id ("arg" ^ string_of_int idx)
- in
- P_aux (P_id id, a), E_aux (E_id id, a))
- args)
- in
- let pexp = construct_pexp
- (P_aux (P_app (constr_id, argpats), pannot),
- None,
- E_aux (E_app (aux_fun_id, argexps), annot),
- annot)
- in
- clauses @ [FCL_aux (FCL_Funcl (id, pexp), fannot)],
- Bindings.add aux_fun_id [aux_funcl] aux_funs
- end
- | _ -> clauses @ [clause], aux_funs)
+ let pat, guard, exp, annot = destruct_pexp pexp in
+ match pat with
+ | P_aux (P_app (constr_id, args), pannot) ->
+ let ctor_typq, ctor_typ = Env.get_union_id constr_id env in
+ let args = match args with [P_aux (P_tup args, _)] -> args | _ -> args in
+ let argstup_typ = tuple_typ (List.map typ_of_pat args) in
+ let pannot' = swaptyp argstup_typ pannot in
+ let pat' =
+ match args with
+ | [arg] -> arg
+ | _ -> P_aux (P_tup args, pannot')
+ in
+ let pexp' = construct_pexp (pat', guard, exp, annot) in
+ let aux_fun_id = prepend_id (fun_name ^ "_") constr_id in
+ let aux_funcl = FCL_aux (FCL_Funcl (aux_fun_id, pexp'), pannot') in
+ begin
+ try
+ let aux_clauses = Bindings.find aux_fun_id aux_funs in
+ clauses,
+ Bindings.add aux_fun_id (aux_clauses @ [(aux_funcl, ctor_typq, ctor_typ)]) aux_funs
+ with Not_found ->
+ let argpats, argexps = List.split (List.mapi
+ (fun idx (P_aux (_,a) as pat) ->
+ let id = match pat_var pat with
+ | Some id -> id
+ | None -> mk_id ("arg" ^ string_of_int idx)
+ in
+ P_aux (P_id id, a), E_aux (E_id id, a))
+ args)
+ in
+ let pexp = construct_pexp
+ (P_aux (P_app (constr_id, argpats), pannot),
+ None,
+ E_aux (E_app (aux_fun_id, argexps), annot),
+ annot)
+ in
+ clauses @ [FCL_aux (FCL_Funcl (id, pexp), fannot)],
+ Bindings.add aux_fun_id [(aux_funcl, ctor_typq, ctor_typ)] aux_funs
+ end
+ | _ -> clauses @ [clause], aux_funs)
([], Bindings.empty) clauses
in
- let add_aux_def id funcls defs =
- let env, args_typ, ret_typ = match funcls with
- | FCL_aux (FCL_Funcl (_, pexp), _) :: _ ->
+ let add_aux_def id aux_funs defs =
+ let funcls = List.map (fun (fcl, _, _) -> fcl) aux_funs in
+ let env, quants, args_typ, ret_typ = match aux_funs with
+ | (FCL_aux (FCL_Funcl (_, pexp), _), ctor_typq, ctor_typ) :: _ ->
let pat, _, exp, _ = destruct_pexp pexp in
- env_of exp, typ_of_pat pat, typ_of exp
+ let ctor_quants args_typ =
+ List.filter (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ args_typ))
+ (quant_items ctor_typq)
+ in
+ begin match ctor_typ with
+ | Typ_aux (Typ_fn ([Typ_aux (Typ_exist (kopts, nc, args_typ), _)], _, _), _) ->
+ env_of exp, ctor_quants args_typ @ List.map mk_qi_kopt kopts @ [mk_qi_nc nc], args_typ, typ_of exp
+ | Typ_aux (Typ_fn ([args_typ], _, _), _) -> env_of exp, ctor_quants args_typ, args_typ, typ_of exp
+ | _ ->
+ raise (Reporting.err_unreachable l __POS__
+ ("Union constructor has non-function type: " ^ string_of_typ ctor_typ))
+ end
| _ ->
raise (Reporting.err_unreachable l __POS__
- "rewrite_split_fun_constr_pats: empty auxiliary function")
+ "rewrite_split_fun_constr_pats: empty auxiliary function")
in
let eff = List.fold_left
- (fun eff (FCL_aux (FCL_Funcl (_, pexp), _)) ->
- let _, _, exp, _ = destruct_pexp pexp in
- union_effects eff (effect_of exp))
- no_effect funcls
+ (fun eff (FCL_aux (FCL_Funcl (_, pexp), _)) ->
+ let _, _, exp, _ = destruct_pexp pexp in
+ union_effects eff (effect_of exp))
+ no_effect funcls
in
let fun_typ =
(* Because we got the argument type from a pattern we need to
@@ -2107,27 +2121,9 @@ let rewrite_split_fun_constr_pats fun_name env (Defs defs) =
| _ ->
function_typ [args_typ] ret_typ eff
in
- let quant_new_kopts qis =
- let quant_kopts = List.fold_left KOptSet.union KOptSet.empty (List.map kopts_of_quant_item qis) in
- let typ_kopts = kopts_of_typ fun_typ in
- let new_kopts = KOptSet.diff typ_kopts quant_kopts in
- List.map mk_qi_kopt (KOptSet.elements new_kopts)
- in
- let typquant = match typquant with
- | TypQ_aux (TypQ_tq qis, l) ->
- let qis =
- List.filter
- (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ fun_typ))
- qis
- @ quant_new_kopts qis
- in
- TypQ_aux (TypQ_tq qis, l)
- | _ ->
- TypQ_aux (TypQ_tq (List.map mk_qi_kopt (KOptSet.elements (kopts_of_typ fun_typ))), l)
- in
let val_spec =
VS_aux (VS_val_spec
- (mk_typschm typquant fun_typ, id, (fun _ -> None), false),
+ (mk_typschm (mk_typquant quants) fun_typ, id, (fun _ -> None), false),
(Parse_ast.Unknown, empty_tannot))
in
let fundef = FD_aux (FD_function (r_o, t_o, e_o, funcls), fdannot) in