summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/prelude.sail177
-rw-r--r--src/type_check_new.ml105
2 files changed, 149 insertions, 133 deletions
diff --git a/lib/prelude.sail b/lib/prelude.sail
index 698a39a0..54d56c15 100644
--- a/lib/prelude.sail
+++ b/lib/prelude.sail
@@ -7,12 +7,14 @@ val cast forall Nat 'n, Nat 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0:2**'m
val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|]) -> 'a effect pure vector_access_dec
val forall Nat 'n, Nat 'l, Type 'a, 'l >= 0. (vector<'n,'l,inc,'a>, [|'n:'n + 'l - 1|]) -> 'a effect pure vector_access_inc
+overload vector_access [vector_access_inc; vector_access_dec]
+
(* Type safe vector subrange *)
val forall Nat 'n, Nat 'l, Nat 'm, Nat 'o, Type 'a, 'l >= 0, 'm <= 'o, 'o <= 'l.
(vector<'n,'l,inc,'a>, [:'m:], [:'o:]) -> vector<'m,'o - 'm,inc,'a> effect pure vector_subrange_inc
val forall Nat 'n, Nat 'l, Nat 'm, Nat 'o, Type 'a, 'n >= 'm, 'm >= 'o, 'o >= 'n - 'l + 1.
- (vector<'n,'l,dec,'a>, [:'m:], [:'o:]) -> vector<'m,('m - 'o) - 1,dec,'a> effect pure vector_subrange_dec
+ (vector<'n,'l,dec,'a>, [:'m:], [:'o:]) -> vector<'m,'m - 'o - 1,dec,'a> effect pure vector_subrange_dec
overload vector_subrange [vector_subrange_inc; vector_subrange_dec]
@@ -21,9 +23,7 @@ val forall Nat 'n1, Nat 'l1, Nat 'n2, Nat 'l2, Order 'o, Type 'a, 'l1 >= 0, 'l2
(vector<'n1,'l1,'o,'a>, vector<'n2,'l2,'o,'a>) -> vector<'n1,'l1 + 'l2,'o,'a> effect pure vector_append
(* Implicit register dereferencing *)
-val cast forall Type 'a. register<'a> -> 'a effect pure reg_deref
-
-overload vector_access [vector_access_inc; vector_access_dec]
+val cast forall Type 'a. register<'a> -> 'a effect {rreg} reg_deref
(* Bitvector duplication *)
val forall Nat 'n. (bit, [:'n:]) -> vector<'n - 1,'n,dec,bit> effect pure duplicate
@@ -64,139 +64,118 @@ val forall Nat 'n, Nat 'm, Order 'ord. vector<'n, 'm, 'ord, bit> -> bit effect p
(* Arithmetic *)
-val forall Nat 'n, Nat 'm.
- (atom<'n>, atom<'m>) -> atom<'n+'m> effect pure add
-
-val forall Nat 'n, Nat 'o, Nat 'p, Order 'ord.
- (vector<'o, 'n, 'ord, bit>, vector<'p, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_vec
+val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p.
+ ([|'n:'m|], [|'o:'p|]) -> [|'n + 'o:'m + 'p|] effect pure add
-val forall Nat 'n, Nat 'o, Nat 'p, Nat 'q, Order 'ord.
- (vector<'o, 'n, 'ord, bit>, vector<'p, 'n, 'ord, bit>) -> range<'q, 2**'n> effect pure add_vec_vec_range
+val (nat, nat) -> nat effect pure add_nat
-(* FIXME: the parser is broken for 2**... it's just been hacked to work for this common case *)
-val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord, 'o <= 2** 'm - 1.
- (vector<'n, 'm, 'ord, bit>, atom<'o>) -> vector<'n, 'm, 'ord, bit> effect pure add_vec_range
+val (int, int) -> int effect pure add_int
-val forall Nat 'n, Nat 'o, Nat 'p, Order 'ord.
- (vector<'o, 'n, 'ord, bit>, vector<'p, 'n, 'ord, bit>) -> (vector<'o, 'n, 'ord, bit>, bit, bit) effect pure add_overflow_vec
+val forall Nat 'n, Nat 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure add_vec
-(* but it doesn't parse this
-val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord, 'o <= 2** 'm - 1.
- (vector<'n, 'm, 'ord, bit>, atom<'o>) -> range<'o, 'o+2** 'm> effect pure add_vec_range_range
- *)
+val forall Nat 'n, Nat 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> (vector<'o, 'n, 'ord, bit>, bit, bit) effect pure add_overflow_vec
-val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord, 'o <= 2** 'm - 1.
- (atom<'o>, vector<'n, 'm, 'ord, bit>) -> vector<'n, 'm, 'ord, bit> effect pure add_range_vec
+val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p.
+ ([|'n:'m|], [|'o:'p|]) -> [|'n - 'p:'m - 'o|] effect pure sub
-(* or this
-val forall Nat 'n, Nat 'm, Nat 'o, Order 'ord, 'o <= 2** 'm - 1.
- (atom<'o>, vector<'n, 'm, 'ord, bit>) -> range<'o, 'o+2**'m-1> effect pure add_range_vec_range
-*)
+val (int, int) -> int effect pure sub_int
-val forall Nat 'o, Nat 'p, Order 'ord.
- (vector<'o, 'p, 'ord, bit>, bit) -> vector<'o, 'p, 'ord, bit> effect pure add_vec_bit
+val forall Nat 'n, Nat 'm, Order 'ord.
+ (vector<'n,'m,'ord,bit>, int) -> vector<'n,'m,'ord,bit> effect pure sub_vec_int
-val forall Nat 'o, Nat 'p, Order 'ord.
- (bit, vector<'o, 'p, 'ord, bit>) -> vector<'o, 'p, 'ord, bit> effect pure add_bit_vec
+val forall Nat 'n, Nat 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> vector<'o, 'n, 'ord, bit> effect pure sub_vec
-val forall Nat 'n, Nat 'm. ([:'n:], [:'m:]) -> [:'n - 'm:] effect pure sub_exact
-val forall Nat 'n, Nat 'm, Nat 'o, 'o <= 'm - 'n. ([|'n:'m|], [:'o:]) -> [|'n:'m - 'o|] effect pure sub_range
-val forall Nat 'n, Nat 'm, Order 'ord. (vector<'n,'m,'ord,bit>, int) -> vector<'n,'m,'ord,bit> effect pure sub_bv
+val forall Nat 'n, Nat 'o, Order 'ord.
+ (vector<'o, 'n, 'ord, bit>, vector<'o, 'n, 'ord, bit>) -> (vector<'o, 'n, 'ord, bit>, bit, bit) effect pure sub_underflow_vec
overload (deinfix +) [
- add;
add_vec;
- add_vec_vec_range;
- add_vec_range;
add_overflow_vec;
- add_vec_range_range;
- add_range_vec;
- add_range_vec_range;
- add_vec_bit;
- add_bit_vec;
+ add;
+ add_nat;
+ add_int
]
overload (deinfix -) [
- sub_exact;
- sub_bv;
- sub_range;
+ sub_vec_int;
+ sub_vec;
+ sub_underflow_vec;
+ sub;
+ sub_int
]
-(* Equality *)
-
-(* Sail gives a bunch of overloads for equality, but apparantly also
-gives an equality and inequality for any type 'a, so why bother
-overloading? *)
-
-val forall Type 'a. ('a, 'a) -> bool effect pure eq
-val forall Type 'a. ('a, 'a) -> bool effect pure neq
-
-overload (deinfix ==) [eq]
-overload (deinfix !=) [neq]
-
(* Boolean operators *)
val bool -> bool effect pure bool_not
val (bool, bool) -> bool effect pure bool_or
val (bool, bool) -> bool effect pure bool_and
-overload ~ [bool_not]
-overload (deinfix &) [bool_and]
-overload (deinfix |) [bool_or]
+val forall Num 'n, Num 'm, Order 'ord.
+ vector<'n,'m,'ord,bit> -> vector<'n,'m,'ord,bit> effect pure bitwise_not
-(*
-val forall Nat 'n, Nat 'l, Nat 'm, Nat 'o, Type 'a, 'n >= 'm, 'm >= 'o, 'o >= 'n - 'l + 1. (vector<'n,'l,dec,'a>, [:'m:], [:'o:]) -> vector<'m,'m - 'o - 1,dec,'a> effect pure vector_subrange
+val forall Num 'n, Num 'm, Order 'ord.
+ (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> vector<'n,'m,'ord,bit> effect pure bitwise_and
-val forall Nat 'n, Nat 'l, Order 'ord. ([|0:1|], vector<'n,'l,'ord,bit>) -> bool effect pure vec_eq_01_left
-val forall Nat 'n, Nat 'l, Order 'ord. (vector<'n,'l,'ord,bit>, [|0:1|]) -> bool effect pure vec_eq_01_right
+val forall Num 'n, Num 'm, Order 'ord.
+ (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> vector<'n,'m,'ord,bit> effect pure bitwise_or
-val forall Nat 'n, Nat 'l, Order 'ord. [|0:1|] -> vector<'n,'l,'ord,bit> effect pure cast_01_to_vec
+overload ~ [bool_not; bitwise_not]
+overload (deinfix &) [bool_and; bitwise_and]
+overload (deinfix |) [bool_or; bitwise_or]
-val forall Nat 'n, Nat 'm, Order 'ord. vector<'n,'m,'ord,bit> -> [|0:2**'m - 1|] effect pure cast_vec_to_range
+(* Equality *)
-val forall Type 'a. register<'a> -> 'a effect pure reg_deref
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure eq_vec
-val forall Nat 'n, Nat 'l, Type 'a.
- (vector<'n,'l,dec,'a>, [|'n - 'l + 1:'n|], 'a) -> vector<'n,'l,dec,'a>
- effect pure vector_update_dec
+val forall Type 'a. ('a, 'a) -> bool effect pure eq
-val forall Nat 'n, Nat 'm, Nat 'o, Type 'a, 'o <= 'm.
- vector<'n,'m,dec,'a> -> vector<'o - 1,'o,dec,'a>
- effect pure mask_dec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure neq_vec
-val forall Nat 'n, Nat 'm, Nat 'o, Type 'a, 'o <= 'm.
- vector<'n,'m,inc,'a> -> vector<0,'o,inc,'a>
- effect pure mask_inc
+val forall Type 'a. ('a, 'a) -> bool effect pure neq
-val bool -> bool effect pure not
-val (bool, bool) -> bool effect pure bool_or
-val (bool, bool) -> bool effect pure bool_and
+function forall Num 'n, Num 'm, Order 'ord. bool neq_vec (v1, v2) = bool_not(eq_vec(v1, v2))
-val forall Nat 'n. vector<'n,'n,dec,bit> -> bool effect pure cast_dec_bv_to_bool
+overload (deinfix ==) [eq_vec; eq]
+overload (deinfix !=) [neq_vec; neq]
-val bit -> bool effect pure cast_bit_to_bool
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gteq_vec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure gt_vec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lteq_vec
+val forall Num 'n, Num 'm, Order 'ord. (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> bool effect pure lt_vec
-val forall Nat 'n, Nat 'm. ([:'n:], [:'m:]) -> [:'n - 'm:] effect pure sub_exact
-val forall Nat 'n, Nat 'm, Nat 'o, 'o <= 'm - 'n. ([|'n:'m|], [:'o:]) -> [|'n:'m - 'o|] effect pure sub_range
-val forall Nat 'n, Nat 'm, Order 'ord. (vector<'n,'m,'ord,bit>, int) -> vector<'n,'m,'ord,bit> effect pure sub_bv
+val (int, int) -> bool effect pure gteq_int
+val (int, int) -> bool effect pure gt_int
+val (int, int) -> bool effect pure lteq_int
+val (int, int) -> bool effect pure lt_int
-val [:1:] -> bit effect pure cast_one_bit
-val forall Nat 'n, Order 'ord. [:1:] -> vector<'n,1,'ord,bit> effect pure cast_one_bv
-val [:0:] -> bit effect pure cast_zero_bit
-val forall Nat 'n, Order 'ord. [:0:] -> vector<'n,1,'ord,bit> effect pure cast_zero_bv
+val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lt_range_atom
+val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure lteq_range_atom
+val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gt_range_atom
+val forall Num 'n, Num 'm, Num 'o. ([|'n:'m|], [:'o:]) -> bool effect pure gteq_range_atom
+val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lt_atom_range
+val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure lteq_atom_range
+val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gt_atom_range
+val forall Num 'n, Num 'm, Num 'o. ([:'n:], [|'m:'o|]) -> bool effect pure gteq_atom_range
-val forall Type 'a. ('a, 'a) -> bool effect pure eq_anything
-val forall Type 'a. ('a, 'a) -> bool effect pure neq_anything
+overload (deinfix >=) [gteq_range_atom; gteq_atom_range; gteq_vec; gteq_int]
+overload (deinfix >) [gt_vec; gt_int]
+overload (deinfix <=) [lteq_range_atom; lteq_atom_range; lteq_vec; lteq_int]
+overload (deinfix <) [lt_vec; lt_int]
-val forall Nat 'n, Order 'ord. vector<'n,1,'ord,bit> -> bool effect pure cast_vec_bool
+val (int, int) -> int effect pure quotient
-val forall Nat 'n, Nat 'm, Nat 'o, Nat 'p, Order 'ord, 'm >= 'n.
- vector<'o,'n,'ord,bit> -> vector<'p,'m,'ord,bit> effect pure EXTS
+overload (deinfix quot) [quotient]
-val forall Nat 'n, Nat 'm, Order 'ord.
- (vector<'n,'m,'ord,bit>, vector<'n,'m,'ord,bit>) -> vector<'n,'m,'ord,bit>
- effect pure bv_add
+val forall Num 'n, Num 'm, Order 'ord, Type 'a. vector<'n,'m,'ord,'a> -> [:'m:] effect pure length
-val forall Nat 'n, Nat 'm, Nat 'o, 'n >= 'm - 1, 'o >= 'm - 1.
- vector<'n,'m,dec,bit> -> vector<'o,'m,dec,bit>
- effect pure ADJUST
-*)
+default Order dec
+
+val forall Nat 'W, 'W >= 1. bit[8 * 'W] -> bit[8 * 'W] effect pure reverse_endianness
+function rec forall Nat 'W, 'W >= 1. bit[8 * 'W] reverse_endianness ((bit[8 * 'W]) value) =
+{
+ ([:8 * 'W:]) width := length(value);
+ if width <= 8 then value
+ else value[7..0] : reverse_endianness(value[(width - 1) .. 8])
+}
diff --git a/src/type_check_new.ml b/src/type_check_new.ml
index 30bb97b0..6a1bc967 100644
--- a/src/type_check_new.ml
+++ b/src/type_check_new.ml
@@ -46,7 +46,7 @@ open Util
open Ast_util
open Big_int
-let debug = ref 0
+let debug = ref 1
let depth = ref 0
let rec indent n = match n with
@@ -957,6 +957,25 @@ let rec nexp_frees (Nexp_aux (nexp, l)) =
| Nexp_exp n -> nexp_frees n
| Nexp_neg n -> nexp_frees n
+let order_frees (Ord_aux (ord_aux, l)) =
+ match ord_aux with
+ | Ord_var kid -> KidSet.singleton kid
+ | _ -> KidSet.empty
+
+let rec typ_frees (Typ_aux (typ_aux, l)) =
+ match typ_aux with
+ | Typ_wild -> KidSet.empty
+ | Typ_id v -> KidSet.empty
+ | Typ_var kid -> KidSet.singleton kid
+ | Typ_tup typs -> List.fold_left KidSet.union KidSet.empty (List.map typ_frees typs)
+ | Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map typ_arg_frees args)
+and typ_arg_frees (Typ_arg_aux (typ_arg_aux, l)) =
+ match typ_arg_aux with
+ | Typ_arg_nexp n -> nexp_frees n
+ | Typ_arg_typ typ -> typ_frees typ
+ | Typ_arg_order ord -> order_frees ord
+ | Typ_arg_effect _ -> assert false
+
let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) =
match nexp1, nexp2 with
| Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0
@@ -979,31 +998,39 @@ exception Unification_error of l * string;;
let unify_error l str = raise (Unification_error (l, str))
-let rec unify_nexps l (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) =
- typ_debug ("UNIFYING NEXPS " ^ string_of_nexp nexp1 ^ " AND " ^ string_of_nexp nexp2);
- match nexp_aux1 with
- | Nexp_id v -> unify_error l "Unimplemented Nexp_id in unify nexp"
- | Nexp_var kid -> Some (kid, nexp2)
- | Nexp_constant c1 ->
- begin
- match nexp_aux2 with
- | Nexp_constant c2 -> if c1 = c2 then None else unify_error l "Constants are not the same"
- | _ -> unify_error l "Unification error"
- end
- | Nexp_sum (n1a, n1b) ->
- if KidSet.is_empty (nexp_frees n1b)
- then unify_nexps l n1a (nminus nexp2 n1b)
- else
- if KidSet.is_empty (nexp_frees n1a)
- then unify_nexps l n1b (nminus nexp2 n1a)
- else unify_error l ("Both sides of Nat expression " ^ string_of_nexp nexp1
- ^ " contain free type variables so it cannot be unified with " ^ string_of_nexp nexp2)
- | Nexp_minus (n1a, n1b) ->
- if KidSet.is_empty (nexp_frees n1b)
- then unify_nexps l n1a (nsum nexp2 n1b)
- else unify_error l ("Cannot unify minus Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
-
- | _ -> unify_error l ("Cannot unify Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
+let rec unify_nexps l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) =
+ typ_debug ("UNIFYING NEXPS " ^ string_of_nexp nexp1 ^ " AND " ^ string_of_nexp nexp2 ^ " FOR GOALS " ^ string_of_list ", " string_of_kid (KidSet.elements goals));
+ if KidSet.is_empty (KidSet.inter (nexp_frees nexp1) goals)
+ then
+ begin
+ if prove env (NC_aux (NC_fixed (nexp1, nexp2), Parse_ast.Unknown))
+ then None
+ else unify_error l ("Nexp " ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 ^ " are not equal")
+ end
+ else
+ match nexp_aux1 with
+ | Nexp_id v -> unify_error l "Unimplemented Nexp_id in unify nexp"
+ | Nexp_var kid when KidSet.mem kid goals -> Some (kid, nexp2)
+ | Nexp_constant c1 ->
+ begin
+ match nexp_aux2 with
+ | Nexp_constant c2 -> if c1 = c2 then None else unify_error l "Constants are not the same"
+ | _ -> unify_error l "Unification error"
+ end
+ | Nexp_sum (n1a, n1b) ->
+ if KidSet.is_empty (nexp_frees n1b)
+ then unify_nexps l env goals n1a (nminus nexp2 n1b)
+ else
+ if KidSet.is_empty (nexp_frees n1a)
+ then unify_nexps l env goals n1b (nminus nexp2 n1a)
+ else unify_error l ("Both sides of Nat expression " ^ string_of_nexp nexp1
+ ^ " contain free type variables so it cannot be unified with " ^ string_of_nexp nexp2)
+ | Nexp_minus (n1a, n1b) ->
+ if KidSet.is_empty (nexp_frees n1b)
+ then unify_nexps l env goals n1a (nsum nexp2 n1b)
+ else unify_error l ("Cannot unify minus Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
+
+ | _ -> unify_error l ("Cannot unify Nat expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
let string_of_uvar = function
| U_nexp n -> string_of_nexp n
@@ -1040,6 +1067,11 @@ let subst_args_unifiers unifiers typ_args =
List.fold_left subst_unifier typ_args (KBindings.bindings unifiers)
let unify l env typ1 typ2 =
+ typ_print ("Unify " ^ string_of_typ typ1 ^ " with " ^ string_of_typ typ2);
+ if not (KidSet.is_empty (KidSet.inter (typ_frees typ1) (typ_frees typ2)))
+ then unify_error l "Can only unify types with disjoint type variables"
+ else ();
+ let goals = typ_frees typ1 in
let merge_unifiers l kid uvar1 uvar2 =
match uvar1, uvar2 with
| Some (U_nexp n1), Some (U_nexp n2) ->
@@ -1058,7 +1090,8 @@ let unify l env typ1 typ2 =
| Typ_id v1, Typ_id v2 ->
if Id.compare v1 v2 = 0 then KBindings.empty
else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2)
- | Typ_var kid, _ -> KBindings.singleton kid (U_typ typ2)
+ | Typ_var kid, _ when KidSet.mem kid goals -> KBindings.singleton kid (U_typ typ2)
+ | Typ_var kid1, Typ_var kid2 when Kid.compare kid1 kid2 = 0 -> KBindings.empty
| Typ_tup typs1, Typ_tup typs2 ->
begin
try List.fold_left (KBindings.merge (merge_unifiers l)) KBindings.empty (List.map2 (unify_typ l) typs1 typs2) with
@@ -1095,7 +1128,7 @@ let unify l env typ1 typ2 =
match typ_arg_aux1, typ_arg_aux2 with
| Typ_arg_nexp n1, Typ_arg_nexp n2 ->
begin
- match unify_nexps l (nexp_simp n1) (nexp_simp n2) with
+ match unify_nexps l env goals (nexp_simp n1) (nexp_simp n2) with
| Some (kid, unifier) -> KBindings.singleton kid (U_nexp unifier)
| None -> KBindings.empty
end
@@ -1279,7 +1312,7 @@ let irule r env exp =
incr depth;
try
let inferred_exp = r env exp in
- typ_print ("Infer " ^ string_of_exp exp ^ " => " ^ string_of_typ (typ_of inferred_exp));
+ typ_print ("Infer " ^ string_of_exp exp ^ " => " ^ string_of_typ (typ_of inferred_exp));
decr depth;
inferred_exp
with
@@ -1596,11 +1629,12 @@ and bind_lexp env (LEXP_aux (lexp_aux, (l, ())) as lexp) typ =
| _ -> typ_error l ("Unhandled l-expression")
and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
+ typ_print ("Inferring " ^ string_of_exp exp);
let annot_exp_effect exp typ eff = E_aux (exp, (l, Some (env, typ, eff))) in
let annot_exp exp typ = annot_exp_effect exp typ no_effect in
match exp_aux with
| E_nondet exps ->
- annot_exp (E_nondet (List.map (fun exp -> check_exp env exp unit_typ) exps)) unit_typ
+ annot_exp (E_nondet (List.map (fun exp -> crule check_exp env exp unit_typ) exps)) unit_typ
| E_id v ->
begin
match Env.lookup_id v env with
@@ -1618,10 +1652,11 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
end
| E_field (exp, field) ->
begin
- let inferred_exp = infer_exp env exp in
+ let inferred_exp = irule infer_exp env exp in
match Env.expand_synonyms env (typ_of inferred_exp) with
(* Accessing a (bit) field of a register *)
| Typ_aux (Typ_id regtyp, _) when Env.is_regtyp regtyp env ->
+ typ_print "REGTYP";
let base, top, ranges = Env.get_regtyp regtyp env in
let range, _ =
try List.find (fun (_, id) -> Id.compare id field = 0) ranges with
@@ -1640,6 +1675,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
(* Accessing a field of a record *)
| Typ_aux (Typ_id rectyp, _) as typ when Env.is_record rectyp env ->
begin
+ typ_print "RECTYP";
let inferred_acc = infer_funapp' l (Env.no_casts env) field (Env.get_accessor field env) [strip_exp inferred_exp] None in
match inferred_acc with
| E_aux (E_app (field, [inferred_exp]) ,_) -> annot_exp (E_field (inferred_exp, field)) (typ_of inferred_acc)
@@ -1696,8 +1732,8 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
in
annot_exp (E_vector (inferred_item :: checked_items)) vec_typ
| E_assert (test, msg) ->
- let checked_test = check_exp env test bool_typ in
- let checked_msg = check_exp env msg string_typ in
+ let checked_test = crule check_exp env test bool_typ in
+ let checked_msg = crule check_exp env msg string_typ in
annot_exp (E_assert (checked_test, checked_msg)) unit_typ
| _ -> typ_error l ("Cannot infer type of: " ^ string_of_exp exp)
@@ -1752,7 +1788,8 @@ and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
| None -> (quants, typs, ret_typ)
| Some rct ->
begin
- let unifiers = try unify l env ret_typ rct with Unification_error _ -> typ_debug "UERROR"; KBindings.empty in
+ typ_debug ("INSTANTIATE RETURN:" ^ string_of_typ ret_typ);
+ let unifiers = try unify l env ret_typ rct with Unification_error _ -> typ_debug "UERROR"; KBindings.empty in
typ_debug (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers));
let typs' = List.map (subst_unifiers unifiers) typs in
let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in