summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml4
-rw-r--r--src/ast_util.mli1
-rw-r--r--src/rewrites.ml38
3 files changed, 27 insertions, 16 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 7d56d3e6..1c74381f 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -815,6 +815,10 @@ let destruct_range (Typ_aux (typ_aux, _)) =
when string_of_id f = "range" -> Some (n1, n2)
| _ -> None
+let is_unit_typ = function
+ | Typ_aux (Typ_id u, _) -> string_of_id u = "unit"
+ | _ -> false
+
let rec is_number (Typ_aux (t,_)) =
match t with
| Typ_id (Id_aux (Id "int", _))
diff --git a/src/ast_util.mli b/src/ast_util.mli
index 9f815899..d1827685 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -274,6 +274,7 @@ val is_nexp_constant : nexp -> bool
val lexp_to_exp : 'a lexp -> 'a exp
+val is_unit_typ : typ -> bool
val is_number : typ -> bool
val is_reftyp : typ -> bool
val is_vector_typ : typ -> bool
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 462cbb25..50a8ae68 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -2207,7 +2207,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
(* body is a function : E_id variable -> actual body *)
let (E_aux (_,(l,annot))) = v in
match annot with
- | Some (env, Typ_aux (Typ_id tid, _), eff) when string_of_id tid = "unit" ->
+ | Some (env, typ, eff) when is_unit_typ typ ->
let body = body (annot_exp (E_lit (mk_lit L_unit)) l env unit_typ) in
let body_typ = try typ_of body with _ -> unit_typ in
let wild = add_p_typ (typ_of v) (annot_pat P_wild l env (typ_of v)) in
@@ -2639,10 +2639,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
effects/update local variables in "tail-position": check n_exp_term
and where it is used. *)
if overwrite then
- match typ_of exp with
- | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> tuple_exp vars
- | _ -> raise (Reporting_basic.err_unreachable l
- "add_vars: trying to overwrite a non-unit expression in tail-position")
+ let lb = LB_aux (LB_val (P_aux (P_wild, annot), exp), annot) in
+ let exp' = tuple_exp vars in
+ E_aux (E_let (lb, exp'), swaptyp (typ_of exp') annot)
else tuple_exp (exp :: vars) in
let mk_var_exps_pats l env ids =
@@ -2654,10 +2653,10 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
exp, P_aux (P_id id, a))
|> List.split in
- let rewrite (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (_,(pl,pannot)) as pat) =
+ let rewrite (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) =
let env = env_of_annot annot in
- let overwrite = match typ_of full_exp with
- | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> true
+ let overwrite = match paux with
+ | P_wild | P_typ (_, P_aux (P_wild, _)) -> true
| _ -> false in
match expaux with
| E_for(id,exp1,exp2,exp3,order,exp4) ->
@@ -2799,9 +2798,14 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let lb = annot_letbind (paux, v) l env typ in
let exp = propagate_exp_effect (annot_exp (E_let (lb, body)) l env (typ_of body)) in
rewrite_var_updates exp
+ | E_for _ | E_loop _ | E_if _ | E_case _ | E_assign _ ->
+ let var_id = fresh_id "u__" l in
+ let lb = LB_aux (LB_val (P_aux (P_id var_id, annot), exp), annot) in
+ let exp' = E_aux (E_let (lb, E_aux (E_id var_id, annot)), annot) in
+ rewrite_var_updates exp'
| E_internal_plet (pat,v,body) ->
failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet"
- (* There are no expressions that have effects or variable updates in
+ (* There are no other expressions that have effects or variable updates in
"tail-position": check the definition nexp_term and where it is used. *)
| _ -> exp
@@ -2844,15 +2848,21 @@ let rewrite_defs_remove_superfluous_letbinds =
(* 'let x = EXP1 in x' can be replaced with 'EXP1' *)
| LB_aux (LB_val (P_aux (P_id id, _), exp1), _),
E_aux (E_id id', _)
+ | LB_aux (LB_val (P_aux (P_typ (_, P_aux (P_id id, _)), _), exp1), _),
+ E_aux (E_id id', _)
| LB_aux (LB_val (P_aux (P_id id, _), exp1), _),
E_aux (E_cast (_,E_aux (E_id id', _)), _)
- when Id.compare id id' == 0 && id_is_unbound id (env_of_annot annot) ->
+ when Id.compare id id' = 0 && id_is_unbound id (env_of_annot annot) ->
exp1
+ (* "let _ = () in exp" can be replaced with exp *)
+ | LB_aux (LB_val (P_aux (P_wild, _), E_aux (E_lit (L_aux (L_unit, _)), _)), _),
+ exp2 ->
+ exp2
(* "let x = EXP1 in return x" can be replaced with 'return (EXP1)', at
least when EXP1 is 'small' enough *)
| LB_aux (LB_val (P_aux (P_id id, _), exp1), _),
E_aux (E_internal_return (E_aux (E_id id', _)), _)
- when Id.compare id id' == 0 && small exp1 && id_is_unbound id (env_of_annot annot) ->
+ when Id.compare id id' = 0 && small exp1 && id_is_unbound id (env_of_annot annot) ->
let (E_aux (_,e1annot)) = exp1 in
E_aux (E_internal_return (exp1),e1annot)
| _ -> E_aux (exp,annot)
@@ -2873,10 +2883,6 @@ let rewrite_defs_remove_superfluous_letbinds =
let rewrite_defs_remove_superfluous_returns =
- let has_unittype e = match typ_of e with
- | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> true
- | _ -> false in
-
let untyp_pat = function
| P_aux (P_typ (typ, pat), _) -> pat, Some typ
| pat -> pat, None in
@@ -2901,7 +2907,7 @@ let rewrite_defs_remove_superfluous_returns =
end
| (P_aux (P_wild,pannot), ptyp),
(E_aux (E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)), a), etyp)
- when has_unittype exp1 ->
+ when is_unit_typ (typ_of exp1) ->
begin
match ptyp, etyp with
| Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a)