summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBrian Campbell2018-03-29 18:31:57 +0100
committerBrian Campbell2018-04-04 14:45:00 +0100
commitccbfed3dda64c0e92ae9e07e507e63a65b4c8318 (patch)
treee95b5f2bfff389826f003f26bf7d3305d06481e2 /src
parentb20b624c6b82d8fa27396c4c3abefdf52741e6bc (diff)
Use simple equations in function specifications to instantiate tyvars
Allows the type checker to deal with val foo : forall 'm 'n, 'n = 8 * 'm. atom('m) -> bits('n) for example
Diffstat (limited to 'src')
-rw-r--r--src/type_check.ml76
1 files changed, 71 insertions, 5 deletions
diff --git a/src/type_check.ml b/src/type_check.ml
index 689a7338..06fbed0a 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -1365,6 +1365,18 @@ let uvar_subst_nexp sv subst = function
| U_effect eff -> U_effect eff
| U_order ord -> U_order ord
+let uvar_subst_typ sv subst = function
+ | U_nexp nexp -> U_nexp nexp
+ | U_typ typ -> U_typ (typ_subst_typ sv subst typ)
+ | U_effect eff -> U_effect eff
+ | U_order ord -> U_order ord
+
+let uvar_subst_order sv subst = function
+ | U_nexp nexp -> U_nexp nexp
+ | U_typ typ -> U_typ (typ_subst_order sv subst typ)
+ | U_effect eff -> U_effect eff
+ | U_order ord -> U_order (order_subst sv subst ord)
+
exception Unification_error of l * string;;
let unify_error l str = raise (Unification_error (l, str))
@@ -1472,6 +1484,16 @@ let subst_args_unifiers unifiers typ_args =
in
List.fold_left subst_unifier typ_args (KBindings.bindings unifiers)
+let subst_uvar_unifiers unifiers uvar =
+ let subst_unifier uvar' (kid, uvar) =
+ match uvar with
+ | U_nexp nexp -> uvar_subst_nexp kid (unaux_nexp nexp) uvar'
+ | U_order ord -> uvar_subst_order kid (unaux_order ord) uvar'
+ | U_typ subst -> uvar_subst_typ kid (unaux_typ subst) uvar'
+ | _ -> typ_error Parse_ast.Unknown "Cannot subst unifier"
+ in
+ List.fold_left subst_unifier uvar (KBindings.bindings unifiers)
+
let merge_unifiers l kid uvar1 uvar2 =
match uvar1, uvar2 with
| Some (U_nexp n1), Some (U_nexp n2) ->
@@ -1802,6 +1824,36 @@ let rec instantiate_quants quants kid uvar = match quants with
| _ -> (QI_aux (QI_const nc, l)) :: instantiate_quants quants kid uvar
end
+let instantiate_simple_equations =
+ let rec find_eqs kid (NC_aux (nc,_)) =
+ match nc with
+ | NC_equal (Nexp_aux (Nexp_var kid',_), nexp)
+ when Kid.compare kid kid' == 0 &&
+ not (KidSet.mem kid (nexp_frees nexp)) ->
+ [U_nexp nexp]
+ | NC_and (nexp1, nexp2) ->
+ find_eqs kid nexp1 @ find_eqs kid nexp2
+ | _ -> []
+ in
+ let rec find_eqs_quant kid (QI_aux (qi,_)) =
+ match qi with
+ | QI_id _ -> []
+ | QI_const nc -> find_eqs kid nc
+ in
+ let rec inst_from_eq = function
+ | [] -> KBindings.empty
+ | (QI_aux (QI_id kinded_kid, _) as quant) :: quants ->
+ let kid = kopt_kid kinded_kid in
+ let insts_tl = inst_from_eq quants in
+ begin
+ match List.concat (List.map (find_eqs_quant kid) quants) with
+ | [] -> insts_tl
+ | h::_ -> KBindings.add kid h insts_tl
+ end
+ | quant :: quants ->
+ inst_from_eq quants
+in inst_from_eq
+
let destruct_vec_typ l env typ =
let destruct_vec_typ' l = function
| Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp n1, _);
@@ -3017,6 +3069,11 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
| QI_aux (QI_id _, _) -> false
| QI_aux (QI_const nc, _) -> prove env nc
in
+ let record_unifiers unifiers =
+ let previous_unifiers = !all_unifiers in
+ let updated_unifiers = KBindings.map (subst_uvar_unifiers unifiers) previous_unifiers in
+ all_unifiers := merge_uvars l updated_unifiers unifiers;
+ in
let rec instantiate env quants typs ret_typ args =
match typs, args with
| (utyps, []), (uargs, []) ->
@@ -3060,7 +3117,7 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
in
let tag_unifier uvar = List.fold_left (fun uvar kid -> uvar_subst_nexp kid (Nexp_var (prepend_kid ex_tag kid)) uvar) uvar ex_kids in
let unifiers = KBindings.map tag_unifier unifiers in
- all_unifiers := merge_uvars l !all_unifiers unifiers;
+ record_unifiers unifiers;
let utyps' = List.map (subst_unifiers unifiers) utyps in
let typs' = List.map (subst_unifiers unifiers) typs in
let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
@@ -3089,7 +3146,7 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
in
typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers));
if ex_kids = [] then () else (typ_debug ("EX GOAL: " ^ string_of_option string_of_n_constraint ex_nc); ex_goal := ex_nc);
- all_unifiers := merge_uvars l !all_unifiers unifiers;
+ record_unifiers unifiers;
let env = List.fold_left (fun env kid -> Env.add_typ_var kid BK_nat env) env ex_kids in
let typs' = List.map (subst_unifiers unifiers) typs in
let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
@@ -3101,14 +3158,23 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
(quants', typs', ret_typ', env)
end
in
- let (quants, typ_args, typ_ret, env), eff =
+ let quants, typ_args, typ_ret, eff =
match Env.expand_synonyms env f_typ with
| Typ_aux (Typ_fn (Typ_aux (Typ_tup typ_args, _), typ_ret, eff), _) ->
- instantiate_ret env (quant_items typq) typ_args typ_ret, eff
+ quant_items typq, typ_args, typ_ret, eff
| Typ_aux (Typ_fn (typ_arg, typ_ret, eff), _) ->
- instantiate_ret env (quant_items typq) [typ_arg] typ_ret, eff
+ quant_items typq, [typ_arg], typ_ret, eff
| _ -> typ_error l (string_of_typ f_typ ^ " is not a function type")
in
+ let unifiers = instantiate_simple_equations quants in
+ typ_debug "Instantiating from equations";
+ typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)); all_unifiers := unifiers;
+ let typ_args = List.map (subst_unifiers unifiers) typ_args in
+ let quants = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ let typ_ret = subst_unifiers unifiers typ_ret in
+ let quants, typ_args, typ_ret, env =
+ instantiate_ret env quants typ_args typ_ret
+ in
let (xs_instantiated, typ_ret, env) = instantiate env quants ([], typ_args) typ_ret ([], number 0 xs) in
let xs_reordered = List.map snd (List.sort (fun (n, _) (m, _) -> compare n m) xs_instantiated) in