diff options
| author | Thomas Bauereiss | 2017-06-29 20:38:30 +0100 |
|---|---|---|
| committer | Thomas Bauereiss | 2017-06-29 21:22:24 +0100 |
| commit | 4ee5648506dce2675408d5ccf98318ff6003fb03 (patch) | |
| tree | 873decf90b23ed3a715fc724e2836ae86ad15484 | |
| parent | 0b3d26f0c7727631ac47c61ff88a16e0a217641d (diff) | |
Rewrite bitvector patterns
Seems to work for CHERI-MIPS, but still a few things to be done, e.g.
collecting let bindings for variables bound in bitvector patterns
| -rw-r--r-- | src/rewriter.ml | 482 |
1 files changed, 425 insertions, 57 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index d26879e9..0e70e9e3 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -63,15 +63,6 @@ type 'a rewriters = { let (>>) f g = fun x -> g(f(x)) -let fresh_name_counter = ref 0 - -let fresh_name () = - let current = !fresh_name_counter in - let () = fresh_name_counter := (current + 1) in - current -let reset_fresh_name_counter () = - fresh_name_counter := 0 - let get_effsum_annot (_,t) = match t with | Base (_,_,_,_,effs,_) -> effs | NoTyp -> failwith "no effect information" @@ -89,6 +80,31 @@ let get_type_annot (_,t) = match t with let get_type (E_aux (_,a)) = get_type_annot a +let get_loc (E_aux (_,(l,_))) = l + +let fresh_name_counter = ref 0 + +let fresh_name () = + let current = !fresh_name_counter in + let () = fresh_name_counter := (current + 1) in + current +let reset_fresh_name_counter () = + fresh_name_counter := 0 + +let fresh_id pre l = + let current = fresh_name () in + Id_aux (Id (pre ^ string_of_int current), Parse_ast.Generated l) + +let fresh_id_exp pre ((l,_) as annot) = + let id = fresh_id pre l in + let annot_var = (Parse_ast.Generated l,simple_annot (get_type_annot annot)) in + E_aux (E_id id, annot_var) + +let fresh_id_pat pre ((l,_) as annot) = + let id = fresh_id pre l in + let annot_var = (Parse_ast.Generated l,simple_annot (get_type_annot annot)) in + P_aux (P_id id, annot_var) + let union_effs effs = List.fold_left (fun acc eff -> union_effects acc eff) pure_e effs @@ -210,6 +226,11 @@ let updates_vars_effs {effect = Eset effs} = let updates_vars eaux = updates_vars_effs (get_effsum_exp eaux) +let id_to_string (Id_aux(id,l)) = + match id with + | Id(s) -> s + | DeIid(s) -> s + let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with | [] -> None @@ -217,6 +238,10 @@ let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b let mk_atom_typ i = {t=Tapp("atom",[TA_nexp i])} +let simple_num l n : tannot exp = + let typ = simple_annot (mk_atom_typ (mk_c (big_int_of_int n))) in + E_aux (E_lit (L_aux (L_num n,l)), (l,typ)) + let rec rewrite_nexp_to_exp program_vars l nexp = let rewrite n = rewrite_nexp_to_exp program_vars l n in let typ = mk_atom_typ nexp in @@ -832,10 +857,8 @@ let remove_vector_concat_pat pat = let pat = remove_typed_patterns pat in - let fresh_name l = - let current = fresh_name () in - Id_aux (Id ("v__" ^ string_of_int current), Parse_ast.Generated l) in - + let fresh_id_v = fresh_id "v__" in + (* expects that P_typ elements have been removed from AST, that the length of all vectors involved is known, that we don't have indexed vectors *) @@ -860,7 +883,7 @@ let remove_vector_concat_pat pat = | P_vector_concat pats -> (if contained_in_p_as then P_aux (pat,annot) - else P_aux (P_as (P_aux (pat,annot),fresh_name l),annot)) + else P_aux (P_as (P_aux (pat,annot),fresh_id_v l),annot)) | _ -> P_aux (pat,annot) ) ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) @@ -873,7 +896,7 @@ let remove_vector_concat_pat pat = let name_vector_concat_elements = let p_vector_concat pats = let aux ((P_aux (p,((l,_) as a))) as pat) = match p with - | P_vector _ -> P_aux (P_as (pat,fresh_name l),a) + | P_vector _ -> P_aux (P_as (pat,fresh_id_v l),a) | P_id id -> P_aux (P_id id,a) | P_as (p,id) -> P_aux (P_as (p,id),a) | P_wild -> P_aux (P_wild,a) @@ -908,17 +931,13 @@ let remove_vector_concat_pat pat = let (Id_aux (Id rootname,_)) = rootid in let (Id_aux (Id childname,_)) = child in - let simple_num n : tannot exp = - let typ = simple_annot (mk_atom_typ (mk_c (big_int_of_int n))) in - E_aux (E_lit (L_aux (L_num n,l)), (l,typ)) in - let vlength_info (Base ((_,{t = Tapp("vector",[_;TA_nexp nexp;_;_])}),_,_,_,_,_)) = nexp in let root : tannot exp = E_aux (E_id rootid,rannot) in - let index_i = simple_num i in + let index_i = simple_num l i in let index_j : tannot exp = match j with - | Some j -> simple_num j + | Some j -> simple_num l j | None -> let length_root_nexp = vlength_info (snd rannot) in let length_app_exp : tannot exp = @@ -950,41 +969,31 @@ let remove_vector_concat_pat pat = let p_aux = function | ((P_as (P_aux (P_vector_concat pats,rannot'),rootid),decls),rannot) -> let aux (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) = match cannot with - | (_,Base((_,{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_,_)) - | (_,Base((_,{t = Tabbrev (_,{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])})}),_,_,_,_,_)) -> - let length = int_of_big_int length in + | (l,Base((_,({t = Tapp ("vector",[_;TA_nexp length;_;_])} as t)),_,_,_,_,_)) + | (l,Base((_,({t = Tabbrev (_,{t = Tapp ("vector",[_;TA_nexp length;_;_])})} as t)),_,_,_,_,_)) -> + let (pos',index_j) = match has_const_vector_length t with + | Some i -> + let length = int_of_big_int i in + (pos+length, Some(pos+length-1)) + | None -> + if is_last then (pos,None) + else + raise + (Reporting_basic.err_unreachable + l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) in (match p with (* if we see a named vector pattern, remove the name and remember to declare it later *) | P_as (P_aux (p,cannot),cname) -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,Some(pos+length-1)) in - (pos + length, pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) + let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in + (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) (* if we see a P_id variable, remember to declare it later *) | P_id cname -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,Some(pos+length-1)) in - (pos + length, pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) + let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in + (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) (* normal vector patterns are fine *) - | _ -> (pos + length, pat_acc @ [P_aux (p,cannot)],decl_acc) ) + | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc) ) (* non-vector patterns aren't *) - | (l,Base((_,{t = Tapp ("vector",[_;_;_;_])}),_,_,_,_,_)) - | (l,Base((_,{t = Tabbrev (_,{t = Tapp ("vector",[_;_;_;_])})}),_,_,_,_,_)) -> - if is_last then - match p with - (* if we see a named vector pattern, remove the name and remember to - declare it later *) - | P_as (P_aux (p,cannot),cname) -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,None) in - (pos, pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)]) - (* if we see a P_id variable, remember to declare it later *) - | P_id cname -> - let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,None) in - (pos, pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)]) - (* normal vector patterns are fine *) - | _ -> (pos, pat_acc @ [P_aux (p,cannot)],decl_acc) - else - raise - (Reporting_basic.err_unreachable - l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) | (l,Base((_,t),_,_,_,_,_)) -> raise (Reporting_basic.err_unreachable @@ -1185,7 +1194,372 @@ let rewrite_defs_remove_vector_concat defs = rewrite_defs_base rewrite_fun = rewrite_fun_remove_vector_concat_pat; rewrite_def = rewrite_def; rewrite_defs = rewrite_defs_remove_vector_concat_pat} defs - + +let map_default f = function +| None -> None +| Some x -> f x + +let rec binop_opt f x y = match x, y with +| None, None -> None +| Some x, None -> Some x +| None, Some y -> Some y +| Some x, Some y -> Some (f x y) + +let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with +| P_lit _ | P_wild _ | P_id _ -> false +| P_as (pat,_) | P_typ (_,pat) -> contains_bitvector_pat pat +| P_vector _ | P_vector_concat _ | P_vector_indexed _ -> + is_bit_vector (get_type_annot annot) +| P_app (_,pats) | P_tup pats | P_list pats -> + List.exists contains_bitvector_pat pats +| P_record (fpats,_) -> + List.exists (fun (FP_aux (FP_Fpat (_,pat),_)) -> contains_bitvector_pat pat) fpats + +let remove_bitvector_pat pat = + + (* first introduce names for bitvector patterns *) + let name_bitvector_roots = + { p_lit = (fun lit -> P_lit lit) + ; p_typ = (fun (typ,p) -> P_typ (typ,p false)) + ; p_wild = P_wild + ; p_as = (fun (pat,id) -> P_as (pat true,id)) + ; p_id = (fun id -> P_id id) + ; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps)) + ; p_record = (fun (fpats,b) -> P_record (fpats, b)) + ; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps)) + ; p_vector_indexed = (fun ps -> P_vector_indexed (List.map (fun (i,p) -> (i,p false)) ps)) + ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps)) + ; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps)) + ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps)) + ; p_aux = + (fun (pat,annot) contained_in_p_as -> + match pat, annot with + | P_vector _, (l, Base((_,t),_,_,_,_,_)) + | P_vector_indexed _, (l, Base((_,t),_,_,_,_,_)) -> + (if is_bit_vector t && not contained_in_p_as + then P_aux (P_as (P_aux (pat,annot),fresh_id "b__" l), annot) + else P_aux (pat,annot)) + | _ -> P_aux (pat,annot) + ) + ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) + ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false)) + } in + + let pat = (fold_pat name_bitvector_roots pat) false in + + let bit_annot l eaux = + let bitannot = (Parse_ast.Generated l, simple_annot {t = Tid "bit"}) in + E_aux (eaux, bitannot) in + + let access_bit_exp (rootid,rannot) l idx = + let root : tannot exp = E_aux (E_id rootid,rannot) in + let idx_exp = simple_num l idx in + bit_annot l (E_vector_access (root,idx_exp)) in + + let bitwise_and exp1 exp2 = + let (E_aux (_,(l,_))) = exp1 in + let andid = Id_aux (Id "&", Parse_ast.Generated l) in + let andannot = (Parse_ast.Generated l, + tag_annot {t = Tid "bit"} (External (Some "bitwise_and_bit"))) in + let andexp : tannot exp = E_aux (E_app_infix(exp1,andid,exp2), andannot) in + andexp in + + let compose_guards guards (*root_bv*) = + (*let guards = List.map ((|>) root_bv) guards in*) + List.fold_right (binop_opt bitwise_and) guards None in + + let test_bit_exp rootid t idx (pat, guard) = + (match guard with + | Some exp -> + let (P_aux (_,(l,_))) = pat in + let rannot = (Parse_ast.Generated l, simple_annot t) in + let elem = access_bit_exp (rootid,rannot) l idx in + let eqid = Id_aux (Id "==", Parse_ast.Generated l) in + let eqannot = (Parse_ast.Generated l, + tag_annot {t = Tid "bit"} (External (Some "eq_bit"))) in + let eqexp : tannot exp = E_aux (E_app_infix(elem,eqid,exp), eqannot) in + Some (eqexp) + | None -> None) in + + (* TODO: Collect let bindings for bits that are bound via P_as or P_id *) + let guard_bitvector_pat = + { p_lit = (fun lit bvid t -> (P_lit lit, match bvid, (normalize_t t).t with Some (Id_aux (_,l)), Tid "bit" -> Some (bit_annot l (E_lit lit)) | _, _ -> None)) + ; p_wild = (fun _ _ -> (P_wild, None)) + ; p_as = (fun (p,id) _ _ -> let (pat,guard) = p (Some id) in (P_as (pat,id), guard)) + ; p_typ = (fun (typ,p) _ _ -> let (pat,guards) = p None in (P_typ (typ,pat), guards)) + ; p_id = (fun id _ _ -> (P_id id, None)) + ; p_app = (fun (id,ps) _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in + (P_app (id,ps), compose_guards guards)) + ; p_record = (fun (ps,b) _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in + (P_record (ps,b), compose_guards guards)) + ; p_vector = (fun p bvid t -> let p = List.map ((|>) bvid) p in + let (ps,guards) = List.split p in + (*let guard = (match bvid with + | Some (Id_aux (_,l)) -> Some (bit_annot l (E_lit (L_aux (L_true,l)))) + | None -> None) in*) + let guards = (match bvid, is_bit_vector t with + | Some id, true -> List.mapi (test_bit_exp id t) p + | _, _ -> guards) in + (*let guards' = (function + | Some root_bv -> + let tests = List.mapi (test_bit_exp root_bv) p in + List.fold_right (binop_opt bitwise_and) tests None + | None -> compose_guards guards None) in*) + (P_vector ps, compose_guards guards)) + ; p_vector_indexed = (fun p bvid t -> let (is,p) = List.split p in + let p = List.map ((|>) bvid) p in + let (ps,guards) = List.split p in + let guards = (match bvid, is_bit_vector t with + | Some id, true -> List.map2 (test_bit_exp id t) is p + | _, _ -> guards) in + (*let guards' = (function + | Some root_bv -> + let tests = List.map2 (test_bit_exp root_bv) is p in + List.fold_right (binop_opt bitwise_and) tests None + | None -> compose_guards guards None) in*) + let ps = List.combine is ps in + (P_vector_indexed ps, (*compose_guards guards*) None)) + ; p_vector_concat = (fun ps _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in + (P_vector_concat ps, compose_guards guards)) + ; p_tup = (fun ps _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in + (P_tup ps, compose_guards guards)) + ; p_list = (fun ps _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in + (P_list ps, compose_guards guards)) + ; p_aux = (fun (p,annot) bvid -> + let (l,Base((_,t),_,_,_,_,_)) = annot in + let (pat,guard) = p bvid t in + (*(P_aux (pat,annot), guard))*) + (match pat, is_bit_vector t with + | P_as (P_aux (P_vector _, _), id), true + | P_as (P_aux (P_vector_indexed _, _), id), true -> + (P_aux (P_id id, annot), guard) + | _, _ -> (P_aux (pat,annot), guard))) + ; fP_aux = (fun ((fpat,guard),annot) _ -> (FP_aux (fpat,annot), guard)) + ; fP_Fpat = (fun (id,p) -> let (pat,guard) = p None in (FP_Fpat (id,pat), guard)) + } in + + let (pat,guard) = fold_pat guard_bitvector_pat pat None in + (pat, guard) + +let remove_wildcards pre (P_aux (_,(l,_)) as pat) = + fold_pat + {id_pat_alg with + p_aux = function + | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot)) + | (p,annot) -> P_aux (p,annot) } + pat + +(* Check if one pattern subsumes the other, and if so, calculate a + substitution of variables that are used in the same position. + TODO: Check somewhere that there are no variable clashes (the same variable + name used in different positions of the patterns) + *) +let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,_) as pat2) = + let rewrap p = P_aux (p,annot1) in + match p1, p2 with + | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) -> + if lit1 = lit2 then Some [] else None + | P_as (pat1,_), _ -> subsumes_pat pat1 pat2 + | _, P_as (pat2,_) -> subsumes_pat pat1 pat2 + | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2 + | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2 + | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) -> + if id1 = id2 then Some [] else Some [(id2,id1)] + | P_id id1, _ -> Some [] + | P_wild, _ -> Some [] + | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) -> + if id1 = id2 then subsumes_pat_list args1 args2 else None + | P_vector pats1, P_vector pats2 + | P_tup pats1, P_tup pats2 + | P_list pats1, P_list pats2 -> + subsumes_pat_list pats1 pats2 + (* TODO: records *) + (* TODO: indexed vectors, vector concats *) + | _ -> None +and subsumes_pat_list pats1 pats2 = + if List.length pats1 = List.length pats2 + then + let subs = List.map2 subsumes_pat pats1 pats2 in + List.fold_right + (fun p acc -> match p, acc with + | Some subst, Some substs -> Some (subst @ substs) + | _ -> None) + subs (Some []) + else None + +let equiv_pats pat1 pat2 = + match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with + | Some _, Some _ -> true + | _, _ -> false + +let subst_id_pat pat (id1,id2) = + let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in + fold_pat {id_pat_alg with p_id = p_id} pat + +let subst_id_exp exp (id1,id2) = + (* TODO Don't substitute bound occurrences inside let expressions etc *) + let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in + fold_exp {id_exp_alg with e_id = e_id} exp + +let gen_annot l t efr = (Parse_ast.Generated l,simple_annot_efr t efr) + +let rec pat_to_exp (P_aux (pat,(l,annot))) = + let rewrap e = E_aux (e,(l,annot)) in + match pat with + | P_lit lit -> rewrap (E_lit lit) + | P_wild -> raise (Reporting_basic.err_unreachable l + "pat_to_exp given wildcard pattern") + | P_as (pat,id) -> rewrap (E_id id) + | P_typ (_,pat) -> pat_to_exp pat + | P_id id -> rewrap (E_id id) + | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats)) + | P_record (fpats,b) -> + rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot)))) + | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats)) + | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l + "pat_to_exp not implemented for P_vector_concat") + (* We assume that vector concatenation patterns have been transformed + away already *) + | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats)) + | P_list pats -> rewrap (E_list (List.map pat_to_exp pats)) + | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l + "pat_to_exp not implemented for P_vector_indexed") + (* TODO: We can't guess the default value for the indexed vector + expression here. We should make sure that indexed vector patterns are + bound to a variable via P_as before calling pat_to_exp *) +and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) = + FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot)) + +let case_exp e t cs = + let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in + let ps = List.map pexp cs in + (* let efr = union_effs (List.map get_effsum_pexp ps) in *) + fix_effsum_exp (E_aux (E_case (e,ps), gen_annot (get_loc e) t pure_e)) + +let rewrite_guarded_clauses l cs = + let rec group clauses = + let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in + let rec group_aux current acc = (function + | ((pat,guard,body,annot) as c) :: cs -> + let (current_pat,_,_) = current in + (match subsumes_pat current_pat pat with + | Some substs -> + let pat' = List.fold_left subst_id_pat pat substs in + let guard' = (match guard with + | Some exp -> Some (List.fold_left subst_id_exp exp substs) + | None -> None) in + let body' = List.fold_left subst_id_exp body substs in + let c' = (pat',guard',body',annot) in + group_aux (add_clause current c') acc cs + | None -> + let pat = remove_wildcards "g__" pat in + group_aux (pat,[c],annot) (acc @ [current]) cs) + | [] -> acc @ [current]) in + let groups = match clauses with + | ((pat,guard,body,annot) as c) :: cs -> + group_aux (remove_wildcards "g__" pat, [c], annot) [] cs + | _ -> + raise (Reporting_basic.err_unreachable l + "group given empty list in rewrite_guarded_clauses") in + List.map (fun cs -> if_pexp cs) groups + and if_pexp (pat,cs,annot) = (match cs with + | c :: _ -> + (* fix_effsum_pexp (pexp *) + let body = if_exp pat cs in + let pexp = fix_effsum_pexp (Pat_aux (Pat_exp (pat,body),annot)) in + let (Pat_aux (Pat_exp (_,_),annot)) = pexp in + (pat, body, annot) + | [] -> + raise (Reporting_basic.err_unreachable l + "if_pexp given empty list in rewrite_guarded_clauses")) + and if_exp current_pat = (function + | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs -> + (match guard with + | Some exp -> + let else_exp = + if equiv_pats current_pat pat' + then if_exp current_pat (c' :: cs) + else case_exp (pat_to_exp current_pat) (get_type_annot annot') (group (c' :: cs)) in + fix_effsum_exp (E_aux (E_if (exp,body,else_exp), annot)) + | None -> body) + | [(pat,guard,body,annot)] -> body + | [] -> + raise (Reporting_basic.err_unreachable l + "if_exp given empty list in rewrite_guarded_clauses")) in + group cs + +let rewrite_exp_remove_bitvector_pat rewriters nmap (E_aux (exp,(l,annot)) as full_exp) = + let rewrap e = E_aux (e,(l,annot)) in + let rewrite_rec = rewriters.rewrite_exp rewriters nmap in + let rewrite_base = rewrite_exp rewriters nmap in + match exp with + | E_case (e,ps) + when List.exists (fun (Pat_aux (Pat_exp (pat,_),_)) -> contains_bitvector_pat pat) ps -> + let clause (Pat_aux (Pat_exp (pat,body),annot')) = + let (pat',guard) = remove_bitvector_pat pat in + (pat',guard,rewrite_rec body,annot') in + let clauses = rewrite_guarded_clauses l (List.map clause ps) in + if (effectful e) then + let e = rewrite_rec e in + let (E_aux (_,(el,eannot))) = e in + let pat_e' = fresh_id_pat "p__" (el,eannot) in + let exp_e' = pat_to_exp pat_e' in + (* let fresh = fresh_id "p__" el in + let exp_e' = E_aux (E_id fresh, gen_annot l (get_type e) pure_e) in + let pat_e' = P_aux (P_id fresh, gen_annot l (get_type e) pure_e) in *) + let letbind_e = LB_aux (LB_val_implicit (pat_e',e), gen_annot l (get_type e) (get_effsum_exp e)) in + let exp' = case_exp exp_e' (get_type full_exp) clauses in + rewrap (E_let (letbind_e, exp')) + else case_exp e (get_type full_exp) clauses + (*| E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) -> + let (pat,_,decls) = remove_vector_concat_pat pat in + rewrap (E_let (LB_aux (LB_val_explicit (typ,pat,rewrite_rec v),annot'), + decls (rewrite_rec body))) + | E_let (LB_aux (LB_val_implicit (pat,v),annot'),body) -> + let (pat,_,decls) = remove_vector_concat_pat pat in + rewrap (E_let (LB_aux (LB_val_implicit (pat,rewrite_rec v),annot'), + decls (rewrite_rec body)))*) + | _ -> rewrite_base full_exp + +let rewrite_fun_remove_bitvector_pat + rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = + let _ = reset_fresh_name_counter () in + (* TODO Can there be clauses with different id's in one FD_function? *) + let funcls = match funcls with + | (FCL_aux (FCL_Funcl(id,_,_),_) :: _) -> + let clause (FCL_aux (FCL_Funcl(_,pat,exp),annot)) = + let (pat,guard) = remove_bitvector_pat pat in + let exp = rewriters.rewrite_exp rewriters None exp in + (pat,guard,exp,annot) in + let cs = rewrite_guarded_clauses l (List.map clause funcls) in + List.map (fun (pat,exp,annot) -> FCL_aux (FCL_Funcl(id,pat,exp),annot)) cs + | _ -> funcls (* TODO is the empty list possible here? *) in + FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot)) + +(*let rewrite_defs_remove_vector_concat_pat rewriters (Defs defs) = + let rewrite_def d = + let d = rewriters.rewrite_def rewriters d in + match d with + | DEF_val (LB_aux (LB_val_explicit (t,pat,exp),a)) -> + let (pat,letbinds,_) = remove_vector_concat_pat pat in + let defvals = List.map (fun lb -> DEF_val lb) letbinds in + [DEF_val (LB_aux (LB_val_explicit (t,pat,exp),a))] @ defvals + | DEF_val (LB_aux (LB_val_implicit (pat,exp),a)) -> + let (pat,letbinds,_) = remove_vector_concat_pat pat in + let defvals = List.map (fun lb -> DEF_val lb) letbinds in + [DEF_val (LB_aux (LB_val_implicit (pat,exp),a))] @ defvals + | d -> [rewriters.rewrite_def rewriters d] in + Defs (List.flatten (List.map rewrite_def defs))*) + +let rewrite_defs_remove_bitvector_pats defs = rewrite_defs_base + {rewrite_exp = rewrite_exp_remove_bitvector_pat; + rewrite_pat = rewrite_pat; + rewrite_let = rewrite_let; + rewrite_lexp = rewrite_lexp; + rewrite_fun = rewrite_fun_remove_bitvector_pat; + rewrite_def = rewrite_def; + rewrite_defs = rewrite_defs_base } defs + (*Expects to be called after rewrite_defs; thus the following should not appear: internal_exp of any form lit vectors in patterns or expressions @@ -1384,12 +1758,6 @@ let rewrite_defs_remove_blocks = -let fresh_id ((l,_) as annot) = - let current = fresh_name () in - let id = Id_aux (Id ("w__" ^ string_of_int current), Parse_ast.Generated l) in - let annot_var = (Parse_ast.Generated l,simple_annot (get_type_annot annot)) in - E_aux (E_id id, annot_var) - let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = (* body is a function : E_id variable -> actual body *) match get_type v with @@ -1405,7 +1773,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) | _ -> let (E_aux (_,((l,_) as annot))) = v in - let ((E_aux (E_id id,_)) as e_id) = fresh_id annot in + let ((E_aux (E_id id,_)) as e_id) = fresh_id_exp "w__" annot in let body = body e_id in let annot_pat = (Parse_ast.Generated l,simple_annot (get_type v)) in @@ -2146,6 +2514,7 @@ let rewrite_defs_remove_e_assign = let rewrite_defs_lem = top_sort_defs >> rewrite_defs_remove_vector_concat >> + rewrite_defs_remove_bitvector_pats >> rewrite_defs_exp_lift_assign >> rewrite_defs_remove_blocks >> rewrite_defs_letbind_effects >> @@ -2154,4 +2523,3 @@ let rewrite_defs_lem = rewrite_defs_remove_superfluous_letbinds >> rewrite_defs_remove_superfluous_returns - |
