summaryrefslogtreecommitdiff
path: root/src/rewriter.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewriter.ml')
-rw-r--r--src/rewriter.ml137
1 files changed, 69 insertions, 68 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 3c4ff5e8..c7e6986b 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -25,6 +25,43 @@ let fresh_name () =
let () = fresh_name_counter := (current + 1) in
current
+let geteffs_annot (_,t) = match t with
+ | Base (_,_,_,_,effs,_) -> effs
+ | NoTyp -> failwith "no effect information"
+ | _ -> failwith "a_normalise doesn't support Overload"
+
+let gettype_annot (_,t) = match t with
+ | Base((_,t),_,_,_,_,_) -> t
+ | NoTyp -> failwith "no type information"
+ | _ -> failwith "a_normalise doesn't support Overload"
+
+let gettype (E_aux (_,a)) = gettype_annot a
+let geteffs (E_aux (_,a)) = geteffs_annot a
+
+let effectful_effs {effect = Eset effs} =
+ List.exists
+ (fun (BE_aux (be,_)) ->
+ match be with
+ | BE_nondet | BE_unspec | BE_undef | BE_lset -> false
+ | _ -> true
+ ) effs
+
+let effectful eaux =
+ effectful_effs (geteffs eaux)
+
+let updates_vars_effs {effect = Eset effs} =
+ List.exists
+ (fun (BE_aux (be,_)) ->
+ match be with
+ | BE_lset -> true
+ | _ -> false
+ ) effs
+
+let updates_vars eaux =
+ updates_vars_effs (geteffs eaux)
+
+let eff_union es =
+ List.fold_left (fun acc e -> union_effects acc (geteffs e)) pure_e es
let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with
| [] -> None
@@ -871,8 +908,9 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful
let le' = rewriters.rewrite_lexp rewriters nmap le in
let e' = rewrite_base e in
let exps' = walker exps in
- [(E_aux (E_internal_let(le', e', E_aux(E_block exps', (l, simple_annot {t=Tid "unit"}))),
- (l, simple_annot t)))]
+ let effects = eff_union exps' in
+ [E_aux (E_internal_let(le', e', E_aux(E_block exps', (l, simple_annot_efr {t=Tid "unit"} effects))),
+ (l, simple_annot_efr t (eff_union (e::exps'))))]
| ((E_aux(E_if(c,t,e),(l,annot))) as exp)::exps ->
let vars_t = introduced_variables t in
let vars_e = introduced_variables e in
@@ -886,8 +924,9 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful
let c' = rewrite_base c in
let t' = rewriters.rewrite_exp rewriters new_nmap t in
let e' = rewriters.rewrite_exp rewriters new_nmap e in
- Envmap.fold
- (fun res i (t,e) ->
+ let exps' = walker exps in
+ fst ((Envmap.fold
+ (fun (res,effects) i (t,e) ->
let bitlit = E_aux (E_lit (L_aux(L_zero, Parse_ast.Unknown)),
(Parse_ast.Unknown, simple_annot bit_t)) in
let rangelit = E_aux (E_lit (L_aux (L_num 0, Parse_ast.Unknown)),
@@ -907,12 +946,13 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful
(Parse_ast.Unknown,simple_annot bit_t))),
(Parse_ast.Unknown, simple_annot t))
| _ -> e in
- [E_aux (E_internal_let (LEXP_aux (LEXP_id (Id_aux (Id i, Parse_ast.Unknown)),
+ let unioneffs = union_effects effects (geteffs set_exp) in
+ ([E_aux (E_internal_let (LEXP_aux (LEXP_id (Id_aux (Id i, Parse_ast.Unknown)),
(Parse_ast.Unknown, (tag_annot t Emp_intro))),
set_exp,
- E_aux (E_block res, (Parse_ast.Unknown, (simple_annot unit_t)))),
- (Parse_ast.Unknown, simple_annot unit_t))])
- (E_aux(E_if(c',t',e'), (Parse_ast.Unknown, annot))::(walker exps)) new_vars
+ E_aux (E_block res, (Parse_ast.Unknown, (simple_annot_efr unit_t effects)))),
+ (Parse_ast.Unknown, simple_annot_efr unit_t unioneffs))],unioneffs)))
+ (E_aux(E_if(c',t',e'), (Parse_ast.Unknown, annot))::exps',eff_union exps') new_vars)
| e::exps -> (rewrite_rec e)::(walker exps)
in
rewrap (E_block (walker exps))
@@ -923,7 +963,7 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful
(match le' with
| LEXP_aux(_, (_,Base(_,Emp_intro,_,_,_,_))) ->
let e' = rewrite_base e in
- rewrap (E_internal_let(le', e', E_aux(E_block [], (l, simple_annot unit_t))))
+ rewrap (E_internal_let(le', e', E_aux(E_block [], (l, simple_annot_efr unit_t (geteffs e')))))
| _ -> E_aux((E_assign(le', rewrite_base e)),(l, tag_annot unit_t Emp_set)))
| _ -> rewrite_base full_exp)
| _ -> rewrite_base full_exp
@@ -961,52 +1001,12 @@ let rewrite_defs_ocaml defs =
let defs_lifted_assign = rewrite_defs_exp_lift_assign defs_vec_concat_removed in
defs_lifted_assign
-
-
-let geteffs_annot (_,t) = match t with
- | Base (_,_,_,_,effs,_) -> effs
- | NoTyp -> failwith "no effect information"
- | _ -> failwith "a_normalise doesn't support Overload"
-
-let gettype_annot (_,t) = match t with
- | Base((_,t),_,_,_,_,_) -> t
- | NoTyp -> failwith "no type information"
- | _ -> failwith "a_normalise doesn't support Overload"
-
-let gettype (E_aux (_,a)) = gettype_annot a
-let geteffs (E_aux (_,a)) = geteffs_annot a
-
-let effectful_effs {effect = Eset effs} =
- List.exists
- (fun (BE_aux (be,_)) ->
- match be with
- | BE_nondet | BE_unspec | BE_undef | BE_lset -> false
- | _ -> true
- ) effs
-
-let effectful eaux =
- effectful_effs (geteffs eaux)
-
-let updates_vars_effs {effect = Eset effs} =
- List.exists
- (fun (BE_aux (be,_)) ->
- match be with
- | BE_lset -> true
- | _ -> false
- ) effs
-
-let updates_vars eaux =
- updates_vars_effs (geteffs eaux)
-
-let eff_union e1 e2 = union_effects (geteffs e1) (geteffs e2)
-
-let remove_blocks_exp_alg =
+let remove_blocks =
let letbind_wild v body =
- let annot_pat = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
- let annot_lb = annot_pat in
- let annot_let =
- (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v body)) in
+ let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in
+ let annot_lb = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
+ let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in
E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_wild,annot_pat),v),annot_lb),body),annot_let) in
let rec f = function
@@ -1018,7 +1018,7 @@ let remove_blocks_exp_alg =
| (E_block es,annot) -> f es
| (e,annot) -> E_aux (e,annot) in
- { id_exp_alg with e_aux = e_aux }
+ fold_exp { id_exp_alg with e_aux = e_aux }
let fresh_id annot =
@@ -1036,7 +1036,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
let body = body e in
let annot_pat = (Parse_ast.Unknown,simple_annot unit_t) in
let annot_lb = annot_pat in
- let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union v body)) in
+ let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in
let pat = P_aux (P_wild,annot_pat) in
if effectful v
@@ -1049,7 +1049,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in
let annot_lb = annot_pat in
- let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union v body)) in
+ let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in
let pat = P_aux (P_id id,annot_pat) in
if effectful v
@@ -1106,7 +1106,7 @@ let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
and n_exp_pure (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
n_exp exp (fun exp -> if not (effectful exp || updates_vars exp) then k exp else letbind exp k)
-
+
and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp =
mapCont n_exp_name exps k
@@ -1170,7 +1170,6 @@ and n_lexp (lexp : 'a lexp) (k : 'a lexp -> 'a exp) : 'a exp =
k (LEXP_aux (LEXP_field (lexp,id),local_eff_plus effs annot)))
and n_exp_term (new_return : bool) (exp : 'a exp) : 'a exp =
- let (E_aux (_,annot)) = exp in
let exp =
if new_return
then E_aux (E_internal_return exp,(Unknown,simple_annot_efr (gettype exp) (geteffs exp)))
@@ -1187,8 +1186,8 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
let (E_aux (exp_aux,annot)) = exp in
let rewrap_effs effsum exp_aux = (* explicitly give effect sum *)
- let (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds)) = annot in
- E_aux (exp_aux, (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds))) in
+ let (l,Base (t,tag,nexps,eff,_,bounds)) = annot in
+ E_aux (exp_aux, (l,Base (t,tag,nexps,eff,effsum,bounds))) in
let rewrap_localeff exp_aux = (* give exp_aux the local effect as the effect sum *)
E_aux (exp_aux,only_local_eff annot) in
@@ -1218,7 +1217,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
let new_return = effectful exp2 || effectful exp3 in
let exp2 = n_exp_term new_return exp2 in
let exp3 = n_exp_term new_return exp3 in
- k (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3))))
+ k (rewrap_effs (eff_union [exp2;exp3]) (E_if (exp1,exp2,exp3))))
| E_for (id,start,stop,by,dir,body) ->
n_exp_name start (fun start ->
n_exp_name stop (fun stop ->
@@ -1292,7 +1291,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
| LB_aux (LB_val_explicit (_,pat,exp'),annot')
| LB_aux (LB_val_implicit (pat,exp'),annot') ->
if effectful exp'
- then (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',n_exp body k)))
+ then (rewrap_effs (eff_union [exp';body]) (E_internal_plet (pat,exp',n_exp body k)))
else (rewrap_effs (geteffs body) (E_let (lb,n_exp body k)))
))
| E_assign (lexp,exp1) ->
@@ -1309,7 +1308,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
(if effectful exp1
then n_exp_name exp1
else n_exp exp1) (fun exp1 ->
- (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,n_exp exp2 k))))
+ rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,n_exp exp2 k)))
| E_internal_return exp1 ->
n_exp_name exp1 (fun exp1 ->
k (rewrap_localeff (E_internal_return exp1)))
@@ -1318,7 +1317,8 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
let rewrite_defs_a_normalise =
let rewrite_exp _ _ e =
- n_exp_term (effectful e) (fold_exp remove_blocks_exp_alg e) in
+ let e = remove_blocks e in
+ n_exp_term (effectful e) e in
rewrite_defs_base
{rewrite_exp = rewrite_exp
; rewrite_pat = rewrite_pat
@@ -1376,7 +1376,8 @@ let find_updated_vars exp =
; e_internal_cast = (fun (_,e1) -> e1)
; e_internal_exp = (fun _ -> [])
; e_internal_exp_user = (fun _ -> [])
- ; e_internal_let = (fun ((None,[(id,_)]), e2, e3) -> List.filter (eqidtyp id) (e2 @ e3))
+ ; e_internal_let = (fun ((None,[(id,_)]), e2, e3) ->
+ List.filter (fun id2 -> not (eqidtyp id id2)) (e2 @ e3))
; e_internal_plet = (fun (_, e1, e2) -> e1 @ e2)
; e_internal_return = (fun e -> e)
; e_aux = (fun (e,_) -> e)
@@ -1523,7 +1524,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) =
(* after a-normalisation c shouldn't need rewriting *)
let t = gettype e1 in
(* let () = assert (simple_annot t = simple_annot (gettype e2)) in *)
- let v = E_aux (E_if (c,e1,e2), (Unknown,simple_annot_efr t (eff_union e1 e2))) in
+ let v = E_aux (E_if (c,e1,e2), (Unknown,simple_annot_efr t (eff_union [e1;e2]))) in
let pat =
(* if overwrite then
mktup_pat vars
@@ -1615,7 +1616,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) =
(match rewrite v pat with
| Added_vars (v,pat) ->
E_aux (E_internal_plet (pat,v,body),
- (Unknown,simple_annot_efr (gettype body) (eff_union v body)))
+ (Unknown,simple_annot_efr (gettype body) (eff_union [v;body])))
| Same_vars v -> E_aux (E_internal_plet (pat,v,body),annot))
| E_internal_let (lexp,v,body) ->
(* After a-normalisation E_internal_lets can only bind values to names, those don't
@@ -1627,7 +1628,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) =
let pat = P_aux (P_id id, (Parse_ast.Unknown,simple_annot (gettype v))) in
let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
let lb = LB_aux (LB_val_implicit (pat,v),lbannot) in
- E_aux (E_let (lb,body),(Unknown,simple_annot_efr (gettype body) (eff_union v body)))
+ E_aux (E_let (lb,body),(Unknown,simple_annot_efr (gettype body) (eff_union [v;body])))
(* In tail-position there shouldn't be anything we need to do as the terms after
* a-normalisation are pure and don't update local variables. There can't be any variable
* assignments in tail-position (because of the effect), there could be pure pattern-match