diff options
Diffstat (limited to 'src/rewrites.ml')
| -rw-r--r-- | src/rewrites.ml | 128 |
1 files changed, 62 insertions, 66 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index d8cb5a5d..bc9792ef 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2043,60 +2043,74 @@ let rewrite_split_fun_constr_pats fun_name env (Defs defs) = let clauses, aux_funs = List.fold_left (fun (clauses, aux_funs) (FCL_aux (FCL_Funcl (id, pexp), fannot) as clause) -> - let pat, guard, exp, annot = destruct_pexp pexp in - match pat with - | P_aux (P_app (constr_id, args), pannot) -> - let argstup_typ = tuple_typ (List.map typ_of_pat args) in - let pannot' = swaptyp argstup_typ pannot in - let pat' = - match args with - | [arg] -> arg - | _ -> P_aux (P_tup args, pannot') - in - let pexp' = construct_pexp (pat', guard, exp, annot) in - let aux_fun_id = prepend_id (fun_name ^ "_") constr_id in - let aux_funcl = FCL_aux (FCL_Funcl (aux_fun_id, pexp'), pannot') in - begin - try - let aux_clauses = Bindings.find aux_fun_id aux_funs in - clauses, - Bindings.add aux_fun_id (aux_clauses @ [aux_funcl]) aux_funs - with Not_found -> - let argpats, argexps = List.split (List.mapi - (fun idx (P_aux (_,a) as pat) -> - let id = match pat_var pat with - | Some id -> id - | None -> mk_id ("arg" ^ string_of_int idx) - in - P_aux (P_id id, a), E_aux (E_id id, a)) - args) - in - let pexp = construct_pexp - (P_aux (P_app (constr_id, argpats), pannot), - None, - E_aux (E_app (aux_fun_id, argexps), annot), - annot) - in - clauses @ [FCL_aux (FCL_Funcl (id, pexp), fannot)], - Bindings.add aux_fun_id [aux_funcl] aux_funs - end - | _ -> clauses @ [clause], aux_funs) + let pat, guard, exp, annot = destruct_pexp pexp in + match pat with + | P_aux (P_app (constr_id, args), pannot) -> + let ctor_typq, ctor_typ = Env.get_union_id constr_id env in + let args = match args with [P_aux (P_tup args, _)] -> args | _ -> args in + let argstup_typ = tuple_typ (List.map typ_of_pat args) in + let pannot' = swaptyp argstup_typ pannot in + let pat' = + match args with + | [arg] -> arg + | _ -> P_aux (P_tup args, pannot') + in + let pexp' = construct_pexp (pat', guard, exp, annot) in + let aux_fun_id = prepend_id (fun_name ^ "_") constr_id in + let aux_funcl = FCL_aux (FCL_Funcl (aux_fun_id, pexp'), pannot') in + begin + try + let aux_clauses = Bindings.find aux_fun_id aux_funs in + clauses, + Bindings.add aux_fun_id (aux_clauses @ [(aux_funcl, ctor_typq, ctor_typ)]) aux_funs + with Not_found -> + let argpats, argexps = List.split (List.mapi + (fun idx (P_aux (_,a) as pat) -> + let id = match pat_var pat with + | Some id -> id + | None -> mk_id ("arg" ^ string_of_int idx) + in + P_aux (P_id id, a), E_aux (E_id id, a)) + args) + in + let pexp = construct_pexp + (P_aux (P_app (constr_id, argpats), pannot), + None, + E_aux (E_app (aux_fun_id, argexps), annot), + annot) + in + clauses @ [FCL_aux (FCL_Funcl (id, pexp), fannot)], + Bindings.add aux_fun_id [(aux_funcl, ctor_typq, ctor_typ)] aux_funs + end + | _ -> clauses @ [clause], aux_funs) ([], Bindings.empty) clauses in - let add_aux_def id funcls defs = - let env, args_typ, ret_typ = match funcls with - | FCL_aux (FCL_Funcl (_, pexp), _) :: _ -> + let add_aux_def id aux_funs defs = + let funcls = List.map (fun (fcl, _, _) -> fcl) aux_funs in + let env, quants, args_typ, ret_typ = match aux_funs with + | (FCL_aux (FCL_Funcl (_, pexp), _), ctor_typq, ctor_typ) :: _ -> let pat, _, exp, _ = destruct_pexp pexp in - env_of exp, typ_of_pat pat, typ_of exp + let ctor_quants args_typ = + List.filter (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ args_typ)) + (quant_items ctor_typq) + in + begin match ctor_typ with + | Typ_aux (Typ_fn ([Typ_aux (Typ_exist (kopts, nc, args_typ), _)], _, _), _) -> + env_of exp, ctor_quants args_typ @ List.map mk_qi_kopt kopts @ [mk_qi_nc nc], args_typ, typ_of exp + | Typ_aux (Typ_fn ([args_typ], _, _), _) -> env_of exp, ctor_quants args_typ, args_typ, typ_of exp + | _ -> + raise (Reporting.err_unreachable l __POS__ + ("Union constructor has non-function type: " ^ string_of_typ ctor_typ)) + end | _ -> raise (Reporting.err_unreachable l __POS__ - "rewrite_split_fun_constr_pats: empty auxiliary function") + "rewrite_split_fun_constr_pats: empty auxiliary function") in let eff = List.fold_left - (fun eff (FCL_aux (FCL_Funcl (_, pexp), _)) -> - let _, _, exp, _ = destruct_pexp pexp in - union_effects eff (effect_of exp)) - no_effect funcls + (fun eff (FCL_aux (FCL_Funcl (_, pexp), _)) -> + let _, _, exp, _ = destruct_pexp pexp in + union_effects eff (effect_of exp)) + no_effect funcls in let fun_typ = (* Because we got the argument type from a pattern we need to @@ -2107,27 +2121,9 @@ let rewrite_split_fun_constr_pats fun_name env (Defs defs) = | _ -> function_typ [args_typ] ret_typ eff in - let quant_new_kopts qis = - let quant_kopts = List.fold_left KOptSet.union KOptSet.empty (List.map kopts_of_quant_item qis) in - let typ_kopts = kopts_of_typ fun_typ in - let new_kopts = KOptSet.diff typ_kopts quant_kopts in - List.map mk_qi_kopt (KOptSet.elements new_kopts) - in - let typquant = match typquant with - | TypQ_aux (TypQ_tq qis, l) -> - let qis = - List.filter - (fun qi -> KOptSet.subset (kopts_of_quant_item qi) (kopts_of_typ fun_typ)) - qis - @ quant_new_kopts qis - in - TypQ_aux (TypQ_tq qis, l) - | _ -> - TypQ_aux (TypQ_tq (List.map mk_qi_kopt (KOptSet.elements (kopts_of_typ fun_typ))), l) - in let val_spec = VS_aux (VS_val_spec - (mk_typschm typquant fun_typ, id, (fun _ -> None), false), + (mk_typschm (mk_typquant quants) fun_typ, id, (fun _ -> None), false), (Parse_ast.Unknown, empty_tannot)) in let fundef = FD_aux (FD_function (r_o, t_o, e_o, funcls), fdannot) in |
