summaryrefslogtreecommitdiff
path: root/src/rewriter.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewriter.ml')
-rw-r--r--src/rewriter.ml80
1 files changed, 52 insertions, 28 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 68bb2c2a..2a9782ba 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -1000,12 +1000,9 @@ let remove_blocks_exp_alg =
let rec f = function
| [e] -> e (* check with Kathy if that annotation is fine *)
- | e :: es -> letbind_wild e (f es)
- | e -> E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"}))) in
-(*
- | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"})))
+ | e -> E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"})))
| e :: es -> letbind_wild e (f es) in
- *)
+
let e_aux = function
| (E_block es,annot) -> f es
| (e,annot) -> E_aux (e,annot) in
@@ -1043,7 +1040,7 @@ let rec value ((E_aux (exp_aux,_)) as exp) =
| E_tuple es
| E_vector es
| E_list es -> List.fold_left (&&) true (List.map value es)
- | _ -> false
+ | _ -> false
let only_local_eff (l,(Base ((t_params,t),tag,nexps,eff,effsum,bounds))) =
(l,Base ((t_params,t),tag,nexps,eff,eff,bounds))
@@ -1059,6 +1056,9 @@ let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list
let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
n_exp exp (fun exp -> if value exp then k exp else letbind exp k)
+
+and n_exp_pure (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
+ n_exp exp (fun exp -> if not (effectful exp) then k exp else letbind exp k)
and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp =
mapCont n_exp_name exps k
@@ -1128,7 +1128,7 @@ and n_exp_term (new_return : bool) (exp : 'a exp) : 'a exp =
(* changed this from n_exp to n_exp_name so that when we return updated variables
* from a for-loop, for example, we can just add those into the returned tuple and
* don't need to a-normalise again *)
- n_exp exp (fun exp -> exp)
+ n_exp_pure exp (fun exp -> exp)
and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
@@ -1226,6 +1226,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
n_exp_name exp1 (fun exp1 ->
k (rewrap_localeff (E_field (exp1,id))))
| E_case (exp1,pexps) ->
+ (* PROBABLY NEED to insert E_returns here *)
n_exp_name exp1 (fun exp1 ->
n_pexpL pexps (fun pexps ->
let geteffs (Pat_aux (_,(_,Base (_,_,_,_,eff,_)))) = eff in
@@ -1308,7 +1309,11 @@ let find_updated_vars exp =
; e_field = (fun (e1,id) -> e1)
; e_case = (fun (e1,pexps) -> e1 @ List.flatten pexps)
; e_let = (fun (lb,e2) -> lb @ e2)
- ; e_assign = (fun ((None,[(id,b)]),e2) -> if b then id :: e2 else e2)
+ ; e_assign =
+ (function
+ | ((None,[(id,b)]),e2) -> if b then id :: e2 else e2
+ | ((None,[]),e2) -> e2
+ )
; e_exit = (fun e1 -> e1)
; e_internal_cast = (fun (_,e1) -> e1)
; e_internal_exp = (fun _ -> [])
@@ -1511,9 +1516,11 @@ let rec rewrite_for_if_case ((E_aux (expaux,annot)) as exp) =
| _ -> exp
-let replace_e_assign =
+let replace_var_update_e_assign =
let e_aux (expaux,annot) =
+
+ let f v body lexp =
let letbind (E_aux (E_id id,_) as e_id) (v : 'a exp) (body : 'a exp) : 'a exp =
(* body is a function : E_id variable -> actual body *)
@@ -1523,40 +1530,57 @@ let replace_e_assign =
let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (geteffs body)) in
let pat = P_aux (P_id id,annot_pat) in
E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) in
-
- let f v body = function
- | LEXP_aux (LEXP_id id,annot) ->
- let eid = E_aux (E_id id,(Unknown,simple_annot (gettype v))) in
- letbind eid v body
- | LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,annot2),i),annot) ->
- let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in
- let v = E_aux (E_vector_update (eid,i,v),(Unknown,simple_annot (gettype_annot annot))) in
- letbind eid v body
- | LEXP_aux (LEXP_vector_range (LEXP_aux (LEXP_id id,annot2),i,j),annot) ->
- let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in
- let v = E_aux (E_vector_update_subrange (eid,i,j,v),
- (Unknown,simple_annot (gettype_annot annot))) in
- letbind eid v body in
+
+ match lexp with
+ | LEXP_aux (LEXP_id id,annot) ->
+ let eid = E_aux (E_id id,(Unknown,simple_annot (gettype v))) in
+ letbind eid v body
+ | LEXP_aux (LEXP_cast (_,id),annot) ->
+ let eid = E_aux (E_id id,(Unknown,simple_annot (gettype v))) in
+ letbind eid v body
+ | LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,annot2),i),annot) ->
+ let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in
+ let v = E_aux (E_vector_update (eid,i,v),(Unknown,simple_annot (gettype_annot annot))) in
+ letbind eid v body
+ | LEXP_aux (LEXP_vector_range (LEXP_aux (LEXP_id id,annot2),i,j),annot) ->
+ let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in
+ let v = E_aux (E_vector_update_subrange (eid,i,j,v),
+ (Unknown,simple_annot (gettype_annot annot))) in
+ letbind eid v body in
match expaux with
| E_let (LB_aux (LB_val_explicit (_,_,E_aux (E_assign (lexp,v),annot2)),_),body)
| E_let (LB_aux (LB_val_implicit (_,E_aux (E_assign (lexp,v),annot2)),_),body)
- when
- let {effect = Eset effs} = geteffs_annot annot2 in
- List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs ->
+ when let {effect = Eset effs} = geteffs_annot annot2 in
+ List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs ->
f v body lexp
| E_let (lb,body) -> E_aux (E_let (lb,body),annot)
(* E_internal_plet is only used for effectful terms, shouldn't be needed to deal with here *)
| E_internal_let (LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_),v,body) ->
let (E_aux (_,pannot)) = v in
let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype body) (geteffs body)) in
- E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_id id,pannot),v),lbannot),body),annot)
+ E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_id id,pannot),v),lbannot),body),annot)
+ | E_assign (lexp,v)
+ when let {effect = Eset effs} = geteffs_annot annot in
+ List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs ->
+ f v (E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"})))) lexp
+
+
| _ -> E_aux (expaux,annot) in
{ id_exp_alg with e_aux = e_aux }
+let replace_memwrite_e_assign =
+ let e_aux = fun (expaux,annot) ->
+ match expaux with
+ | E_assign (LEXP_aux (LEXP_memory (id,args),_),v) -> E_aux (E_app (id,args @ [v]),annot)
+ | _ -> E_aux (expaux,annot) in
+ { id_exp_alg with e_aux = e_aux }
+
let rewrite_defs_remove_e_assign =
- let rewrite_exp _ _ e = (fold_exp replace_e_assign) (rewrite_for_if_case e) in
+ let rewrite_exp _ _ e =
+ (fold_exp replace_memwrite_e_assign)
+ ((fold_exp replace_var_update_e_assign) (rewrite_for_if_case e)) in
rewrite_defs_base
{rewrite_exp = rewrite_exp
; rewrite_pat = rewrite_pat