diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 23 | ||||
| -rw-r--r-- | src/ast_util.mli | 3 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 25 | ||||
| -rw-r--r-- | src/rewriter.ml | 199 | ||||
| -rw-r--r-- | src/type_check.ml | 30 | ||||
| -rw-r--r-- | src/type_check.mli | 4 |
6 files changed, 234 insertions, 50 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 9b3dee84..04ce5b07 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -355,6 +355,29 @@ module IdSet = Set.Make(Id) module KBindings = Map.Make(Kid) module KidSet = Set.Make(Kid) +let rec nexp_frees (Nexp_aux (nexp, l)) = + match nexp with + | Nexp_id _ -> raise (Reporting_basic.err_typ l "Unimplemented Nexp_id in nexp_frees") + | Nexp_var kid -> KidSet.singleton kid + | Nexp_constant _ -> KidSet.empty + | Nexp_times (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) + | Nexp_sum (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) + | Nexp_minus (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) + | Nexp_exp n -> nexp_frees n + | Nexp_neg n -> nexp_frees n + +let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = + match nexp1, nexp2 with + | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0 + | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 = 0 + | Nexp_constant c1, Nexp_constant c2 -> c1 = c2 + | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b + | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b + | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b + | Nexp_exp n1, Nexp_exp n2 -> nexp_identical n1 n2 + | Nexp_neg n1, Nexp_neg n2 -> nexp_identical n1 n2 + | _, _ -> false + let rec is_number (Typ_aux (t,_)) = match t with | Typ_app (Id_aux (Id "range", _),_) diff --git a/src/ast_util.mli b/src/ast_util.mli index 0adda3ef..d7f68412 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -115,6 +115,9 @@ module Bindings : sig include Map.S with type key = id end +val nexp_frees : nexp -> KidSet.t +val nexp_identical : nexp -> nexp -> bool + val is_number : typ -> bool val is_vector_typ : typ -> bool val is_bit_typ : typ -> bool diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 6b7b8aca..9f479cdc 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -392,7 +392,7 @@ let doc_exp_lem, doc_let_lem = | _ -> (prefix 2 1) (string "write_reg") (doc_lexp_deref_lem regtypes le ^/^ expY e)) | E_vector_append(le,re) -> - let t = typ_of_annot (l,annot) in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (call,ta,aexp_needed) = if is_bitvector_typ t then if not (contains_t_pp_var t) @@ -450,7 +450,7 @@ let doc_exp_lem, doc_let_lem = let call = if is_bitvector_typ t1 then "bvslice_raw" else "slice_raw" in let epp = separate space [string call;expY e1;expY e2;expY e3] in let (taepp,aexp_needed) = - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let eff = effect_of full_exp in if contains_bitvector_typ t && not (contains_t_pp_var t) then (align epp ^^ (doc_tannot_lem regtypes (effectful eff) t), true) @@ -504,7 +504,7 @@ let doc_exp_lem, doc_let_lem = | args -> parens (align (separate_map (comma ^^ break 0) (argpp false) args)) in let epp = align (call ^//^ argspp) in let (taepp,aexp_needed) = - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let eff = effect_of full_exp in if contains_bitvector_typ t && not (contains_t_pp_var t) then (align epp ^^ (doc_tannot_lem regtypes (effectful eff) t), true) @@ -523,7 +523,7 @@ let doc_exp_lem, doc_let_lem = separate space [string call;expY v;expY e] in if aexp_needed then parens (align epp) else epp | E_vector_subrange (v,e1,e2) -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let eff = effect_of full_exp in let (epp,aexp_needed) = if has_effect eff BE_rreg then @@ -543,7 +543,7 @@ let doc_exp_lem, doc_let_lem = let ft = typ_of_annot (l,fannot) in (match fannot with | Some(env, (Typ_aux (Typ_id tid, _)), _) when Env.is_regtyp tid env -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let field_f = string (if is_bit_typ t then "read_reg_bitfield" @@ -566,7 +566,7 @@ let doc_exp_lem, doc_let_lem = | E_block exps -> raise (report l "Blocks should have been removed till now.") | E_nondet exps -> raise (report l "Nondet blocks not supported.") | E_id id -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in (match annot with | Some (env, Typ_aux (Typ_id tid, _), eff) when Env.is_regtyp tid env -> if has_effect eff BE_rreg then @@ -613,6 +613,7 @@ let doc_exp_lem, doc_let_lem = | _ -> doc_id_lem id) | E_lit lit -> doc_lit_lem false lit annot | E_cast(typ,e) -> + let typ = Env.base_typ_of (env_of full_exp) typ in if is_vector_typ typ then let (start,_,_,_) = vector_typ_args_of typ in let call = @@ -670,10 +671,11 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (report l "cannot get record type") in anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp regtypes recordtyp) fexps)) | E_vector exps -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (start, len, order, etyp) = if is_vector_typ t then vector_typ_args_of t - else raise (Reporting_basic.err_unreachable l "E_vector of non-vector type") in + else raise (Reporting_basic.err_unreachable l + "E_vector of non-vector type") in (*match annot with | Base((_,t),_,_,_,_,_) -> match t.t with @@ -707,7 +709,7 @@ let doc_exp_lem, doc_let_lem = if aexp_needed then parens (align epp) else epp (* *) | E_vector_indexed (iexps, (Def_val_aux (default,(dl,dannot)))) -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (start, len, order, etyp) = if is_vector_typ t then vector_typ_args_of t else raise (Reporting_basic.err_unreachable l "E_vector_indexed of non-vector type") in @@ -1285,8 +1287,9 @@ let doc_dec_lem (DEC_aux (reg,(l,annot))) = (match typ with | Typ_aux (Typ_app (r, [Typ_arg_aux (Typ_arg_typ rt, _)]), _) when string_of_id r = "register" && is_vector_typ rt -> - let (start, size, order, etyp) = vector_typ_args_of rt in - (match is_bit_typ etyp, start, size with + let env = env_of_annot (l,annot) in + let (start, size, order, etyp) = vector_typ_args_of (Env.base_typ_of env rt) in + (match is_bit_typ (Env.base_typ_of env etyp), start, size with | true, Nexp_aux (Nexp_constant start, _), Nexp_aux (Nexp_constant size, _) -> let o = if is_order_inc order then "true" else "false" in (doc_op equals) diff --git a/src/rewriter.ml b/src/rewriter.ml index ecde3e8a..560159d2 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -60,12 +60,6 @@ type 'a rewriters = { let (>>) f g = fun x -> g(f(x)) -let env_of_annot = function - | (_,Some(env,_,_)) -> env - | (l,None) -> Env.empty - -let env_of (E_aux (_,a)) = env_of_annot a - let effect_of_fpat (FP_aux (_,(_,a))) = effect_of_annot a let effect_of_lexp (LEXP_aux (_,(_,a))) = effect_of_annot a let effect_of_fexp (FE_aux (_,(_,a))) = effect_of_annot a @@ -573,15 +567,17 @@ let rewrite_defs_base rewriters (Defs defs) = | [] -> [] | d::ds -> (rewriters.rewrite_def rewriters d)::(rewrite ds) in Defs (rewrite defs) + +let rewriters_base = + {rewrite_exp = rewrite_exp; + rewrite_pat = rewrite_pat; + rewrite_let = rewrite_let; + rewrite_lexp = rewrite_lexp; + rewrite_fun = rewrite_fun; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base} -let rewrite_defs (Defs defs) = rewrite_defs_base - {rewrite_exp = rewrite_exp; - rewrite_pat = rewrite_pat; - rewrite_let = rewrite_let; - rewrite_lexp = rewrite_lexp; - rewrite_fun = rewrite_fun; - rewrite_def = rewrite_def; - rewrite_defs = rewrite_defs_base} (Defs defs) +let rewrite_defs (Defs defs) = rewrite_defs_base rewriters_base (Defs defs) module Envmap = Finite_map.Fmap_map(String) @@ -860,7 +856,177 @@ let id_exp_alg = ; lB_aux = (fun (lb,annot) -> LB_aux (lb,annot)) ; pat_alg = id_pat_alg } - + +(* Folding algorithms for not only rewriting patterns/expressions, but also + computing some additional value. Usage: Pass default value (bot) and a + binary join operator as arguments, and specify the non-default cases of + rewriting/computation by overwriting fields of the record. + See rewrite_sizeof for examples. *) +let compute_pat_alg bot join = + let join_list vs = List.fold_left join bot vs in + let split_join f ps = let (vs,ps) = List.split ps in (join_list vs, f ps) in + { p_lit = (fun lit -> (bot, P_lit lit)) + ; p_wild = (bot, P_wild) + ; p_as = (fun ((v,pat),id) -> (v, P_as (pat,id))) + ; p_typ = (fun (typ,(v,pat)) -> (v, P_typ (typ,pat))) + ; p_id = (fun id -> (bot, P_id id)) + ; p_app = (fun (id,ps) -> split_join (fun ps -> P_app (id,ps)) ps) + ; p_record = (fun (ps,b) -> split_join (fun ps -> P_record (ps,b)) ps) + ; p_vector = split_join (fun ps -> P_vector ps) + ; p_vector_indexed = (fun ps -> + let (is,ps) = List.split ps in + let (vs,ps) = List.split ps in + (join_list vs, P_vector_indexed (List.combine is ps))) + ; p_vector_concat = split_join (fun ps -> P_vector_concat ps) + ; p_tup = split_join (fun ps -> P_tup ps) + ; p_list = split_join (fun ps -> P_list ps) + ; p_aux = (fun ((v,pat),annot) -> (v, P_aux (pat,annot))) + ; fP_aux = (fun ((v,fpat),annot) -> (v, FP_aux (fpat,annot))) + ; fP_Fpat = (fun (id,(v,pat)) -> (v, FP_Fpat (id,pat))) + } + +let compute_exp_alg bot join = + let join_list vs = List.fold_left join bot vs in + let split_join f es = let (vs,es) = List.split es in (join_list vs, f es) in + { e_block = split_join (fun es -> E_block es) + ; e_nondet = split_join (fun es -> E_nondet es) + ; e_id = (fun id -> (bot, E_id id)) + ; e_lit = (fun lit -> (bot, E_lit lit)) + ; e_cast = (fun (typ,(v,e)) -> (v, E_cast (typ,e))) + ; e_app = (fun (id,es) -> split_join (fun es -> E_app (id,es)) es) + ; e_app_infix = (fun ((v1,e1),id,(v2,e2)) -> (join v1 v2, E_app_infix (e1,id,e2))) + ; e_tuple = split_join (fun es -> E_tuple es) + ; e_if = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_if (e1,e2,e3))) + ; e_for = (fun (id,(v1,e1),(v2,e2),(v3,e3),order,(v4,e4)) -> + (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4))) + ; e_vector = split_join (fun es -> E_vector es) + ; e_vector_indexed = (fun (es,(v2,opt2)) -> + let (is,es) = List.split es in + let (vs,es) = List.split es in + (join_list (vs @ [v2]), E_vector_indexed (List.combine is es,opt2))) + ; e_vector_access = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_access (e1,e2))) + ; e_vector_subrange = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_subrange (e1,e2,e3))) + ; e_vector_update = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_update (e1,e2,e3))) + ; e_vector_update_subrange = (fun ((v1,e1),(v2,e2),(v3,e3),(v4,e4)) -> (join_list [v1;v2;v3;v4], E_vector_update_subrange (e1,e2,e3,e4))) + ; e_vector_append = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_append (e1,e2))) + ; e_list = split_join (fun es -> E_list es) + ; e_cons = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_cons (e1,e2))) + ; e_record = (fun (vs,fexps) -> (vs, E_record fexps)) + ; e_record_update = (fun ((v1,e1),(vf,fexp)) -> (join v1 vf, E_record_update (e1,fexp))) + ; e_field = (fun ((v1,e1),id) -> (v1, E_field (e1,id))) + ; e_case = (fun ((v1,e1),pexps) -> + let (vps,pexps) = List.split pexps in + (join_list (v1::vps), E_case (e1,pexps))) + ; e_let = (fun ((vl,lb),(v2,e2)) -> (join vl v2, E_let (lb,e2))) + ; e_assign = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, E_assign (lexp,e2))) + ; e_sizeof = (fun nexp -> (bot, E_sizeof nexp)) + ; e_exit = (fun (v1,e1) -> (v1, E_exit (e1))) + ; e_return = (fun (v1,e1) -> (v1, E_return e1)) + ; e_assert = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_assert(e1,e2)) ) + ; e_internal_cast = (fun (a,(v1,e1)) -> (v1, E_internal_cast (a,e1))) + ; e_internal_exp = (fun a -> (bot, E_internal_exp a)) + ; e_internal_exp_user = (fun (a1,a2) -> (bot, E_internal_exp_user (a1,a2))) + ; e_internal_let = (fun ((vl, lexp), (v2,e2), (v3,e3)) -> + (join_list [vl;v2;v3], E_internal_let (lexp,e2,e3))) + ; e_internal_plet = (fun ((vp,pat), (v1,e1), (v2,e2)) -> + (join_list [vp;v1;v2], E_internal_plet (pat,e1,e2))) + ; e_internal_return = (fun (v,e) -> (v, E_internal_return e)) + ; e_aux = (fun ((v,e),annot) -> (v, E_aux (e,annot))) + ; lEXP_id = (fun id -> (bot, LEXP_id id)) + ; lEXP_memory = (fun (id,es) -> split_join (fun es -> LEXP_memory (id,es)) es) + ; lEXP_cast = (fun (typ,id) -> (bot, LEXP_cast (typ,id))) + ; lEXP_tup = split_join (fun tups -> LEXP_tup tups) + ; lEXP_vector = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, LEXP_vector (lexp,e2))) + ; lEXP_vector_range = (fun ((vl,lexp),(v2,e2),(v3,e3)) -> + (join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3))) + ; lEXP_field = (fun ((vl,lexp),id) -> (vl, LEXP_field (lexp,id))) + ; lEXP_aux = (fun ((vl,lexp),annot) -> (vl, LEXP_aux (lexp,annot))) + ; fE_Fexp = (fun (id,(v,e)) -> (v, FE_Fexp (id,e))) + ; fE_aux = (fun ((vf,fexp),annot) -> (vf, FE_aux (fexp,annot))) + ; fES_Fexps = (fun (fexps,b) -> + let (vs,fexps) = List.split fexps in + (join_list vs, FES_Fexps (fexps,b))) + ; fES_aux = (fun ((vf,fexp),annot) -> (vf, FES_aux (fexp,annot))) + ; def_val_empty = (bot, Def_val_empty) + ; def_val_dec = (fun (v,e) -> (v, Def_val_dec e)) + ; def_val_aux = (fun ((v,defval),aux) -> (v, Def_val_aux (defval,aux))) + ; pat_exp = (fun ((vp,pat),(v,e)) -> (join vp v, Pat_exp (pat,e))) + ; pat_aux = (fun ((v,pexp),a) -> (v, Pat_aux (pexp,a))) + ; lB_val_explicit = (fun (typ,(vp,pat),(v,e)) -> (join vp v, LB_val_explicit (typ,pat,e))) + ; lB_val_implicit = (fun ((vp,pat),(v,e)) -> (join vp v, LB_val_implicit (pat,e))) + ; lB_aux = (fun ((vl,lb),annot) -> (vl,LB_aux (lb,annot))) + ; pat_alg = compute_pat_alg bot join + } + +let rewrite_sizeof defs = + let sizeof_frees exp = + fst (fold_exp + { (compute_exp_alg KidSet.empty KidSet.union) with + e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) } + exp) in + + let nexps_from_params pat = + fst (fold_pat + { (compute_pat_alg [] (@)) with + p_aux = (fun ((v,pat),((l,_) as annot)) -> + let v' = match pat with + | P_id id | P_as (_, id) -> + let (Typ_aux (typ,_) as typ_aux) = typ_of_annot annot in + (match typ with + | Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp nexp, _)]) + when string_of_id atom = "atom" -> + [nexp, E_id id] + | Typ_app (vector, _) when string_of_id vector = "vector" -> + let (_,len,_,_) = vector_typ_args_of typ_aux in + let exp = E_app + (Id_aux (Id "length", Parse_ast.Generated l), + [E_aux (E_id id, annot)]) in + [len, exp] + | _ -> []) + | _ -> [] in + (v @ v', P_aux (pat,annot)))} pat) in + + let rec e_sizeof nmap (Nexp_aux (nexp, l) as nexp_aux) = + try snd (List.find (fun (nexp,_) -> nexp_identical nexp nexp_aux) nmap) + with + | Not_found -> + let binop nexp1 op nexp2 = E_app_infix ( + E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)), + Id_aux (Id op, Unknown), + E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2)) + ) in + (match nexp with + | Nexp_constant i -> E_lit (L_aux (L_num i, l)) + | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2 + | Nexp_sum (nexp1, nexp2) -> binop nexp1 "+" nexp2 + | Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2 + | _ -> E_sizeof nexp_aux) in + + let rewrite_sizeof_exp nmap rewriters exp = + let exp = rewriters_base.rewrite_exp rewriters exp in + fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp in + + let rewrite_sizeof_fun rewriters + (FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) = + let rewrite_funcl_body (FCL_aux (FCL_Funcl (id,pat,exp), annot)) = + let body_env = env_of exp in + let body_typ = typ_of exp in + let nmap = nexps_from_params pat in + let exp = + try check_exp body_env + (strip_exp (fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp)) + body_typ + with + | Type_error _ -> exp in + FCL_aux (FCL_Funcl (id,pat,exp), annot) in + let funcls = List.map rewrite_funcl_body funcls in + FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot) in + + rewrite_defs_base + { rewriters_base with + rewrite_exp = rewrite_sizeof_exp []; + rewrite_fun = rewrite_sizeof_fun } + defs let remove_vector_concat_pat pat = @@ -2398,7 +2564,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = simple_annot l (typ_of_annot annot)) in let pat = P_aux (P_id id, simple_annot pl (typ_of vexp)) in Added_vars (vexp,pat) - | _ -> raise (Reporting_basic.err_unreachable el "Unsupported l-exp")) + | _ -> Same_vars (E_aux (E_assign (lexp,vexp),annot))) | _ -> (* after rewrite_defs_letbind_effects this expression is pure and updates no variables: check n_exp_term and where it's used. *) @@ -2580,6 +2746,7 @@ let rewrite_defs_remove_e_assign = let rewrite_defs_lem = top_sort_defs >> + rewrite_sizeof >> rewrite_defs_remove_vector_concat >> rewrite_defs_remove_bitvector_pats >> rewrite_defs_exp_lift_assign >> diff --git a/src/type_check.ml b/src/type_check.ml index a7d44544..021ace42 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -1079,17 +1079,6 @@ let typ_equality l env typ1 typ2 = (* 4. Unification *) (**************************************************************************) -let rec nexp_frees (Nexp_aux (nexp, l)) = - match nexp with - | Nexp_id _ -> typ_error l "Unimplemented Nexp_id in nexp_frees" - | Nexp_var kid -> KidSet.singleton kid - | Nexp_constant _ -> KidSet.empty - | Nexp_times (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) - | Nexp_sum (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) - | Nexp_minus (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) - | Nexp_exp n -> nexp_frees n - | Nexp_neg n -> nexp_frees n - let order_frees (Ord_aux (ord_aux, l)) = match ord_aux with | Ord_var kid -> KidSet.singleton kid @@ -1109,18 +1098,6 @@ and typ_arg_frees (Typ_arg_aux (typ_arg_aux, l)) = | Typ_arg_order ord -> order_frees ord | Typ_arg_effect _ -> assert false -let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = - match nexp1, nexp2 with - | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0 - | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 = 0 - | Nexp_constant c1, Nexp_constant c2 -> c1 = c2 - | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b - | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b - | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b - | Nexp_exp n1, Nexp_exp n2 -> nexp_identical n1 n2 - | Nexp_neg n1, Nexp_neg n2 -> nexp_identical n1 n2 - | _, _ -> false - let ord_identical (Ord_aux (ord1, _)) (Ord_aux (ord2, _)) = match ord1, ord2 with | Ord_var kid1, Ord_var kid2 -> Kid.compare kid1 kid2 = 0 @@ -1421,6 +1398,13 @@ let destructure_vec_typ l env typ = in destructure_vec_typ' l (Env.expand_synonyms env typ) + +let env_of_annot (l, tannot) = match tannot with + | Some (env, _, _) -> env + | None -> raise (Reporting_basic.err_unreachable l "no type annotation") + +let env_of (E_aux (_, (l, tannot))) = env_of_annot (l, tannot) + let typ_of_annot (l, tannot) = match tannot with | Some (_, typ, _) -> typ | None -> raise (Reporting_basic.err_unreachable l "no type annotation") diff --git a/src/type_check.mli b/src/type_check.mli index 723f796a..b956e3aa 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -185,6 +185,10 @@ val check_exp : Env.t -> unit exp -> typ -> tannot exp (* Partial functions: The expressions and patterns passed to these functions must be guaranteed to have tannots of the form Some (env, typ) for these to work. *) + +val env_of : tannot exp -> Env.t +val env_of_annot : Ast.l * tannot -> Env.t + val typ_of : tannot exp -> typ val typ_of_annot : Ast.l * tannot -> typ |
