diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/type_check_new.ml | 105 |
1 files changed, 71 insertions, 34 deletions
diff --git a/src/type_check_new.ml b/src/type_check_new.ml index 30bb97b0..6a1bc967 100644 --- a/src/type_check_new.ml +++ b/src/type_check_new.ml @@ -46,7 +46,7 @@ open Util open Ast_util open Big_int -let debug = ref 0 +let debug = ref 1 let depth = ref 0 let rec indent n = match n with @@ -957,6 +957,25 @@ let rec nexp_frees (Nexp_aux (nexp, l)) = | Nexp_exp n -> nexp_frees n | Nexp_neg n -> nexp_frees n +let order_frees (Ord_aux (ord_aux, l)) = + match ord_aux with + | Ord_var kid -> KidSet.singleton kid + | _ -> KidSet.empty + +let rec typ_frees (Typ_aux (typ_aux, l)) = + match typ_aux with + | Typ_wild -> KidSet.empty + | Typ_id v -> KidSet.empty + | Typ_var kid -> KidSet.singleton kid + | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map typ_frees typs) + | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map typ_arg_frees args) +and typ_arg_frees (Typ_arg_aux (typ_arg_aux, l)) = + match typ_arg_aux with + | Typ_arg_nexp n -> nexp_frees n + | Typ_arg_typ typ -> typ_frees typ + | Typ_arg_order ord -> order_frees ord + | Typ_arg_effect _ -> assert false + let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = match nexp1, nexp2 with | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0 @@ -979,31 +998,39 @@ exception Unification_error of l * string;; let unify_error l str = raise (Unification_error (l, str)) -let rec unify_nexps l (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) = - typ_debug ("UNIFYING NEXPS " ^ string_of_nexp nexp1 ^ " AND " ^ string_of_nexp nexp2); - match nexp_aux1 with - | Nexp_id v -> unify_error l "Unimplemented Nexp_id in unify nexp" - | Nexp_var kid -> Some (kid, nexp2) - | Nexp_constant c1 -> - begin - match nexp_aux2 with - | Nexp_constant c2 -> if c1 = c2 then None else unify_error l "Constants are not the same" - | _ -> unify_error l "Unification error" - end - | Nexp_sum (n1a, n1b) -> - if KidSet.is_empty (nexp_frees n1b) - then unify_nexps l n1a (nminus nexp2 n1b) - else - if KidSet.is_empty (nexp_frees n1a) - then unify_nexps l n1b (nminus nexp2 n1a) - else unify_error l ("Both sides of Nat expression " ^ string_of_nexp nexp1 - ^ " contain free type variables so it cannot be unified with " ^ string_of_nexp nexp2) - | Nexp_minus (n1a, n1b) -> - if KidSet.is_empty (nexp_frees n1b) - then unify_nexps l n1a (nsum nexp2 n1b) - else unify_error l ("Cannot unify minus Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) - - | _ -> unify_error l ("Cannot unify Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) +let rec unify_nexps l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) = + typ_debug ("UNIFYING NEXPS " ^ string_of_nexp nexp1 ^ " AND " ^ string_of_nexp nexp2 ^ " FOR GOALS " ^ string_of_list ", " string_of_kid (KidSet.elements goals)); + if KidSet.is_empty (KidSet.inter (nexp_frees nexp1) goals) + then + begin + if prove env (NC_aux (NC_fixed (nexp1, nexp2), Parse_ast.Unknown)) + then None + else unify_error l ("Nexp " ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 ^ " are not equal") + end + else + match nexp_aux1 with + | Nexp_id v -> unify_error l "Unimplemented Nexp_id in unify nexp" + | Nexp_var kid when KidSet.mem kid goals -> Some (kid, nexp2) + | Nexp_constant c1 -> + begin + match nexp_aux2 with + | Nexp_constant c2 -> if c1 = c2 then None else unify_error l "Constants are not the same" + | _ -> unify_error l "Unification error" + end + | Nexp_sum (n1a, n1b) -> + if KidSet.is_empty (nexp_frees n1b) + then unify_nexps l env goals n1a (nminus nexp2 n1b) + else + if KidSet.is_empty (nexp_frees n1a) + then unify_nexps l env goals n1b (nminus nexp2 n1a) + else unify_error l ("Both sides of Nat expression " ^ string_of_nexp nexp1 + ^ " contain free type variables so it cannot be unified with " ^ string_of_nexp nexp2) + | Nexp_minus (n1a, n1b) -> + if KidSet.is_empty (nexp_frees n1b) + then unify_nexps l env goals n1a (nsum nexp2 n1b) + else unify_error l ("Cannot unify minus Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + + | _ -> unify_error l ("Cannot unify Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) let string_of_uvar = function | U_nexp n -> string_of_nexp n @@ -1040,6 +1067,11 @@ let subst_args_unifiers unifiers typ_args = List.fold_left subst_unifier typ_args (KBindings.bindings unifiers) let unify l env typ1 typ2 = + typ_print ("Unify " ^ string_of_typ typ1 ^ " with " ^ string_of_typ typ2); + if not (KidSet.is_empty (KidSet.inter (typ_frees typ1) (typ_frees typ2))) + then unify_error l "Can only unify types with disjoint type variables" + else (); + let goals = typ_frees typ1 in let merge_unifiers l kid uvar1 uvar2 = match uvar1, uvar2 with | Some (U_nexp n1), Some (U_nexp n2) -> @@ -1058,7 +1090,8 @@ let unify l env typ1 typ2 = | Typ_id v1, Typ_id v2 -> if Id.compare v1 v2 = 0 then KBindings.empty else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2) - | Typ_var kid, _ -> KBindings.singleton kid (U_typ typ2) + | Typ_var kid, _ when KidSet.mem kid goals -> KBindings.singleton kid (U_typ typ2) + | Typ_var kid1, Typ_var kid2 when Kid.compare kid1 kid2 = 0 -> KBindings.empty | Typ_tup typs1, Typ_tup typs2 -> begin try List.fold_left (KBindings.merge (merge_unifiers l)) KBindings.empty (List.map2 (unify_typ l) typs1 typs2) with @@ -1095,7 +1128,7 @@ let unify l env typ1 typ2 = match typ_arg_aux1, typ_arg_aux2 with | Typ_arg_nexp n1, Typ_arg_nexp n2 -> begin - match unify_nexps l (nexp_simp n1) (nexp_simp n2) with + match unify_nexps l env goals (nexp_simp n1) (nexp_simp n2) with | Some (kid, unifier) -> KBindings.singleton kid (U_nexp unifier) | None -> KBindings.empty end @@ -1279,7 +1312,7 @@ let irule r env exp = incr depth; try let inferred_exp = r env exp in - typ_print ("Infer " ^ string_of_exp exp ^ " => " ^ string_of_typ (typ_of inferred_exp)); + typ_print ("Infer " ^ string_of_exp exp ^ " => " ^ string_of_typ (typ_of inferred_exp)); decr depth; inferred_exp with @@ -1596,11 +1629,12 @@ and bind_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) typ = | _ -> typ_error l ("Unhandled l-expression") and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = + typ_print ("Inferring " ^ string_of_exp exp); let annot_exp_effect exp typ eff = E_aux (exp, (l, Some (env, typ, eff))) in let annot_exp exp typ = annot_exp_effect exp typ no_effect in match exp_aux with | E_nondet exps -> - annot_exp (E_nondet (List.map (fun exp -> check_exp env exp unit_typ) exps)) unit_typ + annot_exp (E_nondet (List.map (fun exp -> crule check_exp env exp unit_typ) exps)) unit_typ | E_id v -> begin match Env.lookup_id v env with @@ -1618,10 +1652,11 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = end | E_field (exp, field) -> begin - let inferred_exp = infer_exp env exp in + let inferred_exp = irule infer_exp env exp in match Env.expand_synonyms env (typ_of inferred_exp) with (* Accessing a (bit) field of a register *) | Typ_aux (Typ_id regtyp, _) when Env.is_regtyp regtyp env -> + typ_print "REGTYP"; let base, top, ranges = Env.get_regtyp regtyp env in let range, _ = try List.find (fun (_, id) -> Id.compare id field = 0) ranges with @@ -1640,6 +1675,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = (* Accessing a field of a record *) | Typ_aux (Typ_id rectyp, _) as typ when Env.is_record rectyp env -> begin + typ_print "RECTYP"; let inferred_acc = infer_funapp' l (Env.no_casts env) field (Env.get_accessor field env) [strip_exp inferred_exp] None in match inferred_acc with | E_aux (E_app (field, [inferred_exp]) ,_) -> annot_exp (E_field (inferred_exp, field)) (typ_of inferred_acc) @@ -1696,8 +1732,8 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = in annot_exp (E_vector (inferred_item :: checked_items)) vec_typ | E_assert (test, msg) -> - let checked_test = check_exp env test bool_typ in - let checked_msg = check_exp env msg string_typ in + let checked_test = crule check_exp env test bool_typ in + let checked_msg = crule check_exp env msg string_typ in annot_exp (E_assert (checked_test, checked_msg)) unit_typ | _ -> typ_error l ("Cannot infer type of: " ^ string_of_exp exp) @@ -1752,7 +1788,8 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ = | None -> (quants, typs, ret_typ) | Some rct -> begin - let unifiers = try unify l env ret_typ rct with Unification_error _ -> typ_debug "UERROR"; KBindings.empty in + typ_debug ("INSTANTIATE RETURN:" ^ string_of_typ ret_typ); + let unifiers = try unify l env ret_typ rct with Unification_error _ -> typ_debug "UERROR"; KBindings.empty in typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)); let typs' = List.map (subst_unifiers unifiers) typs in let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in |
