summaryrefslogtreecommitdiff
path: root/src/rewriter.ml
diff options
context:
space:
mode:
authorChristopher Pulte2015-11-05 08:45:31 +0000
committerChristopher Pulte2015-11-05 08:45:31 +0000
commitbf36f5273afa8a63adcd739e09f29bd0f64d9527 (patch)
treefe31b8b6d0ce14d073b474e4c31ddf229301e5de /src/rewriter.ml
parent0f935fbc68d0000bbb97eccfe54f54292cb2b36f (diff)
some progress on lem backend: rewrite away mutable variable assignments, rewrite for-loops, if/case-expressions to return updated variables
Diffstat (limited to 'src/rewriter.ml')
-rw-r--r--src/rewriter.ml396
1 files changed, 357 insertions, 39 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 39234b11..68bb2c2a 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -17,6 +17,15 @@ type 'a rewriters = { rewrite_exp : 'a rewriters -> (nexp_map * 'a namemap) opt
rewrite_defs : 'a rewriters -> 'a defs -> 'a defs;
}
+
+let fresh_name_counter = ref 0
+
+let fresh_name () =
+ let current = !fresh_name_counter in
+ let () = fresh_name_counter := (current + 1) in
+ current
+
+
let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with
| [] -> None
| (v1,v2)::ls -> if (eq v1 v) then Some v2 else partial_assoc eq v ls
@@ -593,7 +602,6 @@ let id_exp_alg =
}
-let remove_vector_concat_pat_counter = ref 0
let remove_vector_concat_pat pat =
(* ivc: bool that indicates whether the exp is in a vector_concat pattern *)
let remove_tannot_in_vector_concats =
@@ -624,8 +632,7 @@ let remove_vector_concat_pat pat =
let pat = (fold_pat remove_tannot_in_vector_concats pat) false in
let fresh_name () =
- let current = !remove_vector_concat_pat_counter in
- let () = remove_vector_concat_pat_counter := (current + 1) in
+ let current = fresh_name () in
Id_aux (Id ("__v" ^ string_of_int current), Parse_ast.Unknown) in
(* expects that P_typ elements have been removed from AST,
@@ -956,24 +963,26 @@ let rewrite_defs_ocaml defs =
-let geteffs_annot = function
- | (_,Base (_,_,_,_,effs,_)) -> effs
- | (_,NoTyp) -> failwith "no effect information"
+let geteffs_annot (_,t) = match t with
+ | Base (_,_,_,_,effs,_) -> effs
+ | NoTyp -> failwith "no effect information"
| _ -> failwith "a_normalise doesn't support Overload"
-let geteffs (E_aux (_,a)) = geteffs_annot a
-
-let gettype (E_aux (_,(_,a))) =
- match a with
+let gettype_annot (_,t) = match t with
| Base((_,t),_,_,_,_,_) -> t
| NoTyp -> failwith "no type information"
| _ -> failwith "a_normalise doesn't support Overload"
-
+let gettype (E_aux (_,a)) = gettype_annot a
+let geteffs (E_aux (_,a)) = geteffs_annot a
+
let effectful_effs {effect = Eset effs} =
List.exists
- (fun (BE_aux (be,_)) -> match be with BE_nondet | BE_unspec | BE_undef -> false | _ -> true)
- effs
+ (fun (BE_aux (be,_)) ->
+ match be with
+ | BE_nondet | BE_unspec | BE_undef | BE_lset -> false
+ | _ -> true
+ ) effs
let effectful eaux =
effectful_effs (geteffs eaux)
@@ -992,8 +1001,11 @@ let remove_blocks_exp_alg =
let rec f = function
| [e] -> e (* check with Kathy if that annotation is fine *)
| e :: es -> letbind_wild e (f es)
- | [] -> failwith "empty block encountered" in
-
+ | e -> E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"}))) in
+(*
+ | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)), (Unknown,simple_annot ({t = Tid "unit"})))
+ | e :: es -> letbind_wild e (f es) in
+ *)
let e_aux = function
| (E_block es,annot) -> f es
| (e,annot) -> E_aux (e,annot) in
@@ -1001,26 +1013,23 @@ let remove_blocks_exp_alg =
{ id_exp_alg with e_aux = e_aux }
-let a_normalise_counter = ref 0
+
+let fresh_id annot =
+ let current = fresh_name () in
+ let id = Id_aux (Id ("__w" ^ string_of_int current), Parse_ast.Unknown) in
+ let annot_var = (Parse_ast.Unknown,simple_annot (gettype_annot annot)) in
+ E_aux (E_id id, annot_var)
let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
(* body is a function : E_id variable -> actual body *)
-
- let fresh_id () =
- let current = !a_normalise_counter in
- let () = a_normalise_counter := (current + 1) in
- Id_aux (Id ("__w" ^ string_of_int current), Parse_ast.Unknown) in
-
- let freshid = fresh_id () in
- let annot_var = (Parse_ast.Unknown,simple_annot (gettype v)) in
- let eid = E_aux (E_id freshid, annot_var) in
-
- let body = body eid in
+ let (E_aux (_,annot)) = v in
+ let ((E_aux (E_id id,_)) as e_id) = fresh_id annot in
+ let body = body e_id in
- let annot_pat = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
+ let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in
let annot_lb = annot_pat in
- let annot_let = (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v body)) in
- let pat = P_aux (P_id freshid,annot_pat) in
+ let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union v body)) in
+ let pat = P_aux (P_id id,annot_pat) in
if effectful v
then E_aux (E_internal_plet (pat,v,body),annot_let)
@@ -1034,11 +1043,15 @@ let rec value ((E_aux (exp_aux,_)) as exp) =
| E_tuple es
| E_vector es
| E_list es -> List.fold_left (&&) true (List.map value es)
- | _ -> false
+ | _ -> false
let only_local_eff (l,(Base ((t_params,t),tag,nexps,eff,effsum,bounds))) =
(l,Base ((t_params,t),tag,nexps,eff,eff,bounds))
+let local_eff_plus eff2 (l,(Base ((t_params,t),tag,nexps,eff,effsum,bounds))) =
+ let effsum = union_effects eff eff2 in
+ (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds))
+
let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp =
match l with
| [] -> k []
@@ -1093,23 +1106,34 @@ and n_lexp (lexp : 'a lexp) (k : 'a lexp -> 'a exp) : 'a exp =
n_exp_nameL es (fun es -> k (LEXP_aux (LEXP_memory (id,es),only_local_eff annot)))
| LEXP_cast (typ,id) -> k (LEXP_aux (LEXP_cast (typ,id),only_local_eff annot))
| LEXP_vector (lexp,id) ->
- n_lexp lexp (fun lexp -> k (LEXP_aux (LEXP_vector (lexp,id),only_local_eff annot)))
+ let (LEXP_aux (_,annot)) = lexp in
+ let effs = geteffs_annot annot in
+ n_lexp lexp (fun lexp -> k (LEXP_aux (LEXP_vector (lexp,id),local_eff_plus effs annot)))
| LEXP_vector_range (lexp,exp1,exp2) ->
n_lexp lexp (fun lexp ->
+ let (LEXP_aux (_,annot)) = lexp in
+ let effs = geteffs_annot annot in
n_exp_name exp1 (fun exp1 ->
n_exp_name exp2 (fun exp2 ->
- k (LEXP_aux (LEXP_vector_range (lexp,exp1,exp2),only_local_eff annot)))))
+ k (LEXP_aux (LEXP_vector_range (lexp,exp1,exp2),local_eff_plus effs annot)))))
| LEXP_field (lexp,id) ->
- n_lexp lexp (fun lexp ->
- k (LEXP_aux (LEXP_field (lexp,id),only_local_eff annot)))
+ n_lexp lexp (fun lexp ->
+ let (LEXP_aux (_,annot)) = lexp in
+ let effs = geteffs_annot annot in
+ k (LEXP_aux (LEXP_field (lexp,id),local_eff_plus effs annot)))
and n_exp_term (new_return : bool) (exp : 'a exp) : 'a exp =
let (E_aux (_,annot)) = exp in
let exp = if new_return then E_aux (E_internal_return exp,annot) else exp in
+ (* changed this from n_exp to n_exp_name so that when we return updated variables
+ * from a for-loop, for example, we can just add those into the returned tuple and
+ * don't need to a-normalise again *)
n_exp exp (fun exp -> exp)
and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
+(* if not (effectful exp) then k exp else comment out this line for full a-normalisation *)
+
let (E_aux (exp_aux,annot)) = exp in
let rewrap_effs effsum exp_aux = (* explicitly give effect sum *)
@@ -1125,7 +1149,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
| E_id id -> k exp (* if value exp then return exp else letbind exp *)
| E_lit _ -> k exp
| E_cast (typ,exp') ->
- n_exp_name exp' (fun exp' ->
+ n_exp exp' (fun exp' ->
k (rewrap_localeff (E_cast (typ,exp'))))
| E_app (id,exps) ->
n_exp_nameL exps (fun exps ->
@@ -1149,7 +1173,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
n_exp_name start (fun start ->
n_exp_name stop (fun stop ->
n_exp_name by (fun by ->
- let body = n_exp_term (effectful body) body in
+ let body = n_exp_term (false) body in
k (rewrap_effs (geteffs body) (E_for (id,start,stop,by,dir,body))))))
| E_vector exps ->
n_exp_nameL exps (fun exps ->
@@ -1230,8 +1254,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
(if effectful exp1
then n_exp_name exp1
else n_exp exp1) (fun exp1 ->
- n_exp exp2 (fun exp2 ->
- k (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,exp2)))))
+ (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,n_exp exp2 k))))
| E_internal_return exp1 ->
n_exp_name exp1 (fun exp1 ->
k (rewrap_localeff (E_internal_return exp1)))
@@ -1251,10 +1274,305 @@ let rewrite_defs_a_normalise =
; rewrite_defs = rewrite_defs_base
}
+let dedup eq = List.fold_left (fun acc e -> if List.exists (eq e) acc then acc else e :: acc) []
+
+let eqidtyp (id1,_) (id2,_) =
+ let name1 = match id1 with Id_aux ((Id name | DeIid name),_) -> name in
+ let name2 = match id1 with Id_aux ((Id name | DeIid name),_) -> name in
+ name1 = name2
+
+let find_updated_vars exp =
+ let names =
+ fold_exp
+ { e_block = (fun es -> List.flatten es)
+ ; e_nondet = (fun es -> List.flatten es)
+ ; e_id = (fun _ -> [])
+ ; e_lit = (fun _ -> [])
+ ; e_cast = (fun (_,e) -> e)
+ ; e_app = (fun (_,es) -> List.flatten es)
+ ; e_app_infix = (fun (e1,_,e2) -> e1 @ e2)
+ ; e_tuple = (fun es -> List.flatten es)
+ ; e_if = (fun (e1,e2,e3) -> e1 @ e2 @ e3)
+ ; e_for = (fun (_,e1,e2,e3,_,e4) -> e1 @ e2 @ e3 @ e4)
+ ; e_vector = (fun es -> List.flatten es)
+ ; e_vector_indexed = (fun (es,opt2) -> (List.flatten (List.map snd es)) @ opt2)
+ ; e_vector_access = (fun (e1,e2) -> e1 @ e2)
+ ; e_vector_subrange = (fun (e1,e2,e3) -> e1 @ e2 @ e3)
+ ; e_vector_update = (fun (e1,e2,e3) -> e1 @ e2 @ e3)
+ ; e_vector_update_subrange = (fun (e1,e2,e3,e4) -> e1 @ e2 @ e3 @ e4)
+ ; e_vector_append = (fun (e1,e2) -> e1 @ e2)
+ ; e_list = (fun es -> List.flatten es)
+ ; e_cons = (fun (e1,e2) -> e1 @ e2)
+ ; e_record = (fun fexps -> fexps)
+ ; e_record_update = (fun (e1,fexp) -> e1 @ fexp)
+ ; e_field = (fun (e1,id) -> e1)
+ ; e_case = (fun (e1,pexps) -> e1 @ List.flatten pexps)
+ ; e_let = (fun (lb,e2) -> lb @ e2)
+ ; e_assign = (fun ((None,[(id,b)]),e2) -> if b then id :: e2 else e2)
+ ; e_exit = (fun e1 -> e1)
+ ; e_internal_cast = (fun (_,e1) -> e1)
+ ; e_internal_exp = (fun _ -> [])
+ ; e_internal_exp_user = (fun _ -> [])
+ ; e_internal_let = (fun ((None,[(id,_)]), e2, e3) -> List.filter (eqidtyp id) (e2 @ e3))
+ ; e_internal_plet = (fun (_, e1, e2) -> e1 @ e2)
+ ; e_internal_return = (fun e -> e)
+ ; e_aux = (fun (e,_) -> e)
+ ; lEXP_id = (fun id -> (Some id,[]))
+ ; lEXP_memory = (fun (_,_) -> (None,[]))
+ ; lEXP_cast = (fun (_,id) -> (Some id,[]))
+ ; lEXP_vector = (fun ((None,lexp),_) -> (None,lexp))
+ ; lEXP_vector_range = (fun ((None,lexp),_,_) -> (None,lexp))
+ ; lEXP_field = (fun ((None,lexp),_) -> (None,lexp))
+ ; lEXP_aux =
+ (function
+ | ((Some id,[]),annot) ->
+ let effs = geteffs_annot annot in
+ let b =
+ match effs with
+ | {effect = Eset [BE_aux (BE_lset,_)]} -> true
+ | _ -> false in
+ (None,[((id,annot),b)])
+ | ((None,es),_) -> (None,es)
+ )
+ ; fE_Fexp = (fun (_,e) -> e)
+ ; fE_aux = (fun (fexp,_) -> fexp)
+ ; fES_Fexps = (fun (fexps,_) -> List.flatten fexps)
+ ; fES_aux = (fun (fexp,_) -> fexp)
+ ; def_val_empty = []
+ ; def_val_dec = (fun e -> e)
+ ; def_val_aux = (fun (defval,_) -> defval)
+ ; pat_exp = (fun (_,e) -> e)
+ ; pat_aux = (fun (pexp,_) -> pexp)
+ ; lB_val_explicit = (fun (_,_,e) -> e)
+ ; lB_val_implicit = (fun (_,e) -> e)
+ ; lB_aux = (fun (lb,_) -> lb)
+ ; pat_alg = id_pat_alg
+ } exp in
+ dedup eqidtyp names
+
+
+
+let swaptyp t (l,(Base ((t_params,_),tag,nexps,eff,effsum,bounds))) =
+ (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds))
+
+let mktup es =
+ if es = [] then
+ E_aux (E_lit (L_aux (L_unit,Unknown)),(Unknown,simple_annot unit_t))
+ else
+ let effs = List.fold_left (fun acc e -> union_effects acc (geteffs e)) {effect = Eset []} es in
+ let typs = List.map gettype es in
+ E_aux (E_tuple es,(Parse_ast.Unknown,simple_annot_efr {t = Ttup typs} effs))
+
+let mktup_pat es =
+ if es = [] then
+ P_aux (P_wild,(Unknown,simple_annot unit_t))
+ else
+ let typs = List.map gettype es in
+ let pats = List.map (fun (E_aux (E_id id,_) as exp) ->
+ P_aux (P_id id,(Unknown,simple_annot (gettype exp)))) es in
+ P_aux (P_tup pats,(Parse_ast.Unknown,simple_annot {t = Ttup typs}))
+
+
+let rec rewrite_for_if_case ((E_aux (expaux,annot)) as exp) =
+
+ let rec add_vars ((E_aux (expaux,annot)) as exp) vars =
+ let rewrap expaux = E_aux (expaux,annot) in
+ match expaux with
+ | E_let (lb,exp) -> rewrap (E_let (lb,add_vars exp vars))
+ | E_internal_let (lexp,exp1,exp2) -> rewrap (E_internal_let (lexp,exp1,add_vars exp2 vars))
+ | E_internal_plet (pat,exp1,exp2) -> rewrap (E_internal_plet (pat,exp1,add_vars exp2 vars))
+ | E_internal_return exp2 ->
+ E_aux (E_internal_return (add_vars exp2 vars),
+ swaptyp {t = Ttup [gettype exp;gettype vars]} annot)
+ | _ ->
+ (* after a-normalisation this will be pure:
+ * if the whole body of the function/if-expression/case-expression/for-loop was
+ * pure, then it's still pure; if it wasn't then the body was wrapped in E_return
+ * and (in this case) exp is a name contained in E_return that by definition of
+ * value must be pure
+ *)
+ let () = assert (not (effectful exp)) in
+ E_aux (E_tuple [exp;vars],swaptyp {t = Ttup [gettype exp;gettype vars]} annot)in
+
+ let rewrite (E_aux (eaux,annot)) =
+ match eaux with
+ | E_for(id,exp1,exp2,exp3,order,exp4) ->
+ let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars exp4) in
+ let vartuple = mktup vars in
+ let exp4 = rewrite_for_if_case (add_vars exp4 vartuple) in
+ let funcl = match order with
+ | Ord_aux (Ord_inc,_) -> Id_aux (Id "foreach_inc",Unknown)
+ | Ord_aux (Ord_dec,_) -> Id_aux (Id "foreach_dec",Unknown) in
+ Some (E_aux (E_app (funcl,[mktup [exp1;exp2;exp3];exp4;vartuple]),
+ swaptyp (gettype exp4) annot),vars)
+ | E_if (c,e1,e2) ->
+ let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t)))
+ (dedup eqidtyp (find_updated_vars e1 @ find_updated_vars e2)) in
+ if vars = [] then None else
+ let vartuple = mktup vars in
+ let e1 = rewrite_for_if_case (add_vars e1 vartuple) in
+ let e2 = rewrite_for_if_case (add_vars e2 vartuple) in
+ (* after a-normalisation c shouldn't need rewriting *)
+ let t = gettype e1 in
+ (* let () = assert (simple_annot t = simple_annot (gettype e2)) in *)
+ Some (E_aux (E_if (c,e1,e2), swaptyp t annot),vars)
+ | E_case (e1,ps) ->
+ (* after a-normalisation e1 shouldn't need rewriting *)
+ let vars =
+ let f acc (Pat_aux (Pat_exp (_,e),_)) = acc @ find_updated_vars e in
+ List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t)))
+ (dedup eqidtyp (List.fold_left f [] ps)) in
+ if vars = [] then None else
+ let vartuple = mktup vars in
+ let typ =
+ let (Pat_aux (Pat_exp (_,first),_)) = List.hd ps in
+ gettype first in
+ let (ps,typ) =
+ let f (acc,typ) (Pat_aux (Pat_exp (p,e),pannot)) =
+ let etyp = gettype e in
+ let () = assert (simple_annot etyp = simple_annot typ) in
+ let e = rewrite_for_if_case (add_vars e vartuple) in
+ let pannot = swaptyp (gettype e) pannot in
+ (acc @ [Pat_aux (Pat_exp (p,e),pannot)],typ) in
+ List.fold_left f ([],typ) ps in
+ Some (E_aux (E_case (e1,ps), swaptyp typ annot),vars)
+ | _ ->
+ (* assumes everying's a-normlised: an expression is a sequence of let-expressions,
+ * "control-flow" structures and a return value, possibly wrapped in E_return *)
+ None in
+
+ match expaux with
+ | E_let (lb,body) ->
+ let body = rewrite_for_if_case body in
+ let lb = match lb with
+ | LB_aux (LB_val_implicit (pat,v),lbannot) ->
+ (match rewrite v with
+ | Some (v,vars) ->
+ let varpat = mktup_pat vars in
+ let (P_aux (_,pannot)) = pat in
+ let pat = P_aux (P_tup [pat;varpat],swaptyp (gettype v) annot) in
+ let lbannot = swaptyp (gettype v) lbannot in
+ LB_aux (LB_val_implicit (pat,v),lbannot)
+ | None -> lb)
+ | LB_aux (LB_val_explicit (typ,pat,v),lbannot) ->
+ (match rewrite v with
+ | Some (v,vars) ->
+ let varpat = mktup_pat vars in
+ let (P_aux (_,pannot)) = pat in
+ let pat = P_aux (P_tup [pat;varpat],swaptyp (gettype v) annot) in
+ let lbannot = swaptyp (gettype v) lbannot in
+ LB_aux (LB_val_implicit (pat,v),lbannot)
+ | None -> lb) in
+ (* as let-expressions have type unit exp's annot doesn't need to change *)
+ E_aux (E_let (lb,body),annot)
+ | E_internal_plet (pat,v,body) ->
+ let body = rewrite_for_if_case body in
+ (match rewrite v with
+ | Some (v,vars) ->
+ let varpat = mktup_pat vars in
+ let (P_aux (_,pannot)) = pat in
+ let pat = P_aux (P_tup [pat;varpat],swaptyp (gettype v) annot) in
+ (* as let-expressions have type unit exp's annot doesn't need to change *)
+ E_aux (E_internal_plet (pat,v,body),annot)
+ | None -> E_aux (E_internal_plet (pat,v,body),annot))
+ | E_internal_let (lexp,v,body) ->
+ (* because we need patterns and internal_plets are needed to distinguish monadic
+ * expressions E_internal_lets are rewritten to E_lets. We only need them for OCaml
+ * anyways. *)
+ let body = rewrite_for_if_case body in
+ let id = match lexp with
+ | LEXP_aux (LEXP_id id,_) -> id
+ | LEXP_aux (LEXP_cast (_,id),_) -> id in
+ let pat = P_aux (P_id id, (Parse_ast.Unknown,simple_annot (gettype v))) in
+ let lb = (match rewrite v with
+ | Some (v,vars) ->
+ let varpat = mktup_pat vars in
+ let (P_aux (_,pannot)) = pat in
+ let pat = P_aux (P_tup [pat;varpat],swaptyp (gettype v) annot) in
+ let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
+ LB_aux (LB_val_implicit (pat,v),lbannot)
+ | None ->
+ let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
+ LB_aux (LB_val_implicit (pat,v),lbannot)) in
+ E_aux (E_let (lb,body),annot)
+ (* In tail-position only for-loops matter: if if-expressions or pattern-matching expressions
+ * are in tail-position their updated variables won't be read anyways. For for-loops, however,
+ * we do have to rewrite but also make sure the return type is as expected. *)
+ | E_for _ ->
+ let (Some (exp,_)) = rewrite exp in
+ let annot_pat = (Parse_ast.Unknown,simple_annot (gettype exp)) in
+ let pat = (P_aux (P_wild,annot_pat)) in
+ let body = E_aux (E_lit (L_aux (L_unit,Unknown)),(Unknown,simple_annot unit_t)) in
+ let annot_lb = annot_pat in
+ let annot_let = (Parse_ast.Unknown,simple_annot unit_t) in
+ if effectful exp
+ then E_aux (E_internal_plet (pat,exp,body),annot_let)
+ else E_aux (E_let (LB_aux (LB_val_implicit (pat,exp),annot_lb),body),annot_let)
+
+ | _ -> exp
+
+let replace_e_assign =
+
+ let e_aux (expaux,annot) =
+
+ let letbind (E_aux (E_id id,_) as e_id) (v : 'a exp) (body : 'a exp) : 'a exp =
+ (* body is a function : E_id variable -> actual body *)
+ let (E_aux (_,annot)) = v in
+ let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in
+ let annot_lb = annot_pat in
+ let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (geteffs body)) in
+ let pat = P_aux (P_id id,annot_pat) in
+ E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) in
+
+ let f v body = function
+ | LEXP_aux (LEXP_id id,annot) ->
+ let eid = E_aux (E_id id,(Unknown,simple_annot (gettype v))) in
+ letbind eid v body
+ | LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,annot2),i),annot) ->
+ let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in
+ let v = E_aux (E_vector_update (eid,i,v),(Unknown,simple_annot (gettype_annot annot))) in
+ letbind eid v body
+ | LEXP_aux (LEXP_vector_range (LEXP_aux (LEXP_id id,annot2),i,j),annot) ->
+ let eid = E_aux (E_id id,(Unknown,simple_annot (gettype_annot annot2))) in
+ let v = E_aux (E_vector_update_subrange (eid,i,j,v),
+ (Unknown,simple_annot (gettype_annot annot))) in
+ letbind eid v body in
+
+ match expaux with
+ | E_let (LB_aux (LB_val_explicit (_,_,E_aux (E_assign (lexp,v),annot2)),_),body)
+ | E_let (LB_aux (LB_val_implicit (_,E_aux (E_assign (lexp,v),annot2)),_),body)
+ when
+ let {effect = Eset effs} = geteffs_annot annot2 in
+ List.exists (function BE_aux (BE_lset,_) -> true | _ -> false) effs ->
+ f v body lexp
+ | E_let (lb,body) -> E_aux (E_let (lb,body),annot)
+ (* E_internal_plet is only used for effectful terms, shouldn't be needed to deal with here *)
+ | E_internal_let (LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_),v,body) ->
+ let (E_aux (_,pannot)) = v in
+ let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype body) (geteffs body)) in
+ E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_id id,pannot),v),lbannot),body),annot)
+ | _ -> E_aux (expaux,annot) in
+
+ { id_exp_alg with e_aux = e_aux }
+
+let rewrite_defs_remove_e_assign =
+ let rewrite_exp _ _ e = (fold_exp replace_e_assign) (rewrite_for_if_case e) in
+ rewrite_defs_base
+ {rewrite_exp = rewrite_exp
+ ; rewrite_pat = rewrite_pat
+ ; rewrite_let = rewrite_let
+ ; rewrite_lexp = rewrite_lexp
+ ; rewrite_fun = rewrite_fun
+ ; rewrite_def = rewrite_def
+ ; rewrite_defs = rewrite_defs_base
+ }
+
+
let rewrite_defs_lem defs =
let defs = rewrite_defs_remove_vector_concat defs in
let defs = rewrite_defs_exp_lift_assign defs in
let defs = rewrite_defs_a_normalise defs in
+ let defs = rewrite_defs_remove_e_assign defs in
defs