diff options
Diffstat (limited to 'src/rewriter.ml')
| -rw-r--r-- | src/rewriter.ml | 304 |
1 files changed, 159 insertions, 145 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index 1411a3bf..f099427d 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -1029,6 +1029,14 @@ let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp = | Nexp_neg nexp -> mk_exp (E_app (mk_id "negate_range", [split_nexp nexp])) | _ -> mk_exp (E_sizeof nexp) in + let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = match nexp with + | Nexp_id id -> rewrite_nexp_ids env (Env.get_num_def id env) + | Nexp_times (nexp1, nexp2) -> Nexp_aux (Nexp_times (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) + | Nexp_sum (nexp1, nexp2) -> Nexp_aux (Nexp_sum (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) + | Nexp_minus (nexp1, nexp2) -> Nexp_aux (Nexp_minus (rewrite_nexp_ids env nexp1, rewrite_nexp_ids env nexp2), l) + | Nexp_exp nexp -> Nexp_aux (Nexp_exp (rewrite_nexp_ids env nexp), l) + | Nexp_neg nexp -> Nexp_aux (Nexp_neg (rewrite_nexp_ids env nexp), l) + | _ -> nexp_aux in let rec rewrite_e_aux split_sizeof (E_aux (e_aux, (l, _)) as orig_exp) = let env = env_of orig_exp in match e_aux with @@ -1036,17 +1044,21 @@ let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp = E_aux (E_lit (L_aux (L_num c, l)), (l, Some (env, atom_typ nexp, no_effect))) | E_sizeof nexp -> begin - let locals = Env.get_locals env in - let exps = Bindings.bindings locals - |> List.map (extract_typ_var l env nexp) - |> List.map (fun opt -> match opt with Some x -> [x] | None -> []) - |> List.concat - in - match exps with - | (exp :: _) -> exp - | [] when split_sizeof -> - fold_exp (rewrite_e_sizeof false) (check_exp env (split_nexp nexp) (typ_of orig_exp)) - | [] -> orig_exp + match simplify_nexp (rewrite_nexp_ids (env_of orig_exp) nexp) with + | Nexp_aux (Nexp_constant c, _) -> + E_aux (E_lit (L_aux (L_num c, l)), (l, Some (env, atom_typ nexp, no_effect))) + | _ -> + let locals = Env.get_locals env in + let exps = Bindings.bindings locals + |> List.map (extract_typ_var l env nexp) + |> List.map (fun opt -> match opt with Some x -> [x] | None -> []) + |> List.concat + in + match exps with + | (exp :: _) -> exp + | [] when split_sizeof -> + fold_exp (rewrite_e_sizeof false) (check_exp env (split_nexp nexp) (typ_of orig_exp)) + | [] -> orig_exp end | _ -> orig_exp and rewrite_e_sizeof split_sizeof = @@ -1244,17 +1256,20 @@ let rewrite_sizeof (Defs defs) = let kid_pats = List.map kid_pat (KidSet.elements nvars) in let kid_nmap = List.map (fun kid -> (nvar kid, kid_eaux kid)) (KidSet.elements nvars) in let rewrite_funcl_params (FCL_aux (FCL_Funcl (id, pat, exp), annot) as funcl) = - let rec rewrite_pat (P_aux (pat,(l,_)) as paux) = + let rec rewrite_pat (P_aux (pat, ((l, _) as pannot)) as paux) = + let penv = env_of_annot pannot in + let peff = effect_of_annot (snd pannot) in if KidSet.is_empty nvars then paux else match pat_typ_of paux with - | Typ_aux (Typ_tup _, _) -> + | Typ_aux (Typ_tup typs, _) -> + let ptyp' = Typ_aux (Typ_tup (kid_typs @ typs), l) in (match pat with | P_tup pats -> - P_aux (P_tup (kid_pats @ pats), (l, None)) - | P_wild -> paux + P_aux (P_tup (kid_pats @ pats), (l, Some (penv, ptyp', peff))) + | P_wild -> P_aux (pat, (l, Some (penv, ptyp', peff))) | P_typ (Typ_aux (Typ_tup typs, l), pat) -> P_aux (P_typ (Typ_aux (Typ_tup (kid_typs @ typs), l), - rewrite_pat pat), (l, None)) + rewrite_pat pat), (l, Some (penv, ptyp', peff))) | P_as (_, id) | P_id id -> (* adding parameters here would change the type of id; we should remove the P_as/P_id here and add a let-binding to the body *) @@ -1263,13 +1278,15 @@ let rewrite_sizeof (Defs defs) = | _ -> raise (Reporting_basic.err_unreachable l "unexpected pattern while rewriting function parameters for sizeof expressions")) - | _ -> P_aux (P_tup (kid_pats @ [paux]), (l, None)) in + | ptyp -> + let ptyp' = Typ_aux (Typ_tup (kid_typs @ [ptyp]), l) in + P_aux (P_tup (kid_pats @ [paux]), (l, Some (penv, ptyp', peff))) in let exp' = fold_exp { id_exp_alg with e_sizeof = e_sizeof kid_nmap } exp in FCL_aux (FCL_Funcl (id, rewrite_pat pat, exp'), annot) in let funcls = List.map rewrite_funcl_params funcls in (nvars, FD_aux (FD_function (rec_opt,tannot,eff,funcls),annot)) in - let rewrite_sizeof_fundef (params_map, defs) = function + let rewrite_sizeof_def (params_map, defs) = function | DEF_fundef fd as def -> let (nvars, fd') = rewrite_sizeof_fun params_map fd in let id = id_of_fundef fd in @@ -1277,6 +1294,17 @@ let rewrite_sizeof (Defs defs) = if KidSet.is_empty nvars then params_map else Bindings.add id nvars params_map in (params_map', defs @ [DEF_fundef fd']) + | DEF_val (LB_aux (lb, annot)) -> + begin + let lb' = match lb with + | LB_val_explicit (typschm, pat, exp) -> + let exp' = fst (fold_exp { copy_exp_alg with e_aux = e_app_aux params_map } exp) in + LB_val_explicit (typschm, pat, exp') + | LB_val_implicit (pat, exp) -> + let exp' = fst (fold_exp { copy_exp_alg with e_aux = e_app_aux params_map } exp) in + LB_val_implicit (pat, exp') in + (params_map, defs @ [DEF_val (LB_aux (lb', annot))]) + end | def -> (params_map, defs @ [def]) in @@ -1309,7 +1337,7 @@ let rewrite_sizeof (Defs defs) = DEF_spec (VS_aux (VS_cast_spec (rewrite_typschm typschm id, id), a)) | _ -> def in - let (params_map, defs) = List.fold_left rewrite_sizeof_fundef + let (params_map, defs) = List.fold_left rewrite_sizeof_def (Bindings.empty, []) defs in let defs = List.map (rewrite_sizeof_valspec params_map) defs in Defs defs @@ -1412,8 +1440,12 @@ let remove_vector_concat_pat pat = let root = E_aux (E_id rootid, rannot) in let index_i = simple_num l i in let index_j = simple_num l j in - - let subv = fix_eff_exp (E_aux (E_vector_subrange (root, index_i, index_j), cannot)) in + + (* FIXME *) + (* let subv = fix_eff_exp (E_aux (E_vector_subrange (root, index_i, index_j), cannot)) in *) + let (_, _, ord, _) = vector_typ_args_of (Env.base_typ_of (env_of root) (typ_of root)) in + let subrange_id = if is_order_inc ord then "bitvector_subrange_inc" else "bitvector_subrange_dec" in + let subv = fix_eff_exp (E_aux (E_app (mk_id subrange_id, [root; index_i; index_j]), cannot)) in let id_pat = match typ_opt with @@ -1877,7 +1909,13 @@ let remove_bitvector_pat pat = (* Helper functions for generating guard expressions *) let access_bit_exp (rootid,rannot) l idx = let root : tannot exp = E_aux (E_id rootid,rannot) in - E_aux (E_vector_access (root,simple_num l idx), simple_annot l bit_typ) in + (* FIXME *) + (* E_aux (E_vector_access (root,simple_num l idx), simple_annot l bit_typ) in *) + let env = env_of_annot rannot in + let t = Env.base_typ_of env (typ_of_annot rannot) in + let (_, _, ord, _) = vector_typ_args_of t in + let access_id = if is_order_inc ord then "bitvector_access_inc" else "bitvector_access_dec" in + E_aux (E_app (mk_id access_id, [root; simple_num l idx]), simple_annot l bit_typ) in let test_bit_exp rootid l t idx exp = let rannot = simple_annot l t in @@ -1902,10 +1940,13 @@ let remove_bitvector_pat pat = | _ -> (*if vec_start t = i && vec_length t = List.length lits then E_id rootid - else*) E_vector_subrange ( + else*) + (* E_vector_subrange ( E_aux (E_id rootid, simple_annot l typ), simple_num l i, - simple_num l j) in + simple_num l j) in *) + let subrange_id = if is_order_inc ord then "bitvector_subrange_inc" else "bitvector_subrange_dec" in + E_app (mk_id subrange_id, [E_aux (E_id rootid, simple_annot l typ); simple_num l i; simple_num l j]) in E_aux (E_app( Id_aux (Id "eq_vec", Parse_ast.Generated l), [E_aux (subvec_exp, simple_annot l typ'); @@ -1923,7 +1964,12 @@ let remove_bitvector_pat pat = (letexp, letbind) in let compose_guards guards = - List.fold_right (Util.option_binop bitwise_and_exp) guards None in + let conj g1 g2 = match g1, g2 with + | Some g1, Some g2 -> Some (bitwise_and_exp g1 g2) + | Some g1, None -> Some g1 + | None, Some g2 -> Some g2 + | None, None -> None in + List.fold_right conj guards None in let flatten_guards_decls gd = let (guards,decls,letbinds) = Util.split3 gd in @@ -2094,7 +2140,10 @@ let rewrite_defs_remove_bitvector_pats (Defs defs) = let defvals = List.map (fun lb -> DEF_val lb) letbinds in [DEF_val (LB_aux (LB_val_implicit (pat',exp),a))] @ defvals | d -> [d] in - fst (check initial_env (Defs (List.flatten (List.map rewrite_def defs)))) + (* FIXME See above in rewrite_sizeof *) + (* fst (check initial_env ( *) + Defs (List.flatten (List.map rewrite_def defs)) + (* )) *) (* Remove pattern guards by rewriting them to if-expressions within the @@ -2129,6 +2178,42 @@ let rewrite_exp_guarded_pats rewriters (E_aux (exp,(l,annot)) as full_exp) = let rewrite_defs_guarded_pats = rewrite_defs_base { rewriters_base with rewrite_exp = rewrite_exp_guarded_pats } + +let id_is_local_var id env = match Env.lookup_id id env with + | Local _ | Unbound -> true + | _ -> false + +let rec lexp_is_local (LEXP_aux (lexp, _)) env = match lexp with + | LEXP_memory _ -> false + | LEXP_id id + | LEXP_cast (_, id) -> id_is_local_var id env + | LEXP_tup lexps -> List.for_all (fun lexp -> lexp_is_local lexp env) lexps + | LEXP_vector (lexp,_) + | LEXP_vector_range (lexp,_,_) + | LEXP_field (lexp,_) -> lexp_is_local lexp env + +let lexp_is_effectful (LEXP_aux (_, (_, annot))) = match annot with + | Some (_, _, eff) -> effectful_effs eff + | _ -> false + +let rec rewrite_local_lexp ((LEXP_aux(lexp,((l,_) as annot))) as le) = match lexp with + | LEXP_id id | LEXP_cast (_, id) -> + (le, E_aux (E_id id, annot), (fun exp -> exp)) + | LEXP_vector (lexp, e) -> + let (lexp, access, rexp) = rewrite_local_lexp lexp in + (lexp, E_aux (E_vector_access (access, e), annot), + (fun exp -> rexp (E_aux (E_vector_update (access, e, exp), annot)))) + | LEXP_vector_range (lexp, e1, e2) -> + let (lexp, access, rexp) = rewrite_local_lexp lexp in + (lexp, E_aux (E_vector_subrange (access, e1, e2), annot), + (fun exp -> rexp (E_aux (E_vector_update_subrange (access, e1, e2, exp), annot)))) + | LEXP_field (lexp, id) -> + let (lexp, access, rexp) = rewrite_local_lexp lexp in + let field_update exp = FES_aux (FES_Fexps ([FE_aux (FE_Fexp (id, exp), annot)], false), annot) in + (lexp, E_aux (E_field (access, id), annot), + (fun exp -> rexp (E_aux (E_record_update (access, field_update exp), annot)))) + | _ -> raise (Reporting_basic.err_unreachable l "unsupported lexp") + (*Expects to be called after rewrite_defs; thus the following should not appear: internal_exp of any form lit vectors in patterns or expressions @@ -2143,17 +2228,14 @@ let rewrite_exp_lift_assign_intro rewriters ((E_aux (exp,((l,_) as annot))) as f | E_block exps -> let rec walker exps = match exps with | [] -> [] - | (E_aux(E_assign((LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_)) as le,e), - ((l, Some (env,typ,eff)) as annot)) as exp)::exps -> - (match Env.lookup_id id env with - | Unbound -> - let le' = rewriters.rewrite_lexp rewriters le in - let e' = rewrite_base e in - let exps' = walker exps in - let effects = union_eff_exps exps' in - let block = E_aux (E_block exps', (l, Some (env, unit_typ, effects))) in - [fix_eff_exp (E_aux (E_internal_let(le', e', block), annot))] - | _ -> (rewrite_rec exp)::(walker exps)) + | (E_aux(E_assign(le,e), ((l, Some (env,typ,eff)) as annot)) as exp)::exps + when lexp_is_local le env && not (lexp_is_effectful le)-> + let (le', _, re') = rewrite_local_lexp le in + let e' = re' (rewrite_base e) in + let exps' = walker exps in + let effects = union_eff_exps exps' in + let block = E_aux (E_block exps', (l, Some (env, unit_typ, effects))) in + [fix_eff_exp (E_aux (E_internal_let(le', e', block), annot))] (*| ((E_aux(E_if(c,t,e),(l,annot))) as exp)::exps -> let vars_t = introduced_variables t in let vars_e = introduced_variables e in @@ -2199,20 +2281,12 @@ let rewrite_exp_lift_assign_intro rewriters ((E_aux (exp,((l,_) as annot))) as f | e::exps -> (rewrite_rec e)::(walker exps) in rewrap (E_block (walker exps)) - | E_assign(((LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),lannot)) as le),e) -> - let le' = rewriters.rewrite_lexp rewriters le in - let e' = rewrite_base e in - let effects = effect_of e' in - (match Env.lookup_id id (env_of_annot annot) with - | Unbound -> - rewrap_effects - (E_internal_let(le', e', E_aux(E_block [], simple_annot l unit_typ))) - effects - | Local _ -> - let effects' = union_effects effects (effect_of_annot (snd lannot)) in - let annot' = Some (env_of_annot annot, unit_typ, effects') in - E_aux((E_assign(le', e')),(l, annot')) - | _ -> rewrite_base full_exp) + | E_assign(le,e) + when lexp_is_local le (env_of full_exp) && not (lexp_is_effectful le) -> + let (le', _, re') = rewrite_local_lexp le in + let e' = re' (rewrite_base e) in + let block = E_aux (E_block [], simple_annot l unit_typ) in + fix_eff_exp (E_aux (E_internal_let(le', e', block), annot)) | _ -> rewrite_base full_exp let rewrite_lexp_lift_assign_intro rewriters ((LEXP_aux(lexp,annot)) as le) = @@ -2777,7 +2851,8 @@ let rewrite_defs_letbind_effects = rewrap (E_let (lb,n_exp body k))) | E_sizeof nexp -> k (rewrap (E_sizeof nexp)) - | E_constraint nc -> failwith "E_constraint should have been removed till now" + | E_constraint nc -> + k (rewrap (E_constraint nc)) | E_sizeof_internal annot -> k (rewrap (E_sizeof_internal annot)) | E_assign (lexp,exp1) -> @@ -2841,13 +2916,15 @@ let rewrite_defs_effectful_let_expressions = else E_let (lb,body) in let e_internal_let = fun (lexp,exp1,exp2) -> - if effectful exp1 then - match lexp with - | LEXP_aux (LEXP_id id,annot) - | LEXP_aux (LEXP_cast (_,id),annot) -> + match lexp with + | LEXP_aux (LEXP_id id,annot) + | LEXP_aux (LEXP_cast (_,id),annot) -> + if effectful exp1 then E_internal_plet (P_aux (P_id id,annot),exp1,exp2) - | _ -> failwith "E_internal_plet with unexpected lexp" - else E_internal_let (lexp,exp1,exp2) in + else + let lb = LB_aux (LB_val_implicit (P_aux (P_id id,annot), exp1), annot) in + E_let (lb, exp2) + | _ -> failwith "E_internal_let with unexpected lexp" in let alg = { id_exp_alg with e_let = e_let; e_internal_let = e_internal_let } in rewrite_defs_base @@ -2873,93 +2950,30 @@ let eqidtyp (id1,_) (id2,_) = let name2 = match id2 with Id_aux ((Id name | DeIid name),_) -> name in name1 = name2 -let find_updated_vars (E_aux (_,(l,_)) as exp) = - let ( @@ ) (a,b) (a',b') = (a @ a',b @ b') in - let lapp2 (l : (('a list * 'b list) list)) : ('a list * 'b list) = - List.fold_left - (fun ((intros_acc : 'a list),(updates_acc : 'b list)) (intros,updates) -> - (intros_acc @ intros, updates_acc @ updates)) ([],[]) l in - - let (intros,updates) = - fold_exp - { e_aux = (fun (e,_) -> e) - ; e_id = (fun _ -> ([],[])) - ; e_lit = (fun _ -> ([],[])) - ; e_cast = (fun (_,e) -> e) - ; e_block = (fun es -> lapp2 es) - ; e_nondet = (fun es -> lapp2 es) - ; e_app = (fun (_,es) -> lapp2 es) - ; e_app_infix = (fun (e1,_,e2) -> e1 @@ e2) - ; e_tuple = (fun es -> lapp2 es) - ; e_if = (fun (e1,e2,e3) -> e1 @@ e2 @@ e3) - ; e_for = (fun (_,e1,e2,e3,_,e4) -> e1 @@ e2 @@ e3 @@ e4) - ; e_vector = (fun es -> lapp2 es) - ; e_vector_indexed = (fun (es,opt) -> opt @@ lapp2 (List.map snd es)) - ; e_vector_access = (fun (e1,e2) -> e1 @@ e2) - ; e_vector_subrange = (fun (e1,e2,e3) -> e1 @@ e2 @@ e3) - ; e_vector_update = (fun (e1,e2,e3) -> e1 @@ e2 @@ e3) - ; e_vector_update_subrange = (fun (e1,e2,e3,e4) -> e1 @@ e2 @@ e3 @@ e4) - ; e_vector_append = (fun (e1,e2) -> e1 @@ e2) - ; e_list = (fun es -> lapp2 es) - ; e_cons = (fun (e1,e2) -> e1 @@ e2) - ; e_record = (fun fexps -> fexps) - ; e_record_update = (fun (e1,fexp) -> e1 @@ fexp) - ; e_field = (fun (e1,id) -> e1) - ; e_case = (fun (e1,pexps) -> e1 @@ lapp2 pexps) - ; e_let = (fun (lb,e2) -> lb @@ e2) - ; e_assign = (fun ((ids,acc),e2) -> ([],ids) @@ acc @@ e2) - ; e_constraint = (fun nc -> ([],[])) - ; e_sizeof = (fun nexp -> ([],[])) - ; e_exit = (fun e1 -> ([],[])) - ; e_return = (fun e1 -> e1) - ; e_assert = (fun (e1,e2) -> ([],[])) - ; e_internal_cast = (fun (_,e1) -> e1) - ; e_internal_exp = (fun _ -> ([],[])) - ; e_internal_exp_user = (fun _ -> ([],[])) - ; e_comment = (fun _ -> ([],[])) - ; e_comment_struc = (fun _ -> ([],[])) - ; e_internal_let = - (fun ((ids,acc),e2,e3) -> - let id = match ids with - | [] -> raise (Reporting_basic.err_unreachable l "E_internal_let found not introducing a variable") - | [id] -> id - | _ -> raise (Reporting_basic.err_unreachable l "E_internal_let found introducing more than one variable") in - let (xs,ys) = ([id],[]) @@ acc @@ e2 @@ e3 in - let ys = List.filter (fun id2 -> not (eqidtyp id id2)) ys in - (xs,ys)) - ; e_internal_plet = (fun (_, e1, e2) -> e1 @@ e2) - ; e_internal_return = (fun e -> e) - ; lEXP_id = (fun id -> (Some id,[],([],[]))) - ; lEXP_memory = (fun (_,es) -> (None,[],lapp2 es)) - ; lEXP_cast = (fun (_,id) -> (Some id,[],([],[]))) - ; lEXP_tup = (fun tups -> failwith "FORCHRISTOPHER:: this needs implementing, not sure what you want to do") - ; lEXP_vector = (fun ((ids,acc),e1) -> (None,ids,acc @@ e1)) - ; lEXP_vector_range = (fun ((ids,acc),e1,e2) -> (None,ids,acc @@ e1 @@ e2)) - ; lEXP_field = (fun ((ids,acc),_) -> (None,ids,acc)) - ; lEXP_aux = - (function - | ((Some id,ids,acc),(annot)) -> - (match Env.lookup_id id (env_of_annot annot) with - | Unbound | Local _ -> ((id,annot) :: ids,acc) - | _ -> (ids,acc)) - | ((_,ids,acc),_) -> (ids,acc) - ) - ; fE_Fexp = (fun (_,e) -> e) - ; fE_aux = (fun (fexp,_) -> fexp) - ; fES_Fexps = (fun (fexps,_) -> lapp2 fexps) - ; fES_aux = (fun (fexp,_) -> fexp) - ; def_val_empty = ([],[]) - ; def_val_dec = (fun e -> e) - ; def_val_aux = (fun (defval,_) -> defval) - ; pat_exp = (fun (_,e) -> e) - ; pat_when = (fun (_,_,e) -> e) - ; pat_aux = (fun (pexp,_) -> pexp) - ; lB_val_explicit = (fun (_,_,e) -> e) - ; lB_val_implicit = (fun (_,e) -> e) - ; lB_aux = (fun (lb,_) -> lb) - ; pat_alg = id_pat_alg - } exp in - dedup eqidtyp updates +let find_introduced_vars exp = + let lEXP_aux ((ids,lexp),annot) = + let ids = match lexp, annot with + | LEXP_id id, (_, Some (env, _, _)) -> + (match Env.lookup_id id env with + | Unbound -> IdSet.add id ids + | _ -> ids) + | _ -> ids in + (ids, LEXP_aux (lexp, annot)) in + fst (fold_exp + { (compute_exp_alg IdSet.empty IdSet.union) with lEXP_aux = lEXP_aux } exp) + +let find_updated_vars exp = + let intros = find_introduced_vars exp in + let lEXP_aux ((ids,lexp),annot) = + let ids = match lexp, annot with + | LEXP_id id, (_, Some (env, _, _)) when not (IdSet.mem id intros) -> + (match Env.lookup_id id env with + | Local (Mutable, _) -> (id, annot) :: ids + | _ -> ids) + | _ -> ids in + (ids, LEXP_aux (lexp, annot)) in + dedup eqidtyp (fst (fold_exp + { (compute_exp_alg [] (@)) with lEXP_aux = lEXP_aux } exp)) let swaptyp typ (l,tannot) = match tannot with | Some (env, typ', eff) -> (l, Some (env, typ, eff)) |
