diff options
Diffstat (limited to 'src/rewriter.ml')
| -rw-r--r-- | src/rewriter.ml | 137 |
1 files changed, 69 insertions, 68 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index 3c4ff5e8..c7e6986b 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -25,6 +25,43 @@ let fresh_name () = let () = fresh_name_counter := (current + 1) in current +let geteffs_annot (_,t) = match t with + | Base (_,_,_,_,effs,_) -> effs + | NoTyp -> failwith "no effect information" + | _ -> failwith "a_normalise doesn't support Overload" + +let gettype_annot (_,t) = match t with + | Base((_,t),_,_,_,_,_) -> t + | NoTyp -> failwith "no type information" + | _ -> failwith "a_normalise doesn't support Overload" + +let gettype (E_aux (_,a)) = gettype_annot a +let geteffs (E_aux (_,a)) = geteffs_annot a + +let effectful_effs {effect = Eset effs} = + List.exists + (fun (BE_aux (be,_)) -> + match be with + | BE_nondet | BE_unspec | BE_undef | BE_lset -> false + | _ -> true + ) effs + +let effectful eaux = + effectful_effs (geteffs eaux) + +let updates_vars_effs {effect = Eset effs} = + List.exists + (fun (BE_aux (be,_)) -> + match be with + | BE_lset -> true + | _ -> false + ) effs + +let updates_vars eaux = + updates_vars_effs (geteffs eaux) + +let eff_union es = + List.fold_left (fun acc e -> union_effects acc (geteffs e)) pure_e es let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with | [] -> None @@ -871,8 +908,9 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful let le' = rewriters.rewrite_lexp rewriters nmap le in let e' = rewrite_base e in let exps' = walker exps in - [(E_aux (E_internal_let(le', e', E_aux(E_block exps', (l, simple_annot {t=Tid "unit"}))), - (l, simple_annot t)))] + let effects = eff_union exps' in + [E_aux (E_internal_let(le', e', E_aux(E_block exps', (l, simple_annot_efr {t=Tid "unit"} effects))), + (l, simple_annot_efr t (eff_union (e::exps'))))] | ((E_aux(E_if(c,t,e),(l,annot))) as exp)::exps -> let vars_t = introduced_variables t in let vars_e = introduced_variables e in @@ -886,8 +924,9 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful let c' = rewrite_base c in let t' = rewriters.rewrite_exp rewriters new_nmap t in let e' = rewriters.rewrite_exp rewriters new_nmap e in - Envmap.fold - (fun res i (t,e) -> + let exps' = walker exps in + fst ((Envmap.fold + (fun (res,effects) i (t,e) -> let bitlit = E_aux (E_lit (L_aux(L_zero, Parse_ast.Unknown)), (Parse_ast.Unknown, simple_annot bit_t)) in let rangelit = E_aux (E_lit (L_aux (L_num 0, Parse_ast.Unknown)), @@ -907,12 +946,13 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful (Parse_ast.Unknown,simple_annot bit_t))), (Parse_ast.Unknown, simple_annot t)) | _ -> e in - [E_aux (E_internal_let (LEXP_aux (LEXP_id (Id_aux (Id i, Parse_ast.Unknown)), + let unioneffs = union_effects effects (geteffs set_exp) in + ([E_aux (E_internal_let (LEXP_aux (LEXP_id (Id_aux (Id i, Parse_ast.Unknown)), (Parse_ast.Unknown, (tag_annot t Emp_intro))), set_exp, - E_aux (E_block res, (Parse_ast.Unknown, (simple_annot unit_t)))), - (Parse_ast.Unknown, simple_annot unit_t))]) - (E_aux(E_if(c',t',e'), (Parse_ast.Unknown, annot))::(walker exps)) new_vars + E_aux (E_block res, (Parse_ast.Unknown, (simple_annot_efr unit_t effects)))), + (Parse_ast.Unknown, simple_annot_efr unit_t unioneffs))],unioneffs))) + (E_aux(E_if(c',t',e'), (Parse_ast.Unknown, annot))::exps',eff_union exps') new_vars) | e::exps -> (rewrite_rec e)::(walker exps) in rewrap (E_block (walker exps)) @@ -923,7 +963,7 @@ let rewrite_exp_lift_assign_intro rewriters nmap ((E_aux (exp,(l,annot))) as ful (match le' with | LEXP_aux(_, (_,Base(_,Emp_intro,_,_,_,_))) -> let e' = rewrite_base e in - rewrap (E_internal_let(le', e', E_aux(E_block [], (l, simple_annot unit_t)))) + rewrap (E_internal_let(le', e', E_aux(E_block [], (l, simple_annot_efr unit_t (geteffs e'))))) | _ -> E_aux((E_assign(le', rewrite_base e)),(l, tag_annot unit_t Emp_set))) | _ -> rewrite_base full_exp) | _ -> rewrite_base full_exp @@ -961,52 +1001,12 @@ let rewrite_defs_ocaml defs = let defs_lifted_assign = rewrite_defs_exp_lift_assign defs_vec_concat_removed in defs_lifted_assign - - -let geteffs_annot (_,t) = match t with - | Base (_,_,_,_,effs,_) -> effs - | NoTyp -> failwith "no effect information" - | _ -> failwith "a_normalise doesn't support Overload" - -let gettype_annot (_,t) = match t with - | Base((_,t),_,_,_,_,_) -> t - | NoTyp -> failwith "no type information" - | _ -> failwith "a_normalise doesn't support Overload" - -let gettype (E_aux (_,a)) = gettype_annot a -let geteffs (E_aux (_,a)) = geteffs_annot a - -let effectful_effs {effect = Eset effs} = - List.exists - (fun (BE_aux (be,_)) -> - match be with - | BE_nondet | BE_unspec | BE_undef | BE_lset -> false - | _ -> true - ) effs - -let effectful eaux = - effectful_effs (geteffs eaux) - -let updates_vars_effs {effect = Eset effs} = - List.exists - (fun (BE_aux (be,_)) -> - match be with - | BE_lset -> true - | _ -> false - ) effs - -let updates_vars eaux = - updates_vars_effs (geteffs eaux) - -let eff_union e1 e2 = union_effects (geteffs e1) (geteffs e2) - -let remove_blocks_exp_alg = +let remove_blocks = let letbind_wild v body = - let annot_pat = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in - let annot_lb = annot_pat in - let annot_let = - (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v body)) in + let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in + let annot_lb = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in + let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_wild,annot_pat),v),annot_lb),body),annot_let) in let rec f = function @@ -1018,7 +1018,7 @@ let remove_blocks_exp_alg = | (E_block es,annot) -> f es | (e,annot) -> E_aux (e,annot) in - { id_exp_alg with e_aux = e_aux } + fold_exp { id_exp_alg with e_aux = e_aux } let fresh_id annot = @@ -1036,7 +1036,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = let body = body e in let annot_pat = (Parse_ast.Unknown,simple_annot unit_t) in let annot_lb = annot_pat in - let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union v body)) in + let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in let pat = P_aux (P_wild,annot_pat) in if effectful v @@ -1049,7 +1049,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = let annot_pat = (Parse_ast.Unknown,simple_annot (gettype v)) in let annot_lb = annot_pat in - let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union v body)) in + let annot_let = (Parse_ast.Unknown,simple_annot_efr (gettype body) (eff_union [v;body])) in let pat = P_aux (P_id id,annot_pat) in if effectful v @@ -1106,7 +1106,7 @@ let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = and n_exp_pure (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = n_exp exp (fun exp -> if not (effectful exp || updates_vars exp) then k exp else letbind exp k) - + and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp = mapCont n_exp_name exps k @@ -1170,7 +1170,6 @@ and n_lexp (lexp : 'a lexp) (k : 'a lexp -> 'a exp) : 'a exp = k (LEXP_aux (LEXP_field (lexp,id),local_eff_plus effs annot))) and n_exp_term (new_return : bool) (exp : 'a exp) : 'a exp = - let (E_aux (_,annot)) = exp in let exp = if new_return then E_aux (E_internal_return exp,(Unknown,simple_annot_efr (gettype exp) (geteffs exp))) @@ -1187,8 +1186,8 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = let (E_aux (exp_aux,annot)) = exp in let rewrap_effs effsum exp_aux = (* explicitly give effect sum *) - let (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds)) = annot in - E_aux (exp_aux, (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds))) in + let (l,Base (t,tag,nexps,eff,_,bounds)) = annot in + E_aux (exp_aux, (l,Base (t,tag,nexps,eff,effsum,bounds))) in let rewrap_localeff exp_aux = (* give exp_aux the local effect as the effect sum *) E_aux (exp_aux,only_local_eff annot) in @@ -1218,7 +1217,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = let new_return = effectful exp2 || effectful exp3 in let exp2 = n_exp_term new_return exp2 in let exp3 = n_exp_term new_return exp3 in - k (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3)))) + k (rewrap_effs (eff_union [exp2;exp3]) (E_if (exp1,exp2,exp3)))) | E_for (id,start,stop,by,dir,body) -> n_exp_name start (fun start -> n_exp_name stop (fun stop -> @@ -1292,7 +1291,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = | LB_aux (LB_val_explicit (_,pat,exp'),annot') | LB_aux (LB_val_implicit (pat,exp'),annot') -> if effectful exp' - then (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',n_exp body k))) + then (rewrap_effs (eff_union [exp';body]) (E_internal_plet (pat,exp',n_exp body k))) else (rewrap_effs (geteffs body) (E_let (lb,n_exp body k))) )) | E_assign (lexp,exp1) -> @@ -1309,7 +1308,7 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = (if effectful exp1 then n_exp_name exp1 else n_exp exp1) (fun exp1 -> - (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,n_exp exp2 k)))) + rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,n_exp exp2 k))) | E_internal_return exp1 -> n_exp_name exp1 (fun exp1 -> k (rewrap_localeff (E_internal_return exp1))) @@ -1318,7 +1317,8 @@ and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = let rewrite_defs_a_normalise = let rewrite_exp _ _ e = - n_exp_term (effectful e) (fold_exp remove_blocks_exp_alg e) in + let e = remove_blocks e in + n_exp_term (effectful e) e in rewrite_defs_base {rewrite_exp = rewrite_exp ; rewrite_pat = rewrite_pat @@ -1376,7 +1376,8 @@ let find_updated_vars exp = ; e_internal_cast = (fun (_,e1) -> e1) ; e_internal_exp = (fun _ -> []) ; e_internal_exp_user = (fun _ -> []) - ; e_internal_let = (fun ((None,[(id,_)]), e2, e3) -> List.filter (eqidtyp id) (e2 @ e3)) + ; e_internal_let = (fun ((None,[(id,_)]), e2, e3) -> + List.filter (fun id2 -> not (eqidtyp id id2)) (e2 @ e3)) ; e_internal_plet = (fun (_, e1, e2) -> e1 @ e2) ; e_internal_return = (fun e -> e) ; e_aux = (fun (e,_) -> e) @@ -1523,7 +1524,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = (* after a-normalisation c shouldn't need rewriting *) let t = gettype e1 in (* let () = assert (simple_annot t = simple_annot (gettype e2)) in *) - let v = E_aux (E_if (c,e1,e2), (Unknown,simple_annot_efr t (eff_union e1 e2))) in + let v = E_aux (E_if (c,e1,e2), (Unknown,simple_annot_efr t (eff_union [e1;e2]))) in let pat = (* if overwrite then mktup_pat vars @@ -1615,7 +1616,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = (match rewrite v pat with | Added_vars (v,pat) -> E_aux (E_internal_plet (pat,v,body), - (Unknown,simple_annot_efr (gettype body) (eff_union v body))) + (Unknown,simple_annot_efr (gettype body) (eff_union [v;body]))) | Same_vars v -> E_aux (E_internal_plet (pat,v,body),annot)) | E_internal_let (lexp,v,body) -> (* After a-normalisation E_internal_lets can only bind values to names, those don't @@ -1627,7 +1628,7 @@ let rec rewrite_var_updates ((E_aux (expaux,annot)) as exp) = let pat = P_aux (P_id id, (Parse_ast.Unknown,simple_annot (gettype v))) in let lbannot = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in let lb = LB_aux (LB_val_implicit (pat,v),lbannot) in - E_aux (E_let (lb,body),(Unknown,simple_annot_efr (gettype body) (eff_union v body))) + E_aux (E_let (lb,body),(Unknown,simple_annot_efr (gettype body) (eff_union [v;body]))) (* In tail-position there shouldn't be anything we need to do as the terms after * a-normalisation are pure and don't update local variables. There can't be any variable * assignments in tail-position (because of the effect), there could be pure pattern-match |
