summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlasdair Armstrong2018-12-07 21:53:29 +0000
committerAlasdair Armstrong2018-12-07 21:53:29 +0000
commit2c25110ad2f5e636239ba65a2154aae79ffa253c (patch)
tree51cdd81ea260dacd0faa1aed476ae95a2f3cc322
parent25ab845211e3df24386a0573b517a01dab879b03 (diff)
Working on better flow typing for ASL
On a new branch because it's completely broken everything for now
-rw-r--r--editors/sail2-mode.el2
-rw-r--r--language/sail.ott30
-rw-r--r--src/ast_util.ml168
-rw-r--r--src/ast_util.mli42
-rw-r--r--src/constraint.ml229
-rw-r--r--src/constraint.mli34
-rw-r--r--src/initial_check.ml22
-rw-r--r--src/lexer.mll1
-rw-r--r--src/ocaml_backend.ml9
-rw-r--r--src/parse_ast.ml3
-rw-r--r--src/parser.mly12
-rw-r--r--src/pretty_print_coq.ml3
-rw-r--r--src/pretty_print_lem.ml3
-rw-r--r--src/pretty_print_sail.ml10
-rw-r--r--src/rewriter.ml2
-rw-r--r--src/rewrites.ml7
-rw-r--r--src/spec_analysis.ml4
-rw-r--r--src/state.ml10
-rw-r--r--src/type_check.ml799
-rw-r--r--src/type_check.mli19
-rw-r--r--src/type_error.ml7
21 files changed, 532 insertions, 884 deletions
diff --git a/editors/sail2-mode.el b/editors/sail2-mode.el
index e0081fb6..b7998357 100644
--- a/editors/sail2-mode.el
+++ b/editors/sail2-mode.el
@@ -12,7 +12,7 @@
"mapping" "where" "with"))
(defconst sail2-kinds
- '("Int" "Type" "Order" "inc" "dec"
+ '("Int" "Type" "Order" "Bool" "inc" "dec"
"barr" "depend" "rreg" "wreg" "rmem" "rmemt" "wmv" "wmvt" "eamem" "wmem"
"exmem" "undef" "unspec" "nondet" "escape" "configuration"))
diff --git a/language/sail.ott b/language/sail.ott
index a0b02a1c..b35e64d3 100644
--- a/language/sail.ott
+++ b/language/sail.ott
@@ -175,6 +175,7 @@ kind :: 'K_' ::=
| Type :: :: type {{ com kind of types }}
| Int :: :: int {{ com kind of natural number size expressions }}
| Order :: :: order {{ com kind of vector order specifications }}
+ | Bool :: :: bool {{ com kind of constraints }}
nexp :: 'Nexp_' ::=
{{ com numeric expression, of kind Int }}
@@ -273,22 +274,25 @@ typ :: 'Typ_' ::=
typ_arg :: 'Typ_arg_' ::=
{{ com type constructor arguments of all kinds }}
{{ aux _ l }}
- | nexp :: :: nexp
- | typ :: :: typ
- | order :: :: order
+ | nexp :: :: nexp
+ | typ :: :: typ
+ | order :: :: order
+ | n_constraint :: :: bool
n_constraint :: 'NC_' ::=
{{ com constraint over kind Int }}
{{ aux _ l }}
- | nexp = nexp' :: :: equal
- | nexp >= nexp' :: :: bounded_ge
- | nexp '<=' nexp' :: :: bounded_le
- | nexp != nexp' :: :: not_equal
- | kid 'IN' { num1 , ... , numn } :: :: set
- | n_constraint \/ n_constraint' :: :: or
- | n_constraint /\ n_constraint' :: :: and
- | true :: :: true
- | false :: :: false
+ | nexp = nexp' :: :: equal
+ | nexp >= nexp' :: :: bounded_ge
+ | nexp '<=' nexp' :: :: bounded_le
+ | nexp != nexp' :: :: not_equal
+ | kid 'IN' { num1 , ... , numn } :: :: set
+ | n_constraint \/ n_constraint' :: :: or
+ | n_constraint /\ n_constraint' :: :: and
+ | id ( typ_arg0 , ... , typ_argn ) :: :: app
+ | kid :: :: var
+ | true :: :: true
+ | false :: :: false
% Note only id on the left and constants on the right in a
% finite-set-bound, as we don't think we need anything more
@@ -366,7 +370,7 @@ type_def {{ ocaml 'a type_def }} {{ lem type_def 'a }} :: 'TD_' ::=
type_def_aux :: 'TD_' ::=
{{ com type definition body }}
- | typedef id name_scm_opt = typschm :: :: abbrev
+ | type id typquant = typ_arg :: :: abbrev
{{ com type abbreviation }} {{ texlong }}
| typedef id name_scm_opt = const struct typquant { typ1 id1 ; ... ; typn idn semi_opt } :: :: record
{{ com struct type definition }} {{ texlong }}
diff --git a/src/ast_util.ml b/src/ast_util.ml
index e9153f7a..788008d1 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -370,23 +370,16 @@ let nc_lt n1 n2 = nc_lteq (nsum n1 (nint 1)) n2
let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nint 1))
let nc_and nc1 nc2 = mk_nc (NC_and (nc1, nc2))
let nc_or nc1 nc2 = mk_nc (NC_or (nc1, nc2))
+let nc_var kid = mk_nc (NC_var kid)
let nc_true = mk_nc NC_true
let nc_false = mk_nc NC_false
-let rec nc_negate (NC_aux (nc, l)) =
- match nc with
- | NC_bounded_ge (n1, n2) -> nc_lt n1 n2
- | NC_bounded_le (n1, n2) -> nc_gt n1 n2
- | NC_equal (n1, n2) -> nc_neq n1 n2
- | NC_not_equal (n1, n2) -> nc_eq n1 n2
- | NC_and (n1, n2) -> mk_nc (NC_or (nc_negate n1, nc_negate n2))
- | NC_or (n1, n2) -> mk_nc (NC_and (nc_negate n1, nc_negate n2))
- | NC_false -> mk_nc NC_true
- | NC_true -> mk_nc NC_false
- | NC_set (kid, []) -> nc_false
- | NC_set (kid, [int]) -> nc_neq (nvar kid) (nconstant int)
- | NC_set (kid, int :: ints) ->
- mk_nc (NC_and (nc_neq (nvar kid) (nconstant int), nc_negate (mk_nc (NC_set (kid, ints)))))
+let arg_nexp ?loc:(l=Parse_ast.Unknown) n = Typ_arg_aux (Typ_arg_nexp n, l)
+let arg_order ?loc:(l=Parse_ast.Unknown) ord = Typ_arg_aux (Typ_arg_order ord, l)
+let arg_typ ?loc:(l=Parse_ast.Unknown) typ = Typ_arg_aux (Typ_arg_typ typ, l)
+let arg_bool ?loc:(l=Parse_ast.Unknown) nc = Typ_arg_aux (Typ_arg_bool nc, l)
+
+let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc]))
let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown)
@@ -437,6 +430,7 @@ let unaux_nexp (Nexp_aux (nexp, _)) = nexp
let unaux_order (Ord_aux (ord, _)) = ord
let unaux_typ (Typ_aux (typ, _)) = typ
let unaux_kind (K_aux (k, _)) = k
+let unaux_constraint (NC_aux (nc, _)) = nc
let rec map_exp_annot f (E_aux (exp, annot)) = E_aux (map_exp_annot_aux f exp, f annot)
and map_exp_annot_aux f = function
@@ -628,6 +622,7 @@ let string_of_kind_aux = function
| K_type -> "Type"
| K_int -> "Int"
| K_order -> "Order"
+ | K_bool -> "Bool"
let string_of_kind (K_aux (k, _)) = string_of_kind_aux k
@@ -680,6 +675,7 @@ and string_of_typ_arg_aux = function
| Typ_arg_nexp n -> string_of_nexp n
| Typ_arg_typ typ -> string_of_typ typ
| Typ_arg_order o -> string_of_order o
+ | Typ_arg_bool nc -> string_of_n_constraint nc
and string_of_n_constraint = function
| NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2
| NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2
@@ -691,6 +687,8 @@ and string_of_n_constraint = function
"(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_set (kid, ns), _) ->
string_of_kid kid ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}"
+ | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")"
+ | NC_aux (NC_var v, _) -> string_of_kid v
| NC_aux (NC_true, _) -> "true"
| NC_aux (NC_false, _) -> "false"
@@ -1142,10 +1140,13 @@ let rec tyvars_of_constraint (NC_aux (nc, _)) =
| NC_or (nc1, nc2)
| NC_and (nc1, nc2) ->
KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2)
+ | NC_app (id, args) ->
+ List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ_arg t)) KidSet.empty args
+ | NC_var kid -> KidSet.singleton kid
| NC_true
| NC_false -> KidSet.empty
-let rec tyvars_of_typ (Typ_aux (t,_)) =
+and tyvars_of_typ (Typ_aux (t,_)) =
match t with
| Typ_internal_unknown -> KidSet.empty
| Typ_id _ -> KidSet.empty
@@ -1166,6 +1167,7 @@ and tyvars_of_typ_arg (Typ_arg_aux (ta,_)) =
| Typ_arg_nexp nexp -> tyvars_of_nexp nexp
| Typ_arg_typ typ -> tyvars_of_typ typ
| Typ_arg_order _ -> KidSet.empty
+ | Typ_arg_bool nc -> tyvars_of_constraint nc
let tyvars_of_quant_item (QI_aux (qi, _)) = match qi with
| QI_id (KOpt_aux ((KOpt_none kid | KOpt_kind (_, kid)), _)) ->
@@ -1547,10 +1549,26 @@ let unique l =
(* 1. Substitutions *)
(**************************************************************************)
+let order_subst_aux sv subst = function
+ | Ord_var kid ->
+ begin match subst with
+ | Typ_arg_aux (Typ_arg_order ord, _) when Kid.compare kid sv = 0 ->
+ unaux_order ord
+ | _ -> Ord_var kid
+ end
+ | Ord_inc -> Ord_inc
+ | Ord_dec -> Ord_dec
+
+let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l)
+
let rec nexp_subst sv subst (Nexp_aux (nexp, l)) = Nexp_aux (nexp_subst_aux sv subst nexp, l)
and nexp_subst_aux sv subst = function
- | Nexp_id v -> Nexp_id v
- | Nexp_var kid -> if Kid.compare kid sv = 0 then subst else Nexp_var kid
+ | Nexp_var kid ->
+ begin match subst with
+ | Typ_arg_aux (Typ_arg_nexp n, _) when Kid.compare kid sv = 0 -> unaux_nexp n
+ | _ -> Nexp_var kid
+ end
+ | Nexp_id id -> Nexp_id id
| Nexp_constant c -> Nexp_constant c
| Nexp_times (nexp1, nexp2) -> Nexp_times (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2)
| Nexp_sum (nexp1, nexp2) -> Nexp_sum (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2)
@@ -1564,100 +1582,68 @@ let rec nexp_set_to_or l subst = function
| [int] -> NC_equal (subst, nconstant int)
| (int :: ints) -> NC_or (mk_nc (NC_equal (subst, nconstant int)), mk_nc (nexp_set_to_or l subst ints))
-let rec nc_subst_nexp sv subst (NC_aux (nc, l)) = NC_aux (nc_subst_nexp_aux l sv subst nc, l)
-and nc_subst_nexp_aux l sv subst = function
+let rec constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l)
+and constraint_subst_aux l sv subst = function
| NC_equal (n1, n2) -> NC_equal (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_ge (n1, n2) -> NC_bounded_ge (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_le (n1, n2) -> NC_bounded_le (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_set (kid, ints) as set_nc ->
- if Kid.compare kid sv = 0
- then nexp_set_to_or l (mk_nexp subst) ints
- else set_nc
- | NC_or (nc1, nc2) -> NC_or (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2)
- | NC_and (nc1, nc2) -> NC_and (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2)
+ begin match subst with
+ | Typ_arg_aux (Typ_arg_nexp n, _) when Kid.compare kid sv = 0 ->
+ nexp_set_to_or l n ints
+ | _ -> set_nc
+ end
+ | NC_or (nc1, nc2) -> NC_or (constraint_subst sv subst nc1, constraint_subst sv subst nc2)
+ | NC_and (nc1, nc2) -> NC_and (constraint_subst sv subst nc1, constraint_subst sv subst nc2)
+ | NC_app (id, args) -> NC_app (id, List.map (typ_arg_subst sv subst) args)
+ | NC_var kid ->
+ begin match subst with
+ | Typ_arg_aux (Typ_arg_bool nc, _) when Kid.compare kid sv = 0 ->
+ unaux_constraint nc
+ | _ -> NC_var kid
+ end
| NC_false -> NC_false
| NC_true -> NC_true
-let rec typ_subst_nexp sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_nexp_aux sv subst typ, l)
-and typ_subst_nexp_aux sv subst = function
+and typ_subst sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_aux sv subst typ, l)
+and typ_subst_aux sv subst = function
| Typ_internal_unknown -> Typ_internal_unknown
| Typ_id v -> Typ_id v
- | Typ_var kid -> Typ_var kid
- | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_nexp sv subst) arg_typs, typ_subst_nexp sv subst ret_typ, effs)
- | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_nexp sv subst typ1, typ_subst_nexp sv subst typ2)
- | Typ_tup typs -> Typ_tup (List.map (typ_subst_nexp sv subst) typs)
- | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_nexp sv subst) args)
+ | Typ_var kid ->
+ begin match subst with
+ | Typ_arg_aux (Typ_arg_typ typ, _) when Kid.compare kid sv = 0 ->
+ unaux_typ typ
+ | _ -> Typ_var kid
+ end
+ | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst sv subst) arg_typs, typ_subst sv subst ret_typ, effs)
+ | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst sv subst typ1, typ_subst sv subst typ2)
+ | Typ_tup typs -> Typ_tup (List.map (typ_subst sv subst) typs)
+ | Typ_app (f, args) -> Typ_app (f, List.map (typ_arg_subst sv subst) args)
| Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ)
- | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc_subst_nexp sv subst nc, typ_subst_nexp sv subst typ)
-and typ_subst_arg_nexp sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_nexp_aux sv subst arg, l)
-and typ_subst_arg_nexp_aux sv subst = function
- | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv subst nexp)
- | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_nexp sv subst typ)
- | Typ_arg_order ord -> Typ_arg_order ord
+ | Typ_exist (kids, nc, typ) -> Typ_exist (kids, constraint_subst sv subst nc, typ_subst sv subst typ)
-let rec typ_subst_typ sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_typ_aux sv subst typ, l)
-and typ_subst_typ_aux sv subst = function
- | Typ_internal_unknown -> Typ_internal_unknown
- | Typ_id v -> Typ_id v
- | Typ_var kid -> if Kid.compare kid sv = 0 then subst else Typ_var kid
- | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_typ sv subst) arg_typs, typ_subst_typ sv subst ret_typ, effs)
- | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_typ sv subst typ1, typ_subst_typ sv subst typ2)
- | Typ_tup typs -> Typ_tup (List.map (typ_subst_typ sv subst) typs)
- | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_typ sv subst) args)
- | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, typ_subst_typ sv subst typ)
-and typ_subst_arg_typ sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_typ_aux sv subst arg, l)
-and typ_subst_arg_typ_aux sv subst = function
- | Typ_arg_nexp nexp -> Typ_arg_nexp nexp
- | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_typ sv subst typ)
- | Typ_arg_order ord -> Typ_arg_order ord
-
-let order_subst_aux sv subst = function
- | Ord_var kid -> if Kid.compare kid sv = 0 then subst else Ord_var kid
- | Ord_inc -> Ord_inc
- | Ord_dec -> Ord_dec
-
-let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l)
-
-let rec typ_subst_order sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_order_aux sv subst typ, l)
-and typ_subst_order_aux sv subst = function
- | Typ_internal_unknown -> Typ_internal_unknown
- | Typ_id v -> Typ_id v
- | Typ_var kid -> Typ_var kid
- | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_order sv subst) arg_typs, typ_subst_order sv subst ret_typ, effs)
- | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_order sv subst typ1, typ_subst_order sv subst typ2)
- | Typ_tup typs -> Typ_tup (List.map (typ_subst_order sv subst) typs)
- | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_order sv subst) args)
- | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, typ_subst_order sv subst typ)
-and typ_subst_arg_order sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_order_aux sv subst arg, l)
-and typ_subst_arg_order_aux sv subst = function
- | Typ_arg_nexp nexp -> Typ_arg_nexp nexp
- | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_order sv subst typ)
+and typ_arg_subst sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_arg_subst_aux sv subst arg, l)
+and typ_arg_subst_aux sv subst = function
+ | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv subst nexp)
+ | Typ_arg_typ typ -> Typ_arg_typ (typ_subst sv subst typ)
| Typ_arg_order ord -> Typ_arg_order (order_subst sv subst ord)
+ | Typ_arg_bool nc -> Typ_arg_bool (constraint_subst sv subst nc)
-let rec typ_subst_kid sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_kid_aux sv subst typ, l)
-and typ_subst_kid_aux sv subst = function
- | Typ_internal_unknown -> Typ_internal_unknown
- | Typ_id v -> Typ_id v
- | Typ_var kid -> if Kid.compare kid sv = 0 then Typ_var subst else Typ_var kid
- | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_kid sv subst) arg_typs, typ_subst_kid sv subst ret_typ, effs)
- | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_kid sv subst typ1, typ_subst_kid sv subst typ2)
- | Typ_tup typs -> Typ_tup (List.map (typ_subst_kid sv subst) typs)
- | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_kid sv subst) args)
- | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ)
- | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc_subst_nexp sv (Nexp_var subst) nc, typ_subst_kid sv subst typ)
-and typ_subst_arg_kid sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_kid_aux sv subst arg, l)
-and typ_subst_arg_kid_aux sv subst = function
- | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv (Nexp_var subst) nexp)
- | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_kid sv subst typ)
- | Typ_arg_order ord -> Typ_arg_order (order_subst sv (Ord_var subst) ord)
+let subst_kid subst sv v x =
+ x
+ |> subst sv (mk_typ_arg (Typ_arg_bool (nc_var v)))
+ |> subst sv (mk_typ_arg (Typ_arg_nexp (nvar v)))
+ |> subst sv (mk_typ_arg (Typ_arg_order (Ord_aux (Ord_var v, Parse_ast.Unknown))))
+ |> subst sv (mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var v))))
let quant_item_subst_kid_aux sv subst = function
| QI_id (KOpt_aux (KOpt_none kid, l)) as qid ->
if Kid.compare kid sv = 0 then QI_id (KOpt_aux (KOpt_none subst, l)) else qid
| QI_id (KOpt_aux (KOpt_kind (k, kid), l)) as qid ->
if Kid.compare kid sv = 0 then QI_id (KOpt_aux (KOpt_kind (k, subst), l)) else qid
- | QI_const nc -> QI_const (nc_subst_nexp sv (Nexp_var subst) nc)
+ | QI_const nc ->
+ QI_const (subst_kid constraint_subst sv subst nc)
let quant_item_subst_kid sv subst (QI_aux (quant, l)) = QI_aux (quant_item_subst_kid_aux sv subst quant, l)
diff --git a/src/ast_util.mli b/src/ast_util.mli
index 19fc017d..73ab4a01 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -97,6 +97,7 @@ val unaux_nexp : nexp -> nexp_aux
val unaux_order : order -> order_aux
val unaux_typ : typ -> typ_aux
val unaux_kind : kind -> kind_aux
+val unaux_constraint : n_constraint -> n_constraint_aux
val untyp_pat : 'a pat -> 'a pat * typ option
val uncast_exp : 'a exp -> 'a exp * typ option
@@ -154,7 +155,7 @@ val ntimes : nexp -> nexp -> nexp
val npow2 : nexp -> nexp
val nvar : kid -> nexp
val napp : id -> nexp list -> nexp
-val nid : id -> nexp (* NOTE: Nexp_id's don't do anything currently *)
+val nid : id -> nexp
(* Numeric constraint builders *)
val nc_eq : nexp -> nexp -> n_constraint
@@ -165,15 +166,16 @@ val nc_lt : nexp -> nexp -> n_constraint
val nc_gt : nexp -> nexp -> n_constraint
val nc_and : n_constraint -> n_constraint -> n_constraint
val nc_or : n_constraint -> n_constraint -> n_constraint
+val nc_not : n_constraint -> n_constraint
val nc_true : n_constraint
val nc_false : n_constraint
val nc_set : kid -> Big_int.num list -> n_constraint
val nc_int_set : kid -> int list -> n_constraint
-(* Negate a n_constraint. Note that there's no NC_not constructor, so
- this flips all the inequalites a the n_constraint leaves and uses
- de-morgans to switch and to or and vice versa. *)
-val nc_negate : n_constraint -> n_constraint
+val arg_nexp : ?loc:l -> nexp -> typ_arg
+val arg_order : ?loc:l -> order -> typ_arg
+val arg_typ : ?loc:l -> typ -> typ_arg
+val arg_bool : ?loc:l -> n_constraint -> typ_arg
(* Functions for working with type quantifiers *)
val quant_add : quant_item -> typquant -> typquant
@@ -203,7 +205,6 @@ val def_loc : 'a def -> Parse_ast.l
(* For debugging and error messages only: Not guaranteed to produce
parseable SAIL, or even print all language constructs! *)
-(* TODO: replace with existing pretty-printer *)
val string_of_id : id -> string
val string_of_kid : kid -> string
val string_of_base_effect_aux : base_effect_aux -> string
@@ -382,27 +383,16 @@ val unique : l -> l
(** Substitutions *)
-(* The function X_subst_Y substitutes a Y into something of type X, if
- X = Y then the function is just X_subst. Substitutions are always
- unwrapped from their aux constructors. *)
-val nexp_subst : kid -> nexp_aux -> nexp -> nexp
-val nc_subst_nexp : kid -> nexp_aux -> n_constraint -> n_constraint
-val order_subst : kid -> order_aux -> order -> order
+(* The function X_subst substitutes a type argument into something of
+ type X. The type of the type argument determines which kind of type
+ variables willb e replaced *)
+val nexp_subst : kid -> typ_arg -> nexp -> nexp
+val constraint_subst : kid -> typ_arg -> n_constraint -> n_constraint
+val order_subst : kid -> typ_arg -> order -> order
+val typ_subst : kid -> typ_arg -> typ -> typ
+val typ_arg_subst : kid -> typ_arg -> typ_arg -> typ_arg
-(* kid must be Int-kinded *)
-val typ_subst_nexp : kid -> nexp_aux -> typ -> typ
-val typ_subst_arg_nexp : kid -> nexp_aux -> typ_arg -> typ_arg
-
-(* kid must be Type-kinded *)
-val typ_subst_typ : kid -> typ_aux -> typ -> typ
-val typ_subst_arg_typ : kid -> typ_aux -> typ_arg -> typ_arg
-
-(* kid must be Order-kinded *)
-val typ_subst_order : kid -> order_aux -> typ -> typ
-val typ_subst_arg_order : kid -> order_aux -> typ_arg -> typ_arg
-
-val typ_subst_kid : kid -> kid -> typ -> typ
-val typ_subst_arg_kid : kid -> kid -> typ_arg -> typ_arg
+val subst_kid : (kid -> typ_arg -> 'a -> 'a) -> kid -> kid -> 'a -> 'a
val quant_item_subst_kid : kid -> kid -> quant_item -> quant_item
val typquant_subst_kid : kid -> kid -> typquant -> typquant
diff --git a/src/constraint.ml b/src/constraint.ml
index cf861423..460e8c76 100644
--- a/src/constraint.ml
+++ b/src/constraint.ml
@@ -49,86 +49,10 @@
(**************************************************************************)
module Big_int = Nat_big_num
+open Ast
+open Ast_util
open Util
-(* ===== Integer Constraints ===== *)
-
-type nexp_op = string
-
-type nexp =
- | NFun of (nexp_op * nexp list)
- | N2n of nexp
- | NConstant of Big_int.num
- | NVar of int
-
-let big_int_op : nexp_op -> (Big_int.num -> Big_int.num -> Big_int.num) option = function
- | "+" -> Some Big_int.add
- | "-" -> Some Big_int.sub
- | "*" -> Some Big_int.mul
- | _ -> None
-
-let rec arith constr =
- let constr' = match constr with
- | NFun (op, [x; y]) -> NFun (op, [arith x; arith y])
- | N2n c -> N2n (arith c)
- | c -> c
- in
- match constr' with
- | NFun (op, [NConstant x; NConstant y]) as c ->
- begin
- match big_int_op op with
- | Some op -> NConstant (op x y)
- | None -> c
- end
- | N2n (NConstant x) -> NConstant (Big_int.pow_int_positive 2 (Big_int.to_int x))
- | c -> c
-
-(* ===== Boolean Constraints ===== *)
-
-type constraint_bool_op = And | Or
-
-type constraint_compare_op = Gt | Lt | GtEq | LtEq | Eq | NEq
-
-let negate_comparison = function
- | Gt -> LtEq
- | Lt -> GtEq
- | GtEq -> Lt
- | LtEq -> Gt
- | Eq -> NEq
- | NEq -> Eq
-
-type 'a constraint_bool =
- | BFun of (constraint_bool_op * 'a constraint_bool * 'a constraint_bool)
- | Not of 'a constraint_bool
- | CFun of (constraint_compare_op * 'a * 'a)
- | Forall of (int list * 'a constraint_bool)
- | Boolean of bool
-
-let rec pairs (xs : 'a list) (ys : 'a list) : ('a * 'b) list =
- match xs with
- | [] -> []
- | (x :: xs) -> List.map (fun y -> (x, y)) ys @ pairs xs ys
-
-(* Get a set of variables from a constraint *)
-module IntSet = Set.Make(
- struct
- let compare = Pervasives.compare
- type t = int
- end)
-
-let rec nexp_vars : nexp -> IntSet.t = function
- | NConstant _ -> IntSet.empty
- | NVar v -> IntSet.singleton v
- | NFun (_, xs) -> List.fold_left IntSet.union IntSet.empty (List.map nexp_vars xs)
- | N2n x -> nexp_vars x
-
-let rec constraint_vars : nexp constraint_bool -> IntSet.t = function
- | BFun (_, x, y) -> IntSet.union (constraint_vars x) (constraint_vars y)
- | Not x -> constraint_vars x
- | CFun (_, x, y) -> IntSet.union (nexp_vars x) (nexp_vars y)
- | Forall (vars, x) -> IntSet.diff (constraint_vars x) (IntSet.of_list vars)
- | Boolean _ -> IntSet.empty
-
(* SMTLIB v2.0 format is based on S-expressions so we have a
lightweight representation of those here. *)
type sexpr = List of (sexpr list) | Atom of string
@@ -139,47 +63,67 @@ let rec pp_sexpr : sexpr -> string = function
| List xs -> "(" ^ string_of_list " " pp_sexpr xs ^ ")"
| Atom x -> x
-let var_decs constr =
- constraint_vars constr
- |> IntSet.elements
- |> List.map (fun var -> sfun "declare-const" [Atom ("v" ^ string_of_int var); Atom "Int"])
- |> string_of_list "\n" pp_sexpr
+let zencode_kid kid = Util.zencode_string (string_of_kid kid)
-let cop_sexpr op x y =
- match op with
- | Gt -> sfun ">" [x; y]
- | Lt -> sfun "<" [x; y]
- | GtEq -> sfun ">=" [x; y]
- | LtEq -> sfun "<=" [x; y]
- | Eq -> sfun "=" [x; y]
- | NEq -> sfun "not" [sfun "=" [x; y]]
+(** Each non-Type/Order kind in Sail mapes to a type in the SMT solver *)
+let smt_type l = function
+ | K_int -> Atom "Int"
+ | K_bool -> Atom "Bool"
+ | _ -> raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kinded variable to SMT solver")
-let rec sexpr_of_nexp = function
- | NFun (op, xs) -> sfun op (List.map sexpr_of_nexp xs)
- | N2n x -> sfun "^" [Atom "2"; sexpr_of_nexp x]
- | NConstant c -> Atom (Big_int.to_string c) (* CHECK: do we do negative constants right? *)
- | NVar var -> Atom ("v" ^ string_of_int var)
+let smt_var v = Atom ("v" ^ zencode_kid v)
-let rec sexpr_of_constraint = function
- | BFun (And, x, y) -> sfun "and" [sexpr_of_constraint x; sexpr_of_constraint y]
- | BFun (Or, x, y) -> sfun "or" [sexpr_of_constraint x; sexpr_of_constraint y]
- | Not x -> sfun "not" [sexpr_of_constraint x]
- | CFun (op, x, y) -> cop_sexpr op (sexpr_of_nexp (arith x)) (sexpr_of_nexp (arith y))
- | Forall (vars, x) ->
- sfun "forall" [List (List.map (fun v -> List [Atom ("v" ^ string_of_int v); Atom "Int"]) vars); sexpr_of_constraint x]
- | Boolean true -> Atom "true"
- | Boolean false -> Atom "false"
+(** var_decs outputs the list of variables to be used by the SMT
+ solver in SMTLIB v2.0 format. It takes a kind_aux KBindings, as
+ returned by Type_check.get_typ_vars *)
+let var_decs l (vars : kind_aux KBindings.t) : string = vars
+ |> KBindings.bindings
+ |> List.map (fun (v, k) -> sfun "declare-const" [smt_var v; smt_type l k])
+ |> string_of_list "\n" pp_sexpr
-let smtlib_of_constraints ?get_model:(get_model=false) constr : string =
+let rec smt_nexp (Nexp_aux (aux, l) : nexp) : sexpr =
+ match aux with
+ | Nexp_id id -> Atom (Util.zencode_string (string_of_id id))
+ | Nexp_var v -> smt_var v
+ | Nexp_constant c -> Atom (Big_int.to_string c)
+ | Nexp_app (id, nexps) -> sfun (string_of_id id) (List.map smt_nexp nexps)
+ | Nexp_times (nexp1, nexp2) -> sfun "*" [smt_nexp nexp1; smt_nexp nexp2]
+ | Nexp_sum (nexp1, nexp2) -> sfun "+" [smt_nexp nexp1; smt_nexp nexp2]
+ | Nexp_minus (nexp1, nexp2) -> sfun "-" [smt_nexp nexp1; smt_nexp nexp2]
+ | Nexp_exp nexp -> sfun "^" [Atom "2"; smt_nexp nexp]
+ | Nexp_neg nexp -> sfun "-" [smt_nexp nexp]
+
+let rec smt_constraint (NC_aux (aux, l) : n_constraint) : sexpr =
+ match aux with
+ | NC_equal (nexp1, nexp2) -> sfun "=" [smt_nexp nexp1; smt_nexp nexp2]
+ | NC_bounded_le (nexp1, nexp2) -> sfun "<=" [smt_nexp nexp1; smt_nexp nexp2]
+ | NC_bounded_ge (nexp1, nexp2) -> sfun ">=" [smt_nexp nexp1; smt_nexp nexp2]
+ | NC_not_equal (nexp1, nexp2) -> sfun "not" [sfun "=" [smt_nexp nexp1; smt_nexp nexp2]]
+ | NC_set (v, ints) ->
+ sfun "or" (List.map (fun i -> sfun "=" [smt_var v; Atom (Big_int.to_string i)]) ints)
+ | NC_or (nc1, nc2) -> sfun "or" [smt_constraint nc1; smt_constraint nc2]
+ | NC_and (nc1, nc2) -> sfun "and" [smt_constraint nc1; smt_constraint nc2]
+ | NC_app (id, args) ->
+ sfun (string_of_id id) (List.map smt_typ_arg args)
+ | NC_true -> Atom "true"
+ | NC_false -> Atom "false"
+ | NC_var v -> smt_var v
+
+and smt_typ_arg (Typ_arg_aux (aux, l) : typ_arg) : sexpr =
+ match aux with
+ | Typ_arg_nexp nexp -> smt_nexp nexp
+ | Typ_arg_bool nc -> smt_constraint nc
+ | _ ->
+ raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function")
+
+let smtlib_of_constraints ?get_model:(get_model=false) l vars constr : string =
"(push)\n"
- ^ var_decs constr ^ "\n"
- ^ pp_sexpr (sfun "define-fun" [Atom "constraint"; List []; Atom "Bool"; sexpr_of_constraint constr])
+ ^ var_decs l vars ^ "\n"
+ ^ pp_sexpr (sfun "define-fun" [Atom "constraint"; List []; Atom "Bool"; smt_constraint constr])
^ "\n(assert constraint)\n(check-sat)"
^ (if get_model then "\n(get-model)" else "")
^ "\n(pop)"
-type t = nexp constraint_bool
-
type smt_result = Unknown | Sat | Unsat
module DigestMap = Map.Make(Digest)
@@ -219,9 +163,9 @@ let save_digests () =
DigestMap.iter output !known_problems;
close_out out_chan
-let call_z3' constraints : smt_result =
+let call_z3' l vars constraints : smt_result =
let problems = [constraints] in
- let z3_file = smtlib_of_constraints constraints in
+ let z3_file = smtlib_of_constraints l vars constraints in
(* prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" z3_file); *)
@@ -261,15 +205,15 @@ let call_z3' constraints : smt_result =
else (known_problems := DigestMap.add digest Unknown !known_problems; Unknown)
end
-let call_z3 constraints =
+let call_z3 l vars constraints =
let t = Profile.start_z3 () in
- let result = call_z3' constraints in
+ let result = call_z3' l vars constraints in
Profile.finish_z3 t;
result
-let rec solve_z3 constraints var =
+let rec solve_z3 l vars constraints var =
let problems = [constraints] in
- let z3_file = smtlib_of_constraints ~get_model:true constraints in
+ let z3_file = smtlib_of_constraints ~get_model:true l vars constraints in
(* prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" z3_file); *)
@@ -289,62 +233,13 @@ let rec solve_z3 constraints var =
let z3_output = String.concat " " (input_all z3_chan) in
let _ = Unix.close_process_in z3_chan in
Sys.remove input_file;
- let regexp = {|(define-fun v|} ^ string_of_int var ^ {| () Int[ ]+\([0-9]+\))|} in
+ let regexp = {|(define-fun v|} ^ Util.zencode_string (string_of_kid var) ^ {| () Int[ ]+\([0-9]+\))|} in
try
let _ = Str.search_forward (Str.regexp regexp) z3_output 0 in
let result = Big_int.of_string (Str.matched_group 1 z3_output) in
- begin match call_z3 (BFun (And, constraints, CFun (NEq, NConstant result, NVar var))) with
+ begin match call_z3 l vars (nc_and constraints (nc_neq (nconstant result) (nvar var))) with
| Unsat -> Some result
| _ -> None
end
with
Not_found -> None
-
-let string_of constr = smtlib_of_constraints constr
-
-(* ===== Abstract API for building constraints ===== *)
-
-(* These functions are exported from constraint.mli, and ensure that
- the internal representation of constraints remains opaque. *)
-
-let implies (x : t) (y : t) : t =
- BFun (Or, Not x, y)
-
-let conj (x : t) (y : t) : t =
- BFun (And, x, y)
-
-let disj (x : t) (y : t) : t =
- BFun (Or, x, y)
-
-let forall (vars : int list) (x : t) : t =
- if vars = [] then x else Forall (vars, x)
-
-let negate (x : t) : t = Not x
-
-let literal (b : bool) : t = Boolean b
-
-let lt x y : t = CFun (Lt, x, y)
-
-let lteq x y : t = CFun (LtEq, x, y)
-
-let gt x y : t = CFun (Gt, x, y)
-
-let gteq x y : t = CFun (GtEq, x, y)
-
-let eq x y : t = CFun (Eq, x, y)
-
-let neq x y : t = CFun (NEq, x, y)
-
-let pow2 x : nexp = N2n x
-
-let add x y : nexp = NFun ("+", [x; y])
-
-let sub x y : nexp = NFun ("-", [x; y])
-
-let mult x y : nexp = NFun ("*", [x; y])
-
-let app f xs : nexp = NFun (f, xs)
-
-let constant (x : Big_int.num) : nexp = NConstant x
-
-let variable (v : int) : nexp = NVar v
diff --git a/src/constraint.mli b/src/constraint.mli
index df9c8b3a..51088245 100644
--- a/src/constraint.mli
+++ b/src/constraint.mli
@@ -49,40 +49,14 @@
(**************************************************************************)
module Big_int = Nat_big_num
-
-type nexp
-type t
+open Ast
+open Ast_util
type smt_result = Unknown | Sat | Unsat
val load_digests : unit -> unit
val save_digests : unit -> unit
-val call_z3 : t -> smt_result
-
-val solve_z3 : t -> int -> Big_int.num option
-
-val string_of : t -> string
-
-val implies : t -> t -> t
-val conj : t -> t -> t
-val disj : t -> t -> t
-val negate : t -> t
-val literal : bool -> t
-val forall : int list -> t -> t
-
-val lt : nexp -> nexp -> t
-val lteq : nexp -> nexp -> t
-val gt : nexp -> nexp -> t
-val gteq : nexp -> nexp -> t
-val eq : nexp -> nexp -> t
-val neq : nexp -> nexp -> t
-
-val pow2 : nexp -> nexp
-val add : nexp -> nexp -> nexp
-val sub : nexp -> nexp -> nexp
-val mult : nexp -> nexp -> nexp
-val app : string -> nexp list -> nexp
+val call_z3 : l -> kind_aux KBindings.t -> n_constraint -> smt_result
-val constant : Big_int.num -> nexp
-val variable : int -> nexp
+val solve_z3 : l -> kind_aux KBindings.t -> n_constraint -> kid -> Big_int.num option
diff --git a/src/initial_check.ml b/src/initial_check.ml
index b57e6b17..e84f655c 100644
--- a/src/initial_check.ml
+++ b/src/initial_check.ml
@@ -85,6 +85,7 @@ let to_ast_kind (P.K_aux (k, l)) =
| P.K_type -> K_aux (K_type, l)
| P.K_int -> K_aux (K_int, l)
| P.K_order -> K_aux (K_order, l)
+ | P.K_bool -> K_aux (K_bool, l)
let to_ast_id (P.Id_aux(id, l)) =
if string_contains (string_of_parse_id_aux id) '#' && not (!opt_magic_hash) then
@@ -143,6 +144,8 @@ let rec to_ast_typ ctx (P.ATyp_aux (aux, l)) =
| P.ATyp_tup typs -> Typ_tup (List.map (to_ast_typ ctx) typs)
| P.ATyp_app (P.Id_aux (P.Id "int", il), [n]) ->
Typ_app (Id_aux (Id "atom", il), [to_ast_typ_arg ctx n K_int])
+ | P.ATyp_app (P.Id_aux (P.Id "bool", il), [n]) ->
+ Typ_app (Id_aux (Id "atom_bool", il), [to_ast_typ_arg ctx n K_bool])
| P.ATyp_app (id, args) ->
let id = to_ast_id id in
begin match Bindings.find_opt id ctx.type_constructors with
@@ -166,6 +169,7 @@ and to_ast_typ_arg ctx (ATyp_aux (_, l) as atyp) = function
| K_type -> Typ_arg_aux (Typ_arg_typ (to_ast_typ ctx atyp), l)
| K_int -> Typ_arg_aux (Typ_arg_nexp (to_ast_nexp ctx atyp), l)
| K_order -> Typ_arg_aux (Typ_arg_order (to_ast_order ctx atyp), l)
+ | K_bool -> Typ_arg_aux (Typ_arg_bool (to_ast_constraint ctx atyp), l)
and to_ast_nexp ctx (P.ATyp_aux (aux, l)) =
let aux = match aux with
@@ -203,6 +207,17 @@ and to_ast_constraint ctx (P.ATyp_aux (aux, l) as atyp) =
| "|" -> NC_or (to_ast_constraint ctx t1, to_ast_constraint ctx t2)
| _ -> raise (Reporting.err_typ l ("Invalid operator in constraint"))
end
+ | P.ATyp_app (id, args) ->
+ let id = to_ast_id id in
+ begin match Bindings.find_opt id ctx.type_constructors with
+ | None -> raise (Reporting.err_typ l (sprintf "Could not find type constructor %s" (string_of_id id)))
+ | Some kinds when List.length args <> List.length kinds ->
+ raise (Reporting.err_typ l (sprintf "%s : %s -> Bool expected %d arguments, given %d"
+ (string_of_id id) (format_kind_aux_list kinds)
+ (List.length kinds) (List.length args)))
+ | Some kinds -> NC_app (id, List.map2 (to_ast_typ_arg ctx) args kinds)
+ end
+ | P.ATyp_var v -> NC_var (to_ast_var v)
| P.ATyp_lit (P.L_aux (P.L_true, _)) -> NC_true
| P.ATyp_lit (P.L_aux (P.L_false, _)) -> NC_false
| P.ATyp_nset (id, bounds) -> NC_set (to_ast_var id, bounds)
@@ -490,11 +505,12 @@ let add_constructor id typq ctx =
let to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit type_def ctx_out =
let aux, ctx = match aux with
- | P.TD_abbrev (id, namescm_opt, P.TypSchm_aux (P.TypSchm_ts (typq, typ), l)) ->
+ | P.TD_abbrev (id, typq, kind, typ_arg) ->
let id = to_ast_id id in
let typq, typq_ctx = to_ast_typquant ctx typq in
- let typ = to_ast_typ typq_ctx typ in
- TD_abbrev (id, to_ast_namescm namescm_opt, TypSchm_aux (TypSchm_ts (typq, typ), l)),
+ let kind = to_ast_kind kind in
+ let typ_arg = to_ast_typ_arg typq_ctx typ_arg (unaux_kind kind) in
+ TD_abbrev (id, typq, typ_arg),
add_constructor id typq ctx
| P.TD_record (id, namescm_opt, typq, fields, _) ->
diff --git a/src/lexer.mll b/src/lexer.mll
index f5a982eb..57580e7a 100644
--- a/src/lexer.mll
+++ b/src/lexer.mll
@@ -140,6 +140,7 @@ let kw_table =
("ref", (fun _ -> Ref));
("Int", (fun x -> Int));
("Order", (fun x -> Order));
+ ("Bool", (fun x -> Bool));
("pure", (fun x -> Pure));
("register", (fun x -> Register));
("return", (fun x -> Return));
diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml
index f3a3fa54..a3d47814 100644
--- a/src/ocaml_backend.ml
+++ b/src/ocaml_backend.ml
@@ -602,7 +602,7 @@ let ocaml_typedef ctx (TD_aux (td_aux, _)) =
^//^ (bar ^^ space ^^ ocaml_enum ctx ids))
^^ ocaml_def_end
^^ ocaml_string_of_enum ctx id ids
- | TD_abbrev (id, _, TypSchm_aux (TypSchm_ts (typq, typ), _)) ->
+ | TD_abbrev (id, typq, Typ_arg_aux (Typ_arg_typ typ, _)) ->
separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals; ocaml_typ ctx typ]
^^ ocaml_def_end
^^ ocaml_string_of_abbrev ctx id typq typ
@@ -706,7 +706,7 @@ let ocaml_pp_generators ctx defs orig_types required =
-> required
and add_req_from_td required (TD_aux (td,(l,_))) =
match td with
- | TD_abbrev (_, _, TypSchm_aux (TypSchm_ts (_,typ),_)) ->
+ | TD_abbrev (_, _, Typ_arg_aux (Typ_arg_typ typ, _)) ->
add_req_from_typ required typ
| TD_record (_, _, _, fields, _) ->
List.fold_left (fun req (typ,_) -> add_req_from_typ req typ) required fields
@@ -723,10 +723,11 @@ let ocaml_pp_generators ctx defs orig_types required =
match Bindings.find id typemap with
| TD_aux (td,_) ->
(match td with
- | TD_abbrev (_,_,TypSchm_aux (TypSchm_ts (tqs,typ),_)) -> tqs
+ | TD_abbrev (_,tqs,Typ_arg_aux (Typ_arg_typ _, _)) -> tqs
| TD_record (_,_,tqs,_,_) -> tqs
| TD_variant (_,_,tqs,_,_) -> tqs
| TD_enum _ -> TypQ_aux (TypQ_no_forall,Unknown)
+ | TD_abbrev (_, _, _) -> assert false
| TD_bitfield _ -> assert false)
| exception Not_found ->
Bindings.find id Type_check.Env.builtin_typs
@@ -844,7 +845,7 @@ let ocaml_pp_generators ctx defs orig_types required =
let tqs, body, constructors, builders =
let TD_aux (td,(l,_)) = Bindings.find id typemap in
match td with
- | TD_abbrev (_,_,TypSchm_aux (TypSchm_ts (tqs,typ),_)) ->
+ | TD_abbrev (_,tqs,Typ_arg_aux (Typ_arg_typ typ, _)) ->
tqs, gen_type typ, None, None
| TD_variant (_,_,tqs,variants,_) ->
tqs,
diff --git a/src/parse_ast.ml b/src/parse_ast.ml
index 204389f9..c57daa26 100644
--- a/src/parse_ast.ml
+++ b/src/parse_ast.ml
@@ -74,6 +74,7 @@ kind_aux = (* base kind *)
K_type (* kind of types *)
| K_int (* kind of natural number size expressions *)
| K_order (* kind of vector order specifications *)
+ | K_bool (* kind of constraints *)
type
@@ -443,7 +444,7 @@ fundef_aux = (* Function definition *)
type
type_def_aux = (* Type definition body *)
- TD_abbrev of id * name_scm_opt * typschm (* type abbreviation *)
+ TD_abbrev of id * typquant * kind * atyp (* type abbreviation *)
| TD_record of id * name_scm_opt * typquant * ((atyp * id)) list * bool (* struct type definition *)
| TD_variant of id * name_scm_opt * typquant * (type_union) list * bool (* union type definition *)
| TD_enum of id * name_scm_opt * (id) list * bool (* enumeration type definition *)
diff --git a/src/parser.mly b/src/parser.mly
index bb5aa5f1..fa36591c 100644
--- a/src/parser.mly
+++ b/src/parser.mly
@@ -175,7 +175,7 @@ let rec desugar_rchain chain s e =
/*Terminals with no content*/
%token And As Assert Bitzero Bitone By Match Clause Dec Default Effect End Op Where
-%token Enum Else False Forall Foreach Overload Function_ Mapping If_ In Inc Let_ Int Order Cast
+%token Enum Else False Forall Foreach Overload Function_ Mapping If_ In Inc Let_ Int Order Bool Cast
%token Pure Register Return Scattered Sizeof Struct Then True TwoCaret TYPE Typedef
%token Undefined Union Newtype With Val Constant Constraint Throw Try Catch Exit Bitfield
%token Barr Depend Rreg Wreg Rmem Rmemt Wmem Wmv Wmvt Eamem Exmem Undef Unspec Nondet Escape
@@ -563,6 +563,8 @@ kind:
{ K_aux (K_type, loc $startpos $endpos) }
| Order
{ K_aux (K_order, loc $startpos $endpos) }
+ | Bool
+ { K_aux (K_bool, loc $startpos $endpos) }
kopt:
| Lparen kid Colon kind Rparen
@@ -1154,9 +1156,13 @@ typaram:
type_def:
| Typedef id typaram Eq typ
- { mk_td (TD_abbrev ($2, mk_namesectn, mk_typschm $3 $5 $startpos($3) $endpos)) $startpos $endpos }
+ { mk_td (TD_abbrev ($2, $3, K_aux (K_type, Parse_ast.Unknown), $5)) $startpos $endpos }
| Typedef id Eq typ
- { mk_td (TD_abbrev ($2, mk_namesectn, mk_typschm mk_typqn $4 $startpos($4) $endpos)) $startpos $endpos }
+ { mk_td (TD_abbrev ($2, mk_typqn, K_aux (K_type, Parse_ast.Unknown), $4)) $startpos $endpos }
+ | Typedef id typaram MinusGt kind Eq typ
+ { mk_td (TD_abbrev ($2, $3, $5, $7)) $startpos $endpos }
+ | Typedef id Colon kind Eq typ
+ { mk_td (TD_abbrev ($2, mk_typqn, $4, $6)) $startpos $endpos }
| Struct id Eq Lcurly struct_fields Rcurly
{ mk_td (TD_record ($2, mk_namesectn, TypQ_aux (TypQ_tq [], loc $endpos($2) $startpos($3)), $5, false)) $startpos $endpos }
| Struct id typaram Eq Lcurly struct_fields Rcurly
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml
index 7cc61507..50a97fa8 100644
--- a/src/pretty_print_coq.ml
+++ b/src/pretty_print_coq.ml
@@ -1711,7 +1711,8 @@ let rec doc_range (BF_aux(r,_)) = match r with
| BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2)
let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with
- | TD_abbrev(id,nm,(TypSchm_aux (TypSchm_ts (typq, _), _) as typschm)) ->
+ | TD_abbrev(id,typq,Typ_arg_aux (Typ_arg_typ typ, _)) ->
+ let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in
doc_op coloneq
(separate space [string "Definition"; doc_id_type id;
doc_typquant_items empty_ctxt parens typq;
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index be790a1c..e5613961 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -1006,7 +1006,8 @@ let rec doc_range_lem (BF_aux(r,_)) = match r with
| BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2)
let doc_typdef_lem (TD_aux(td, (l, annot))) = match td with
- | TD_abbrev(id,nm,(TypSchm_aux (TypSchm_ts (typq, _), _) as typschm)) ->
+ | TD_abbrev(id,typq,Typ_arg_aux (Typ_arg_typ typ, _)) ->
+ let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in
doc_op equals
(separate space [string "type"; doc_id_lem_type id; doc_typquant_items_lem None typq])
(doc_typschm_lem false typschm)
diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml
index 5201744b..7fb67a06 100644
--- a/src/pretty_print_sail.ml
+++ b/src/pretty_print_sail.ml
@@ -259,7 +259,7 @@ let doc_typschm ?(simple=false) (TypSchm_aux (TypSchm_ts (typq, typ), _)) = doc_
let doc_typschm_typ (TypSchm_aux (TypSchm_ts (TypQ_aux (tq_aux, _), typ), _)) = doc_typ typ
-let doc_typschm_quants (TypSchm_aux (TypSchm_ts (TypQ_aux (tq_aux, _), typ), _)) =
+let doc_typquant (TypQ_aux (tq_aux, _)) =
match tq_aux with
| TypQ_no_forall -> None
| TypQ_tq [] -> None
@@ -562,13 +562,13 @@ let doc_field (typ, id) =
let doc_union (Tu_aux (Tu_ty_id (typ, id), l)) = separate space [doc_id id; colon; doc_typ typ]
let doc_typdef (TD_aux(td,_)) = match td with
- | TD_abbrev (id, _, typschm) ->
+ | TD_abbrev (id, typq, typ_arg) ->
begin
- match doc_typschm_quants typschm with
+ match doc_typquant typq with
| Some qdoc ->
- doc_op equals (concat [string "type"; space; doc_id id; qdoc]) (doc_typschm_typ typschm)
+ doc_op equals (concat [string "type"; space; doc_id id; qdoc]) (doc_typ_arg typ_arg)
| None ->
- doc_op equals (concat [string "type"; space; doc_id id]) (doc_typschm_typ typschm)
+ doc_op equals (concat [string "type"; space; doc_id id]) (doc_typ_arg typ_arg)
end
| TD_enum (id, _, ids, _) ->
separate space [string "enum"; doc_id id; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_id ids) rbrace]
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 77070025..200121c0 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -94,7 +94,7 @@ let lookup_generated_kid env kid =
let generated_kids typ = KidSet.filter is_kid_generated (tyvars_of_typ typ)
let resolve_generated_kids env typ =
- let subst_kid kid typ = typ_subst_kid kid (lookup_generated_kid env kid) typ in
+ let subst_kid kid typ = subst_kid typ_subst kid (lookup_generated_kid env kid) typ in
KidSet.fold subst_kid (generated_kids typ) typ
let rec remove_p_typ = function
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 0ead9670..82228206 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -2300,9 +2300,10 @@ let rewrite_constraint =
let rewrite_type_union_typs rw_typ (Tu_aux (Tu_ty_id (typ, id), annot)) =
Tu_aux (Tu_ty_id (rw_typ typ, id), annot)
-let rewrite_type_def_typs rw_typ rw_typquant rw_typschm (TD_aux (td, annot)) =
+let rewrite_type_def_typs rw_typ rw_typquant (TD_aux (td, annot)) =
match td with
- | TD_abbrev (id, nso, typschm) -> TD_aux (TD_abbrev (id, nso, rw_typschm typschm), annot)
+ | TD_abbrev (id, typq, Typ_arg_aux (Typ_arg_typ typ, l)) ->
+ TD_aux (TD_abbrev (id, rw_typquant typq, Typ_arg_aux (Typ_arg_typ (rw_typ typ), l)), annot)
| TD_record (id, nso, typq, typ_ids, flag) ->
TD_aux (TD_record (id, nso, rw_typquant typq, List.map (fun (typ, id) -> (rw_typ typ, id)) typ_ids, flag), annot)
| TD_variant (id, nso, typq, tus, flag) ->
@@ -2396,7 +2397,7 @@ let rewrite_simple_types (Defs defs) =
in
let simple_def = function
| DEF_spec vs -> DEF_spec (simple_vs vs)
- | DEF_type td -> DEF_type (rewrite_type_def_typs simple_typ simple_typquant simple_typschm td)
+ | DEF_type td -> DEF_type (rewrite_type_def_typs simple_typ simple_typquant td)
| DEF_reg_dec ds -> DEF_reg_dec (rewrite_dec_spec_typs simple_typ ds)
| def -> def
in
diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml
index 9453e999..84fa8235 100644
--- a/src/spec_analysis.ml
+++ b/src/spec_analysis.ml
@@ -309,7 +309,9 @@ let fv_of_kind_def consider_var (KD_aux(k,_)) = match k with
| KD_nabbrev(_,id,_,nexp) -> init_env (string_of_id id), fv_of_nexp consider_var mt mt nexp
let fv_of_type_def consider_var (TD_aux(t,_)) = match t with
- | TD_abbrev(id,_,typschm) -> init_env (string_of_id id), snd (fv_of_typschm consider_var mt mt typschm)
+ | TD_abbrev(id,typq,Typ_arg_aux(Typ_arg_typ typ, l)) ->
+ let typschm = TypSchm_aux (TypSchm_ts (typq,typ), l) in
+ init_env (string_of_id id), snd (fv_of_typschm consider_var mt mt typschm)
| TD_record(id,_,typq,tids,_) ->
let binds = init_env (string_of_id id) in
let bounds = if consider_var then typq_bindings typq else mt in
diff --git a/src/state.ml b/src/state.ml
index 31f5c7eb..00f81bf4 100644
--- a/src/state.ml
+++ b/src/state.ml
@@ -127,15 +127,9 @@ let generate_initial_regstate defs =
| Typ_exist (_, _, typ) -> lookup_init_val vals typ
| _ -> raise Not_found
in
- (* Helper functions to instantiate type arguments *)
- let typ_subst_targ kid (Typ_arg_aux (arg, _)) typ = match arg with
- | Typ_arg_nexp (Nexp_aux (nexp, _)) -> typ_subst_nexp kid nexp typ
- | Typ_arg_typ (Typ_aux (typ', _)) -> typ_subst_typ kid typ' typ
- | Typ_arg_order (Ord_aux (ord, _)) -> typ_subst_order kid ord typ
- in
let typ_subst_quant_item typ (QI_aux (qi, _)) arg = match qi with
| QI_id (KOpt_aux ((KOpt_none kid | KOpt_kind (_, kid)), _)) ->
- typ_subst_targ kid arg typ
+ typ_subst kid arg typ
| _ -> typ
in
let typ_subst_typquant tq args typ =
@@ -152,7 +146,7 @@ let generate_initial_regstate defs =
string_of_id id1 ^ " (" ^ lookup_init_val vals typ1 ^ ")"
in
Bindings.add id init_val vals
- | TD_abbrev (id, _, TypSchm_aux (TypSchm_ts (tq, typ), _)) ->
+ | TD_abbrev (id, tq, Typ_arg_aux (Typ_arg_typ typ, _)) ->
let init_val args = lookup_init_val vals (typ_subst_typquant tq args typ) in
Bindings.add id init_val vals
| TD_record (id, _, tq, fields, _) ->
diff --git a/src/type_check.ml b/src/type_check.ml
index f204a558..2c10b8ae 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -167,6 +167,8 @@ and strip_n_constraint_aux = function
| NC_set (kid, nums) -> NC_set (strip_kid kid, nums)
| NC_or (nc1, nc2) -> NC_or (strip_n_constraint nc1, strip_n_constraint nc2)
| NC_and (nc1, nc2) -> NC_and (strip_n_constraint nc1, strip_n_constraint nc2)
+ | NC_var kid -> NC_var (strip_kid kid)
+ | NC_app (id, args) -> NC_app (strip_id id, List.map strip_typ_arg args)
| NC_true -> NC_true
| NC_false -> NC_false
and strip_n_constraint = function
@@ -177,6 +179,7 @@ and strip_typ_arg_aux = function
| Typ_arg_nexp nexp -> Typ_arg_nexp (strip_nexp nexp)
| Typ_arg_typ typ -> Typ_arg_typ (strip_typ typ)
| Typ_arg_order ord -> Typ_arg_order (strip_order ord)
+ | Typ_arg_bool nc -> Typ_arg_bool (strip_n_constraint nc)
and strip_order = function
| Ord_aux (ord_aux, _) -> Ord_aux (strip_order_aux ord_aux, Parse_ast.Unknown)
and strip_order_aux = function
@@ -256,9 +259,8 @@ module Env : sig
val add_typ_var : l -> kid -> kind_aux -> t -> t
val get_ret_typ : t -> typ option
val add_ret_typ : typ -> t -> t
- val add_typ_synonym : id -> (t -> typ_arg list -> typ) -> t -> t
- val get_typ_synonym : id -> t -> t -> typ_arg list -> typ
- val add_constraint_synonym : id -> kid list -> n_constraint -> t -> t
+ val add_typ_synonym : id -> (t -> typ_arg list -> typ_arg) -> t -> t
+ val get_typ_synonym : id -> t -> t -> typ_arg list -> typ_arg
val add_num_def : id -> nexp -> t -> t
val get_num_def : id -> t -> nexp
val add_overloads : id -> id list -> t -> t
@@ -282,6 +284,7 @@ module Env : sig
val lookup_id : ?raw:bool -> id -> t -> typ lvar
val fresh_kid : ?kid:kid -> t -> kid
val expand_synonyms : t -> typ -> typ
+ val expand_constraint_synonyms : t -> n_constraint -> n_constraint
val canonicalize : t -> typ -> typ
val base_typ_of : t -> typ -> typ
val add_smt_op : id -> string -> t -> t
@@ -320,7 +323,7 @@ end = struct
variants : (typquant * type_union list) Bindings.t;
mappings : (typquant * typ * typ) Bindings.t;
typ_vars : (Ast.l * kind_aux) KBindings.t;
- typ_synonyms : (t -> typ_arg list -> typ) Bindings.t;
+ typ_synonyms : (t -> typ_arg list -> typ_arg) Bindings.t;
num_defs : nexp Bindings.t;
overloads : (id list) Bindings.t;
flow : (typ -> typ) Bindings.t;
@@ -329,7 +332,6 @@ end = struct
accessors : (typquant * typ) Bindings.t;
externs : (string -> string option) Bindings.t;
smt_ops : string Bindings.t;
- constraint_synonyms : (kid list * n_constraint) Bindings.t;
casts : id list;
allow_casts : bool;
allow_bindings : bool;
@@ -359,7 +361,6 @@ end = struct
accessors = Bindings.empty;
externs = Bindings.empty;
smt_ops = Bindings.empty;
- constraint_synonyms = Bindings.empty;
casts = [];
allow_bindings = true;
allow_casts = true;
@@ -465,8 +466,8 @@ end = struct
let kopts, ncs = quant_split typq in
let rec subst_args kopts args =
match kopts, args with
- | kopt :: kopts, Typ_arg_aux (Typ_arg_nexp arg, _) :: args when is_nat_kopt kopt ->
- List.map (nc_subst_nexp (kopt_kid kopt) (unaux_nexp arg)) (subst_args kopts args)
+ | kopt :: kopts, (Typ_arg_aux (Typ_arg_nexp _, _) as arg) :: args when is_nat_kopt kopt ->
+ List.map (constraint_subst (kopt_kid kopt) arg) (subst_args kopts args)
| kopt :: kopts, Typ_arg_aux (Typ_arg_typ arg, _) :: args when is_typ_kopt kopt ->
subst_args kopts args
| kopt :: kopts, Typ_arg_aux (Typ_arg_order arg, _) :: args when is_order_kopt kopt ->
@@ -480,28 +481,41 @@ end = struct
then ()
else typ_error (id_loc id) ("Could not prove " ^ string_of_list ", " string_of_n_constraint ncs ^ " for type constructor " ^ string_of_id id)
- let rec expand_synonyms env (Typ_aux (typ, l) as t) =
+ let rec expand_constraint_synonyms env (NC_aux (aux, l) as nc) =
+ match aux with
+ | NC_or (nc1, nc2) -> NC_aux (NC_or (expand_constraint_synonyms env nc1, expand_constraint_synonyms env nc2), l)
+ | NC_and (nc1, nc2) -> NC_aux (NC_and (expand_constraint_synonyms env nc1, expand_constraint_synonyms env nc2), l)
+ | NC_app (id, args) ->
+ (try
+ begin match Bindings.find id env.typ_synonyms env args with
+ | Typ_arg_aux (Typ_arg_bool nc, _) -> expand_constraint_synonyms env nc
+ | _ -> typ_error l ("Expected Type when expanding synonym " ^ string_of_id id)
+ end
+ with Not_found -> NC_aux (NC_app (id, List.map (expand_synonyms_arg env) args), l))
+ | NC_true | NC_false | NC_equal _ | NC_not_equal _ | NC_bounded_le _ | NC_bounded_ge _ | NC_var _ | NC_set _ -> nc
+
+ and expand_synonyms env (Typ_aux (typ, l) as t) =
match typ with
| Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l)
| Typ_tup typs -> Typ_aux (Typ_tup (List.map (expand_synonyms env) typs), l)
| Typ_fn (arg_typs, ret_typ, effs) -> Typ_aux (Typ_fn (List.map (expand_synonyms env) arg_typs, expand_synonyms env ret_typ, effs), l)
| Typ_bidir (typ1, typ2) -> Typ_aux (Typ_bidir (expand_synonyms env typ1, expand_synonyms env typ2), l)
| Typ_app (id, args) ->
- begin
- try
- let synonym = Bindings.find id env.typ_synonyms in
- expand_synonyms env (synonym env args)
- with
- | Not_found -> Typ_aux (Typ_app (id, List.map (expand_synonyms_arg env) args), l)
- end
+ (try
+ begin match Bindings.find id env.typ_synonyms env args with
+ | Typ_arg_aux (Typ_arg_typ typ, _) -> expand_synonyms env typ
+ | _ -> typ_error l ("Expected Type when expanding synonym " ^ string_of_id id)
+ end
+ with
+ | Not_found -> Typ_aux (Typ_app (id, List.map (expand_synonyms_arg env) args), l))
| Typ_id id ->
- begin
- try
- let synonym = Bindings.find id env.typ_synonyms in
- expand_synonyms env (synonym env [])
- with
- | Not_found -> Typ_aux (Typ_id id, l)
- end
+ (try
+ begin match Bindings.find id env.typ_synonyms env [] with
+ | Typ_arg_aux (Typ_arg_typ typ, _) -> expand_synonyms env typ
+ | _ -> typ_error l ("Expected Type when expanding synonym " ^ string_of_id id)
+ end
+ with
+ | Not_found -> Typ_aux (Typ_id id, l))
| Typ_exist (kids, nc, typ) ->
(* When expanding an existential synonym we need to take care
to add the type variables and constraints to the
@@ -524,8 +538,8 @@ end = struct
let env = List.fold_left add_typ_var env kids in
let kids = List.map rename_kid kids in
- let nc = List.fold_left (fun nc kid -> nc_subst_nexp kid (Nexp_var (prepend_kid "syn#" kid)) nc) nc !rebindings in
- let typ = List.fold_left (fun typ kid -> typ_subst_nexp kid (Nexp_var (prepend_kid "syn#" kid)) typ) typ !rebindings in
+ let nc = List.fold_left (fun nc kid -> constraint_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) nc) nc !rebindings in
+ let typ = List.fold_left (fun typ kid -> typ_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) typ) typ !rebindings in
typ_debug (lazy ("Synonym existential: {" ^ string_of_list " " string_of_kid kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}"));
let env = { env with constraints = nc :: env.constraints } in
Typ_aux (Typ_exist (kids, nc, expand_synonyms env typ), l)
@@ -533,6 +547,7 @@ end = struct
and expand_synonyms_arg env (Typ_arg_aux (typ_arg, l)) =
match typ_arg with
| Typ_arg_typ typ -> Typ_arg_aux (Typ_arg_typ (expand_synonyms env typ), l)
+ | Typ_arg_bool nc -> Typ_arg_aux (Typ_arg_bool (expand_constraint_synonyms env nc), l)
| arg -> Typ_arg_aux (arg, l)
(** Map over all nexps in a type - excluding those in existential constraints **)
@@ -547,7 +562,7 @@ end = struct
| Typ_app (id, args) -> Typ_aux (Typ_app (id, List.map (map_nexps_arg f) args), l)
and map_nexps_arg f (Typ_arg_aux (arg_aux, l) as arg) =
match arg_aux with
- | Typ_arg_order _ | Typ_arg_typ _ -> arg
+ | Typ_arg_order _ | Typ_arg_typ _ | Typ_arg_bool _ -> arg
| Typ_arg_nexp n -> Typ_arg_aux (Typ_arg_nexp (f n), l)
let canonical env typ =
@@ -600,7 +615,6 @@ end = struct
(* Check if a type, order, n-expression or constraint is
well-formed. Throws a type error if the type is badly formed. *)
let rec wf_typ ?exs:(exs=KidSet.empty) env typ =
- typ_debug (lazy ("well-formed " ^ string_of_typ typ));
let (Typ_aux (typ_aux, l)) = expand_synonyms env typ in
match typ_aux with
| Typ_id id when bound_typ_id env id ->
@@ -637,8 +651,8 @@ end = struct
| Typ_arg_nexp nexp -> wf_nexp ~exs:exs env nexp
| Typ_arg_typ typ -> wf_typ ~exs:exs env typ
| Typ_arg_order ord -> wf_order env ord
+ | Typ_arg_bool nc -> wf_constraint ~exs:exs env nc
and wf_nexp ?exs:(exs=KidSet.empty) env (Nexp_aux (nexp_aux, l) as nexp) =
- typ_debug (lazy ("well-formed nexp " ^ string_of_nexp nexp));
match nexp_aux with
| Nexp_id _ -> ()
| Nexp_var kid when KidSet.mem kid exs -> ()
@@ -671,22 +685,29 @@ end = struct
end
| Ord_inc | Ord_dec -> ()
and wf_constraint ?exs:(exs=KidSet.empty) env (NC_aux (nc_aux, l) as nc) =
- typ_debug (lazy ("well-formed constraint " ^ string_of_n_constraint nc));
match nc_aux with
| NC_equal (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2
| NC_not_equal (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2
| NC_bounded_ge (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2
| NC_bounded_le (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2
| NC_set (kid, _) when KidSet.mem kid exs -> ()
- | NC_set (kid, _) -> begin
- match get_typ_var kid env with
- | K_int -> ()
- | kind -> typ_error l ("Set constraint is badly formed, "
- ^ string_of_kid kid ^ " has kind "
- ^ string_of_kind_aux kind ^ " but should have kind Int")
- end
+ | NC_set (kid, _) ->
+ begin match get_typ_var kid env with
+ | K_int -> ()
+ | kind -> typ_error l ("Set constraint is badly formed, "
+ ^ string_of_kid kid ^ " has kind "
+ ^ string_of_kind_aux kind ^ " but should have kind Int")
+ end
| NC_or (nc1, nc2) -> wf_constraint ~exs:exs env nc1; wf_constraint ~exs:exs env nc2
| NC_and (nc1, nc2) -> wf_constraint ~exs:exs env nc1; wf_constraint ~exs:exs env nc2
+ | NC_app (id, args) -> List.iter (wf_typ_arg ~exs:exs env) args
+ | NC_var kid ->
+ begin match get_typ_var kid env with
+ | K_bool -> ()
+ | kind -> typ_error l ("Set constraint is badly formed, "
+ ^ string_of_kid kid ^ " has kind "
+ ^ string_of_kind_aux kind ^ " but should have kind Bool")
+ end
| NC_true | NC_false -> ()
let counter = ref 0
@@ -699,7 +720,7 @@ end = struct
let freshen_kid env kid (typq, typ) =
let fresh = fresh_kid ~kid:kid env in
if KidSet.mem kid (KidSet.of_list (List.map kopt_kid (quant_kopts typq))) then
- (typquant_subst_kid kid fresh typq, typ_subst_kid kid fresh typ)
+ (typquant_subst_kid kid fresh typq, subst_kid typ_subst kid fresh typ)
else
(typq, typ)
@@ -733,8 +754,8 @@ end = struct
match expand_synonyms env typ with
| Typ_aux (Typ_exist (kids, nc, typ), _) ->
let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in
- let nc = List.fold_left (fun nc (kid, fresh) -> nc_subst_nexp kid (Nexp_var fresh) nc) nc fresh_kids in
- let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst_nexp kid (Nexp_var fresh) typ) typ fresh_kids in
+ let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in
+ let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in
Some (List.map snd fresh_kids, nc, typ)
| _ -> None
@@ -1067,16 +1088,6 @@ end = struct
let get_typ_synonym id env = Bindings.find id env.typ_synonyms
- let add_constraint_synonym id kids nc env =
- if Bindings.mem id env.constraint_synonyms
- then typ_error (id_loc id) ("Constraint synonym " ^ string_of_id id ^ " already exists")
- else
- begin
- typ_print (lazy (adding ^ "constraint synonym " ^ string_of_id id));
- wf_constraint ~exs:(KidSet.of_list kids) env nc;
- { env with constraint_synonyms = Bindings.add id (kids, nc) env.constraint_synonyms }
- end
-
let get_default_order env =
match env.default_order with
| None -> typ_error Parse_ast.Unknown ("No default order has been set")
@@ -1154,8 +1165,8 @@ let destruct_exist env typ =
match Env.expand_synonyms env typ with
| Typ_aux (Typ_exist (kids, nc, typ), _) ->
let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in
- let nc = List.fold_left (fun nc (kid, fresh) -> nc_subst_nexp kid (Nexp_var fresh) nc) nc fresh_kids in
- let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst_nexp kid (Nexp_var fresh) typ) typ fresh_kids in
+ let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in
+ let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in
Some (List.map snd fresh_kids, nc, typ)
| _ -> None
@@ -1282,65 +1293,18 @@ this is equivalent to
which is then a problem we can feed to the constraint solver expecting unsat.
*)
-let rec nexp_constraint env var_of (Nexp_aux (nexp, l)) =
- match nexp with
- | Nexp_id v -> nexp_constraint env var_of (Env.get_num_def v env)
- | Nexp_var kid -> Constraint.variable (var_of kid)
- | Nexp_constant c -> Constraint.constant c
- | Nexp_times (nexp1, nexp2) -> Constraint.mult (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2)
- | Nexp_sum (nexp1, nexp2) -> Constraint.add (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2)
- | Nexp_minus (nexp1, nexp2) -> Constraint.sub (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2)
- | Nexp_exp nexp -> Constraint.pow2 (nexp_constraint env var_of nexp)
- | Nexp_neg nexp -> Constraint.sub (Constraint.constant (Big_int.of_int 0)) (nexp_constraint env var_of nexp)
- | Nexp_app (id, nexps) -> Constraint.app (Env.get_smt_op id env) (List.map (nexp_constraint env var_of) nexps)
-
-let rec nc_constraint env var_of (NC_aux (nc, l)) =
- match nc with
- | NC_equal (nexp1, nexp2) -> Constraint.eq (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2)
- | NC_not_equal (nexp1, nexp2) -> Constraint.neq (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2)
- | NC_bounded_ge (nexp1, nexp2) -> Constraint.gteq (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2)
- | NC_bounded_le (nexp1, nexp2) -> Constraint.lteq (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2)
- | NC_set (_, []) -> Constraint.literal false
- | NC_set (kid, (int :: ints)) ->
- List.fold_left Constraint.disj
- (Constraint.eq (nexp_constraint env var_of (nvar kid)) (Constraint.constant int))
- (List.map (fun i -> Constraint.eq (nexp_constraint env var_of (nvar kid)) (Constraint.constant i)) ints)
- | NC_or (nc1, nc2) -> Constraint.disj (nc_constraint env var_of nc1) (nc_constraint env var_of nc2)
- | NC_and (nc1, nc2) -> Constraint.conj (nc_constraint env var_of nc1) (nc_constraint env var_of nc2)
- | NC_false -> Constraint.literal false
- | NC_true -> Constraint.literal true
-
-let rec nc_constraints env var_of ncs =
- match ncs with
- | [] -> Constraint.literal true
- | [nc] -> nc_constraint env var_of nc
- | (nc :: ncs) ->
- Constraint.conj (nc_constraint env var_of nc) (nc_constraints env var_of ncs)
-
-let prove_z3' env constr =
- let module Bindings = Map.Make(Kid) in
- let bindings = ref Bindings.empty in
- let fresh_var kid =
- let n = Bindings.cardinal !bindings in
- bindings := Bindings.add kid n !bindings;
- n
- in
- let var_of kid =
- try Bindings.find kid !bindings with
- | Not_found -> fresh_var kid
- in
- let constr = Constraint.conj (nc_constraints env var_of (Env.get_constraints env)) (constr var_of) in
- match Constraint.call_z3 constr with
+let prove_z3 env (NC_aux (_, l) as nc) =
+ let vars = Env.get_typ_vars env in
+ let vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) vars in
+ let ncs = Env.get_constraints env in
+ match Constraint.call_z3 l vars (List.fold_left nc_and (nc_not nc) ncs) with
| Constraint.Unsat -> typ_debug (lazy "unsat"); true
| Constraint.Sat -> typ_debug (lazy "sat"); false
| Constraint.Unknown -> typ_debug (lazy "unknown"); false
-let prove_z3 env nc =
- typ_print (lazy (Util.("Prove " |> red |> clear) ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_n_constraint nc));
- prove_z3' env (fun var_of -> Constraint.negate (nc_constraint env var_of nc))
+let solve env nexp = failwith "WIP"
-let solve env nexp =
- typ_print (lazy ("Solve " ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_nexp nexp ^ " = ?"));
+ (* typ_print (lazy ("Solve " ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_nexp nexp ^ " = ?"));
match nexp with
| Nexp_aux (Nexp_constant n,_) -> Some n
| _ ->
@@ -1359,8 +1323,11 @@ let solve env nexp =
(nc_constraint env var_of (nc_eq (nvar (mk_kid "solve#")) nexp))
in
Constraint.solve_z3 constr (var_of (mk_kid "solve#"))
+ *)
-let prove env (NC_aux (nc_aux, _) as nc) =
+let prove env nc =
+ typ_print (lazy (Util.("Prove " |> red |> clear) ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_n_constraint nc));
+ let (NC_aux (nc_aux, _) as nc) = Env.expand_constraint_synonyms env nc in
let compare_const f (Nexp_aux (n1, _)) (Nexp_aux (n2, _)) =
match n1, n2 with
| Nexp_constant c1, Nexp_constant c2 when f c1 c2 -> true
@@ -1499,61 +1466,91 @@ let typ_identical env typ1 typ2 =
in
typ_identical' (Env.expand_synonyms env typ1) (Env.expand_synonyms env typ2)
-type uvar =
- | U_nexp of nexp
- | U_order of order
- | U_typ of typ
+exception Unification_error of l * string;;
+
+let unify_error l str = raise (Unification_error (l, str))
+
+let merge_unifiers l kid uvar1 uvar2 =
+ match uvar1, uvar2 with
+ | Some (Typ_arg_aux (Typ_arg_nexp n1, _)), Some (Typ_arg_aux (Typ_arg_nexp n2, _)) ->
+ if nexp_identical n1 n2 then
+ Some (arg_nexp n1)
+ else
+ unify_error l ("Multiple non-identical unifiers for " ^ string_of_kid kid
+ ^ ": " ^ string_of_nexp n1 ^ " and " ^ string_of_nexp n2)
+ | Some _, Some _ -> unify_error l "Multiple non-identical non-nexp unifiers"
+ | None, Some u2 -> Some u2
+ | Some u1, None -> Some u1
+ | None, None -> None
+
+let merge_uvars l unifiers1 unifiers2 =
+ KBindings.merge (merge_unifiers l) unifiers1 unifiers2
+
+let rec unify_typ l env goals (Typ_aux (aux1, _) as typ1) (Typ_aux (aux2, _) as typ2) =
+ match aux1, aux2 with
+ | Typ_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_typ typ2)
-let uvar_subst_nexp sv subst = function
- | U_nexp nexp -> U_nexp (nexp_subst sv subst nexp)
- | U_typ typ -> U_typ (typ_subst_nexp sv subst typ)
- | U_order ord -> U_order ord
+ | Typ_app (range, [Typ_arg_aux (Typ_arg_nexp n1, _); Typ_arg_aux (Typ_arg_nexp n2, _)]),
+ Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp m, _)])
+ when string_of_id range = "range" && string_of_id atom = "atom" ->
+ merge_uvars l (unify_nexp l env goals n1 m) (unify_nexp l env goals n2 m)
-let uvar_subst_typ sv subst = function
- | U_nexp nexp -> U_nexp nexp
- | U_typ typ -> U_typ (typ_subst_typ sv subst typ)
- | U_order ord -> U_order ord
+ | Typ_app (id1, args1), Typ_app (id2, args2) when List.length args1 = List.length args2 && Id.compare id1 id2 = 0 ->
+ List.fold_left (merge_uvars l) KBindings.empty (List.map2 (unify_typ_arg l env goals) args1 args2)
-let uvar_subst_order sv subst = function
- | U_nexp nexp -> U_nexp nexp
- | U_typ typ -> U_typ (typ_subst_order sv subst typ)
- | U_order ord -> U_order (order_subst sv subst ord)
+ | Typ_id id1, Typ_id id2 when Id.compare id1 id2 = 0 -> KBindings.empty
-exception Unification_error of l * string;;
+ | Typ_tup typs1, Typ_tup typs2 when List.length typs1 = List.length typs2 ->
+ List.fold_left (merge_uvars l) KBindings.empty (List.map2 (unify_typ l env goals) typs1 typs2)
-let unify_error l str = raise (Unification_error (l, str))
+ | _, _ -> unify_error l ("Cound not unify " ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2)
+
+and unify_typ_arg l env goals (Typ_arg_aux (aux1, _) as typ_arg1) (Typ_arg_aux (aux2, _) as typ_arg2) =
+ match aux1, aux2 with
+ | Typ_arg_typ typ1, Typ_arg_typ typ2 -> unify_typ l env goals typ1 typ2
+ | Typ_arg_nexp nexp1, Typ_arg_nexp nexp2 -> unify_nexp l env goals nexp1 nexp2
+ | Typ_arg_order ord1, Typ_arg_order ord2 -> unify_order l goals ord1 ord2
+ | _, _ -> unify_error l ("Could not unify type arguments " ^ string_of_typ_arg typ_arg1 ^ " and " ^ string_of_typ_arg typ_arg2)
+
+and unify_order l goals (Ord_aux (aux1, _) as ord1) (Ord_aux (aux2, _) as ord2) =
+ typ_print (lazy (Util.("Unify order " |> magenta |> clear) ^ string_of_order ord1 ^ " and " ^ string_of_order ord2));
+ match aux1, aux2 with
+ | Ord_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_order ord2)
+ | Ord_inc, Ord_inc -> KBindings.empty
+ | Ord_dec, Ord_dec -> KBindings.empty
+ | _, _ -> unify_error l ("Cound not unify " ^ string_of_order ord1 ^ " and " ^ string_of_order ord2)
-let rec unify_nexps l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) =
+and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) =
typ_debug (lazy ("UNIFYING NEXPS " ^ string_of_nexp nexp1 ^ " AND " ^ string_of_nexp nexp2 ^ " FOR GOALS " ^ string_of_list ", " string_of_kid (KidSet.elements goals)));
if KidSet.is_empty (KidSet.inter (nexp_frees nexp1) goals)
then
begin
if prove env (NC_aux (NC_equal (nexp1, nexp2), Parse_ast.Unknown))
- then None
+ then KBindings.empty
else unify_error l ("Nexp " ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 ^ " are not equal")
end
else
match nexp_aux1 with
| Nexp_id v -> unify_error l "Unimplemented Nexp_id in unify nexp"
- | Nexp_var kid when KidSet.mem kid goals -> Some (kid, nexp2)
+ | Nexp_var kid when KidSet.mem kid goals -> KBindings.singleton kid (arg_nexp nexp2)
| Nexp_constant c1 ->
begin
match nexp_aux2 with
- | Nexp_constant c2 -> if c1 = c2 then None else unify_error l "Constants are not the same"
+ | Nexp_constant c2 -> if c1 = c2 then KBindings.empty else unify_error l "Constants are not the same"
| _ -> unify_error l "Unification error"
end
| Nexp_sum (n1a, n1b) ->
if KidSet.is_empty (nexp_frees n1b)
- then unify_nexps l env goals n1a (nminus nexp2 n1b)
+ then unify_nexp l env goals n1a (nminus nexp2 n1b)
else
if KidSet.is_empty (nexp_frees n1a)
- then unify_nexps l env goals n1b (nminus nexp2 n1a)
+ then unify_nexp l env goals n1b (nminus nexp2 n1a)
else unify_error l ("Both sides of Int expression " ^ string_of_nexp nexp1
^ " contain free type variables so it cannot be unified with " ^ string_of_nexp nexp2)
| Nexp_minus (n1a, n1b) ->
if KidSet.is_empty (nexp_frees n1b)
- then unify_nexps l env goals n1a (nsum nexp2 n1b)
- else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
+ then unify_nexp l env goals n1a (nsum nexp2 n1b)
+ else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
| Nexp_times (n1a, n1b) ->
(* If we have SMT operations div and mod, then we can use the
property that
@@ -1564,20 +1561,20 @@ let rec unify_nexps l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (ne
if Env.have_smt_op (mk_id "div") env && Env.have_smt_op (mk_id "mod") env then
let valid n c = prove env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove env (nc_neq c (nint 0)) in
if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then
- unify_nexps l env goals n1a (napp (mk_id "div") [nexp2; n1b])
+ unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b])
else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then
- unify_nexps l env goals n1b (napp (mk_id "div") [nexp2; n1a])
+ unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a])
else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
else if KidSet.is_empty (nexp_frees n1a) then
begin
match nexp_aux2 with
| Nexp_times (n2a, n2b) when prove env (NC_aux (NC_equal (n1a, n2a), Parse_ast.Unknown)) ->
- unify_nexps l env goals n1b n2b
+ unify_nexp l env goals n1b n2b
| Nexp_constant c2 ->
begin
match n1a with
| Nexp_aux (Nexp_constant c1,_) when Big_int.equal (Big_int.modulus c2 c1) Big_int.zero ->
- unify_nexps l env goals n1b (mk_nexp (Nexp_constant (Big_int.div c2 c1)))
+ unify_nexp l env goals n1b (nconstant (Big_int.div c2 c1))
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
end
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
@@ -1586,148 +1583,30 @@ let rec unify_nexps l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (ne
begin
match nexp_aux2 with
| Nexp_times (n2a, n2b) when prove env (NC_aux (NC_equal (n1b, n2b), Parse_ast.Unknown)) ->
- unify_nexps l env goals n1a n2a
+ unify_nexp l env goals n1a n2a
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
end
else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
-let string_of_uvar = function
- | U_nexp n -> string_of_nexp n
- | U_order o -> string_of_order o
- | U_typ typ -> string_of_typ typ
-
-let unify_order l (Ord_aux (ord_aux1, _) as ord1) (Ord_aux (ord_aux2, _) as ord2) =
- typ_debug (lazy ("UNIFYING ORDERS " ^ string_of_order ord1 ^ " AND " ^ string_of_order ord2));
- match ord_aux1, ord_aux2 with
- | Ord_var kid, _ -> KBindings.singleton kid (U_order ord2)
- | Ord_inc, Ord_inc -> KBindings.empty
- | Ord_dec, Ord_dec -> KBindings.empty
- | _, _ -> unify_error l (string_of_order ord1 ^ " cannot be unified with " ^ string_of_order ord2)
+let unify l env typ1 typ2 goals =
+ typ_print (lazy (Util.("Unify " |> magenta |> clear) ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2));
+ let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in
+ unify_typ l env goals typ1 typ2
let subst_unifiers unifiers typ =
- let subst_unifier typ (kid, uvar) =
- match uvar with
- | U_nexp nexp -> typ_subst_nexp kid (unaux_nexp nexp) typ
- | U_order ord -> typ_subst_order kid (unaux_order ord) typ
- | U_typ subst -> typ_subst_typ kid (unaux_typ subst) typ
- in
- List.fold_left subst_unifier typ (KBindings.bindings unifiers)
-
-let subst_args_unifiers unifiers typ_args =
- let subst_unifier typ_args (kid, uvar) =
- match uvar with
- | U_nexp nexp -> List.map (typ_subst_arg_nexp kid (unaux_nexp nexp)) typ_args
- | U_order ord -> List.map (typ_subst_arg_order kid (unaux_order ord)) typ_args
- | U_typ subst -> List.map (typ_subst_arg_typ kid (unaux_typ subst)) typ_args
- in
- List.fold_left subst_unifier typ_args (KBindings.bindings unifiers)
-
-let subst_uvar_unifiers unifiers uvar =
- let subst_unifier uvar' (kid, uvar) =
- match uvar with
- | U_nexp nexp -> uvar_subst_nexp kid (unaux_nexp nexp) uvar'
- | U_order ord -> uvar_subst_order kid (unaux_order ord) uvar'
- | U_typ subst -> uvar_subst_typ kid (unaux_typ subst) uvar'
- in
- List.fold_left subst_unifier uvar (KBindings.bindings unifiers)
-
-let merge_unifiers l kid uvar1 uvar2 =
- match uvar1, uvar2 with
- | Some (U_nexp n1), Some (U_nexp n2) ->
- if nexp_identical n1 n2 then Some (U_nexp n1)
- else unify_error l ("Multiple non-identical unifiers for " ^ string_of_kid kid
- ^ ": " ^ string_of_nexp n1 ^ " and " ^ string_of_nexp n2)
- | Some _, Some _ -> unify_error l "Multiple non-identical non-nexp unifiers"
- | None, Some u2 -> Some u2
- | Some u1, None -> Some u1
- | None, None -> None
+ List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ (KBindings.bindings unifiers)
-let rec unify l env typ1 typ2 =
- typ_print (lazy (Util.("Unify " |> magenta |> clear) ^ string_of_typ typ1 ^ " with " ^ string_of_typ typ2));
- let goals = KidSet.inter (KidSet.diff (typ_frees typ1) (typ_frees typ2)) (typ_frees typ1) in
-
- let rec unify_typ l (Typ_aux (typ1_aux, _) as typ1) (Typ_aux (typ2_aux, _) as typ2) =
- typ_debug (lazy ("UNIFYING TYPES " ^ string_of_typ typ1 ^ " AND " ^ string_of_typ typ2));
- match typ1_aux, typ2_aux with
- | Typ_internal_unknown, _
- | _, Typ_internal_unknown when Env.allow_unknowns env -> KBindings.empty
- | Typ_id v1, Typ_id v2 ->
- if Id.compare v1 v2 = 0 then KBindings.empty
- else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2)
- | Typ_id v1, Typ_app (f2, []) ->
- if Id.compare v1 f2 = 0 then KBindings.empty
- else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2)
- | Typ_app (f1, []), Typ_id v2 ->
- if Id.compare f1 v2 = 0 then KBindings.empty
- else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2)
- | Typ_var kid, _ when KidSet.mem kid goals -> KBindings.singleton kid (U_typ typ2)
- | Typ_var kid1, Typ_var kid2 when Kid.compare kid1 kid2 = 0 -> KBindings.empty
- | Typ_tup typs1, Typ_tup typs2 ->
- begin
- try List.fold_left (KBindings.merge (merge_unifiers l)) KBindings.empty (List.map2 (unify_typ l) typs1 typs2) with
- | Invalid_argument _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2
- ^ " tuple type is of different length")
- end
- | Typ_app (f1, [arg1]), Typ_app (f2, [arg2a; arg2b])
- when Id.compare (mk_id "atom") f1 = 0 && Id.compare (mk_id "range") f2 = 0 ->
- unify_typ_arg_list 0 KBindings.empty [] [] [arg1; arg1] [arg2a; arg2b]
- | Typ_app (f1, [arg1a; arg1b]), Typ_app (f2, [arg2])
- when Id.compare (mk_id "range") f1 = 0 && Id.compare (mk_id "atom") f2 = 0 ->
- unify_typ_arg_list 0 KBindings.empty [] [] [arg1a; arg1b] [arg2; arg2]
- | Typ_app (f1, args1), Typ_app (f2, args2) when Id.compare f1 f2 = 0 ->
- unify_typ_arg_list 0 KBindings.empty [] [] args1 args2
- | _, _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2)
-
- and unify_typ_arg_list unified acc uargs1 uargs2 args1 args2 =
- match args1, args2 with
- | [], [] when unified = 0 && List.length uargs1 > 0 ->
- unify_error l "Could not unify arg lists" (*FIXME improve error *)
- | [], [] when unified > 0 && List.length uargs1 > 0 -> unify_typ_arg_list 0 acc [] [] uargs1 uargs2
- | [], [] when List.length uargs1 = 0 -> acc
- | (a1 :: a1s), (a2 :: a2s) ->
- begin
- let unifiers, success =
- try unify_typ_args l a1 a2, true with
- | Unification_error _ -> KBindings.empty, false
- in
- let a1s = subst_args_unifiers unifiers a1s in
- let a2s = subst_args_unifiers unifiers a2s in
- let uargs1 = subst_args_unifiers unifiers uargs1 in
- let uargs2 = subst_args_unifiers unifiers uargs2 in
- if success
- then unify_typ_arg_list (unified + 1) (KBindings.merge (merge_unifiers l) unifiers acc) uargs1 uargs2 a1s a2s
- else unify_typ_arg_list unified acc (a1 :: uargs1) (a2 :: uargs2) a1s a2s
- end
- | _, _ -> unify_error l "Cannot unify type lists of different length"
-
- and unify_typ_args l (Typ_arg_aux (typ_arg_aux1, _) as typ_arg1) (Typ_arg_aux (typ_arg_aux2, _) as typ_arg2) =
- match typ_arg_aux1, typ_arg_aux2 with
- | Typ_arg_nexp n1, Typ_arg_nexp n2 ->
- begin
- match unify_nexps l env goals (nexp_simp n1) (nexp_simp n2) with
- | Some (kid, unifier) -> KBindings.singleton kid (U_nexp (nexp_simp unifier))
- | None -> KBindings.empty
- end
- | Typ_arg_typ typ1, Typ_arg_typ typ2 -> unify_typ l typ1 typ2
- | Typ_arg_order ord1, Typ_arg_order ord2 -> unify_order l ord1 ord2
- | _, _ -> unify_error l (string_of_typ_arg typ_arg1 ^ " cannot be unified with type argument " ^ string_of_typ_arg typ_arg2)
- in
-
- match destruct_exist env typ2 with
- | Some (kids, nc, typ2) ->
- let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in
- let (unifiers, _, _) = unify l env typ1 typ2 in
- typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)));
- unifiers, kids, Some nc
- | None ->
- let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in
- unify_typ l typ1 typ2, [], None
+let instantiate_quant (v, arg) (QI_aux (aux, l) as qi) =
+ match aux with
+ | QI_id kopt when Kid.compare (kopt_kid kopt) v = 0 ->
+ typ_debug (lazy ("Instantiated " ^ string_of_quant_item qi));
+ None
+ | QI_id _ -> Some qi
+ | QI_const nc -> Some (QI_aux (QI_const (constraint_subst v arg nc), l))
-let merge_uvars l unifiers1 unifiers2 =
- try KBindings.merge (merge_unifiers l) unifiers1 unifiers2
- with
- | Unification_error (_, m) -> typ_error l ("Could not merge unification variables: " ^ m)
+let instantiate_quants quants unifier =
+ List.map (instantiate_quant unifier) quants |> Util.option_these
(**************************************************************************)
(* 3.5. Subtyping with existentials *)
@@ -1750,16 +1629,6 @@ let destruct_atom_kid env typ =
when string_of_id f = "range" && Kid.compare kid1 kid2 = 0 -> Some kid1
| _ -> None
-let nc_subst_uvar kid uvar nc =
- match uvar with
- | U_nexp nexp -> nc_subst_nexp kid (unaux_nexp nexp) nc
- | _ -> nc
-
-let uv_nexp_constraint env (kid, uvar) =
- match uvar with
- | U_nexp nexp -> Env.add_constraint (nc_eq (nvar kid) nexp) env
- | _ -> env
-
(* The kid_order function takes a set of Int-kinded kids, and returns
a list of those kids in the order they appear in a type, as well as
a set containing all the kids that did not occur in the type. We
@@ -1809,8 +1678,8 @@ let rec alpha_equivalent env typ1 typ2 =
| Typ_exist (kids, nc, typ) ->
let (kids, _) = kid_order (KidSet.of_list kids) typ in
let kids = List.map (fun kid -> (kid, new_kid ())) kids in
- let nc = List.fold_left (fun nc (kid, nk) -> nc_subst_nexp kid (Nexp_var nk) nc) nc kids in
- let typ = List.fold_left (fun nc (kid, nk) -> typ_subst_nexp kid (Nexp_var nk) nc) typ kids in
+ let nc = List.fold_left (fun nc (kid, nk) -> constraint_subst kid (arg_nexp (nvar nk)) nc) nc kids in
+ let typ = List.fold_left (fun nc (kid, nk) -> typ_subst kid (arg_nexp (nvar nk)) nc) typ kids in
let kids = List.map snd kids in
Typ_exist (kids, nc, typ)
| Typ_app (id, args) ->
@@ -1836,6 +1705,11 @@ let unwrap_exist env typ =
| Some (kids, nc, typ) -> (kids, nc, typ)
| None -> ([], nc_true, typ)
+let unifier_constraint env (v, arg) =
+ match arg with
+ | Typ_arg_aux (Typ_arg_nexp nexp, _) -> Env.add_constraint (nc_eq (nvar v) nexp) env
+ | _ -> env
+
let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as typ2) =
typ_print (lazy (("Subtype " |> Util.green |> Util.clear) ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2));
match typ_aux1, typ_aux2 with
@@ -1854,12 +1728,9 @@ let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as t
let env = add_existential l kids1 nc1 env in
let env = add_typ_vars l (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2))) env in
let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in
+ if not (kids2 = []) then typ_error l "Universally quantified constraint generated" else ();
let env = Env.add_constraint (nc_eq nexp1 nexp2) env in
- let constr var_of =
- Constraint.forall (List.map var_of kids2)
- (nc_constraint env var_of (nc_negate nc2))
- in
- if prove_z3' env constr then ()
+ if prove env nc2 then ()
else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env))
| _, _ ->
match destruct_exist env typ1, unwrap_exist env (Env.canonicalize env typ2) with
@@ -1869,24 +1740,14 @@ let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as t
typ_debug (lazy "Subtype check with unification");
let env = add_typ_vars l kids env in
let kids' = KidSet.elements (KidSet.diff (KidSet.of_list kids) (typ_frees typ2)) in
- let unifiers, existential_kids, existential_nc =
- try unify l env typ2 typ1 with
+ if not (kids' = []) then typ_error l "Universally quantified constraint generated" else ();
+ let unifiers =
+ try unify l env typ2 typ1 (KidSet.of_list kids) with
| Unification_error (_, m) -> typ_error l m
in
- let nc = List.fold_left (fun nc (kid, uvar) -> nc_subst_uvar kid uvar nc) nc (KBindings.bindings unifiers) in
- let env = List.fold_left uv_nexp_constraint env (KBindings.bindings unifiers) in
- let env = match existential_kids, existential_nc with
- | [], None -> env
- | _, Some enc ->
- let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env existential_kids in
- Env.add_constraint enc env
- | _, None -> assert false (* Cannot have existential_kids without existential_nc *)
- in
- let constr var_of =
- Constraint.forall (List.map var_of kids')
- (nc_constraint env var_of (nc_negate nc))
- in
- if prove_z3' env constr then ()
+ let nc = List.fold_left (fun nc (kid, uvar) -> constraint_subst kid uvar nc) nc (KBindings.bindings unifiers) in
+ let env = List.fold_left unifier_constraint env (KBindings.bindings unifiers) in
+ if prove env nc then ()
else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env))
let typ_equality l env typ1 typ2 =
@@ -1967,40 +1828,13 @@ let is_typ_kid kid = function
| KOpt_aux (KOpt_kind (K_aux (K_type, _), kid'), _) -> Kid.compare kid kid' = 0
| _ -> false
-let rec instantiate_quants quants kid uvar = match quants with
- | [] -> []
- | ((QI_aux (QI_id kinded_id, _) as quant) :: quants) ->
- typ_debug (lazy ("instantiating quant " ^ string_of_quant_item quant));
- begin
- match uvar with
- | U_nexp nexp ->
- if is_nat_kid kid kinded_id
- then instantiate_quants quants kid uvar
- else quant :: instantiate_quants quants kid uvar
- | U_order ord ->
- if is_order_kid kid kinded_id
- then instantiate_quants quants kid uvar
- else quant :: instantiate_quants quants kid uvar
- | U_typ typ ->
- if is_typ_kid kid kinded_id
- then instantiate_quants quants kid uvar
- else quant :: instantiate_quants quants kid uvar
- end
- | ((QI_aux (QI_const nc, l)) :: quants) ->
- begin
- match uvar with
- | U_nexp nexp ->
- QI_aux (QI_const (nc_subst_nexp kid (unaux_nexp nexp) nc), l) :: instantiate_quants quants kid uvar
- | _ -> (QI_aux (QI_const nc, l)) :: instantiate_quants quants kid uvar
- end
-
let instantiate_simple_equations =
let rec find_eqs kid (NC_aux (nc,_)) =
match nc with
| NC_equal (Nexp_aux (Nexp_var kid',_), nexp)
when Kid.compare kid kid' == 0 &&
not (KidSet.mem kid (nexp_frees nexp)) ->
- [U_nexp nexp]
+ [arg_nexp nexp]
| NC_and (nexp1, nexp2) ->
find_eqs kid nexp1 @ find_eqs kid nexp2
| _ -> []
@@ -2275,7 +2109,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
| ((E_aux (E_if (cond, (E_aux (E_throw _, _) | E_aux (E_block [E_aux (E_throw _, _)], _)), _), _) as exp) :: exps) ->
let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in
let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in
- let env = add_opt_constraint (option_map nc_negate (assert_constraint env false cond')) env in
+ let env = add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env in
texp :: check_block l env exps typ
| (exp :: exps) ->
let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in
@@ -2318,7 +2152,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
in
let check_fexp (FE_aux (FE_Fexp (field, exp), (l, ()))) =
let (typq, rectyp_q, field_typ, _) = Env.get_accessor rectyp_id field env in
- let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q typ with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in
+ let unifiers = try unify l env rectyp_q typ (tyvars_of_typ rectyp_q) with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in
let field_typ' = subst_unifiers unifiers field_typ in
let checked_exp = crule check_exp env exp field_typ' in
FE_aux (FE_Fexp (field, checked_exp), (l, None))
@@ -2333,7 +2167,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
in
let check_fexp (FE_aux (FE_Fexp (field, exp), (l, ()))) =
let (typq, rectyp_q, field_typ, _) = Env.get_accessor rectyp_id field env in
- let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q typ with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in
+ let unifiers = try unify l env rectyp_q typ (tyvars_of_typ rectyp_q) with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in
let field_typ' = subst_unifiers unifiers field_typ in
let checked_exp = crule check_exp env exp field_typ' in
FE_aux (FE_Fexp (field, checked_exp), (l, None))
@@ -2413,7 +2247,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
| E_if (cond, then_branch, else_branch), _ ->
let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in
let then_branch' = crule check_exp (add_opt_constraint (assert_constraint env true cond') env) then_branch typ in
- let else_branch' = crule check_exp (add_opt_constraint (option_map nc_negate (assert_constraint env false cond')) env) else_branch typ in
+ let else_branch' = crule check_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch typ in
annot_exp (E_if (cond', then_branch', else_branch')) typ
| E_exit exp, _ ->
let checked_exp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in
@@ -2563,7 +2397,7 @@ and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ =
required that exp_typ unifies with typ. Returns the annotated
coercion as with type_coercion and also a set of unifiers, or
throws a unification error *)
-and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ =
+and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ =
let strip exp_aux = strip_exp (E_aux (exp_aux, (Parse_ast.Unknown, None))) in
let annot_exp exp typ' = E_aux (exp, (l, Some ((env, typ', no_effect), Some typ))) in
let switch_typ exp typ = match exp with
@@ -2576,8 +2410,8 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ =
typ_print (lazy ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " for unification"));
try
let inferred_cast = irule infer_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) in
- let ityp = typ_of inferred_cast in
- annot_exp (E_cast (ityp, inferred_cast)) ityp, unify l env typ ityp
+ let ityp, env = bind_existential l (typ_of inferred_cast) env in
+ inferred_cast, unify l env typ ityp goals, env
with
| Type_error (_, err) -> try_casts casts
| Unification_error (_, err) -> try_casts casts
@@ -2586,7 +2420,8 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ =
begin
try
typ_debug (lazy "PERFORMING COERCING UNIFICATION");
- annotated_exp, unify l env typ (typ_of annotated_exp)
+ let atyp, env = bind_existential l (typ_of annotated_exp) env in
+ annotated_exp, unify l env typ atyp goals, env
with
| Unification_error (_, m) when Env.allow_casts env ->
let casts = filter_casts env (typ_of annotated_exp) typ (Env.get_casts env) in
@@ -2708,11 +2543,11 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
| Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) ->
begin
try
+ let goals = List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_quant_item quants) in
typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ));
- let unifiers, _, _ (* FIXME! *) = unify l env ret_typ typ in
- typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)));
+ let unifiers = unify l env ret_typ typ goals in
let arg_typ' = subst_unifiers unifiers arg_typ in
- let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
if (match quants' with [] -> false | _ -> true)
then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat)
else ();
@@ -2742,12 +2577,10 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
try
typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for pattern " ^ string_of_typ typ));
- let unifiers, _, _ (* FIXME! *) = unify l env typ2 typ in
-
- typ_debug (lazy ("unifiers: " ^ string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)));
-
+ (* FIXME: There's no obvious goals here *)
+ let unifiers = unify l env typ2 typ (tyvars_of_typ typ2) in
let arg_typ' = subst_unifiers unifiers typ1 in
- let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
if (match quants' with [] -> false | _ -> true)
then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat)
else ();
@@ -2763,10 +2596,9 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
try
typ_debug (lazy "Unifying mapping forwards failed, trying backwards.");
typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for pattern " ^ string_of_typ typ));
- let unifiers, _, _ (* FIXME! *) = unify l env typ1 typ in
- typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)));
+ let unifiers = unify l env typ1 typ (tyvars_of_typ typ1) in
let arg_typ' = subst_unifiers unifiers typ2 in
- let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
if (match quants' with [] -> false | _ -> true)
then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat)
else ();
@@ -2955,7 +2787,7 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as
| Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env ->
let eff = if is_register then mk_effect [BE_wreg] else no_effect in
let (typq, rectyp_q, field_typ, _) = Env.get_accessor rectyp_id field env in
- let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q regtyp with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in
+ let unifiers = try unify l env rectyp_q regtyp (tyvars_of_typ rectyp_q) with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in
let field_typ' = subst_unifiers unifiers field_typ in
let checked_exp = crule check_exp env exp field_typ' in
annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp eff, field)) field_typ') checked_exp, env
@@ -3190,7 +3022,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
in
let check_fexp (FE_aux (FE_Fexp (field, exp), (l, ()))) =
let (typq, rectyp_q, field_typ, _) = Env.get_accessor rectyp_id field env in
- let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q typ with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in
+ let unifiers = try unify l env rectyp_q typ (tyvars_of_typ rectyp_q) with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in
let field_typ' = subst_unifiers unifiers field_typ in
let inferred_exp = crule check_exp env exp field_typ' in
FE_aux (FE_Fexp (field, inferred_exp), (l, None))
@@ -3261,7 +3093,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
| E_if (cond, then_branch, else_branch) ->
let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in
let then_branch' = irule infer_exp (add_opt_constraint (assert_constraint env true cond') env) then_branch in
- let else_branch' = crule check_exp (add_opt_constraint (option_map nc_negate (assert_constraint env false cond')) env) else_branch (typ_of then_branch') in
+ let else_branch' = crule check_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch (typ_of then_branch') in
annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch')
| E_vector_access (v, n) -> infer_exp env (E_aux (E_app (mk_id "vector_access", [v; n]), (l, ())))
| E_vector_update (v, n, exp) -> infer_exp env (E_aux (E_app (mk_id "vector_update", [v; n; exp]), (l, ())))
@@ -3337,160 +3169,104 @@ and instantiation_of_without_type (E_aux (exp_aux, (l, _)) as exp) =
| E_app (f, xs) -> snd (infer_funapp' l (Env.no_casts env) f (Env.get_val_spec f env) (List.map strip_exp xs) None)
| _ -> invalid_arg ("instantiation_of expected application, got " ^ string_of_exp exp)
-and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
- let annot_exp exp typ eff = E_aux (exp, (l, Some ((env, typ, eff), ret_ctx_typ))) in
- let switch_annot env typ = function
- | (E_aux (exp, (l, Some (_, _, eff)))) -> E_aux (exp, (l, Some (env, typ, eff)))
- | _ -> failwith "Cannot switch annot for unannotated function"
- in
- let all_unifiers = ref KBindings.empty in
- let ex_goal = ref None in
- let prove_goal env = match !ex_goal with
- | Some goal when prove env goal -> ()
- | Some goal -> typ_error l ("Could not prove existential goal: " ^ string_of_n_constraint goal)
- | None -> ()
- in
+and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ =
+ let annot_exp exp typ eff = E_aux (exp, (l, Some ((env, typ, eff), expected_ret_typ))) in
+ let is_bound env kid = KBindings.mem kid (Env.get_typ_vars env) in
+
+ (* First we record all the type variables when we start checking the
+ application, so we can distinguish them from existentials
+ introduced by instantiating function arguments later. *)
let universals = Env.get_typ_vars env in
let universal_constraints = Env.get_constraints env in
- let is_bound kid env = KBindings.mem kid (Env.get_typ_vars env) in
- let rec number n = function
- | [] -> []
- | (x :: xs) -> (n, x) :: number (n + 1) xs
- in
- let solve_quant env = function
- | QI_aux (QI_id _, _) -> false
- | QI_aux (QI_const nc, _) -> prove env nc
- in
- let record_unifiers unifiers =
- let previous_unifiers = !all_unifiers in
- let updated_unifiers = KBindings.map (subst_uvar_unifiers unifiers) previous_unifiers in
- all_unifiers := merge_uvars l updated_unifiers unifiers;
- in
- let rec instantiate env quants typs ret_typ args =
- match typs, args with
- | (utyps, []), (uargs, []) ->
- begin
- typ_debug (lazy ("Got unresolved args: " ^ string_of_list ", " (fun (_, exp) -> string_of_exp exp) uargs));
- if List.for_all (solve_quant env) quants
- then
- let iuargs = List.map2 (fun utyp (n, uarg) -> (n, crule check_exp env uarg utyp)) utyps uargs in
- (iuargs, ret_typ, env)
- else typ_raise l (Err_unresolved_quants (f, quants, Env.get_locals env, Env.get_constraints env))
- end
- | (utyps, (typ :: typs)), (uargs, ((n, arg) :: args))
- when List.for_all (fun kid -> is_bound kid env) (KidSet.elements (typ_frees typ)) ->
- begin
- let carg = crule check_exp env arg typ in
- let (iargs, ret_typ', env) = instantiate env quants (utyps, typs) ret_typ (uargs, args) in
- ((n, carg) :: iargs, ret_typ', env)
- end
- | (utyps, (typ :: typs)), (uargs, ((n, arg) :: args)) ->
- begin
- typ_debug (lazy ("INSTANTIATE: " ^ string_of_exp arg ^ " with " ^ string_of_typ typ));
- let iarg = irule infer_exp env arg in
- typ_debug (lazy ("INFER: " ^ string_of_exp arg ^ " type " ^ string_of_typ (typ_of iarg)));
- try
- (* If we get an existential when instantiating, we prepend
- the identifier of the exisitential with the tag argN# to
- denote that it was bound by the Nth argument to the
- function. *)
- let ex_tag = "arg" ^ string_of_int n ^ "#" in
- let iarg, (unifiers, ex_kids, ex_nc) = type_coercion_unify env iarg typ in
- typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)));
- typ_debug (lazy ("EX KIDS: " ^ string_of_list ", " string_of_kid ex_kids));
- let env = match ex_kids, ex_nc with
- | [], None -> env
- | _, Some enc ->
- let enc = List.fold_left (fun nc kid -> nc_subst_nexp kid (Nexp_var (prepend_kid ex_tag kid)) nc) enc ex_kids in
- let env = List.fold_left (fun env kid -> Env.add_typ_var l (prepend_kid ex_tag kid) K_int env) env ex_kids in
- Env.add_constraint enc env
- | _, None -> assert false (* Cannot have ex_kids without ex_nc *)
- in
- let tag_unifier uvar = List.fold_left (fun uvar kid -> uvar_subst_nexp kid (Nexp_var (prepend_kid ex_tag kid)) uvar) uvar ex_kids in
- let unifiers = KBindings.map tag_unifier unifiers in
- record_unifiers unifiers;
- let utyps' = List.map (subst_unifiers unifiers) utyps in
- let typs' = List.map (subst_unifiers unifiers) typs in
- let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
- let ret_typ' = subst_unifiers unifiers ret_typ in
- let (iargs, ret_typ'', env) = instantiate env quants' (utyps', typs') ret_typ' (uargs, args) in
- ((n, iarg) :: iargs, ret_typ'', env)
- with
- | Unification_error (l, str) ->
- typ_print (lazy ("Unification error: " ^ str));
- instantiate env quants (typ :: utyps, typs) ret_typ ((n, arg) :: uargs, args)
- end
- | (_, []), _ -> typ_error l ("Function " ^ string_of_id f ^ " applied to too many arguments")
- | _, (_, []) -> typ_error l ("Function " ^ string_of_id f ^ " not applied to enough arguments")
- in
- let instantiate_ret env quants typs ret_typ =
- match ret_ctx_typ with
- | None -> (quants, typs, ret_typ, env)
- | Some rct when is_exist (Env.expand_synonyms env rct) -> (quants, typs, ret_typ, env)
- | Some rct ->
- begin
- typ_debug (lazy ("RCT is " ^ string_of_typ rct));
- typ_debug (lazy ("INSTANTIATE RETURN:" ^ string_of_typ ret_typ));
- let unifiers, ex_kids, ex_nc =
- try unify l env ret_typ rct with
- | Unification_error _ -> typ_debug (lazy "UERROR"); KBindings.empty, [], None
- in
- typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)));
- if ex_kids = [] then () else (typ_debug (lazy ("EX GOAL: " ^ string_of_option string_of_n_constraint ex_nc)); ex_goal := ex_nc);
- record_unifiers unifiers;
- let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env ex_kids in
- let typs' = List.map (subst_unifiers unifiers) typs in
- let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
- let ret_typ' =
- match ex_nc with
- | None -> subst_unifiers unifiers ret_typ
- | Some nc -> mk_typ (Typ_exist (ex_kids, nc, subst_unifiers unifiers ret_typ))
- in
- (quants', typs', ret_typ', env)
- end
- in
+
let quants, typ_args, typ_ret, eff =
match Env.expand_synonyms env f_typ with
- | Typ_aux (Typ_fn (typ_args, typ_ret, eff), _) -> quant_items typq, typ_args, typ_ret, eff
+ | Typ_aux (Typ_fn (typ_args, typ_ret, eff), _) -> ref (quant_items typq), typ_args, ref typ_ret, eff
| _ -> typ_error l (string_of_typ f_typ ^ " is not a function type")
in
- let unifiers = instantiate_simple_equations quants in
- typ_debug (lazy "Instantiating from equations");
- typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); all_unifiers := unifiers;
- let typ_args = List.map (subst_unifiers unifiers) typ_args in
- let quants = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
- let typ_ret = subst_unifiers unifiers typ_ret in
- let quants, typ_args, typ_ret, env =
- instantiate_ret env quants typ_args typ_ret
+
+ typ_debug (lazy ("Quantifiers " ^ Util.string_of_list ", " string_of_quant_item !quants));
+
+ if not (List.length typ_args = List.length xs) then
+ typ_error l (Printf.sprintf "Function %s applied to %d args, expected %d" (string_of_id f) (List.length xs) (List.length typ_args))
+ else ();
+
+ let instantiate_quant (v, arg) (QI_aux (aux, l) as qi) =
+ match aux with
+ | QI_id kopt when Kid.compare (kopt_kid kopt) v = 0 -> None
+ | QI_id _ -> Some qi
+ | QI_const nc -> Some (QI_aux (QI_const (constraint_subst v arg nc), l))
in
- let (xs_instantiated, typ_ret, env) = instantiate env quants ([], typ_args) typ_ret ([], number 0 xs) in
- let xs_reordered = List.map snd (List.sort (fun (n, _) (m, _) -> compare n m) xs_instantiated) in
- prove_goal env;
+ let typ_args = match expected_ret_typ with
+ | None -> typ_args
+ | Some expect when is_exist (Env.expand_synonyms env expect) || is_exist !typ_ret -> typ_args
+ | Some expect ->
+ let goals = List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_quant_item !quants) in
+ try
+ let unifiers = unify l env !typ_ret expect goals |> KBindings.bindings in
+ typ_debug (lazy (Util.("Unifiers " |> magenta |> clear)
+ ^ Util.string_of_list ", " (fun (v, arg) -> string_of_kid v ^ " => " ^ string_of_typ_arg arg) unifiers));
+ List.iter (fun unifier -> quants := instantiate_quants !quants unifier) unifiers;
+ List.iter (fun (v, arg) -> typ_ret := typ_subst v arg !typ_ret) unifiers;
+ List.map (fun typ -> List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ unifiers) typ_args
+ with Unification_error _ -> typ_args
+ in
+
+ (* We now iterate throught the function arguments, checking them and
+ instantiating quantifiers. *)
+ let instantiate env arg typ remaining_typs =
+ if KidSet.for_all (is_bound env) (tyvars_of_typ typ) then
+ crule check_exp env arg typ, remaining_typs, env
+ else
+ let goals = List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_quant_item !quants) in
+ let inferred_arg = irule infer_exp env arg in
+ let inferred_arg, unifiers, env =
+ try type_coercion_unify env goals inferred_arg typ with
+ | Unification_error (l, m) -> typ_error l m
+ in
+ let unifiers = KBindings.bindings unifiers in
+ typ_debug (lazy (Util.("Unifiers " |> magenta |> clear)
+ ^ Util.string_of_list ", " (fun (v, arg) -> string_of_kid v ^ " => " ^ string_of_typ_arg arg) unifiers));
+ List.iter (fun unifier -> quants := instantiate_quants !quants unifier) unifiers;
+ List.iter (fun (v, arg) -> typ_ret := typ_subst v arg !typ_ret) unifiers;
+ let remaining_typs =
+ List.map (fun typ -> List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ unifiers) remaining_typs
+ in
+ inferred_arg, remaining_typs, env
+ in
+ let fold_instantiate (xs, args, env) x =
+ match args with
+ | arg :: remaining_args ->
+ let x, remaining_args, env = instantiate env x arg remaining_args in
+ (x :: xs, remaining_args, env)
+ | [] -> raise (Reporting.err_unreachable l __POS__ "Empty arguments during instantiation")
+ in
+ let xs, _, env = List.fold_left fold_instantiate ([], typ_args, env) xs in
+ let xs = List.rev xs in
+
+ if not (List.for_all (solve_quant env) !quants) then
+ typ_raise l (Err_unresolved_quants (f, !quants, Env.get_locals env, Env.get_constraints env))
+ else ();
let ty_vars = List.map fst (KBindings.bindings (Env.get_typ_vars env)) in
let existentials = List.filter (fun kid -> not (KBindings.mem kid universals)) ty_vars in
let num_new_ncs = List.length (Env.get_constraints env) - List.length universal_constraints in
- let ex_constraints = take num_new_ncs (Env.get_constraints env) in
+ let ex_constraints = take num_new_ncs (Env.get_constraints env) in
typ_debug (lazy ("Existentials: " ^ string_of_list ", " string_of_kid existentials));
typ_debug (lazy ("Existential constraints: " ^ string_of_list ", " string_of_n_constraint ex_constraints));
let typ_ret =
- if KidSet.is_empty (KidSet.of_list existentials) || KidSet.is_empty (typ_frees typ_ret)
- then (typ_debug (lazy "Returning Existential"); typ_ret)
- else mk_typ (Typ_exist (existentials, List.fold_left nc_and nc_true ex_constraints, typ_ret))
+ if KidSet.is_empty (KidSet.of_list existentials) || KidSet.is_empty (typ_frees !typ_ret)
+ then !typ_ret
+ else mk_typ (Typ_exist (existentials, List.fold_left nc_and nc_true ex_constraints, !typ_ret))
in
let typ_ret = simp_typ typ_ret in
- let exp = annot_exp (E_app (f, xs_reordered)) typ_ret eff in
- typ_debug (lazy ("RETURNING: " ^ string_of_typ (typ_of exp)));
- match ret_ctx_typ with
- | None ->
- exp, !all_unifiers
- | Some rct ->
- let exp = type_coercion env exp rct in
- typ_debug (lazy ("RETURNING AFTER COERCION " ^ string_of_typ (typ_of exp)));
- exp, !all_unifiers
+ let exp = annot_exp (E_app (f, xs)) typ_ret eff in
+ typ_debug (lazy ("RETURNING: " ^ string_of_exp exp));
+
+ exp, KBindings.empty (* FIXME *)
and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) (Typ_aux (typ_aux, _) as typ) =
let (Typ_aux (typ_aux, _) as typ), env = bind_existential l typ env in
@@ -3589,10 +3365,9 @@ and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) (
begin
try
typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for mapping-pattern " ^ string_of_typ typ));
- let unifiers, _, _ (* FIXME! *) = unify l env ret_typ typ in
- typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)));
+ let unifiers = unify l env ret_typ typ (tyvars_of_typ ret_typ) in
let arg_typ' = subst_unifiers unifiers arg_typ in
- let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
if (match quants' with [] -> false | _ -> true)
then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat)
else ();
@@ -3620,10 +3395,9 @@ and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) (
begin
try
typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for mapping-pattern " ^ string_of_typ typ));
- let unifiers, _, _ (* FIXME! *) = unify l env typ2 typ in
- typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)));
+ let unifiers = unify l env typ2 typ (tyvars_of_typ typ2) in
let arg_typ' = subst_unifiers unifiers typ1 in
- let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
if (match quants' with [] -> false | _ -> true)
then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat)
else ();
@@ -3638,10 +3412,9 @@ and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) (
try
typ_debug (lazy "Unifying mapping forwards failed, trying backwards.");
typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for mapping-pattern " ^ string_of_typ typ));
- let unifiers, _, _ (* FIXME! *) = unify l env typ1 typ in
- typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers)));
+ let unifiers = unify l env typ1 typ (tyvars_of_typ typ1) in
let arg_typ' = subst_unifiers unifiers typ2 in
- let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in
+ let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
if (match quants' with [] -> false | _ -> true)
then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat)
else ();
@@ -4420,30 +4193,33 @@ let check_type_union env variant typq (Tu_aux (tu, l)) =
|> Env.add_val_spec v (typq, typ')
(* FIXME: This code is duplicated with general kind-checking code in environment, can they be merged? *)
-let mk_synonym typq typ =
+let mk_synonym typq typ_arg =
let kopts, ncs = quant_split typq in
let rec subst_args kopts args =
match kopts, args with
| kopt :: kopts, Typ_arg_aux (Typ_arg_nexp arg, _) :: args when is_nat_kopt kopt ->
- let typ, ncs = subst_args kopts args in
- typ_subst_nexp (kopt_kid kopt) (unaux_nexp arg) typ,
- List.map (nc_subst_nexp (kopt_kid kopt) (unaux_nexp arg)) ncs
+ let typ_arg, ncs = subst_args kopts args in
+ typ_arg_subst (kopt_kid kopt) (arg_nexp arg) typ_arg,
+ List.map (constraint_subst (kopt_kid kopt) (arg_nexp arg)) ncs
| kopt :: kopts, Typ_arg_aux (Typ_arg_typ arg, _) :: args when is_typ_kopt kopt ->
- let typ, ncs = subst_args kopts args in
- typ_subst_typ (kopt_kid kopt) (unaux_typ arg) typ, ncs
+ let typ_arg, ncs = subst_args kopts args in
+ typ_arg_subst (kopt_kid kopt) (arg_typ arg) typ_arg, ncs
| kopt :: kopts, Typ_arg_aux (Typ_arg_order arg, _) :: args when is_order_kopt kopt ->
- let typ, ncs = subst_args kopts args in
- typ_subst_order (kopt_kid kopt) (unaux_order arg) typ, ncs
- | [], [] -> typ, ncs
+ let typ_arg, ncs = subst_args kopts args in
+ typ_arg_subst (kopt_kid kopt) (arg_order arg) typ_arg, ncs
+ | kopt :: kopts, Typ_arg_aux (Typ_arg_bool arg, _) :: args when is_order_kopt kopt ->
+ let typ_arg, ncs = subst_args kopts args in
+ typ_arg_subst (kopt_kid kopt) (arg_bool arg) typ_arg, ncs
+ | [], [] -> typ_arg, ncs
| _, Typ_arg_aux (_, l) :: _ -> typ_error l "Synonym applied to bad arguments"
| _, _ -> typ_error Parse_ast.Unknown "Synonym applied to bad arguments"
in
fun env args ->
- let typ, ncs = subst_args kopts args in
+ let typ_arg, ncs = subst_args kopts args in
if List.for_all (prove env) ncs
- then typ
+ then typ_arg
else typ_error Parse_ast.Unknown ("Could not prove constraints " ^ string_of_list ", " string_of_n_constraint ncs
- ^ " in type synonym " ^ string_of_typ typ
+ ^ " in type synonym " ^ string_of_typ_arg typ_arg
^ " with " ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env))
let check_kinddef env (KD_aux (kdef, (l, _))) =
@@ -4458,8 +4234,11 @@ let rec check_typedef : 'a. Env.t -> 'a type_def -> (tannot def) list * Env.t =
fun env (TD_aux (tdef, (l, _))) ->
let td_err () = raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "Unimplemented Typedef") in
match tdef with
- | TD_abbrev (id, nmscm, (TypSchm_aux (TypSchm_ts (typq, typ), _))) ->
- [DEF_type (TD_aux (tdef, (l, None)))], Env.add_typ_synonym id (mk_synonym typq typ) env
+ | TD_abbrev (id, typq, (Typ_arg_aux (Typ_arg_typ _, _) as typ_arg)) ->
+ [DEF_type (TD_aux (tdef, (l, None)))], Env.add_typ_synonym id (mk_synonym typq typ_arg) env
+ (* For type synonyms for non-Type kinds we omit them from the AST *)
+ | TD_abbrev (id, typq, typ_arg) ->
+ [], Env.add_typ_synonym id (mk_synonym typq typ_arg) env
| TD_record (id, nmscm, typq, fields, _) ->
[DEF_type (TD_aux (tdef, (l, None)))], Env.add_record id typq fields env
| TD_variant (id, nmscm, typq, arms, _) ->
diff --git a/src/type_check.mli b/src/type_check.mli
index f08272de..7dc2da30 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -149,7 +149,7 @@ module Env : sig
won't throw any exceptions. *)
val get_ret_typ : t -> typ option
- val get_typ_synonym : id -> t -> (t -> typ_arg list -> typ)
+ val get_typ_synonym : id -> t -> (t -> typ_arg list -> typ_arg)
val get_overloads : id -> t -> id list
@@ -357,28 +357,21 @@ val destruct_numeric : Env.t -> typ -> (kid list * n_constraint * nexp) option
val destruct_vector : Env.t -> typ -> (nexp * order * typ) option
-type uvar =
- | U_nexp of nexp
- | U_order of order
- | U_typ of typ
+val subst_unifiers : typ_arg KBindings.t -> typ -> typ
-val string_of_uvar : uvar -> string
-
-val subst_unifiers : uvar KBindings.t -> typ -> typ
-
-val unify : l -> Env.t -> typ -> typ -> uvar KBindings.t * kid list * n_constraint option
+val unify : l -> Env.t -> typ -> typ -> typ_arg KBindings.t * kid list * n_constraint option
val alpha_equivalent : Env.t -> typ -> typ -> bool
(** Throws Invalid_argument if the argument is not a E_app expression *)
-val instantiation_of : tannot exp -> uvar KBindings.t
+val instantiation_of : tannot exp -> typ_arg KBindings.t
(** Doesn't use the type of the expression when calculating instantiations.
May fail if the arguments aren't sufficient to calculate all unifiers. *)
-val instantiation_of_without_type : tannot exp -> uvar KBindings.t
+val instantiation_of_without_type : tannot exp -> typ_arg KBindings.t
(* Type variable instantiations that inference will extract from constraints *)
-val instantiate_simple_equations : quant_item list -> uvar KBindings.t
+val instantiate_simple_equations : quant_item list -> typ_arg KBindings.t
val propagate_exp_effect : tannot exp -> tannot exp
diff --git a/src/type_error.ml b/src/type_error.ml
index 7551970f..0fa238ed 100644
--- a/src/type_error.ml
+++ b/src/type_error.ml
@@ -198,13 +198,16 @@ let rec analyze_unresolved_quant locals ncs = function
empty
let rec pp_type_error = function
- | Err_no_casts (exp, typ_from, typ_to, trigger, _) ->
+ | Err_no_casts (exp, typ_from, typ_to, trigger, reasons) ->
let coercion =
group (string "Tried performing type coercion from" ^/^ Pretty_print_sail.doc_typ typ_from
^/^ string "to" ^/^ Pretty_print_sail.doc_typ typ_to
^/^ string "on" ^/^ Pretty_print_sail.doc_exp exp)
in
- coercion ^^ hardline ^^ (string "Failed because" ^/^ pp_type_error trigger)
+ coercion ^^ hardline
+ ^^ (string "Coercion failed because:" ^//^ pp_type_error trigger)
+ ^^ hardline
+ ^^ (string "Possible reasons:" ^//^ separate_map hardline pp_type_error reasons)
| Err_no_overloading (id, errs) ->
string ("No overloadings for " ^ string_of_id id ^ ", tried:") ^//^