summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorChristopher Pulte2016-09-25 15:14:12 +0100
committerChristopher Pulte2016-09-25 15:14:12 +0100
commitdd052bfc3e00a1ae988044ae81dd1624332dd899 (patch)
tree357d6e14136545dce7d0d120b7c1e5bccf27970d /src
parent6e7cee1575a7c49f4bdc30dfd6f25546c6c70995 (diff)
nicer lem output: no more unecessary 'unit' returns if if-expressions, for-loops or case-expressions also return updated variables
Diffstat (limited to 'src')
-rw-r--r--src/gen_lib/prompt.lem12
-rw-r--r--src/gen_lib/sail_values.lem27
-rw-r--r--src/gen_lib/state.lem12
-rw-r--r--src/gen_lib/vector.lem7
-rw-r--r--src/pretty_print.ml1
-rw-r--r--src/rewriter.ml228
6 files changed, 120 insertions, 167 deletions
diff --git a/src/gen_lib/prompt.lem b/src/gen_lib/prompt.lem
index 4cd76156..a9aa3218 100644
--- a/src/gen_lib/prompt.lem
+++ b/src/gen_lib/prompt.lem
@@ -247,23 +247,23 @@ val write_reg_bitfield : forall 'e. register -> register_bitfield -> bit -> M 'e
let write_reg_bitfield reg rbit = write_reg_bit reg (field_index_bit rbit)
val foreachM_inc : forall 'e 'vars. (nat * nat * nat) -> 'vars ->
- (nat -> 'vars -> M 'e (unit * 'vars)) -> M 'e (unit * 'vars)
+ (nat -> 'vars -> M 'e 'vars) -> M 'e 'vars
let rec foreachM_inc (i,stop,by) vars body =
if i <= stop
then
- body i vars >>= fun (_,vars) ->
+ body i vars >>= fun vars ->
foreachM_inc (i + by,stop,by) vars body
- else return ((),vars)
+ else return vars
val foreachM_dec : forall 'e 'vars. (nat * nat * nat) -> 'vars ->
- (nat -> 'vars -> M 'e (unit * 'vars)) -> M 'e (unit * 'vars)
+ (nat -> 'vars -> M 'e 'vars) -> M 'e 'vars
let rec foreachM_dec (i,stop,by) vars body =
if i >= stop
then
- body i vars >>= fun (_,vars) ->
+ body i vars >>= fun vars ->
foreachM_dec (i - by,stop,by) vars body
- else return ((),vars)
+ else return vars
let write_two_regs r1 r2 vec =
diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem
index b9a4fbd1..454778c4 100644
--- a/src/gen_lib/sail_values.lem
+++ b/src/gen_lib/sail_values.lem
@@ -527,22 +527,21 @@ let toNaturalFiveTup (n1,n2,n3,n4,n5) =
toNatural n5)
-
-
-val foreach_inc : forall 'vars. (integer * integer * integer) (*(nat * nat * nat)*) -> 'vars ->
- (integer (*nat*) -> 'vars -> (unit * 'vars)) -> (unit * 'vars)
+val foreach_inc : forall 'vars. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> 'vars) -> 'vars
let rec foreach_inc (i,stop,by) vars body =
if i <= stop
- then
- let (_,vars) = body i vars in
- foreach_inc (i + by,stop,by) vars body
- else ((),vars)
+ then let vars = body i vars in
+ foreach_inc (i + by,stop,by) vars body
+ else vars
-val foreach_dec : forall 'vars. (integer * integer * integer) (*(nat * nat * nat)*) -> 'vars ->
- (integer (*nat*) -> 'vars -> (unit * 'vars)) -> (unit * 'vars)
+val foreach_dec : forall 'vars. (integer * integer * integer) -> 'vars ->
+ (integer -> 'vars -> 'vars) -> 'vars
let rec foreach_dec (i,stop,by) vars body =
if i >= stop
- then
- let (_,vars) = body i vars in
- foreach_dec (i - by,stop,by) vars body
- else ((),vars)
+ then let vars = body i vars in
+ foreach_dec (i - by,stop,by) vars body
+ else vars
+
+
+
diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem
index 51658d6e..5fc59207 100644
--- a/src/gen_lib/state.lem
+++ b/src/gen_lib/state.lem
@@ -158,20 +158,20 @@ let read_two_regs r1 r2 =
return (v1 ^^ v2)
val foreachM_inc : forall 'e 'vars. (i * i * i) -> 'vars ->
- (i -> 'vars -> M 'e (unit * 'vars)) -> M 'e (unit * 'vars)
+ (i -> 'vars -> M 'e 'vars) -> M 'e 'vars
let rec foreachM_inc (i,stop,by) vars body =
if i <= stop
then
- body i vars >>= fun (_,vars) ->
+ body i vars >>= fun vars ->
foreachM_inc (i + by,stop,by) vars body
- else return ((),vars)
+ else return vars
val foreachM_dec : forall 'e 'vars. (i * i * i) -> 'vars ->
- (i -> 'vars -> M 'e (unit * 'vars)) -> M 'e (unit * 'vars)
+ (i -> 'vars -> M 'e 'vars) -> M 'e 'vars
let rec foreachM_dec (i,stop,by) vars body =
if i >= stop
then
- body i vars >>= fun (_,vars) ->
+ body i vars >>= fun vars ->
foreachM_dec (i - by,stop,by) vars body
- else return ((),vars)
+ else return vars
diff --git a/src/gen_lib/vector.lem b/src/gen_lib/vector.lem
index 7c22e3ba..b2d68132 100644
--- a/src/gen_lib/vector.lem
+++ b/src/gen_lib/vector.lem
@@ -34,7 +34,7 @@ let vector_concat (Vector bs start is_inc) (Vector bs' _ _) =
let (^^) = vector_concat
-val slice : vector bit -> integer -> integer -> vector bit
+val slice : forall 'a. vector 'a -> integer -> integer -> vector 'a
let slice (Vector bs start is_inc) n m =
let n = natFromInteger n in
let m = natFromInteger m in
@@ -45,6 +45,7 @@ let slice (Vector bs start is_inc) n m =
let n = integerFromNat n in
Vector subvector n is_inc
+val update : forall 'a. vector 'a -> integer -> integer -> vector 'a -> vector 'a
let update (Vector bs start is_inc) n m (Vector bs' _ _) =
let n = natFromInteger n in
let m = natFromInteger m in
@@ -55,10 +56,10 @@ let update (Vector bs start is_inc) n m (Vector bs' _ _) =
let start = integerFromNat start in
Vector (prefix ++ (List.take length bs') ++ suffix) start is_inc
-val access : forall 'a. vector 'a -> (*nat*) integer -> 'a
+val access : forall 'a. vector 'a -> integer -> 'a
let access (Vector bs start is_inc) n =
if is_inc then nth bs (n - start) else nth bs (start - n)
-val update_pos : forall 'a. vector 'a -> (*nat*) integer -> 'a -> vector 'a
+val update_pos : forall 'a. vector 'a -> integer -> 'a -> vector 'a
let update_pos v n b =
update v n n (Vector [b] 0 true)
diff --git a/src/pretty_print.ml b/src/pretty_print.ml
index 2d5974f2..e58beea4 100644
--- a/src/pretty_print.ml
+++ b/src/pretty_print.ml
@@ -2346,7 +2346,6 @@ let doc_exp_lem, doc_let_lem =
(* temporary hack to make the loop body a function of the temporary variables *)
| Id_aux ((Id (("foreach_inc" | "foreach_dec" |
"foreachM_inc" | "foreachM_dec" ) as loopf),_)) ->
- let call = doc_id_lem in
let [id;indices;body;e5] = args in
(match e5 with
| E_aux (E_tuple vars,_) ->
diff --git a/src/rewriter.ml b/src/rewriter.ml
index cac5084b..c636d6fd 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -27,9 +27,6 @@ let fresh_name () =
let reset_fresh_name_counter () =
fresh_name_counter := 0
-let get_fresh_name_counter () = !fresh_name_counter
-let set_fresh_name_counter i = (fresh_name_counter := i)
-
let get_effsum_annot (_,t) = match t with
| Base (_,_,_,_,effs,_) -> effs
| NoTyp -> failwith "no effect information"
@@ -1309,7 +1306,7 @@ let rewrite_defs_ocaml defs =
let defs_separate_nums = rewrite_defs_separate_numbs defs_lifted_assign in
defs_separate_nums
-let remove_blocks =
+let rewrite_defs_remove_blocks =
let letbind_wild v body =
let (E_aux (_,(l,_))) = v in
let annot_pat = (Parse_ast.Generated l,simple_annot (get_type v)) in
@@ -1326,7 +1323,18 @@ let remove_blocks =
| (E_block es,(l,_)) -> f l es
| (e,annot) -> E_aux (e,annot) in
- fold_exp { id_exp_alg with e_aux = e_aux }
+ let alg = { id_exp_alg with e_aux = e_aux } in
+
+ rewrite_defs_base
+ {rewrite_exp = (fun _ _ -> fold_exp alg)
+ ; rewrite_pat = rewrite_pat
+ ; rewrite_let = rewrite_let
+ ; rewrite_lexp = rewrite_lexp
+ ; rewrite_fun = rewrite_fun
+ ; rewrite_def = rewrite_def
+ ; rewrite_defs = rewrite_defs_base
+ }
+
let fresh_id ((l,_) as annot) =
@@ -1347,9 +1355,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
let annot_let = (Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body])) in
let pat = P_aux (P_wild,annot_pat) in
- if effectful v
- then E_aux (E_internal_plet (pat,v,body),annot_let)
- else E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let)
+ E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let)
| _ ->
let (E_aux (_,((l,_) as annot))) = v in
let ((E_aux (E_id id,_)) as e_id) = fresh_id annot in
@@ -1360,9 +1366,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
let annot_let = (Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body])) in
let pat = P_aux (P_id id,annot_pat) in
- if effectful v
- then E_aux (E_internal_plet (pat,v,body),annot_let)
- else E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let)
+ E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let)
let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp =
@@ -1373,33 +1377,7 @@ let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list
let rewrite_defs_letbind_effects =
let rec value ((E_aux (exp_aux,_)) as exp) =
- not (effectful exp) && not (updates_vars exp) (*&&
- match exp_aux with
- | E_id _
- | E_lit _ -> true
- | E_tuple es
- | E_vector es
- | E_list es -> List.fold_left (&&) true (List.map value es)
-
- (* the ones below are debatable *)
- | E_app (_,es) -> List.fold_left (fun b e -> b && value e) true es
- | E_app_infix (e1,_,e2) -> value e1 && value e2
-
- | E_cast (_,e) -> value e
- | E_vector_indexed (ies,optdefault) ->
- List.fold_left (fun b (i,e) -> b && value e) true ies && value_optdefault optdefault
- | E_vector_append (e1,e2)
- | E_vector_access (e1,e2) -> value e1 && value e2
- | E_vector_subrange (e1,e2,e3)
- | E_vector_update (e1,e2,e3) -> value e1 && value e2 && value e3
- | E_vector_update_subrange (e1,e2,e3,e4) -> value e1 && value e2 && value e3 && value e4
- | E_cons (e1,e2) -> value e1 && value e2
- | E_record fexps -> value_fexps fexps
- | E_record_update (e1,fexps) -> value e1 && value_fexps fexps
- | E_field (e1,_) -> value e1
- | E_return e -> value e
-
- | _ -> false *)
+ not (effectful exp) && not (updates_vars exp)
and value_optdefault (Def_val_aux (o,_)) = match o with
| Def_val_empty -> true
| Def_val_dec e -> value e
@@ -1483,9 +1461,9 @@ let rewrite_defs_letbind_effects =
E_aux (E_internal_return exp,(Parse_ast.Generated l,simple_annot_efr (get_type exp) (get_effsum_exp exp)))
else
exp in
- (* changed this from n_exp to n_exp_pure so that when we return updated variables
- * from a for-loop, for example, we can just add those into the returned tuple and
- * don't need to a-normalise again *)
+ (* n_exp_term forces an expression to be translated into a form
+ "let .. let .. let .. in EXP" where EXP has no effect and does not update
+ variables *)
n_exp_pure exp (fun exp -> exp)
and n_exp (E_aux (exp_aux,annot) as exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
@@ -1624,7 +1602,7 @@ let rewrite_defs_letbind_effects =
b || effectful_effs (get_localeff_annot annot)) false funcls in
let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),annot)) =
let _ = reset_fresh_name_counter () in
- FCL_aux (FCL_Funcl (id,pat,n_exp_term newreturn (remove_blocks exp)),annot)
+ FCL_aux (FCL_Funcl (id,pat,n_exp_term newreturn exp),annot)
in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),fdannot) in
rewrite_defs_base
{rewrite_exp = rewrite_exp
@@ -1637,23 +1615,25 @@ let rewrite_defs_letbind_effects =
}
let rewrite_defs_effectful_let_expressions =
- let alg = { id_exp_alg with
- e_let = (fun (lb,body) ->
- match lb with
- | LB_aux (LB_val_explicit (_,pat,exp'),annot')
- | LB_aux (LB_val_implicit (pat,exp'),annot') ->
- if effectful exp'
- then E_internal_plet (pat,exp',body)
- else E_let (lb,body))
- ; e_internal_let = fun (lexp,exp1,exp2) ->
- if effectful exp1 then
- match lexp with
- | LEXP_aux (LEXP_id id,annot)
- | LEXP_aux (LEXP_cast (_,id),annot) ->
- E_internal_plet (P_aux (P_id id,annot),exp1,exp2)
- | _ -> failwith "E_internal_plet with unexpected lexp"
- else E_internal_let (lexp,exp1,exp2)
- } in
+
+ let e_let (lb,body) =
+ match lb with
+ | LB_aux (LB_val_explicit (_,pat,exp'),annot')
+ | LB_aux (LB_val_implicit (pat,exp'),annot') ->
+ if effectful exp'
+ then E_internal_plet (pat,exp',body)
+ else E_let (lb,body) in
+
+ let e_internal_let = fun (lexp,exp1,exp2) ->
+ if effectful exp1 then
+ match lexp with
+ | LEXP_aux (LEXP_id id,annot)
+ | LEXP_aux (LEXP_cast (_,id),annot) ->
+ E_internal_plet (P_aux (P_id id,annot),exp1,exp2)
+ | _ -> failwith "E_internal_plet with unexpected lexp"
+ else E_internal_let (lexp,exp1,exp2) in
+
+ let alg = { id_exp_alg with e_let = e_let; e_internal_let = e_internal_let } in
rewrite_defs_base
{rewrite_exp = (fun _ _ -> fold_exp alg)
; rewrite_pat = rewrite_pat
@@ -1782,55 +1762,46 @@ type 'a updated_term =
let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
- let rec add_vars (*overwrite*) ((E_aux (expaux,annot)) as exp) vars =
+ let rec add_vars overwrite ((E_aux (expaux,annot)) as exp) vars =
match expaux with
| E_let (lb,exp) ->
- let exp = add_vars (*overwrite*) exp vars in
+ let exp = add_vars overwrite exp vars in
E_aux (E_let (lb,exp),swaptyp (get_type exp) annot)
| E_internal_let (lexp,exp1,exp2) ->
- let exp2 = add_vars (*overwrite*) exp2 vars in
+ let exp2 = add_vars overwrite exp2 vars in
E_aux (E_internal_let (lexp,exp1,exp2), swaptyp (get_type exp2) annot)
| E_internal_plet (pat,exp1,exp2) ->
- let exp2 = add_vars (*overwrite*) exp2 vars in
+ let exp2 = add_vars overwrite exp2 vars in
E_aux (E_internal_plet (pat,exp1,exp2), swaptyp (get_type exp2) annot)
| E_internal_return exp2 ->
- let exp2 = add_vars (*overwrite*) exp2 vars in
+ let exp2 = add_vars overwrite exp2 vars in
E_aux (E_internal_return exp2,swaptyp (get_type exp2) annot)
| _ ->
- (* after a-normalisation this will be pure:
- * if the whole body of the function/if-expression/case-expression/for-loop was
- * pure, then it's still pure; if it wasn't then the body was wrapped in E_return
- * and (in this case) exp is a name contained in E_return that by definition of
- * value must be pure
- *)
- let () =
- if (effectful exp) then
- failwith (e_to_string (get_effsum_exp exp))
- else
- () in
-(* if overwrite then
-(* let () = match expaux with
- | E_id _ when get_type exp = {t = Tid "unit"} -> ()
- | _ -> failwith "nono" in *)
+ (* after rewrite_defs_letbind_effects there cannot be terms that have
+ effects/update local variables in "tail-position": check n_exp_term
+ and where it is used. *)
+ if overwrite then
+ let () = if get_type exp = {t = Tid "unit"} then ()
+ else failwith "nono" in
vars
- else*)
+ else
E_aux (E_tuple [exp;vars],swaptyp {t = Ttup [get_type exp;get_type vars]} annot) in
let rewrite (E_aux (expaux,((el,_) as annot))) (P_aux (_,(pl,pannot)) as pat) =
+ let overwrite = match get_type_annot annot with
+ | {t = Tid "unit"} -> true
+ | _ -> false in
match expaux with
| E_for(id,exp1,exp2,exp3,order,exp4) ->
let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars exp4) in
let vartuple = mktup el vars in
-(* let overwrite = match get_type exp with
- | {t = Tid "unit"} -> true
- | _ -> false in*)
- let exp4 = rewrite_var_updates (add_vars (*overwrite*) exp4 vartuple) in
- let orderb = match order with
- | Ord_aux (Ord_inc,_) -> true
- | Ord_aux (Ord_dec,_) -> false in
- let funcl = match effectful exp4 with
- | false -> Id_aux (Id (if orderb then "foreach_inc" else "foreach_dec"),Parse_ast.Generated el)
- | true -> Id_aux (Id (if orderb then "foreachM_inc" else "foreachM_dec"),Parse_ast.Generated el) in
+ let exp4 = rewrite_var_updates (add_vars overwrite exp4 vartuple) in
+ let fname = match effectful exp4,order with
+ | false, Ord_aux (Ord_inc,_) -> "foreach_inc"
+ | false, Ord_aux (Ord_dec,_) -> "foreach_dec"
+ | true, Ord_aux (Ord_inc,_) -> "foreachM_inc"
+ | true, Ord_aux (Ord_dec,_) -> "foreachM_dec" in
+ let funcl = Id_aux (Id fname,Parse_ast.Generated el) in
let loopvar =
let (bf,tf) = match get_type exp1 with
| {t = Tapp ("atom",[TA_nexp f])} -> (TA_nexp f,TA_nexp f)
@@ -1844,15 +1815,16 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])} -> (TA_nexp bt,TA_nexp tt)
| {t = Tapp ("atom",[TA_typ {t = Tapp ("range",[TA_nexp bt;TA_nexp tt])}])} -> (TA_nexp bt,TA_nexp tt)
| {t = Tapp (name,_)} -> failwith (name ^ " shouldn't be here") in
- let t = {t = Tapp ("range",if orderb then [bf;tt] else [tf;bt])} in
+ let t = {t = Tapp ("range",match order with
+ | Ord_aux (Ord_inc,_) -> [bf;tt]
+ | Ord_aux (Ord_dec,_) -> [tf;bt])} in
E_aux (E_id id,(Parse_ast.Generated el,simple_annot t)) in
let v = E_aux (E_app (funcl,[loopvar;mktup el [exp1;exp2;exp3];exp4;vartuple]),
(Parse_ast.Generated el,simple_annot_efr (get_type exp4) (get_effsum_exp exp4))) in
let pat =
-(* if overwrite then
- mktup_pat vars
- else *)
- P_aux (P_tup [pat; mktup_pat pl vars], (Parse_ast.Generated pl,simple_annot (get_type v))) in
+ if overwrite then mktup_pat el vars
+ else P_aux (P_tup [pat; mktup_pat pl vars],
+ (Parse_ast.Generated pl,simple_annot (get_type v))) in
Added_vars (v,pat)
| E_if (c,e1,e2) ->
let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t)))
@@ -1861,25 +1833,20 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
(Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot)))
else
let vartuple = mktup el vars in
-(* let overwrite = match get_type exp with
- | {t = Tid "unit"} -> true
- | _ -> false in *)
- let e1 = rewrite_var_updates (add_vars (*overwrite*) e1 vartuple) in
- let e2 = rewrite_var_updates (add_vars (*overwrite*) e2 vartuple) in
- (* after a-normalisation c shouldn't need rewriting *)
+ let e1 = rewrite_var_updates (add_vars overwrite e1 vartuple) in
+ let e2 = rewrite_var_updates (add_vars overwrite e2 vartuple) in
+ (* after rewrite_defs_letbind_effects c has no variable updates *)
let t = get_type e1 in
- (* let () = assert (simple_annot t = simple_annot (get_type e2)) in *)
let v = E_aux (E_if (c,e1,e2), (Parse_ast.Generated el,simple_annot_efr t (eff_union_exps [e1;e2]))) in
let pat =
-(* if overwrite then
- mktup_pat vars
- else*)
- P_aux (P_tup [pat; mktup_pat pl vars],(Parse_ast.Generated pl,simple_annot (get_type v))) in
+ if overwrite then mktup_pat el vars
+ else P_aux (P_tup [pat; mktup_pat pl vars],
+ (Parse_ast.Generated pl,simple_annot (get_type v))) in
Added_vars (v,pat)
| E_case (e1,ps) ->
- (* after a-normalisation e1 shouldn't need rewriting *)
+ (* after rewrite_defs_letbind_effects e1 needs no rewriting *)
let vars =
- let f acc (Pat_aux (Pat_exp (_,e),_)) = acc @ find_updated_vars e in
+ let f acc (Pat_aux (Pat_exp (_,e),_)) = acc @ find_updated_vars e in
List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t)))
(dedup eqidtyp (List.fold_left f [] ps)) in
if vars = [] then
@@ -1887,9 +1854,6 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
Same_vars (E_aux (E_case (e1,ps),annot))
else
let vartuple = mktup el vars in
-(* let overwrite = match get_type exp with
- | {t = Tid "unit"} -> true
- | _ -> false in*)
let typ =
let (Pat_aux (Pat_exp (_,first),_)) = List.hd ps in
get_type first in
@@ -1897,7 +1861,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let f (acc,typ,effs) (Pat_aux (Pat_exp (p,e),pannot)) =
let etyp = get_type e in
let () = assert (simple_annot etyp = simple_annot typ) in
- let e = rewrite_var_updates (add_vars (*overwrite*) e vartuple) in
+ let e = rewrite_var_updates (add_vars overwrite e vartuple) in
let pannot = (Parse_ast.Generated pl,simple_annot (get_type e)) in
let effs = union_effects effs (get_effsum_exp e) in
let pat' = Pat_aux (Pat_exp (p,e),pannot) in
@@ -1905,10 +1869,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
List.fold_left f ([],typ,{effect = Eset []}) ps in
let v = E_aux (E_case (e1,ps), (Parse_ast.Generated pl,simple_annot_efr typ effs)) in
let pat =
-(* if overwrite then
- P_aux (P_tup [mktup_pat vars],(Unknown,simple_annot (get_type v)))
- else*)
- P_aux (P_tup [pat; mktup_pat pl vars],(Parse_ast.Generated pl,simple_annot (get_type v))) in
+ if overwrite then mktup_pat el vars
+ else P_aux (P_tup [pat; mktup_pat pl vars],
+ (Parse_ast.Generated pl,simple_annot (get_type v))) in
Added_vars (v,pat)
| E_assign (lexp,vexp) ->
let {effect = Eset effs} = get_effsum_annot annot in
@@ -1936,8 +1899,8 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let pat = P_aux (P_id id,(Parse_ast.Generated pl,simple_annot (get_type vexp))) in
Added_vars (vexp,pat))
| _ ->
- (* assumes everying's a-normlised: an expression is a sequence of let-expressions,
- * "control-flow" structures and a return value, possibly wrapped in E_return *)
+ (* after rewrite_defs_letbind_effects this expression is pure and updates
+ no variables: check n_exp_term and where it's used. *)
Same_vars (E_aux (expaux,annot)) in
match expaux with
@@ -1958,32 +1921,22 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let lbannot = (Parse_ast.Generated l,simple_annot (get_type v)) in
(get_effsum_exp v,LB_aux (LB_val_implicit (pat,v),lbannot))
| Same_vars v -> (get_effsum_exp v,LB_aux (LB_val_explicit (typ,pat,v),lbannot))) in
- E_aux (E_let (lb,body),
- (Parse_ast.Generated l,simple_annot_efr (get_type body) (union_effects eff (get_effsum_exp body))))
- | E_internal_plet (pat,v,body) ->
- let body = rewrite_var_updates body in
- (match rewrite v pat with
- | Added_vars (v,pat) ->
- E_aux (E_internal_plet (pat,v,body),
- (Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body])))
- | Same_vars v -> E_aux (E_internal_plet (pat,v,body),annot))
+ let typ = simple_annot_efr (get_type body) (union_effects eff (get_effsum_exp body)) in
+ E_aux (E_let (lb,body),(Parse_ast.Generated l,typ))
| E_internal_let (lexp,v,body) ->
- (* After a-normalisation E_internal_lets can only bind values to names, those don't
- * need rewriting. *)
- let body = rewrite_var_updates body in
+ (* Rewrite E_internal_let into E_let and call recursively *)
let id = match lexp with
| LEXP_aux (LEXP_id id,_) -> id
| LEXP_aux (LEXP_cast (_,id),_) -> id in
let pat = P_aux (P_id id, (Parse_ast.Generated l,simple_annot (get_type v))) in
let lbannot = (Parse_ast.Generated l,simple_annot_efr (get_type v) (get_effsum_exp v)) in
let lb = LB_aux (LB_val_implicit (pat,v),lbannot) in
- E_aux (E_let (lb,body),(Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body])))
- (* In tail-position there shouldn't be anything we need to do as the terms after
- * a-normalisation are pure and don't update local variables. There can't be any variable
- * assignments in tail-position (because of the effect), there could be pure pattern-match
- * expressions, if-expressions that don't need rewriting. For-loops still need rewriting,
- * but it would be pointless to have them in tail-position if they don't access memory or
- * update variables. *)
+ let exp = E_aux (E_let (lb,body),(Parse_ast.Generated l,simple_annot_efr (get_type body) (eff_union_exps [v;body]))) in
+ rewrite_var_updates exp
+ | E_internal_plet (pat,v,body) ->
+ failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet"
+ (* There are no expressions that have effects or variable updates in
+ "tail-position": check the definition nexp_term and where it is used. *)
| _ -> exp
let replace_memwrite_e_assign exp =
@@ -2088,6 +2041,7 @@ let rewrite_defs_remove_e_assign =
let rewrite_defs_lem defs =
let defs = rewrite_defs_remove_vector_concat defs in
let defs = rewrite_defs_exp_lift_assign defs in
+ let defs = rewrite_defs_remove_blocks defs in
let defs = rewrite_defs_letbind_effects defs in
let defs = rewrite_defs_remove_e_assign defs in
let defs = rewrite_defs_effectful_let_expressions defs in