diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 4 | ||||
| -rw-r--r-- | src/ast_util.mli | 1 | ||||
| -rw-r--r-- | src/rewrites.ml | 38 |
3 files changed, 27 insertions, 16 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 7d56d3e6..1c74381f 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -815,6 +815,10 @@ let destruct_range (Typ_aux (typ_aux, _)) = when string_of_id f = "range" -> Some (n1, n2) | _ -> None +let is_unit_typ = function + | Typ_aux (Typ_id u, _) -> string_of_id u = "unit" + | _ -> false + let rec is_number (Typ_aux (t,_)) = match t with | Typ_id (Id_aux (Id "int", _)) diff --git a/src/ast_util.mli b/src/ast_util.mli index 9f815899..d1827685 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -274,6 +274,7 @@ val is_nexp_constant : nexp -> bool val lexp_to_exp : 'a lexp -> 'a exp +val is_unit_typ : typ -> bool val is_number : typ -> bool val is_reftyp : typ -> bool val is_vector_typ : typ -> bool diff --git a/src/rewrites.ml b/src/rewrites.ml index 462cbb25..50a8ae68 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2207,7 +2207,7 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = (* body is a function : E_id variable -> actual body *) let (E_aux (_,(l,annot))) = v in match annot with - | Some (env, Typ_aux (Typ_id tid, _), eff) when string_of_id tid = "unit" -> + | Some (env, typ, eff) when is_unit_typ typ -> let body = body (annot_exp (E_lit (mk_lit L_unit)) l env unit_typ) in let body_typ = try typ_of body with _ -> unit_typ in let wild = add_p_typ (typ_of v) (annot_pat P_wild l env (typ_of v)) in @@ -2639,10 +2639,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = effects/update local variables in "tail-position": check n_exp_term and where it is used. *) if overwrite then - match typ_of exp with - | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> tuple_exp vars - | _ -> raise (Reporting_basic.err_unreachable l - "add_vars: trying to overwrite a non-unit expression in tail-position") + let lb = LB_aux (LB_val (P_aux (P_wild, annot), exp), annot) in + let exp' = tuple_exp vars in + E_aux (E_let (lb, exp'), swaptyp (typ_of exp') annot) else tuple_exp (exp :: vars) in let mk_var_exps_pats l env ids = @@ -2654,10 +2653,10 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = exp, P_aux (P_id id, a)) |> List.split in - let rewrite (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (_,(pl,pannot)) as pat) = + let rewrite (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) = let env = env_of_annot annot in - let overwrite = match typ_of full_exp with - | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> true + let overwrite = match paux with + | P_wild | P_typ (_, P_aux (P_wild, _)) -> true | _ -> false in match expaux with | E_for(id,exp1,exp2,exp3,order,exp4) -> @@ -2799,9 +2798,14 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let lb = annot_letbind (paux, v) l env typ in let exp = propagate_exp_effect (annot_exp (E_let (lb, body)) l env (typ_of body)) in rewrite_var_updates exp + | E_for _ | E_loop _ | E_if _ | E_case _ | E_assign _ -> + let var_id = fresh_id "u__" l in + let lb = LB_aux (LB_val (P_aux (P_id var_id, annot), exp), annot) in + let exp' = E_aux (E_let (lb, E_aux (E_id var_id, annot)), annot) in + rewrite_var_updates exp' | E_internal_plet (pat,v,body) -> failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet" - (* There are no expressions that have effects or variable updates in + (* There are no other expressions that have effects or variable updates in "tail-position": check the definition nexp_term and where it is used. *) | _ -> exp @@ -2844,15 +2848,21 @@ let rewrite_defs_remove_superfluous_letbinds = (* 'let x = EXP1 in x' can be replaced with 'EXP1' *) | LB_aux (LB_val (P_aux (P_id id, _), exp1), _), E_aux (E_id id', _) + | LB_aux (LB_val (P_aux (P_typ (_, P_aux (P_id id, _)), _), exp1), _), + E_aux (E_id id', _) | LB_aux (LB_val (P_aux (P_id id, _), exp1), _), E_aux (E_cast (_,E_aux (E_id id', _)), _) - when Id.compare id id' == 0 && id_is_unbound id (env_of_annot annot) -> + when Id.compare id id' = 0 && id_is_unbound id (env_of_annot annot) -> exp1 + (* "let _ = () in exp" can be replaced with exp *) + | LB_aux (LB_val (P_aux (P_wild, _), E_aux (E_lit (L_aux (L_unit, _)), _)), _), + exp2 -> + exp2 (* "let x = EXP1 in return x" can be replaced with 'return (EXP1)', at least when EXP1 is 'small' enough *) | LB_aux (LB_val (P_aux (P_id id, _), exp1), _), E_aux (E_internal_return (E_aux (E_id id', _)), _) - when Id.compare id id' == 0 && small exp1 && id_is_unbound id (env_of_annot annot) -> + when Id.compare id id' = 0 && small exp1 && id_is_unbound id (env_of_annot annot) -> let (E_aux (_,e1annot)) = exp1 in E_aux (E_internal_return (exp1),e1annot) | _ -> E_aux (exp,annot) @@ -2873,10 +2883,6 @@ let rewrite_defs_remove_superfluous_letbinds = let rewrite_defs_remove_superfluous_returns = - let has_unittype e = match typ_of e with - | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> true - | _ -> false in - let untyp_pat = function | P_aux (P_typ (typ, pat), _) -> pat, Some typ | pat -> pat, None in @@ -2901,7 +2907,7 @@ let rewrite_defs_remove_superfluous_returns = end | (P_aux (P_wild,pannot), ptyp), (E_aux (E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)), a), etyp) - when has_unittype exp1 -> + when is_unit_typ (typ_of exp1) -> begin match ptyp, etyp with | Some typ, _ | _, Some typ -> E_aux (E_cast (typ, exp1), a) |
