diff options
| author | Thomas Bauereiss | 2017-07-25 13:06:46 +0100 |
|---|---|---|
| committer | Thomas Bauereiss | 2017-07-25 14:06:30 +0100 |
| commit | 0ea787cbb87e5508040d53b06bd812abc5acbb96 (patch) | |
| tree | 5a1898ed30832d107078fb0f1871d360d366f802 /src | |
| parent | 5c306614427179282c8747a6fa6c34637c64ca68 (diff) | |
Add partial support for rewriting of sizeof expressions
Tries to extract values of nexps from the (type annotations of) parameters
passed to the function. This seems to correspond to the behaviour of the
previous typechecker.
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 23 | ||||
| -rw-r--r-- | src/ast_util.mli | 3 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 25 | ||||
| -rw-r--r-- | src/rewriter.ml | 199 | ||||
| -rw-r--r-- | src/type_check.ml | 30 | ||||
| -rw-r--r-- | src/type_check.mli | 4 |
6 files changed, 234 insertions, 50 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 9b3dee84..04ce5b07 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -355,6 +355,29 @@ module IdSet = Set.Make(Id) module KBindings = Map.Make(Kid) module KidSet = Set.Make(Kid) +let rec nexp_frees (Nexp_aux (nexp, l)) = + match nexp with + | Nexp_id _ -> raise (Reporting_basic.err_typ l "Unimplemented Nexp_id in nexp_frees") + | Nexp_var kid -> KidSet.singleton kid + | Nexp_constant _ -> KidSet.empty + | Nexp_times (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) + | Nexp_sum (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) + | Nexp_minus (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) + | Nexp_exp n -> nexp_frees n + | Nexp_neg n -> nexp_frees n + +let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = + match nexp1, nexp2 with + | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0 + | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 = 0 + | Nexp_constant c1, Nexp_constant c2 -> c1 = c2 + | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b + | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b + | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b + | Nexp_exp n1, Nexp_exp n2 -> nexp_identical n1 n2 + | Nexp_neg n1, Nexp_neg n2 -> nexp_identical n1 n2 + | _, _ -> false + let rec is_number (Typ_aux (t,_)) = match t with | Typ_app (Id_aux (Id "range", _),_) diff --git a/src/ast_util.mli b/src/ast_util.mli index 0adda3ef..d7f68412 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -115,6 +115,9 @@ module Bindings : sig include Map.S with type key = id end +val nexp_frees : nexp -> KidSet.t +val nexp_identical : nexp -> nexp -> bool + val is_number : typ -> bool val is_vector_typ : typ -> bool val is_bit_typ : typ -> bool diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 6b7b8aca..9f479cdc 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -392,7 +392,7 @@ let doc_exp_lem, doc_let_lem = | _ -> (prefix 2 1) (string "write_reg") (doc_lexp_deref_lem regtypes le ^/^ expY e)) | E_vector_append(le,re) -> - let t = typ_of_annot (l,annot) in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (call,ta,aexp_needed) = if is_bitvector_typ t then if not (contains_t_pp_var t) @@ -450,7 +450,7 @@ let doc_exp_lem, doc_let_lem = let call = if is_bitvector_typ t1 then "bvslice_raw" else "slice_raw" in let epp = separate space [string call;expY e1;expY e2;expY e3] in let (taepp,aexp_needed) = - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let eff = effect_of full_exp in if contains_bitvector_typ t && not (contains_t_pp_var t) then (align epp ^^ (doc_tannot_lem regtypes (effectful eff) t), true) @@ -504,7 +504,7 @@ let doc_exp_lem, doc_let_lem = | args -> parens (align (separate_map (comma ^^ break 0) (argpp false) args)) in let epp = align (call ^//^ argspp) in let (taepp,aexp_needed) = - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let eff = effect_of full_exp in if contains_bitvector_typ t && not (contains_t_pp_var t) then (align epp ^^ (doc_tannot_lem regtypes (effectful eff) t), true) @@ -523,7 +523,7 @@ let doc_exp_lem, doc_let_lem = separate space [string call;expY v;expY e] in if aexp_needed then parens (align epp) else epp | E_vector_subrange (v,e1,e2) -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let eff = effect_of full_exp in let (epp,aexp_needed) = if has_effect eff BE_rreg then @@ -543,7 +543,7 @@ let doc_exp_lem, doc_let_lem = let ft = typ_of_annot (l,fannot) in (match fannot with | Some(env, (Typ_aux (Typ_id tid, _)), _) when Env.is_regtyp tid env -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let field_f = string (if is_bit_typ t then "read_reg_bitfield" @@ -566,7 +566,7 @@ let doc_exp_lem, doc_let_lem = | E_block exps -> raise (report l "Blocks should have been removed till now.") | E_nondet exps -> raise (report l "Nondet blocks not supported.") | E_id id -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in (match annot with | Some (env, Typ_aux (Typ_id tid, _), eff) when Env.is_regtyp tid env -> if has_effect eff BE_rreg then @@ -613,6 +613,7 @@ let doc_exp_lem, doc_let_lem = | _ -> doc_id_lem id) | E_lit lit -> doc_lit_lem false lit annot | E_cast(typ,e) -> + let typ = Env.base_typ_of (env_of full_exp) typ in if is_vector_typ typ then let (start,_,_,_) = vector_typ_args_of typ in let call = @@ -670,10 +671,11 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (report l "cannot get record type") in anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp regtypes recordtyp) fexps)) | E_vector exps -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (start, len, order, etyp) = if is_vector_typ t then vector_typ_args_of t - else raise (Reporting_basic.err_unreachable l "E_vector of non-vector type") in + else raise (Reporting_basic.err_unreachable l + "E_vector of non-vector type") in (*match annot with | Base((_,t),_,_,_,_,_) -> match t.t with @@ -707,7 +709,7 @@ let doc_exp_lem, doc_let_lem = if aexp_needed then parens (align epp) else epp (* *) | E_vector_indexed (iexps, (Def_val_aux (default,(dl,dannot)))) -> - let t = typ_of full_exp in + let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (start, len, order, etyp) = if is_vector_typ t then vector_typ_args_of t else raise (Reporting_basic.err_unreachable l "E_vector_indexed of non-vector type") in @@ -1285,8 +1287,9 @@ let doc_dec_lem (DEC_aux (reg,(l,annot))) = (match typ with | Typ_aux (Typ_app (r, [Typ_arg_aux (Typ_arg_typ rt, _)]), _) when string_of_id r = "register" && is_vector_typ rt -> - let (start, size, order, etyp) = vector_typ_args_of rt in - (match is_bit_typ etyp, start, size with + let env = env_of_annot (l,annot) in + let (start, size, order, etyp) = vector_typ_args_of (Env.base_typ_of env rt) in + (match is_bit_typ (Env.base_typ_of env etyp), start, size with | true, Nexp_aux (Nexp_constant start, _), Nexp_aux (Nexp_constant size, _) -> let o = if is_order_inc order then "true" else "false" in (doc_op equals) diff --git a/src/rewriter.ml b/src/rewriter.ml index ecde3e8a..560159d2 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -60,12 +60,6 @@ type 'a rewriters = { let (>>) f g = fun x -> g(f(x)) -let env_of_annot = function - | (_,Some(env,_,_)) -> env - | (l,None) -> Env.empty - -let env_of (E_aux (_,a)) = env_of_annot a - let effect_of_fpat (FP_aux (_,(_,a))) = effect_of_annot a let effect_of_lexp (LEXP_aux (_,(_,a))) = effect_of_annot a let effect_of_fexp (FE_aux (_,(_,a))) = effect_of_annot a @@ -573,15 +567,17 @@ let rewrite_defs_base rewriters (Defs defs) = | [] -> [] | d::ds -> (rewriters.rewrite_def rewriters d)::(rewrite ds) in Defs (rewrite defs) + +let rewriters_base = + {rewrite_exp = rewrite_exp; + rewrite_pat = rewrite_pat; + rewrite_let = rewrite_let; + rewrite_lexp = rewrite_lexp; + rewrite_fun = rewrite_fun; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base} -let rewrite_defs (Defs defs) = rewrite_defs_base - {rewrite_exp = rewrite_exp; - rewrite_pat = rewrite_pat; - rewrite_let = rewrite_let; - rewrite_lexp = rewrite_lexp; - rewrite_fun = rewrite_fun; - rewrite_def = rewrite_def; - rewrite_defs = rewrite_defs_base} (Defs defs) +let rewrite_defs (Defs defs) = rewrite_defs_base rewriters_base (Defs defs) module Envmap = Finite_map.Fmap_map(String) @@ -860,7 +856,177 @@ let id_exp_alg = ; lB_aux = (fun (lb,annot) -> LB_aux (lb,annot)) ; pat_alg = id_pat_alg } - + +(* Folding algorithms for not only rewriting patterns/expressions, but also + computing some additional value. Usage: Pass default value (bot) and a + binary join operator as arguments, and specify the non-default cases of + rewriting/computation by overwriting fields of the record. + See rewrite_sizeof for examples. *) +let compute_pat_alg bot join = + let join_list vs = List.fold_left join bot vs in + let split_join f ps = let (vs,ps) = List.split ps in (join_list vs, f ps) in + { p_lit = (fun lit -> (bot, P_lit lit)) + ; p_wild = (bot, P_wild) + ; p_as = (fun ((v,pat),id) -> (v, P_as (pat,id))) + ; p_typ = (fun (typ,(v,pat)) -> (v, P_typ (typ,pat))) + ; p_id = (fun id -> (bot, P_id id)) + ; p_app = (fun (id,ps) -> split_join (fun ps -> P_app (id,ps)) ps) + ; p_record = (fun (ps,b) -> split_join (fun ps -> P_record (ps,b)) ps) + ; p_vector = split_join (fun ps -> P_vector ps) + ; p_vector_indexed = (fun ps -> + let (is,ps) = List.split ps in + let (vs,ps) = List.split ps in + (join_list vs, P_vector_indexed (List.combine is ps))) + ; p_vector_concat = split_join (fun ps -> P_vector_concat ps) + ; p_tup = split_join (fun ps -> P_tup ps) + ; p_list = split_join (fun ps -> P_list ps) + ; p_aux = (fun ((v,pat),annot) -> (v, P_aux (pat,annot))) + ; fP_aux = (fun ((v,fpat),annot) -> (v, FP_aux (fpat,annot))) + ; fP_Fpat = (fun (id,(v,pat)) -> (v, FP_Fpat (id,pat))) + } + +let compute_exp_alg bot join = + let join_list vs = List.fold_left join bot vs in + let split_join f es = let (vs,es) = List.split es in (join_list vs, f es) in + { e_block = split_join (fun es -> E_block es) + ; e_nondet = split_join (fun es -> E_nondet es) + ; e_id = (fun id -> (bot, E_id id)) + ; e_lit = (fun lit -> (bot, E_lit lit)) + ; e_cast = (fun (typ,(v,e)) -> (v, E_cast (typ,e))) + ; e_app = (fun (id,es) -> split_join (fun es -> E_app (id,es)) es) + ; e_app_infix = (fun ((v1,e1),id,(v2,e2)) -> (join v1 v2, E_app_infix (e1,id,e2))) + ; e_tuple = split_join (fun es -> E_tuple es) + ; e_if = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_if (e1,e2,e3))) + ; e_for = (fun (id,(v1,e1),(v2,e2),(v3,e3),order,(v4,e4)) -> + (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4))) + ; e_vector = split_join (fun es -> E_vector es) + ; e_vector_indexed = (fun (es,(v2,opt2)) -> + let (is,es) = List.split es in + let (vs,es) = List.split es in + (join_list (vs @ [v2]), E_vector_indexed (List.combine is es,opt2))) + ; e_vector_access = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_access (e1,e2))) + ; e_vector_subrange = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_subrange (e1,e2,e3))) + ; e_vector_update = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_update (e1,e2,e3))) + ; e_vector_update_subrange = (fun ((v1,e1),(v2,e2),(v3,e3),(v4,e4)) -> (join_list [v1;v2;v3;v4], E_vector_update_subrange (e1,e2,e3,e4))) + ; e_vector_append = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_append (e1,e2))) + ; e_list = split_join (fun es -> E_list es) + ; e_cons = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_cons (e1,e2))) + ; e_record = (fun (vs,fexps) -> (vs, E_record fexps)) + ; e_record_update = (fun ((v1,e1),(vf,fexp)) -> (join v1 vf, E_record_update (e1,fexp))) + ; e_field = (fun ((v1,e1),id) -> (v1, E_field (e1,id))) + ; e_case = (fun ((v1,e1),pexps) -> + let (vps,pexps) = List.split pexps in + (join_list (v1::vps), E_case (e1,pexps))) + ; e_let = (fun ((vl,lb),(v2,e2)) -> (join vl v2, E_let (lb,e2))) + ; e_assign = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, E_assign (lexp,e2))) + ; e_sizeof = (fun nexp -> (bot, E_sizeof nexp)) + ; e_exit = (fun (v1,e1) -> (v1, E_exit (e1))) + ; e_return = (fun (v1,e1) -> (v1, E_return e1)) + ; e_assert = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_assert(e1,e2)) ) + ; e_internal_cast = (fun (a,(v1,e1)) -> (v1, E_internal_cast (a,e1))) + ; e_internal_exp = (fun a -> (bot, E_internal_exp a)) + ; e_internal_exp_user = (fun (a1,a2) -> (bot, E_internal_exp_user (a1,a2))) + ; e_internal_let = (fun ((vl, lexp), (v2,e2), (v3,e3)) -> + (join_list [vl;v2;v3], E_internal_let (lexp,e2,e3))) + ; e_internal_plet = (fun ((vp,pat), (v1,e1), (v2,e2)) -> + (join_list [vp;v1;v2], E_internal_plet (pat,e1,e2))) + ; e_internal_return = (fun (v,e) -> (v, E_internal_return e)) + ; e_aux = (fun ((v,e),annot) -> (v, E_aux (e,annot))) + ; lEXP_id = (fun id -> (bot, LEXP_id id)) + ; lEXP_memory = (fun (id,es) -> split_join (fun es -> LEXP_memory (id,es)) es) + ; lEXP_cast = (fun (typ,id) -> (bot, LEXP_cast (typ,id))) + ; lEXP_tup = split_join (fun tups -> LEXP_tup tups) + ; lEXP_vector = (fun ((vl,lexp),(v2,e2)) -> (join vl v2, LEXP_vector (lexp,e2))) + ; lEXP_vector_range = (fun ((vl,lexp),(v2,e2),(v3,e3)) -> + (join_list [vl;v2;v3], LEXP_vector_range (lexp,e2,e3))) + ; lEXP_field = (fun ((vl,lexp),id) -> (vl, LEXP_field (lexp,id))) + ; lEXP_aux = (fun ((vl,lexp),annot) -> (vl, LEXP_aux (lexp,annot))) + ; fE_Fexp = (fun (id,(v,e)) -> (v, FE_Fexp (id,e))) + ; fE_aux = (fun ((vf,fexp),annot) -> (vf, FE_aux (fexp,annot))) + ; fES_Fexps = (fun (fexps,b) -> + let (vs,fexps) = List.split fexps in + (join_list vs, FES_Fexps (fexps,b))) + ; fES_aux = (fun ((vf,fexp),annot) -> (vf, FES_aux (fexp,annot))) + ; def_val_empty = (bot, Def_val_empty) + ; def_val_dec = (fun (v,e) -> (v, Def_val_dec e)) + ; def_val_aux = (fun ((v,defval),aux) -> (v, Def_val_aux (defval,aux))) + ; pat_exp = (fun ((vp,pat),(v,e)) -> (join vp v, Pat_exp (pat,e))) + ; pat_aux = (fun ((v,pexp),a) -> (v, Pat_aux (pexp,a))) + ; lB_val_explicit = (fun (typ,(vp,pat),(v,e)) -> (join vp v, LB_val_explicit (typ,pat,e))) + ; lB_val_implicit = (fun ((vp,pat),(v,e)) -> (join vp v, LB_val_implicit (pat,e))) + ; lB_aux = (fun ((vl,lb),annot) -> (vl,LB_aux (lb,annot))) + ; pat_alg = compute_pat_alg bot join + } + +let rewrite_sizeof defs = + let sizeof_frees exp = + fst (fold_exp + { (compute_exp_alg KidSet.empty KidSet.union) with + e_sizeof = (fun nexp -> (nexp_frees nexp, E_sizeof nexp)) } + exp) in + + let nexps_from_params pat = + fst (fold_pat + { (compute_pat_alg [] (@)) with + p_aux = (fun ((v,pat),((l,_) as annot)) -> + let v' = match pat with + | P_id id | P_as (_, id) -> + let (Typ_aux (typ,_) as typ_aux) = typ_of_annot annot in + (match typ with + | Typ_app (atom, [Typ_arg_aux (Typ_arg_nexp nexp, _)]) + when string_of_id atom = "atom" -> + [nexp, E_id id] + | Typ_app (vector, _) when string_of_id vector = "vector" -> + let (_,len,_,_) = vector_typ_args_of typ_aux in + let exp = E_app + (Id_aux (Id "length", Parse_ast.Generated l), + [E_aux (E_id id, annot)]) in + [len, exp] + | _ -> []) + | _ -> [] in + (v @ v', P_aux (pat,annot)))} pat) in + + let rec e_sizeof nmap (Nexp_aux (nexp, l) as nexp_aux) = + try snd (List.find (fun (nexp,_) -> nexp_identical nexp nexp_aux) nmap) + with + | Not_found -> + let binop nexp1 op nexp2 = E_app_infix ( + E_aux (e_sizeof nmap nexp1, simple_annot l (atom_typ nexp1)), + Id_aux (Id op, Unknown), + E_aux (e_sizeof nmap nexp2, simple_annot l (atom_typ nexp2)) + ) in + (match nexp with + | Nexp_constant i -> E_lit (L_aux (L_num i, l)) + | Nexp_times (nexp1, nexp2) -> binop nexp1 "*" nexp2 + | Nexp_sum (nexp1, nexp2) -> binop nexp1 "+" nexp2 + | Nexp_minus (nexp1, nexp2) -> binop nexp1 "-" nexp2 + | _ -> E_sizeof nexp_aux) in + + let rewrite_sizeof_exp nmap rewriters exp = + let exp = rewriters_base.rewrite_exp rewriters exp in + fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp in + + let rewrite_sizeof_fun rewriters + (FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) = + let rewrite_funcl_body (FCL_aux (FCL_Funcl (id,pat,exp), annot)) = + let body_env = env_of exp in + let body_typ = typ_of exp in + let nmap = nexps_from_params pat in + let exp = + try check_exp body_env + (strip_exp (fold_exp { id_exp_alg with e_sizeof = e_sizeof nmap } exp)) + body_typ + with + | Type_error _ -> exp in + FCL_aux (FCL_Funcl (id,pat,exp), annot) in + let funcls = List.map rewrite_funcl_body funcls in + FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot) in + + rewrite_defs_base + { rewriters_base with + rewrite_exp = rewrite_sizeof_exp []; + rewrite_fun = rewrite_sizeof_fun } + defs let remove_vector_concat_pat pat = @@ -2398,7 +2564,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = simple_annot l (typ_of_annot annot)) in let pat = P_aux (P_id id, simple_annot pl (typ_of vexp)) in Added_vars (vexp,pat) - | _ -> raise (Reporting_basic.err_unreachable el "Unsupported l-exp")) + | _ -> Same_vars (E_aux (E_assign (lexp,vexp),annot))) | _ -> (* after rewrite_defs_letbind_effects this expression is pure and updates no variables: check n_exp_term and where it's used. *) @@ -2580,6 +2746,7 @@ let rewrite_defs_remove_e_assign = let rewrite_defs_lem = top_sort_defs >> + rewrite_sizeof >> rewrite_defs_remove_vector_concat >> rewrite_defs_remove_bitvector_pats >> rewrite_defs_exp_lift_assign >> diff --git a/src/type_check.ml b/src/type_check.ml index a7d44544..021ace42 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -1079,17 +1079,6 @@ let typ_equality l env typ1 typ2 = (* 4. Unification *) (**************************************************************************) -let rec nexp_frees (Nexp_aux (nexp, l)) = - match nexp with - | Nexp_id _ -> typ_error l "Unimplemented Nexp_id in nexp_frees" - | Nexp_var kid -> KidSet.singleton kid - | Nexp_constant _ -> KidSet.empty - | Nexp_times (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) - | Nexp_sum (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) - | Nexp_minus (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) - | Nexp_exp n -> nexp_frees n - | Nexp_neg n -> nexp_frees n - let order_frees (Ord_aux (ord_aux, l)) = match ord_aux with | Ord_var kid -> KidSet.singleton kid @@ -1109,18 +1098,6 @@ and typ_arg_frees (Typ_arg_aux (typ_arg_aux, l)) = | Typ_arg_order ord -> order_frees ord | Typ_arg_effect _ -> assert false -let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = - match nexp1, nexp2 with - | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0 - | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 = 0 - | Nexp_constant c1, Nexp_constant c2 -> c1 = c2 - | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b - | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b - | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b - | Nexp_exp n1, Nexp_exp n2 -> nexp_identical n1 n2 - | Nexp_neg n1, Nexp_neg n2 -> nexp_identical n1 n2 - | _, _ -> false - let ord_identical (Ord_aux (ord1, _)) (Ord_aux (ord2, _)) = match ord1, ord2 with | Ord_var kid1, Ord_var kid2 -> Kid.compare kid1 kid2 = 0 @@ -1421,6 +1398,13 @@ let destructure_vec_typ l env typ = in destructure_vec_typ' l (Env.expand_synonyms env typ) + +let env_of_annot (l, tannot) = match tannot with + | Some (env, _, _) -> env + | None -> raise (Reporting_basic.err_unreachable l "no type annotation") + +let env_of (E_aux (_, (l, tannot))) = env_of_annot (l, tannot) + let typ_of_annot (l, tannot) = match tannot with | Some (_, typ, _) -> typ | None -> raise (Reporting_basic.err_unreachable l "no type annotation") diff --git a/src/type_check.mli b/src/type_check.mli index 723f796a..b956e3aa 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -185,6 +185,10 @@ val check_exp : Env.t -> unit exp -> typ -> tannot exp (* Partial functions: The expressions and patterns passed to these functions must be guaranteed to have tannots of the form Some (env, typ) for these to work. *) + +val env_of : tannot exp -> Env.t +val env_of_annot : Ast.l * tannot -> Env.t + val typ_of : tannot exp -> typ val typ_of_annot : Ast.l * tannot -> typ |
