summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/constant_propagation_mutrec.ml42
1 files changed, 30 insertions, 12 deletions
diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml
index 6cc6d28c..03d8e154 100644
--- a/src/constant_propagation_mutrec.ml
+++ b/src/constant_propagation_mutrec.ml
@@ -97,7 +97,8 @@ let generate_fun_id id args =
that will be propagated in *)
let generate_val_spec env id args l annot =
match Env.get_val_spec_orig id env with
- | tq, Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) ->
+ | tq, (Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) as fn_typ) ->
+ (* Get instantiation of type variables at call site *)
let orig_ksubst (kid, typ_arg) =
match typ_arg with
| A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg)
@@ -110,21 +111,38 @@ let generate_val_spec env id args l annot =
|> List.map orig_ksubst
|> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty
in
+ (* Apply instantiation to original function type. Also collect the
+ type variables in the new type together their kinds for the new
+ val spec. *)
+ let kopts_of_typ env typ =
+ tyvars_of_typ typ |> KidSet.elements
+ |> List.map (fun kid -> mk_kopt (Env.get_typ_var kid env) kid)
+ |> KOptSet.of_list
+ in
let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in
- let arg_typs' =
- List.map (KBindings.fold typ_subst ksubsts) arg_typs
- |> List.map2 (fun arg typ -> if is_const_exp arg then [] else [typ]) args
- |> List.concat
- |> function [] -> [unit_typ] | typs -> typs
+ let (arg_typs', kopts') =
+ List.fold_right2 (fun arg typ (arg_typs', kopts') ->
+ if is_const_exp arg then
+ (arg_typs', kopts')
+ else
+ let typ' = KBindings.fold typ_subst ksubsts typ in
+ let arg_kopts = kopts_of_typ (env_of arg) typ' in
+ (typ' :: arg_typs', KOptSet.union arg_kopts kopts'))
+ args arg_typs ([], kopts_of_typ (env_of_tannot annot) ret_typ')
in
+ let arg_typs' = if arg_typs' = [] then [unit_typ] else arg_typs' in
let typ' = mk_typ (Typ_fn (arg_typs', ret_typ', eff)) in
- let tyvars = tyvars_of_typ typ' in
- let tq' =
- quant_items tq |>
- List.filter (fun qi -> KidSet.subset (tyvars_of_quant_item qi) tyvars) |>
- mk_typquant
+ (* Construct new val spec *)
+ let constraints' =
+ quant_split tq |> snd
+ |> List.map (KBindings.fold constraint_subst ksubsts)
+ |> List.filter (fun nc -> KidSet.subset (tyvars_of_constraint nc) (tyvars_of_typ typ'))
+ in
+ let quant_items' =
+ List.map mk_qi_kopt (KOptSet.elements kopts') @
+ List.map mk_qi_nc constraints'
in
- let typschm = mk_typschm tq' typ' in
+ let typschm = mk_typschm (mk_typquant quant_items') typ' in
mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, [], false)),
ksubsts
| _, Typ_aux (_, l) ->