diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/constant_propagation_mutrec.ml | 42 |
1 files changed, 30 insertions, 12 deletions
diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml index 6cc6d28c..03d8e154 100644 --- a/src/constant_propagation_mutrec.ml +++ b/src/constant_propagation_mutrec.ml @@ -97,7 +97,8 @@ let generate_fun_id id args = that will be propagated in *) let generate_val_spec env id args l annot = match Env.get_val_spec_orig id env with - | tq, Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) -> + | tq, (Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) as fn_typ) -> + (* Get instantiation of type variables at call site *) let orig_ksubst (kid, typ_arg) = match typ_arg with | A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg) @@ -110,21 +111,38 @@ let generate_val_spec env id args l annot = |> List.map orig_ksubst |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty in + (* Apply instantiation to original function type. Also collect the + type variables in the new type together their kinds for the new + val spec. *) + let kopts_of_typ env typ = + tyvars_of_typ typ |> KidSet.elements + |> List.map (fun kid -> mk_kopt (Env.get_typ_var kid env) kid) + |> KOptSet.of_list + in let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in - let arg_typs' = - List.map (KBindings.fold typ_subst ksubsts) arg_typs - |> List.map2 (fun arg typ -> if is_const_exp arg then [] else [typ]) args - |> List.concat - |> function [] -> [unit_typ] | typs -> typs + let (arg_typs', kopts') = + List.fold_right2 (fun arg typ (arg_typs', kopts') -> + if is_const_exp arg then + (arg_typs', kopts') + else + let typ' = KBindings.fold typ_subst ksubsts typ in + let arg_kopts = kopts_of_typ (env_of arg) typ' in + (typ' :: arg_typs', KOptSet.union arg_kopts kopts')) + args arg_typs ([], kopts_of_typ (env_of_tannot annot) ret_typ') in + let arg_typs' = if arg_typs' = [] then [unit_typ] else arg_typs' in let typ' = mk_typ (Typ_fn (arg_typs', ret_typ', eff)) in - let tyvars = tyvars_of_typ typ' in - let tq' = - quant_items tq |> - List.filter (fun qi -> KidSet.subset (tyvars_of_quant_item qi) tyvars) |> - mk_typquant + (* Construct new val spec *) + let constraints' = + quant_split tq |> snd + |> List.map (KBindings.fold constraint_subst ksubsts) + |> List.filter (fun nc -> KidSet.subset (tyvars_of_constraint nc) (tyvars_of_typ typ')) + in + let quant_items' = + List.map mk_qi_kopt (KOptSet.elements kopts') @ + List.map mk_qi_nc constraints' in - let typschm = mk_typschm tq' typ' in + let typschm = mk_typschm (mk_typquant quant_items') typ' in mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, [], false)), ksubsts | _, Typ_aux (_, l) -> |
