summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml3
-rw-r--r--src/ast_util.mli1
-rw-r--r--src/monomorphise.ml80
-rw-r--r--src/pretty_print_lem.ml223
-rw-r--r--src/rewrites.ml74
-rw-r--r--src/spec_analysis.ml387
-rw-r--r--src/type_check.ml160
-rw-r--r--src/type_check.mli6
8 files changed, 455 insertions, 479 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 074376a9..a70db3e0 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -428,6 +428,9 @@ let id_loc = function
let kid_loc = function
| Kid_aux (_, l) -> l
+let pat_loc = function
+ | P_aux (_, (l, _)) -> l
+
let def_loc = function
| DEF_kind (KD_aux (_, (l, _)))
| DEF_type (TD_aux (_, (l, _)))
diff --git a/src/ast_util.mli b/src/ast_util.mli
index ec12d44b..69d80ea7 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -164,6 +164,7 @@ val map_letbind_annot : ('a annot -> 'b annot) -> 'a letbind -> 'b letbind
val id_loc : id -> Parse_ast.l
val kid_loc : kid -> Parse_ast.l
+val pat_loc : 'a pat -> Parse_ast.l
val def_loc : 'a def -> Parse_ast.l
(* For debugging and error messages only: Not guaranteed to produce
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 4fa5d1d6..3ea156d0 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -54,8 +54,7 @@ open Ast_util
module Big_int = Nat_big_num
open Type_check
-let size_set_limit = 8
-let vector_split_limit = 4
+let size_set_limit = 16
let optmap v f =
match v with
@@ -1316,14 +1315,10 @@ let split_defs continue_anyway splits defs =
| Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp len,_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
(match len with
| Nexp_aux (Nexp_constant sz,_) ->
- if Big_int.to_int sz <= vector_split_limit then
- let lits = make_vectors (Big_int.to_int sz) in
- List.map (fun lit ->
- P_aux (P_lit lit,(l,annot)),
- [var,E_aux (E_lit lit,(new_l,annot))]) lits
- else
- cannot ("Refusing to split vector type of length " ^ Big_int.to_string sz ^
- " (above limit " ^ string_of_int vector_split_limit ^ ")")
+ let lits = make_vectors (Big_int.to_int sz) in
+ List.map (fun lit ->
+ P_aux (P_lit lit,(l,annot)),
+ [var,E_aux (E_lit lit,(new_l,annot))]) lits
| _ ->
cannot ("length not constant, " ^ string_of_nexp len)
)
@@ -1494,6 +1489,19 @@ let split_defs continue_anyway splits defs =
in p
in
+ let check_split_size lst l =
+ let size = List.length lst in
+ if size > size_set_limit then
+ let open Reporting_basic in
+ let error =
+ Err_general (l, "Case split is too large (" ^ string_of_int size ^
+ " > limit " ^ string_of_int size_set_limit ^ ")")
+ in if continue_anyway
+ then (print_error error; false)
+ else raise (Fatal_error error)
+ else true
+ in
+
let rec map_exp ((E_aux (e,annot)) as ea) =
let re e = E_aux (e,annot) in
match e with
@@ -1556,13 +1564,16 @@ let split_defs continue_anyway splits defs =
FE_aux (FE_Fexp (id,map_exp e),annot)
and map_pexp = function
| Pat_aux (Pat_exp (p,e),l) ->
+ let nosplit = [Pat_aux (Pat_exp (p,map_exp e),l)] in
(match map_pat p with
- | NoSplit -> [Pat_aux (Pat_exp (p,map_exp e),l)]
+ | NoSplit -> nosplit
| VarSplit patsubsts ->
- List.map (fun (pat',substs) ->
- let exp' = subst_exp substs e in
- Pat_aux (Pat_exp (pat', map_exp exp'),l))
- patsubsts
+ if check_split_size patsubsts (pat_loc p) then
+ List.map (fun (pat',substs) ->
+ let exp' = subst_exp substs e in
+ Pat_aux (Pat_exp (pat', map_exp exp'),l))
+ patsubsts
+ else nosplit
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
let pat' = nexp_subst_pat nsubst pat' in
@@ -1570,14 +1581,17 @@ let split_defs continue_anyway splits defs =
Pat_aux (Pat_exp (pat', map_exp exp'),l)
) patnsubsts)
| Pat_aux (Pat_when (p,e1,e2),l) ->
+ let nosplit = [Pat_aux (Pat_when (p,map_exp e1,map_exp e2),l)] in
(match map_pat p with
- | NoSplit -> [Pat_aux (Pat_when (p,map_exp e1,map_exp e2),l)]
+ | NoSplit -> nosplit
| VarSplit patsubsts ->
- List.map (fun (pat',substs) ->
- let exp1' = subst_exp substs e1 in
- let exp2' = subst_exp substs e2 in
- Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l))
- patsubsts
+ if check_split_size patsubsts (pat_loc p) then
+ List.map (fun (pat',substs) ->
+ let exp1' = subst_exp substs e1 in
+ let exp2' = subst_exp substs e2 in
+ Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l))
+ patsubsts
+ else nosplit
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
let pat' = nexp_subst_pat nsubst pat' in
@@ -1770,6 +1784,10 @@ let rewrite_size_parameters env (Defs defs) =
{ (compute_exp_alg KidSet.empty KidSet.union) with
e_aux = (fun ((s,e),annot) -> KidSet.union s (sizes_of_annot annot), E_aux (e,annot));
e_let = (fun ((sl,lb),(s2,e2)) -> KidSet.union sl (KidSet.diff s2 (tyvars_bound_in_lb lb)), E_let (lb,e2));
+ e_for = (fun (id,(s1,e1),(s2,e2),(s3,e3),ord,(s4,e4)) ->
+ let kid = mk_kid ("loop_" ^ string_of_id id) in
+ KidSet.union s1 (KidSet.union s2 (KidSet.union s3 (KidSet.remove kid s4))),
+ E_for (id,e1,e2,e3,ord,e4));
pat_exp = (fun ((sp,pat),(s,e)) -> KidSet.diff s (tyvars_bound_in_pat pat), Pat_exp (pat,e))}
pexp)
in
@@ -1786,7 +1804,7 @@ let rewrite_size_parameters env (Defs defs) =
| P_aux (P_tup ps,_) -> ps
| _ -> [pat]
in
- let to_change = List.map
+ let to_change = Util.map_filter
(fun kid ->
let check (P_aux (_,(_,Some (env,typ,_)))) =
match Env.expand_synonyms env typ with
@@ -1799,9 +1817,10 @@ let rewrite_size_parameters env (Defs defs) =
if Kid.compare kid kid' = 0 then Some kid else None
| _ -> None
in match findi check parameters with
- | None -> raise (Reporting_basic.err_general l
- ("Unable to find an argument for " ^ string_of_kid kid))
- | Some i -> i)
+ | None -> (Reporting_basic.print_error (Reporting_basic.Err_general (l,
+ ("Unable to find an argument for " ^ string_of_kid kid)));
+ None)
+ | Some i -> Some i)
(KidSet.elements expose_tyvars)
in
let ik_compare (i,k) (i',k') =
@@ -1858,7 +1877,7 @@ let rewrite_size_parameters env (Defs defs) =
let body = List.fold_left add_var_rebind body vars in
let guard = match guard with
| None -> None
- | Some exp -> Some (List.fold_left add_var_rebind body vars)
+ | Some exp -> Some (List.fold_left add_var_rebind exp vars)
in
pat,guard,body
| exception Not_found -> pat,guard,body
@@ -2752,6 +2771,15 @@ let rewrite_app env typ (id,args) =
| _ -> E_app (id,args)
+ else if is_id env (Id "UInt") id then
+ let is_slice = is_id env (Id "slice") in
+ match args with
+ | [E_aux (E_app (slice1, [vector1; start1; length1]),_)]
+ when is_slice slice1 && not (is_constant length1) ->
+ E_app (mk_id "UInt_slice", [vector1; start1; length1])
+
+ | _ -> E_app (id,args)
+
else E_app (id,args)
let rewrite_aux = function
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 0002f8cc..c776081c 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -62,6 +62,12 @@ open Pretty_print_common
let opt_sequential = ref false
let opt_mwords = ref false
+type context = {
+ early_ret : bool;
+ bound_nexps : NexpSet.t;
+}
+let empty_ctxt = { early_ret = false; bound_nexps = NexpSet.empty }
+
let print_to_from_interp_value = ref false
let langlebar = string "<|"
let ranglebar = string "|>"
@@ -326,12 +332,12 @@ let doc_typ_lem, doc_atomic_typ_lem =
length argument are checked for variables, and the latter only if it is
a bitvector; for other types of vectors, the length is not pretty-printed
in the type, and the start index is never pretty-printed in vector types. *)
-let rec contains_t_pp_var (Typ_aux (t,a) as typ) = match t with
+let rec contains_t_pp_var ctxt (Typ_aux (t,a) as typ) = match t with
| Typ_id _ -> false
| Typ_var _ -> true
| Typ_exist _ -> true
- | Typ_fn (t1,t2,_) -> contains_t_pp_var t1 || contains_t_pp_var t2
- | Typ_tup ts -> List.exists contains_t_pp_var ts
+ | Typ_fn (t1,t2,_) -> contains_t_pp_var ctxt t1 || contains_t_pp_var ctxt t2
+ | Typ_tup ts -> List.exists (contains_t_pp_var ctxt) ts
| Typ_app (c,targs) ->
if Ast_util.is_number typ then false
else if is_bitvector_typ typ then
@@ -340,23 +346,22 @@ let rec contains_t_pp_var (Typ_aux (t,a) as typ) = match t with
not (is_nexp_constant length ||
(!opt_mwords &&
match length with Nexp_aux (Nexp_var _,_) -> true | _ -> false))
- else List.exists contains_t_arg_pp_var targs
-and contains_t_arg_pp_var (Typ_arg_aux (targ, _)) = match targ with
- | Typ_arg_typ t -> contains_t_pp_var t
- | Typ_arg_nexp nexp -> not (is_nexp_constant (nexp_simp nexp))
+ else List.exists (contains_t_arg_pp_var ctxt) targs
+and contains_t_arg_pp_var ctxt (Typ_arg_aux (targ, _)) = match targ with
+ | Typ_arg_typ t -> contains_t_pp_var ctxt t
+ | Typ_arg_nexp nexp ->
+ let nexp = nexp_simp nexp in
+ not (is_nexp_constant nexp || NexpSet.mem nexp ctxt.bound_nexps)
| _ -> false
-let doc_tannot_lem eff typ =
- if contains_t_pp_var typ then empty
+let doc_tannot_lem ctxt eff typ =
+ if contains_t_pp_var ctxt typ then empty
else
let ta = doc_typ_lem typ in
if eff then string " : M " ^^ parens ta
else string " : " ^^ ta
-(* doc_lit_lem gets as an additional parameter the type information from the
- * expression around it: that's a hack, but how else can we distinguish between
- * undefined values of different types ? *)
-let doc_lit_lem in_pat (L_aux(lit,l)) a =
+let doc_lit_lem (L_aux(lit,l)) =
match lit with
| L_unit -> utf8string "()"
| L_zero -> utf8string "B0"
@@ -366,24 +371,12 @@ let doc_lit_lem in_pat (L_aux(lit,l)) a =
| L_num i ->
let ipp = Big_int.to_string i in
utf8string (
- if in_pat then "("^ipp^":nn)"
- else if Big_int.less i Big_int.zero then "((0"^ipp^"):ii)"
+ if Big_int.less i Big_int.zero then "((0"^ipp^"):ii)"
else "("^ipp^":ii)")
| L_hex n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*)*)
| L_bin n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*)*)
| L_undef ->
- (match a with
- | Some (_, (Typ_aux (t,_) as typ), _) ->
- (match t with
- | Typ_id (Id_aux (Id "bit", _))
- | Typ_app (Id_aux (Id "register", _),_) -> utf8string "UndefinedRegister 0"
- | Typ_id (Id_aux (Id "string", _)) -> utf8string "\"\""
- | _ ->
- let ta = if contains_t_pp_var typ then empty
- else doc_tannot_lem false typ in
- parens
- ((utf8string "(failwith \"undefined value of unsupported type\")") ^^ ta))
- | _ -> utf8string "(failwith \"undefined value of unsupported type\")")
+ utf8string "(failwith \"undefined value of unsupported type\")"
| L_string s -> utf8string ("\"" ^ s ^ "\"")
| L_real s ->
(* Lem does not support decimal syntax, so we translate a string
@@ -427,29 +420,30 @@ let doc_typquant_lem (TypQ_aux(tq,_)) vars_included typ = match tq with
machine words. Often these will be unnecessary, but this simple
approach will do for now. *)
-let rec typeclass_nexps (Typ_aux(t,_)) = match t with
-| Typ_id _
-| Typ_var _
- -> NexpSet.empty
-| Typ_fn (t1,t2,_) -> NexpSet.union (typeclass_nexps t1) (typeclass_nexps t2)
-| Typ_tup ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts)
-| Typ_app (Id_aux (Id "vector",_),
- [_;Typ_arg_aux (Typ_arg_nexp size_nexp,_);
- _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)])
-| Typ_app (Id_aux (Id "itself",_),
- [Typ_arg_aux (Typ_arg_nexp size_nexp,_)]) ->
- let size_nexp = nexp_simp size_nexp in
- if is_nexp_constant size_nexp then NexpSet.empty else
- NexpSet.singleton (orig_nexp size_nexp)
-| Typ_app _ -> NexpSet.empty
-| Typ_exist (kids,_,t) -> NexpSet.empty (* todo *)
+let rec typeclass_nexps (Typ_aux(t,_)) =
+ if !opt_mwords then
+ match t with
+ | Typ_id _
+ | Typ_var _
+ -> NexpSet.empty
+ | Typ_fn (t1,t2,_) -> NexpSet.union (typeclass_nexps t1) (typeclass_nexps t2)
+ | Typ_tup ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts)
+ | Typ_app (Id_aux (Id "vector",_),
+ [_;Typ_arg_aux (Typ_arg_nexp size_nexp,_);
+ _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)])
+ | Typ_app (Id_aux (Id "itself",_),
+ [Typ_arg_aux (Typ_arg_nexp size_nexp,_)]) ->
+ let size_nexp = nexp_simp size_nexp in
+ if is_nexp_constant size_nexp then NexpSet.empty else
+ NexpSet.singleton (orig_nexp size_nexp)
+ | Typ_app _ -> NexpSet.empty
+ | Typ_exist (kids,_,t) -> NexpSet.empty (* todo *)
+ else NexpSet.empty
let doc_typclasses_lem t =
- if !opt_mwords then
- let nexps = typeclass_nexps t in
- if NexpSet.is_empty nexps then (empty, NexpSet.empty) else
- (separate_map comma_sp (fun nexp -> string "Size " ^^ doc_nexp_lem nexp) (NexpSet.elements nexps) ^^ string " => ", nexps)
- else (empty, NexpSet.empty)
+ let nexps = typeclass_nexps t in
+ if NexpSet.is_empty nexps then (empty, NexpSet.empty) else
+ (separate_map comma_sp (fun nexp -> string "Size " ^^ doc_nexp_lem nexp) (NexpSet.elements nexps) ^^ string " => ", nexps)
let doc_typschm_lem quants (TypSchm_aux(TypSchm_ts(tq,t),_)) =
let pt = doc_typ_lem t in
@@ -470,44 +464,44 @@ let is_ctor env id = match Env.lookup_id id env with
(*Note: vector concatenation, literal vectors, indexed vectors, and record should
be removed prior to pp. The latter two have never yet been seen
*)
-let rec doc_pat_lem apat_needed (P_aux (p,(l,annot)) as pa) = match p with
+let rec doc_pat_lem ctxt apat_needed (P_aux (p,(l,annot)) as pa) = match p with
| P_app(id, ((_ :: _) as pats)) ->
let ppp = doc_unop (doc_id_lem_ctor id)
- (parens (separate_map comma (doc_pat_lem true) pats)) in
+ (parens (separate_map comma (doc_pat_lem ctxt true) pats)) in
if apat_needed then parens ppp else ppp
| P_app(id,[]) -> doc_id_lem_ctor id
- | P_lit lit -> doc_lit_lem true lit annot
+ | P_lit lit -> doc_lit_lem lit
| P_wild -> underscore
| P_id id ->
begin match id with
| Id_aux (Id "None",_) -> string "Nothing" (* workaround temporary issue *)
| _ -> doc_id_lem id end
- | P_var(p,kid) -> doc_pat_lem true p
- | P_as(p,id) -> parens (separate space [doc_pat_lem true p; string "as"; doc_id_lem id])
+ | P_var(p,kid) -> doc_pat_lem ctxt true p
+ | P_as(p,id) -> parens (separate space [doc_pat_lem ctxt true p; string "as"; doc_id_lem id])
| P_typ(Typ_aux (Typ_tup typs, _), P_aux (P_tup pats, _)) ->
(* Isabelle does not seem to like type-annotated tuple patterns;
it gives a syntax error. Avoid this by annotating the tuple elements instead *)
let doc_elem typ (P_aux (_, annot) as pat) =
- doc_pat_lem true (P_aux (P_typ (typ, pat), annot)) in
+ doc_pat_lem ctxt true (P_aux (P_typ (typ, pat), annot)) in
parens (separate comma_sp (List.map2 doc_elem typs pats))
| P_typ(typ,p) ->
- let doc_p = doc_pat_lem true p in
- if contains_t_pp_var typ then doc_p
+ let doc_p = doc_pat_lem ctxt true p in
+ if contains_t_pp_var ctxt typ then doc_p
else parens (doc_op colon doc_p (doc_typ_lem typ))
| P_vector pats ->
let ppp =
(separate space)
- [string "Vector";brackets (separate_map semi (doc_pat_lem true) pats);underscore;underscore] in
+ [string "Vector";brackets (separate_map semi (doc_pat_lem ctxt true) pats);underscore;underscore] in
if apat_needed then parens ppp else ppp
| P_vector_concat pats ->
raise (Reporting_basic.err_unreachable l
"vector concatenation patterns should have been removed before pretty-printing")
| P_tup pats ->
(match pats with
- | [p] -> doc_pat_lem apat_needed p
- | _ -> parens (separate_map comma_sp (doc_pat_lem false) pats))
- | P_list pats -> brackets (separate_map semi (doc_pat_lem false) pats) (*Never seen but easy in lem*)
- | P_cons (p,p') -> doc_op (string "::") (doc_pat_lem true p) (doc_pat_lem true p')
+ | [p] -> doc_pat_lem ctxt apat_needed p
+ | _ -> parens (separate_map comma_sp (doc_pat_lem ctxt false) pats))
+ | P_list pats -> brackets (separate_map semi (doc_pat_lem ctxt false) pats) (*Never seen but easy in lem*)
+ | P_cons (p,p') -> doc_op (string "::") (doc_pat_lem ctxt true p) (doc_pat_lem ctxt true p')
| P_record (_,_) -> empty (* TODO *)
let rec typ_needs_printed (Typ_aux (t,_) as typ) = match t with
@@ -539,13 +533,13 @@ let typ_id_of (Typ_aux (typ, l)) = match typ with
let prefix_recordtype = true
let report = Reporting_basic.err_unreachable
let doc_exp_lem, doc_let_lem =
- let rec top_exp (early_ret : bool) (aexp_needed : bool)
+ let rec top_exp (ctxt : context) (aexp_needed : bool)
(E_aux (e, (l,annot)) as full_exp) =
- let expY = top_exp early_ret true in
- let expN = top_exp early_ret false in
- let expV = top_exp early_ret in
+ let expY = top_exp ctxt true in
+ let expN = top_exp ctxt false in
+ let expV = top_exp ctxt in
let liftR doc =
- if early_ret && effectful (effect_of full_exp)
+ if ctxt.early_ret && effectful (effect_of full_exp)
then separate space [string "liftR"; parens (doc)]
else doc in
match e with
@@ -565,10 +559,10 @@ let doc_exp_lem, doc_let_lem =
doc_id_lem id in
liftR ((prefix 2 1)
(string "write_reg_field_range")
- (align (doc_lexp_deref_lem early_ret le ^/^
+ (align (doc_lexp_deref_lem ctxt le ^/^
field_ref ^/^ expY e2 ^/^ expY e3 ^/^ expY e)))
| _ ->
- let deref = doc_lexp_deref_lem early_ret le in
+ let deref = doc_lexp_deref_lem ctxt le in
liftR ((prefix 2 1)
(string "write_reg_range")
(align (deref ^/^ expY e2 ^/^ expY e3) ^/^ expY e)))
@@ -585,10 +579,10 @@ let doc_exp_lem, doc_let_lem =
let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot fannot)) then "write_reg_field_bit" else "write_reg_field_pos" in
liftR ((prefix 2 1)
(string call)
- (align (doc_lexp_deref_lem early_ret le ^/^
+ (align (doc_lexp_deref_lem ctxt le ^/^
field_ref ^/^ expY e2 ^/^ expY e)))
| LEXP_aux (_, lannot) ->
- let deref = doc_lexp_deref_lem early_ret le in
+ let deref = doc_lexp_deref_lem ctxt le in
let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot lannot)) then "write_reg_bit" else "write_reg_pos" in
liftR ((prefix 2 1) (string call)
(deref ^/^ expY e2 ^/^ expY e))
@@ -602,10 +596,10 @@ let doc_exp_lem, doc_let_lem =
string "set_field"*) in
liftR ((prefix 2 1)
(string "write_reg_field")
- (doc_lexp_deref_lem early_ret le ^^ space ^^
+ (doc_lexp_deref_lem ctxt le ^^ space ^^
field_ref ^/^ expY e))
| _ ->
- liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem early_ret le ^/^ expY e)))
+ liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem ctxt le ^/^ expY e)))
| E_vector_append(le,re) ->
raise (Reporting_basic.err_unreachable l
"E_vector_append should have been rewritten before pretty-printing")
@@ -621,7 +615,7 @@ let doc_exp_lem, doc_let_lem =
| E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) ->
raise (report l "E_for should have been removed till now")
| E_let(leb,e) ->
- let epp = let_exp early_ret leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in
+ let epp = let_exp ctxt leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in
if aexp_needed then parens epp else epp
| E_app(f,args) ->
begin match f with
@@ -676,8 +670,8 @@ let doc_exp_lem, doc_let_lem =
| [exp] ->
let epp = separate space [string "early_return"; expY exp] in
let aexp_needed, tepp =
- if contains_t_pp_var (typ_of exp) ||
- contains_t_pp_var (typ_of full_exp) then
+ if contains_t_pp_var ctxt (typ_of exp) ||
+ contains_t_pp_var ctxt (typ_of full_exp) then
aexp_needed, epp
else
let tannot = separate space [string "MR";
@@ -716,7 +710,7 @@ let doc_exp_lem, doc_let_lem =
let t = (*Env.base_typ_of (env_of full_exp)*) (typ_of full_exp) in
let eff = effect_of full_exp in
if typ_needs_printed (Env.base_typ_of (env_of full_exp) t)
- then (align epp ^^ (doc_tannot_lem (effectful eff) t), true)
+ then (align epp ^^ (doc_tannot_lem ctxt (effectful eff) t), true)
else (epp, aexp_needed) in
liftR (if aexp_needed then parens (align taepp) else taepp)
end
@@ -749,11 +743,11 @@ let doc_exp_lem, doc_let_lem =
if has_effect eff BE_rreg then
let epp = separate space [string "read_reg";doc_id_lem id] in
if is_bitvector_typ base_typ
- then liftR (parens (epp ^^ doc_tannot_lem true base_typ))
+ then liftR (parens (epp ^^ doc_tannot_lem ctxt true base_typ))
else liftR epp
else if is_ctor env id then doc_id_lem_ctor id
else doc_id_lem id
- | E_lit lit -> doc_lit_lem false lit annot
+ | E_lit lit -> doc_lit_lem lit
| E_cast(typ,e) ->
expV aexp_needed e
| E_tuple exps ->
@@ -767,7 +761,7 @@ let doc_exp_lem, doc_let_lem =
| _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in
let epp = anglebars (space ^^ (align (separate_map
(semi_sp ^^ break 1)
- (doc_fexp early_ret recordtyp) fexps)) ^^ space) in
+ (doc_fexp ctxt recordtyp) fexps)) ^^ space) in
if aexp_needed then parens epp else epp
| E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) ->
let recordtyp = match annot with
@@ -776,7 +770,7 @@ let doc_exp_lem, doc_let_lem =
when Env.is_record tid env ->
tid
| _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in
- anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp early_ret recordtyp) fexps))
+ anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps))
| E_vector exps ->
let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in
let (start, len, order, etyp) =
@@ -804,7 +798,7 @@ let doc_exp_lem, doc_let_lem =
let (epp,aexp_needed) =
if is_bit_typ etyp && !opt_mwords then
let bepp = string "vec_to_bvec" ^^ space ^^ parens (align epp) in
- (bepp ^^ doc_tannot_lem false t, true)
+ (bepp ^^ doc_tannot_lem ctxt false t, true)
else (epp,aexp_needed) in
if aexp_needed then parens (align epp) else epp
| E_vector_update(v,e1,e2) ->
@@ -835,15 +829,15 @@ let doc_exp_lem, doc_let_lem =
let epp =
group ((separate space [string "match"; only_integers e; string "with"]) ^/^
- (separate_map (break 1) (doc_case early_ret) pexps) ^/^
+ (separate_map (break 1) (doc_case ctxt) pexps) ^/^
(string "end")) in
if aexp_needed then parens (align epp) else align epp
| E_try (e, pexps) ->
if effectful (effect_of e) then
- let try_catch = if early_ret then "try_catchR" else "try_catch" in
+ let try_catch = if ctxt.early_ret then "try_catchR" else "try_catch" in
let epp =
group ((separate space [string try_catch; expY e; string "(function "]) ^/^
- (separate_map (break 1) (doc_case early_ret) pexps) ^/^
+ (separate_map (break 1) (doc_case ctxt) pexps) ^/^
(string "end)")) in
if aexp_needed then parens (align epp) else align epp
else
@@ -868,20 +862,20 @@ let doc_exp_lem, doc_let_lem =
(separate space [expV b e1; string ">>"]) ^^ hardline ^^ expN e2
| _ ->
(separate space [expV b e1; string ">>= fun";
- doc_pat_lem true pat;arrow]) ^^ hardline ^^ expN e2 in
+ doc_pat_lem ctxt true pat;arrow]) ^^ hardline ^^ expN e2 in
if aexp_needed then parens (align epp) else epp
| E_internal_return (e1) ->
separate space [string "return"; expY e1]
| E_sizeof nexp ->
(match nexp_simp nexp with
- | Nexp_aux (Nexp_constant i, _) -> doc_lit_lem false (L_aux (L_num i, l)) annot
+ | Nexp_aux (Nexp_constant i, _) -> doc_lit_lem (L_aux (L_num i, l))
| _ ->
raise (Reporting_basic.err_unreachable l
"pretty-printing non-constant sizeof expressions to Lem not supported"))
| E_return r ->
let ret_monad = if !opt_sequential then " : MR regstate" else " : MR" in
let ta =
- if contains_t_pp_var (typ_of full_exp) || contains_t_pp_var (typ_of r)
+ if contains_t_pp_var ctxt (typ_of full_exp) || contains_t_pp_var ctxt (typ_of r)
then empty
else separate space
[string ret_monad;
@@ -893,33 +887,33 @@ let doc_exp_lem, doc_let_lem =
| E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_exp_user _ ->
raise (Reporting_basic.err_unreachable l
"unsupported internal expression encountered while pretty-printing")
- and let_exp early_ret (LB_aux(lb,_)) = match lb with
+ and let_exp ctxt (LB_aux(lb,_)) = match lb with
| LB_val(pat,e) ->
prefix 2 1
- (separate space [string "let"; doc_pat_lem true pat; equals])
- (top_exp early_ret false e)
+ (separate space [string "let"; doc_pat_lem ctxt true pat; equals])
+ (top_exp ctxt false e)
- and doc_fexp early_ret recordtyp (FE_aux(FE_Fexp(id,e),_)) =
+ and doc_fexp ctxt recordtyp (FE_aux(FE_Fexp(id,e),_)) =
let fname =
if prefix_recordtype
then (string (string_of_id recordtyp ^ "_")) ^^ doc_id_lem id
else doc_id_lem id in
- group (doc_op equals fname (top_exp early_ret true e))
+ group (doc_op equals fname (top_exp ctxt true e))
- and doc_case early_ret = function
+ and doc_case ctxt = function
| Pat_aux(Pat_exp(pat,e),_) ->
- group (prefix 3 1 (separate space [pipe; doc_pat_lem false pat;arrow])
- (group (top_exp early_ret false e)))
+ group (prefix 3 1 (separate space [pipe; doc_pat_lem ctxt false pat;arrow])
+ (group (top_exp ctxt false e)))
| Pat_aux(Pat_when(_,_,_),(l,_)) ->
raise (Reporting_basic.err_unreachable l
"guarded pattern expression should have been rewritten before pretty-printing")
- and doc_lexp_deref_lem early_ret ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with
+ and doc_lexp_deref_lem ctxt ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with
| LEXP_field (le,id) ->
- parens (separate empty [doc_lexp_deref_lem early_ret le;dot;doc_id_lem id])
+ parens (separate empty [doc_lexp_deref_lem ctxt le;dot;doc_id_lem id])
| LEXP_id id -> doc_id_lem id
| LEXP_cast (typ,id) -> doc_id_lem id
- | LEXP_tup lexps -> parens (separate_map comma_sp (doc_lexp_deref_lem early_ret) lexps)
+ | LEXP_tup lexps -> parens (separate_map comma_sp (doc_lexp_deref_lem ctxt) lexps)
| _ ->
raise (Reporting_basic.err_unreachable l ("doc_lexp_deref_lem: Unsupported lexp"))
(* expose doc_exp_lem and doc_let *)
@@ -963,7 +957,7 @@ let doc_typdef_lem (TD_aux(td, (l, annot))) = match td with
mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown),
[mk_typ_arg (Typ_arg_typ rectyp);
mk_typ_arg (Typ_arg_typ ftyp)])) in
- let rfannot = doc_tannot_lem false reftyp in
+ let rfannot = doc_tannot_lem empty_ctxt false reftyp in
let get, set =
string "rec_val" ^^ dot ^^ fname fid,
anglebars (space ^^ string "rec_val with " ^^
@@ -1215,17 +1209,20 @@ let doc_rec_lem (Rec_aux(r,_)) = match r with
let doc_tannot_opt_lem (Typ_annot_opt_aux(t,_)) = match t with
| Typ_annot_opt_some(tq,typ) -> (*doc_typquant_lem tq*) (doc_typ_lem typ)
-let doc_fun_body_lem exp =
- let early_ret =contains_early_return exp in
- let doc_exp = doc_exp_lem early_ret false exp in
- if early_ret
+let doc_fun_body_lem ctxt exp =
+ let doc_exp = doc_exp_lem ctxt false exp in
+ if ctxt.early_ret
then align (string "catch_early_return" ^//^ parens (doc_exp))
else doc_exp
-let doc_funcl_lem (FCL_aux(FCL_Funcl(id,pexp),_)) =
+let doc_funcl_lem (FCL_aux(FCL_Funcl(id, pexp), annot)) =
+ let typ = typ_of_annot annot in
let pat,guard,exp,(l,_) = destruct_pexp pexp in
+ let ctxt =
+ { early_ret = contains_early_return exp;
+ bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) } in
let pats, bind = untuple_args_pat pat in
- let patspp = separate_map space (doc_pat_lem true) pats in
+ let patspp = separate_map space (doc_pat_lem ctxt true) pats in
let _ = match guard with
| None -> ()
| _ ->
@@ -1233,7 +1230,7 @@ let doc_funcl_lem (FCL_aux(FCL_Funcl(id,pexp),_)) =
"guarded pattern expression should have been rewritten before pretty-printing") in
group (prefix 3 1
(separate space [doc_id_lem id; patspp; equals])
- (doc_fun_body_lem (bind exp)))
+ (doc_fun_body_lem ctxt (bind exp)))
let get_id = function
| [] -> failwith "FD_function with empty list"
@@ -1247,8 +1244,8 @@ let doc_fundef_rhs_lem (FD_aux(FD_function(r, typa, efa, funcls),fannot) as fd)
let doc_mutrec_lem = function
| [] -> failwith "DEF_internal_mutrec with empty function list"
| fundefs ->
- string "let rec " ^^
- separate_map (hardline ^^ string "and ") doc_fundef_rhs_lem fundefs
+ string "let rec " ^^
+ separate_map (hardline ^^ string "and ") doc_fundef_rhs_lem fundefs
let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) =
match fcls with
@@ -1301,13 +1298,13 @@ let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) =
let named_args = if argspat = [] then [unit_pat] else named_argspat in
let doc_arg idx (P_aux (p,(l,a))) = match p with
| P_as (pat,id) -> doc_id_lem id
- | P_lit lit -> doc_lit_lem false lit a
+ | P_lit lit -> doc_lit_lem lit
| P_id id -> doc_id_lem id
| _ -> string ("arg" ^ string_of_int idx) in
let clauses =
clauses ^^ (break 1) ^^
(separate space
- [pipe;doc_pat_lem false named_pat;arrow;
+ [pipe;doc_pat_lem empty_ctxt false named_pat;arrow;
string aux_fname;
separate space (List.mapi doc_arg named_args)]) in
(already_used_fnames,auxiliary_functions,clauses)
@@ -1389,7 +1386,7 @@ let doc_regtype_fields (tname, (n1, n2, fields)) =
mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown),
[mk_typ_arg (Typ_arg_typ (mk_id_typ (mk_id tname)));
mk_typ_arg (Typ_arg_typ ftyp)])) in
- let rfannot = doc_tannot_lem false reftyp in
+ let rfannot = doc_tannot_lem empty_ctxt false reftyp in
doc_op equals
(concat [string "let "; parens (concat [string tname; underscore; doc_id_lem fid; rfannot])])
(concat [
@@ -1416,7 +1413,7 @@ let rec doc_def_lem regtypes def =
if is_field_accessor regtypes fdef then (doc_fdef, empty) else (empty, doc_fdef)
| DEF_internal_mutrec fundefs ->
(empty, doc_mutrec_lem fundefs ^/^ hardline)
- | DEF_val lbind -> (empty,group (doc_let_lem false lbind) ^/^ hardline)
+ | DEF_val lbind -> (empty,group (doc_let_lem empty_ctxt lbind) ^/^ hardline)
| DEF_scattered sdef -> failwith "doc_def_lem: shoulnd't have DEF_scattered at this point"
| DEF_kind _ -> (empty,empty)
@@ -1481,7 +1478,7 @@ let doc_regstate_lem registers =
E_record (FES_aux (FES_Fexps (List.map initreg registers, false), annot)),
(l, Some (Env.empty, mk_id_typ (mk_id "regstate"), no_effect)))
in
- doc_op equals (string "let initial_regstate") (doc_exp_lem false false exp)
+ doc_op equals (string "let initial_regstate") (doc_exp_lem empty_ctxt false exp)
else empty
in
doc_typdef_lem (TD_aux (regstate, annot)),
diff --git a/src/rewrites.ml b/src/rewrites.ml
index e4ac71cd..40772828 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -1087,7 +1087,7 @@ let remove_bitvector_pat (P_aux (_, (l, _)) as pat) =
; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot))
; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false))
} in
- let pat, env = bind_pat env
+ let pat, env = bind_pat_no_guard env
(strip_pat ((fold_pat name_bitvector_roots pat) false))
(pat_typ_of pat) in
@@ -1624,6 +1624,10 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base
rewriting of early returns
*)
let rewrite_defs_early_return (Defs defs) =
+ let is_unit (E_aux (exp, _)) = match exp with
+ | E_lit (L_aux (L_unit, _)) -> true
+ | _ -> false in
+
let is_return (E_aux (exp, _)) = match exp with
| E_return _ -> true
| _ -> false in
@@ -1632,7 +1636,35 @@ let rewrite_defs_early_return (Defs defs) =
| E_return e -> e
| _ -> exp in
- let e_block es =
+ let e_if (e1, e2, e3) =
+ if is_return e2 && is_return e3 then
+ let (E_aux (_, annot)) = get_return e2 in
+ E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot))
+ else E_if (e1, e2, e3) in
+
+ let rec e_block es =
+ (* If one of the branches of an if-expression in a block is an early
+ return, fold the rest of the block after the if-expression into the
+ other branch *)
+ let fold_if_return exp block = match exp with
+ | E_aux (E_if (c, t, (E_aux (_, annot) as e)), _) when is_return t ->
+ let annot = match block with
+ | [] -> annot
+ | _ -> let (E_aux (_, annot)) = Util.last block in annot
+ in
+ let block = if is_unit e then block else e :: block in
+ let e' = E_aux (e_block block, annot) in
+ [E_aux (e_if (c, t, e'), annot)]
+ | E_aux (E_if (c, (E_aux (_, annot) as t), e), _) when is_return e ->
+ let annot = match block with
+ | [] -> annot
+ | _ -> let (E_aux (_, annot)) = Util.last block in annot
+ in
+ let block = if is_unit t then block else t :: block in
+ let t' = E_aux (e_block block, annot) in
+ [E_aux (e_if (c, t', e), annot)]
+ | _ -> exp :: block in
+ let es = List.fold_right fold_if_return es [] in
match es with
| [E_aux (e, _)] -> e
| _ :: _ when is_return (Util.last es) ->
@@ -1640,12 +1672,6 @@ let rewrite_defs_early_return (Defs defs) =
E_return (E_aux (E_block (Util.butlast es @ [get_return e]), annot))
| _ -> E_block es in
- let e_if (e1, e2, e3) =
- if is_return e2 && is_return e3 then
- let (E_aux (_, annot)) = get_return e2 in
- E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot))
- else E_if (e1, e2, e3) in
-
let e_case (e, pes) =
let is_return_pexp (Pat_aux (pexp, _)) = match pexp with
| Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in
@@ -1660,6 +1686,17 @@ let rewrite_defs_early_return (Defs defs) =
then E_return (E_aux (E_case (e, List.map get_return_pexp pes), annot))
else E_case (e, pes) in
+ let e_let (lb, exp) =
+ let (E_aux (_, annot) as ret_exp) = get_return exp in
+ if is_return exp then E_return (E_aux (E_let (lb, ret_exp), annot))
+ else E_let (lb, exp) in
+
+ let e_internal_let (lexp, exp1, exp2) =
+ let (E_aux (_, annot) as ret_exp2) = get_return exp2 in
+ if is_return exp2 then
+ E_return (E_aux (E_var (lexp, exp1, ret_exp2), annot))
+ else E_var (lexp, exp1, exp2) in
+
let e_aux (exp, (l, annot)) =
let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in
let env = env_of full_exp in
@@ -1674,14 +1711,18 @@ let rewrite_defs_early_return (Defs defs) =
let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pexp), a)) =
let pat,guard,exp,pannot = destruct_pexp pexp in
+ (* Try to pull out early returns as far as possible *)
+ let exp' =
+ fold_exp
+ { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case;
+ e_let = e_let; e_internal_let = e_internal_let }
+ exp in
+ (* Remove early return if we can pull it out completely, and rewrite
+ remaining early returns to "early_return" calls *)
let exp =
- exp
- (* Pull early returns out as far as possible *)
- |> fold_exp { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case }
- (* Remove singleton E_return *)
- |> get_return
- (* Fix effect annotations *)
- |> fold_exp { id_exp_alg with e_aux = e_aux } in
+ fold_exp
+ { id_exp_alg with e_aux = e_aux }
+ (if is_return exp' then get_return exp' else exp) in
let a = match a with
| (l, Some (env, typ, eff)) ->
(l, Some (env, typ, union_effects eff (effect_of exp)))
@@ -2141,7 +2182,7 @@ let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list
| [] -> k []
| exp :: exps -> f exp (fun exp -> mapCont f exps (fun exps -> k (exp :: exps)))
-let rewrite_defs_letbind_effects =
+let rewrite_defs_letbind_effects =
let rec value ((E_aux (exp_aux,_)) as exp) =
not (effectful exp || updates_vars exp)
@@ -2235,6 +2276,7 @@ let rewrite_defs_letbind_effects =
let exp =
if newreturn then
(* let typ = try typ_of exp with _ -> unit_typ in *)
+ let exp = annot_exp (E_cast (typ_of exp, exp)) l (env_of exp) (typ_of exp) in
annot_exp (E_internal_return exp) l (env_of exp) (typ_of exp)
else
exp in
diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml
index 23ce6663..371acfdc 100644
--- a/src/spec_analysis.ml
+++ b/src/spec_analysis.ml
@@ -64,133 +64,6 @@ let set_to_string n =
list_to_string (Nameset.elements n)
-(*Query a spec for its default order if one is provided. Assumes Inc if not *)
-(* let get_default_order_sp (DT_aux(spec,_)) =
- match spec with
- | DT_order (Ord_aux(o,_)) ->
- (match o with
- | Ord_inc -> Some {order = Oinc}
- | Ord_dec -> Some { order = Odec}
- | _ -> Some {order = Oinc})
- | _ -> None
-
-let get_default_order_def = function
- | DEF_default def_spec -> get_default_order_sp def_spec
- | _ -> None
-
-let rec default_order (Defs defs) =
- match defs with
- | [] -> { order = Oinc } (*When no order is specified, we assume that it's inc*)
- | def::defs ->
- match get_default_order_def def with
- | None -> default_order (Defs defs)
- | Some o -> o *)
-
-(*Is within range*)
-
-(* let check_in_range (candidate : big_int) (range : typ) : bool =
- match range.t with
- | Tapp("range", [TA_nexp min; TA_nexp max]) | Tabbrev(_,{t=Tapp("range", [TA_nexp min; TA_nexp max])}) ->
- let min,max =
- match min.nexp,max.nexp with
- | (Nconst min, Nconst max)
- | (Nconst min, N2n(_, Some max))
- | (N2n(_, Some min), Nconst max)
- | (N2n(_, Some min), N2n(_, Some max))
- -> min, max
- | (Nneg n, Nconst max) | (Nneg n, N2n(_, Some max))->
- (match n.nexp with
- | Nconst abs_min | N2n(_,Some abs_min) ->
- (Big_int.negate abs_min), max
- | _ -> assert false (*Put a better error message here*))
- | (Nconst min,Nneg n) | (N2n(_, Some min), Nneg n) ->
- (match n.nexp with
- | Nconst abs_max | N2n(_,Some abs_max) ->
- min, (Big_int.negate abs_max)
- | _ -> assert false (*Put a better error message here*))
- | (Nneg nmin, Nneg nmax) ->
- ((match nmin.nexp with
- | Nconst abs_min | N2n(_,Some abs_min) -> (Big_int.negate abs_min)
- | _ -> assert false (*Put a better error message here*)),
- (match nmax.nexp with
- | Nconst abs_max | N2n(_,Some abs_max) -> (Big_int.negate abs_max)
- | _ -> assert false (*Put a better error message here*)))
- | _ -> assert false
- in Big_int.less_equal min candidate && Big_int.less_equal candidate max
- | _ -> assert false
-
-(*Rmove me when switch to zarith*)
-let rec power_big_int b n =
- if Big_int.equal n Big_int.zero
- then (Big_int.of_int 1)
- else Big_int.mul b (power_big_int b (Big_int.sub n (Big_int.of_int 1)))
-
-let unpower_of_2 b =
- let two = Big_int.of_int 2 in
- let four = Big_int.of_int 4 in
- let eight = Big_int.of_int 8 in
- let sixteen = Big_int.of_int 16 in
- let thirty_two = Big_int.of_int 32 in
- let sixty_four = Big_int.of_int 64 in
- let onetwentyeight = Big_int.of_int 128 in
- let twofiftysix = Big_int.of_int 256 in
- let fivetwelve = Big_int.of_int 512 in
- let oneotwentyfour = Big_int.of_int 1024 in
- let to_the_sixteen = Big_int.of_int 65536 in
- let to_the_thirtytwo = Big_int.of_string "4294967296" in
- let to_the_sixtyfour = Big_int.of_string "18446744073709551616" in
- let ck i = Big_int.equal b i in
- if ck (Big_int.of_int 1) then Big_int.zero
- else if ck two then (Big_int.of_int 1)
- else if ck four then two
- else if ck eight then Big_int.of_int 3
- else if ck sixteen then four
- else if ck thirty_two then Big_int.of_int 5
- else if ck sixty_four then Big_int.of_int 6
- else if ck onetwentyeight then Big_int.of_int 7
- else if ck twofiftysix then eight
- else if ck fivetwelve then Big_int.of_int 9
- else if ck oneotwentyfour then Big_int.of_int 10
- else if ck to_the_sixteen then sixteen
- else if ck to_the_thirtytwo then thirty_two
- else if ck to_the_sixtyfour then sixty_four
- else let rec unpower b power =
- if Big_int.equal b (Big_int.of_int 1)
- then power
- else (unpower (Big_int.div b two) (Big_int.succ power)) in
- unpower b Big_int.zero
-
-let is_within_range candidate range constraints =
- let candidate_actual = match candidate.t with
- | Tabbrev(_,t) -> t
- | _ -> candidate in
- match candidate_actual.t with
- | Tapp("atom", [TA_nexp n]) ->
- (match n.nexp with
- | Nconst i | N2n(_,Some i) -> if check_in_range i range then Yes else No
- | _ -> Maybe)
- | Tapp("range", [TA_nexp bot; TA_nexp top]) ->
- (match bot.nexp,top.nexp with
- | Nconst b, Nconst t | Nconst b, N2n(_,Some t) | N2n(_, Some b), Nconst t | N2n(_,Some b), N2n(_, Some t) ->
- let at_least_in = check_in_range b range in
- let at_most_in = check_in_range t range in
- if at_least_in && at_most_in
- then Yes
- else if at_least_in || at_most_in
- then Maybe
- else No
- | _ -> Maybe)
- | Tapp("vector", [_; TA_nexp size ; _; _]) ->
- (match size.nexp with
- | Nconst i | N2n(_, Some i) ->
- if check_in_range (power_big_int (Big_int.of_int 2) i) range
- then Yes
- else No
- | _ -> Maybe)
- | _ -> Maybe
-
-let is_within_machine64 candidate constraints = is_within_range candidate int64_t constraints *)
-
(************************************************************************************************)
(*FV finding analysis: identifies the free variables of a function, expression, etc *)
@@ -313,9 +186,11 @@ let rec fv_of_exp consider_var bound used set (E_aux (e,(_,tannot))) : (Nameset.
fv_of_exp consider_var bound u set e
| E_app(id,es) ->
let us = conditional_add_exp bound used id in
+ let us = conditional_add_exp bound us (prepend_id "val:" id) in
list_fv bound us set es
| E_app_infix(l,id,r) ->
let us = conditional_add_exp bound used id in
+ let us = conditional_add_exp bound us (prepend_id "val:" id) in
list_fv bound us set [l;r]
| E_if(c,t,e) -> list_fv bound used set [c;t;e]
| E_for(id,from,to_,by,_,body) ->
@@ -464,8 +339,12 @@ let fv_of_fun consider_var (FD_aux (FD_function(rec_opt,tannot_opt,_,funcls),_)
| [] -> failwith "fv_of_fun fell off the end looking for the function name"
| FCL_aux(FCL_Funcl(id,_),_)::_ -> string_of_id id in
let base_bounds = match rec_opt with
+ (* Current Sail does not have syntax for declaring functions as recursive,
+ and type checker does not check whether functions are recursive, so
+ just always add a self-dependency of functions on themselves
| Rec_aux(Ast.Rec_rec,_) -> init_env fun_name
- | _ -> mt in
+ | _ -> mt*)
+ | _ -> init_env fun_name in
let base_bounds,ns_r = match tannot_opt with
| Typ_annot_opt_aux(Typ_annot_opt_some (typq, typ),_) ->
let bindings = if consider_var then typq_bindings typq else mt in
@@ -577,7 +456,9 @@ let fv_of_def consider_var consider_scatter_as_one all_defs = function
| DEF_val lebind -> ((fun (b,u,_) -> (b,u)) (fv_of_let consider_var mt mt mt lebind))
| DEF_spec vspec -> fv_of_vspec consider_var vspec
| DEF_fixity _ -> mt,mt
- | DEF_overload (id,ids) -> init_env (string_of_id id), List.fold_left (fun ns id -> Nameset.add (string_of_id id) ns) mt ids
+ | DEF_overload (id,ids) ->
+ init_env (string_of_id id),
+ List.fold_left (fun ns id -> Nameset.add ("val:" ^ string_of_id id) ns) mt ids
| DEF_default def -> mt,mt
| DEF_internal_mutrec fdefs ->
let fvs = List.map (fv_of_fun consider_var) fdefs in
@@ -590,129 +471,133 @@ let fv_of_def consider_var consider_scatter_as_one all_defs = function
let group_defs consider_scatter_as_one (Ast.Defs defs) =
List.map (fun d -> (fv_of_def false consider_scatter_as_one defs d,d)) defs
-(*******************************************************************************
- * Reorder defs take 2
-*)
-(*remove all of ns1 instances from ns2*)
-let remove_all ns1 ns2 =
- List.fold_right Nameset.remove (Nameset.elements ns1) ns2
-
-let remove_from_all_uses bs dbts =
- List.map (fun ((b,uses),d) -> (b,remove_all bs uses),d) dbts
-
-let remove_local_or_lib_vars dbts =
- let bound_in_dbts = List.fold_right (fun ((b,_),_) bounds -> Nameset.union b bounds) dbts mt in
- let is_bound_in_defs s = Nameset.mem s bound_in_dbts in
- let rec remove_from_uses = function
- | [] -> []
- | ((b,uses),d)::defs ->
- ((b,(Nameset.filter is_bound_in_defs uses)),d)::remove_from_uses defs in
- remove_from_uses dbts
+(*
+ * Sorting definitions, take 3
+ *)
+
+module Namemap = Map.Make(String)
+(* Nodes are labeled with strings. A graph is represented as a map associating
+ each node with its sucessors *)
+type graph = Nameset.t Namemap.t
+type node_idx = { index : int; root : int }
+
+(* Find strongly connected components using Tarjan's algorithm.
+ This algorithm also returns a topological sorting of the graph components. *)
+let scc ?(original_order : string list option) (g : graph) =
+ let components = ref [] in
+ let index = ref 0 in
+
+ let stack = ref [] in
+ let push v = (stack := v :: !stack) in
+ let pop () =
+ begin
+ let v = List.hd !stack in
+ stack := List.tl !stack;
+ v
+ end
+ in
-let compare_dbts ((_,u1),_) ((_,u2),_) = Pervasives.compare (Nameset.cardinal u1) (Nameset.cardinal u2)
+ let node_indices = Hashtbl.create (Namemap.cardinal g) in
+ let get_index v = (Hashtbl.find node_indices v).index in
+ let get_root v = (Hashtbl.find node_indices v).root in
+ let set_root v r =
+ Hashtbl.replace node_indices v { (Hashtbl.find node_indices v) with root = r } in
+
+ let rec visit_node v =
+ begin
+ Hashtbl.add node_indices v { index = !index; root = !index };
+ index := !index + 1;
+ push v;
+ if Namemap.mem v g then Nameset.iter (visit_edge v) (Namemap.find v g);
+ if get_root v = get_index v then (* v is the root of a SCC *)
+ begin
+ let component = ref [] in
+ let finished = ref false in
+ while not !finished do
+ let w = pop () in
+ component := w :: !component;
+ if String.compare v w = 0 then finished := true;
+ done;
+ components := !component :: !components;
+ end
+ end
+ and visit_edge v w =
+ if not (Hashtbl.mem node_indices w) then
+ begin
+ visit_node w;
+ if Hashtbl.mem node_indices w then set_root v (min (get_root v) (get_root w));
+ end else begin
+ if List.mem w !stack then set_root v (min (get_root v) (get_index w))
+ end
+ in
-let rec print_dependencies orig_queue work_queue names =
- match work_queue with
- | [] -> ()
- | ((binds,uses),_)::wq ->
- (if not(Nameset.is_empty(Nameset.inter names binds))
- then ((Printf.eprintf "binds of %s has uses of %s\n" (set_to_string binds) (set_to_string uses));
- print_dependencies orig_queue orig_queue uses));
- print_dependencies orig_queue wq names
-
-let merge_mutrecs defs =
- let merge_aux ((binds', uses'), def) ((binds, uses), fundefs) =
- let fundefs = match def with
- | DEF_fundef fundef -> fundef :: fundefs
- | DEF_internal_mutrec fundefs' -> fundefs' @ fundefs
- | _ ->
- (* let _ = Pretty_print_sail.pp_defs stderr (Defs [def]) in *)
- raise (Reporting_basic.err_unreachable (def_loc def)
- "Trying to merge non-function definition with mutually recursive functions") in
- (* let _ = Printf.eprintf " - Merging %s (using %s)\n" (set_to_string binds') (set_to_string uses') in *)
- ((Nameset.union binds' binds, Nameset.union uses' uses), fundefs) in
- let ((binds, uses), fundefs) = List.fold_right merge_aux defs ((mt, mt), []) in
- ((binds, uses), DEF_internal_mutrec fundefs)
-
-let rec topological_sort work_queue defs =
- match work_queue with
- | [] -> List.rev defs
- | ((binds,uses),def)::wq ->
- (*Assumes work queue given in sorted order, invariant mantained on appropriate recursive calls*)
- if (Nameset.cardinal uses = 0)
- then (*let _ = Printf.eprintf "Adding def that binds %s to definitions\n" (set_to_string binds) in*)
- topological_sort (remove_from_all_uses binds wq) (def::defs)
- else if not(Nameset.is_empty(Nameset.inter binds uses))
- then topological_sort (((binds,(remove_all binds uses)),def)::wq) defs
- else
- match List.stable_sort compare_dbts work_queue with (*We wait to sort until there are no 0 dependency nodes on top*)
- | [] -> failwith "sort shrunk the list???"
- | (((n,uses),def)::rest) as wq ->
- if (Nameset.cardinal uses = 0)
- then topological_sort wq defs
- else
- let _ = Printf.eprintf "Merging (potentially) mutually recursive definitions %s and %s\n" (set_to_string n) (set_to_string uses) in
- let is_used ((binds', uses'), def') = not(Nameset.is_empty(Nameset.inter binds' uses)) in
- let (used, rest) = List.partition is_used rest in
- let wq = merge_mutrecs (((n,uses),def)::used) :: rest in
- topological_sort wq defs
-
-let rec add_to_partial_order ((binds,uses),def) = function
- | [] ->
-(* let _ = Printf.eprintf "add_to_partial_order for def with bindings %s, uses %s.\n Eol case.\n" (set_to_string binds) (set_to_string uses) in*)
- [(binds,uses),def]
- | (((bf,uf),deff)::defs as full_defs) ->
- (*let _ = Printf.eprintf "add_to_partial_order for def with bindings %s, uses %s.\n None eol case. With first def binding %s, uses %s\n" (set_to_string binds) (set_to_string uses) (set_to_string bf) (set_to_string uf) in*)
- if Nameset.is_empty uses
- then ((binds,uses),def)::full_defs
- else if Nameset.subset binds uf (*deff relies on def, so def must be defined first*)
- then ((binds,uses),def)::((bf,(remove_all binds uf)),deff)::defs
- else if Nameset.subset bf uses (*def relies at least on deff, but maybe more, push in*)
- then ((bf,uf),deff)::(add_to_partial_order ((binds,(remove_all bf uses)),def) defs)
- else (*These two are unrelated but new def might need to go further in*)
- ((bf,uf),deff)::(add_to_partial_order ((binds,uses),def) defs)
-
-let rec gather_defs name already_included def_bind_triples =
- match def_bind_triples with
- | [] -> [],already_included,mt
- | ((binds,uses),def)::def_bind_triples ->
- let (defs,already_included,requires) = gather_defs name already_included def_bind_triples in
- let bound_names = Nameset.elements binds in
- if List.mem name already_included || List.exists (fun b -> List.mem b already_included) bound_names
- then (defs,already_included,requires)
- else
- let uses = List.fold_right Nameset.remove already_included uses in
- if Nameset.mem name binds
- then (def::defs,(bound_names@already_included), Nameset.remove name (Nameset.union uses requires))
- else (defs,already_included,requires)
-
-let rec gather_all names already_included def_bind_triples =
- let rec gather ns already_included defs reqs = match ns with
- | [] -> defs,already_included,reqs
- | name::ns ->
- if List.mem name already_included
- then gather ns already_included defs (Nameset.remove name reqs)
- else
- let (new_defs,already_included,new_reqs) = gather_defs name already_included def_bind_triples in
- gather ns already_included (new_defs@defs) (Nameset.remove name (Nameset.union new_reqs reqs))
+ let nodes = match original_order with
+ | Some nodes -> nodes
+ | None -> List.map fst (Namemap.bindings g)
in
- let (defs,already_included,reqs) = gather names already_included [] mt in
- if Nameset.is_empty reqs
- then defs
- else (gather_all (Nameset.elements reqs) already_included def_bind_triples)@defs
-
-let restrict_defs defs name_list =
- let defsno = gather_all name_list [] (group_defs false defs) in
- let rdbts = group_defs true (Defs defsno) in
- (*let partial_order =
- List.fold_left (fun po d -> add_to_partial_order d po) [] rdbts in
- let defs = List.map snd partial_order in*)
- let defs = topological_sort (List.sort compare_dbts (remove_local_or_lib_vars rdbts)) [] in
- Defs defs
-
-
-let top_sort_defs defs =
- let rdbts = group_defs true defs in
- let defs = topological_sort (List.stable_sort compare_dbts (remove_local_or_lib_vars rdbts)) [] in
- Defs defs
+ List.iter (fun v -> if not (Hashtbl.mem node_indices v) then visit_node v) nodes;
+ List.rev !components
+
+let add_def_to_graph (prelude, original_order, defset, graph) d =
+ let bound, used = fv_of_def false true [] d in
+ try
+ (* A definition may bind multiple identifiers, e.g. "let (x, y) = ...".
+ We add all identifiers to the dependency graph as a cycle. The actual
+ definition is attached to only one of the identifiers, so it will not
+ be duplicated in the final output. *)
+ let id = Nameset.choose bound in
+ let other_ids = Nameset.remove id bound in
+ let graph_id = Namemap.add id (Nameset.union used other_ids) graph in
+ let add_other_node id' g = Namemap.add id' (Nameset.singleton id) g in
+ prelude,
+ original_order @ [id],
+ Namemap.add id d defset,
+ Nameset.fold add_other_node other_ids graph_id
+ with
+ | Not_found ->
+ (* Some definitions do not bind any identifiers at all. This *should*
+ only happen for default bitvector order declarations, operator fixity
+ declarations, and comments. The sorting does not (currently) attempt
+ to preserve the positions of these AST nodes; they are collected
+ separately and placed at the beginning of the output. Comments are
+ currently ignored by the Lem and OCaml backends, anyway. For
+ default order and fixity declarations, this means that specifications
+ currently have to assume those declarations are moved to the
+ beginning when using a backend that requires topological sorting. *)
+ prelude @ [d], original_order, defset, graph
+
+let print_dot graph component : unit =
+ match component with
+ | root :: _ ->
+ print_endline ("// Dependency cycle including " ^ root);
+ print_endline ("digraph cycle_" ^ root ^ " {");
+ List.iter (fun caller ->
+ let print_edge callee = print_endline (" \"" ^ caller ^ "\" -> \"" ^ callee ^ "\";") in
+ Namemap.find caller graph
+ |> Nameset.filter (fun id -> List.mem id component)
+ |> Nameset.iter print_edge) component;
+ print_endline "}"
+ | [] -> ()
+
+let def_of_component graph defset comp =
+ let get_def id = if Namemap.mem id defset then [Namemap.find id defset] else [] in
+ match List.concat (List.map get_def comp) with
+ | [] -> []
+ | [def] -> [def]
+ | (def :: _) as defs ->
+ let get_fundefs = function
+ | DEF_fundef fundef -> [fundef]
+ | DEF_internal_mutrec fundefs -> fundefs
+ | _ ->
+ raise (Reporting_basic.err_unreachable (def_loc def)
+ "Trying to merge non-function definition with mutually recursive functions") in
+ let fundefs = List.concat (List.map get_fundefs defs) in
+ print_dot graph (List.map (fun fd -> string_of_id (id_of_fundef fd)) fundefs);
+ [DEF_internal_mutrec fundefs]
+
+let top_sort_defs (Defs defs) =
+ let prelude, original_order, defset, graph =
+ List.fold_left add_def_to_graph ([], [], Namemap.empty, Namemap.empty) defs in
+ let components = scc ~original_order:original_order graph in
+ Defs (prelude @ List.concat (List.map (def_of_component graph defset) components))
diff --git a/src/type_check.ml b/src/type_check.ml
index 30334783..04c16ad5 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -2029,16 +2029,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
annot_exp (E_case (inferred_exp, List.map (fun case -> check_case env inferred_typ case typ) cases)) typ
| E_try (exp, cases), _ ->
let checked_exp = crule check_exp env exp typ in
- let check_case pat typ = match pat with
- | Pat_aux (Pat_exp (pat, case), (l, _)) ->
- let tpat, env = bind_pat env pat exc_typ in
- Pat_aux (Pat_exp (tpat, crule check_exp env case typ), (l, None))
- | Pat_aux (Pat_when (pat, guard, case), (l, _)) ->
- let tpat, env = bind_pat env pat exc_typ in
- let checked_guard = check_exp env guard bool_typ in
- Pat_aux (Pat_when (tpat, checked_guard, crule check_exp env case typ), (l, None))
- in
- annot_exp (E_try (checked_exp, List.map (fun case -> check_case case typ) cases)) typ
+ annot_exp (E_try (checked_exp, List.map (fun case -> check_case env exc_typ case typ) cases)) typ
| E_cons (x, xs), _ ->
begin
match is_list (Env.expand_synonyms env typ) with
@@ -2093,11 +2084,11 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
| LB_val (P_aux (P_typ (ptyp, _), _) as pat, bind) ->
Env.wf_typ env ptyp;
let checked_bind = crule check_exp env bind ptyp in
- let tpat, env = bind_pat env pat ptyp in
+ let tpat, env = bind_pat_no_guard env pat ptyp in
annot_exp (E_let (LB_aux (LB_val (tpat, checked_bind), (let_loc, None)), crule check_exp env exp typ)) typ
| LB_val (pat, bind) ->
let inferred_bind = irule infer_exp env bind in
- let tpat, env = bind_pat env pat (typ_of inferred_bind) in
+ let tpat, env = bind_pat_no_guard env pat (typ_of inferred_bind) in
annot_exp (E_let (LB_aux (LB_val (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ
end
| E_app_infix (x, op, y), _ ->
@@ -2159,7 +2150,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
| _ ->
let inferred_bind = irule infer_exp env bind in
inferred_bind, typ_of inferred_bind in
- let tpat, env = bind_pat env pat ptyp in
+ let tpat, env = bind_pat_no_guard env pat ptyp in
(* Propagate constraint assertions on the lhs of monadic binds to the rhs *)
let env = match bind_exp with
| E_aux (E_assert (E_aux (E_constraint nc, _), _), _) ->
@@ -2192,7 +2183,17 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ
and check_case env pat_typ pexp typ =
let pat,guard,case,((l,_) as annot) = destruct_pexp pexp in
match bind_pat env pat pat_typ with
- | tpat, env ->
+ | tpat, env, guards ->
+ let guard = match guard, guards with
+ | None, h::t -> Some (h,t)
+ | Some x, l -> Some (x,l)
+ | None, [] -> None
+ in
+ let guard = match guard with
+ | Some (h,t) ->
+ Some (List.fold_left (fun acc guard -> mk_exp (E_app_infix (acc, mk_id "&", guard))) h t)
+ | None -> None
+ in
let checked_guard, env' = match guard with
| None -> None, env
| Some guard ->
@@ -2202,6 +2203,7 @@ and check_case env pat_typ pexp typ =
in
let checked_case = crule check_exp env' case typ in
construct_pexp (tpat, checked_guard, checked_case, (l, None))
+ (* AA: Not sure if we still need this *)
| exception (Type_error _ as typ_exn) ->
match pat with
| P_aux (P_lit lit, _) ->
@@ -2283,29 +2285,34 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ =
try_casts casts
end
+and bind_pat_no_guard env (P_aux (_,(l,_)) as pat) typ =
+ match bind_pat env pat typ with
+ | _, _, _::_ -> typ_error l "Literal patterns not supported here"
+ | tpat, env, [] -> tpat, env
+
and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) =
typ_print ("Binding " ^ string_of_pat pat ^ " to " ^ string_of_typ typ);
let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in
let switch_typ pat typ = match pat with
- | (P_aux (pat_aux, (l, Some (env, _, eff)))) -> P_aux (pat_aux, (l, Some (env, typ, eff)))
- | _ -> failwith "Cannot switch type of unannotated pattern"
+ | P_aux (pat_aux, (l, Some (env, _, eff))) -> P_aux (pat_aux, (l, Some (env, typ, eff)))
+ | _ -> typ_error l "Cannot switch type for unannotated pattern"
in
- let bind_tuple_pat (tpats, env) pat typ =
- let tpat, env = bind_pat env pat typ in tpat :: tpats, env
+ let bind_tuple_pat (tpats, env, guards) pat typ =
+ let tpat, env, guards' = bind_pat env pat typ in tpat :: tpats, env, guards' @ guards
in
match pat_aux with
| P_id v ->
begin
match Env.lookup_id v env with
- | Local (Immutable, _) | Unbound -> annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env
+ | Local (Immutable, _) | Unbound -> annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env, []
| Local (Mutable, _) | Register _ ->
typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat)
- | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env
+ | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env, []
| Union (typq, ctor_typ) ->
begin
try
let _ = unify l env ctor_typ typ in
- annot_pat (P_id v) typ, env
+ annot_pat (P_id v) typ, env, []
with
| Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m)
end
@@ -2318,34 +2325,34 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
let env = Env.add_typ_var kid BK_nat env in
let ex_typ = typ_subst_nexp kid' (Nexp_var kid) ex_typ in
let env = Env.add_constraint (nc_subst_nexp kid' (Nexp_var kid) nc) env in
- let typed_pat, env = bind_pat env pat ex_typ in
- annot_pat (P_var (typed_pat, kid)) typ, env
+ let typed_pat, env, guards = bind_pat env pat ex_typ in
+ annot_pat (P_var (typed_pat, kid)) typ, env, guards
| Some _, _ -> typ_error l ("Cannot bind type variable pattern against multiple argument existential")
| None, Typ_aux (Typ_id id, _) when Id.compare id (mk_id "int") == 0 ->
let env = Env.add_typ_var kid BK_nat env in
- let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in
- annot_pat (P_var (typed_pat, kid)) typ, env
+ let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in
+ annot_pat (P_var (typed_pat, kid)) typ, env, guards
| None, Typ_aux (Typ_id id, _) when Id.compare id (mk_id "nat") == 0 ->
let env = Env.add_typ_var kid BK_nat env in
let env = Env.add_constraint (nc_gt (nvar kid) (nint 0)) env in
- let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in
- annot_pat (P_var (typed_pat, kid)) typ, env
+ let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in
+ annot_pat (P_var (typed_pat, kid)) typ, env, guards
| None, Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp lo, _); Typ_arg_aux (Typ_arg_nexp hi, _)]), _)
when Id.compare id (mk_id "range") == 0 ->
let env = Env.add_typ_var kid BK_nat env in
let env = Env.add_constraint (nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi)) env in
- let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in
- annot_pat (P_var (typed_pat, kid)) typ, env
+ let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in
+ annot_pat (P_var (typed_pat, kid)) typ, env, guards
| None, _ -> typ_error l ("Cannot bind type variable against non existential or numeric type")
end
- | P_wild -> annot_pat P_wild typ, env
+ | P_wild -> annot_pat P_wild typ, env, []
| P_cons (hd_pat, tl_pat) ->
begin
match Env.expand_synonyms env typ with
| Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 ->
- let hd_pat, env = bind_pat env hd_pat ltyp in
- let tl_pat, env = bind_pat env tl_pat typ in
- annot_pat (P_cons (hd_pat, tl_pat)) typ, env
+ let hd_pat, env, hd_guards = bind_pat env hd_pat ltyp in
+ let tl_pat, env, tl_guards = bind_pat env tl_pat typ in
+ annot_pat (P_cons (hd_pat, tl_pat)) typ, env, hd_guards @ tl_guards
| _ -> typ_error l "Cannot match cons pattern against non-list type"
end
| P_list pats ->
@@ -2353,32 +2360,32 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
match Env.expand_synonyms env typ with
| Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 ->
let rec process_pats env = function
- | [] -> [], env
+ | [] -> [], env, []
| (pat :: pats) ->
- let pat', env = bind_pat env pat ltyp in
- let pats', env = process_pats env pats in
- pat' :: pats', env
+ let pat', env, guards = bind_pat env pat ltyp in
+ let pats', env, guards' = process_pats env pats in
+ pat' :: pats', env, guards @ guards'
in
- let pats, env = process_pats env pats in
- annot_pat (P_list pats) typ, env
+ let pats, env, guards = process_pats env pats in
+ annot_pat (P_list pats) typ, env, guards
| _ -> typ_error l ("Cannot match list pattern " ^ string_of_pat pat ^ " against non-list type " ^ string_of_typ typ)
end
| P_tup [] ->
begin
match Env.expand_synonyms env typ with
| Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" ->
- annot_pat (P_tup []) typ, env
+ annot_pat (P_tup []) typ, env, []
| _ -> typ_error l "Cannot match unit pattern against non-unit type"
end
| P_tup pats ->
begin
match Env.expand_synonyms env typ with
| Typ_aux (Typ_tup typs, _) ->
- let tpats, env =
- try List.fold_left2 bind_tuple_pat ([], env) pats typs with
+ let tpats, env, guards =
+ try List.fold_left2 bind_tuple_pat ([], env, []) pats typs with
| Invalid_argument _ -> typ_error l "Tuple pattern and tuple type have different length"
in
- annot_pat (P_tup (List.rev tpats)) typ, env
+ annot_pat (P_tup (List.rev tpats)) typ, env, guards
| _ -> typ_error l "Cannot bind tuple pattern against non tuple type"
end
| P_app (f, pats) when Env.is_union_constructor f env ->
@@ -2402,11 +2409,11 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat)
else ();
let ret_typ' = subst_unifiers unifiers ret_typ in
- let tpats, env =
- try List.fold_left2 bind_tuple_pat ([], env) pats (untuple arg_typ') with
+ let tpats, env, guards =
+ try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with
| Invalid_argument _ -> typ_error l "Union constructor pattern arguments have incorrect length"
in
- annot_pat (P_app (f, List.rev tpats)) typ, env
+ annot_pat (P_app (f, List.rev tpats)) typ, env, guards
with
| Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m)
end
@@ -2415,12 +2422,19 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ)
| P_app (f, _) when not (Env.is_union_constructor f env) ->
typ_error l (string_of_id f ^ " is not a union constructor in pattern " ^ string_of_pat pat)
| P_as (pat, id) ->
- let (typed_pat, env) = bind_pat env pat typ in
- annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env
+ let (typed_pat, env, guards) = bind_pat env pat typ in
+ annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env, guards
| _ ->
- let (inferred_pat, env) = infer_pat env pat in
- subtyp l env (pat_typ_of inferred_pat) typ;
- switch_typ inferred_pat typ, env
+ let (inferred_pat, env, guards) = infer_pat env pat in
+ match subtyp l env (pat_typ_of inferred_pat) typ with
+ | () -> switch_typ inferred_pat typ, env, guards
+ | exception (Type_error _ as typ_exn) ->
+ match pat_aux with
+ | P_lit lit ->
+ let guard = mk_exp (E_app_infix (mk_exp (E_id (mk_id "p#")), mk_id "==", mk_exp (E_lit lit))) in
+ let (typed_pat, env, guards) = bind_pat env (mk_pat (P_id (mk_id "p#"))) typ in
+ typed_pat, env, guard::guards
+ | _ -> raise typ_exn
and infer_pat env (P_aux (pat_aux, (l, ())) as pat) =
let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in
@@ -2432,31 +2446,31 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) =
typ_error l ("Cannot infer identifier in pattern " ^ string_of_pat pat ^ " - try adding a type annotation")
| Local (Mutable, _) | Register _ ->
typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat)
- | Enum enum -> annot_pat (P_id v) enum, env
+ | Enum enum -> annot_pat (P_id v) enum, env, []
end
| P_typ (typ_annot, pat) ->
Env.wf_typ env typ_annot;
- let (typed_pat, env) = bind_pat env pat typ_annot in
- annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env
+ let (typed_pat, env, guards) = bind_pat env pat typ_annot in
+ annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env, guards
| P_lit lit ->
- annot_pat (P_lit lit) (infer_lit env lit), env
+ annot_pat (P_lit lit) (infer_lit env lit), env, []
| P_vector (pat :: pats) ->
- let fold_pats (pats, env) pat =
- let typed_pat, env = bind_pat env pat bit_typ in
- pats @ [typed_pat], env
+ let fold_pats (pats, env, guards) pat =
+ let typed_pat, env, guards' = bind_pat env pat bit_typ in
+ pats @ [typed_pat], env, guards' @ guards
in
- let pats, env =
- List.fold_left fold_pats ([], env) (pat :: pats) in
+ let pats, env, guards = List.fold_left fold_pats ([], env, []) (pat :: pats) in
let len = nexp_simp (nint (List.length pats)) in
let etyp = pat_typ_of (List.hd pats) in
List.iter (fun pat -> typ_equality l env etyp (pat_typ_of pat)) pats;
- annot_pat (P_vector pats) (dvector_typ env len etyp), env
+ annot_pat (P_vector pats) (dvector_typ env len etyp), env, guards
| P_vector_concat (pat :: pats) ->
- let fold_pats (pats, env) pat =
- let inferred_pat, env = infer_pat env pat in
- pats @ [inferred_pat], env
+ let fold_pats (pats, env, guards) pat =
+ let inferred_pat, env, guards' = infer_pat env pat in
+ pats @ [inferred_pat], env, guards' @ guards
in
- let inferred_pats, env = List.fold_left fold_pats ([], env) (pat :: pats) in
+ let inferred_pats, env, guards =
+ List.fold_left fold_pats ([], env, []) (pat :: pats) in
let (len, _, vtyp) = destruct_vec_typ l env (pat_typ_of (List.hd inferred_pats)) in
let fold_len len pat =
let (len', _, vtyp') = destruct_vec_typ l env (pat_typ_of pat) in
@@ -2464,10 +2478,12 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) =
nsum len len'
in
let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_pats)) in
- annot_pat (P_vector_concat inferred_pats) (dvector_typ env len vtyp), env
+ annot_pat (P_vector_concat inferred_pats) (dvector_typ env len vtyp), env, guards
| P_as (pat, id) ->
- let (typed_pat, env) = infer_pat env pat in
- annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env
+ let (typed_pat, env, guards) = infer_pat env pat in
+ annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat),
+ Env.add_local id (Immutable, pat_typ_of typed_pat) env,
+ guards
| _ -> typ_error l ("Couldn't infer type of pattern " ^ string_of_pat pat)
and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as exp) =
@@ -2858,7 +2874,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
| _ ->
let inferred_bind = irule infer_exp env bind in
inferred_bind, typ_of inferred_bind in
- let tpat, env = bind_pat env pat ptyp in
+ let tpat, env = bind_pat_no_guard env pat ptyp in
(* Propagate constraint assertions on the lhs of monadic binds to the rhs *)
let env = match bind_exp with
| E_aux (E_assert (E_aux (E_constraint nc, _), _), _) ->
@@ -2876,7 +2892,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) =
| LB_val (pat, bind) ->
let inferred_bind = irule infer_exp env bind in
inferred_bind, pat, typ_of inferred_bind in
- let tpat, env = bind_pat env pat ptyp in
+ let tpat, env = bind_pat_no_guard env pat ptyp in
let inferred_exp = irule infer_exp env exp in
annot_exp (E_let (LB_aux (LB_val (tpat, bind_exp), (let_loc, None)), inferred_exp)) (typ_of inferred_exp)
| E_ref id when Env.is_mutable id env ->
@@ -3328,11 +3344,11 @@ let check_letdef env (LB_aux (letbind, (l, _))) =
match letbind with
| LB_val (P_aux (P_typ (typ_annot, pat), _), bind) ->
let checked_bind = crule check_exp env (strip_exp bind) typ_annot in
- let tpat, env = bind_pat env (strip_pat pat) typ_annot in
+ let tpat, env = bind_pat_no_guard env (strip_pat pat) typ_annot in
[DEF_val (LB_aux (LB_val (P_aux (P_typ (typ_annot, tpat), (l, Some (env, typ_annot, no_effect))), checked_bind), (l, None)))], env
| LB_val (pat, bind) ->
let inferred_bind = irule infer_exp env (strip_exp bind) in
- let tpat, env = bind_pat env (strip_pat pat) (typ_of inferred_bind) in
+ let tpat, env = bind_pat_no_guard env (strip_pat pat) (typ_of inferred_bind) in
[DEF_val (LB_aux (LB_val (tpat, inferred_bind), (l, None)))], env
end
diff --git a/src/type_check.mli b/src/type_check.mli
index e3daec75..3f43492f 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -199,7 +199,11 @@ val prove : Env.t -> n_constraint -> bool
val subtype_check : Env.t -> typ -> typ -> bool
-val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t
+val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t * unit Ast.exp list
+(* Variant that doesn't introduce new guards for literal patterns, but raises
+ a type error instead. This should always be safe to use on patterns that
+ have previously been type checked. *)
+val bind_pat_no_guard : Env.t -> unit pat -> typ -> tannot pat * Env.t
(* Partial functions: The expressions and patterns passed to these
functions must be guaranteed to have tannots of the form Some (env,