summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/rewriter.ml482
1 files changed, 425 insertions, 57 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index d26879e9..0e70e9e3 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -63,15 +63,6 @@ type 'a rewriters = {
let (>>) f g = fun x -> g(f(x))
-let fresh_name_counter = ref 0
-
-let fresh_name () =
- let current = !fresh_name_counter in
- let () = fresh_name_counter := (current + 1) in
- current
-let reset_fresh_name_counter () =
- fresh_name_counter := 0
-
let get_effsum_annot (_,t) = match t with
| Base (_,_,_,_,effs,_) -> effs
| NoTyp -> failwith "no effect information"
@@ -89,6 +80,31 @@ let get_type_annot (_,t) = match t with
let get_type (E_aux (_,a)) = get_type_annot a
+let get_loc (E_aux (_,(l,_))) = l
+
+let fresh_name_counter = ref 0
+
+let fresh_name () =
+ let current = !fresh_name_counter in
+ let () = fresh_name_counter := (current + 1) in
+ current
+let reset_fresh_name_counter () =
+ fresh_name_counter := 0
+
+let fresh_id pre l =
+ let current = fresh_name () in
+ Id_aux (Id (pre ^ string_of_int current), Parse_ast.Generated l)
+
+let fresh_id_exp pre ((l,_) as annot) =
+ let id = fresh_id pre l in
+ let annot_var = (Parse_ast.Generated l,simple_annot (get_type_annot annot)) in
+ E_aux (E_id id, annot_var)
+
+let fresh_id_pat pre ((l,_) as annot) =
+ let id = fresh_id pre l in
+ let annot_var = (Parse_ast.Generated l,simple_annot (get_type_annot annot)) in
+ P_aux (P_id id, annot_var)
+
let union_effs effs =
List.fold_left (fun acc eff -> union_effects acc eff) pure_e effs
@@ -210,6 +226,11 @@ let updates_vars_effs {effect = Eset effs} =
let updates_vars eaux = updates_vars_effs (get_effsum_exp eaux)
+let id_to_string (Id_aux(id,l)) =
+ match id with
+ | Id(s) -> s
+ | DeIid(s) -> s
+
let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with
| [] -> None
@@ -217,6 +238,10 @@ let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b
let mk_atom_typ i = {t=Tapp("atom",[TA_nexp i])}
+let simple_num l n : tannot exp =
+ let typ = simple_annot (mk_atom_typ (mk_c (big_int_of_int n))) in
+ E_aux (E_lit (L_aux (L_num n,l)), (l,typ))
+
let rec rewrite_nexp_to_exp program_vars l nexp =
let rewrite n = rewrite_nexp_to_exp program_vars l n in
let typ = mk_atom_typ nexp in
@@ -832,10 +857,8 @@ let remove_vector_concat_pat pat =
let pat = remove_typed_patterns pat in
- let fresh_name l =
- let current = fresh_name () in
- Id_aux (Id ("v__" ^ string_of_int current), Parse_ast.Generated l) in
-
+ let fresh_id_v = fresh_id "v__" in
+
(* expects that P_typ elements have been removed from AST,
that the length of all vectors involved is known,
that we don't have indexed vectors *)
@@ -860,7 +883,7 @@ let remove_vector_concat_pat pat =
| P_vector_concat pats ->
(if contained_in_p_as
then P_aux (pat,annot)
- else P_aux (P_as (P_aux (pat,annot),fresh_name l),annot))
+ else P_aux (P_as (P_aux (pat,annot),fresh_id_v l),annot))
| _ -> P_aux (pat,annot)
)
; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot))
@@ -873,7 +896,7 @@ let remove_vector_concat_pat pat =
let name_vector_concat_elements =
let p_vector_concat pats =
let aux ((P_aux (p,((l,_) as a))) as pat) = match p with
- | P_vector _ -> P_aux (P_as (pat,fresh_name l),a)
+ | P_vector _ -> P_aux (P_as (pat,fresh_id_v l),a)
| P_id id -> P_aux (P_id id,a)
| P_as (p,id) -> P_aux (P_as (p,id),a)
| P_wild -> P_aux (P_wild,a)
@@ -908,17 +931,13 @@ let remove_vector_concat_pat pat =
let (Id_aux (Id rootname,_)) = rootid in
let (Id_aux (Id childname,_)) = child in
- let simple_num n : tannot exp =
- let typ = simple_annot (mk_atom_typ (mk_c (big_int_of_int n))) in
- E_aux (E_lit (L_aux (L_num n,l)), (l,typ)) in
-
let vlength_info (Base ((_,{t = Tapp("vector",[_;TA_nexp nexp;_;_])}),_,_,_,_,_)) =
nexp in
let root : tannot exp = E_aux (E_id rootid,rannot) in
- let index_i = simple_num i in
+ let index_i = simple_num l i in
let index_j : tannot exp = match j with
- | Some j -> simple_num j
+ | Some j -> simple_num l j
| None ->
let length_root_nexp = vlength_info (snd rannot) in
let length_app_exp : tannot exp =
@@ -950,41 +969,31 @@ let remove_vector_concat_pat pat =
let p_aux = function
| ((P_as (P_aux (P_vector_concat pats,rannot'),rootid),decls),rannot) ->
let aux (pos,pat_acc,decl_acc) (P_aux (p,cannot),is_last) = match cannot with
- | (_,Base((_,{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_,_))
- | (_,Base((_,{t = Tabbrev (_,{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])})}),_,_,_,_,_)) ->
- let length = int_of_big_int length in
+ | (l,Base((_,({t = Tapp ("vector",[_;TA_nexp length;_;_])} as t)),_,_,_,_,_))
+ | (l,Base((_,({t = Tabbrev (_,{t = Tapp ("vector",[_;TA_nexp length;_;_])})} as t)),_,_,_,_,_)) ->
+ let (pos',index_j) = match has_const_vector_length t with
+ | Some i ->
+ let length = int_of_big_int i in
+ (pos+length, Some(pos+length-1))
+ | None ->
+ if is_last then (pos,None)
+ else
+ raise
+ (Reporting_basic.err_unreachable
+ l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern")) in
(match p with
(* if we see a named vector pattern, remove the name and remember to
declare it later *)
| P_as (P_aux (p,cannot),cname) ->
- let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,Some(pos+length-1)) in
- (pos + length, pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)])
+ let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in
+ (pos', pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)])
(* if we see a P_id variable, remember to declare it later *)
| P_id cname ->
- let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,Some(pos+length-1)) in
- (pos + length, pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)])
+ let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,index_j) in
+ (pos', pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)])
(* normal vector patterns are fine *)
- | _ -> (pos + length, pat_acc @ [P_aux (p,cannot)],decl_acc) )
+ | _ -> (pos', pat_acc @ [P_aux (p,cannot)],decl_acc) )
(* non-vector patterns aren't *)
- | (l,Base((_,{t = Tapp ("vector",[_;_;_;_])}),_,_,_,_,_))
- | (l,Base((_,{t = Tabbrev (_,{t = Tapp ("vector",[_;_;_;_])})}),_,_,_,_,_)) ->
- if is_last then
- match p with
- (* if we see a named vector pattern, remove the name and remember to
- declare it later *)
- | P_as (P_aux (p,cannot),cname) ->
- let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,None) in
- (pos, pat_acc @ [P_aux (p,cannot)], decl_acc @ [((lb,decl),info)])
- (* if we see a P_id variable, remember to declare it later *)
- | P_id cname ->
- let (lb,decl,info) = letbind_vec (rootid,rannot) (cname,cannot) (pos,None) in
- (pos, pat_acc @ [P_aux (P_id cname,cannot)], decl_acc @ [((lb,decl),info)])
- (* normal vector patterns are fine *)
- | _ -> (pos, pat_acc @ [P_aux (p,cannot)],decl_acc)
- else
- raise
- (Reporting_basic.err_unreachable
- l ("unname_vector_concat_elements: vector of unspecified length in vector-concat pattern"))
| (l,Base((_,t),_,_,_,_,_)) ->
raise
(Reporting_basic.err_unreachable
@@ -1185,7 +1194,372 @@ let rewrite_defs_remove_vector_concat defs = rewrite_defs_base
rewrite_fun = rewrite_fun_remove_vector_concat_pat;
rewrite_def = rewrite_def;
rewrite_defs = rewrite_defs_remove_vector_concat_pat} defs
-
+
+let map_default f = function
+| None -> None
+| Some x -> f x
+
+let rec binop_opt f x y = match x, y with
+| None, None -> None
+| Some x, None -> Some x
+| None, Some y -> Some y
+| Some x, Some y -> Some (f x y)
+
+let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with
+| P_lit _ | P_wild _ | P_id _ -> false
+| P_as (pat,_) | P_typ (_,pat) -> contains_bitvector_pat pat
+| P_vector _ | P_vector_concat _ | P_vector_indexed _ ->
+ is_bit_vector (get_type_annot annot)
+| P_app (_,pats) | P_tup pats | P_list pats ->
+ List.exists contains_bitvector_pat pats
+| P_record (fpats,_) ->
+ List.exists (fun (FP_aux (FP_Fpat (_,pat),_)) -> contains_bitvector_pat pat) fpats
+
+let remove_bitvector_pat pat =
+
+ (* first introduce names for bitvector patterns *)
+ let name_bitvector_roots =
+ { p_lit = (fun lit -> P_lit lit)
+ ; p_typ = (fun (typ,p) -> P_typ (typ,p false))
+ ; p_wild = P_wild
+ ; p_as = (fun (pat,id) -> P_as (pat true,id))
+ ; p_id = (fun id -> P_id id)
+ ; p_app = (fun (id,ps) -> P_app (id, List.map (fun p -> p false) ps))
+ ; p_record = (fun (fpats,b) -> P_record (fpats, b))
+ ; p_vector = (fun ps -> P_vector (List.map (fun p -> p false) ps))
+ ; p_vector_indexed = (fun ps -> P_vector_indexed (List.map (fun (i,p) -> (i,p false)) ps))
+ ; p_vector_concat = (fun ps -> P_vector_concat (List.map (fun p -> p false) ps))
+ ; p_tup = (fun ps -> P_tup (List.map (fun p -> p false) ps))
+ ; p_list = (fun ps -> P_list (List.map (fun p -> p false) ps))
+ ; p_aux =
+ (fun (pat,annot) contained_in_p_as ->
+ match pat, annot with
+ | P_vector _, (l, Base((_,t),_,_,_,_,_))
+ | P_vector_indexed _, (l, Base((_,t),_,_,_,_,_)) ->
+ (if is_bit_vector t && not contained_in_p_as
+ then P_aux (P_as (P_aux (pat,annot),fresh_id "b__" l), annot)
+ else P_aux (pat,annot))
+ | _ -> P_aux (pat,annot)
+ )
+ ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot))
+ ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false))
+ } in
+
+ let pat = (fold_pat name_bitvector_roots pat) false in
+
+ let bit_annot l eaux =
+ let bitannot = (Parse_ast.Generated l, simple_annot {t = Tid "bit"}) in
+ E_aux (eaux, bitannot) in
+
+ let access_bit_exp (rootid,rannot) l idx =
+ let root : tannot exp = E_aux (E_id rootid,rannot) in
+ let idx_exp = simple_num l idx in
+ bit_annot l (E_vector_access (root,idx_exp)) in
+
+ let bitwise_and exp1 exp2 =
+ let (E_aux (_,(l,_))) = exp1 in
+ let andid = Id_aux (Id "&", Parse_ast.Generated l) in
+ let andannot = (Parse_ast.Generated l,
+ tag_annot {t = Tid "bit"} (External (Some "bitwise_and_bit"))) in
+ let andexp : tannot exp = E_aux (E_app_infix(exp1,andid,exp2), andannot) in
+ andexp in
+
+ let compose_guards guards (*root_bv*) =
+ (*let guards = List.map ((|>) root_bv) guards in*)
+ List.fold_right (binop_opt bitwise_and) guards None in
+
+ let test_bit_exp rootid t idx (pat, guard) =
+ (match guard with
+ | Some exp ->
+ let (P_aux (_,(l,_))) = pat in
+ let rannot = (Parse_ast.Generated l, simple_annot t) in
+ let elem = access_bit_exp (rootid,rannot) l idx in
+ let eqid = Id_aux (Id "==", Parse_ast.Generated l) in
+ let eqannot = (Parse_ast.Generated l,
+ tag_annot {t = Tid "bit"} (External (Some "eq_bit"))) in
+ let eqexp : tannot exp = E_aux (E_app_infix(elem,eqid,exp), eqannot) in
+ Some (eqexp)
+ | None -> None) in
+
+ (* TODO: Collect let bindings for bits that are bound via P_as or P_id *)
+ let guard_bitvector_pat =
+ { p_lit = (fun lit bvid t -> (P_lit lit, match bvid, (normalize_t t).t with Some (Id_aux (_,l)), Tid "bit" -> Some (bit_annot l (E_lit lit)) | _, _ -> None))
+ ; p_wild = (fun _ _ -> (P_wild, None))
+ ; p_as = (fun (p,id) _ _ -> let (pat,guard) = p (Some id) in (P_as (pat,id), guard))
+ ; p_typ = (fun (typ,p) _ _ -> let (pat,guards) = p None in (P_typ (typ,pat), guards))
+ ; p_id = (fun id _ _ -> (P_id id, None))
+ ; p_app = (fun (id,ps) _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in
+ (P_app (id,ps), compose_guards guards))
+ ; p_record = (fun (ps,b) _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in
+ (P_record (ps,b), compose_guards guards))
+ ; p_vector = (fun p bvid t -> let p = List.map ((|>) bvid) p in
+ let (ps,guards) = List.split p in
+ (*let guard = (match bvid with
+ | Some (Id_aux (_,l)) -> Some (bit_annot l (E_lit (L_aux (L_true,l))))
+ | None -> None) in*)
+ let guards = (match bvid, is_bit_vector t with
+ | Some id, true -> List.mapi (test_bit_exp id t) p
+ | _, _ -> guards) in
+ (*let guards' = (function
+ | Some root_bv ->
+ let tests = List.mapi (test_bit_exp root_bv) p in
+ List.fold_right (binop_opt bitwise_and) tests None
+ | None -> compose_guards guards None) in*)
+ (P_vector ps, compose_guards guards))
+ ; p_vector_indexed = (fun p bvid t -> let (is,p) = List.split p in
+ let p = List.map ((|>) bvid) p in
+ let (ps,guards) = List.split p in
+ let guards = (match bvid, is_bit_vector t with
+ | Some id, true -> List.map2 (test_bit_exp id t) is p
+ | _, _ -> guards) in
+ (*let guards' = (function
+ | Some root_bv ->
+ let tests = List.map2 (test_bit_exp root_bv) is p in
+ List.fold_right (binop_opt bitwise_and) tests None
+ | None -> compose_guards guards None) in*)
+ let ps = List.combine is ps in
+ (P_vector_indexed ps, (*compose_guards guards*) None))
+ ; p_vector_concat = (fun ps _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in
+ (P_vector_concat ps, compose_guards guards))
+ ; p_tup = (fun ps _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in
+ (P_tup ps, compose_guards guards))
+ ; p_list = (fun ps _ _ -> let (ps,guards) = List.split (List.map ((|>) None) ps) in
+ (P_list ps, compose_guards guards))
+ ; p_aux = (fun (p,annot) bvid ->
+ let (l,Base((_,t),_,_,_,_,_)) = annot in
+ let (pat,guard) = p bvid t in
+ (*(P_aux (pat,annot), guard))*)
+ (match pat, is_bit_vector t with
+ | P_as (P_aux (P_vector _, _), id), true
+ | P_as (P_aux (P_vector_indexed _, _), id), true ->
+ (P_aux (P_id id, annot), guard)
+ | _, _ -> (P_aux (pat,annot), guard)))
+ ; fP_aux = (fun ((fpat,guard),annot) _ -> (FP_aux (fpat,annot), guard))
+ ; fP_Fpat = (fun (id,p) -> let (pat,guard) = p None in (FP_Fpat (id,pat), guard))
+ } in
+
+ let (pat,guard) = fold_pat guard_bitvector_pat pat None in
+ (pat, guard)
+
+let remove_wildcards pre (P_aux (_,(l,_)) as pat) =
+ fold_pat
+ {id_pat_alg with
+ p_aux = function
+ | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot))
+ | (p,annot) -> P_aux (p,annot) }
+ pat
+
+(* Check if one pattern subsumes the other, and if so, calculate a
+ substitution of variables that are used in the same position.
+ TODO: Check somewhere that there are no variable clashes (the same variable
+ name used in different positions of the patterns)
+ *)
+let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,_) as pat2) =
+ let rewrap p = P_aux (p,annot1) in
+ match p1, p2 with
+ | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) ->
+ if lit1 = lit2 then Some [] else None
+ | P_as (pat1,_), _ -> subsumes_pat pat1 pat2
+ | _, P_as (pat2,_) -> subsumes_pat pat1 pat2
+ | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2
+ | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2
+ | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) ->
+ if id1 = id2 then Some [] else Some [(id2,id1)]
+ | P_id id1, _ -> Some []
+ | P_wild, _ -> Some []
+ | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) ->
+ if id1 = id2 then subsumes_pat_list args1 args2 else None
+ | P_vector pats1, P_vector pats2
+ | P_tup pats1, P_tup pats2
+ | P_list pats1, P_list pats2 ->
+ subsumes_pat_list pats1 pats2
+ (* TODO: records *)
+ (* TODO: indexed vectors, vector concats *)
+ | _ -> None
+and subsumes_pat_list pats1 pats2 =
+ if List.length pats1 = List.length pats2
+ then
+ let subs = List.map2 subsumes_pat pats1 pats2 in
+ List.fold_right
+ (fun p acc -> match p, acc with
+ | Some subst, Some substs -> Some (subst @ substs)
+ | _ -> None)
+ subs (Some [])
+ else None
+
+let equiv_pats pat1 pat2 =
+ match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with
+ | Some _, Some _ -> true
+ | _, _ -> false
+
+let subst_id_pat pat (id1,id2) =
+ let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in
+ fold_pat {id_pat_alg with p_id = p_id} pat
+
+let subst_id_exp exp (id1,id2) =
+ (* TODO Don't substitute bound occurrences inside let expressions etc *)
+ let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in
+ fold_exp {id_exp_alg with e_id = e_id} exp
+
+let gen_annot l t efr = (Parse_ast.Generated l,simple_annot_efr t efr)
+
+let rec pat_to_exp (P_aux (pat,(l,annot))) =
+ let rewrap e = E_aux (e,(l,annot)) in
+ match pat with
+ | P_lit lit -> rewrap (E_lit lit)
+ | P_wild -> raise (Reporting_basic.err_unreachable l
+ "pat_to_exp given wildcard pattern")
+ | P_as (pat,id) -> rewrap (E_id id)
+ | P_typ (_,pat) -> pat_to_exp pat
+ | P_id id -> rewrap (E_id id)
+ | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats))
+ | P_record (fpats,b) ->
+ rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot))))
+ | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats))
+ | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l
+ "pat_to_exp not implemented for P_vector_concat")
+ (* We assume that vector concatenation patterns have been transformed
+ away already *)
+ | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats))
+ | P_list pats -> rewrap (E_list (List.map pat_to_exp pats))
+ | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l
+ "pat_to_exp not implemented for P_vector_indexed")
+ (* TODO: We can't guess the default value for the indexed vector
+ expression here. We should make sure that indexed vector patterns are
+ bound to a variable via P_as before calling pat_to_exp *)
+and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) =
+ FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot))
+
+let case_exp e t cs =
+ let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in
+ let ps = List.map pexp cs in
+ (* let efr = union_effs (List.map get_effsum_pexp ps) in *)
+ fix_effsum_exp (E_aux (E_case (e,ps), gen_annot (get_loc e) t pure_e))
+
+let rewrite_guarded_clauses l cs =
+ let rec group clauses =
+ let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in
+ let rec group_aux current acc = (function
+ | ((pat,guard,body,annot) as c) :: cs ->
+ let (current_pat,_,_) = current in
+ (match subsumes_pat current_pat pat with
+ | Some substs ->
+ let pat' = List.fold_left subst_id_pat pat substs in
+ let guard' = (match guard with
+ | Some exp -> Some (List.fold_left subst_id_exp exp substs)
+ | None -> None) in
+ let body' = List.fold_left subst_id_exp body substs in
+ let c' = (pat',guard',body',annot) in
+ group_aux (add_clause current c') acc cs
+ | None ->
+ let pat = remove_wildcards "g__" pat in
+ group_aux (pat,[c],annot) (acc @ [current]) cs)
+ | [] -> acc @ [current]) in
+ let groups = match clauses with
+ | ((pat,guard,body,annot) as c) :: cs ->
+ group_aux (remove_wildcards "g__" pat, [c], annot) [] cs
+ | _ ->
+ raise (Reporting_basic.err_unreachable l
+ "group given empty list in rewrite_guarded_clauses") in
+ List.map (fun cs -> if_pexp cs) groups
+ and if_pexp (pat,cs,annot) = (match cs with
+ | c :: _ ->
+ (* fix_effsum_pexp (pexp *)
+ let body = if_exp pat cs in
+ let pexp = fix_effsum_pexp (Pat_aux (Pat_exp (pat,body),annot)) in
+ let (Pat_aux (Pat_exp (_,_),annot)) = pexp in
+ (pat, body, annot)
+ | [] ->
+ raise (Reporting_basic.err_unreachable l
+ "if_pexp given empty list in rewrite_guarded_clauses"))
+ and if_exp current_pat = (function
+ | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs ->
+ (match guard with
+ | Some exp ->
+ let else_exp =
+ if equiv_pats current_pat pat'
+ then if_exp current_pat (c' :: cs)
+ else case_exp (pat_to_exp current_pat) (get_type_annot annot') (group (c' :: cs)) in
+ fix_effsum_exp (E_aux (E_if (exp,body,else_exp), annot))
+ | None -> body)
+ | [(pat,guard,body,annot)] -> body
+ | [] ->
+ raise (Reporting_basic.err_unreachable l
+ "if_exp given empty list in rewrite_guarded_clauses")) in
+ group cs
+
+let rewrite_exp_remove_bitvector_pat rewriters nmap (E_aux (exp,(l,annot)) as full_exp) =
+ let rewrap e = E_aux (e,(l,annot)) in
+ let rewrite_rec = rewriters.rewrite_exp rewriters nmap in
+ let rewrite_base = rewrite_exp rewriters nmap in
+ match exp with
+ | E_case (e,ps)
+ when List.exists (fun (Pat_aux (Pat_exp (pat,_),_)) -> contains_bitvector_pat pat) ps ->
+ let clause (Pat_aux (Pat_exp (pat,body),annot')) =
+ let (pat',guard) = remove_bitvector_pat pat in
+ (pat',guard,rewrite_rec body,annot') in
+ let clauses = rewrite_guarded_clauses l (List.map clause ps) in
+ if (effectful e) then
+ let e = rewrite_rec e in
+ let (E_aux (_,(el,eannot))) = e in
+ let pat_e' = fresh_id_pat "p__" (el,eannot) in
+ let exp_e' = pat_to_exp pat_e' in
+ (* let fresh = fresh_id "p__" el in
+ let exp_e' = E_aux (E_id fresh, gen_annot l (get_type e) pure_e) in
+ let pat_e' = P_aux (P_id fresh, gen_annot l (get_type e) pure_e) in *)
+ let letbind_e = LB_aux (LB_val_implicit (pat_e',e), gen_annot l (get_type e) (get_effsum_exp e)) in
+ let exp' = case_exp exp_e' (get_type full_exp) clauses in
+ rewrap (E_let (letbind_e, exp'))
+ else case_exp e (get_type full_exp) clauses
+ (*| E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) ->
+ let (pat,_,decls) = remove_vector_concat_pat pat in
+ rewrap (E_let (LB_aux (LB_val_explicit (typ,pat,rewrite_rec v),annot'),
+ decls (rewrite_rec body)))
+ | E_let (LB_aux (LB_val_implicit (pat,v),annot'),body) ->
+ let (pat,_,decls) = remove_vector_concat_pat pat in
+ rewrap (E_let (LB_aux (LB_val_implicit (pat,rewrite_rec v),annot'),
+ decls (rewrite_rec body)))*)
+ | _ -> rewrite_base full_exp
+
+let rewrite_fun_remove_bitvector_pat
+ rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) =
+ let _ = reset_fresh_name_counter () in
+ (* TODO Can there be clauses with different id's in one FD_function? *)
+ let funcls = match funcls with
+ | (FCL_aux (FCL_Funcl(id,_,_),_) :: _) ->
+ let clause (FCL_aux (FCL_Funcl(_,pat,exp),annot)) =
+ let (pat,guard) = remove_bitvector_pat pat in
+ let exp = rewriters.rewrite_exp rewriters None exp in
+ (pat,guard,exp,annot) in
+ let cs = rewrite_guarded_clauses l (List.map clause funcls) in
+ List.map (fun (pat,exp,annot) -> FCL_aux (FCL_Funcl(id,pat,exp),annot)) cs
+ | _ -> funcls (* TODO is the empty list possible here? *) in
+ FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))
+
+(*let rewrite_defs_remove_vector_concat_pat rewriters (Defs defs) =
+ let rewrite_def d =
+ let d = rewriters.rewrite_def rewriters d in
+ match d with
+ | DEF_val (LB_aux (LB_val_explicit (t,pat,exp),a)) ->
+ let (pat,letbinds,_) = remove_vector_concat_pat pat in
+ let defvals = List.map (fun lb -> DEF_val lb) letbinds in
+ [DEF_val (LB_aux (LB_val_explicit (t,pat,exp),a))] @ defvals
+ | DEF_val (LB_aux (LB_val_implicit (pat,exp),a)) ->
+ let (pat,letbinds,_) = remove_vector_concat_pat pat in
+ let defvals = List.map (fun lb -> DEF_val lb) letbinds in
+ [DEF_val (LB_aux (LB_val_implicit (pat,exp),a))] @ defvals
+ | d -> [rewriters.rewrite_def rewriters d] in
+ Defs (List.flatten (List.map rewrite_def defs))*)
+
+let rewrite_defs_remove_bitvector_pats defs = rewrite_defs_base
+ {rewrite_exp = rewrite_exp_remove_bitvector_pat;
+ rewrite_pat = rewrite_pat;
+ rewrite_let = rewrite_let;
+ rewrite_lexp = rewrite_lexp;
+ rewrite_fun = rewrite_fun_remove_bitvector_pat;
+ rewrite_def = rewrite_def;
+ rewrite_defs = rewrite_defs_base } defs
+
(*Expects to be called after rewrite_defs; thus the following should not appear:
internal_exp of any form
lit vectors in patterns or expressions
@@ -1384,12 +1758,6 @@ let rewrite_defs_remove_blocks =
-let fresh_id ((l,_) as annot) =
- let current = fresh_name () in
- let id = Id_aux (Id ("w__" ^ string_of_int current), Parse_ast.Generated l) in
- let annot_var = (Parse_ast.Generated l,simple_annot (get_type_annot annot)) in
- E_aux (E_id id, annot_var)
-
let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
(* body is a function : E_id variable -> actual body *)
match get_type v with
@@ -1405,7 +1773,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let)
| _ ->
let (E_aux (_,((l,_) as annot))) = v in
- let ((E_aux (E_id id,_)) as e_id) = fresh_id annot in
+ let ((E_aux (E_id id,_)) as e_id) = fresh_id_exp "w__" annot in
let body = body e_id in
let annot_pat = (Parse_ast.Generated l,simple_annot (get_type v)) in
@@ -2146,6 +2514,7 @@ let rewrite_defs_remove_e_assign =
let rewrite_defs_lem =
top_sort_defs >>
rewrite_defs_remove_vector_concat >>
+ rewrite_defs_remove_bitvector_pats >>
rewrite_defs_exp_lift_assign >>
rewrite_defs_remove_blocks >>
rewrite_defs_letbind_effects >>
@@ -2154,4 +2523,3 @@ let rewrite_defs_lem =
rewrite_defs_remove_superfluous_letbinds >>
rewrite_defs_remove_superfluous_returns
-