diff options
Diffstat (limited to 'src/rewriter.ml')
| -rw-r--r-- | src/rewriter.ml | 80 |
1 files changed, 52 insertions, 28 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index 68bb2c2a..2a9782ba 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -1000,12 +1000,9 @@ let remove_blocks_exp_alg = let rec f = function | [e] -> e (* check with Kathy if that annotation is fine *) - | e :: es -> letbind_wild e (f es) - | e -> E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"}))) in -(* - | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"}))) + | e -> E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"}))) | e :: es -> letbind_wild e (f es) in - *) + let e_aux = function | (E_block es,annot) -> f es | (e,annot) -> E_aux (e,annot) in @@ -1043,7 +1040,7 @@ let rec value ((E_aux (exp_aux,_)) as exp) = | E_tuple es | E_vector es | E_list es -> List.fold_left (&&) true (List.map value es) - | _ -> false + | _ -> false let only_local_eff (l,(Base ((t_params,t),tag,nexps,eff,effsum,bounds))) = (l,Base ((t_params,t),tag,nexps,eff,eff,bounds)) @@ -1059,6 +1056,9 @@ let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = n_exp exp (fun exp -> if value exp then k exp else letbind exp k) + +and n_exp_pure (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = + n_exp exp (fun exp -> if not (effectful exp) then k exp else letbind exp k) and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp = mapCont n_exp_name exps k @@ -1128,7 +1128,7 @@ and n_exp_term (new_return : bool) (exp : 'a exp) : 'a exp = (* changed this from n_exp to n_exp_name so that when we return updated variables * from a for-loop, for example, we can just add those into the returned tuple and * don't need to a-normalise again *) - n_exp exp (fun exp -> exp) + n_exp_pure exp (fun exp -> exp) and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = @@ -1226,6 +1226,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = n_exp_name exp1 (fun exp1 -> k (rewrap_localeff (E_field (exp1,id)))) | E_case (exp1,pexps) -> + (* PROBABLY NEED to insert E_returns here *) n_exp_name exp1 (fun exp1 -> n_pexpL pexps (fun pexps -> let geteffs (Pat_aux (_,(_,Base (_,_,_,_,eff,_)))) = eff in @@ -1308,7 +1309,11 @@ let find_updated_vars exp = ; e_field = (fun (e1,id) -> e1) ; e_case = (fun (e1,pexps) -> e1 @ List.flatten pexps) ; e_let = (fun (lb,e2) -> lb @ e2) - ; e_assign = (fun ((None,[(id,b)]),e2) -> if b then id :: e2 else e2) + ; e_assign = + (function + | ((None,[(id,b)]),e2) -> if b then id :: e2 else e2 + | ((None,[]),e2) -> e2 + ) ; e_exit = (fun e1 -> e1) ; e_internal_cast = (fun (_,e1) -> e1) ; e_internal_exp = (fun _ -> []) @@ -1511,9 +1516,11 @@ let rec rewrite_for_if_case ((E_aux (expaux,annot)) as exp) = | _ -> exp -let replace_e_assign = +let replace_var_update_e_assign = let e_aux (expaux,annot) = + + let f v body lexp = let letbind (E_aux (E_id id,_) as e_id) (v : 'a exp) (body : 'a exp) : 'a exp = (* body is a function : E_id variable -> actual body *) @@ -1523,40 +1530,57 @@ let replace_e_assign = let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (geteffs body)) in let pat = P_aux (P_id id,annot_pat) in E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) in - - let f v body = function - | LEXP_aux (LEXP_id id,annot) -> - let eid = E_aux (E_id id,(Unknown,simple_annot (gettype v))) in - letbind eid v body - | LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,annot2),i),annot) -> - let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in - let v = E_aux (E_vector_update (eid,i,v),(Unknown,simple_annot (gettype_annot annot))) in - letbind eid v body - | LEXP_aux (LEXP_vector_range (LEXP_aux (LEXP_id id,annot2),i,j),annot) -> - let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in - let v = E_aux (E_vector_update_subrange (eid,i,j,v), - (Unknown,simple_annot (gettype_annot annot))) in - letbind eid v body in + + match lexp with + | LEXP_aux (LEXP_id id,annot) -> + let eid = E_aux (E_id id,(Unknown,simple_annot (gettype v))) in + letbind eid v body + | LEXP_aux (LEXP_cast (_,id),annot) -> + let eid = E_aux (E_id id,(Unknown,simple_annot (gettype v))) in + letbind eid v body + | LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,annot2),i),annot) -> + let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in + let v = E_aux (E_vector_update (eid,i,v),(Unknown,simple_annot (gettype_annot annot))) in + letbind eid v body + | LEXP_aux (LEXP_vector_range (LEXP_aux (LEXP_id id,annot2),i,j),annot) -> + let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in + let v = E_aux (E_vector_update_subrange (eid,i,j,v), + (Unknown,simple_annot (gettype_annot annot))) in + letbind eid v body in match expaux with | E_let (LB_aux (LB_val_explicit (_,_,E_aux (E_assign (lexp,v),annot2)),_),body) | E_let (LB_aux (LB_val_implicit (_,E_aux (E_assign (lexp,v),annot2)),_),body) - when - let {effect = Eset effs} = geteffs_annot annot2 in - List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs -> + when let {effect = Eset effs} = geteffs_annot annot2 in + List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs -> f v body lexp | E_let (lb,body) -> E_aux (E_let (lb,body),annot) (* E_internal_plet is only used for effectful terms, shouldn't be needed to deal with here *) | E_internal_let (LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_),v,body) -> let (E_aux (_,pannot)) = v in let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype body) (geteffs body)) in - E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_id id,pannot),v),lbannot),body),annot) + E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_id id,pannot),v),lbannot),body),annot) + | E_assign (lexp,v) + when let {effect = Eset effs} = geteffs_annot annot in + List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs -> + f v (E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"})))) lexp + + | _ -> E_aux (expaux,annot) in { id_exp_alg with e_aux = e_aux } +let replace_memwrite_e_assign = + let e_aux = fun (expaux,annot) -> + match expaux with + | E_assign (LEXP_aux (LEXP_memory (id,args),_),v) -> E_aux (E_app (id,args @ [v]),annot) + | _ -> E_aux (expaux,annot) in + { id_exp_alg with e_aux = e_aux } + let rewrite_defs_remove_e_assign = - let rewrite_exp _ _ e = (fold_exp replace_e_assign) (rewrite_for_if_case e) in + let rewrite_exp _ _ e = + (fold_exp replace_memwrite_e_assign) + ((fold_exp replace_var_update_e_assign) (rewrite_for_if_case e)) in rewrite_defs_base {rewrite_exp = rewrite_exp ; rewrite_pat = rewrite_pat |
