summaryrefslogtreecommitdiff
path: root/src/constant_propagation_mutrec.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/constant_propagation_mutrec.ml')
-rw-r--r--src/constant_propagation_mutrec.ml57
1 files changed, 38 insertions, 19 deletions
diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml
index 683cc6f3..03d8e154 100644
--- a/src/constant_propagation_mutrec.ml
+++ b/src/constant_propagation_mutrec.ml
@@ -97,7 +97,8 @@ let generate_fun_id id args =
that will be propagated in *)
let generate_val_spec env id args l annot =
match Env.get_val_spec_orig id env with
- | tq, Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) ->
+ | tq, (Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) as fn_typ) ->
+ (* Get instantiation of type variables at call site *)
let orig_ksubst (kid, typ_arg) =
match typ_arg with
| A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg)
@@ -110,27 +111,44 @@ let generate_val_spec env id args l annot =
|> List.map orig_ksubst
|> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty
in
+ (* Apply instantiation to original function type. Also collect the
+ type variables in the new type together their kinds for the new
+ val spec. *)
+ let kopts_of_typ env typ =
+ tyvars_of_typ typ |> KidSet.elements
+ |> List.map (fun kid -> mk_kopt (Env.get_typ_var kid env) kid)
+ |> KOptSet.of_list
+ in
let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in
- let arg_typs' =
- List.map (KBindings.fold typ_subst ksubsts) arg_typs
- |> List.map2 (fun arg typ -> if is_const_exp arg then [] else [typ]) args
- |> List.concat
- |> function [] -> [unit_typ] | typs -> typs
+ let (arg_typs', kopts') =
+ List.fold_right2 (fun arg typ (arg_typs', kopts') ->
+ if is_const_exp arg then
+ (arg_typs', kopts')
+ else
+ let typ' = KBindings.fold typ_subst ksubsts typ in
+ let arg_kopts = kopts_of_typ (env_of arg) typ' in
+ (typ' :: arg_typs', KOptSet.union arg_kopts kopts'))
+ args arg_typs ([], kopts_of_typ (env_of_tannot annot) ret_typ')
in
+ let arg_typs' = if arg_typs' = [] then [unit_typ] else arg_typs' in
let typ' = mk_typ (Typ_fn (arg_typs', ret_typ', eff)) in
- let tyvars = tyvars_of_typ typ' in
- let tq' =
- quant_items tq |>
- List.filter (fun qi -> KidSet.subset (tyvars_of_quant_item qi) tyvars) |>
- mk_typquant
+ (* Construct new val spec *)
+ let constraints' =
+ quant_split tq |> snd
+ |> List.map (KBindings.fold constraint_subst ksubsts)
+ |> List.filter (fun nc -> KidSet.subset (tyvars_of_constraint nc) (tyvars_of_typ typ'))
+ in
+ let quant_items' =
+ List.map mk_qi_kopt (KOptSet.elements kopts') @
+ List.map mk_qi_nc constraints'
in
- let typschm = mk_typschm tq' typ' in
- mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, (fun _ -> None), false)),
+ let typschm = mk_typschm (mk_typquant quant_items') typ' in
+ mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, [], false)),
ksubsts
| _, Typ_aux (_, l) ->
raise (Reporting.err_unreachable l __POS__ "Function val spec is not a function type")
-let const_prop defs substs ksubsts exp =
+let const_prop target defs substs ksubsts exp =
(* Constant_propagation currently only supports nexps for kid substitutions *)
let nexp_substs =
KBindings.bindings ksubsts
@@ -139,6 +157,7 @@ let const_prop defs substs ksubsts exp =
|> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty
in
Constant_propagation.const_prop
+ target
(Defs defs)
(Constant_propagation.referenced_vars exp)
(substs, nexp_substs)
@@ -147,7 +166,7 @@ let const_prop defs substs ksubsts exp =
|> fst
(* Propagate constant arguments into function clause pexp *)
-let prop_args_pexp defs ksubsts args pexp =
+let prop_args_pexp target defs ksubsts args pexp =
let pat, guard, exp, annot = destruct_pexp pexp in
let pats = match pat with
| P_aux (P_tup pats, _) -> pats
@@ -164,14 +183,14 @@ let prop_args_pexp defs ksubsts args pexp =
else (pat :: pats, substs)
in
let pats, substs = List.fold_right2 match_arg args pats ([], Bindings.empty) in
- let exp' = const_prop defs substs ksubsts exp in
+ let exp' = const_prop target defs substs ksubsts exp in
let pat' = match pats with
| [pat] -> pat
| _ -> P_aux (P_tup pats, (Parse_ast.Unknown, empty_tannot))
in
construct_pexp (pat', guard, exp', annot)
-let rewrite_defs env (Defs defs) =
+let rewrite_defs target env (Defs defs) =
let rec rewrite = function
| [] -> []
| DEF_internal_mutrec mutrecs :: ds ->
@@ -194,7 +213,7 @@ let rewrite_defs env (Defs defs) =
let valspec, ksubsts = generate_val_spec env id args l annot in
let const_prop_funcl (FCL_aux (FCL_Funcl (_, pexp), (l, _))) =
let pexp' =
- prop_args_pexp defs ksubsts args pexp
+ prop_args_pexp target defs ksubsts args pexp
|> rewrite_pexp
|> strip_pexp
in
@@ -215,7 +234,7 @@ let rewrite_defs env (Defs defs) =
let pexp' =
if List.exists (fun id' -> Id.compare id id' = 0) !targets then
let pat, guard, body, annot = destruct_pexp pexp in
- let body' = const_prop defs Bindings.empty KBindings.empty body in
+ let body' = const_prop target defs Bindings.empty KBindings.empty body in
rewrite_pexp (construct_pexp (pat, guard, recheck_exp body', annot))
else pexp
in FCL_aux (FCL_Funcl (id, pexp'), a)