From dd052bfc3e00a1ae988044ae81dd1624332dd899 Mon Sep 17 00:00:00 2001 From: Christopher Pulte Date: Sun, 25 Sep 2016 15:14:12 +0100 Subject: nicer lem output: no more unecessary 'unit' returns if if-expressions, for-loops or case-expressions also return updated variables --- src/gen_lib/prompt.lem | 12 +-- src/gen_lib/sail_values.lem | 27 +++--- src/gen_lib/state.lem | 12 +-- src/gen_lib/vector.lem | 7 +- src/pretty_print.ml | 1 - src/rewriter.ml | 228 ++++++++++++++++++-------------------------- 6 files changed, 120 insertions(+), 167 deletions(-) (limited to 'src') diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem index 4cd76156..a9aa3218 100644 --- a/src/gen_lib/prompt.lem +++ b/src/gen_lib/prompt.lem @@ -247,23 +247,23 @@ val write_reg_bitfield : forall 'e. register -> register_bitfield -> bit -> M 'e let write_reg_bitfield reg rbit = write_reg_bit reg (field_index_bit rbit) val foreachM_inc : forall 'e 'vars. (nat * nat * nat) -> 'vars -> - (nat -> 'vars -> M 'e (unit * 'vars)) -> M 'e (unit * 'vars) + (nat -> 'vars -> M 'e 'vars) -> M 'e 'vars let rec foreachM_inc (i,stop,by) vars body = if i <= stop then - body i vars >>= fun (_,vars) -> + body i vars >>= fun vars -> foreachM_inc (i + by,stop,by) vars body - else return ((),vars) + else return vars val foreachM_dec : forall 'e 'vars. (nat * nat * nat) -> 'vars -> - (nat -> 'vars -> M 'e (unit * 'vars)) -> M 'e (unit * 'vars) + (nat -> 'vars -> M 'e 'vars) -> M 'e 'vars let rec foreachM_dec (i,stop,by) vars body = if i >= stop then - body i vars >>= fun (_,vars) -> + body i vars >>= fun vars -> foreachM_dec (i - by,stop,by) vars body - else return ((),vars) + else return vars let write_two_regs r1 r2 vec = diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index b9a4fbd1..454778c4 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -527,22 +527,21 @@ let toNaturalFiveTup (n1,n2,n3,n4,n5) = toNatural n5) - - -val foreach_inc : forall 'vars. (integer * integer * integer) (*(nat * nat * nat)*) -> 'vars -> - (integer (*nat*) -> 'vars -> (unit * 'vars)) -> (unit * 'vars) +val foreach_inc : forall 'vars. (integer * integer * integer) -> 'vars -> + (integer -> 'vars -> 'vars) -> 'vars let rec foreach_inc (i,stop,by) vars body = if i <= stop - then - let (_,vars) = body i vars in - foreach_inc (i + by,stop,by) vars body - else ((),vars) + then let vars = body i vars in + foreach_inc (i + by,stop,by) vars body + else vars -val foreach_dec : forall 'vars. (integer * integer * integer) (*(nat * nat * nat)*) -> 'vars -> - (integer (*nat*) -> 'vars -> (unit * 'vars)) -> (unit * 'vars) +val foreach_dec : forall 'vars. (integer * integer * integer) -> 'vars -> + (integer -> 'vars -> 'vars) -> 'vars let rec foreach_dec (i,stop,by) vars body = if i >= stop - then - let (_,vars) = body i vars in - foreach_dec (i - by,stop,by) vars body - else ((),vars) + then let vars = body i vars in + foreach_dec (i - by,stop,by) vars body + else vars + + + diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index 51658d6e..5fc59207 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -158,20 +158,20 @@ let read_two_regs r1 r2 = return (v1 ^^ v2) val foreachM_inc : forall 'e 'vars. (i * i * i) -> 'vars -> - (i -> 'vars -> M 'e (unit * 'vars)) -> M 'e (unit * 'vars) + (i -> 'vars -> M 'e 'vars) -> M 'e 'vars let rec foreachM_inc (i,stop,by) vars body = if i <= stop then - body i vars >>= fun (_,vars) -> + body i vars >>= fun vars -> foreachM_inc (i + by,stop,by) vars body - else return ((),vars) + else return vars val foreachM_dec : forall 'e 'vars. (i * i * i) -> 'vars -> - (i -> 'vars -> M 'e (unit * 'vars)) -> M 'e (unit * 'vars) + (i -> 'vars -> M 'e 'vars) -> M 'e 'vars let rec foreachM_dec (i,stop,by) vars body = if i >= stop then - body i vars >>= fun (_,vars) -> + body i vars >>= fun vars -> foreachM_dec (i - by,stop,by) vars body - else return ((),vars) + else return vars diff --git a/src/gen_lib/vector.lem b/src/gen_lib/vector.lem index 7c22e3ba..b2d68132 100644 --- a/src/gen_lib/vector.lem +++ b/src/gen_lib/vector.lem @@ -34,7 +34,7 @@ let vector_concat (Vector bs start is_inc) (Vector bs' _ _) = let (^^) = vector_concat -val slice : vector bit -> integer -> integer -> vector bit +val slice : forall 'a. vector 'a -> integer -> integer -> vector 'a let slice (Vector bs start is_inc) n m = let n = natFromInteger n in let m = natFromInteger m in @@ -45,6 +45,7 @@ let slice (Vector bs start is_inc) n m = let n = integerFromNat n in Vector subvector n is_inc +val update : forall 'a. vector 'a -> integer -> integer -> vector 'a -> vector 'a let update (Vector bs start is_inc) n m (Vector bs' _ _) = let n = natFromInteger n in let m = natFromInteger m in @@ -55,10 +56,10 @@ let update (Vector bs start is_inc) n m (Vector bs' _ _) = let start = integerFromNat start in Vector (prefix ++ (List.take length bs') ++ suffix) start is_inc -val access : forall 'a. vector 'a -> (*nat*) integer -> 'a +val access : forall 'a. vector 'a -> integer -> 'a let access (Vector bs start is_inc) n = if is_inc then nth bs (n - start) else nth bs (start - n) -val update_pos : forall 'a. vector 'a -> (*nat*) integer -> 'a -> vector 'a +val update_pos : forall 'a. vector 'a -> integer -> 'a -> vector 'a let update_pos v n b = update v n n (Vector [b] 0 true) diff --git a/src/pretty_print.ml b/src/pretty_print.ml index 2d5974f2..e58beea4 100644 --- a/src/pretty_print.ml +++ b/src/pretty_print.ml @@ -2346,7 +2346,6 @@ let doc_exp_lem, doc_let_lem = (* temporary hack to make the loop body a function of the temporary variables *) | Id_aux ((Id (("foreach_inc" | "foreach_dec" | "foreachM_inc" | "foreachM_dec" ) as loopf),_)) -> - let call = doc_id_lem in let [id;indices;body;e5] = args in (match e5 with | E_aux (E_tuple vars,_) -> diff --git a/src/rewriter.ml b/src/rewriter.ml index cac5084b..c636d6fd 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -27,9 +27,6 @@ let fresh_name () = let reset_fresh_name_counter () = fresh_name_counter := 0 -let get_fresh_name_counter () = !fresh_name_counter -let set_fresh_name_counter i = (fresh_name_counter := i) - let get_effsum_annot (_,t) = match t with | Base (_,_,_,_,effs,_) -> effs | NoTyp -> failwith "no effect information" @@ -1309,7 +1306,7 @@ let rewrite_defs_ocaml defs = let defs_separate_nums = rewrite_defs_separate_numbs defs_lifted_assign in defs_separate_nums -let remove_blocks = +let rewrite_defs_remove_blocks = let letbind_wild v body = let (E_aux (_,(l,_))) = v in let annot_pat = (Parse_ast.Generated l,simple_annot (get_type v)) in @@ -1326,7 +1323,18 @@ let remove_blocks = | (E_block es,(l,_)) -> f l es | (e,annot) -> E_aux (e,annot) in - fold_exp { id_exp_alg with e_aux = e_aux } + let alg = { id_exp_alg with e_aux = e_aux } in + + rewrite_defs_base + {rewrite_exp = (fun _ _ -> fold_exp alg) + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + let fresh_id ((l,_) as annot) = @@ -1347,9 +1355,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = let annot_let = (Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body])) in let pat = P_aux (P_wild,annot_pat) in - if effectful v - then E_aux (E_internal_plet (pat,v,body),annot_let) - else E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) + E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) | _ -> let (E_aux (_,((l,_) as annot))) = v in let ((E_aux (E_id id,_)) as e_id) = fresh_id annot in @@ -1360,9 +1366,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = let annot_let = (Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body])) in let pat = P_aux (P_id id,annot_pat) in - if effectful v - then E_aux (E_internal_plet (pat,v,body),annot_let) - else E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) + E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp = @@ -1373,33 +1377,7 @@ let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list let rewrite_defs_letbind_effects = let rec value ((E_aux (exp_aux,_)) as exp) = - not (effectful exp) && not (updates_vars exp) (*&& - match exp_aux with - | E_id _ - | E_lit _ -> true - | E_tuple es - | E_vector es - | E_list es -> List.fold_left (&&) true (List.map value es) - - (* the ones below are debatable *) - | E_app (_,es) -> List.fold_left (fun b e -> b && value e) true es - | E_app_infix (e1,_,e2) -> value e1 && value e2 - - | E_cast (_,e) -> value e - | E_vector_indexed (ies,optdefault) -> - List.fold_left (fun b (i,e) -> b && value e) true ies && value_optdefault optdefault - | E_vector_append (e1,e2) - | E_vector_access (e1,e2) -> value e1 && value e2 - | E_vector_subrange (e1,e2,e3) - | E_vector_update (e1,e2,e3) -> value e1 && value e2 && value e3 - | E_vector_update_subrange (e1,e2,e3,e4) -> value e1 && value e2 && value e3 && value e4 - | E_cons (e1,e2) -> value e1 && value e2 - | E_record fexps -> value_fexps fexps - | E_record_update (e1,fexps) -> value e1 && value_fexps fexps - | E_field (e1,_) -> value e1 - | E_return e -> value e - - | _ -> false *) + not (effectful exp) && not (updates_vars exp) and value_optdefault (Def_val_aux (o,_)) = match o with | Def_val_empty -> true | Def_val_dec e -> value e @@ -1483,9 +1461,9 @@ let rewrite_defs_letbind_effects = E_aux (E_internal_return exp,(Parse_ast.Generated l,simple_annot_efr (get_type exp) (get_effsum_exp exp))) else exp in - (* changed this from n_exp to n_exp_pure so that when we return updated variables - * from a for-loop, for example, we can just add those into the returned tuple and - * don't need to a-normalise again *) + (* n_exp_term forces an expression to be translated into a form + "let .. let .. let .. in EXP" where EXP has no effect and does not update + variables *) n_exp_pure exp (fun exp -> exp) and n_exp (E_aux (exp_aux,annot) as exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = @@ -1624,7 +1602,7 @@ let rewrite_defs_letbind_effects = b || effectful_effs (get_localeff_annot annot)) false funcls in let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),annot)) = let _ = reset_fresh_name_counter () in - FCL_aux (FCL_Funcl (id,pat,n_exp_term newreturn (remove_blocks exp)),annot) + FCL_aux (FCL_Funcl (id,pat,n_exp_term newreturn exp),annot) in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),fdannot) in rewrite_defs_base {rewrite_exp = rewrite_exp @@ -1637,23 +1615,25 @@ let rewrite_defs_letbind_effects = } let rewrite_defs_effectful_let_expressions = - let alg = { id_exp_alg with - e_let = (fun (lb,body) -> - match lb with - | LB_aux (LB_val_explicit (_,pat,exp'),annot') - | LB_aux (LB_val_implicit (pat,exp'),annot') -> - if effectful exp' - then E_internal_plet (pat,exp',body) - else E_let (lb,body)) - ; e_internal_let = fun (lexp,exp1,exp2) -> - if effectful exp1 then - match lexp with - | LEXP_aux (LEXP_id id,annot) - | LEXP_aux (LEXP_cast (_,id),annot) -> - E_internal_plet (P_aux (P_id id,annot),exp1,exp2) - | _ -> failwith "E_internal_plet with unexpected lexp" - else E_internal_let (lexp,exp1,exp2) - } in + + let e_let (lb,body) = + match lb with + | LB_aux (LB_val_explicit (_,pat,exp'),annot') + | LB_aux (LB_val_implicit (pat,exp'),annot') -> + if effectful exp' + then E_internal_plet (pat,exp',body) + else E_let (lb,body) in + + let e_internal_let = fun (lexp,exp1,exp2) -> + if effectful exp1 then + match lexp with + | LEXP_aux (LEXP_id id,annot) + | LEXP_aux (LEXP_cast (_,id),annot) -> + E_internal_plet (P_aux (P_id id,annot),exp1,exp2) + | _ -> failwith "E_internal_plet with unexpected lexp" + else E_internal_let (lexp,exp1,exp2) in + + let alg = { id_exp_alg with e_let = e_let; e_internal_let = e_internal_let } in rewrite_defs_base {rewrite_exp = (fun _ _ -> fold_exp alg) ; rewrite_pat = rewrite_pat @@ -1782,55 +1762,46 @@ type 'a updated_term = let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = - let rec add_vars (*overwrite*) ((E_aux (expaux,annot)) as exp) vars = + let rec add_vars overwrite ((E_aux (expaux,annot)) as exp) vars = match expaux with | E_let (lb,exp) -> - let exp = add_vars (*overwrite*) exp vars in + let exp = add_vars overwrite exp vars in E_aux (E_let (lb,exp),swaptyp (get_type exp) annot) | E_internal_let (lexp,exp1,exp2) -> - let exp2 = add_vars (*overwrite*) exp2 vars in + let exp2 = add_vars overwrite exp2 vars in E_aux (E_internal_let (lexp,exp1,exp2), swaptyp (get_type exp2) annot) | E_internal_plet (pat,exp1,exp2) -> - let exp2 = add_vars (*overwrite*) exp2 vars in + let exp2 = add_vars overwrite exp2 vars in E_aux (E_internal_plet (pat,exp1,exp2), swaptyp (get_type exp2) annot) | E_internal_return exp2 -> - let exp2 = add_vars (*overwrite*) exp2 vars in + let exp2 = add_vars overwrite exp2 vars in E_aux (E_internal_return exp2,swaptyp (get_type exp2) annot) | _ -> - (* after a-normalisation this will be pure: - * if the whole body of the function/if-expression/case-expression/for-loop was - * pure, then it's still pure; if it wasn't then the body was wrapped in E_return - * and (in this case) exp is a name contained in E_return that by definition of - * value must be pure - *) - let () = - if (effectful exp) then - failwith (e_to_string (get_effsum_exp exp)) - else - () in -(* if overwrite then -(* let () = match expaux with - | E_id _ when get_type exp = {t = Tid "unit"} -> () - | _ -> failwith "nono" in *) + (* after rewrite_defs_letbind_effects there cannot be terms that have + effects/update local variables in "tail-position": check n_exp_term + and where it is used. *) + if overwrite then + let () = if get_type exp = {t = Tid "unit"} then () + else failwith "nono" in vars - else*) + else E_aux (E_tuple [exp;vars],swaptyp {t = Ttup [get_type exp;get_type vars]} annot) in let rewrite (E_aux (expaux,((el,_) as annot))) (P_aux (_,(pl,pannot)) as pat) = + let overwrite = match get_type_annot annot with + | {t = Tid "unit"} -> true + | _ -> false in match expaux with | E_for(id,exp1,exp2,exp3,order,exp4) -> let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars exp4) in let vartuple = mktup el vars in -(* let overwrite = match get_type exp with - | {t = Tid "unit"} -> true - | _ -> false in*) - let exp4 = rewrite_var_updates (add_vars (*overwrite*) exp4 vartuple) in - let orderb = match order with - | Ord_aux (Ord_inc,_) -> true - | Ord_aux (Ord_dec,_) -> false in - let funcl = match effectful exp4 with - | false -> Id_aux (Id (if orderb then "foreach_inc" else "foreach_dec"),Parse_ast.Generated el) - | true -> Id_aux (Id (if orderb then "foreachM_inc" else "foreachM_dec"),Parse_ast.Generated el) in + let exp4 = rewrite_var_updates (add_vars overwrite exp4 vartuple) in + let fname = match effectful exp4,order with + | false, Ord_aux (Ord_inc,_) -> "foreach_inc" + | false, Ord_aux (Ord_dec,_) -> "foreach_dec" + | true, Ord_aux (Ord_inc,_) -> "foreachM_inc" + | true, Ord_aux (Ord_dec,_) -> "foreachM_dec" in + let funcl = Id_aux (Id fname,Parse_ast.Generated el) in let loopvar = let (bf,tf) = match get_type exp1 with | {t = Tapp ("atom",[TA_nexp f])} -> (TA_nexp f,TA_nexp f) @@ -1844,15 +1815,16 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])} -> (TA_nexp bt,TA_nexp tt) | {t = Tapp ("atom",[TA_typ {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])}])} -> (TA_nexp bt,TA_nexp tt) | {t = Tapp (name,_)} -> failwith (name ^ " shouldn't be here") in - let t = {t = Tapp ("range",if orderb then [bf;tt] else [tf;bt])} in + let t = {t = Tapp ("range",match order with + | Ord_aux (Ord_inc,_) -> [bf;tt] + | Ord_aux (Ord_dec,_) -> [tf;bt])} in E_aux (E_id id,(Parse_ast.Generated el,simple_annot t)) in let v = E_aux (E_app (funcl,[loopvar;mktup el [exp1;exp2;exp3];exp4;vartuple]), (Parse_ast.Generated el,simple_annot_efr (get_type exp4) (get_effsum_exp exp4))) in let pat = -(* if overwrite then - mktup_pat vars - else *) - P_aux (P_tup [pat; mktup_pat pl vars], (Parse_ast.Generated pl,simple_annot (get_type v))) in + if overwrite then mktup_pat el vars + else P_aux (P_tup [pat; mktup_pat pl vars], + (Parse_ast.Generated pl,simple_annot (get_type v))) in Added_vars (v,pat) | E_if (c,e1,e2) -> let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) @@ -1861,25 +1833,20 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = (Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot))) else let vartuple = mktup el vars in -(* let overwrite = match get_type exp with - | {t = Tid "unit"} -> true - | _ -> false in *) - let e1 = rewrite_var_updates (add_vars (*overwrite*) e1 vartuple) in - let e2 = rewrite_var_updates (add_vars (*overwrite*) e2 vartuple) in - (* after a-normalisation c shouldn't need rewriting *) + let e1 = rewrite_var_updates (add_vars overwrite e1 vartuple) in + let e2 = rewrite_var_updates (add_vars overwrite e2 vartuple) in + (* after rewrite_defs_letbind_effects c has no variable updates *) let t = get_type e1 in - (* let () = assert (simple_annot t = simple_annot (get_type e2)) in *) let v = E_aux (E_if (c,e1,e2), (Parse_ast.Generated el,simple_annot_efr t (eff_union_exps [e1;e2]))) in let pat = -(* if overwrite then - mktup_pat vars - else*) - P_aux (P_tup [pat; mktup_pat pl vars],(Parse_ast.Generated pl,simple_annot (get_type v))) in + if overwrite then mktup_pat el vars + else P_aux (P_tup [pat; mktup_pat pl vars], + (Parse_ast.Generated pl,simple_annot (get_type v))) in Added_vars (v,pat) | E_case (e1,ps) -> - (* after a-normalisation e1 shouldn't need rewriting *) + (* after rewrite_defs_letbind_effects e1 needs no rewriting *) let vars = - let f acc (Pat_aux (Pat_exp (_,e),_)) = acc @ find_updated_vars e in + let f acc (Pat_aux (Pat_exp (_,e),_)) = acc @ find_updated_vars e in List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (dedup eqidtyp (List.fold_left f [] ps)) in if vars = [] then @@ -1887,9 +1854,6 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = Same_vars (E_aux (E_case (e1,ps),annot)) else let vartuple = mktup el vars in -(* let overwrite = match get_type exp with - | {t = Tid "unit"} -> true - | _ -> false in*) let typ = let (Pat_aux (Pat_exp (_,first),_)) = List.hd ps in get_type first in @@ -1897,7 +1861,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let f (acc,typ,effs) (Pat_aux (Pat_exp (p,e),pannot)) = let etyp = get_type e in let () = assert (simple_annot etyp = simple_annot typ) in - let e = rewrite_var_updates (add_vars (*overwrite*) e vartuple) in + let e = rewrite_var_updates (add_vars overwrite e vartuple) in let pannot = (Parse_ast.Generated pl,simple_annot (get_type e)) in let effs = union_effects effs (get_effsum_exp e) in let pat' = Pat_aux (Pat_exp (p,e),pannot) in @@ -1905,10 +1869,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = List.fold_left f ([],typ,{effect = Eset []}) ps in let v = E_aux (E_case (e1,ps), (Parse_ast.Generated pl,simple_annot_efr typ effs)) in let pat = -(* if overwrite then - P_aux (P_tup [mktup_pat vars],(Unknown,simple_annot (get_type v))) - else*) - P_aux (P_tup [pat; mktup_pat pl vars],(Parse_ast.Generated pl,simple_annot (get_type v))) in + if overwrite then mktup_pat el vars + else P_aux (P_tup [pat; mktup_pat pl vars], + (Parse_ast.Generated pl,simple_annot (get_type v))) in Added_vars (v,pat) | E_assign (lexp,vexp) -> let {effect = Eset effs} = get_effsum_annot annot in @@ -1936,8 +1899,8 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let pat = P_aux (P_id id,(Parse_ast.Generated pl,simple_annot (get_type vexp))) in Added_vars (vexp,pat)) | _ -> - (* assumes everying's a-normlised: an expression is a sequence of let-expressions, - * "control-flow" structures and a return value, possibly wrapped in E_return *) + (* after rewrite_defs_letbind_effects this expression is pure and updates + no variables: check n_exp_term and where it's used. *) Same_vars (E_aux (expaux,annot)) in match expaux with @@ -1958,32 +1921,22 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let lbannot = (Parse_ast.Generated l,simple_annot (get_type v)) in (get_effsum_exp v,LB_aux (LB_val_implicit (pat,v),lbannot)) | Same_vars v -> (get_effsum_exp v,LB_aux (LB_val_explicit (typ,pat,v),lbannot))) in - E_aux (E_let (lb,body), - (Parse_ast.Generated l,simple_annot_efr (get_type body) (union_effects eff (get_effsum_exp body)))) - | E_internal_plet (pat,v,body) -> - let body = rewrite_var_updates body in - (match rewrite v pat with - | Added_vars (v,pat) -> - E_aux (E_internal_plet (pat,v,body), - (Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body]))) - | Same_vars v -> E_aux (E_internal_plet (pat,v,body),annot)) + let typ = simple_annot_efr (get_type body) (union_effects eff (get_effsum_exp body)) in + E_aux (E_let (lb,body),(Parse_ast.Generated l,typ)) | E_internal_let (lexp,v,body) -> - (* After a-normalisation E_internal_lets can only bind values to names, those don't - * need rewriting. *) - let body = rewrite_var_updates body in + (* Rewrite E_internal_let into E_let and call recursively *) let id = match lexp with | LEXP_aux (LEXP_id id,_) -> id | LEXP_aux (LEXP_cast (_,id),_) -> id in let pat = P_aux (P_id id, (Parse_ast.Generated l,simple_annot (get_type v))) in let lbannot = (Parse_ast.Generated l,simple_annot_efr (get_type v) (get_effsum_exp v)) in let lb = LB_aux (LB_val_implicit (pat,v),lbannot) in - E_aux (E_let (lb,body),(Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body]))) - (* In tail-position there shouldn't be anything we need to do as the terms after - * a-normalisation are pure and don't update local variables. There can't be any variable - * assignments in tail-position (because of the effect), there could be pure pattern-match - * expressions, if-expressions that don't need rewriting. For-loops still need rewriting, - * but it would be pointless to have them in tail-position if they don't access memory or - * update variables. *) + let exp = E_aux (E_let (lb,body),(Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body]))) in + rewrite_var_updates exp + | E_internal_plet (pat,v,body) -> + failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet" + (* There are no expressions that have effects or variable updates in + "tail-position": check the definition nexp_term and where it is used. *) | _ -> exp let replace_memwrite_e_assign exp = @@ -2088,6 +2041,7 @@ let rewrite_defs_remove_e_assign = let rewrite_defs_lem defs = let defs = rewrite_defs_remove_vector_concat defs in let defs = rewrite_defs_exp_lift_assign defs in + let defs = rewrite_defs_remove_blocks defs in let defs = rewrite_defs_letbind_effects defs in let defs = rewrite_defs_remove_e_assign defs in let defs = rewrite_defs_effectful_let_expressions defs in -- cgit v1.2.3