diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 3 | ||||
| -rw-r--r-- | src/ast_util.mli | 1 | ||||
| -rw-r--r-- | src/monomorphise.ml | 80 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 223 | ||||
| -rw-r--r-- | src/rewrites.ml | 74 | ||||
| -rw-r--r-- | src/spec_analysis.ml | 387 | ||||
| -rw-r--r-- | src/type_check.ml | 160 | ||||
| -rw-r--r-- | src/type_check.mli | 6 |
8 files changed, 455 insertions, 479 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 074376a9..a70db3e0 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -428,6 +428,9 @@ let id_loc = function let kid_loc = function | Kid_aux (_, l) -> l +let pat_loc = function + | P_aux (_, (l, _)) -> l + let def_loc = function | DEF_kind (KD_aux (_, (l, _))) | DEF_type (TD_aux (_, (l, _))) diff --git a/src/ast_util.mli b/src/ast_util.mli index ec12d44b..69d80ea7 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -164,6 +164,7 @@ val map_letbind_annot : ('a annot -> 'b annot) -> 'a letbind -> 'b letbind val id_loc : id -> Parse_ast.l val kid_loc : kid -> Parse_ast.l +val pat_loc : 'a pat -> Parse_ast.l val def_loc : 'a def -> Parse_ast.l (* For debugging and error messages only: Not guaranteed to produce diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 4fa5d1d6..3ea156d0 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -54,8 +54,7 @@ open Ast_util module Big_int = Nat_big_num open Type_check -let size_set_limit = 8 -let vector_split_limit = 4 +let size_set_limit = 16 let optmap v f = match v with @@ -1316,14 +1315,10 @@ let split_defs continue_anyway splits defs = | Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp len,_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) -> (match len with | Nexp_aux (Nexp_constant sz,_) -> - if Big_int.to_int sz <= vector_split_limit then - let lits = make_vectors (Big_int.to_int sz) in - List.map (fun lit -> - P_aux (P_lit lit,(l,annot)), - [var,E_aux (E_lit lit,(new_l,annot))]) lits - else - cannot ("Refusing to split vector type of length " ^ Big_int.to_string sz ^ - " (above limit " ^ string_of_int vector_split_limit ^ ")") + let lits = make_vectors (Big_int.to_int sz) in + List.map (fun lit -> + P_aux (P_lit lit,(l,annot)), + [var,E_aux (E_lit lit,(new_l,annot))]) lits | _ -> cannot ("length not constant, " ^ string_of_nexp len) ) @@ -1494,6 +1489,19 @@ let split_defs continue_anyway splits defs = in p in + let check_split_size lst l = + let size = List.length lst in + if size > size_set_limit then + let open Reporting_basic in + let error = + Err_general (l, "Case split is too large (" ^ string_of_int size ^ + " > limit " ^ string_of_int size_set_limit ^ ")") + in if continue_anyway + then (print_error error; false) + else raise (Fatal_error error) + else true + in + let rec map_exp ((E_aux (e,annot)) as ea) = let re e = E_aux (e,annot) in match e with @@ -1556,13 +1564,16 @@ let split_defs continue_anyway splits defs = FE_aux (FE_Fexp (id,map_exp e),annot) and map_pexp = function | Pat_aux (Pat_exp (p,e),l) -> + let nosplit = [Pat_aux (Pat_exp (p,map_exp e),l)] in (match map_pat p with - | NoSplit -> [Pat_aux (Pat_exp (p,map_exp e),l)] + | NoSplit -> nosplit | VarSplit patsubsts -> - List.map (fun (pat',substs) -> - let exp' = subst_exp substs e in - Pat_aux (Pat_exp (pat', map_exp exp'),l)) - patsubsts + if check_split_size patsubsts (pat_loc p) then + List.map (fun (pat',substs) -> + let exp' = subst_exp substs e in + Pat_aux (Pat_exp (pat', map_exp exp'),l)) + patsubsts + else nosplit | ConstrSplit patnsubsts -> List.map (fun (pat',nsubst) -> let pat' = nexp_subst_pat nsubst pat' in @@ -1570,14 +1581,17 @@ let split_defs continue_anyway splits defs = Pat_aux (Pat_exp (pat', map_exp exp'),l) ) patnsubsts) | Pat_aux (Pat_when (p,e1,e2),l) -> + let nosplit = [Pat_aux (Pat_when (p,map_exp e1,map_exp e2),l)] in (match map_pat p with - | NoSplit -> [Pat_aux (Pat_when (p,map_exp e1,map_exp e2),l)] + | NoSplit -> nosplit | VarSplit patsubsts -> - List.map (fun (pat',substs) -> - let exp1' = subst_exp substs e1 in - let exp2' = subst_exp substs e2 in - Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)) - patsubsts + if check_split_size patsubsts (pat_loc p) then + List.map (fun (pat',substs) -> + let exp1' = subst_exp substs e1 in + let exp2' = subst_exp substs e2 in + Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)) + patsubsts + else nosplit | ConstrSplit patnsubsts -> List.map (fun (pat',nsubst) -> let pat' = nexp_subst_pat nsubst pat' in @@ -1770,6 +1784,10 @@ let rewrite_size_parameters env (Defs defs) = { (compute_exp_alg KidSet.empty KidSet.union) with e_aux = (fun ((s,e),annot) -> KidSet.union s (sizes_of_annot annot), E_aux (e,annot)); e_let = (fun ((sl,lb),(s2,e2)) -> KidSet.union sl (KidSet.diff s2 (tyvars_bound_in_lb lb)), E_let (lb,e2)); + e_for = (fun (id,(s1,e1),(s2,e2),(s3,e3),ord,(s4,e4)) -> + let kid = mk_kid ("loop_" ^ string_of_id id) in + KidSet.union s1 (KidSet.union s2 (KidSet.union s3 (KidSet.remove kid s4))), + E_for (id,e1,e2,e3,ord,e4)); pat_exp = (fun ((sp,pat),(s,e)) -> KidSet.diff s (tyvars_bound_in_pat pat), Pat_exp (pat,e))} pexp) in @@ -1786,7 +1804,7 @@ let rewrite_size_parameters env (Defs defs) = | P_aux (P_tup ps,_) -> ps | _ -> [pat] in - let to_change = List.map + let to_change = Util.map_filter (fun kid -> let check (P_aux (_,(_,Some (env,typ,_)))) = match Env.expand_synonyms env typ with @@ -1799,9 +1817,10 @@ let rewrite_size_parameters env (Defs defs) = if Kid.compare kid kid' = 0 then Some kid else None | _ -> None in match findi check parameters with - | None -> raise (Reporting_basic.err_general l - ("Unable to find an argument for " ^ string_of_kid kid)) - | Some i -> i) + | None -> (Reporting_basic.print_error (Reporting_basic.Err_general (l, + ("Unable to find an argument for " ^ string_of_kid kid))); + None) + | Some i -> Some i) (KidSet.elements expose_tyvars) in let ik_compare (i,k) (i',k') = @@ -1858,7 +1877,7 @@ let rewrite_size_parameters env (Defs defs) = let body = List.fold_left add_var_rebind body vars in let guard = match guard with | None -> None - | Some exp -> Some (List.fold_left add_var_rebind body vars) + | Some exp -> Some (List.fold_left add_var_rebind exp vars) in pat,guard,body | exception Not_found -> pat,guard,body @@ -2752,6 +2771,15 @@ let rewrite_app env typ (id,args) = | _ -> E_app (id,args) + else if is_id env (Id "UInt") id then + let is_slice = is_id env (Id "slice") in + match args with + | [E_aux (E_app (slice1, [vector1; start1; length1]),_)] + when is_slice slice1 && not (is_constant length1) -> + E_app (mk_id "UInt_slice", [vector1; start1; length1]) + + | _ -> E_app (id,args) + else E_app (id,args) let rewrite_aux = function diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 0002f8cc..c776081c 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -62,6 +62,12 @@ open Pretty_print_common let opt_sequential = ref false let opt_mwords = ref false +type context = { + early_ret : bool; + bound_nexps : NexpSet.t; +} +let empty_ctxt = { early_ret = false; bound_nexps = NexpSet.empty } + let print_to_from_interp_value = ref false let langlebar = string "<|" let ranglebar = string "|>" @@ -326,12 +332,12 @@ let doc_typ_lem, doc_atomic_typ_lem = length argument are checked for variables, and the latter only if it is a bitvector; for other types of vectors, the length is not pretty-printed in the type, and the start index is never pretty-printed in vector types. *) -let rec contains_t_pp_var (Typ_aux (t,a) as typ) = match t with +let rec contains_t_pp_var ctxt (Typ_aux (t,a) as typ) = match t with | Typ_id _ -> false | Typ_var _ -> true | Typ_exist _ -> true - | Typ_fn (t1,t2,_) -> contains_t_pp_var t1 || contains_t_pp_var t2 - | Typ_tup ts -> List.exists contains_t_pp_var ts + | Typ_fn (t1,t2,_) -> contains_t_pp_var ctxt t1 || contains_t_pp_var ctxt t2 + | Typ_tup ts -> List.exists (contains_t_pp_var ctxt) ts | Typ_app (c,targs) -> if Ast_util.is_number typ then false else if is_bitvector_typ typ then @@ -340,23 +346,22 @@ let rec contains_t_pp_var (Typ_aux (t,a) as typ) = match t with not (is_nexp_constant length || (!opt_mwords && match length with Nexp_aux (Nexp_var _,_) -> true | _ -> false)) - else List.exists contains_t_arg_pp_var targs -and contains_t_arg_pp_var (Typ_arg_aux (targ, _)) = match targ with - | Typ_arg_typ t -> contains_t_pp_var t - | Typ_arg_nexp nexp -> not (is_nexp_constant (nexp_simp nexp)) + else List.exists (contains_t_arg_pp_var ctxt) targs +and contains_t_arg_pp_var ctxt (Typ_arg_aux (targ, _)) = match targ with + | Typ_arg_typ t -> contains_t_pp_var ctxt t + | Typ_arg_nexp nexp -> + let nexp = nexp_simp nexp in + not (is_nexp_constant nexp || NexpSet.mem nexp ctxt.bound_nexps) | _ -> false -let doc_tannot_lem eff typ = - if contains_t_pp_var typ then empty +let doc_tannot_lem ctxt eff typ = + if contains_t_pp_var ctxt typ then empty else let ta = doc_typ_lem typ in if eff then string " : M " ^^ parens ta else string " : " ^^ ta -(* doc_lit_lem gets as an additional parameter the type information from the - * expression around it: that's a hack, but how else can we distinguish between - * undefined values of different types ? *) -let doc_lit_lem in_pat (L_aux(lit,l)) a = +let doc_lit_lem (L_aux(lit,l)) = match lit with | L_unit -> utf8string "()" | L_zero -> utf8string "B0" @@ -366,24 +371,12 @@ let doc_lit_lem in_pat (L_aux(lit,l)) a = | L_num i -> let ipp = Big_int.to_string i in utf8string ( - if in_pat then "("^ipp^":nn)" - else if Big_int.less i Big_int.zero then "((0"^ipp^"):ii)" + if Big_int.less i Big_int.zero then "((0"^ipp^"):ii)" else "("^ipp^":ii)") | L_hex n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*)*) | L_bin n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*)*) | L_undef -> - (match a with - | Some (_, (Typ_aux (t,_) as typ), _) -> - (match t with - | Typ_id (Id_aux (Id "bit", _)) - | Typ_app (Id_aux (Id "register", _),_) -> utf8string "UndefinedRegister 0" - | Typ_id (Id_aux (Id "string", _)) -> utf8string "\"\"" - | _ -> - let ta = if contains_t_pp_var typ then empty - else doc_tannot_lem false typ in - parens - ((utf8string "(failwith \"undefined value of unsupported type\")") ^^ ta)) - | _ -> utf8string "(failwith \"undefined value of unsupported type\")") + utf8string "(failwith \"undefined value of unsupported type\")" | L_string s -> utf8string ("\"" ^ s ^ "\"") | L_real s -> (* Lem does not support decimal syntax, so we translate a string @@ -427,29 +420,30 @@ let doc_typquant_lem (TypQ_aux(tq,_)) vars_included typ = match tq with machine words. Often these will be unnecessary, but this simple approach will do for now. *) -let rec typeclass_nexps (Typ_aux(t,_)) = match t with -| Typ_id _ -| Typ_var _ - -> NexpSet.empty -| Typ_fn (t1,t2,_) -> NexpSet.union (typeclass_nexps t1) (typeclass_nexps t2) -| Typ_tup ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts) -| Typ_app (Id_aux (Id "vector",_), - [_;Typ_arg_aux (Typ_arg_nexp size_nexp,_); - _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) -| Typ_app (Id_aux (Id "itself",_), - [Typ_arg_aux (Typ_arg_nexp size_nexp,_)]) -> - let size_nexp = nexp_simp size_nexp in - if is_nexp_constant size_nexp then NexpSet.empty else - NexpSet.singleton (orig_nexp size_nexp) -| Typ_app _ -> NexpSet.empty -| Typ_exist (kids,_,t) -> NexpSet.empty (* todo *) +let rec typeclass_nexps (Typ_aux(t,_)) = + if !opt_mwords then + match t with + | Typ_id _ + | Typ_var _ + -> NexpSet.empty + | Typ_fn (t1,t2,_) -> NexpSet.union (typeclass_nexps t1) (typeclass_nexps t2) + | Typ_tup ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts) + | Typ_app (Id_aux (Id "vector",_), + [_;Typ_arg_aux (Typ_arg_nexp size_nexp,_); + _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) + | Typ_app (Id_aux (Id "itself",_), + [Typ_arg_aux (Typ_arg_nexp size_nexp,_)]) -> + let size_nexp = nexp_simp size_nexp in + if is_nexp_constant size_nexp then NexpSet.empty else + NexpSet.singleton (orig_nexp size_nexp) + | Typ_app _ -> NexpSet.empty + | Typ_exist (kids,_,t) -> NexpSet.empty (* todo *) + else NexpSet.empty let doc_typclasses_lem t = - if !opt_mwords then - let nexps = typeclass_nexps t in - if NexpSet.is_empty nexps then (empty, NexpSet.empty) else - (separate_map comma_sp (fun nexp -> string "Size " ^^ doc_nexp_lem nexp) (NexpSet.elements nexps) ^^ string " => ", nexps) - else (empty, NexpSet.empty) + let nexps = typeclass_nexps t in + if NexpSet.is_empty nexps then (empty, NexpSet.empty) else + (separate_map comma_sp (fun nexp -> string "Size " ^^ doc_nexp_lem nexp) (NexpSet.elements nexps) ^^ string " => ", nexps) let doc_typschm_lem quants (TypSchm_aux(TypSchm_ts(tq,t),_)) = let pt = doc_typ_lem t in @@ -470,44 +464,44 @@ let is_ctor env id = match Env.lookup_id id env with (*Note: vector concatenation, literal vectors, indexed vectors, and record should be removed prior to pp. The latter two have never yet been seen *) -let rec doc_pat_lem apat_needed (P_aux (p,(l,annot)) as pa) = match p with +let rec doc_pat_lem ctxt apat_needed (P_aux (p,(l,annot)) as pa) = match p with | P_app(id, ((_ :: _) as pats)) -> let ppp = doc_unop (doc_id_lem_ctor id) - (parens (separate_map comma (doc_pat_lem true) pats)) in + (parens (separate_map comma (doc_pat_lem ctxt true) pats)) in if apat_needed then parens ppp else ppp | P_app(id,[]) -> doc_id_lem_ctor id - | P_lit lit -> doc_lit_lem true lit annot + | P_lit lit -> doc_lit_lem lit | P_wild -> underscore | P_id id -> begin match id with | Id_aux (Id "None",_) -> string "Nothing" (* workaround temporary issue *) | _ -> doc_id_lem id end - | P_var(p,kid) -> doc_pat_lem true p - | P_as(p,id) -> parens (separate space [doc_pat_lem true p; string "as"; doc_id_lem id]) + | P_var(p,kid) -> doc_pat_lem ctxt true p + | P_as(p,id) -> parens (separate space [doc_pat_lem ctxt true p; string "as"; doc_id_lem id]) | P_typ(Typ_aux (Typ_tup typs, _), P_aux (P_tup pats, _)) -> (* Isabelle does not seem to like type-annotated tuple patterns; it gives a syntax error. Avoid this by annotating the tuple elements instead *) let doc_elem typ (P_aux (_, annot) as pat) = - doc_pat_lem true (P_aux (P_typ (typ, pat), annot)) in + doc_pat_lem ctxt true (P_aux (P_typ (typ, pat), annot)) in parens (separate comma_sp (List.map2 doc_elem typs pats)) | P_typ(typ,p) -> - let doc_p = doc_pat_lem true p in - if contains_t_pp_var typ then doc_p + let doc_p = doc_pat_lem ctxt true p in + if contains_t_pp_var ctxt typ then doc_p else parens (doc_op colon doc_p (doc_typ_lem typ)) | P_vector pats -> let ppp = (separate space) - [string "Vector";brackets (separate_map semi (doc_pat_lem true) pats);underscore;underscore] in + [string "Vector";brackets (separate_map semi (doc_pat_lem ctxt true) pats);underscore;underscore] in if apat_needed then parens ppp else ppp | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l "vector concatenation patterns should have been removed before pretty-printing") | P_tup pats -> (match pats with - | [p] -> doc_pat_lem apat_needed p - | _ -> parens (separate_map comma_sp (doc_pat_lem false) pats)) - | P_list pats -> brackets (separate_map semi (doc_pat_lem false) pats) (*Never seen but easy in lem*) - | P_cons (p,p') -> doc_op (string "::") (doc_pat_lem true p) (doc_pat_lem true p') + | [p] -> doc_pat_lem ctxt apat_needed p + | _ -> parens (separate_map comma_sp (doc_pat_lem ctxt false) pats)) + | P_list pats -> brackets (separate_map semi (doc_pat_lem ctxt false) pats) (*Never seen but easy in lem*) + | P_cons (p,p') -> doc_op (string "::") (doc_pat_lem ctxt true p) (doc_pat_lem ctxt true p') | P_record (_,_) -> empty (* TODO *) let rec typ_needs_printed (Typ_aux (t,_) as typ) = match t with @@ -539,13 +533,13 @@ let typ_id_of (Typ_aux (typ, l)) = match typ with let prefix_recordtype = true let report = Reporting_basic.err_unreachable let doc_exp_lem, doc_let_lem = - let rec top_exp (early_ret : bool) (aexp_needed : bool) + let rec top_exp (ctxt : context) (aexp_needed : bool) (E_aux (e, (l,annot)) as full_exp) = - let expY = top_exp early_ret true in - let expN = top_exp early_ret false in - let expV = top_exp early_ret in + let expY = top_exp ctxt true in + let expN = top_exp ctxt false in + let expV = top_exp ctxt in let liftR doc = - if early_ret && effectful (effect_of full_exp) + if ctxt.early_ret && effectful (effect_of full_exp) then separate space [string "liftR"; parens (doc)] else doc in match e with @@ -565,10 +559,10 @@ let doc_exp_lem, doc_let_lem = doc_id_lem id in liftR ((prefix 2 1) (string "write_reg_field_range") - (align (doc_lexp_deref_lem early_ret le ^/^ + (align (doc_lexp_deref_lem ctxt le ^/^ field_ref ^/^ expY e2 ^/^ expY e3 ^/^ expY e))) | _ -> - let deref = doc_lexp_deref_lem early_ret le in + let deref = doc_lexp_deref_lem ctxt le in liftR ((prefix 2 1) (string "write_reg_range") (align (deref ^/^ expY e2 ^/^ expY e3) ^/^ expY e))) @@ -585,10 +579,10 @@ let doc_exp_lem, doc_let_lem = let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot fannot)) then "write_reg_field_bit" else "write_reg_field_pos" in liftR ((prefix 2 1) (string call) - (align (doc_lexp_deref_lem early_ret le ^/^ + (align (doc_lexp_deref_lem ctxt le ^/^ field_ref ^/^ expY e2 ^/^ expY e))) | LEXP_aux (_, lannot) -> - let deref = doc_lexp_deref_lem early_ret le in + let deref = doc_lexp_deref_lem ctxt le in let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot lannot)) then "write_reg_bit" else "write_reg_pos" in liftR ((prefix 2 1) (string call) (deref ^/^ expY e2 ^/^ expY e)) @@ -602,10 +596,10 @@ let doc_exp_lem, doc_let_lem = string "set_field"*) in liftR ((prefix 2 1) (string "write_reg_field") - (doc_lexp_deref_lem early_ret le ^^ space ^^ + (doc_lexp_deref_lem ctxt le ^^ space ^^ field_ref ^/^ expY e)) | _ -> - liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem early_ret le ^/^ expY e))) + liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem ctxt le ^/^ expY e))) | E_vector_append(le,re) -> raise (Reporting_basic.err_unreachable l "E_vector_append should have been rewritten before pretty-printing") @@ -621,7 +615,7 @@ let doc_exp_lem, doc_let_lem = | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> raise (report l "E_for should have been removed till now") | E_let(leb,e) -> - let epp = let_exp early_ret leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in + let epp = let_exp ctxt leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in if aexp_needed then parens epp else epp | E_app(f,args) -> begin match f with @@ -676,8 +670,8 @@ let doc_exp_lem, doc_let_lem = | [exp] -> let epp = separate space [string "early_return"; expY exp] in let aexp_needed, tepp = - if contains_t_pp_var (typ_of exp) || - contains_t_pp_var (typ_of full_exp) then + if contains_t_pp_var ctxt (typ_of exp) || + contains_t_pp_var ctxt (typ_of full_exp) then aexp_needed, epp else let tannot = separate space [string "MR"; @@ -716,7 +710,7 @@ let doc_exp_lem, doc_let_lem = let t = (*Env.base_typ_of (env_of full_exp)*) (typ_of full_exp) in let eff = effect_of full_exp in if typ_needs_printed (Env.base_typ_of (env_of full_exp) t) - then (align epp ^^ (doc_tannot_lem (effectful eff) t), true) + then (align epp ^^ (doc_tannot_lem ctxt (effectful eff) t), true) else (epp, aexp_needed) in liftR (if aexp_needed then parens (align taepp) else taepp) end @@ -749,11 +743,11 @@ let doc_exp_lem, doc_let_lem = if has_effect eff BE_rreg then let epp = separate space [string "read_reg";doc_id_lem id] in if is_bitvector_typ base_typ - then liftR (parens (epp ^^ doc_tannot_lem true base_typ)) + then liftR (parens (epp ^^ doc_tannot_lem ctxt true base_typ)) else liftR epp else if is_ctor env id then doc_id_lem_ctor id else doc_id_lem id - | E_lit lit -> doc_lit_lem false lit annot + | E_lit lit -> doc_lit_lem lit | E_cast(typ,e) -> expV aexp_needed e | E_tuple exps -> @@ -767,7 +761,7 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in let epp = anglebars (space ^^ (align (separate_map (semi_sp ^^ break 1) - (doc_fexp early_ret recordtyp) fexps)) ^^ space) in + (doc_fexp ctxt recordtyp) fexps)) ^^ space) in if aexp_needed then parens epp else epp | E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) -> let recordtyp = match annot with @@ -776,7 +770,7 @@ let doc_exp_lem, doc_let_lem = when Env.is_record tid env -> tid | _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in - anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp early_ret recordtyp) fexps)) + anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps)) | E_vector exps -> let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (start, len, order, etyp) = @@ -804,7 +798,7 @@ let doc_exp_lem, doc_let_lem = let (epp,aexp_needed) = if is_bit_typ etyp && !opt_mwords then let bepp = string "vec_to_bvec" ^^ space ^^ parens (align epp) in - (bepp ^^ doc_tannot_lem false t, true) + (bepp ^^ doc_tannot_lem ctxt false t, true) else (epp,aexp_needed) in if aexp_needed then parens (align epp) else epp | E_vector_update(v,e1,e2) -> @@ -835,15 +829,15 @@ let doc_exp_lem, doc_let_lem = let epp = group ((separate space [string "match"; only_integers e; string "with"]) ^/^ - (separate_map (break 1) (doc_case early_ret) pexps) ^/^ + (separate_map (break 1) (doc_case ctxt) pexps) ^/^ (string "end")) in if aexp_needed then parens (align epp) else align epp | E_try (e, pexps) -> if effectful (effect_of e) then - let try_catch = if early_ret then "try_catchR" else "try_catch" in + let try_catch = if ctxt.early_ret then "try_catchR" else "try_catch" in let epp = group ((separate space [string try_catch; expY e; string "(function "]) ^/^ - (separate_map (break 1) (doc_case early_ret) pexps) ^/^ + (separate_map (break 1) (doc_case ctxt) pexps) ^/^ (string "end)")) in if aexp_needed then parens (align epp) else align epp else @@ -868,20 +862,20 @@ let doc_exp_lem, doc_let_lem = (separate space [expV b e1; string ">>"]) ^^ hardline ^^ expN e2 | _ -> (separate space [expV b e1; string ">>= fun"; - doc_pat_lem true pat;arrow]) ^^ hardline ^^ expN e2 in + doc_pat_lem ctxt true pat;arrow]) ^^ hardline ^^ expN e2 in if aexp_needed then parens (align epp) else epp | E_internal_return (e1) -> separate space [string "return"; expY e1] | E_sizeof nexp -> (match nexp_simp nexp with - | Nexp_aux (Nexp_constant i, _) -> doc_lit_lem false (L_aux (L_num i, l)) annot + | Nexp_aux (Nexp_constant i, _) -> doc_lit_lem (L_aux (L_num i, l)) | _ -> raise (Reporting_basic.err_unreachable l "pretty-printing non-constant sizeof expressions to Lem not supported")) | E_return r -> let ret_monad = if !opt_sequential then " : MR regstate" else " : MR" in let ta = - if contains_t_pp_var (typ_of full_exp) || contains_t_pp_var (typ_of r) + if contains_t_pp_var ctxt (typ_of full_exp) || contains_t_pp_var ctxt (typ_of r) then empty else separate space [string ret_monad; @@ -893,33 +887,33 @@ let doc_exp_lem, doc_let_lem = | E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_exp_user _ -> raise (Reporting_basic.err_unreachable l "unsupported internal expression encountered while pretty-printing") - and let_exp early_ret (LB_aux(lb,_)) = match lb with + and let_exp ctxt (LB_aux(lb,_)) = match lb with | LB_val(pat,e) -> prefix 2 1 - (separate space [string "let"; doc_pat_lem true pat; equals]) - (top_exp early_ret false e) + (separate space [string "let"; doc_pat_lem ctxt true pat; equals]) + (top_exp ctxt false e) - and doc_fexp early_ret recordtyp (FE_aux(FE_Fexp(id,e),_)) = + and doc_fexp ctxt recordtyp (FE_aux(FE_Fexp(id,e),_)) = let fname = if prefix_recordtype then (string (string_of_id recordtyp ^ "_")) ^^ doc_id_lem id else doc_id_lem id in - group (doc_op equals fname (top_exp early_ret true e)) + group (doc_op equals fname (top_exp ctxt true e)) - and doc_case early_ret = function + and doc_case ctxt = function | Pat_aux(Pat_exp(pat,e),_) -> - group (prefix 3 1 (separate space [pipe; doc_pat_lem false pat;arrow]) - (group (top_exp early_ret false e))) + group (prefix 3 1 (separate space [pipe; doc_pat_lem ctxt false pat;arrow]) + (group (top_exp ctxt false e))) | Pat_aux(Pat_when(_,_,_),(l,_)) -> raise (Reporting_basic.err_unreachable l "guarded pattern expression should have been rewritten before pretty-printing") - and doc_lexp_deref_lem early_ret ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with + and doc_lexp_deref_lem ctxt ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with | LEXP_field (le,id) -> - parens (separate empty [doc_lexp_deref_lem early_ret le;dot;doc_id_lem id]) + parens (separate empty [doc_lexp_deref_lem ctxt le;dot;doc_id_lem id]) | LEXP_id id -> doc_id_lem id | LEXP_cast (typ,id) -> doc_id_lem id - | LEXP_tup lexps -> parens (separate_map comma_sp (doc_lexp_deref_lem early_ret) lexps) + | LEXP_tup lexps -> parens (separate_map comma_sp (doc_lexp_deref_lem ctxt) lexps) | _ -> raise (Reporting_basic.err_unreachable l ("doc_lexp_deref_lem: Unsupported lexp")) (* expose doc_exp_lem and doc_let *) @@ -963,7 +957,7 @@ let doc_typdef_lem (TD_aux(td, (l, annot))) = match td with mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), [mk_typ_arg (Typ_arg_typ rectyp); mk_typ_arg (Typ_arg_typ ftyp)])) in - let rfannot = doc_tannot_lem false reftyp in + let rfannot = doc_tannot_lem empty_ctxt false reftyp in let get, set = string "rec_val" ^^ dot ^^ fname fid, anglebars (space ^^ string "rec_val with " ^^ @@ -1215,17 +1209,20 @@ let doc_rec_lem (Rec_aux(r,_)) = match r with let doc_tannot_opt_lem (Typ_annot_opt_aux(t,_)) = match t with | Typ_annot_opt_some(tq,typ) -> (*doc_typquant_lem tq*) (doc_typ_lem typ) -let doc_fun_body_lem exp = - let early_ret =contains_early_return exp in - let doc_exp = doc_exp_lem early_ret false exp in - if early_ret +let doc_fun_body_lem ctxt exp = + let doc_exp = doc_exp_lem ctxt false exp in + if ctxt.early_ret then align (string "catch_early_return" ^//^ parens (doc_exp)) else doc_exp -let doc_funcl_lem (FCL_aux(FCL_Funcl(id,pexp),_)) = +let doc_funcl_lem (FCL_aux(FCL_Funcl(id, pexp), annot)) = + let typ = typ_of_annot annot in let pat,guard,exp,(l,_) = destruct_pexp pexp in + let ctxt = + { early_ret = contains_early_return exp; + bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) } in let pats, bind = untuple_args_pat pat in - let patspp = separate_map space (doc_pat_lem true) pats in + let patspp = separate_map space (doc_pat_lem ctxt true) pats in let _ = match guard with | None -> () | _ -> @@ -1233,7 +1230,7 @@ let doc_funcl_lem (FCL_aux(FCL_Funcl(id,pexp),_)) = "guarded pattern expression should have been rewritten before pretty-printing") in group (prefix 3 1 (separate space [doc_id_lem id; patspp; equals]) - (doc_fun_body_lem (bind exp))) + (doc_fun_body_lem ctxt (bind exp))) let get_id = function | [] -> failwith "FD_function with empty list" @@ -1247,8 +1244,8 @@ let doc_fundef_rhs_lem (FD_aux(FD_function(r, typa, efa, funcls),fannot) as fd) let doc_mutrec_lem = function | [] -> failwith "DEF_internal_mutrec with empty function list" | fundefs -> - string "let rec " ^^ - separate_map (hardline ^^ string "and ") doc_fundef_rhs_lem fundefs + string "let rec " ^^ + separate_map (hardline ^^ string "and ") doc_fundef_rhs_lem fundefs let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) = match fcls with @@ -1301,13 +1298,13 @@ let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) = let named_args = if argspat = [] then [unit_pat] else named_argspat in let doc_arg idx (P_aux (p,(l,a))) = match p with | P_as (pat,id) -> doc_id_lem id - | P_lit lit -> doc_lit_lem false lit a + | P_lit lit -> doc_lit_lem lit | P_id id -> doc_id_lem id | _ -> string ("arg" ^ string_of_int idx) in let clauses = clauses ^^ (break 1) ^^ (separate space - [pipe;doc_pat_lem false named_pat;arrow; + [pipe;doc_pat_lem empty_ctxt false named_pat;arrow; string aux_fname; separate space (List.mapi doc_arg named_args)]) in (already_used_fnames,auxiliary_functions,clauses) @@ -1389,7 +1386,7 @@ let doc_regtype_fields (tname, (n1, n2, fields)) = mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), [mk_typ_arg (Typ_arg_typ (mk_id_typ (mk_id tname))); mk_typ_arg (Typ_arg_typ ftyp)])) in - let rfannot = doc_tannot_lem false reftyp in + let rfannot = doc_tannot_lem empty_ctxt false reftyp in doc_op equals (concat [string "let "; parens (concat [string tname; underscore; doc_id_lem fid; rfannot])]) (concat [ @@ -1416,7 +1413,7 @@ let rec doc_def_lem regtypes def = if is_field_accessor regtypes fdef then (doc_fdef, empty) else (empty, doc_fdef) | DEF_internal_mutrec fundefs -> (empty, doc_mutrec_lem fundefs ^/^ hardline) - | DEF_val lbind -> (empty,group (doc_let_lem false lbind) ^/^ hardline) + | DEF_val lbind -> (empty,group (doc_let_lem empty_ctxt lbind) ^/^ hardline) | DEF_scattered sdef -> failwith "doc_def_lem: shoulnd't have DEF_scattered at this point" | DEF_kind _ -> (empty,empty) @@ -1481,7 +1478,7 @@ let doc_regstate_lem registers = E_record (FES_aux (FES_Fexps (List.map initreg registers, false), annot)), (l, Some (Env.empty, mk_id_typ (mk_id "regstate"), no_effect))) in - doc_op equals (string "let initial_regstate") (doc_exp_lem false false exp) + doc_op equals (string "let initial_regstate") (doc_exp_lem empty_ctxt false exp) else empty in doc_typdef_lem (TD_aux (regstate, annot)), diff --git a/src/rewrites.ml b/src/rewrites.ml index e4ac71cd..40772828 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1087,7 +1087,7 @@ let remove_bitvector_pat (P_aux (_, (l, _)) as pat) = ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false)) } in - let pat, env = bind_pat env + let pat, env = bind_pat_no_guard env (strip_pat ((fold_pat name_bitvector_roots pat) false)) (pat_typ_of pat) in @@ -1624,6 +1624,10 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base rewriting of early returns *) let rewrite_defs_early_return (Defs defs) = + let is_unit (E_aux (exp, _)) = match exp with + | E_lit (L_aux (L_unit, _)) -> true + | _ -> false in + let is_return (E_aux (exp, _)) = match exp with | E_return _ -> true | _ -> false in @@ -1632,7 +1636,35 @@ let rewrite_defs_early_return (Defs defs) = | E_return e -> e | _ -> exp in - let e_block es = + let e_if (e1, e2, e3) = + if is_return e2 && is_return e3 then + let (E_aux (_, annot)) = get_return e2 in + E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot)) + else E_if (e1, e2, e3) in + + let rec e_block es = + (* If one of the branches of an if-expression in a block is an early + return, fold the rest of the block after the if-expression into the + other branch *) + let fold_if_return exp block = match exp with + | E_aux (E_if (c, t, (E_aux (_, annot) as e)), _) when is_return t -> + let annot = match block with + | [] -> annot + | _ -> let (E_aux (_, annot)) = Util.last block in annot + in + let block = if is_unit e then block else e :: block in + let e' = E_aux (e_block block, annot) in + [E_aux (e_if (c, t, e'), annot)] + | E_aux (E_if (c, (E_aux (_, annot) as t), e), _) when is_return e -> + let annot = match block with + | [] -> annot + | _ -> let (E_aux (_, annot)) = Util.last block in annot + in + let block = if is_unit t then block else t :: block in + let t' = E_aux (e_block block, annot) in + [E_aux (e_if (c, t', e), annot)] + | _ -> exp :: block in + let es = List.fold_right fold_if_return es [] in match es with | [E_aux (e, _)] -> e | _ :: _ when is_return (Util.last es) -> @@ -1640,12 +1672,6 @@ let rewrite_defs_early_return (Defs defs) = E_return (E_aux (E_block (Util.butlast es @ [get_return e]), annot)) | _ -> E_block es in - let e_if (e1, e2, e3) = - if is_return e2 && is_return e3 then - let (E_aux (_, annot)) = get_return e2 in - E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot)) - else E_if (e1, e2, e3) in - let e_case (e, pes) = let is_return_pexp (Pat_aux (pexp, _)) = match pexp with | Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in @@ -1660,6 +1686,17 @@ let rewrite_defs_early_return (Defs defs) = then E_return (E_aux (E_case (e, List.map get_return_pexp pes), annot)) else E_case (e, pes) in + let e_let (lb, exp) = + let (E_aux (_, annot) as ret_exp) = get_return exp in + if is_return exp then E_return (E_aux (E_let (lb, ret_exp), annot)) + else E_let (lb, exp) in + + let e_internal_let (lexp, exp1, exp2) = + let (E_aux (_, annot) as ret_exp2) = get_return exp2 in + if is_return exp2 then + E_return (E_aux (E_var (lexp, exp1, ret_exp2), annot)) + else E_var (lexp, exp1, exp2) in + let e_aux (exp, (l, annot)) = let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in let env = env_of full_exp in @@ -1674,14 +1711,18 @@ let rewrite_defs_early_return (Defs defs) = let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pexp), a)) = let pat,guard,exp,pannot = destruct_pexp pexp in + (* Try to pull out early returns as far as possible *) + let exp' = + fold_exp + { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case; + e_let = e_let; e_internal_let = e_internal_let } + exp in + (* Remove early return if we can pull it out completely, and rewrite + remaining early returns to "early_return" calls *) let exp = - exp - (* Pull early returns out as far as possible *) - |> fold_exp { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case } - (* Remove singleton E_return *) - |> get_return - (* Fix effect annotations *) - |> fold_exp { id_exp_alg with e_aux = e_aux } in + fold_exp + { id_exp_alg with e_aux = e_aux } + (if is_return exp' then get_return exp' else exp) in let a = match a with | (l, Some (env, typ, eff)) -> (l, Some (env, typ, union_effects eff (effect_of exp))) @@ -2141,7 +2182,7 @@ let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list | [] -> k [] | exp :: exps -> f exp (fun exp -> mapCont f exps (fun exps -> k (exp :: exps))) -let rewrite_defs_letbind_effects = +let rewrite_defs_letbind_effects = let rec value ((E_aux (exp_aux,_)) as exp) = not (effectful exp || updates_vars exp) @@ -2235,6 +2276,7 @@ let rewrite_defs_letbind_effects = let exp = if newreturn then (* let typ = try typ_of exp with _ -> unit_typ in *) + let exp = annot_exp (E_cast (typ_of exp, exp)) l (env_of exp) (typ_of exp) in annot_exp (E_internal_return exp) l (env_of exp) (typ_of exp) else exp in diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml index 23ce6663..371acfdc 100644 --- a/src/spec_analysis.ml +++ b/src/spec_analysis.ml @@ -64,133 +64,6 @@ let set_to_string n = list_to_string (Nameset.elements n) -(*Query a spec for its default order if one is provided. Assumes Inc if not *) -(* let get_default_order_sp (DT_aux(spec,_)) = - match spec with - | DT_order (Ord_aux(o,_)) -> - (match o with - | Ord_inc -> Some {order = Oinc} - | Ord_dec -> Some { order = Odec} - | _ -> Some {order = Oinc}) - | _ -> None - -let get_default_order_def = function - | DEF_default def_spec -> get_default_order_sp def_spec - | _ -> None - -let rec default_order (Defs defs) = - match defs with - | [] -> { order = Oinc } (*When no order is specified, we assume that it's inc*) - | def::defs -> - match get_default_order_def def with - | None -> default_order (Defs defs) - | Some o -> o *) - -(*Is within range*) - -(* let check_in_range (candidate : big_int) (range : typ) : bool = - match range.t with - | Tapp("range", [TA_nexp min; TA_nexp max]) | Tabbrev(_,{t=Tapp("range", [TA_nexp min; TA_nexp max])}) -> - let min,max = - match min.nexp,max.nexp with - | (Nconst min, Nconst max) - | (Nconst min, N2n(_, Some max)) - | (N2n(_, Some min), Nconst max) - | (N2n(_, Some min), N2n(_, Some max)) - -> min, max - | (Nneg n, Nconst max) | (Nneg n, N2n(_, Some max))-> - (match n.nexp with - | Nconst abs_min | N2n(_,Some abs_min) -> - (Big_int.negate abs_min), max - | _ -> assert false (*Put a better error message here*)) - | (Nconst min,Nneg n) | (N2n(_, Some min), Nneg n) -> - (match n.nexp with - | Nconst abs_max | N2n(_,Some abs_max) -> - min, (Big_int.negate abs_max) - | _ -> assert false (*Put a better error message here*)) - | (Nneg nmin, Nneg nmax) -> - ((match nmin.nexp with - | Nconst abs_min | N2n(_,Some abs_min) -> (Big_int.negate abs_min) - | _ -> assert false (*Put a better error message here*)), - (match nmax.nexp with - | Nconst abs_max | N2n(_,Some abs_max) -> (Big_int.negate abs_max) - | _ -> assert false (*Put a better error message here*))) - | _ -> assert false - in Big_int.less_equal min candidate && Big_int.less_equal candidate max - | _ -> assert false - -(*Rmove me when switch to zarith*) -let rec power_big_int b n = - if Big_int.equal n Big_int.zero - then (Big_int.of_int 1) - else Big_int.mul b (power_big_int b (Big_int.sub n (Big_int.of_int 1))) - -let unpower_of_2 b = - let two = Big_int.of_int 2 in - let four = Big_int.of_int 4 in - let eight = Big_int.of_int 8 in - let sixteen = Big_int.of_int 16 in - let thirty_two = Big_int.of_int 32 in - let sixty_four = Big_int.of_int 64 in - let onetwentyeight = Big_int.of_int 128 in - let twofiftysix = Big_int.of_int 256 in - let fivetwelve = Big_int.of_int 512 in - let oneotwentyfour = Big_int.of_int 1024 in - let to_the_sixteen = Big_int.of_int 65536 in - let to_the_thirtytwo = Big_int.of_string "4294967296" in - let to_the_sixtyfour = Big_int.of_string "18446744073709551616" in - let ck i = Big_int.equal b i in - if ck (Big_int.of_int 1) then Big_int.zero - else if ck two then (Big_int.of_int 1) - else if ck four then two - else if ck eight then Big_int.of_int 3 - else if ck sixteen then four - else if ck thirty_two then Big_int.of_int 5 - else if ck sixty_four then Big_int.of_int 6 - else if ck onetwentyeight then Big_int.of_int 7 - else if ck twofiftysix then eight - else if ck fivetwelve then Big_int.of_int 9 - else if ck oneotwentyfour then Big_int.of_int 10 - else if ck to_the_sixteen then sixteen - else if ck to_the_thirtytwo then thirty_two - else if ck to_the_sixtyfour then sixty_four - else let rec unpower b power = - if Big_int.equal b (Big_int.of_int 1) - then power - else (unpower (Big_int.div b two) (Big_int.succ power)) in - unpower b Big_int.zero - -let is_within_range candidate range constraints = - let candidate_actual = match candidate.t with - | Tabbrev(_,t) -> t - | _ -> candidate in - match candidate_actual.t with - | Tapp("atom", [TA_nexp n]) -> - (match n.nexp with - | Nconst i | N2n(_,Some i) -> if check_in_range i range then Yes else No - | _ -> Maybe) - | Tapp("range", [TA_nexp bot; TA_nexp top]) -> - (match bot.nexp,top.nexp with - | Nconst b, Nconst t | Nconst b, N2n(_,Some t) | N2n(_, Some b), Nconst t | N2n(_,Some b), N2n(_, Some t) -> - let at_least_in = check_in_range b range in - let at_most_in = check_in_range t range in - if at_least_in && at_most_in - then Yes - else if at_least_in || at_most_in - then Maybe - else No - | _ -> Maybe) - | Tapp("vector", [_; TA_nexp size ; _; _]) -> - (match size.nexp with - | Nconst i | N2n(_, Some i) -> - if check_in_range (power_big_int (Big_int.of_int 2) i) range - then Yes - else No - | _ -> Maybe) - | _ -> Maybe - -let is_within_machine64 candidate constraints = is_within_range candidate int64_t constraints *) - (************************************************************************************************) (*FV finding analysis: identifies the free variables of a function, expression, etc *) @@ -313,9 +186,11 @@ let rec fv_of_exp consider_var bound used set (E_aux (e,(_,tannot))) : (Nameset. fv_of_exp consider_var bound u set e | E_app(id,es) -> let us = conditional_add_exp bound used id in + let us = conditional_add_exp bound us (prepend_id "val:" id) in list_fv bound us set es | E_app_infix(l,id,r) -> let us = conditional_add_exp bound used id in + let us = conditional_add_exp bound us (prepend_id "val:" id) in list_fv bound us set [l;r] | E_if(c,t,e) -> list_fv bound used set [c;t;e] | E_for(id,from,to_,by,_,body) -> @@ -464,8 +339,12 @@ let fv_of_fun consider_var (FD_aux (FD_function(rec_opt,tannot_opt,_,funcls),_) | [] -> failwith "fv_of_fun fell off the end looking for the function name" | FCL_aux(FCL_Funcl(id,_),_)::_ -> string_of_id id in let base_bounds = match rec_opt with + (* Current Sail does not have syntax for declaring functions as recursive, + and type checker does not check whether functions are recursive, so + just always add a self-dependency of functions on themselves | Rec_aux(Ast.Rec_rec,_) -> init_env fun_name - | _ -> mt in + | _ -> mt*) + | _ -> init_env fun_name in let base_bounds,ns_r = match tannot_opt with | Typ_annot_opt_aux(Typ_annot_opt_some (typq, typ),_) -> let bindings = if consider_var then typq_bindings typq else mt in @@ -577,7 +456,9 @@ let fv_of_def consider_var consider_scatter_as_one all_defs = function | DEF_val lebind -> ((fun (b,u,_) -> (b,u)) (fv_of_let consider_var mt mt mt lebind)) | DEF_spec vspec -> fv_of_vspec consider_var vspec | DEF_fixity _ -> mt,mt - | DEF_overload (id,ids) -> init_env (string_of_id id), List.fold_left (fun ns id -> Nameset.add (string_of_id id) ns) mt ids + | DEF_overload (id,ids) -> + init_env (string_of_id id), + List.fold_left (fun ns id -> Nameset.add ("val:" ^ string_of_id id) ns) mt ids | DEF_default def -> mt,mt | DEF_internal_mutrec fdefs -> let fvs = List.map (fv_of_fun consider_var) fdefs in @@ -590,129 +471,133 @@ let fv_of_def consider_var consider_scatter_as_one all_defs = function let group_defs consider_scatter_as_one (Ast.Defs defs) = List.map (fun d -> (fv_of_def false consider_scatter_as_one defs d,d)) defs -(******************************************************************************* - * Reorder defs take 2 -*) -(*remove all of ns1 instances from ns2*) -let remove_all ns1 ns2 = - List.fold_right Nameset.remove (Nameset.elements ns1) ns2 - -let remove_from_all_uses bs dbts = - List.map (fun ((b,uses),d) -> (b,remove_all bs uses),d) dbts - -let remove_local_or_lib_vars dbts = - let bound_in_dbts = List.fold_right (fun ((b,_),_) bounds -> Nameset.union b bounds) dbts mt in - let is_bound_in_defs s = Nameset.mem s bound_in_dbts in - let rec remove_from_uses = function - | [] -> [] - | ((b,uses),d)::defs -> - ((b,(Nameset.filter is_bound_in_defs uses)),d)::remove_from_uses defs in - remove_from_uses dbts +(* + * Sorting definitions, take 3 + *) + +module Namemap = Map.Make(String) +(* Nodes are labeled with strings. A graph is represented as a map associating + each node with its sucessors *) +type graph = Nameset.t Namemap.t +type node_idx = { index : int; root : int } + +(* Find strongly connected components using Tarjan's algorithm. + This algorithm also returns a topological sorting of the graph components. *) +let scc ?(original_order : string list option) (g : graph) = + let components = ref [] in + let index = ref 0 in + + let stack = ref [] in + let push v = (stack := v :: !stack) in + let pop () = + begin + let v = List.hd !stack in + stack := List.tl !stack; + v + end + in -let compare_dbts ((_,u1),_) ((_,u2),_) = Pervasives.compare (Nameset.cardinal u1) (Nameset.cardinal u2) + let node_indices = Hashtbl.create (Namemap.cardinal g) in + let get_index v = (Hashtbl.find node_indices v).index in + let get_root v = (Hashtbl.find node_indices v).root in + let set_root v r = + Hashtbl.replace node_indices v { (Hashtbl.find node_indices v) with root = r } in + + let rec visit_node v = + begin + Hashtbl.add node_indices v { index = !index; root = !index }; + index := !index + 1; + push v; + if Namemap.mem v g then Nameset.iter (visit_edge v) (Namemap.find v g); + if get_root v = get_index v then (* v is the root of a SCC *) + begin + let component = ref [] in + let finished = ref false in + while not !finished do + let w = pop () in + component := w :: !component; + if String.compare v w = 0 then finished := true; + done; + components := !component :: !components; + end + end + and visit_edge v w = + if not (Hashtbl.mem node_indices w) then + begin + visit_node w; + if Hashtbl.mem node_indices w then set_root v (min (get_root v) (get_root w)); + end else begin + if List.mem w !stack then set_root v (min (get_root v) (get_index w)) + end + in -let rec print_dependencies orig_queue work_queue names = - match work_queue with - | [] -> () - | ((binds,uses),_)::wq -> - (if not(Nameset.is_empty(Nameset.inter names binds)) - then ((Printf.eprintf "binds of %s has uses of %s\n" (set_to_string binds) (set_to_string uses)); - print_dependencies orig_queue orig_queue uses)); - print_dependencies orig_queue wq names - -let merge_mutrecs defs = - let merge_aux ((binds', uses'), def) ((binds, uses), fundefs) = - let fundefs = match def with - | DEF_fundef fundef -> fundef :: fundefs - | DEF_internal_mutrec fundefs' -> fundefs' @ fundefs - | _ -> - (* let _ = Pretty_print_sail.pp_defs stderr (Defs [def]) in *) - raise (Reporting_basic.err_unreachable (def_loc def) - "Trying to merge non-function definition with mutually recursive functions") in - (* let _ = Printf.eprintf " - Merging %s (using %s)\n" (set_to_string binds') (set_to_string uses') in *) - ((Nameset.union binds' binds, Nameset.union uses' uses), fundefs) in - let ((binds, uses), fundefs) = List.fold_right merge_aux defs ((mt, mt), []) in - ((binds, uses), DEF_internal_mutrec fundefs) - -let rec topological_sort work_queue defs = - match work_queue with - | [] -> List.rev defs - | ((binds,uses),def)::wq -> - (*Assumes work queue given in sorted order, invariant mantained on appropriate recursive calls*) - if (Nameset.cardinal uses = 0) - then (*let _ = Printf.eprintf "Adding def that binds %s to definitions\n" (set_to_string binds) in*) - topological_sort (remove_from_all_uses binds wq) (def::defs) - else if not(Nameset.is_empty(Nameset.inter binds uses)) - then topological_sort (((binds,(remove_all binds uses)),def)::wq) defs - else - match List.stable_sort compare_dbts work_queue with (*We wait to sort until there are no 0 dependency nodes on top*) - | [] -> failwith "sort shrunk the list???" - | (((n,uses),def)::rest) as wq -> - if (Nameset.cardinal uses = 0) - then topological_sort wq defs - else - let _ = Printf.eprintf "Merging (potentially) mutually recursive definitions %s and %s\n" (set_to_string n) (set_to_string uses) in - let is_used ((binds', uses'), def') = not(Nameset.is_empty(Nameset.inter binds' uses)) in - let (used, rest) = List.partition is_used rest in - let wq = merge_mutrecs (((n,uses),def)::used) :: rest in - topological_sort wq defs - -let rec add_to_partial_order ((binds,uses),def) = function - | [] -> -(* let _ = Printf.eprintf "add_to_partial_order for def with bindings %s, uses %s.\n Eol case.\n" (set_to_string binds) (set_to_string uses) in*) - [(binds,uses),def] - | (((bf,uf),deff)::defs as full_defs) -> - (*let _ = Printf.eprintf "add_to_partial_order for def with bindings %s, uses %s.\n None eol case. With first def binding %s, uses %s\n" (set_to_string binds) (set_to_string uses) (set_to_string bf) (set_to_string uf) in*) - if Nameset.is_empty uses - then ((binds,uses),def)::full_defs - else if Nameset.subset binds uf (*deff relies on def, so def must be defined first*) - then ((binds,uses),def)::((bf,(remove_all binds uf)),deff)::defs - else if Nameset.subset bf uses (*def relies at least on deff, but maybe more, push in*) - then ((bf,uf),deff)::(add_to_partial_order ((binds,(remove_all bf uses)),def) defs) - else (*These two are unrelated but new def might need to go further in*) - ((bf,uf),deff)::(add_to_partial_order ((binds,uses),def) defs) - -let rec gather_defs name already_included def_bind_triples = - match def_bind_triples with - | [] -> [],already_included,mt - | ((binds,uses),def)::def_bind_triples -> - let (defs,already_included,requires) = gather_defs name already_included def_bind_triples in - let bound_names = Nameset.elements binds in - if List.mem name already_included || List.exists (fun b -> List.mem b already_included) bound_names - then (defs,already_included,requires) - else - let uses = List.fold_right Nameset.remove already_included uses in - if Nameset.mem name binds - then (def::defs,(bound_names@already_included), Nameset.remove name (Nameset.union uses requires)) - else (defs,already_included,requires) - -let rec gather_all names already_included def_bind_triples = - let rec gather ns already_included defs reqs = match ns with - | [] -> defs,already_included,reqs - | name::ns -> - if List.mem name already_included - then gather ns already_included defs (Nameset.remove name reqs) - else - let (new_defs,already_included,new_reqs) = gather_defs name already_included def_bind_triples in - gather ns already_included (new_defs@defs) (Nameset.remove name (Nameset.union new_reqs reqs)) + let nodes = match original_order with + | Some nodes -> nodes + | None -> List.map fst (Namemap.bindings g) in - let (defs,already_included,reqs) = gather names already_included [] mt in - if Nameset.is_empty reqs - then defs - else (gather_all (Nameset.elements reqs) already_included def_bind_triples)@defs - -let restrict_defs defs name_list = - let defsno = gather_all name_list [] (group_defs false defs) in - let rdbts = group_defs true (Defs defsno) in - (*let partial_order = - List.fold_left (fun po d -> add_to_partial_order d po) [] rdbts in - let defs = List.map snd partial_order in*) - let defs = topological_sort (List.sort compare_dbts (remove_local_or_lib_vars rdbts)) [] in - Defs defs - - -let top_sort_defs defs = - let rdbts = group_defs true defs in - let defs = topological_sort (List.stable_sort compare_dbts (remove_local_or_lib_vars rdbts)) [] in - Defs defs + List.iter (fun v -> if not (Hashtbl.mem node_indices v) then visit_node v) nodes; + List.rev !components + +let add_def_to_graph (prelude, original_order, defset, graph) d = + let bound, used = fv_of_def false true [] d in + try + (* A definition may bind multiple identifiers, e.g. "let (x, y) = ...". + We add all identifiers to the dependency graph as a cycle. The actual + definition is attached to only one of the identifiers, so it will not + be duplicated in the final output. *) + let id = Nameset.choose bound in + let other_ids = Nameset.remove id bound in + let graph_id = Namemap.add id (Nameset.union used other_ids) graph in + let add_other_node id' g = Namemap.add id' (Nameset.singleton id) g in + prelude, + original_order @ [id], + Namemap.add id d defset, + Nameset.fold add_other_node other_ids graph_id + with + | Not_found -> + (* Some definitions do not bind any identifiers at all. This *should* + only happen for default bitvector order declarations, operator fixity + declarations, and comments. The sorting does not (currently) attempt + to preserve the positions of these AST nodes; they are collected + separately and placed at the beginning of the output. Comments are + currently ignored by the Lem and OCaml backends, anyway. For + default order and fixity declarations, this means that specifications + currently have to assume those declarations are moved to the + beginning when using a backend that requires topological sorting. *) + prelude @ [d], original_order, defset, graph + +let print_dot graph component : unit = + match component with + | root :: _ -> + print_endline ("// Dependency cycle including " ^ root); + print_endline ("digraph cycle_" ^ root ^ " {"); + List.iter (fun caller -> + let print_edge callee = print_endline (" \"" ^ caller ^ "\" -> \"" ^ callee ^ "\";") in + Namemap.find caller graph + |> Nameset.filter (fun id -> List.mem id component) + |> Nameset.iter print_edge) component; + print_endline "}" + | [] -> () + +let def_of_component graph defset comp = + let get_def id = if Namemap.mem id defset then [Namemap.find id defset] else [] in + match List.concat (List.map get_def comp) with + | [] -> [] + | [def] -> [def] + | (def :: _) as defs -> + let get_fundefs = function + | DEF_fundef fundef -> [fundef] + | DEF_internal_mutrec fundefs -> fundefs + | _ -> + raise (Reporting_basic.err_unreachable (def_loc def) + "Trying to merge non-function definition with mutually recursive functions") in + let fundefs = List.concat (List.map get_fundefs defs) in + print_dot graph (List.map (fun fd -> string_of_id (id_of_fundef fd)) fundefs); + [DEF_internal_mutrec fundefs] + +let top_sort_defs (Defs defs) = + let prelude, original_order, defset, graph = + List.fold_left add_def_to_graph ([], [], Namemap.empty, Namemap.empty) defs in + let components = scc ~original_order:original_order graph in + Defs (prelude @ List.concat (List.map (def_of_component graph defset) components)) diff --git a/src/type_check.ml b/src/type_check.ml index 30334783..04c16ad5 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -2029,16 +2029,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ annot_exp (E_case (inferred_exp, List.map (fun case -> check_case env inferred_typ case typ) cases)) typ | E_try (exp, cases), _ -> let checked_exp = crule check_exp env exp typ in - let check_case pat typ = match pat with - | Pat_aux (Pat_exp (pat, case), (l, _)) -> - let tpat, env = bind_pat env pat exc_typ in - Pat_aux (Pat_exp (tpat, crule check_exp env case typ), (l, None)) - | Pat_aux (Pat_when (pat, guard, case), (l, _)) -> - let tpat, env = bind_pat env pat exc_typ in - let checked_guard = check_exp env guard bool_typ in - Pat_aux (Pat_when (tpat, checked_guard, crule check_exp env case typ), (l, None)) - in - annot_exp (E_try (checked_exp, List.map (fun case -> check_case case typ) cases)) typ + annot_exp (E_try (checked_exp, List.map (fun case -> check_case env exc_typ case typ) cases)) typ | E_cons (x, xs), _ -> begin match is_list (Env.expand_synonyms env typ) with @@ -2093,11 +2084,11 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | LB_val (P_aux (P_typ (ptyp, _), _) as pat, bind) -> Env.wf_typ env ptyp; let checked_bind = crule check_exp env bind ptyp in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in annot_exp (E_let (LB_aux (LB_val (tpat, checked_bind), (let_loc, None)), crule check_exp env exp typ)) typ | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env bind in - let tpat, env = bind_pat env pat (typ_of inferred_bind) in + let tpat, env = bind_pat_no_guard env pat (typ_of inferred_bind) in annot_exp (E_let (LB_aux (LB_val (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ end | E_app_infix (x, op, y), _ -> @@ -2159,7 +2150,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | _ -> let inferred_bind = irule infer_exp env bind in inferred_bind, typ_of inferred_bind in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) let env = match bind_exp with | E_aux (E_assert (E_aux (E_constraint nc, _), _), _) -> @@ -2192,7 +2183,17 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ and check_case env pat_typ pexp typ = let pat,guard,case,((l,_) as annot) = destruct_pexp pexp in match bind_pat env pat pat_typ with - | tpat, env -> + | tpat, env, guards -> + let guard = match guard, guards with + | None, h::t -> Some (h,t) + | Some x, l -> Some (x,l) + | None, [] -> None + in + let guard = match guard with + | Some (h,t) -> + Some (List.fold_left (fun acc guard -> mk_exp (E_app_infix (acc, mk_id "&", guard))) h t) + | None -> None + in let checked_guard, env' = match guard with | None -> None, env | Some guard -> @@ -2202,6 +2203,7 @@ and check_case env pat_typ pexp typ = in let checked_case = crule check_exp env' case typ in construct_pexp (tpat, checked_guard, checked_case, (l, None)) + (* AA: Not sure if we still need this *) | exception (Type_error _ as typ_exn) -> match pat with | P_aux (P_lit lit, _) -> @@ -2283,29 +2285,34 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ = try_casts casts end +and bind_pat_no_guard env (P_aux (_,(l,_)) as pat) typ = + match bind_pat env pat typ with + | _, _, _::_ -> typ_error l "Literal patterns not supported here" + | tpat, env, [] -> tpat, env + and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) = typ_print ("Binding " ^ string_of_pat pat ^ " to " ^ string_of_typ typ); let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in let switch_typ pat typ = match pat with - | (P_aux (pat_aux, (l, Some (env, _, eff)))) -> P_aux (pat_aux, (l, Some (env, typ, eff))) - | _ -> failwith "Cannot switch type of unannotated pattern" + | P_aux (pat_aux, (l, Some (env, _, eff))) -> P_aux (pat_aux, (l, Some (env, typ, eff))) + | _ -> typ_error l "Cannot switch type for unannotated pattern" in - let bind_tuple_pat (tpats, env) pat typ = - let tpat, env = bind_pat env pat typ in tpat :: tpats, env + let bind_tuple_pat (tpats, env, guards) pat typ = + let tpat, env, guards' = bind_pat env pat typ in tpat :: tpats, env, guards' @ guards in match pat_aux with | P_id v -> begin match Env.lookup_id v env with - | Local (Immutable, _) | Unbound -> annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env + | Local (Immutable, _) | Unbound -> annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env, [] | Local (Mutable, _) | Register _ -> typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat) - | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env + | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env, [] | Union (typq, ctor_typ) -> begin try let _ = unify l env ctor_typ typ in - annot_pat (P_id v) typ, env + annot_pat (P_id v) typ, env, [] with | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m) end @@ -2318,34 +2325,34 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) let env = Env.add_typ_var kid BK_nat env in let ex_typ = typ_subst_nexp kid' (Nexp_var kid) ex_typ in let env = Env.add_constraint (nc_subst_nexp kid' (Nexp_var kid) nc) env in - let typed_pat, env = bind_pat env pat ex_typ in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat ex_typ in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | Some _, _ -> typ_error l ("Cannot bind type variable pattern against multiple argument existential") | None, Typ_aux (Typ_id id, _) when Id.compare id (mk_id "int") == 0 -> let env = Env.add_typ_var kid BK_nat env in - let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | None, Typ_aux (Typ_id id, _) when Id.compare id (mk_id "nat") == 0 -> let env = Env.add_typ_var kid BK_nat env in let env = Env.add_constraint (nc_gt (nvar kid) (nint 0)) env in - let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | None, Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp lo, _); Typ_arg_aux (Typ_arg_nexp hi, _)]), _) when Id.compare id (mk_id "range") == 0 -> let env = Env.add_typ_var kid BK_nat env in let env = Env.add_constraint (nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi)) env in - let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | None, _ -> typ_error l ("Cannot bind type variable against non existential or numeric type") end - | P_wild -> annot_pat P_wild typ, env + | P_wild -> annot_pat P_wild typ, env, [] | P_cons (hd_pat, tl_pat) -> begin match Env.expand_synonyms env typ with | Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> - let hd_pat, env = bind_pat env hd_pat ltyp in - let tl_pat, env = bind_pat env tl_pat typ in - annot_pat (P_cons (hd_pat, tl_pat)) typ, env + let hd_pat, env, hd_guards = bind_pat env hd_pat ltyp in + let tl_pat, env, tl_guards = bind_pat env tl_pat typ in + annot_pat (P_cons (hd_pat, tl_pat)) typ, env, hd_guards @ tl_guards | _ -> typ_error l "Cannot match cons pattern against non-list type" end | P_list pats -> @@ -2353,32 +2360,32 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) match Env.expand_synonyms env typ with | Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> let rec process_pats env = function - | [] -> [], env + | [] -> [], env, [] | (pat :: pats) -> - let pat', env = bind_pat env pat ltyp in - let pats', env = process_pats env pats in - pat' :: pats', env + let pat', env, guards = bind_pat env pat ltyp in + let pats', env, guards' = process_pats env pats in + pat' :: pats', env, guards @ guards' in - let pats, env = process_pats env pats in - annot_pat (P_list pats) typ, env + let pats, env, guards = process_pats env pats in + annot_pat (P_list pats) typ, env, guards | _ -> typ_error l ("Cannot match list pattern " ^ string_of_pat pat ^ " against non-list type " ^ string_of_typ typ) end | P_tup [] -> begin match Env.expand_synonyms env typ with | Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" -> - annot_pat (P_tup []) typ, env + annot_pat (P_tup []) typ, env, [] | _ -> typ_error l "Cannot match unit pattern against non-unit type" end | P_tup pats -> begin match Env.expand_synonyms env typ with | Typ_aux (Typ_tup typs, _) -> - let tpats, env = - try List.fold_left2 bind_tuple_pat ([], env) pats typs with + let tpats, env, guards = + try List.fold_left2 bind_tuple_pat ([], env, []) pats typs with | Invalid_argument _ -> typ_error l "Tuple pattern and tuple type have different length" in - annot_pat (P_tup (List.rev tpats)) typ, env + annot_pat (P_tup (List.rev tpats)) typ, env, guards | _ -> typ_error l "Cannot bind tuple pattern against non tuple type" end | P_app (f, pats) when Env.is_union_constructor f env -> @@ -2402,11 +2409,11 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) else (); let ret_typ' = subst_unifiers unifiers ret_typ in - let tpats, env = - try List.fold_left2 bind_tuple_pat ([], env) pats (untuple arg_typ') with + let tpats, env, guards = + try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with | Invalid_argument _ -> typ_error l "Union constructor pattern arguments have incorrect length" in - annot_pat (P_app (f, List.rev tpats)) typ, env + annot_pat (P_app (f, List.rev tpats)) typ, env, guards with | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m) end @@ -2415,12 +2422,19 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) | P_app (f, _) when not (Env.is_union_constructor f env) -> typ_error l (string_of_id f ^ " is not a union constructor in pattern " ^ string_of_pat pat) | P_as (pat, id) -> - let (typed_pat, env) = bind_pat env pat typ in - annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env + let (typed_pat, env, guards) = bind_pat env pat typ in + annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env, guards | _ -> - let (inferred_pat, env) = infer_pat env pat in - subtyp l env (pat_typ_of inferred_pat) typ; - switch_typ inferred_pat typ, env + let (inferred_pat, env, guards) = infer_pat env pat in + match subtyp l env (pat_typ_of inferred_pat) typ with + | () -> switch_typ inferred_pat typ, env, guards + | exception (Type_error _ as typ_exn) -> + match pat_aux with + | P_lit lit -> + let guard = mk_exp (E_app_infix (mk_exp (E_id (mk_id "p#")), mk_id "==", mk_exp (E_lit lit))) in + let (typed_pat, env, guards) = bind_pat env (mk_pat (P_id (mk_id "p#"))) typ in + typed_pat, env, guard::guards + | _ -> raise typ_exn and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in @@ -2432,31 +2446,31 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = typ_error l ("Cannot infer identifier in pattern " ^ string_of_pat pat ^ " - try adding a type annotation") | Local (Mutable, _) | Register _ -> typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat) - | Enum enum -> annot_pat (P_id v) enum, env + | Enum enum -> annot_pat (P_id v) enum, env, [] end | P_typ (typ_annot, pat) -> Env.wf_typ env typ_annot; - let (typed_pat, env) = bind_pat env pat typ_annot in - annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env + let (typed_pat, env, guards) = bind_pat env pat typ_annot in + annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env, guards | P_lit lit -> - annot_pat (P_lit lit) (infer_lit env lit), env + annot_pat (P_lit lit) (infer_lit env lit), env, [] | P_vector (pat :: pats) -> - let fold_pats (pats, env) pat = - let typed_pat, env = bind_pat env pat bit_typ in - pats @ [typed_pat], env + let fold_pats (pats, env, guards) pat = + let typed_pat, env, guards' = bind_pat env pat bit_typ in + pats @ [typed_pat], env, guards' @ guards in - let pats, env = - List.fold_left fold_pats ([], env) (pat :: pats) in + let pats, env, guards = List.fold_left fold_pats ([], env, []) (pat :: pats) in let len = nexp_simp (nint (List.length pats)) in let etyp = pat_typ_of (List.hd pats) in List.iter (fun pat -> typ_equality l env etyp (pat_typ_of pat)) pats; - annot_pat (P_vector pats) (dvector_typ env len etyp), env + annot_pat (P_vector pats) (dvector_typ env len etyp), env, guards | P_vector_concat (pat :: pats) -> - let fold_pats (pats, env) pat = - let inferred_pat, env = infer_pat env pat in - pats @ [inferred_pat], env + let fold_pats (pats, env, guards) pat = + let inferred_pat, env, guards' = infer_pat env pat in + pats @ [inferred_pat], env, guards' @ guards in - let inferred_pats, env = List.fold_left fold_pats ([], env) (pat :: pats) in + let inferred_pats, env, guards = + List.fold_left fold_pats ([], env, []) (pat :: pats) in let (len, _, vtyp) = destruct_vec_typ l env (pat_typ_of (List.hd inferred_pats)) in let fold_len len pat = let (len', _, vtyp') = destruct_vec_typ l env (pat_typ_of pat) in @@ -2464,10 +2478,12 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = nsum len len' in let len = nexp_simp (List.fold_left fold_len len (List.tl inferred_pats)) in - annot_pat (P_vector_concat inferred_pats) (dvector_typ env len vtyp), env + annot_pat (P_vector_concat inferred_pats) (dvector_typ env len vtyp), env, guards | P_as (pat, id) -> - let (typed_pat, env) = infer_pat env pat in - annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env + let (typed_pat, env, guards) = infer_pat env pat in + annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), + Env.add_local id (Immutable, pat_typ_of typed_pat) env, + guards | _ -> typ_error l ("Couldn't infer type of pattern " ^ string_of_pat pat) and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as exp) = @@ -2858,7 +2874,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | _ -> let inferred_bind = irule infer_exp env bind in inferred_bind, typ_of inferred_bind in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) let env = match bind_exp with | E_aux (E_assert (E_aux (E_constraint nc, _), _), _) -> @@ -2876,7 +2892,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env bind in inferred_bind, pat, typ_of inferred_bind in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in let inferred_exp = irule infer_exp env exp in annot_exp (E_let (LB_aux (LB_val (tpat, bind_exp), (let_loc, None)), inferred_exp)) (typ_of inferred_exp) | E_ref id when Env.is_mutable id env -> @@ -3328,11 +3344,11 @@ let check_letdef env (LB_aux (letbind, (l, _))) = match letbind with | LB_val (P_aux (P_typ (typ_annot, pat), _), bind) -> let checked_bind = crule check_exp env (strip_exp bind) typ_annot in - let tpat, env = bind_pat env (strip_pat pat) typ_annot in + let tpat, env = bind_pat_no_guard env (strip_pat pat) typ_annot in [DEF_val (LB_aux (LB_val (P_aux (P_typ (typ_annot, tpat), (l, Some (env, typ_annot, no_effect))), checked_bind), (l, None)))], env | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env (strip_exp bind) in - let tpat, env = bind_pat env (strip_pat pat) (typ_of inferred_bind) in + let tpat, env = bind_pat_no_guard env (strip_pat pat) (typ_of inferred_bind) in [DEF_val (LB_aux (LB_val (tpat, inferred_bind), (l, None)))], env end diff --git a/src/type_check.mli b/src/type_check.mli index e3daec75..3f43492f 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -199,7 +199,11 @@ val prove : Env.t -> n_constraint -> bool val subtype_check : Env.t -> typ -> typ -> bool -val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t +val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t * unit Ast.exp list +(* Variant that doesn't introduce new guards for literal patterns, but raises + a type error instead. This should always be safe to use on patterns that + have previously been type checked. *) +val bind_pat_no_guard : Env.t -> unit pat -> typ -> tannot pat * Env.t (* Partial functions: The expressions and patterns passed to these functions must be guaranteed to have tannots of the form Some (env, |
