summaryrefslogtreecommitdiff
path: root/src/rewriter.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewriter.ml')
-rw-r--r--src/rewriter.ml427
1 files changed, 235 insertions, 192 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 981de14c..0cf25103 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -1015,11 +1015,15 @@ let rewrite_sizeof (Defs defs) =
when string_of_id atom = "atom" ->
[nexp, E_id id]
| Typ_app (vector, _) when string_of_id vector = "vector" ->
- let (_,len,_,_) = vector_typ_args_of typ_aux in
- let exp = E_app
- (Id_aux (Id "length", Parse_ast.Generated l),
- [E_aux (E_id id, annot)]) in
- [len, exp]
+ let id_length = Id_aux (Id "length", Parse_ast.Generated l) in
+ (try
+ (match Env.get_val_spec id_length (env_of_annot annot) with
+ | _ ->
+ let (_,len,_,_) = vector_typ_args_of typ_aux in
+ let exp = E_app (id_length, [E_aux (E_id id, annot)]) in
+ [len, exp])
+ with
+ | _ -> [])
| _ -> [])
| _ -> [] in
(v @ v', P_aux (pat,annot)))} pat) in
@@ -1488,6 +1492,177 @@ let rewrite_defs_remove_vector_concat (Defs defs) =
| d -> [d] in
Defs (List.flatten (List.map rewrite_def defs))
+(* A few helper functions for rewriting guarded pattern clauses.
+ Used both by the rewriting of P_when and separately by the rewriting of
+ bitvectors in parameter patterns of function clauses *)
+
+let remove_wildcards pre (P_aux (_,(l,_)) as pat) =
+ fold_pat
+ {id_pat_alg with
+ p_aux = function
+ | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot))
+ | (p,annot) -> P_aux (p,annot) }
+ pat
+
+(* Check if one pattern subsumes the other, and if so, calculate a
+ substitution of variables that are used in the same position.
+ TODO: Check somewhere that there are no variable clashes (the same variable
+ name used in different positions of the patterns)
+ *)
+let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) =
+ let rewrap p = P_aux (p,annot1) in
+ let subsumes_list s pats1 pats2 =
+ if List.length pats1 = List.length pats2
+ then
+ let subs = List.map2 s pats1 pats2 in
+ List.fold_right
+ (fun p acc -> match p, acc with
+ | Some subst, Some substs -> Some (subst @ substs)
+ | _ -> None)
+ subs (Some [])
+ else None in
+ match p1, p2 with
+ | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) ->
+ if lit1 = lit2 then Some [] else None
+ | P_as (pat1,_), _ -> subsumes_pat pat1 pat2
+ | _, P_as (pat2,_) -> subsumes_pat pat1 pat2
+ | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2
+ | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2
+ | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) ->
+ if id1 = id2 then Some []
+ else if Env.lookup_id aid1 (env_of_annot annot1) = Unbound &&
+ Env.lookup_id aid2 (env_of_annot annot2) = Unbound
+ then Some [(id2,id1)] else None
+ | P_id id1, _ ->
+ if Env.lookup_id id1 (env_of_annot annot1) = Unbound then Some [] else None
+ | P_wild, _ -> Some []
+ | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) ->
+ if id1 = id2 then subsumes_list subsumes_pat args1 args2 else None
+ | P_record (fps1,b1), P_record (fps2,b2) ->
+ if b1 = b2 then subsumes_list subsumes_fpat fps1 fps2 else None
+ | P_vector pats1, P_vector pats2
+ | P_vector_concat pats1, P_vector_concat pats2
+ | P_tup pats1, P_tup pats2
+ | P_list pats1, P_list pats2 ->
+ subsumes_list subsumes_pat pats1 pats2
+ | P_list (pat1 :: pats1), P_cons _ ->
+ subsumes_pat (rewrap (P_cons (pat1, rewrap (P_list pats1)))) pat2
+ | P_cons _, P_list (pat2 :: pats2)->
+ subsumes_pat pat1 (rewrap (P_cons (pat2, rewrap (P_list pats2))))
+ | P_cons (pat1, pats1), P_cons (pat2, pats2) ->
+ (match subsumes_pat pat1 pat2, subsumes_pat pats1 pats2 with
+ | Some substs1, Some substs2 -> Some (substs1 @ substs2)
+ | _ -> None)
+ | P_vector_indexed ips1, P_vector_indexed ips2 ->
+ let (is1,ps1) = List.split ips1 in
+ let (is2,ps2) = List.split ips2 in
+ if is1 = is2 then subsumes_list subsumes_pat ps1 ps2 else None
+ | _ -> None
+and subsumes_fpat (FP_aux (FP_Fpat (id1,pat1),_)) (FP_aux (FP_Fpat (id2,pat2),_)) =
+ if id1 = id2 then subsumes_pat pat1 pat2 else None
+
+let equiv_pats pat1 pat2 =
+ match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with
+ | Some _, Some _ -> true
+ | _, _ -> false
+
+let subst_id_pat pat (id1,id2) =
+ let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in
+ fold_pat {id_pat_alg with p_id = p_id} pat
+
+let subst_id_exp exp (id1,id2) =
+ (* TODO Don't substitute bound occurrences inside let expressions etc *)
+ let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in
+ fold_exp {id_exp_alg with e_id = e_id} exp
+
+let rec pat_to_exp (P_aux (pat,(l,annot))) =
+ let rewrap e = E_aux (e,(l,annot)) in
+ match pat with
+ | P_lit lit -> rewrap (E_lit lit)
+ | P_wild -> raise (Reporting_basic.err_unreachable l
+ "pat_to_exp given wildcard pattern")
+ | P_as (pat,id) -> rewrap (E_id id)
+ | P_typ (_,pat) -> pat_to_exp pat
+ | P_id id -> rewrap (E_id id)
+ | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats))
+ | P_record (fpats,b) ->
+ rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot))))
+ | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats))
+ | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l
+ "pat_to_exp not implemented for P_vector_concat")
+ (* We assume that vector concatenation patterns have been transformed
+ away already *)
+ | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats))
+ | P_list pats -> rewrap (E_list (List.map pat_to_exp pats))
+ | P_cons (p,ps) -> rewrap (E_cons (pat_to_exp p, pat_to_exp ps))
+ | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l
+ "pat_to_exp not implemented for P_vector_indexed") (* TODO *)
+and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) =
+ FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot))
+
+let case_exp e t cs =
+ let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in
+ let ps = List.map pexp cs in
+ (* let efr = union_effs (List.map effect_of_pexp ps) in *)
+ fix_eff_exp (E_aux (E_case (e,ps), (get_loc_exp e, Some (env_of e, t, no_effect))))
+
+let rewrite_guarded_clauses l cs =
+ let rec group clauses =
+ let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in
+ let rec group_aux current acc = (function
+ | ((pat,guard,body,annot) as c) :: cs ->
+ let (current_pat,_,_) = current in
+ (match subsumes_pat current_pat pat with
+ | Some substs ->
+ let pat' = List.fold_left subst_id_pat pat substs in
+ let guard' = (match guard with
+ | Some exp -> Some (List.fold_left subst_id_exp exp substs)
+ | None -> None) in
+ let body' = List.fold_left subst_id_exp body substs in
+ let c' = (pat',guard',body',annot) in
+ group_aux (add_clause current c') acc cs
+ | None ->
+ let pat = remove_wildcards "g__" pat in
+ group_aux (pat,[c],annot) (acc @ [current]) cs)
+ | [] -> acc @ [current]) in
+ let groups = match clauses with
+ | ((pat,guard,body,annot) as c) :: cs ->
+ group_aux (remove_wildcards "g__" pat, [c], annot) [] cs
+ | _ ->
+ raise (Reporting_basic.err_unreachable l
+ "group given empty list in rewrite_guarded_clauses") in
+ List.map (fun cs -> if_pexp cs) groups
+ and if_pexp (pat,cs,annot) = (match cs with
+ | c :: _ ->
+ (* fix_eff_pexp (pexp *)
+ let body = if_exp pat cs in
+ let pexp = fix_eff_pexp (Pat_aux (Pat_exp (pat,body),annot)) in
+ let (Pat_aux (_,annot)) = pexp in
+ (pat, body, annot)
+ | [] ->
+ raise (Reporting_basic.err_unreachable l
+ "if_pexp given empty list in rewrite_guarded_clauses"))
+ and if_exp current_pat = (function
+ | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs ->
+ (match guard with
+ | Some exp ->
+ let else_exp =
+ if equiv_pats current_pat pat'
+ then if_exp current_pat (c' :: cs)
+ else case_exp (pat_to_exp current_pat) (typ_of body') (group (c' :: cs)) in
+ fix_eff_exp (E_aux (E_if (exp,body,else_exp), simple_annot (fst annot) (typ_of body)))
+ | None -> body)
+ | [(pat,guard,body,annot)] -> body
+ | [] ->
+ raise (Reporting_basic.err_unreachable l
+ "if_exp given empty list in rewrite_guarded_clauses")) in
+ group cs
+
+let bitwise_and_exp exp1 exp2 =
+ let (E_aux (_,(l,_))) = exp1 in
+ let andid = Id_aux (Id "bool_and", Parse_ast.Generated l) in
+ E_aux (E_app(andid,[exp1;exp2]), simple_annot l bool_typ)
+
let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with
| P_lit _ | P_wild | P_id _ -> false
| P_as (pat,_) | P_typ (_,pat) -> contains_bitvector_pat pat
@@ -1500,6 +1675,12 @@ let rec contains_bitvector_pat (P_aux (pat,annot)) = match pat with
| P_record (fpats,_) ->
List.exists (fun (FP_aux (FP_Fpat (_,pat),_)) -> contains_bitvector_pat pat) fpats
+let contains_bitvector_pexp = function
+| Pat_aux (Pat_exp (pat,_),_) | Pat_aux (Pat_when (pat,_,_),_) ->
+ contains_bitvector_pat pat
+
+(* Rewrite bitvector patterns to guarded patterns *)
+
let remove_bitvector_pat pat =
(* first introduce names for bitvector patterns *)
@@ -1585,14 +1766,8 @@ let remove_bitvector_pat pat =
E_aux (E_let (letbind,body), (Parse_ast.Generated l, bannot))) in
(letexp, letbind) in
- (* Helper functions for composing guards *)
- let bitwise_and exp1 exp2 =
- let (E_aux (_,(l,_))) = exp1 in
- let andid = Id_aux (Id "bool_and", Parse_ast.Generated l) in
- E_aux (E_app(andid,[exp1;exp2]), simple_annot l bool_typ) in
-
let compose_guards guards =
- List.fold_right (Util.option_binop bitwise_and) guards None in
+ List.fold_right (Util.option_binop bitwise_and_exp) guards None in
let flatten_guards_decls gd =
let (guards,decls,letbinds) = Util.split3 gd in
@@ -1695,192 +1870,27 @@ let remove_bitvector_pat pat =
} in
fold_pat guard_bitvector_pat pat
-let remove_wildcards pre (P_aux (_,(l,_)) as pat) =
- fold_pat
- {id_pat_alg with
- p_aux = function
- | (P_wild,(l,annot)) -> P_aux (P_id (fresh_id pre l),(l,annot))
- | (p,annot) -> P_aux (p,annot) }
- pat
-
-(* Check if one pattern subsumes the other, and if so, calculate a
- substitution of variables that are used in the same position.
- TODO: Check somewhere that there are no variable clashes (the same variable
- name used in different positions of the patterns)
- *)
-let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) =
- let rewrap p = P_aux (p,annot1) in
- let subsumes_list s pats1 pats2 =
- if List.length pats1 = List.length pats2
- then
- let subs = List.map2 s pats1 pats2 in
- List.fold_right
- (fun p acc -> match p, acc with
- | Some subst, Some substs -> Some (subst @ substs)
- | _ -> None)
- subs (Some [])
- else None in
- match p1, p2 with
- | P_lit (L_aux (lit1,_)), P_lit (L_aux (lit2,_)) ->
- if lit1 = lit2 then Some [] else None
- | P_as (pat1,_), _ -> subsumes_pat pat1 pat2
- | _, P_as (pat2,_) -> subsumes_pat pat1 pat2
- | P_typ (_,pat1), _ -> subsumes_pat pat1 pat2
- | _, P_typ (_,pat2) -> subsumes_pat pat1 pat2
- | P_id (Id_aux (id1,_) as aid1), P_id (Id_aux (id2,_) as aid2) ->
- if id1 = id2 then Some []
- else if Env.lookup_id aid1 (env_of_annot annot1) = Unbound &&
- Env.lookup_id aid2 (env_of_annot annot2) = Unbound
- then Some [(id2,id1)] else None
- | P_id id1, _ ->
- if Env.lookup_id id1 (env_of_annot annot1) = Unbound then Some [] else None
- | P_wild, _ -> Some []
- | P_app (Id_aux (id1,l1),args1), P_app (Id_aux (id2,_),args2) ->
- if id1 = id2 then subsumes_list subsumes_pat args1 args2 else None
- | P_record (fps1,b1), P_record (fps2,b2) ->
- if b1 = b2 then subsumes_list subsumes_fpat fps1 fps2 else None
- | P_vector pats1, P_vector pats2
- | P_vector_concat pats1, P_vector_concat pats2
- | P_tup pats1, P_tup pats2
- | P_list pats1, P_list pats2 ->
- subsumes_list subsumes_pat pats1 pats2
- | P_list (pat1 :: pats1), P_cons _ ->
- subsumes_pat (rewrap (P_cons (pat1, rewrap (P_list pats1)))) pat2
- | P_cons _, P_list (pat2 :: pats2)->
- subsumes_pat pat1 (rewrap (P_cons (pat2, rewrap (P_list pats2))))
- | P_cons (pat1, pats1), P_cons (pat2, pats2) ->
- (match subsumes_pat pat1 pat2, subsumes_pat pats1 pats2 with
- | Some substs1, Some substs2 -> Some (substs1 @ substs2)
- | _ -> None)
- | P_vector_indexed ips1, P_vector_indexed ips2 ->
- let (is1,ps1) = List.split ips1 in
- let (is2,ps2) = List.split ips2 in
- if is1 = is2 then subsumes_list subsumes_pat ps1 ps2 else None
- | _ -> None
-and subsumes_fpat (FP_aux (FP_Fpat (id1,pat1),_)) (FP_aux (FP_Fpat (id2,pat2),_)) =
- if id1 = id2 then subsumes_pat pat1 pat2 else None
-
-let equiv_pats pat1 pat2 =
- match subsumes_pat pat1 pat2, subsumes_pat pat2 pat1 with
- | Some _, Some _ -> true
- | _, _ -> false
-
-let subst_id_pat pat (id1,id2) =
- let p_id (Id_aux (id,l)) = (if id = id1 then P_id (Id_aux (id2,l)) else P_id (Id_aux (id,l))) in
- fold_pat {id_pat_alg with p_id = p_id} pat
-
-let subst_id_exp exp (id1,id2) =
- (* TODO Don't substitute bound occurrences inside let expressions etc *)
- let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in
- fold_exp {id_exp_alg with e_id = e_id} exp
-
-let rec pat_to_exp (P_aux (pat,(l,annot))) =
- let rewrap e = E_aux (e,(l,annot)) in
- match pat with
- | P_lit lit -> rewrap (E_lit lit)
- | P_wild -> raise (Reporting_basic.err_unreachable l
- "pat_to_exp given wildcard pattern")
- | P_as (pat,id) -> rewrap (E_id id)
- | P_typ (_,pat) -> pat_to_exp pat
- | P_id id -> rewrap (E_id id)
- | P_app (id,pats) -> rewrap (E_app (id, List.map pat_to_exp pats))
- | P_record (fpats,b) ->
- rewrap (E_record (FES_aux (FES_Fexps (List.map fpat_to_fexp fpats,b),(l,annot))))
- | P_vector pats -> rewrap (E_vector (List.map pat_to_exp pats))
- | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l
- "pat_to_exp not implemented for P_vector_concat")
- (* We assume that vector concatenation patterns have been transformed
- away already *)
- | P_tup pats -> rewrap (E_tuple (List.map pat_to_exp pats))
- | P_list pats -> rewrap (E_list (List.map pat_to_exp pats))
- | P_cons (p,ps) -> rewrap (E_cons (pat_to_exp p, pat_to_exp ps))
- | P_vector_indexed ipats -> raise (Reporting_basic.err_unreachable l
- "pat_to_exp not implemented for P_vector_indexed") (* TODO *)
-and fpat_to_fexp (FP_aux (FP_Fpat (id,pat),(l,annot))) =
- FE_aux (FE_Fexp (id, pat_to_exp pat),(l,annot))
-
-let case_exp e t cs =
- let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in
- let ps = List.map pexp cs in
- (* let efr = union_effs (List.map effect_of_pexp ps) in *)
- fix_eff_exp (E_aux (E_case (e,ps), (get_loc_exp e, Some (env_of e, t, no_effect))))
-
-let rewrite_guarded_clauses l cs =
- let rec group clauses =
- let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in
- let rec group_aux current acc = (function
- | ((pat,guard,body,annot) as c) :: cs ->
- let (current_pat,_,_) = current in
- (match subsumes_pat current_pat pat with
- | Some substs ->
- let pat' = List.fold_left subst_id_pat pat substs in
- let guard' = (match guard with
- | Some exp -> Some (List.fold_left subst_id_exp exp substs)
- | None -> None) in
- let body' = List.fold_left subst_id_exp body substs in
- let c' = (pat',guard',body',annot) in
- group_aux (add_clause current c') acc cs
- | None ->
- let pat = remove_wildcards "g__" pat in
- group_aux (pat,[c],annot) (acc @ [current]) cs)
- | [] -> acc @ [current]) in
- let groups = match clauses with
- | ((pat,guard,body,annot) as c) :: cs ->
- group_aux (remove_wildcards "g__" pat, [c], annot) [] cs
- | _ ->
- raise (Reporting_basic.err_unreachable l
- "group given empty list in rewrite_guarded_clauses") in
- List.map (fun cs -> if_pexp cs) groups
- and if_pexp (pat,cs,annot) = (match cs with
- | c :: _ ->
- (* fix_eff_pexp (pexp *)
- let body = if_exp pat cs in
- let pexp = fix_eff_pexp (Pat_aux (Pat_exp (pat,body),annot)) in
- let (Pat_aux (_,annot)) = pexp in
- (pat, body, annot)
- | [] ->
- raise (Reporting_basic.err_unreachable l
- "if_pexp given empty list in rewrite_guarded_clauses"))
- and if_exp current_pat = (function
- | (pat,guard,body,annot) :: ((pat',guard',body',annot') as c') :: cs ->
- (match guard with
- | Some exp ->
- let else_exp =
- if equiv_pats current_pat pat'
- then if_exp current_pat (c' :: cs)
- else case_exp (pat_to_exp current_pat) (typ_of body') (group (c' :: cs)) in
- fix_eff_exp (E_aux (E_if (exp,body,else_exp), simple_annot (fst annot) (typ_of body)))
- | None -> body)
- | [(pat,guard,body,annot)] -> body
- | [] ->
- raise (Reporting_basic.err_unreachable l
- "if_exp given empty list in rewrite_guarded_clauses")) in
- group cs
-
let rewrite_exp_remove_bitvector_pat rewriters (E_aux (exp,(l,annot)) as full_exp) =
let rewrap e = E_aux (e,(l,annot)) in
let rewrite_rec = rewriters.rewrite_exp rewriters in
let rewrite_base = rewrite_exp rewriters in
match exp with
| E_case (e,ps)
- when List.exists (fun (Pat_aux ((Pat_exp (pat,_)|Pat_when(pat,_,_)),_)) -> contains_bitvector_pat pat) ps ->
- let clause (Pat_aux (Pat_exp (pat,body),annot')) =
- let (pat',(guard,decls,_)) = remove_bitvector_pat pat in
+ when List.exists contains_bitvector_pexp ps ->
+ let rewrite_pexp = function
+ | Pat_aux (Pat_exp (pat,body),annot') ->
+ let (pat',(guard',decls,_)) = remove_bitvector_pat pat in
+ let body' = decls (rewrite_rec body) in
+ (match guard' with
+ | Some guard' -> Pat_aux (Pat_when (pat', guard', body'), annot')
+ | None -> Pat_aux (Pat_exp (pat', body'), annot'))
+ | Pat_aux (Pat_when (pat,guard,body),annot') ->
+ let (pat',(guard',decls,_)) = remove_bitvector_pat pat in
let body' = decls (rewrite_rec body) in
- (pat',guard,body',annot') in
- let clauses = rewrite_guarded_clauses l (List.map clause ps) in
- if (effectful e) then
- let e = rewrite_rec e in
- let (E_aux (_,(el,eannot))) = e in
- let pat_e' = fresh_id_pat "p__" (el,eannot) in
- let exp_e' = pat_to_exp pat_e' in
- (* let fresh = fresh_id "p__" el in
- let exp_e' = E_aux (E_id fresh, gen_annot l (get_type e) pure_e) in
- let pat_e' = P_aux (P_id fresh, gen_annot l (get_type e) pure_e) in *)
- let letbind_e = LB_aux (LB_val_implicit (pat_e',e), (el,eannot)) in
- let exp' = case_exp exp_e' (typ_of full_exp) clauses in
- rewrap (E_let (letbind_e, exp'))
- else case_exp e (typ_of full_exp) clauses
+ (match guard' with
+ | Some guard' -> Pat_aux (Pat_when (pat', bitwise_and_exp guard guard', body'), annot')
+ | None -> Pat_aux (Pat_when (pat', guard, body'), annot')) in
+ rewrap (E_case (e, List.map rewrite_pexp ps))
| E_let (LB_aux (LB_val_explicit (typ,pat,v),annot'),body) ->
let (pat,(_,decls,_)) = remove_bitvector_pat pat in
rewrap (E_let (LB_aux (LB_val_explicit (typ,pat,rewrite_rec v),annot'),
@@ -1930,6 +1940,38 @@ let rewrite_defs_remove_bitvector_pats (Defs defs) =
Defs (List.flatten (List.map rewrite_def defs))
+(* Remove pattern guards by rewriting them to if-expressions within the
+ pattern expression. Shares code with the rewriting of bitvector patterns. *)
+let rewrite_exp_guarded_pats rewriters (E_aux (exp,(l,annot)) as full_exp) =
+ let rewrap e = E_aux (e,(l,annot)) in
+ let rewrite_rec = rewriters.rewrite_exp rewriters in
+ let rewrite_base = rewrite_exp rewriters in
+ let is_guarded_pexp = function
+ | Pat_aux (Pat_when (_,_,_),_) -> true
+ | _ -> false in
+ match exp with
+ | E_case (e,ps)
+ when List.exists is_guarded_pexp ps ->
+ let clause = function
+ | Pat_aux (Pat_exp (pat, body), annot) ->
+ (pat, None, rewrite_rec body, annot)
+ | Pat_aux (Pat_when (pat, guard, body), annot) ->
+ (pat, Some guard, rewrite_rec body, annot) in
+ let clauses = rewrite_guarded_clauses l (List.map clause ps) in
+ if (effectful e) then
+ let e = rewrite_rec e in
+ let (E_aux (_,(el,eannot))) = e in
+ let pat_e' = fresh_id_pat "p__" (el,eannot) in
+ let exp_e' = pat_to_exp pat_e' in
+ let letbind_e = LB_aux (LB_val_implicit (pat_e',e), (el,eannot)) in
+ let exp' = case_exp exp_e' (typ_of full_exp) clauses in
+ rewrap (E_let (letbind_e, exp'))
+ else case_exp e (typ_of full_exp) clauses
+ | _ -> rewrite_base full_exp
+
+let rewrite_defs_guarded_pats =
+ rewrite_defs_base { rewriters_base with rewrite_exp = rewrite_exp_guarded_pats }
+
(*Expects to be called after rewrite_defs; thus the following should not appear:
internal_exp of any form
lit vectors in patterns or expressions
@@ -2919,9 +2961,10 @@ let rewrite_defs_remove_e_assign =
let rewrite_defs_lem =
top_sort_defs >>
+ rewrite_sizeof >>
rewrite_defs_remove_vector_concat >>
rewrite_defs_remove_bitvector_pats >>
- rewrite_sizeof >>
+ rewrite_defs_guarded_pats >>
rewrite_defs_exp_lift_assign >>
rewrite_defs_remove_blocks >>
rewrite_defs_letbind_effects >>