summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml14
1 files changed, 10 insertions, 4 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 86560415..a143175d 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -1674,6 +1674,10 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base
rewriting of early returns
*)
let rewrite_defs_early_return (Defs defs) =
+ let is_unit (E_aux (exp, _)) = match exp with
+ | E_lit (L_aux (L_unit, _)) -> true
+ | _ -> false in
+
let is_return (E_aux (exp, _)) = match exp with
| E_return _ -> true
| _ -> false in
@@ -1693,19 +1697,21 @@ let rewrite_defs_early_return (Defs defs) =
return, fold the rest of the block after the if-expression into the
other branch *)
let fold_if_return exp block = match exp with
- | E_aux (E_if (c, t, e), annot) when is_return t ->
+ | E_aux (E_if (c, t, (E_aux (_, annot) as e)), _) when is_return t ->
let annot = match block with
| [] -> annot
| _ -> let (E_aux (_, annot)) = Util.last block in annot
in
- let e' = E_aux (e_block (e :: block), annot) in
+ let block = if is_unit e then block else e :: block in
+ let e' = E_aux (e_block block, annot) in
[E_aux (e_if (c, t, e'), annot)]
- | E_aux (E_if (c, t, e), annot) when is_return e ->
+ | E_aux (E_if (c, (E_aux (_, annot) as t), e), _) when is_return e ->
let annot = match block with
| [] -> annot
| _ -> let (E_aux (_, annot)) = Util.last block in annot
in
- let t' = E_aux (e_block (t :: block), annot) in
+ let block = if is_unit t then block else t :: block in
+ let t' = E_aux (e_block block, annot) in
[E_aux (e_if (c, t', e), annot)]
| _ -> exp :: block in
let es = List.fold_right fold_if_return es [] in