From 2c25110ad2f5e636239ba65a2154aae79ffa253c Mon Sep 17 00:00:00 2001 From: Alasdair Armstrong Date: Fri, 7 Dec 2018 21:53:29 +0000 Subject: Working on better flow typing for ASL On a new branch because it's completely broken everything for now --- editors/sail2-mode.el | 2 +- language/sail.ott | 30 +- src/ast_util.ml | 168 +++++----- src/ast_util.mli | 42 +-- src/constraint.ml | 229 ++++---------- src/constraint.mli | 34 +- src/initial_check.ml | 22 +- src/lexer.mll | 1 + src/ocaml_backend.ml | 9 +- src/parse_ast.ml | 3 +- src/parser.mly | 12 +- src/pretty_print_coq.ml | 3 +- src/pretty_print_lem.ml | 3 +- src/pretty_print_sail.ml | 10 +- src/rewriter.ml | 2 +- src/rewrites.ml | 7 +- src/spec_analysis.ml | 4 +- src/state.ml | 10 +- src/type_check.ml | 799 +++++++++++++++++------------------------------ src/type_check.mli | 19 +- src/type_error.ml | 7 +- 21 files changed, 532 insertions(+), 884 deletions(-) diff --git a/editors/sail2-mode.el b/editors/sail2-mode.el index e0081fb6..b7998357 100644 --- a/editors/sail2-mode.el +++ b/editors/sail2-mode.el @@ -12,7 +12,7 @@ "mapping" "where" "with")) (defconst sail2-kinds - '("Int" "Type" "Order" "inc" "dec" + '("Int" "Type" "Order" "Bool" "inc" "dec" "barr" "depend" "rreg" "wreg" "rmem" "rmemt" "wmv" "wmvt" "eamem" "wmem" "exmem" "undef" "unspec" "nondet" "escape" "configuration")) diff --git a/language/sail.ott b/language/sail.ott index a0b02a1c..b35e64d3 100644 --- a/language/sail.ott +++ b/language/sail.ott @@ -175,6 +175,7 @@ kind :: 'K_' ::= | Type :: :: type {{ com kind of types }} | Int :: :: int {{ com kind of natural number size expressions }} | Order :: :: order {{ com kind of vector order specifications }} + | Bool :: :: bool {{ com kind of constraints }} nexp :: 'Nexp_' ::= {{ com numeric expression, of kind Int }} @@ -273,22 +274,25 @@ typ :: 'Typ_' ::= typ_arg :: 'Typ_arg_' ::= {{ com type constructor arguments of all kinds }} {{ aux _ l }} - | nexp :: :: nexp - | typ :: :: typ - | order :: :: order + | nexp :: :: nexp + | typ :: :: typ + | order :: :: order + | n_constraint :: :: bool n_constraint :: 'NC_' ::= {{ com constraint over kind Int }} {{ aux _ l }} - | nexp = nexp' :: :: equal - | nexp >= nexp' :: :: bounded_ge - | nexp '<=' nexp' :: :: bounded_le - | nexp != nexp' :: :: not_equal - | kid 'IN' { num1 , ... , numn } :: :: set - | n_constraint \/ n_constraint' :: :: or - | n_constraint /\ n_constraint' :: :: and - | true :: :: true - | false :: :: false + | nexp = nexp' :: :: equal + | nexp >= nexp' :: :: bounded_ge + | nexp '<=' nexp' :: :: bounded_le + | nexp != nexp' :: :: not_equal + | kid 'IN' { num1 , ... , numn } :: :: set + | n_constraint \/ n_constraint' :: :: or + | n_constraint /\ n_constraint' :: :: and + | id ( typ_arg0 , ... , typ_argn ) :: :: app + | kid :: :: var + | true :: :: true + | false :: :: false % Note only id on the left and constants on the right in a % finite-set-bound, as we don't think we need anything more @@ -366,7 +370,7 @@ type_def {{ ocaml 'a type_def }} {{ lem type_def 'a }} :: 'TD_' ::= type_def_aux :: 'TD_' ::= {{ com type definition body }} - | typedef id name_scm_opt = typschm :: :: abbrev + | type id typquant = typ_arg :: :: abbrev {{ com type abbreviation }} {{ texlong }} | typedef id name_scm_opt = const struct typquant { typ1 id1 ; ... ; typn idn semi_opt } :: :: record {{ com struct type definition }} {{ texlong }} diff --git a/src/ast_util.ml b/src/ast_util.ml index e9153f7a..788008d1 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -370,23 +370,16 @@ let nc_lt n1 n2 = nc_lteq (nsum n1 (nint 1)) n2 let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nint 1)) let nc_and nc1 nc2 = mk_nc (NC_and (nc1, nc2)) let nc_or nc1 nc2 = mk_nc (NC_or (nc1, nc2)) +let nc_var kid = mk_nc (NC_var kid) let nc_true = mk_nc NC_true let nc_false = mk_nc NC_false -let rec nc_negate (NC_aux (nc, l)) = - match nc with - | NC_bounded_ge (n1, n2) -> nc_lt n1 n2 - | NC_bounded_le (n1, n2) -> nc_gt n1 n2 - | NC_equal (n1, n2) -> nc_neq n1 n2 - | NC_not_equal (n1, n2) -> nc_eq n1 n2 - | NC_and (n1, n2) -> mk_nc (NC_or (nc_negate n1, nc_negate n2)) - | NC_or (n1, n2) -> mk_nc (NC_and (nc_negate n1, nc_negate n2)) - | NC_false -> mk_nc NC_true - | NC_true -> mk_nc NC_false - | NC_set (kid, []) -> nc_false - | NC_set (kid, [int]) -> nc_neq (nvar kid) (nconstant int) - | NC_set (kid, int :: ints) -> - mk_nc (NC_and (nc_neq (nvar kid) (nconstant int), nc_negate (mk_nc (NC_set (kid, ints))))) +let arg_nexp ?loc:(l=Parse_ast.Unknown) n = Typ_arg_aux (Typ_arg_nexp n, l) +let arg_order ?loc:(l=Parse_ast.Unknown) ord = Typ_arg_aux (Typ_arg_order ord, l) +let arg_typ ?loc:(l=Parse_ast.Unknown) typ = Typ_arg_aux (Typ_arg_typ typ, l) +let arg_bool ?loc:(l=Parse_ast.Unknown) nc = Typ_arg_aux (Typ_arg_bool nc, l) + +let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc])) let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown) @@ -437,6 +430,7 @@ let unaux_nexp (Nexp_aux (nexp, _)) = nexp let unaux_order (Ord_aux (ord, _)) = ord let unaux_typ (Typ_aux (typ, _)) = typ let unaux_kind (K_aux (k, _)) = k +let unaux_constraint (NC_aux (nc, _)) = nc let rec map_exp_annot f (E_aux (exp, annot)) = E_aux (map_exp_annot_aux f exp, f annot) and map_exp_annot_aux f = function @@ -628,6 +622,7 @@ let string_of_kind_aux = function | K_type -> "Type" | K_int -> "Int" | K_order -> "Order" + | K_bool -> "Bool" let string_of_kind (K_aux (k, _)) = string_of_kind_aux k @@ -680,6 +675,7 @@ and string_of_typ_arg_aux = function | Typ_arg_nexp n -> string_of_nexp n | Typ_arg_typ typ -> string_of_typ typ | Typ_arg_order o -> string_of_order o + | Typ_arg_bool nc -> string_of_n_constraint nc and string_of_n_constraint = function | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2 | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 @@ -691,6 +687,8 @@ and string_of_n_constraint = function "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_set (kid, ns), _) -> string_of_kid kid ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}" + | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")" + | NC_aux (NC_var v, _) -> string_of_kid v | NC_aux (NC_true, _) -> "true" | NC_aux (NC_false, _) -> "false" @@ -1142,10 +1140,13 @@ let rec tyvars_of_constraint (NC_aux (nc, _)) = | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2) + | NC_app (id, args) -> + List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ_arg t)) KidSet.empty args + | NC_var kid -> KidSet.singleton kid | NC_true | NC_false -> KidSet.empty -let rec tyvars_of_typ (Typ_aux (t,_)) = +and tyvars_of_typ (Typ_aux (t,_)) = match t with | Typ_internal_unknown -> KidSet.empty | Typ_id _ -> KidSet.empty @@ -1166,6 +1167,7 @@ and tyvars_of_typ_arg (Typ_arg_aux (ta,_)) = | Typ_arg_nexp nexp -> tyvars_of_nexp nexp | Typ_arg_typ typ -> tyvars_of_typ typ | Typ_arg_order _ -> KidSet.empty + | Typ_arg_bool nc -> tyvars_of_constraint nc let tyvars_of_quant_item (QI_aux (qi, _)) = match qi with | QI_id (KOpt_aux ((KOpt_none kid | KOpt_kind (_, kid)), _)) -> @@ -1547,10 +1549,26 @@ let unique l = (* 1. Substitutions *) (**************************************************************************) +let order_subst_aux sv subst = function + | Ord_var kid -> + begin match subst with + | Typ_arg_aux (Typ_arg_order ord, _) when Kid.compare kid sv = 0 -> + unaux_order ord + | _ -> Ord_var kid + end + | Ord_inc -> Ord_inc + | Ord_dec -> Ord_dec + +let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l) + let rec nexp_subst sv subst (Nexp_aux (nexp, l)) = Nexp_aux (nexp_subst_aux sv subst nexp, l) and nexp_subst_aux sv subst = function - | Nexp_id v -> Nexp_id v - | Nexp_var kid -> if Kid.compare kid sv = 0 then subst else Nexp_var kid + | Nexp_var kid -> + begin match subst with + | Typ_arg_aux (Typ_arg_nexp n, _) when Kid.compare kid sv = 0 -> unaux_nexp n + | _ -> Nexp_var kid + end + | Nexp_id id -> Nexp_id id | Nexp_constant c -> Nexp_constant c | Nexp_times (nexp1, nexp2) -> Nexp_times (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2) | Nexp_sum (nexp1, nexp2) -> Nexp_sum (nexp_subst sv subst nexp1, nexp_subst sv subst nexp2) @@ -1564,100 +1582,68 @@ let rec nexp_set_to_or l subst = function | [int] -> NC_equal (subst, nconstant int) | (int :: ints) -> NC_or (mk_nc (NC_equal (subst, nconstant int)), mk_nc (nexp_set_to_or l subst ints)) -let rec nc_subst_nexp sv subst (NC_aux (nc, l)) = NC_aux (nc_subst_nexp_aux l sv subst nc, l) -and nc_subst_nexp_aux l sv subst = function +let rec constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l) +and constraint_subst_aux l sv subst = function | NC_equal (n1, n2) -> NC_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_ge (n1, n2) -> NC_bounded_ge (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_le (n1, n2) -> NC_bounded_le (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_set (kid, ints) as set_nc -> - if Kid.compare kid sv = 0 - then nexp_set_to_or l (mk_nexp subst) ints - else set_nc - | NC_or (nc1, nc2) -> NC_or (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2) - | NC_and (nc1, nc2) -> NC_and (nc_subst_nexp sv subst nc1, nc_subst_nexp sv subst nc2) + begin match subst with + | Typ_arg_aux (Typ_arg_nexp n, _) when Kid.compare kid sv = 0 -> + nexp_set_to_or l n ints + | _ -> set_nc + end + | NC_or (nc1, nc2) -> NC_or (constraint_subst sv subst nc1, constraint_subst sv subst nc2) + | NC_and (nc1, nc2) -> NC_and (constraint_subst sv subst nc1, constraint_subst sv subst nc2) + | NC_app (id, args) -> NC_app (id, List.map (typ_arg_subst sv subst) args) + | NC_var kid -> + begin match subst with + | Typ_arg_aux (Typ_arg_bool nc, _) when Kid.compare kid sv = 0 -> + unaux_constraint nc + | _ -> NC_var kid + end | NC_false -> NC_false | NC_true -> NC_true -let rec typ_subst_nexp sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_nexp_aux sv subst typ, l) -and typ_subst_nexp_aux sv subst = function +and typ_subst sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_aux sv subst typ, l) +and typ_subst_aux sv subst = function | Typ_internal_unknown -> Typ_internal_unknown | Typ_id v -> Typ_id v - | Typ_var kid -> Typ_var kid - | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_nexp sv subst) arg_typs, typ_subst_nexp sv subst ret_typ, effs) - | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_nexp sv subst typ1, typ_subst_nexp sv subst typ2) - | Typ_tup typs -> Typ_tup (List.map (typ_subst_nexp sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_nexp sv subst) args) + | Typ_var kid -> + begin match subst with + | Typ_arg_aux (Typ_arg_typ typ, _) when Kid.compare kid sv = 0 -> + unaux_typ typ + | _ -> Typ_var kid + end + | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst sv subst) arg_typs, typ_subst sv subst ret_typ, effs) + | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst sv subst typ1, typ_subst sv subst typ2) + | Typ_tup typs -> Typ_tup (List.map (typ_subst sv subst) typs) + | Typ_app (f, args) -> Typ_app (f, List.map (typ_arg_subst sv subst) args) | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc_subst_nexp sv subst nc, typ_subst_nexp sv subst typ) -and typ_subst_arg_nexp sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_nexp_aux sv subst arg, l) -and typ_subst_arg_nexp_aux sv subst = function - | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv subst nexp) - | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_nexp sv subst typ) - | Typ_arg_order ord -> Typ_arg_order ord + | Typ_exist (kids, nc, typ) -> Typ_exist (kids, constraint_subst sv subst nc, typ_subst sv subst typ) -let rec typ_subst_typ sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_typ_aux sv subst typ, l) -and typ_subst_typ_aux sv subst = function - | Typ_internal_unknown -> Typ_internal_unknown - | Typ_id v -> Typ_id v - | Typ_var kid -> if Kid.compare kid sv = 0 then subst else Typ_var kid - | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_typ sv subst) arg_typs, typ_subst_typ sv subst ret_typ, effs) - | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_typ sv subst typ1, typ_subst_typ sv subst typ2) - | Typ_tup typs -> Typ_tup (List.map (typ_subst_typ sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_typ sv subst) args) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, typ_subst_typ sv subst typ) -and typ_subst_arg_typ sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_typ_aux sv subst arg, l) -and typ_subst_arg_typ_aux sv subst = function - | Typ_arg_nexp nexp -> Typ_arg_nexp nexp - | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_typ sv subst typ) - | Typ_arg_order ord -> Typ_arg_order ord - -let order_subst_aux sv subst = function - | Ord_var kid -> if Kid.compare kid sv = 0 then subst else Ord_var kid - | Ord_inc -> Ord_inc - | Ord_dec -> Ord_dec - -let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l) - -let rec typ_subst_order sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_order_aux sv subst typ, l) -and typ_subst_order_aux sv subst = function - | Typ_internal_unknown -> Typ_internal_unknown - | Typ_id v -> Typ_id v - | Typ_var kid -> Typ_var kid - | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_order sv subst) arg_typs, typ_subst_order sv subst ret_typ, effs) - | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_order sv subst typ1, typ_subst_order sv subst typ2) - | Typ_tup typs -> Typ_tup (List.map (typ_subst_order sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_order sv subst) args) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, typ_subst_order sv subst typ) -and typ_subst_arg_order sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_order_aux sv subst arg, l) -and typ_subst_arg_order_aux sv subst = function - | Typ_arg_nexp nexp -> Typ_arg_nexp nexp - | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_order sv subst typ) +and typ_arg_subst sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_arg_subst_aux sv subst arg, l) +and typ_arg_subst_aux sv subst = function + | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv subst nexp) + | Typ_arg_typ typ -> Typ_arg_typ (typ_subst sv subst typ) | Typ_arg_order ord -> Typ_arg_order (order_subst sv subst ord) + | Typ_arg_bool nc -> Typ_arg_bool (constraint_subst sv subst nc) -let rec typ_subst_kid sv subst (Typ_aux (typ, l)) = Typ_aux (typ_subst_kid_aux sv subst typ, l) -and typ_subst_kid_aux sv subst = function - | Typ_internal_unknown -> Typ_internal_unknown - | Typ_id v -> Typ_id v - | Typ_var kid -> if Kid.compare kid sv = 0 then Typ_var subst else Typ_var kid - | Typ_fn (arg_typs, ret_typ, effs) -> Typ_fn (List.map (typ_subst_kid sv subst) arg_typs, typ_subst_kid sv subst ret_typ, effs) - | Typ_bidir (typ1, typ2) -> Typ_bidir (typ_subst_kid sv subst typ1, typ_subst_kid sv subst typ2) - | Typ_tup typs -> Typ_tup (List.map (typ_subst_kid sv subst) typs) - | Typ_app (f, args) -> Typ_app (f, List.map (typ_subst_arg_kid sv subst) args) - | Typ_exist (kids, nc, typ) when KidSet.mem sv (KidSet.of_list kids) -> Typ_exist (kids, nc, typ) - | Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc_subst_nexp sv (Nexp_var subst) nc, typ_subst_kid sv subst typ) -and typ_subst_arg_kid sv subst (Typ_arg_aux (arg, l)) = Typ_arg_aux (typ_subst_arg_kid_aux sv subst arg, l) -and typ_subst_arg_kid_aux sv subst = function - | Typ_arg_nexp nexp -> Typ_arg_nexp (nexp_subst sv (Nexp_var subst) nexp) - | Typ_arg_typ typ -> Typ_arg_typ (typ_subst_kid sv subst typ) - | Typ_arg_order ord -> Typ_arg_order (order_subst sv (Ord_var subst) ord) +let subst_kid subst sv v x = + x + |> subst sv (mk_typ_arg (Typ_arg_bool (nc_var v))) + |> subst sv (mk_typ_arg (Typ_arg_nexp (nvar v))) + |> subst sv (mk_typ_arg (Typ_arg_order (Ord_aux (Ord_var v, Parse_ast.Unknown)))) + |> subst sv (mk_typ_arg (Typ_arg_typ (mk_typ (Typ_var v)))) let quant_item_subst_kid_aux sv subst = function | QI_id (KOpt_aux (KOpt_none kid, l)) as qid -> if Kid.compare kid sv = 0 then QI_id (KOpt_aux (KOpt_none subst, l)) else qid | QI_id (KOpt_aux (KOpt_kind (k, kid), l)) as qid -> if Kid.compare kid sv = 0 then QI_id (KOpt_aux (KOpt_kind (k, subst), l)) else qid - | QI_const nc -> QI_const (nc_subst_nexp sv (Nexp_var subst) nc) + | QI_const nc -> + QI_const (subst_kid constraint_subst sv subst nc) let quant_item_subst_kid sv subst (QI_aux (quant, l)) = QI_aux (quant_item_subst_kid_aux sv subst quant, l) diff --git a/src/ast_util.mli b/src/ast_util.mli index 19fc017d..73ab4a01 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -97,6 +97,7 @@ val unaux_nexp : nexp -> nexp_aux val unaux_order : order -> order_aux val unaux_typ : typ -> typ_aux val unaux_kind : kind -> kind_aux +val unaux_constraint : n_constraint -> n_constraint_aux val untyp_pat : 'a pat -> 'a pat * typ option val uncast_exp : 'a exp -> 'a exp * typ option @@ -154,7 +155,7 @@ val ntimes : nexp -> nexp -> nexp val npow2 : nexp -> nexp val nvar : kid -> nexp val napp : id -> nexp list -> nexp -val nid : id -> nexp (* NOTE: Nexp_id's don't do anything currently *) +val nid : id -> nexp (* Numeric constraint builders *) val nc_eq : nexp -> nexp -> n_constraint @@ -165,15 +166,16 @@ val nc_lt : nexp -> nexp -> n_constraint val nc_gt : nexp -> nexp -> n_constraint val nc_and : n_constraint -> n_constraint -> n_constraint val nc_or : n_constraint -> n_constraint -> n_constraint +val nc_not : n_constraint -> n_constraint val nc_true : n_constraint val nc_false : n_constraint val nc_set : kid -> Big_int.num list -> n_constraint val nc_int_set : kid -> int list -> n_constraint -(* Negate a n_constraint. Note that there's no NC_not constructor, so - this flips all the inequalites a the n_constraint leaves and uses - de-morgans to switch and to or and vice versa. *) -val nc_negate : n_constraint -> n_constraint +val arg_nexp : ?loc:l -> nexp -> typ_arg +val arg_order : ?loc:l -> order -> typ_arg +val arg_typ : ?loc:l -> typ -> typ_arg +val arg_bool : ?loc:l -> n_constraint -> typ_arg (* Functions for working with type quantifiers *) val quant_add : quant_item -> typquant -> typquant @@ -203,7 +205,6 @@ val def_loc : 'a def -> Parse_ast.l (* For debugging and error messages only: Not guaranteed to produce parseable SAIL, or even print all language constructs! *) -(* TODO: replace with existing pretty-printer *) val string_of_id : id -> string val string_of_kid : kid -> string val string_of_base_effect_aux : base_effect_aux -> string @@ -382,27 +383,16 @@ val unique : l -> l (** Substitutions *) -(* The function X_subst_Y substitutes a Y into something of type X, if - X = Y then the function is just X_subst. Substitutions are always - unwrapped from their aux constructors. *) -val nexp_subst : kid -> nexp_aux -> nexp -> nexp -val nc_subst_nexp : kid -> nexp_aux -> n_constraint -> n_constraint -val order_subst : kid -> order_aux -> order -> order +(* The function X_subst substitutes a type argument into something of + type X. The type of the type argument determines which kind of type + variables willb e replaced *) +val nexp_subst : kid -> typ_arg -> nexp -> nexp +val constraint_subst : kid -> typ_arg -> n_constraint -> n_constraint +val order_subst : kid -> typ_arg -> order -> order +val typ_subst : kid -> typ_arg -> typ -> typ +val typ_arg_subst : kid -> typ_arg -> typ_arg -> typ_arg -(* kid must be Int-kinded *) -val typ_subst_nexp : kid -> nexp_aux -> typ -> typ -val typ_subst_arg_nexp : kid -> nexp_aux -> typ_arg -> typ_arg - -(* kid must be Type-kinded *) -val typ_subst_typ : kid -> typ_aux -> typ -> typ -val typ_subst_arg_typ : kid -> typ_aux -> typ_arg -> typ_arg - -(* kid must be Order-kinded *) -val typ_subst_order : kid -> order_aux -> typ -> typ -val typ_subst_arg_order : kid -> order_aux -> typ_arg -> typ_arg - -val typ_subst_kid : kid -> kid -> typ -> typ -val typ_subst_arg_kid : kid -> kid -> typ_arg -> typ_arg +val subst_kid : (kid -> typ_arg -> 'a -> 'a) -> kid -> kid -> 'a -> 'a val quant_item_subst_kid : kid -> kid -> quant_item -> quant_item val typquant_subst_kid : kid -> kid -> typquant -> typquant diff --git a/src/constraint.ml b/src/constraint.ml index cf861423..460e8c76 100644 --- a/src/constraint.ml +++ b/src/constraint.ml @@ -49,86 +49,10 @@ (**************************************************************************) module Big_int = Nat_big_num +open Ast +open Ast_util open Util -(* ===== Integer Constraints ===== *) - -type nexp_op = string - -type nexp = - | NFun of (nexp_op * nexp list) - | N2n of nexp - | NConstant of Big_int.num - | NVar of int - -let big_int_op : nexp_op -> (Big_int.num -> Big_int.num -> Big_int.num) option = function - | "+" -> Some Big_int.add - | "-" -> Some Big_int.sub - | "*" -> Some Big_int.mul - | _ -> None - -let rec arith constr = - let constr' = match constr with - | NFun (op, [x; y]) -> NFun (op, [arith x; arith y]) - | N2n c -> N2n (arith c) - | c -> c - in - match constr' with - | NFun (op, [NConstant x; NConstant y]) as c -> - begin - match big_int_op op with - | Some op -> NConstant (op x y) - | None -> c - end - | N2n (NConstant x) -> NConstant (Big_int.pow_int_positive 2 (Big_int.to_int x)) - | c -> c - -(* ===== Boolean Constraints ===== *) - -type constraint_bool_op = And | Or - -type constraint_compare_op = Gt | Lt | GtEq | LtEq | Eq | NEq - -let negate_comparison = function - | Gt -> LtEq - | Lt -> GtEq - | GtEq -> Lt - | LtEq -> Gt - | Eq -> NEq - | NEq -> Eq - -type 'a constraint_bool = - | BFun of (constraint_bool_op * 'a constraint_bool * 'a constraint_bool) - | Not of 'a constraint_bool - | CFun of (constraint_compare_op * 'a * 'a) - | Forall of (int list * 'a constraint_bool) - | Boolean of bool - -let rec pairs (xs : 'a list) (ys : 'a list) : ('a * 'b) list = - match xs with - | [] -> [] - | (x :: xs) -> List.map (fun y -> (x, y)) ys @ pairs xs ys - -(* Get a set of variables from a constraint *) -module IntSet = Set.Make( - struct - let compare = Pervasives.compare - type t = int - end) - -let rec nexp_vars : nexp -> IntSet.t = function - | NConstant _ -> IntSet.empty - | NVar v -> IntSet.singleton v - | NFun (_, xs) -> List.fold_left IntSet.union IntSet.empty (List.map nexp_vars xs) - | N2n x -> nexp_vars x - -let rec constraint_vars : nexp constraint_bool -> IntSet.t = function - | BFun (_, x, y) -> IntSet.union (constraint_vars x) (constraint_vars y) - | Not x -> constraint_vars x - | CFun (_, x, y) -> IntSet.union (nexp_vars x) (nexp_vars y) - | Forall (vars, x) -> IntSet.diff (constraint_vars x) (IntSet.of_list vars) - | Boolean _ -> IntSet.empty - (* SMTLIB v2.0 format is based on S-expressions so we have a lightweight representation of those here. *) type sexpr = List of (sexpr list) | Atom of string @@ -139,47 +63,67 @@ let rec pp_sexpr : sexpr -> string = function | List xs -> "(" ^ string_of_list " " pp_sexpr xs ^ ")" | Atom x -> x -let var_decs constr = - constraint_vars constr - |> IntSet.elements - |> List.map (fun var -> sfun "declare-const" [Atom ("v" ^ string_of_int var); Atom "Int"]) - |> string_of_list "\n" pp_sexpr +let zencode_kid kid = Util.zencode_string (string_of_kid kid) -let cop_sexpr op x y = - match op with - | Gt -> sfun ">" [x; y] - | Lt -> sfun "<" [x; y] - | GtEq -> sfun ">=" [x; y] - | LtEq -> sfun "<=" [x; y] - | Eq -> sfun "=" [x; y] - | NEq -> sfun "not" [sfun "=" [x; y]] +(** Each non-Type/Order kind in Sail mapes to a type in the SMT solver *) +let smt_type l = function + | K_int -> Atom "Int" + | K_bool -> Atom "Bool" + | _ -> raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kinded variable to SMT solver") -let rec sexpr_of_nexp = function - | NFun (op, xs) -> sfun op (List.map sexpr_of_nexp xs) - | N2n x -> sfun "^" [Atom "2"; sexpr_of_nexp x] - | NConstant c -> Atom (Big_int.to_string c) (* CHECK: do we do negative constants right? *) - | NVar var -> Atom ("v" ^ string_of_int var) +let smt_var v = Atom ("v" ^ zencode_kid v) -let rec sexpr_of_constraint = function - | BFun (And, x, y) -> sfun "and" [sexpr_of_constraint x; sexpr_of_constraint y] - | BFun (Or, x, y) -> sfun "or" [sexpr_of_constraint x; sexpr_of_constraint y] - | Not x -> sfun "not" [sexpr_of_constraint x] - | CFun (op, x, y) -> cop_sexpr op (sexpr_of_nexp (arith x)) (sexpr_of_nexp (arith y)) - | Forall (vars, x) -> - sfun "forall" [List (List.map (fun v -> List [Atom ("v" ^ string_of_int v); Atom "Int"]) vars); sexpr_of_constraint x] - | Boolean true -> Atom "true" - | Boolean false -> Atom "false" +(** var_decs outputs the list of variables to be used by the SMT + solver in SMTLIB v2.0 format. It takes a kind_aux KBindings, as + returned by Type_check.get_typ_vars *) +let var_decs l (vars : kind_aux KBindings.t) : string = vars + |> KBindings.bindings + |> List.map (fun (v, k) -> sfun "declare-const" [smt_var v; smt_type l k]) + |> string_of_list "\n" pp_sexpr -let smtlib_of_constraints ?get_model:(get_model=false) constr : string = +let rec smt_nexp (Nexp_aux (aux, l) : nexp) : sexpr = + match aux with + | Nexp_id id -> Atom (Util.zencode_string (string_of_id id)) + | Nexp_var v -> smt_var v + | Nexp_constant c -> Atom (Big_int.to_string c) + | Nexp_app (id, nexps) -> sfun (string_of_id id) (List.map smt_nexp nexps) + | Nexp_times (nexp1, nexp2) -> sfun "*" [smt_nexp nexp1; smt_nexp nexp2] + | Nexp_sum (nexp1, nexp2) -> sfun "+" [smt_nexp nexp1; smt_nexp nexp2] + | Nexp_minus (nexp1, nexp2) -> sfun "-" [smt_nexp nexp1; smt_nexp nexp2] + | Nexp_exp nexp -> sfun "^" [Atom "2"; smt_nexp nexp] + | Nexp_neg nexp -> sfun "-" [smt_nexp nexp] + +let rec smt_constraint (NC_aux (aux, l) : n_constraint) : sexpr = + match aux with + | NC_equal (nexp1, nexp2) -> sfun "=" [smt_nexp nexp1; smt_nexp nexp2] + | NC_bounded_le (nexp1, nexp2) -> sfun "<=" [smt_nexp nexp1; smt_nexp nexp2] + | NC_bounded_ge (nexp1, nexp2) -> sfun ">=" [smt_nexp nexp1; smt_nexp nexp2] + | NC_not_equal (nexp1, nexp2) -> sfun "not" [sfun "=" [smt_nexp nexp1; smt_nexp nexp2]] + | NC_set (v, ints) -> + sfun "or" (List.map (fun i -> sfun "=" [smt_var v; Atom (Big_int.to_string i)]) ints) + | NC_or (nc1, nc2) -> sfun "or" [smt_constraint nc1; smt_constraint nc2] + | NC_and (nc1, nc2) -> sfun "and" [smt_constraint nc1; smt_constraint nc2] + | NC_app (id, args) -> + sfun (string_of_id id) (List.map smt_typ_arg args) + | NC_true -> Atom "true" + | NC_false -> Atom "false" + | NC_var v -> smt_var v + +and smt_typ_arg (Typ_arg_aux (aux, l) : typ_arg) : sexpr = + match aux with + | Typ_arg_nexp nexp -> smt_nexp nexp + | Typ_arg_bool nc -> smt_constraint nc + | _ -> + raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function") + +let smtlib_of_constraints ?get_model:(get_model=false) l vars constr : string = "(push)\n" - ^ var_decs constr ^ "\n" - ^ pp_sexpr (sfun "define-fun" [Atom "constraint"; List []; Atom "Bool"; sexpr_of_constraint constr]) + ^ var_decs l vars ^ "\n" + ^ pp_sexpr (sfun "define-fun" [Atom "constraint"; List []; Atom "Bool"; smt_constraint constr]) ^ "\n(assert constraint)\n(check-sat)" ^ (if get_model then "\n(get-model)" else "") ^ "\n(pop)" -type t = nexp constraint_bool - type smt_result = Unknown | Sat | Unsat module DigestMap = Map.Make(Digest) @@ -219,9 +163,9 @@ let save_digests () = DigestMap.iter output !known_problems; close_out out_chan -let call_z3' constraints : smt_result = +let call_z3' l vars constraints : smt_result = let problems = [constraints] in - let z3_file = smtlib_of_constraints constraints in + let z3_file = smtlib_of_constraints l vars constraints in (* prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" z3_file); *) @@ -261,15 +205,15 @@ let call_z3' constraints : smt_result = else (known_problems := DigestMap.add digest Unknown !known_problems; Unknown) end -let call_z3 constraints = +let call_z3 l vars constraints = let t = Profile.start_z3 () in - let result = call_z3' constraints in + let result = call_z3' l vars constraints in Profile.finish_z3 t; result -let rec solve_z3 constraints var = +let rec solve_z3 l vars constraints var = let problems = [constraints] in - let z3_file = smtlib_of_constraints ~get_model:true constraints in + let z3_file = smtlib_of_constraints ~get_model:true l vars constraints in (* prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" z3_file); *) @@ -289,62 +233,13 @@ let rec solve_z3 constraints var = let z3_output = String.concat " " (input_all z3_chan) in let _ = Unix.close_process_in z3_chan in Sys.remove input_file; - let regexp = {|(define-fun v|} ^ string_of_int var ^ {| () Int[ ]+\([0-9]+\))|} in + let regexp = {|(define-fun v|} ^ Util.zencode_string (string_of_kid var) ^ {| () Int[ ]+\([0-9]+\))|} in try let _ = Str.search_forward (Str.regexp regexp) z3_output 0 in let result = Big_int.of_string (Str.matched_group 1 z3_output) in - begin match call_z3 (BFun (And, constraints, CFun (NEq, NConstant result, NVar var))) with + begin match call_z3 l vars (nc_and constraints (nc_neq (nconstant result) (nvar var))) with | Unsat -> Some result | _ -> None end with Not_found -> None - -let string_of constr = smtlib_of_constraints constr - -(* ===== Abstract API for building constraints ===== *) - -(* These functions are exported from constraint.mli, and ensure that - the internal representation of constraints remains opaque. *) - -let implies (x : t) (y : t) : t = - BFun (Or, Not x, y) - -let conj (x : t) (y : t) : t = - BFun (And, x, y) - -let disj (x : t) (y : t) : t = - BFun (Or, x, y) - -let forall (vars : int list) (x : t) : t = - if vars = [] then x else Forall (vars, x) - -let negate (x : t) : t = Not x - -let literal (b : bool) : t = Boolean b - -let lt x y : t = CFun (Lt, x, y) - -let lteq x y : t = CFun (LtEq, x, y) - -let gt x y : t = CFun (Gt, x, y) - -let gteq x y : t = CFun (GtEq, x, y) - -let eq x y : t = CFun (Eq, x, y) - -let neq x y : t = CFun (NEq, x, y) - -let pow2 x : nexp = N2n x - -let add x y : nexp = NFun ("+", [x; y]) - -let sub x y : nexp = NFun ("-", [x; y]) - -let mult x y : nexp = NFun ("*", [x; y]) - -let app f xs : nexp = NFun (f, xs) - -let constant (x : Big_int.num) : nexp = NConstant x - -let variable (v : int) : nexp = NVar v diff --git a/src/constraint.mli b/src/constraint.mli index df9c8b3a..51088245 100644 --- a/src/constraint.mli +++ b/src/constraint.mli @@ -49,40 +49,14 @@ (**************************************************************************) module Big_int = Nat_big_num - -type nexp -type t +open Ast +open Ast_util type smt_result = Unknown | Sat | Unsat val load_digests : unit -> unit val save_digests : unit -> unit -val call_z3 : t -> smt_result - -val solve_z3 : t -> int -> Big_int.num option - -val string_of : t -> string - -val implies : t -> t -> t -val conj : t -> t -> t -val disj : t -> t -> t -val negate : t -> t -val literal : bool -> t -val forall : int list -> t -> t - -val lt : nexp -> nexp -> t -val lteq : nexp -> nexp -> t -val gt : nexp -> nexp -> t -val gteq : nexp -> nexp -> t -val eq : nexp -> nexp -> t -val neq : nexp -> nexp -> t - -val pow2 : nexp -> nexp -val add : nexp -> nexp -> nexp -val sub : nexp -> nexp -> nexp -val mult : nexp -> nexp -> nexp -val app : string -> nexp list -> nexp +val call_z3 : l -> kind_aux KBindings.t -> n_constraint -> smt_result -val constant : Big_int.num -> nexp -val variable : int -> nexp +val solve_z3 : l -> kind_aux KBindings.t -> n_constraint -> kid -> Big_int.num option diff --git a/src/initial_check.ml b/src/initial_check.ml index b57e6b17..e84f655c 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -85,6 +85,7 @@ let to_ast_kind (P.K_aux (k, l)) = | P.K_type -> K_aux (K_type, l) | P.K_int -> K_aux (K_int, l) | P.K_order -> K_aux (K_order, l) + | P.K_bool -> K_aux (K_bool, l) let to_ast_id (P.Id_aux(id, l)) = if string_contains (string_of_parse_id_aux id) '#' && not (!opt_magic_hash) then @@ -143,6 +144,8 @@ let rec to_ast_typ ctx (P.ATyp_aux (aux, l)) = | P.ATyp_tup typs -> Typ_tup (List.map (to_ast_typ ctx) typs) | P.ATyp_app (P.Id_aux (P.Id "int", il), [n]) -> Typ_app (Id_aux (Id "atom", il), [to_ast_typ_arg ctx n K_int]) + | P.ATyp_app (P.Id_aux (P.Id "bool", il), [n]) -> + Typ_app (Id_aux (Id "atom_bool", il), [to_ast_typ_arg ctx n K_bool]) | P.ATyp_app (id, args) -> let id = to_ast_id id in begin match Bindings.find_opt id ctx.type_constructors with @@ -166,6 +169,7 @@ and to_ast_typ_arg ctx (ATyp_aux (_, l) as atyp) = function | K_type -> Typ_arg_aux (Typ_arg_typ (to_ast_typ ctx atyp), l) | K_int -> Typ_arg_aux (Typ_arg_nexp (to_ast_nexp ctx atyp), l) | K_order -> Typ_arg_aux (Typ_arg_order (to_ast_order ctx atyp), l) + | K_bool -> Typ_arg_aux (Typ_arg_bool (to_ast_constraint ctx atyp), l) and to_ast_nexp ctx (P.ATyp_aux (aux, l)) = let aux = match aux with @@ -203,6 +207,17 @@ and to_ast_constraint ctx (P.ATyp_aux (aux, l) as atyp) = | "|" -> NC_or (to_ast_constraint ctx t1, to_ast_constraint ctx t2) | _ -> raise (Reporting.err_typ l ("Invalid operator in constraint")) end + | P.ATyp_app (id, args) -> + let id = to_ast_id id in + begin match Bindings.find_opt id ctx.type_constructors with + | None -> raise (Reporting.err_typ l (sprintf "Could not find type constructor %s" (string_of_id id))) + | Some kinds when List.length args <> List.length kinds -> + raise (Reporting.err_typ l (sprintf "%s : %s -> Bool expected %d arguments, given %d" + (string_of_id id) (format_kind_aux_list kinds) + (List.length kinds) (List.length args))) + | Some kinds -> NC_app (id, List.map2 (to_ast_typ_arg ctx) args kinds) + end + | P.ATyp_var v -> NC_var (to_ast_var v) | P.ATyp_lit (P.L_aux (P.L_true, _)) -> NC_true | P.ATyp_lit (P.L_aux (P.L_false, _)) -> NC_false | P.ATyp_nset (id, bounds) -> NC_set (to_ast_var id, bounds) @@ -490,11 +505,12 @@ let add_constructor id typq ctx = let to_ast_typedef ctx (P.TD_aux (aux, l) : P.type_def) : unit type_def ctx_out = let aux, ctx = match aux with - | P.TD_abbrev (id, namescm_opt, P.TypSchm_aux (P.TypSchm_ts (typq, typ), l)) -> + | P.TD_abbrev (id, typq, kind, typ_arg) -> let id = to_ast_id id in let typq, typq_ctx = to_ast_typquant ctx typq in - let typ = to_ast_typ typq_ctx typ in - TD_abbrev (id, to_ast_namescm namescm_opt, TypSchm_aux (TypSchm_ts (typq, typ), l)), + let kind = to_ast_kind kind in + let typ_arg = to_ast_typ_arg typq_ctx typ_arg (unaux_kind kind) in + TD_abbrev (id, typq, typ_arg), add_constructor id typq ctx | P.TD_record (id, namescm_opt, typq, fields, _) -> diff --git a/src/lexer.mll b/src/lexer.mll index f5a982eb..57580e7a 100644 --- a/src/lexer.mll +++ b/src/lexer.mll @@ -140,6 +140,7 @@ let kw_table = ("ref", (fun _ -> Ref)); ("Int", (fun x -> Int)); ("Order", (fun x -> Order)); + ("Bool", (fun x -> Bool)); ("pure", (fun x -> Pure)); ("register", (fun x -> Register)); ("return", (fun x -> Return)); diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml index f3a3fa54..a3d47814 100644 --- a/src/ocaml_backend.ml +++ b/src/ocaml_backend.ml @@ -602,7 +602,7 @@ let ocaml_typedef ctx (TD_aux (td_aux, _)) = ^//^ (bar ^^ space ^^ ocaml_enum ctx ids)) ^^ ocaml_def_end ^^ ocaml_string_of_enum ctx id ids - | TD_abbrev (id, _, TypSchm_aux (TypSchm_ts (typq, typ), _)) -> + | TD_abbrev (id, typq, Typ_arg_aux (Typ_arg_typ typ, _)) -> separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals; ocaml_typ ctx typ] ^^ ocaml_def_end ^^ ocaml_string_of_abbrev ctx id typq typ @@ -706,7 +706,7 @@ let ocaml_pp_generators ctx defs orig_types required = -> required and add_req_from_td required (TD_aux (td,(l,_))) = match td with - | TD_abbrev (_, _, TypSchm_aux (TypSchm_ts (_,typ),_)) -> + | TD_abbrev (_, _, Typ_arg_aux (Typ_arg_typ typ, _)) -> add_req_from_typ required typ | TD_record (_, _, _, fields, _) -> List.fold_left (fun req (typ,_) -> add_req_from_typ req typ) required fields @@ -723,10 +723,11 @@ let ocaml_pp_generators ctx defs orig_types required = match Bindings.find id typemap with | TD_aux (td,_) -> (match td with - | TD_abbrev (_,_,TypSchm_aux (TypSchm_ts (tqs,typ),_)) -> tqs + | TD_abbrev (_,tqs,Typ_arg_aux (Typ_arg_typ _, _)) -> tqs | TD_record (_,_,tqs,_,_) -> tqs | TD_variant (_,_,tqs,_,_) -> tqs | TD_enum _ -> TypQ_aux (TypQ_no_forall,Unknown) + | TD_abbrev (_, _, _) -> assert false | TD_bitfield _ -> assert false) | exception Not_found -> Bindings.find id Type_check.Env.builtin_typs @@ -844,7 +845,7 @@ let ocaml_pp_generators ctx defs orig_types required = let tqs, body, constructors, builders = let TD_aux (td,(l,_)) = Bindings.find id typemap in match td with - | TD_abbrev (_,_,TypSchm_aux (TypSchm_ts (tqs,typ),_)) -> + | TD_abbrev (_,tqs,Typ_arg_aux (Typ_arg_typ typ, _)) -> tqs, gen_type typ, None, None | TD_variant (_,_,tqs,variants,_) -> tqs, diff --git a/src/parse_ast.ml b/src/parse_ast.ml index 204389f9..c57daa26 100644 --- a/src/parse_ast.ml +++ b/src/parse_ast.ml @@ -74,6 +74,7 @@ kind_aux = (* base kind *) K_type (* kind of types *) | K_int (* kind of natural number size expressions *) | K_order (* kind of vector order specifications *) + | K_bool (* kind of constraints *) type @@ -443,7 +444,7 @@ fundef_aux = (* Function definition *) type type_def_aux = (* Type definition body *) - TD_abbrev of id * name_scm_opt * typschm (* type abbreviation *) + TD_abbrev of id * typquant * kind * atyp (* type abbreviation *) | TD_record of id * name_scm_opt * typquant * ((atyp * id)) list * bool (* struct type definition *) | TD_variant of id * name_scm_opt * typquant * (type_union) list * bool (* union type definition *) | TD_enum of id * name_scm_opt * (id) list * bool (* enumeration type definition *) diff --git a/src/parser.mly b/src/parser.mly index bb5aa5f1..fa36591c 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -175,7 +175,7 @@ let rec desugar_rchain chain s e = /*Terminals with no content*/ %token And As Assert Bitzero Bitone By Match Clause Dec Default Effect End Op Where -%token Enum Else False Forall Foreach Overload Function_ Mapping If_ In Inc Let_ Int Order Cast +%token Enum Else False Forall Foreach Overload Function_ Mapping If_ In Inc Let_ Int Order Bool Cast %token Pure Register Return Scattered Sizeof Struct Then True TwoCaret TYPE Typedef %token Undefined Union Newtype With Val Constant Constraint Throw Try Catch Exit Bitfield %token Barr Depend Rreg Wreg Rmem Rmemt Wmem Wmv Wmvt Eamem Exmem Undef Unspec Nondet Escape @@ -563,6 +563,8 @@ kind: { K_aux (K_type, loc $startpos $endpos) } | Order { K_aux (K_order, loc $startpos $endpos) } + | Bool + { K_aux (K_bool, loc $startpos $endpos) } kopt: | Lparen kid Colon kind Rparen @@ -1154,9 +1156,13 @@ typaram: type_def: | Typedef id typaram Eq typ - { mk_td (TD_abbrev ($2, mk_namesectn, mk_typschm $3 $5 $startpos($3) $endpos)) $startpos $endpos } + { mk_td (TD_abbrev ($2, $3, K_aux (K_type, Parse_ast.Unknown), $5)) $startpos $endpos } | Typedef id Eq typ - { mk_td (TD_abbrev ($2, mk_namesectn, mk_typschm mk_typqn $4 $startpos($4) $endpos)) $startpos $endpos } + { mk_td (TD_abbrev ($2, mk_typqn, K_aux (K_type, Parse_ast.Unknown), $4)) $startpos $endpos } + | Typedef id typaram MinusGt kind Eq typ + { mk_td (TD_abbrev ($2, $3, $5, $7)) $startpos $endpos } + | Typedef id Colon kind Eq typ + { mk_td (TD_abbrev ($2, mk_typqn, $4, $6)) $startpos $endpos } | Struct id Eq Lcurly struct_fields Rcurly { mk_td (TD_record ($2, mk_namesectn, TypQ_aux (TypQ_tq [], loc $endpos($2) $startpos($3)), $5, false)) $startpos $endpos } | Struct id typaram Eq Lcurly struct_fields Rcurly diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index 7cc61507..50a97fa8 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -1711,7 +1711,8 @@ let rec doc_range (BF_aux(r,_)) = match r with | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2) let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with - | TD_abbrev(id,nm,(TypSchm_aux (TypSchm_ts (typq, _), _) as typschm)) -> + | TD_abbrev(id,typq,Typ_arg_aux (Typ_arg_typ typ, _)) -> + let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in doc_op coloneq (separate space [string "Definition"; doc_id_type id; doc_typquant_items empty_ctxt parens typq; diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index be790a1c..e5613961 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -1006,7 +1006,8 @@ let rec doc_range_lem (BF_aux(r,_)) = match r with | BF_concat(ir1,ir2) -> (doc_range ir1) ^^ comma ^^ (doc_range ir2) let doc_typdef_lem (TD_aux(td, (l, annot))) = match td with - | TD_abbrev(id,nm,(TypSchm_aux (TypSchm_ts (typq, _), _) as typschm)) -> + | TD_abbrev(id,typq,Typ_arg_aux (Typ_arg_typ typ, _)) -> + let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in doc_op equals (separate space [string "type"; doc_id_lem_type id; doc_typquant_items_lem None typq]) (doc_typschm_lem false typschm) diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 5201744b..7fb67a06 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -259,7 +259,7 @@ let doc_typschm ?(simple=false) (TypSchm_aux (TypSchm_ts (typq, typ), _)) = doc_ let doc_typschm_typ (TypSchm_aux (TypSchm_ts (TypQ_aux (tq_aux, _), typ), _)) = doc_typ typ -let doc_typschm_quants (TypSchm_aux (TypSchm_ts (TypQ_aux (tq_aux, _), typ), _)) = +let doc_typquant (TypQ_aux (tq_aux, _)) = match tq_aux with | TypQ_no_forall -> None | TypQ_tq [] -> None @@ -562,13 +562,13 @@ let doc_field (typ, id) = let doc_union (Tu_aux (Tu_ty_id (typ, id), l)) = separate space [doc_id id; colon; doc_typ typ] let doc_typdef (TD_aux(td,_)) = match td with - | TD_abbrev (id, _, typschm) -> + | TD_abbrev (id, typq, typ_arg) -> begin - match doc_typschm_quants typschm with + match doc_typquant typq with | Some qdoc -> - doc_op equals (concat [string "type"; space; doc_id id; qdoc]) (doc_typschm_typ typschm) + doc_op equals (concat [string "type"; space; doc_id id; qdoc]) (doc_typ_arg typ_arg) | None -> - doc_op equals (concat [string "type"; space; doc_id id]) (doc_typschm_typ typschm) + doc_op equals (concat [string "type"; space; doc_id id]) (doc_typ_arg typ_arg) end | TD_enum (id, _, ids, _) -> separate space [string "enum"; doc_id id; equals; surround 2 0 lbrace (separate_map (comma ^^ break 1) doc_id ids) rbrace] diff --git a/src/rewriter.ml b/src/rewriter.ml index 77070025..200121c0 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -94,7 +94,7 @@ let lookup_generated_kid env kid = let generated_kids typ = KidSet.filter is_kid_generated (tyvars_of_typ typ) let resolve_generated_kids env typ = - let subst_kid kid typ = typ_subst_kid kid (lookup_generated_kid env kid) typ in + let subst_kid kid typ = subst_kid typ_subst kid (lookup_generated_kid env kid) typ in KidSet.fold subst_kid (generated_kids typ) typ let rec remove_p_typ = function diff --git a/src/rewrites.ml b/src/rewrites.ml index 0ead9670..82228206 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2300,9 +2300,10 @@ let rewrite_constraint = let rewrite_type_union_typs rw_typ (Tu_aux (Tu_ty_id (typ, id), annot)) = Tu_aux (Tu_ty_id (rw_typ typ, id), annot) -let rewrite_type_def_typs rw_typ rw_typquant rw_typschm (TD_aux (td, annot)) = +let rewrite_type_def_typs rw_typ rw_typquant (TD_aux (td, annot)) = match td with - | TD_abbrev (id, nso, typschm) -> TD_aux (TD_abbrev (id, nso, rw_typschm typschm), annot) + | TD_abbrev (id, typq, Typ_arg_aux (Typ_arg_typ typ, l)) -> + TD_aux (TD_abbrev (id, rw_typquant typq, Typ_arg_aux (Typ_arg_typ (rw_typ typ), l)), annot) | TD_record (id, nso, typq, typ_ids, flag) -> TD_aux (TD_record (id, nso, rw_typquant typq, List.map (fun (typ, id) -> (rw_typ typ, id)) typ_ids, flag), annot) | TD_variant (id, nso, typq, tus, flag) -> @@ -2396,7 +2397,7 @@ let rewrite_simple_types (Defs defs) = in let simple_def = function | DEF_spec vs -> DEF_spec (simple_vs vs) - | DEF_type td -> DEF_type (rewrite_type_def_typs simple_typ simple_typquant simple_typschm td) + | DEF_type td -> DEF_type (rewrite_type_def_typs simple_typ simple_typquant td) | DEF_reg_dec ds -> DEF_reg_dec (rewrite_dec_spec_typs simple_typ ds) | def -> def in diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml index 9453e999..84fa8235 100644 --- a/src/spec_analysis.ml +++ b/src/spec_analysis.ml @@ -309,7 +309,9 @@ let fv_of_kind_def consider_var (KD_aux(k,_)) = match k with | KD_nabbrev(_,id,_,nexp) -> init_env (string_of_id id), fv_of_nexp consider_var mt mt nexp let fv_of_type_def consider_var (TD_aux(t,_)) = match t with - | TD_abbrev(id,_,typschm) -> init_env (string_of_id id), snd (fv_of_typschm consider_var mt mt typschm) + | TD_abbrev(id,typq,Typ_arg_aux(Typ_arg_typ typ, l)) -> + let typschm = TypSchm_aux (TypSchm_ts (typq,typ), l) in + init_env (string_of_id id), snd (fv_of_typschm consider_var mt mt typschm) | TD_record(id,_,typq,tids,_) -> let binds = init_env (string_of_id id) in let bounds = if consider_var then typq_bindings typq else mt in diff --git a/src/state.ml b/src/state.ml index 31f5c7eb..00f81bf4 100644 --- a/src/state.ml +++ b/src/state.ml @@ -127,15 +127,9 @@ let generate_initial_regstate defs = | Typ_exist (_, _, typ) -> lookup_init_val vals typ | _ -> raise Not_found in - (* Helper functions to instantiate type arguments *) - let typ_subst_targ kid (Typ_arg_aux (arg, _)) typ = match arg with - | Typ_arg_nexp (Nexp_aux (nexp, _)) -> typ_subst_nexp kid nexp typ - | Typ_arg_typ (Typ_aux (typ', _)) -> typ_subst_typ kid typ' typ - | Typ_arg_order (Ord_aux (ord, _)) -> typ_subst_order kid ord typ - in let typ_subst_quant_item typ (QI_aux (qi, _)) arg = match qi with | QI_id (KOpt_aux ((KOpt_none kid | KOpt_kind (_, kid)), _)) -> - typ_subst_targ kid arg typ + typ_subst kid arg typ | _ -> typ in let typ_subst_typquant tq args typ = @@ -152,7 +146,7 @@ let generate_initial_regstate defs = string_of_id id1 ^ " (" ^ lookup_init_val vals typ1 ^ ")" in Bindings.add id init_val vals - | TD_abbrev (id, _, TypSchm_aux (TypSchm_ts (tq, typ), _)) -> + | TD_abbrev (id, tq, Typ_arg_aux (Typ_arg_typ typ, _)) -> let init_val args = lookup_init_val vals (typ_subst_typquant tq args typ) in Bindings.add id init_val vals | TD_record (id, _, tq, fields, _) -> diff --git a/src/type_check.ml b/src/type_check.ml index f204a558..2c10b8ae 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -167,6 +167,8 @@ and strip_n_constraint_aux = function | NC_set (kid, nums) -> NC_set (strip_kid kid, nums) | NC_or (nc1, nc2) -> NC_or (strip_n_constraint nc1, strip_n_constraint nc2) | NC_and (nc1, nc2) -> NC_and (strip_n_constraint nc1, strip_n_constraint nc2) + | NC_var kid -> NC_var (strip_kid kid) + | NC_app (id, args) -> NC_app (strip_id id, List.map strip_typ_arg args) | NC_true -> NC_true | NC_false -> NC_false and strip_n_constraint = function @@ -177,6 +179,7 @@ and strip_typ_arg_aux = function | Typ_arg_nexp nexp -> Typ_arg_nexp (strip_nexp nexp) | Typ_arg_typ typ -> Typ_arg_typ (strip_typ typ) | Typ_arg_order ord -> Typ_arg_order (strip_order ord) + | Typ_arg_bool nc -> Typ_arg_bool (strip_n_constraint nc) and strip_order = function | Ord_aux (ord_aux, _) -> Ord_aux (strip_order_aux ord_aux, Parse_ast.Unknown) and strip_order_aux = function @@ -256,9 +259,8 @@ module Env : sig val add_typ_var : l -> kid -> kind_aux -> t -> t val get_ret_typ : t -> typ option val add_ret_typ : typ -> t -> t - val add_typ_synonym : id -> (t -> typ_arg list -> typ) -> t -> t - val get_typ_synonym : id -> t -> t -> typ_arg list -> typ - val add_constraint_synonym : id -> kid list -> n_constraint -> t -> t + val add_typ_synonym : id -> (t -> typ_arg list -> typ_arg) -> t -> t + val get_typ_synonym : id -> t -> t -> typ_arg list -> typ_arg val add_num_def : id -> nexp -> t -> t val get_num_def : id -> t -> nexp val add_overloads : id -> id list -> t -> t @@ -282,6 +284,7 @@ module Env : sig val lookup_id : ?raw:bool -> id -> t -> typ lvar val fresh_kid : ?kid:kid -> t -> kid val expand_synonyms : t -> typ -> typ + val expand_constraint_synonyms : t -> n_constraint -> n_constraint val canonicalize : t -> typ -> typ val base_typ_of : t -> typ -> typ val add_smt_op : id -> string -> t -> t @@ -320,7 +323,7 @@ end = struct variants : (typquant * type_union list) Bindings.t; mappings : (typquant * typ * typ) Bindings.t; typ_vars : (Ast.l * kind_aux) KBindings.t; - typ_synonyms : (t -> typ_arg list -> typ) Bindings.t; + typ_synonyms : (t -> typ_arg list -> typ_arg) Bindings.t; num_defs : nexp Bindings.t; overloads : (id list) Bindings.t; flow : (typ -> typ) Bindings.t; @@ -329,7 +332,6 @@ end = struct accessors : (typquant * typ) Bindings.t; externs : (string -> string option) Bindings.t; smt_ops : string Bindings.t; - constraint_synonyms : (kid list * n_constraint) Bindings.t; casts : id list; allow_casts : bool; allow_bindings : bool; @@ -359,7 +361,6 @@ end = struct accessors = Bindings.empty; externs = Bindings.empty; smt_ops = Bindings.empty; - constraint_synonyms = Bindings.empty; casts = []; allow_bindings = true; allow_casts = true; @@ -465,8 +466,8 @@ end = struct let kopts, ncs = quant_split typq in let rec subst_args kopts args = match kopts, args with - | kopt :: kopts, Typ_arg_aux (Typ_arg_nexp arg, _) :: args when is_nat_kopt kopt -> - List.map (nc_subst_nexp (kopt_kid kopt) (unaux_nexp arg)) (subst_args kopts args) + | kopt :: kopts, (Typ_arg_aux (Typ_arg_nexp _, _) as arg) :: args when is_nat_kopt kopt -> + List.map (constraint_subst (kopt_kid kopt) arg) (subst_args kopts args) | kopt :: kopts, Typ_arg_aux (Typ_arg_typ arg, _) :: args when is_typ_kopt kopt -> subst_args kopts args | kopt :: kopts, Typ_arg_aux (Typ_arg_order arg, _) :: args when is_order_kopt kopt -> @@ -480,28 +481,41 @@ end = struct then () else typ_error (id_loc id) ("Could not prove " ^ string_of_list ", " string_of_n_constraint ncs ^ " for type constructor " ^ string_of_id id) - let rec expand_synonyms env (Typ_aux (typ, l) as t) = + let rec expand_constraint_synonyms env (NC_aux (aux, l) as nc) = + match aux with + | NC_or (nc1, nc2) -> NC_aux (NC_or (expand_constraint_synonyms env nc1, expand_constraint_synonyms env nc2), l) + | NC_and (nc1, nc2) -> NC_aux (NC_and (expand_constraint_synonyms env nc1, expand_constraint_synonyms env nc2), l) + | NC_app (id, args) -> + (try + begin match Bindings.find id env.typ_synonyms env args with + | Typ_arg_aux (Typ_arg_bool nc, _) -> expand_constraint_synonyms env nc + | _ -> typ_error l ("Expected Type when expanding synonym " ^ string_of_id id) + end + with Not_found -> NC_aux (NC_app (id, List.map (expand_synonyms_arg env) args), l)) + | NC_true | NC_false | NC_equal _ | NC_not_equal _ | NC_bounded_le _ | NC_bounded_ge _ | NC_var _ | NC_set _ -> nc + + and expand_synonyms env (Typ_aux (typ, l) as t) = match typ with | Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l) | Typ_tup typs -> Typ_aux (Typ_tup (List.map (expand_synonyms env) typs), l) | Typ_fn (arg_typs, ret_typ, effs) -> Typ_aux (Typ_fn (List.map (expand_synonyms env) arg_typs, expand_synonyms env ret_typ, effs), l) | Typ_bidir (typ1, typ2) -> Typ_aux (Typ_bidir (expand_synonyms env typ1, expand_synonyms env typ2), l) | Typ_app (id, args) -> - begin - try - let synonym = Bindings.find id env.typ_synonyms in - expand_synonyms env (synonym env args) - with - | Not_found -> Typ_aux (Typ_app (id, List.map (expand_synonyms_arg env) args), l) - end + (try + begin match Bindings.find id env.typ_synonyms env args with + | Typ_arg_aux (Typ_arg_typ typ, _) -> expand_synonyms env typ + | _ -> typ_error l ("Expected Type when expanding synonym " ^ string_of_id id) + end + with + | Not_found -> Typ_aux (Typ_app (id, List.map (expand_synonyms_arg env) args), l)) | Typ_id id -> - begin - try - let synonym = Bindings.find id env.typ_synonyms in - expand_synonyms env (synonym env []) - with - | Not_found -> Typ_aux (Typ_id id, l) - end + (try + begin match Bindings.find id env.typ_synonyms env [] with + | Typ_arg_aux (Typ_arg_typ typ, _) -> expand_synonyms env typ + | _ -> typ_error l ("Expected Type when expanding synonym " ^ string_of_id id) + end + with + | Not_found -> Typ_aux (Typ_id id, l)) | Typ_exist (kids, nc, typ) -> (* When expanding an existential synonym we need to take care to add the type variables and constraints to the @@ -524,8 +538,8 @@ end = struct let env = List.fold_left add_typ_var env kids in let kids = List.map rename_kid kids in - let nc = List.fold_left (fun nc kid -> nc_subst_nexp kid (Nexp_var (prepend_kid "syn#" kid)) nc) nc !rebindings in - let typ = List.fold_left (fun typ kid -> typ_subst_nexp kid (Nexp_var (prepend_kid "syn#" kid)) typ) typ !rebindings in + let nc = List.fold_left (fun nc kid -> constraint_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) nc) nc !rebindings in + let typ = List.fold_left (fun typ kid -> typ_subst kid (arg_nexp (nvar (prepend_kid "syn#" kid))) typ) typ !rebindings in typ_debug (lazy ("Synonym existential: {" ^ string_of_list " " string_of_kid kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ ^ "}")); let env = { env with constraints = nc :: env.constraints } in Typ_aux (Typ_exist (kids, nc, expand_synonyms env typ), l) @@ -533,6 +547,7 @@ end = struct and expand_synonyms_arg env (Typ_arg_aux (typ_arg, l)) = match typ_arg with | Typ_arg_typ typ -> Typ_arg_aux (Typ_arg_typ (expand_synonyms env typ), l) + | Typ_arg_bool nc -> Typ_arg_aux (Typ_arg_bool (expand_constraint_synonyms env nc), l) | arg -> Typ_arg_aux (arg, l) (** Map over all nexps in a type - excluding those in existential constraints **) @@ -547,7 +562,7 @@ end = struct | Typ_app (id, args) -> Typ_aux (Typ_app (id, List.map (map_nexps_arg f) args), l) and map_nexps_arg f (Typ_arg_aux (arg_aux, l) as arg) = match arg_aux with - | Typ_arg_order _ | Typ_arg_typ _ -> arg + | Typ_arg_order _ | Typ_arg_typ _ | Typ_arg_bool _ -> arg | Typ_arg_nexp n -> Typ_arg_aux (Typ_arg_nexp (f n), l) let canonical env typ = @@ -600,7 +615,6 @@ end = struct (* Check if a type, order, n-expression or constraint is well-formed. Throws a type error if the type is badly formed. *) let rec wf_typ ?exs:(exs=KidSet.empty) env typ = - typ_debug (lazy ("well-formed " ^ string_of_typ typ)); let (Typ_aux (typ_aux, l)) = expand_synonyms env typ in match typ_aux with | Typ_id id when bound_typ_id env id -> @@ -637,8 +651,8 @@ end = struct | Typ_arg_nexp nexp -> wf_nexp ~exs:exs env nexp | Typ_arg_typ typ -> wf_typ ~exs:exs env typ | Typ_arg_order ord -> wf_order env ord + | Typ_arg_bool nc -> wf_constraint ~exs:exs env nc and wf_nexp ?exs:(exs=KidSet.empty) env (Nexp_aux (nexp_aux, l) as nexp) = - typ_debug (lazy ("well-formed nexp " ^ string_of_nexp nexp)); match nexp_aux with | Nexp_id _ -> () | Nexp_var kid when KidSet.mem kid exs -> () @@ -671,22 +685,29 @@ end = struct end | Ord_inc | Ord_dec -> () and wf_constraint ?exs:(exs=KidSet.empty) env (NC_aux (nc_aux, l) as nc) = - typ_debug (lazy ("well-formed constraint " ^ string_of_n_constraint nc)); match nc_aux with | NC_equal (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 | NC_not_equal (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 | NC_bounded_ge (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 | NC_bounded_le (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 | NC_set (kid, _) when KidSet.mem kid exs -> () - | NC_set (kid, _) -> begin - match get_typ_var kid env with - | K_int -> () - | kind -> typ_error l ("Set constraint is badly formed, " - ^ string_of_kid kid ^ " has kind " - ^ string_of_kind_aux kind ^ " but should have kind Int") - end + | NC_set (kid, _) -> + begin match get_typ_var kid env with + | K_int -> () + | kind -> typ_error l ("Set constraint is badly formed, " + ^ string_of_kid kid ^ " has kind " + ^ string_of_kind_aux kind ^ " but should have kind Int") + end | NC_or (nc1, nc2) -> wf_constraint ~exs:exs env nc1; wf_constraint ~exs:exs env nc2 | NC_and (nc1, nc2) -> wf_constraint ~exs:exs env nc1; wf_constraint ~exs:exs env nc2 + | NC_app (id, args) -> List.iter (wf_typ_arg ~exs:exs env) args + | NC_var kid -> + begin match get_typ_var kid env with + | K_bool -> () + | kind -> typ_error l ("Set constraint is badly formed, " + ^ string_of_kid kid ^ " has kind " + ^ string_of_kind_aux kind ^ " but should have kind Bool") + end | NC_true | NC_false -> () let counter = ref 0 @@ -699,7 +720,7 @@ end = struct let freshen_kid env kid (typq, typ) = let fresh = fresh_kid ~kid:kid env in if KidSet.mem kid (KidSet.of_list (List.map kopt_kid (quant_kopts typq))) then - (typquant_subst_kid kid fresh typq, typ_subst_kid kid fresh typ) + (typquant_subst_kid kid fresh typq, subst_kid typ_subst kid fresh typ) else (typq, typ) @@ -733,8 +754,8 @@ end = struct match expand_synonyms env typ with | Typ_aux (Typ_exist (kids, nc, typ), _) -> let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in - let nc = List.fold_left (fun nc (kid, fresh) -> nc_subst_nexp kid (Nexp_var fresh) nc) nc fresh_kids in - let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst_nexp kid (Nexp_var fresh) typ) typ fresh_kids in + let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in + let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in Some (List.map snd fresh_kids, nc, typ) | _ -> None @@ -1067,16 +1088,6 @@ end = struct let get_typ_synonym id env = Bindings.find id env.typ_synonyms - let add_constraint_synonym id kids nc env = - if Bindings.mem id env.constraint_synonyms - then typ_error (id_loc id) ("Constraint synonym " ^ string_of_id id ^ " already exists") - else - begin - typ_print (lazy (adding ^ "constraint synonym " ^ string_of_id id)); - wf_constraint ~exs:(KidSet.of_list kids) env nc; - { env with constraint_synonyms = Bindings.add id (kids, nc) env.constraint_synonyms } - end - let get_default_order env = match env.default_order with | None -> typ_error Parse_ast.Unknown ("No default order has been set") @@ -1154,8 +1165,8 @@ let destruct_exist env typ = match Env.expand_synonyms env typ with | Typ_aux (Typ_exist (kids, nc, typ), _) -> let fresh_kids = List.map (fun kid -> (kid, fresh_existential ~name:(string_of_id (id_of_kid kid)) ())) kids in - let nc = List.fold_left (fun nc (kid, fresh) -> nc_subst_nexp kid (Nexp_var fresh) nc) nc fresh_kids in - let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst_nexp kid (Nexp_var fresh) typ) typ fresh_kids in + let nc = List.fold_left (fun nc (kid, fresh) -> constraint_subst kid (arg_nexp (nvar fresh)) nc) nc fresh_kids in + let typ = List.fold_left (fun typ (kid, fresh) -> typ_subst kid (arg_nexp (nvar fresh)) typ) typ fresh_kids in Some (List.map snd fresh_kids, nc, typ) | _ -> None @@ -1282,65 +1293,18 @@ this is equivalent to which is then a problem we can feed to the constraint solver expecting unsat. *) -let rec nexp_constraint env var_of (Nexp_aux (nexp, l)) = - match nexp with - | Nexp_id v -> nexp_constraint env var_of (Env.get_num_def v env) - | Nexp_var kid -> Constraint.variable (var_of kid) - | Nexp_constant c -> Constraint.constant c - | Nexp_times (nexp1, nexp2) -> Constraint.mult (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2) - | Nexp_sum (nexp1, nexp2) -> Constraint.add (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2) - | Nexp_minus (nexp1, nexp2) -> Constraint.sub (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2) - | Nexp_exp nexp -> Constraint.pow2 (nexp_constraint env var_of nexp) - | Nexp_neg nexp -> Constraint.sub (Constraint.constant (Big_int.of_int 0)) (nexp_constraint env var_of nexp) - | Nexp_app (id, nexps) -> Constraint.app (Env.get_smt_op id env) (List.map (nexp_constraint env var_of) nexps) - -let rec nc_constraint env var_of (NC_aux (nc, l)) = - match nc with - | NC_equal (nexp1, nexp2) -> Constraint.eq (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2) - | NC_not_equal (nexp1, nexp2) -> Constraint.neq (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2) - | NC_bounded_ge (nexp1, nexp2) -> Constraint.gteq (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2) - | NC_bounded_le (nexp1, nexp2) -> Constraint.lteq (nexp_constraint env var_of nexp1) (nexp_constraint env var_of nexp2) - | NC_set (_, []) -> Constraint.literal false - | NC_set (kid, (int :: ints)) -> - List.fold_left Constraint.disj - (Constraint.eq (nexp_constraint env var_of (nvar kid)) (Constraint.constant int)) - (List.map (fun i -> Constraint.eq (nexp_constraint env var_of (nvar kid)) (Constraint.constant i)) ints) - | NC_or (nc1, nc2) -> Constraint.disj (nc_constraint env var_of nc1) (nc_constraint env var_of nc2) - | NC_and (nc1, nc2) -> Constraint.conj (nc_constraint env var_of nc1) (nc_constraint env var_of nc2) - | NC_false -> Constraint.literal false - | NC_true -> Constraint.literal true - -let rec nc_constraints env var_of ncs = - match ncs with - | [] -> Constraint.literal true - | [nc] -> nc_constraint env var_of nc - | (nc :: ncs) -> - Constraint.conj (nc_constraint env var_of nc) (nc_constraints env var_of ncs) - -let prove_z3' env constr = - let module Bindings = Map.Make(Kid) in - let bindings = ref Bindings.empty in - let fresh_var kid = - let n = Bindings.cardinal !bindings in - bindings := Bindings.add kid n !bindings; - n - in - let var_of kid = - try Bindings.find kid !bindings with - | Not_found -> fresh_var kid - in - let constr = Constraint.conj (nc_constraints env var_of (Env.get_constraints env)) (constr var_of) in - match Constraint.call_z3 constr with +let prove_z3 env (NC_aux (_, l) as nc) = + let vars = Env.get_typ_vars env in + let vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) vars in + let ncs = Env.get_constraints env in + match Constraint.call_z3 l vars (List.fold_left nc_and (nc_not nc) ncs) with | Constraint.Unsat -> typ_debug (lazy "unsat"); true | Constraint.Sat -> typ_debug (lazy "sat"); false | Constraint.Unknown -> typ_debug (lazy "unknown"); false -let prove_z3 env nc = - typ_print (lazy (Util.("Prove " |> red |> clear) ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_n_constraint nc)); - prove_z3' env (fun var_of -> Constraint.negate (nc_constraint env var_of nc)) +let solve env nexp = failwith "WIP" -let solve env nexp = - typ_print (lazy ("Solve " ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_nexp nexp ^ " = ?")); + (* typ_print (lazy ("Solve " ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_nexp nexp ^ " = ?")); match nexp with | Nexp_aux (Nexp_constant n,_) -> Some n | _ -> @@ -1359,8 +1323,11 @@ let solve env nexp = (nc_constraint env var_of (nc_eq (nvar (mk_kid "solve#")) nexp)) in Constraint.solve_z3 constr (var_of (mk_kid "solve#")) + *) -let prove env (NC_aux (nc_aux, _) as nc) = +let prove env nc = + typ_print (lazy (Util.("Prove " |> red |> clear) ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_n_constraint nc)); + let (NC_aux (nc_aux, _) as nc) = Env.expand_constraint_synonyms env nc in let compare_const f (Nexp_aux (n1, _)) (Nexp_aux (n2, _)) = match n1, n2 with | Nexp_constant c1, Nexp_constant c2 when f c1 c2 -> true @@ -1499,61 +1466,91 @@ let typ_identical env typ1 typ2 = in typ_identical' (Env.expand_synonyms env typ1) (Env.expand_synonyms env typ2) -type uvar = - | U_nexp of nexp - | U_order of order - | U_typ of typ +exception Unification_error of l * string;; + +let unify_error l str = raise (Unification_error (l, str)) + +let merge_unifiers l kid uvar1 uvar2 = + match uvar1, uvar2 with + | Some (Typ_arg_aux (Typ_arg_nexp n1, _)), Some (Typ_arg_aux (Typ_arg_nexp n2, _)) -> + if nexp_identical n1 n2 then + Some (arg_nexp n1) + else + unify_error l ("Multiple non-identical unifiers for " ^ string_of_kid kid + ^ ": " ^ string_of_nexp n1 ^ " and " ^ string_of_nexp n2) + | Some _, Some _ -> unify_error l "Multiple non-identical non-nexp unifiers" + | None, Some u2 -> Some u2 + | Some u1, None -> Some u1 + | None, None -> None + +let merge_uvars l unifiers1 unifiers2 = + KBindings.merge (merge_unifiers l) unifiers1 unifiers2 + +let rec unify_typ l env goals (Typ_aux (aux1, _) as typ1) (Typ_aux (aux2, _) as typ2) = + match aux1, aux2 with + | Typ_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_typ typ2) -let uvar_subst_nexp sv subst = function - | U_nexp nexp -> U_nexp (nexp_subst sv subst nexp) - | U_typ typ -> U_typ (typ_subst_nexp sv subst typ) - | U_order ord -> U_order ord + | Typ_app (range, [Typ_arg_aux (Typ_arg_nexp n1, _); Typ_arg_aux (Typ_arg_nexp n2, _)]), + Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp m, _)]) + when string_of_id range = "range" && string_of_id atom = "atom" -> + merge_uvars l (unify_nexp l env goals n1 m) (unify_nexp l env goals n2 m) -let uvar_subst_typ sv subst = function - | U_nexp nexp -> U_nexp nexp - | U_typ typ -> U_typ (typ_subst_typ sv subst typ) - | U_order ord -> U_order ord + | Typ_app (id1, args1), Typ_app (id2, args2) when List.length args1 = List.length args2 && Id.compare id1 id2 = 0 -> + List.fold_left (merge_uvars l) KBindings.empty (List.map2 (unify_typ_arg l env goals) args1 args2) -let uvar_subst_order sv subst = function - | U_nexp nexp -> U_nexp nexp - | U_typ typ -> U_typ (typ_subst_order sv subst typ) - | U_order ord -> U_order (order_subst sv subst ord) + | Typ_id id1, Typ_id id2 when Id.compare id1 id2 = 0 -> KBindings.empty -exception Unification_error of l * string;; + | Typ_tup typs1, Typ_tup typs2 when List.length typs1 = List.length typs2 -> + List.fold_left (merge_uvars l) KBindings.empty (List.map2 (unify_typ l env goals) typs1 typs2) -let unify_error l str = raise (Unification_error (l, str)) + | _, _ -> unify_error l ("Cound not unify " ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2) + +and unify_typ_arg l env goals (Typ_arg_aux (aux1, _) as typ_arg1) (Typ_arg_aux (aux2, _) as typ_arg2) = + match aux1, aux2 with + | Typ_arg_typ typ1, Typ_arg_typ typ2 -> unify_typ l env goals typ1 typ2 + | Typ_arg_nexp nexp1, Typ_arg_nexp nexp2 -> unify_nexp l env goals nexp1 nexp2 + | Typ_arg_order ord1, Typ_arg_order ord2 -> unify_order l goals ord1 ord2 + | _, _ -> unify_error l ("Could not unify type arguments " ^ string_of_typ_arg typ_arg1 ^ " and " ^ string_of_typ_arg typ_arg2) + +and unify_order l goals (Ord_aux (aux1, _) as ord1) (Ord_aux (aux2, _) as ord2) = + typ_print (lazy (Util.("Unify order " |> magenta |> clear) ^ string_of_order ord1 ^ " and " ^ string_of_order ord2)); + match aux1, aux2 with + | Ord_var v, _ when KidSet.mem v goals -> KBindings.singleton v (arg_order ord2) + | Ord_inc, Ord_inc -> KBindings.empty + | Ord_dec, Ord_dec -> KBindings.empty + | _, _ -> unify_error l ("Cound not unify " ^ string_of_order ord1 ^ " and " ^ string_of_order ord2) -let rec unify_nexps l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) = +and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_aux2, _) as nexp2) = typ_debug (lazy ("UNIFYING NEXPS " ^ string_of_nexp nexp1 ^ " AND " ^ string_of_nexp nexp2 ^ " FOR GOALS " ^ string_of_list ", " string_of_kid (KidSet.elements goals))); if KidSet.is_empty (KidSet.inter (nexp_frees nexp1) goals) then begin if prove env (NC_aux (NC_equal (nexp1, nexp2), Parse_ast.Unknown)) - then None + then KBindings.empty else unify_error l ("Nexp " ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 ^ " are not equal") end else match nexp_aux1 with | Nexp_id v -> unify_error l "Unimplemented Nexp_id in unify nexp" - | Nexp_var kid when KidSet.mem kid goals -> Some (kid, nexp2) + | Nexp_var kid when KidSet.mem kid goals -> KBindings.singleton kid (arg_nexp nexp2) | Nexp_constant c1 -> begin match nexp_aux2 with - | Nexp_constant c2 -> if c1 = c2 then None else unify_error l "Constants are not the same" + | Nexp_constant c2 -> if c1 = c2 then KBindings.empty else unify_error l "Constants are not the same" | _ -> unify_error l "Unification error" end | Nexp_sum (n1a, n1b) -> if KidSet.is_empty (nexp_frees n1b) - then unify_nexps l env goals n1a (nminus nexp2 n1b) + then unify_nexp l env goals n1a (nminus nexp2 n1b) else if KidSet.is_empty (nexp_frees n1a) - then unify_nexps l env goals n1b (nminus nexp2 n1a) + then unify_nexp l env goals n1b (nminus nexp2 n1a) else unify_error l ("Both sides of Int expression " ^ string_of_nexp nexp1 ^ " contain free type variables so it cannot be unified with " ^ string_of_nexp nexp2) | Nexp_minus (n1a, n1b) -> if KidSet.is_empty (nexp_frees n1b) - then unify_nexps l env goals n1a (nsum nexp2 n1b) - else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) + then unify_nexp l env goals n1a (nsum nexp2 n1b) + else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) | Nexp_times (n1a, n1b) -> (* If we have SMT operations div and mod, then we can use the property that @@ -1564,20 +1561,20 @@ let rec unify_nexps l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (ne if Env.have_smt_op (mk_id "div") env && Env.have_smt_op (mk_id "mod") env then let valid n c = prove env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove env (nc_neq c (nint 0)) in if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then - unify_nexps l env goals n1a (napp (mk_id "div") [nexp2; n1b]) + unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b]) else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then - unify_nexps l env goals n1b (napp (mk_id "div") [nexp2; n1a]) + unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) else if KidSet.is_empty (nexp_frees n1a) then begin match nexp_aux2 with | Nexp_times (n2a, n2b) when prove env (NC_aux (NC_equal (n1a, n2a), Parse_ast.Unknown)) -> - unify_nexps l env goals n1b n2b + unify_nexp l env goals n1b n2b | Nexp_constant c2 -> begin match n1a with | Nexp_aux (Nexp_constant c1,_) when Big_int.equal (Big_int.modulus c2 c1) Big_int.zero -> - unify_nexps l env goals n1b (mk_nexp (Nexp_constant (Big_int.div c2 c1))) + unify_nexp l env goals n1b (nconstant (Big_int.div c2 c1)) | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) end | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) @@ -1586,148 +1583,30 @@ let rec unify_nexps l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (ne begin match nexp_aux2 with | Nexp_times (n2a, n2b) when prove env (NC_aux (NC_equal (n1b, n2b), Parse_ast.Unknown)) -> - unify_nexps l env goals n1a n2a + unify_nexp l env goals n1a n2a | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) end else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) -let string_of_uvar = function - | U_nexp n -> string_of_nexp n - | U_order o -> string_of_order o - | U_typ typ -> string_of_typ typ - -let unify_order l (Ord_aux (ord_aux1, _) as ord1) (Ord_aux (ord_aux2, _) as ord2) = - typ_debug (lazy ("UNIFYING ORDERS " ^ string_of_order ord1 ^ " AND " ^ string_of_order ord2)); - match ord_aux1, ord_aux2 with - | Ord_var kid, _ -> KBindings.singleton kid (U_order ord2) - | Ord_inc, Ord_inc -> KBindings.empty - | Ord_dec, Ord_dec -> KBindings.empty - | _, _ -> unify_error l (string_of_order ord1 ^ " cannot be unified with " ^ string_of_order ord2) +let unify l env typ1 typ2 goals = + typ_print (lazy (Util.("Unify " |> magenta |> clear) ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2)); + let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in + unify_typ l env goals typ1 typ2 let subst_unifiers unifiers typ = - let subst_unifier typ (kid, uvar) = - match uvar with - | U_nexp nexp -> typ_subst_nexp kid (unaux_nexp nexp) typ - | U_order ord -> typ_subst_order kid (unaux_order ord) typ - | U_typ subst -> typ_subst_typ kid (unaux_typ subst) typ - in - List.fold_left subst_unifier typ (KBindings.bindings unifiers) - -let subst_args_unifiers unifiers typ_args = - let subst_unifier typ_args (kid, uvar) = - match uvar with - | U_nexp nexp -> List.map (typ_subst_arg_nexp kid (unaux_nexp nexp)) typ_args - | U_order ord -> List.map (typ_subst_arg_order kid (unaux_order ord)) typ_args - | U_typ subst -> List.map (typ_subst_arg_typ kid (unaux_typ subst)) typ_args - in - List.fold_left subst_unifier typ_args (KBindings.bindings unifiers) - -let subst_uvar_unifiers unifiers uvar = - let subst_unifier uvar' (kid, uvar) = - match uvar with - | U_nexp nexp -> uvar_subst_nexp kid (unaux_nexp nexp) uvar' - | U_order ord -> uvar_subst_order kid (unaux_order ord) uvar' - | U_typ subst -> uvar_subst_typ kid (unaux_typ subst) uvar' - in - List.fold_left subst_unifier uvar (KBindings.bindings unifiers) - -let merge_unifiers l kid uvar1 uvar2 = - match uvar1, uvar2 with - | Some (U_nexp n1), Some (U_nexp n2) -> - if nexp_identical n1 n2 then Some (U_nexp n1) - else unify_error l ("Multiple non-identical unifiers for " ^ string_of_kid kid - ^ ": " ^ string_of_nexp n1 ^ " and " ^ string_of_nexp n2) - | Some _, Some _ -> unify_error l "Multiple non-identical non-nexp unifiers" - | None, Some u2 -> Some u2 - | Some u1, None -> Some u1 - | None, None -> None + List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ (KBindings.bindings unifiers) -let rec unify l env typ1 typ2 = - typ_print (lazy (Util.("Unify " |> magenta |> clear) ^ string_of_typ typ1 ^ " with " ^ string_of_typ typ2)); - let goals = KidSet.inter (KidSet.diff (typ_frees typ1) (typ_frees typ2)) (typ_frees typ1) in - - let rec unify_typ l (Typ_aux (typ1_aux, _) as typ1) (Typ_aux (typ2_aux, _) as typ2) = - typ_debug (lazy ("UNIFYING TYPES " ^ string_of_typ typ1 ^ " AND " ^ string_of_typ typ2)); - match typ1_aux, typ2_aux with - | Typ_internal_unknown, _ - | _, Typ_internal_unknown when Env.allow_unknowns env -> KBindings.empty - | Typ_id v1, Typ_id v2 -> - if Id.compare v1 v2 = 0 then KBindings.empty - else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2) - | Typ_id v1, Typ_app (f2, []) -> - if Id.compare v1 f2 = 0 then KBindings.empty - else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2) - | Typ_app (f1, []), Typ_id v2 -> - if Id.compare f1 v2 = 0 then KBindings.empty - else unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2) - | Typ_var kid, _ when KidSet.mem kid goals -> KBindings.singleton kid (U_typ typ2) - | Typ_var kid1, Typ_var kid2 when Kid.compare kid1 kid2 = 0 -> KBindings.empty - | Typ_tup typs1, Typ_tup typs2 -> - begin - try List.fold_left (KBindings.merge (merge_unifiers l)) KBindings.empty (List.map2 (unify_typ l) typs1 typs2) with - | Invalid_argument _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2 - ^ " tuple type is of different length") - end - | Typ_app (f1, [arg1]), Typ_app (f2, [arg2a; arg2b]) - when Id.compare (mk_id "atom") f1 = 0 && Id.compare (mk_id "range") f2 = 0 -> - unify_typ_arg_list 0 KBindings.empty [] [] [arg1; arg1] [arg2a; arg2b] - | Typ_app (f1, [arg1a; arg1b]), Typ_app (f2, [arg2]) - when Id.compare (mk_id "range") f1 = 0 && Id.compare (mk_id "atom") f2 = 0 -> - unify_typ_arg_list 0 KBindings.empty [] [] [arg1a; arg1b] [arg2; arg2] - | Typ_app (f1, args1), Typ_app (f2, args2) when Id.compare f1 f2 = 0 -> - unify_typ_arg_list 0 KBindings.empty [] [] args1 args2 - | _, _ -> unify_error l (string_of_typ typ1 ^ " cannot be unified with " ^ string_of_typ typ2) - - and unify_typ_arg_list unified acc uargs1 uargs2 args1 args2 = - match args1, args2 with - | [], [] when unified = 0 && List.length uargs1 > 0 -> - unify_error l "Could not unify arg lists" (*FIXME improve error *) - | [], [] when unified > 0 && List.length uargs1 > 0 -> unify_typ_arg_list 0 acc [] [] uargs1 uargs2 - | [], [] when List.length uargs1 = 0 -> acc - | (a1 :: a1s), (a2 :: a2s) -> - begin - let unifiers, success = - try unify_typ_args l a1 a2, true with - | Unification_error _ -> KBindings.empty, false - in - let a1s = subst_args_unifiers unifiers a1s in - let a2s = subst_args_unifiers unifiers a2s in - let uargs1 = subst_args_unifiers unifiers uargs1 in - let uargs2 = subst_args_unifiers unifiers uargs2 in - if success - then unify_typ_arg_list (unified + 1) (KBindings.merge (merge_unifiers l) unifiers acc) uargs1 uargs2 a1s a2s - else unify_typ_arg_list unified acc (a1 :: uargs1) (a2 :: uargs2) a1s a2s - end - | _, _ -> unify_error l "Cannot unify type lists of different length" - - and unify_typ_args l (Typ_arg_aux (typ_arg_aux1, _) as typ_arg1) (Typ_arg_aux (typ_arg_aux2, _) as typ_arg2) = - match typ_arg_aux1, typ_arg_aux2 with - | Typ_arg_nexp n1, Typ_arg_nexp n2 -> - begin - match unify_nexps l env goals (nexp_simp n1) (nexp_simp n2) with - | Some (kid, unifier) -> KBindings.singleton kid (U_nexp (nexp_simp unifier)) - | None -> KBindings.empty - end - | Typ_arg_typ typ1, Typ_arg_typ typ2 -> unify_typ l typ1 typ2 - | Typ_arg_order ord1, Typ_arg_order ord2 -> unify_order l ord1 ord2 - | _, _ -> unify_error l (string_of_typ_arg typ_arg1 ^ " cannot be unified with type argument " ^ string_of_typ_arg typ_arg2) - in - - match destruct_exist env typ2 with - | Some (kids, nc, typ2) -> - let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in - let (unifiers, _, _) = unify l env typ1 typ2 in - typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); - unifiers, kids, Some nc - | None -> - let typ1, typ2 = Env.expand_synonyms env typ1, Env.expand_synonyms env typ2 in - unify_typ l typ1 typ2, [], None +let instantiate_quant (v, arg) (QI_aux (aux, l) as qi) = + match aux with + | QI_id kopt when Kid.compare (kopt_kid kopt) v = 0 -> + typ_debug (lazy ("Instantiated " ^ string_of_quant_item qi)); + None + | QI_id _ -> Some qi + | QI_const nc -> Some (QI_aux (QI_const (constraint_subst v arg nc), l)) -let merge_uvars l unifiers1 unifiers2 = - try KBindings.merge (merge_unifiers l) unifiers1 unifiers2 - with - | Unification_error (_, m) -> typ_error l ("Could not merge unification variables: " ^ m) +let instantiate_quants quants unifier = + List.map (instantiate_quant unifier) quants |> Util.option_these (**************************************************************************) (* 3.5. Subtyping with existentials *) @@ -1750,16 +1629,6 @@ let destruct_atom_kid env typ = when string_of_id f = "range" && Kid.compare kid1 kid2 = 0 -> Some kid1 | _ -> None -let nc_subst_uvar kid uvar nc = - match uvar with - | U_nexp nexp -> nc_subst_nexp kid (unaux_nexp nexp) nc - | _ -> nc - -let uv_nexp_constraint env (kid, uvar) = - match uvar with - | U_nexp nexp -> Env.add_constraint (nc_eq (nvar kid) nexp) env - | _ -> env - (* The kid_order function takes a set of Int-kinded kids, and returns a list of those kids in the order they appear in a type, as well as a set containing all the kids that did not occur in the type. We @@ -1809,8 +1678,8 @@ let rec alpha_equivalent env typ1 typ2 = | Typ_exist (kids, nc, typ) -> let (kids, _) = kid_order (KidSet.of_list kids) typ in let kids = List.map (fun kid -> (kid, new_kid ())) kids in - let nc = List.fold_left (fun nc (kid, nk) -> nc_subst_nexp kid (Nexp_var nk) nc) nc kids in - let typ = List.fold_left (fun nc (kid, nk) -> typ_subst_nexp kid (Nexp_var nk) nc) typ kids in + let nc = List.fold_left (fun nc (kid, nk) -> constraint_subst kid (arg_nexp (nvar nk)) nc) nc kids in + let typ = List.fold_left (fun nc (kid, nk) -> typ_subst kid (arg_nexp (nvar nk)) nc) typ kids in let kids = List.map snd kids in Typ_exist (kids, nc, typ) | Typ_app (id, args) -> @@ -1836,6 +1705,11 @@ let unwrap_exist env typ = | Some (kids, nc, typ) -> (kids, nc, typ) | None -> ([], nc_true, typ) +let unifier_constraint env (v, arg) = + match arg with + | Typ_arg_aux (Typ_arg_nexp nexp, _) -> Env.add_constraint (nc_eq (nvar v) nexp) env + | _ -> env + let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as typ2) = typ_print (lazy (("Subtype " |> Util.green |> Util.clear) ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2)); match typ_aux1, typ_aux2 with @@ -1854,12 +1728,9 @@ let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as t let env = add_existential l kids1 nc1 env in let env = add_typ_vars l (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2))) env in let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in + if not (kids2 = []) then typ_error l "Universally quantified constraint generated" else (); let env = Env.add_constraint (nc_eq nexp1 nexp2) env in - let constr var_of = - Constraint.forall (List.map var_of kids2) - (nc_constraint env var_of (nc_negate nc2)) - in - if prove_z3' env constr then () + if prove env nc2 then () else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env)) | _, _ -> match destruct_exist env typ1, unwrap_exist env (Env.canonicalize env typ2) with @@ -1869,24 +1740,14 @@ let rec subtyp l env (Typ_aux (typ_aux1, _) as typ1) (Typ_aux (typ_aux2, _) as t typ_debug (lazy "Subtype check with unification"); let env = add_typ_vars l kids env in let kids' = KidSet.elements (KidSet.diff (KidSet.of_list kids) (typ_frees typ2)) in - let unifiers, existential_kids, existential_nc = - try unify l env typ2 typ1 with + if not (kids' = []) then typ_error l "Universally quantified constraint generated" else (); + let unifiers = + try unify l env typ2 typ1 (KidSet.of_list kids) with | Unification_error (_, m) -> typ_error l m in - let nc = List.fold_left (fun nc (kid, uvar) -> nc_subst_uvar kid uvar nc) nc (KBindings.bindings unifiers) in - let env = List.fold_left uv_nexp_constraint env (KBindings.bindings unifiers) in - let env = match existential_kids, existential_nc with - | [], None -> env - | _, Some enc -> - let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env existential_kids in - Env.add_constraint enc env - | _, None -> assert false (* Cannot have existential_kids without existential_nc *) - in - let constr var_of = - Constraint.forall (List.map var_of kids') - (nc_constraint env var_of (nc_negate nc)) - in - if prove_z3' env constr then () + let nc = List.fold_left (fun nc (kid, uvar) -> constraint_subst kid uvar nc) nc (KBindings.bindings unifiers) in + let env = List.fold_left unifier_constraint env (KBindings.bindings unifiers) in + if prove env nc then () else typ_raise l (Err_subtype (typ1, typ2, Env.get_constraints env, Env.get_typ_var_locs env)) let typ_equality l env typ1 typ2 = @@ -1967,40 +1828,13 @@ let is_typ_kid kid = function | KOpt_aux (KOpt_kind (K_aux (K_type, _), kid'), _) -> Kid.compare kid kid' = 0 | _ -> false -let rec instantiate_quants quants kid uvar = match quants with - | [] -> [] - | ((QI_aux (QI_id kinded_id, _) as quant) :: quants) -> - typ_debug (lazy ("instantiating quant " ^ string_of_quant_item quant)); - begin - match uvar with - | U_nexp nexp -> - if is_nat_kid kid kinded_id - then instantiate_quants quants kid uvar - else quant :: instantiate_quants quants kid uvar - | U_order ord -> - if is_order_kid kid kinded_id - then instantiate_quants quants kid uvar - else quant :: instantiate_quants quants kid uvar - | U_typ typ -> - if is_typ_kid kid kinded_id - then instantiate_quants quants kid uvar - else quant :: instantiate_quants quants kid uvar - end - | ((QI_aux (QI_const nc, l)) :: quants) -> - begin - match uvar with - | U_nexp nexp -> - QI_aux (QI_const (nc_subst_nexp kid (unaux_nexp nexp) nc), l) :: instantiate_quants quants kid uvar - | _ -> (QI_aux (QI_const nc, l)) :: instantiate_quants quants kid uvar - end - let instantiate_simple_equations = let rec find_eqs kid (NC_aux (nc,_)) = match nc with | NC_equal (Nexp_aux (Nexp_var kid',_), nexp) when Kid.compare kid kid' == 0 && not (KidSet.mem kid (nexp_frees nexp)) -> - [U_nexp nexp] + [arg_nexp nexp] | NC_and (nexp1, nexp2) -> find_eqs kid nexp1 @ find_eqs kid nexp2 | _ -> [] @@ -2275,7 +2109,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | ((E_aux (E_if (cond, (E_aux (E_throw _, _) | E_aux (E_block [E_aux (E_throw _, _)], _)), _), _) as exp) :: exps) -> let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in - let env = add_opt_constraint (option_map nc_negate (assert_constraint env false cond')) env in + let env = add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env in texp :: check_block l env exps typ | (exp :: exps) -> let texp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in @@ -2318,7 +2152,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ in let check_fexp (FE_aux (FE_Fexp (field, exp), (l, ()))) = let (typq, rectyp_q, field_typ, _) = Env.get_accessor rectyp_id field env in - let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q typ with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in + let unifiers = try unify l env rectyp_q typ (tyvars_of_typ rectyp_q) with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in let field_typ' = subst_unifiers unifiers field_typ in let checked_exp = crule check_exp env exp field_typ' in FE_aux (FE_Fexp (field, checked_exp), (l, None)) @@ -2333,7 +2167,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ in let check_fexp (FE_aux (FE_Fexp (field, exp), (l, ()))) = let (typq, rectyp_q, field_typ, _) = Env.get_accessor rectyp_id field env in - let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q typ with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in + let unifiers = try unify l env rectyp_q typ (tyvars_of_typ rectyp_q) with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in let field_typ' = subst_unifiers unifiers field_typ in let checked_exp = crule check_exp env exp field_typ' in FE_aux (FE_Fexp (field, checked_exp), (l, None)) @@ -2413,7 +2247,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | E_if (cond, then_branch, else_branch), _ -> let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in let then_branch' = crule check_exp (add_opt_constraint (assert_constraint env true cond') env) then_branch typ in - let else_branch' = crule check_exp (add_opt_constraint (option_map nc_negate (assert_constraint env false cond')) env) else_branch typ in + let else_branch' = crule check_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch typ in annot_exp (E_if (cond', then_branch', else_branch')) typ | E_exit exp, _ -> let checked_exp = crule check_exp env exp (mk_typ (Typ_id (mk_id "unit"))) in @@ -2563,7 +2397,7 @@ and type_coercion env (E_aux (_, (l, _)) as annotated_exp) typ = required that exp_typ unifies with typ. Returns the annotated coercion as with type_coercion and also a set of unifiers, or throws a unification error *) -and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ = +and type_coercion_unify env goals (E_aux (_, (l, _)) as annotated_exp) typ = let strip exp_aux = strip_exp (E_aux (exp_aux, (Parse_ast.Unknown, None))) in let annot_exp exp typ' = E_aux (exp, (l, Some ((env, typ', no_effect), Some typ))) in let switch_typ exp typ = match exp with @@ -2576,8 +2410,8 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ = typ_print (lazy ("Casting with " ^ string_of_id cast ^ " expression " ^ string_of_exp annotated_exp ^ " for unification")); try let inferred_cast = irule infer_exp (Env.no_casts env) (strip (E_app (cast, [annotated_exp]))) in - let ityp = typ_of inferred_cast in - annot_exp (E_cast (ityp, inferred_cast)) ityp, unify l env typ ityp + let ityp, env = bind_existential l (typ_of inferred_cast) env in + inferred_cast, unify l env typ ityp goals, env with | Type_error (_, err) -> try_casts casts | Unification_error (_, err) -> try_casts casts @@ -2586,7 +2420,8 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ = begin try typ_debug (lazy "PERFORMING COERCING UNIFICATION"); - annotated_exp, unify l env typ (typ_of annotated_exp) + let atyp, env = bind_existential l (typ_of annotated_exp) env in + annotated_exp, unify l env typ atyp goals, env with | Unification_error (_, m) when Env.allow_casts env -> let casts = filter_casts env (typ_of annotated_exp) typ (Env.get_casts env) in @@ -2708,11 +2543,11 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) -> begin try + let goals = List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_quant_item quants) in typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for pattern " ^ string_of_typ typ)); - let unifiers, _, _ (* FIXME! *) = unify l env ret_typ typ in - typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); + let unifiers = unify l env ret_typ typ goals in let arg_typ' = subst_unifiers unifiers arg_typ in - let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in + let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in if (match quants' with [] -> false | _ -> true) then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) else (); @@ -2742,12 +2577,10 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) try typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for pattern " ^ string_of_typ typ)); - let unifiers, _, _ (* FIXME! *) = unify l env typ2 typ in - - typ_debug (lazy ("unifiers: " ^ string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); - + (* FIXME: There's no obvious goals here *) + let unifiers = unify l env typ2 typ (tyvars_of_typ typ2) in let arg_typ' = subst_unifiers unifiers typ1 in - let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in + let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in if (match quants' with [] -> false | _ -> true) then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) else (); @@ -2763,10 +2596,9 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) try typ_debug (lazy "Unifying mapping forwards failed, trying backwards."); typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for pattern " ^ string_of_typ typ)); - let unifiers, _, _ (* FIXME! *) = unify l env typ1 typ in - typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); + let unifiers = unify l env typ1 typ (tyvars_of_typ typ1) in let arg_typ' = subst_unifiers unifiers typ2 in - let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in + let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in if (match quants' with [] -> false | _ -> true) then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) else (); @@ -2955,7 +2787,7 @@ and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as | Typ_aux (Typ_id rectyp_id, _) | Typ_aux (Typ_app (rectyp_id, _), _) when Env.is_record rectyp_id env -> let eff = if is_register then mk_effect [BE_wreg] else no_effect in let (typq, rectyp_q, field_typ, _) = Env.get_accessor rectyp_id field env in - let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q regtyp with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in + let unifiers = try unify l env rectyp_q regtyp (tyvars_of_typ rectyp_q) with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in let field_typ' = subst_unifiers unifiers field_typ in let checked_exp = crule check_exp env exp field_typ' in annot_assign (annot_lexp (LEXP_field (annot_lexp_effect inferred_flexp regtyp eff, field)) field_typ') checked_exp, env @@ -3190,7 +3022,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = in let check_fexp (FE_aux (FE_Fexp (field, exp), (l, ()))) = let (typq, rectyp_q, field_typ, _) = Env.get_accessor rectyp_id field env in - let unifiers, _, _ (* FIXME *) = try unify l env rectyp_q typ with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in + let unifiers = try unify l env rectyp_q typ (tyvars_of_typ rectyp_q) with Unification_error (l, m) -> typ_error l ("Unification error: " ^ m) in let field_typ' = subst_unifiers unifiers field_typ in let inferred_exp = crule check_exp env exp field_typ' in FE_aux (FE_Fexp (field, inferred_exp), (l, None)) @@ -3261,7 +3093,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | E_if (cond, then_branch, else_branch) -> let cond' = crule check_exp env cond (mk_typ (Typ_id (mk_id "bool"))) in let then_branch' = irule infer_exp (add_opt_constraint (assert_constraint env true cond') env) then_branch in - let else_branch' = crule check_exp (add_opt_constraint (option_map nc_negate (assert_constraint env false cond')) env) else_branch (typ_of then_branch') in + let else_branch' = crule check_exp (add_opt_constraint (option_map nc_not (assert_constraint env false cond')) env) else_branch (typ_of then_branch') in annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch') | E_vector_access (v, n) -> infer_exp env (E_aux (E_app (mk_id "vector_access", [v; n]), (l, ()))) | E_vector_update (v, n, exp) -> infer_exp env (E_aux (E_app (mk_id "vector_update", [v; n; exp]), (l, ()))) @@ -3337,160 +3169,104 @@ and instantiation_of_without_type (E_aux (exp_aux, (l, _)) as exp) = | E_app (f, xs) -> snd (infer_funapp' l (Env.no_casts env) f (Env.get_val_spec f env) (List.map strip_exp xs) None) | _ -> invalid_arg ("instantiation_of expected application, got " ^ string_of_exp exp) -and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ = - let annot_exp exp typ eff = E_aux (exp, (l, Some ((env, typ, eff), ret_ctx_typ))) in - let switch_annot env typ = function - | (E_aux (exp, (l, Some (_, _, eff)))) -> E_aux (exp, (l, Some (env, typ, eff))) - | _ -> failwith "Cannot switch annot for unannotated function" - in - let all_unifiers = ref KBindings.empty in - let ex_goal = ref None in - let prove_goal env = match !ex_goal with - | Some goal when prove env goal -> () - | Some goal -> typ_error l ("Could not prove existential goal: " ^ string_of_n_constraint goal) - | None -> () - in +and infer_funapp' l env f (typq, f_typ) xs expected_ret_typ = + let annot_exp exp typ eff = E_aux (exp, (l, Some ((env, typ, eff), expected_ret_typ))) in + let is_bound env kid = KBindings.mem kid (Env.get_typ_vars env) in + + (* First we record all the type variables when we start checking the + application, so we can distinguish them from existentials + introduced by instantiating function arguments later. *) let universals = Env.get_typ_vars env in let universal_constraints = Env.get_constraints env in - let is_bound kid env = KBindings.mem kid (Env.get_typ_vars env) in - let rec number n = function - | [] -> [] - | (x :: xs) -> (n, x) :: number (n + 1) xs - in - let solve_quant env = function - | QI_aux (QI_id _, _) -> false - | QI_aux (QI_const nc, _) -> prove env nc - in - let record_unifiers unifiers = - let previous_unifiers = !all_unifiers in - let updated_unifiers = KBindings.map (subst_uvar_unifiers unifiers) previous_unifiers in - all_unifiers := merge_uvars l updated_unifiers unifiers; - in - let rec instantiate env quants typs ret_typ args = - match typs, args with - | (utyps, []), (uargs, []) -> - begin - typ_debug (lazy ("Got unresolved args: " ^ string_of_list ", " (fun (_, exp) -> string_of_exp exp) uargs)); - if List.for_all (solve_quant env) quants - then - let iuargs = List.map2 (fun utyp (n, uarg) -> (n, crule check_exp env uarg utyp)) utyps uargs in - (iuargs, ret_typ, env) - else typ_raise l (Err_unresolved_quants (f, quants, Env.get_locals env, Env.get_constraints env)) - end - | (utyps, (typ :: typs)), (uargs, ((n, arg) :: args)) - when List.for_all (fun kid -> is_bound kid env) (KidSet.elements (typ_frees typ)) -> - begin - let carg = crule check_exp env arg typ in - let (iargs, ret_typ', env) = instantiate env quants (utyps, typs) ret_typ (uargs, args) in - ((n, carg) :: iargs, ret_typ', env) - end - | (utyps, (typ :: typs)), (uargs, ((n, arg) :: args)) -> - begin - typ_debug (lazy ("INSTANTIATE: " ^ string_of_exp arg ^ " with " ^ string_of_typ typ)); - let iarg = irule infer_exp env arg in - typ_debug (lazy ("INFER: " ^ string_of_exp arg ^ " type " ^ string_of_typ (typ_of iarg))); - try - (* If we get an existential when instantiating, we prepend - the identifier of the exisitential with the tag argN# to - denote that it was bound by the Nth argument to the - function. *) - let ex_tag = "arg" ^ string_of_int n ^ "#" in - let iarg, (unifiers, ex_kids, ex_nc) = type_coercion_unify env iarg typ in - typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); - typ_debug (lazy ("EX KIDS: " ^ string_of_list ", " string_of_kid ex_kids)); - let env = match ex_kids, ex_nc with - | [], None -> env - | _, Some enc -> - let enc = List.fold_left (fun nc kid -> nc_subst_nexp kid (Nexp_var (prepend_kid ex_tag kid)) nc) enc ex_kids in - let env = List.fold_left (fun env kid -> Env.add_typ_var l (prepend_kid ex_tag kid) K_int env) env ex_kids in - Env.add_constraint enc env - | _, None -> assert false (* Cannot have ex_kids without ex_nc *) - in - let tag_unifier uvar = List.fold_left (fun uvar kid -> uvar_subst_nexp kid (Nexp_var (prepend_kid ex_tag kid)) uvar) uvar ex_kids in - let unifiers = KBindings.map tag_unifier unifiers in - record_unifiers unifiers; - let utyps' = List.map (subst_unifiers unifiers) utyps in - let typs' = List.map (subst_unifiers unifiers) typs in - let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in - let ret_typ' = subst_unifiers unifiers ret_typ in - let (iargs, ret_typ'', env) = instantiate env quants' (utyps', typs') ret_typ' (uargs, args) in - ((n, iarg) :: iargs, ret_typ'', env) - with - | Unification_error (l, str) -> - typ_print (lazy ("Unification error: " ^ str)); - instantiate env quants (typ :: utyps, typs) ret_typ ((n, arg) :: uargs, args) - end - | (_, []), _ -> typ_error l ("Function " ^ string_of_id f ^ " applied to too many arguments") - | _, (_, []) -> typ_error l ("Function " ^ string_of_id f ^ " not applied to enough arguments") - in - let instantiate_ret env quants typs ret_typ = - match ret_ctx_typ with - | None -> (quants, typs, ret_typ, env) - | Some rct when is_exist (Env.expand_synonyms env rct) -> (quants, typs, ret_typ, env) - | Some rct -> - begin - typ_debug (lazy ("RCT is " ^ string_of_typ rct)); - typ_debug (lazy ("INSTANTIATE RETURN:" ^ string_of_typ ret_typ)); - let unifiers, ex_kids, ex_nc = - try unify l env ret_typ rct with - | Unification_error _ -> typ_debug (lazy "UERROR"); KBindings.empty, [], None - in - typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); - if ex_kids = [] then () else (typ_debug (lazy ("EX GOAL: " ^ string_of_option string_of_n_constraint ex_nc)); ex_goal := ex_nc); - record_unifiers unifiers; - let env = List.fold_left (fun env kid -> Env.add_typ_var l kid K_int env) env ex_kids in - let typs' = List.map (subst_unifiers unifiers) typs in - let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in - let ret_typ' = - match ex_nc with - | None -> subst_unifiers unifiers ret_typ - | Some nc -> mk_typ (Typ_exist (ex_kids, nc, subst_unifiers unifiers ret_typ)) - in - (quants', typs', ret_typ', env) - end - in + let quants, typ_args, typ_ret, eff = match Env.expand_synonyms env f_typ with - | Typ_aux (Typ_fn (typ_args, typ_ret, eff), _) -> quant_items typq, typ_args, typ_ret, eff + | Typ_aux (Typ_fn (typ_args, typ_ret, eff), _) -> ref (quant_items typq), typ_args, ref typ_ret, eff | _ -> typ_error l (string_of_typ f_typ ^ " is not a function type") in - let unifiers = instantiate_simple_equations quants in - typ_debug (lazy "Instantiating from equations"); - typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); all_unifiers := unifiers; - let typ_args = List.map (subst_unifiers unifiers) typ_args in - let quants = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in - let typ_ret = subst_unifiers unifiers typ_ret in - let quants, typ_args, typ_ret, env = - instantiate_ret env quants typ_args typ_ret + + typ_debug (lazy ("Quantifiers " ^ Util.string_of_list ", " string_of_quant_item !quants)); + + if not (List.length typ_args = List.length xs) then + typ_error l (Printf.sprintf "Function %s applied to %d args, expected %d" (string_of_id f) (List.length xs) (List.length typ_args)) + else (); + + let instantiate_quant (v, arg) (QI_aux (aux, l) as qi) = + match aux with + | QI_id kopt when Kid.compare (kopt_kid kopt) v = 0 -> None + | QI_id _ -> Some qi + | QI_const nc -> Some (QI_aux (QI_const (constraint_subst v arg nc), l)) in - let (xs_instantiated, typ_ret, env) = instantiate env quants ([], typ_args) typ_ret ([], number 0 xs) in - let xs_reordered = List.map snd (List.sort (fun (n, _) (m, _) -> compare n m) xs_instantiated) in - prove_goal env; + let typ_args = match expected_ret_typ with + | None -> typ_args + | Some expect when is_exist (Env.expand_synonyms env expect) || is_exist !typ_ret -> typ_args + | Some expect -> + let goals = List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_quant_item !quants) in + try + let unifiers = unify l env !typ_ret expect goals |> KBindings.bindings in + typ_debug (lazy (Util.("Unifiers " |> magenta |> clear) + ^ Util.string_of_list ", " (fun (v, arg) -> string_of_kid v ^ " => " ^ string_of_typ_arg arg) unifiers)); + List.iter (fun unifier -> quants := instantiate_quants !quants unifier) unifiers; + List.iter (fun (v, arg) -> typ_ret := typ_subst v arg !typ_ret) unifiers; + List.map (fun typ -> List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ unifiers) typ_args + with Unification_error _ -> typ_args + in + + (* We now iterate throught the function arguments, checking them and + instantiating quantifiers. *) + let instantiate env arg typ remaining_typs = + if KidSet.for_all (is_bound env) (tyvars_of_typ typ) then + crule check_exp env arg typ, remaining_typs, env + else + let goals = List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_quant_item !quants) in + let inferred_arg = irule infer_exp env arg in + let inferred_arg, unifiers, env = + try type_coercion_unify env goals inferred_arg typ with + | Unification_error (l, m) -> typ_error l m + in + let unifiers = KBindings.bindings unifiers in + typ_debug (lazy (Util.("Unifiers " |> magenta |> clear) + ^ Util.string_of_list ", " (fun (v, arg) -> string_of_kid v ^ " => " ^ string_of_typ_arg arg) unifiers)); + List.iter (fun unifier -> quants := instantiate_quants !quants unifier) unifiers; + List.iter (fun (v, arg) -> typ_ret := typ_subst v arg !typ_ret) unifiers; + let remaining_typs = + List.map (fun typ -> List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ unifiers) remaining_typs + in + inferred_arg, remaining_typs, env + in + let fold_instantiate (xs, args, env) x = + match args with + | arg :: remaining_args -> + let x, remaining_args, env = instantiate env x arg remaining_args in + (x :: xs, remaining_args, env) + | [] -> raise (Reporting.err_unreachable l __POS__ "Empty arguments during instantiation") + in + let xs, _, env = List.fold_left fold_instantiate ([], typ_args, env) xs in + let xs = List.rev xs in + + if not (List.for_all (solve_quant env) !quants) then + typ_raise l (Err_unresolved_quants (f, !quants, Env.get_locals env, Env.get_constraints env)) + else (); let ty_vars = List.map fst (KBindings.bindings (Env.get_typ_vars env)) in let existentials = List.filter (fun kid -> not (KBindings.mem kid universals)) ty_vars in let num_new_ncs = List.length (Env.get_constraints env) - List.length universal_constraints in - let ex_constraints = take num_new_ncs (Env.get_constraints env) in + let ex_constraints = take num_new_ncs (Env.get_constraints env) in typ_debug (lazy ("Existentials: " ^ string_of_list ", " string_of_kid existentials)); typ_debug (lazy ("Existential constraints: " ^ string_of_list ", " string_of_n_constraint ex_constraints)); let typ_ret = - if KidSet.is_empty (KidSet.of_list existentials) || KidSet.is_empty (typ_frees typ_ret) - then (typ_debug (lazy "Returning Existential"); typ_ret) - else mk_typ (Typ_exist (existentials, List.fold_left nc_and nc_true ex_constraints, typ_ret)) + if KidSet.is_empty (KidSet.of_list existentials) || KidSet.is_empty (typ_frees !typ_ret) + then !typ_ret + else mk_typ (Typ_exist (existentials, List.fold_left nc_and nc_true ex_constraints, !typ_ret)) in let typ_ret = simp_typ typ_ret in - let exp = annot_exp (E_app (f, xs_reordered)) typ_ret eff in - typ_debug (lazy ("RETURNING: " ^ string_of_typ (typ_of exp))); - match ret_ctx_typ with - | None -> - exp, !all_unifiers - | Some rct -> - let exp = type_coercion env exp rct in - typ_debug (lazy ("RETURNING AFTER COERCION " ^ string_of_typ (typ_of exp))); - exp, !all_unifiers + let exp = annot_exp (E_app (f, xs)) typ_ret eff in + typ_debug (lazy ("RETURNING: " ^ string_of_exp exp)); + + exp, KBindings.empty (* FIXME *) and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) (Typ_aux (typ_aux, _) as typ) = let (Typ_aux (typ_aux, _) as typ), env = bind_existential l typ env in @@ -3589,10 +3365,9 @@ and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) ( begin try typ_debug (lazy ("Unifying " ^ string_of_bind (typq, ctor_typ) ^ " for mapping-pattern " ^ string_of_typ typ)); - let unifiers, _, _ (* FIXME! *) = unify l env ret_typ typ in - typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); + let unifiers = unify l env ret_typ typ (tyvars_of_typ ret_typ) in let arg_typ' = subst_unifiers unifiers arg_typ in - let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in + let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in if (match quants' with [] -> false | _ -> true) then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat) else (); @@ -3620,10 +3395,9 @@ and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) ( begin try typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for mapping-pattern " ^ string_of_typ typ)); - let unifiers, _, _ (* FIXME! *) = unify l env typ2 typ in - typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); + let unifiers = unify l env typ2 typ (tyvars_of_typ typ2) in let arg_typ' = subst_unifiers unifiers typ1 in - let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in + let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in if (match quants' with [] -> false | _ -> true) then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat) else (); @@ -3638,10 +3412,9 @@ and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, ())) as mpat) ( try typ_debug (lazy "Unifying mapping forwards failed, trying backwards."); typ_debug (lazy ("Unifying " ^ string_of_bind (typq, mapping_typ) ^ " for mapping-pattern " ^ string_of_typ typ)); - let unifiers, _, _ (* FIXME! *) = unify l env typ1 typ in - typ_debug (lazy (string_of_list ", " (fun (kid, uvar) -> string_of_kid kid ^ " => " ^ string_of_uvar uvar) (KBindings.bindings unifiers))); + let unifiers = unify l env typ1 typ (tyvars_of_typ typ1) in let arg_typ' = subst_unifiers unifiers typ2 in - let quants' = List.fold_left (fun qs (kid, uvar) -> instantiate_quants qs kid uvar) quants (KBindings.bindings unifiers) in + let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in if (match quants' with [] -> false | _ -> true) then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in mapping-pattern " ^ string_of_mpat mpat) else (); @@ -4420,30 +4193,33 @@ let check_type_union env variant typq (Tu_aux (tu, l)) = |> Env.add_val_spec v (typq, typ') (* FIXME: This code is duplicated with general kind-checking code in environment, can they be merged? *) -let mk_synonym typq typ = +let mk_synonym typq typ_arg = let kopts, ncs = quant_split typq in let rec subst_args kopts args = match kopts, args with | kopt :: kopts, Typ_arg_aux (Typ_arg_nexp arg, _) :: args when is_nat_kopt kopt -> - let typ, ncs = subst_args kopts args in - typ_subst_nexp (kopt_kid kopt) (unaux_nexp arg) typ, - List.map (nc_subst_nexp (kopt_kid kopt) (unaux_nexp arg)) ncs + let typ_arg, ncs = subst_args kopts args in + typ_arg_subst (kopt_kid kopt) (arg_nexp arg) typ_arg, + List.map (constraint_subst (kopt_kid kopt) (arg_nexp arg)) ncs | kopt :: kopts, Typ_arg_aux (Typ_arg_typ arg, _) :: args when is_typ_kopt kopt -> - let typ, ncs = subst_args kopts args in - typ_subst_typ (kopt_kid kopt) (unaux_typ arg) typ, ncs + let typ_arg, ncs = subst_args kopts args in + typ_arg_subst (kopt_kid kopt) (arg_typ arg) typ_arg, ncs | kopt :: kopts, Typ_arg_aux (Typ_arg_order arg, _) :: args when is_order_kopt kopt -> - let typ, ncs = subst_args kopts args in - typ_subst_order (kopt_kid kopt) (unaux_order arg) typ, ncs - | [], [] -> typ, ncs + let typ_arg, ncs = subst_args kopts args in + typ_arg_subst (kopt_kid kopt) (arg_order arg) typ_arg, ncs + | kopt :: kopts, Typ_arg_aux (Typ_arg_bool arg, _) :: args when is_order_kopt kopt -> + let typ_arg, ncs = subst_args kopts args in + typ_arg_subst (kopt_kid kopt) (arg_bool arg) typ_arg, ncs + | [], [] -> typ_arg, ncs | _, Typ_arg_aux (_, l) :: _ -> typ_error l "Synonym applied to bad arguments" | _, _ -> typ_error Parse_ast.Unknown "Synonym applied to bad arguments" in fun env args -> - let typ, ncs = subst_args kopts args in + let typ_arg, ncs = subst_args kopts args in if List.for_all (prove env) ncs - then typ + then typ_arg else typ_error Parse_ast.Unknown ("Could not prove constraints " ^ string_of_list ", " string_of_n_constraint ncs - ^ " in type synonym " ^ string_of_typ typ + ^ " in type synonym " ^ string_of_typ_arg typ_arg ^ " with " ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env)) let check_kinddef env (KD_aux (kdef, (l, _))) = @@ -4458,8 +4234,11 @@ let rec check_typedef : 'a. Env.t -> 'a type_def -> (tannot def) list * Env.t = fun env (TD_aux (tdef, (l, _))) -> let td_err () = raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "Unimplemented Typedef") in match tdef with - | TD_abbrev (id, nmscm, (TypSchm_aux (TypSchm_ts (typq, typ), _))) -> - [DEF_type (TD_aux (tdef, (l, None)))], Env.add_typ_synonym id (mk_synonym typq typ) env + | TD_abbrev (id, typq, (Typ_arg_aux (Typ_arg_typ _, _) as typ_arg)) -> + [DEF_type (TD_aux (tdef, (l, None)))], Env.add_typ_synonym id (mk_synonym typq typ_arg) env + (* For type synonyms for non-Type kinds we omit them from the AST *) + | TD_abbrev (id, typq, typ_arg) -> + [], Env.add_typ_synonym id (mk_synonym typq typ_arg) env | TD_record (id, nmscm, typq, fields, _) -> [DEF_type (TD_aux (tdef, (l, None)))], Env.add_record id typq fields env | TD_variant (id, nmscm, typq, arms, _) -> diff --git a/src/type_check.mli b/src/type_check.mli index f08272de..7dc2da30 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -149,7 +149,7 @@ module Env : sig won't throw any exceptions. *) val get_ret_typ : t -> typ option - val get_typ_synonym : id -> t -> (t -> typ_arg list -> typ) + val get_typ_synonym : id -> t -> (t -> typ_arg list -> typ_arg) val get_overloads : id -> t -> id list @@ -357,28 +357,21 @@ val destruct_numeric : Env.t -> typ -> (kid list * n_constraint * nexp) option val destruct_vector : Env.t -> typ -> (nexp * order * typ) option -type uvar = - | U_nexp of nexp - | U_order of order - | U_typ of typ +val subst_unifiers : typ_arg KBindings.t -> typ -> typ -val string_of_uvar : uvar -> string - -val subst_unifiers : uvar KBindings.t -> typ -> typ - -val unify : l -> Env.t -> typ -> typ -> uvar KBindings.t * kid list * n_constraint option +val unify : l -> Env.t -> typ -> typ -> typ_arg KBindings.t * kid list * n_constraint option val alpha_equivalent : Env.t -> typ -> typ -> bool (** Throws Invalid_argument if the argument is not a E_app expression *) -val instantiation_of : tannot exp -> uvar KBindings.t +val instantiation_of : tannot exp -> typ_arg KBindings.t (** Doesn't use the type of the expression when calculating instantiations. May fail if the arguments aren't sufficient to calculate all unifiers. *) -val instantiation_of_without_type : tannot exp -> uvar KBindings.t +val instantiation_of_without_type : tannot exp -> typ_arg KBindings.t (* Type variable instantiations that inference will extract from constraints *) -val instantiate_simple_equations : quant_item list -> uvar KBindings.t +val instantiate_simple_equations : quant_item list -> typ_arg KBindings.t val propagate_exp_effect : tannot exp -> tannot exp diff --git a/src/type_error.ml b/src/type_error.ml index 7551970f..0fa238ed 100644 --- a/src/type_error.ml +++ b/src/type_error.ml @@ -198,13 +198,16 @@ let rec analyze_unresolved_quant locals ncs = function empty let rec pp_type_error = function - | Err_no_casts (exp, typ_from, typ_to, trigger, _) -> + | Err_no_casts (exp, typ_from, typ_to, trigger, reasons) -> let coercion = group (string "Tried performing type coercion from" ^/^ Pretty_print_sail.doc_typ typ_from ^/^ string "to" ^/^ Pretty_print_sail.doc_typ typ_to ^/^ string "on" ^/^ Pretty_print_sail.doc_exp exp) in - coercion ^^ hardline ^^ (string "Failed because" ^/^ pp_type_error trigger) + coercion ^^ hardline + ^^ (string "Coercion failed because:" ^//^ pp_type_error trigger) + ^^ hardline + ^^ (string "Possible reasons:" ^//^ separate_map hardline pp_type_error reasons) | Err_no_overloading (id, errs) -> string ("No overloadings for " ^ string_of_id id ^ ", tried:") ^//^ -- cgit v1.2.3