summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml243
1 files changed, 158 insertions, 85 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 48ea78ae..35659bb4 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -385,7 +385,7 @@ let remove_vector_concat_pat pat =
let id_pat =
match typ_opt with
- | Some typ -> add_p_typ typ (P_aux (P_id child,cannot))
+ | Some typ -> add_p_typ env typ (P_aux (P_id child,cannot))
| None -> P_aux (P_id child,cannot) in
let letbind = fix_eff_lb (LB_aux (LB_val (id_pat,subv),cannot)) in
(letbind,
@@ -716,6 +716,26 @@ let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) =
| _, P_wild -> if is_irrefutable_pattern pat1 then Some [] else None
| _ -> None
+let vector_string_to_bits_pat (L_aux (lit, _) as l_aux) (l, tannot) =
+ let bit_annot = match destruct_tannot tannot with
+ | Some (env, _, _) -> mk_tannot env bit_typ no_effect
+ | None -> empty_tannot
+ in
+ begin match lit with
+ | L_hex _ | L_bin _ -> P_aux (P_vector (List.map (fun p -> P_aux (P_lit p, (l, bit_annot))) (vector_string_to_bit_list l_aux)), (l, tannot))
+ | lit -> P_aux (P_lit l_aux, (l, tannot))
+ end
+
+let vector_string_to_bits_exp (L_aux (lit, _) as l_aux) (l, tannot) =
+ let bit_annot = match destruct_tannot tannot with
+ | Some (env, _, _) -> mk_tannot env bit_typ no_effect
+ | None -> empty_tannot
+ in
+ begin match lit with
+ | L_hex _ | L_bin _ -> E_aux (E_vector (List.map (fun p -> E_aux (E_lit p, (l, bit_annot))) (vector_string_to_bit_list l_aux)), (l, tannot))
+ | lit -> E_aux (E_lit l_aux, (l, tannot))
+ end
+
(* A simple check for pattern disjointness; used for optimisation in the
guarded pattern rewrite step *)
let rec disjoint_pat env (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) =
@@ -729,6 +749,14 @@ let rec disjoint_pat env (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2)
| P_id id, _ when id_is_unbound id env -> false
| _, P_id id when id_is_unbound id env -> false
| P_id id1, P_id id2 -> Id.compare id1 id2 <> 0
+ | P_lit (L_aux ((L_bin _ | L_hex _), _) as lit), _ ->
+ disjoint_pat env (vector_string_to_bits_pat lit (Unknown, empty_tannot)) pat2
+ | _, P_lit (L_aux ((L_bin _ | L_hex _), _) as lit) ->
+ disjoint_pat env pat1 (vector_string_to_bits_pat lit (Unknown, empty_tannot))
+ | P_lit (L_aux (L_num n1, _)), P_lit (L_aux (L_num n2, _)) ->
+ not (Big_int.equal n1 n2)
+ | P_lit (L_aux (l1, _)), P_lit (L_aux (l2, _)) ->
+ l1 <> l2
| P_app (id1, args1), P_app (id2, args2) ->
Id.compare id1 id2 <> 0 || List.exists2 (disjoint_pat env) args1 args2
| P_vector pats1, P_vector pats2
@@ -1221,21 +1249,13 @@ let rewrite_defs_vector_string_pats_to_bit_list env =
let rewrite_p_aux (pat, (annot : tannot annot)) =
let env = env_of_annot annot in
match pat with
- | P_lit (L_aux (lit, l) as l_aux) ->
- begin match lit with
- | L_hex _ | L_bin _ -> P_aux (P_vector (List.map (fun p -> P_aux (P_lit p, (l, mk_tannot env bit_typ no_effect))) (vector_string_to_bit_list l_aux)), annot)
- | lit -> P_aux (P_lit l_aux, annot)
- end
+ | P_lit lit -> vector_string_to_bits_pat lit annot
| pat -> (P_aux (pat, annot))
in
let rewrite_e_aux (exp, (annot : tannot annot)) =
let env = env_of_annot annot in
match exp with
- | E_lit (L_aux (lit, l) as l_aux) ->
- begin match lit with
- | L_hex _ | L_bin _ -> E_aux (E_vector (List.map (fun e -> E_aux (E_lit e, (l, mk_tannot env bit_typ no_effect))) (vector_string_to_bit_list l_aux)), annot)
- | lit -> E_aux (E_lit l_aux, annot)
- end
+ | E_lit lit -> vector_string_to_bits_exp lit annot
| exp -> (E_aux (exp, annot))
in
let pat_alg = { id_pat_alg with p_aux = rewrite_p_aux } in
@@ -1556,7 +1576,7 @@ let rewrite_defs_early_return env (Defs defs) =
let eff = effect_of_annot tannot in
let tannot' = mk_tannot env typ (union_effects eff (mk_effect [BE_escape])) in
let exp' = match Env.get_ret_typ env with
- | Some typ -> add_e_cast typ exp
+ | Some typ -> add_e_cast env typ exp
| None -> exp in
E_aux (E_app (mk_id "early_return", [exp']), (l, tannot'))
| _ -> full_exp in
@@ -2067,35 +2087,36 @@ let rewrite_vector_concat_assignments env defs =
match nexp_simp len with
| Nexp_aux (Nexp_constant len, _) -> len
| _ -> (Big_int.of_int 1)
- else (Big_int.of_int 1) in
+ else (Big_int.of_int 1)
+ in
let next i step =
if is_order_inc ord
then (Big_int.sub (Big_int.add i step) (Big_int.of_int 1), Big_int.add i step)
- else (Big_int.add (Big_int.sub i step) (Big_int.of_int 1), Big_int.sub i step) in
+ else (Big_int.add (Big_int.sub i step) (Big_int.of_int 1), Big_int.sub i step)
+ in
let i = match nexp_simp start with
- | (Nexp_aux (Nexp_constant i, _)) -> i
- | _ -> if is_order_inc ord then Big_int.zero else Big_int.of_int (List.length lexps - 1) in
+ | (Nexp_aux (Nexp_constant i, _)) -> i
+ | _ -> if is_order_inc ord then Big_int.zero else Big_int.of_int (List.length lexps - 1)
+ in
let l = gen_loc (fst annot) in
- let exp' =
- if small exp then strip_exp exp
- else mk_exp (E_id (mk_id "split_vec")) in
+ let vec_id = mk_id "split_vec" in
+ let exp' = if small exp then strip_exp exp else mk_exp (E_id vec_id) in
let lexp_to_exp (i, exps) lexp =
let (j, i') = next i (len lexp) in
let i_exp = mk_exp (E_lit (mk_lit (L_num i))) in
let j_exp = mk_exp (E_lit (mk_lit (L_num j))) in
let sub = mk_exp (E_vector_subrange (exp', i_exp, j_exp)) in
- (i', exps @ [sub]) in
+ (i', exps @ [sub])
+ in
let (_, exps) = List.fold_left lexp_to_exp (i, []) lexps in
- let tup = mk_exp (E_tuple exps) in
- let lexp = LEXP_aux (LEXP_tup (List.map strip_lexp lexps), (l, ())) in
- let e_aux =
- if small exp then mk_exp (E_assign (lexp, tup))
- else mk_exp (
- E_let (
- mk_letbind (mk_pat (P_id (mk_id "split_vec"))) (strip_exp exp),
- mk_exp (E_assign (lexp, tup)))) in
+ let assign lexp exp = mk_exp (E_assign (strip_lexp lexp, exp)) in
+ let block = mk_exp (E_block (List.map2 assign lexps exps)) in
+ let full_exp =
+ if small exp then block else
+ mk_exp (E_let (mk_letbind (mk_pat (P_id vec_id)) (strip_exp exp), block))
+ in
begin
- try check_exp env e_aux unit_typ with
+ try check_exp env full_exp unit_typ with
| Type_error (_, l, err) ->
raise (Reporting.err_typ l (Type_error.string_of_type_error err))
end
@@ -2114,14 +2135,12 @@ let rewrite_tuple_assignments env defs =
let env = env_of_annot annot in
match e_aux with
| E_assign (LEXP_aux (LEXP_tup lexps, _), exp) ->
- (* let _ = Pretty_print_common.print stderr (Pretty_print_sail.doc_exp (E_aux (e_aux, annot))) in *)
let (_, ids) = List.fold_left (fun (n, ids) _ -> (n + 1, ids @ [mk_id ("tup__" ^ string_of_int n)])) (0, []) lexps in
let block_assign i lexp = mk_exp (E_assign (strip_lexp lexp, mk_exp (E_id (mk_id ("tup__" ^ string_of_int i))))) in
let block = mk_exp (E_block (List.mapi block_assign lexps)) in
- let letbind = mk_letbind (mk_pat (P_tup (List.map (fun id -> mk_pat (P_id id)) ids)))
- (strip_exp exp)
- in
- let let_exp = mk_exp (E_let (letbind, block)) in
+ let pat = mk_pat (P_tup (List.map (fun id -> mk_pat (P_id id)) ids)) in
+ let exp' = add_e_cast env (typ_of exp) exp in
+ let let_exp = mk_exp (E_let (mk_letbind pat (strip_exp exp'), block)) in
begin
try check_exp env let_exp unit_typ with
| Type_error (_, l, err) ->
@@ -2162,9 +2181,11 @@ let rewrite_defs_remove_blocks env =
let l = get_loc_exp v in
let env = env_of v in
let typ = typ_of v in
- let wild = add_p_typ typ (annot_pat P_wild l env typ) in
+ let wild = annot_pat P_wild l env typ in
let e_aux = E_let (annot_letbind (unaux_pat wild, v) l env typ, body) in
- fix_eff_exp (annot_exp e_aux l env (typ_of body)) in
+ fix_eff_exp (annot_exp e_aux l env (typ_of body))
+ |> add_typs_let env typ (typ_of body)
+ in
let rec f l = function
| [] -> E_aux (E_lit (L_aux (L_unit,gen_loc l)), (simple_annot l unit_typ))
@@ -2195,19 +2216,20 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
| Some (env, typ, eff) when is_unit_typ typ ->
let body = body (annot_exp (E_lit (mk_lit L_unit)) l env unit_typ) in
let body_typ = try typ_of body with _ -> unit_typ in
- let wild = add_p_typ (typ_of v) (annot_pat P_wild l env (typ_of v)) in
+ let wild = annot_pat P_wild l env typ in
let lb = fix_eff_lb (annot_letbind (unaux_pat wild, v) l env unit_typ) in
fix_eff_exp (annot_exp (E_let (lb, body)) l env body_typ)
+ |> add_typs_let env typ body_typ
| Some (env, typ, eff) ->
let id = fresh_id "w__" l in
- let pat = add_p_typ (typ_of v) (annot_pat (P_id id) l env (typ_of v)) in
+ let pat = annot_pat (P_id id) l env typ in
let lb = fix_eff_lb (annot_letbind (unaux_pat pat, v) l env typ) in
let body = body (annot_exp (E_id id) l env typ) in
fix_eff_exp (annot_exp (E_let (lb, body)) l env (typ_of body))
+ |> add_typs_let env typ (typ_of body)
| None ->
raise (Reporting.err_unreachable l __POS__ "no type information")
-
let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp =
match l with
| [] -> k []
@@ -2244,9 +2266,9 @@ let rewrite_defs_letbind_effects env =
and n_pexp : 'b. bool -> 'a pexp -> ('a pexp -> 'b) -> 'b = fun newreturn pexp k ->
match pexp with
| Pat_aux (Pat_exp (pat,exp),annot) ->
- k (fix_eff_pexp (Pat_aux (Pat_exp (pat,n_exp_term newreturn exp), annot)))
+ k (fix_eff_pexp (Pat_aux (Pat_exp (pat, n_exp_term newreturn exp), annot)))
| Pat_aux (Pat_when (pat,guard,exp),annot) ->
- k (fix_eff_pexp (Pat_aux (Pat_when (pat,n_exp_term newreturn guard,n_exp_term newreturn exp), annot)))
+ k (fix_eff_pexp (Pat_aux (Pat_when (pat, n_exp_term newreturn guard, n_exp_term newreturn exp), annot)))
and n_pexpL (newreturn : bool) (pexps : 'a pexp list) (k : 'a pexp list -> 'a exp) : 'a exp =
mapCont (n_pexp newreturn) pexps k
@@ -2300,15 +2322,15 @@ let rewrite_defs_letbind_effects env =
and n_lexpL (lexps : 'a lexp list) (k : 'a lexp list -> 'a exp) : 'a exp =
mapCont n_lexp lexps k
- and n_exp_term (newreturn : bool) (exp : 'a exp) : 'a exp =
+ and n_exp_term ?cast:(cast=false) (newreturn : bool) (exp : 'a exp) : 'a exp =
let (E_aux (_,(l,tannot))) = exp in
let exp =
if newreturn then
(* let typ = try typ_of exp with _ -> unit_typ in *)
- let exp = add_e_cast (typ_of exp) exp in
+ let exp = if cast then add_e_cast (env_of exp) (typ_of exp) exp else exp in
annot_exp (E_internal_return exp) l (env_of exp) (typ_of exp)
- else
- exp in
+ else exp
+ in
(* n_exp_term forces an expression to be translated into a form
"let .. let .. let .. in EXP" where EXP has no effect and does not update
variables *)
@@ -2331,8 +2353,8 @@ let rewrite_defs_letbind_effects env =
(* Leave effectful operands of Boolean "and"/"or" in place to allow
short-circuiting. *)
let newreturn = effectful l || effectful r in
- let l = n_exp_term newreturn l in
- let r = n_exp_term newreturn r in
+ let l = n_exp_term ~cast:true newreturn l in
+ let r = n_exp_term ~cast:true newreturn r in
k (rewrap (E_app (op_bool, [l; r])))
| E_app (id,exps) ->
n_exp_nameL exps (fun exps ->
@@ -2351,7 +2373,8 @@ let rewrite_defs_letbind_effects env =
let newreturn = effectful exp2 || effectful exp3 in
let exp2 = n_exp_term newreturn exp2 in
let exp3 = n_exp_term newreturn exp3 in
- k (rewrap (E_if (exp1,exp2,exp3))) in
+ k (rewrap (E_if (exp1,exp2,exp3)))
+ in
if value exp1 then e_if (n_exp_term false exp1) else n_exp_name exp1 e_if
| E_for (id,start,stop,by,dir,body) ->
n_exp_name start (fun start ->
@@ -2365,7 +2388,7 @@ let rewrite_defs_letbind_effects env =
| Measure_aux (Measure_some exp,l) ->
Measure_aux (Measure_some (n_exp_term false exp),l)
in
- let cond = n_exp_term (effectful cond) cond in
+ let cond = n_exp_term ~cast:true (effectful cond) cond in
let body = n_exp_term (effectful body) body in
k (rewrap (E_loop (loop,measure,cond,body)))
| E_vector exps ->
@@ -2437,7 +2460,7 @@ let rewrite_defs_letbind_effects env =
| E_assert (exp1,exp2) ->
n_exp_name exp1 (fun exp1 ->
n_exp_name exp2 (fun exp2 ->
- k (rewrap (E_assert (exp1,exp2)))))
+ k (rewrap (E_assert (exp1, exp2)))))
| E_var (lexp,exp1,exp2) ->
n_lexp lexp (fun lexp ->
n_exp exp1 (fun exp1 ->
@@ -2505,7 +2528,7 @@ let rewrite_defs_internal_lets env =
let rec pat_of_local_lexp (LEXP_aux (lexp, ((l, _) as annot))) = match lexp with
| LEXP_id id -> P_aux (P_id id, annot)
- | LEXP_cast (typ, id) -> add_p_typ typ (P_aux (P_id id, annot))
+ | LEXP_cast (typ, id) -> add_p_typ (env_of_annot annot) typ (P_aux (P_id id, annot))
| LEXP_tup lexps -> P_aux (P_tup (List.map pat_of_local_lexp lexps), annot)
| _ -> raise (Reporting.err_unreachable l __POS__ "unexpected local lexp") in
@@ -2524,7 +2547,7 @@ let rewrite_defs_internal_lets env =
| LEXP_aux (_,lexp_annot') -> lexp_annot'
| exception _ -> lannot)
in
- let rhs = add_e_cast ltyp (rhs exp) in
+ let rhs = add_e_cast (env_of exp) ltyp (rhs exp) in
E_let (LB_aux (LB_val (pat_of_local_lexp lhs, rhs), annot), body)
| LB_aux (LB_val (pat,exp'),annot') ->
if effectful exp'
@@ -2536,7 +2559,7 @@ let rewrite_defs_internal_lets env =
| LEXP_aux (LEXP_id id, annot) ->
(P_id id, annot)
| LEXP_aux (LEXP_cast (typ, id), annot) ->
- (unaux_pat (add_p_typ typ (P_aux (P_id id, annot))), annot)
+ (unaux_pat (add_p_typ (env_of_annot annot) typ (P_aux (P_id id, annot))), annot)
| _ -> failwith "E_var with unexpected lexp" in
if effectful exp1 then
E_internal_plet (P_aux (paux, annot), exp1, exp2)
@@ -2566,7 +2589,7 @@ let fold_typed_guards env guards =
| g :: gs -> List.fold_left (fun g g' -> annot_exp (E_app (mk_id "and_bool", [g; g'])) Parse_ast.Unknown env bool_typ) g gs
-let rewrite_pexp_with_guards rewrite_pat (Pat_aux (pexp_aux, (l, _)) as pexp) =
+let rewrite_pexp_with_guards rewrite_pat (Pat_aux (pexp_aux, (annot: tannot annot)) as pexp) =
let guards = ref [] in
match pexp_aux with
@@ -2576,13 +2599,13 @@ let rewrite_pexp_with_guards rewrite_pat (Pat_aux (pexp_aux, (l, _)) as pexp) =
match !guards with
| [] -> pexp
| gs ->
- let unchecked_pexp = mk_pexp ~loc:l (Pat_when (strip_pat pat, List.map strip_exp gs |> fold_guards, strip_exp exp)) in
+ let unchecked_pexp = mk_pexp (Pat_when (strip_pat pat, List.map strip_exp gs |> fold_guards, strip_exp exp)) in
check_case (env_of_pat pat) (typ_of_pat pat) unchecked_pexp (typ_of exp)
end
| Pat_when (pat, guard, exp) ->
begin
let pat = fold_pat { id_pat_alg with p_aux = rewrite_pat guards } pat in
- let unchecked_pexp = mk_pexp ~loc:l (Pat_when (strip_pat pat, List.map strip_exp !guards |> fold_guards, strip_exp exp)) in
+ let unchecked_pexp = mk_pexp (Pat_when (strip_pat pat, List.map strip_exp !guards |> fold_guards, strip_exp exp)) in
check_case (env_of_pat pat) (typ_of_pat pat) unchecked_pexp (typ_of exp)
end
@@ -3334,22 +3357,24 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let tuple_exp = function
| [] -> annot_exp (E_lit (mk_lit L_unit)) l env unit_typ
| [e] -> e
- | es -> annot_exp (E_tuple es) l env (tuple_typ (List.map typ_of es)) in
+ | es -> annot_exp (E_tuple es) l env (tuple_typ (List.map typ_of es))
+ in
let tuple_pat = function
| [] -> annot_pat P_wild l env unit_typ
| [pat] ->
let typ = typ_of_pat pat in
- add_p_typ typ pat
+ add_p_typ env typ pat
| pats ->
let typ = tuple_typ (List.map typ_of_pat pats) in
- add_p_typ typ (annot_pat (P_tup pats) l env typ) in
+ add_p_typ env typ (annot_pat (P_tup pats) l env typ)
+ in
let rec add_vars overwrite ((E_aux (expaux,annot)) as exp) vars =
match expaux with
| E_let (lb,exp) ->
let exp = add_vars overwrite exp vars in
- E_aux (E_let (lb,exp),swaptyp (typ_of exp) annot)
+ E_aux (E_let (lb,exp), swaptyp (typ_of exp) annot)
| E_var (lexp,exp1,exp2) ->
let exp2 = add_vars overwrite exp2 vars in
E_aux (E_var (lexp,exp1,exp2), swaptyp (typ_of exp2) annot)
@@ -3358,7 +3383,11 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
E_aux (E_internal_plet (pat,exp1,exp2), swaptyp (typ_of exp2) annot)
| E_internal_return exp2 ->
let exp2 = add_vars overwrite exp2 vars in
- E_aux (E_internal_return exp2,swaptyp (typ_of exp2) annot)
+ E_aux (E_internal_return exp2, swaptyp (typ_of exp2) annot)
+ | E_cast (typ, exp) ->
+ let (E_aux (expaux, annot) as exp) = add_vars overwrite exp vars in
+ let typ' = typ_of exp in
+ add_e_cast (env_of exp) typ' (E_aux (expaux, swaptyp typ' annot))
| _ ->
(* after rewrite_defs_letbind_effects there cannot be terms that have
effects/update local variables in "tail-position": check n_exp_term
@@ -3367,6 +3396,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let lb = LB_aux (LB_val (P_aux (P_wild, annot), exp), annot) in
let exp' = tuple_exp vars in
E_aux (E_let (lb, exp'), swaptyp (typ_of exp') annot)
+ |> add_typs_let env (typ_of exp) (typ_of exp')
else tuple_exp (exp :: vars) in
let mk_var_exps_pats l env ids =
@@ -3378,7 +3408,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
exp, P_aux (P_id id, a))
|> List.split in
- let rewrite used_vars (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) =
+ let rec rewrite used_vars (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) =
let env = env_of_annot annot in
let overwrite = match paux with
| P_wild | P_typ (_, P_aux (P_wild, _)) -> true
@@ -3417,11 +3447,15 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
in
let lvar_nc = nc_and (nc_lteq (nvar lower_kid) (nvar lvar_kid)) (nc_lteq (nvar lvar_kid) (nvar upper_kid)) in
let lvar_typ = mk_typ (Typ_exist (List.map (mk_kopt K_int) [lvar_kid], lvar_nc, atom_typ (nvar lvar_kid))) in
- let lvar_pat = unaux_pat (add_p_typ lvar_typ (annot_pat (P_var (
+ let lvar_pat = unaux_pat (annot_pat (P_var (
annot_pat (P_id id) el env (atom_typ (nvar lvar_kid)),
- TP_aux (TP_var lvar_kid, gen_loc el))) el env lvar_typ)) in
+ TP_aux (TP_var lvar_kid, gen_loc el))) el env lvar_typ)
+ in
let lb = fix_eff_lb (annot_letbind (lvar_pat, exp1) el env lvar_typ) in
- let body = fix_eff_exp (annot_exp (E_let (lb, exp4)) el env (typ_of exp4)) in
+ let body =
+ fix_eff_exp (annot_exp (E_let (lb, exp4)) el env (typ_of exp4))
+ |> add_typs_let env lvar_typ (typ_of exp4)
+ in
(* If lower > upper, the loop body never gets executed, and the type
checker might not be able to prove that the initial value exp1
satisfies the constraints on the loop variable.
@@ -3521,7 +3555,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| E_assign (lexp,vexp) ->
let mk_id_pat id =
let typ = lvar_typ (Env.lookup_id id env) in
- add_p_typ typ (annot_pat (P_id id) pl env typ)
+ add_p_typ env typ (annot_pat (P_id id) pl env typ)
in
if effectful exp then
Same_vars (E_aux (E_assign (lexp,vexp),annot))
@@ -3531,7 +3565,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let pat = annot_pat (P_id id) pl env (typ_of vexp) in
Added_vars (vexp, mk_id_pat id)
| LEXP_aux (LEXP_cast (typ,id),annot) ->
- let pat = add_p_typ typ (annot_pat (P_id id) pl env (typ_of vexp)) in
+ let pat = add_p_typ env typ (annot_pat (P_id id) pl env (typ_of vexp)) in
Added_vars (vexp,pat)
| LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,((l2,_) as annot2)),i),((l1,_) as annot)) ->
let eid = annot_exp (E_id id) l2 env (typ_of_annot annot2) in
@@ -3545,6 +3579,13 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let pat = annot_pat (P_id id) pl env (typ_of vexp) in
Added_vars (vexp,pat)
| _ -> Same_vars (E_aux (E_assign (lexp,vexp),annot)))
+ | E_cast (typ, exp) ->
+ begin match rewrite used_vars exp pat with
+ | Added_vars (exp', pat') ->
+ Added_vars (add_e_cast (env_of exp') (typ_of exp') exp', pat')
+ | Same_vars (exp') ->
+ Same_vars (E_aux (E_cast (typ, exp'), annot))
+ end
| _ ->
(* after rewrite_defs_letbind_effects this expression is pure and updates
no variables: check n_exp_term and where it's used. *)
@@ -3565,7 +3606,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| LEXP_aux (LEXP_id id, _) ->
P_id id, typ_of v
| LEXP_aux (LEXP_cast (typ, id), _) ->
- unaux_pat (add_p_typ typ (annot_pat (P_id id) l env (typ_of v))), typ
+ unaux_pat (add_p_typ env typ (annot_pat (P_id id) l env (typ_of v))), typ
| _ ->
raise (Reporting.err_unreachable l __POS__
"E_var with a lexp that is not a variable") in
@@ -3579,6 +3620,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
rewrite_var_updates exp'
| E_internal_plet (pat,v,body) ->
failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet"
+ | E_cast (typ, exp) ->
+ let exp' = rewrite_var_updates exp in
+ E_aux (E_cast (typ, exp'), annot)
(* There are no other expressions that have effects or variable updates in
"tail-position": check the definition nexp_term and where it is used. *)
| _ -> exp
@@ -3714,7 +3758,7 @@ let rewrite_defs_remove_superfluous_returns env =
let add_opt_cast typopt1 typopt2 annot exp =
match typopt1, typopt2 with
- | Some typ, _ | _, Some typ -> add_e_cast typ exp
+ | Some typ, _ | _, Some typ -> add_e_cast (env_of exp) typ exp
| None, None -> exp
in
@@ -3765,11 +3809,11 @@ let rewrite_defs_remove_superfluous_returns env =
let rewrite_defs_remove_e_assign env (Defs defs) =
let (Defs loop_specs) = fst (Type_error.check initial_env (Defs (List.map gen_vs
- [("foreach#", "forall ('vars : Type). (int, int, int, bool, 'vars, 'vars) -> 'vars");
- ("while#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars");
- ("until#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars");
- ("while#t", "forall ('vars : Type). (bool, 'vars, 'vars, int) -> 'vars");
- ("until#t", "forall ('vars : Type). (bool, 'vars, 'vars, int) -> 'vars")]))) in
+ [("foreach#", "forall ('vars_in 'vars_out : Type). (int, int, int, bool, 'vars_in, 'vars_out) -> 'vars_out");
+ ("while#", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out) -> 'vars_out");
+ ("until#", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out) -> 'vars_out");
+ ("while#t", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out, int) -> 'vars_out");
+ ("until#t", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out, int) -> 'vars_out")]))) in
let rewrite_exp _ e =
replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in
rewrite_defs_base
@@ -4752,6 +4796,36 @@ let rec move_loop_measures (Defs defs) =
in Defs (List.rev rev_defs)
+let rewrite_toplevel_consts target type_env (Defs defs) =
+ let istate = Constant_fold.initial_state (Defs defs) type_env in
+ let subst consts exp =
+ let open Rewriter in
+ let used_ids = fold_exp { (pure_exp_alg IdSet.empty IdSet.union) with e_id = IdSet.singleton } exp in
+ let subst_ids = IdSet.filter (fun id -> Bindings.mem id consts) used_ids in
+ IdSet.fold (fun id -> subst id (Bindings.find id consts)) subst_ids exp
+ in
+ let rewrite_def (revdefs, consts) = function
+ | DEF_val (LB_aux (LB_val (pat, exp), a) as lb) ->
+ begin match unaux_pat pat with
+ | P_id id | P_typ (_, P_aux (P_id id, _)) ->
+ let exp' = Constant_fold.rewrite_exp_once target istate (subst consts exp) in
+ if Constant_fold.is_constant exp' then
+ try
+ let exp' = infer_exp (env_of exp') (strip_exp exp') in
+ let pannot = (pat_loc pat, mk_tannot (env_of_pat pat) (typ_of exp') no_effect) in
+ let pat' = P_aux (P_typ (typ_of exp', P_aux (P_id id, pannot)), pannot) in
+ let consts' = Bindings.add id exp' consts in
+ (DEF_val (LB_aux (LB_val (pat', exp'), a)) :: revdefs, consts')
+ with
+ | _ -> (DEF_val lb :: revdefs, consts)
+ else (DEF_val lb :: revdefs, consts)
+ | _ -> (DEF_val lb :: revdefs, consts)
+ end
+ | def -> (def :: revdefs, consts)
+ in
+ let (revdefs, _) = List.fold_left rewrite_def ([], Bindings.empty) defs in
+ Defs (List.rev revdefs)
+
let opt_mono_rewrites = ref false
let opt_mono_complex_nexps = ref true
@@ -4851,10 +4925,10 @@ let all_rewrites = [
("mapping_builtins", Basic_rewriter rewrite_defs_mapping_patterns);
("mono_rewrites", Basic_rewriter mono_rewrites);
("toplevel_nexps", Basic_rewriter rewrite_toplevel_nexps);
+ ("toplevel_consts", String_rewriter (fun target -> Basic_rewriter (rewrite_toplevel_consts target)));
("monomorphise", String_rewriter (fun target -> Basic_rewriter (monomorphise target)));
- ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
- ("add_bitvector_casts", Basic_rewriter (fun _ -> Monomorphise.add_bitvector_casts));
- ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
+ ("atoms_to_singletons", String_rewriter (fun target -> (Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons target))));
+ ("add_bitvector_casts", Basic_rewriter Monomorphise.add_bitvector_casts);
("remove_impossible_int_cases", Basic_rewriter Constant_propagation.remove_impossible_int_cases);
("const_prop_mutrec", String_rewriter (fun target -> Basic_rewriter (Constant_propagation_mutrec.rewrite_defs target)));
("make_cases_exhaustive", Basic_rewriter MakeExhaustive.rewrite);
@@ -4902,17 +4976,18 @@ let rewrites_lem = [
("mono_rewrites", []);
("recheck_defs", [If_mono_arg]);
("undefined", [Bool_arg false]);
+ ("toplevel_consts", [String_arg "lem"; If_mwords_arg]);
("toplevel_nexps", [If_mono_arg]);
("monomorphise", [String_arg "lem"; If_mono_arg]);
("recheck_defs", [If_mwords_arg]);
("add_bitvector_casts", [If_mwords_arg]);
- ("atoms_to_singletons", [If_mono_arg]);
+ ("atoms_to_singletons", [String_arg "lem"; If_mono_arg]);
("recheck_defs", [If_mwords_arg]);
("vector_string_pats_to_bit_list", []);
("remove_not_pats", []);
("remove_impossible_int_cases", []);
- ("vector_concat_assignments", []);
("tuple_assignments", []);
+ ("vector_concat_assignments", []);
("simple_assignments", []);
("remove_vector_concat", []);
("remove_bitvector_pats", []);
@@ -4953,8 +5028,8 @@ let rewrites_coq = [
("vector_string_pats_to_bit_list", []);
("remove_not_pats", []);
("remove_impossible_int_cases", []);
- ("vector_concat_assignments", []);
("tuple_assignments", []);
+ ("vector_concat_assignments", []);
("simple_assignments", []);
("remove_vector_concat", []);
("remove_bitvector_pats", []);
@@ -5002,8 +5077,8 @@ let rewrites_ocaml = [
("mapping_builtins", []);
("undefined", [Bool_arg false]);
("vector_string_pats_to_bit_list", []);
- ("vector_concat_assignments", []);
("tuple_assignments", []);
+ ("vector_concat_assignments", []);
("simple_assignments", []);
("remove_not_pats", []);
("remove_vector_concat", []);
@@ -5026,7 +5101,7 @@ let rewrites_c = [
("recheck_defs", [If_mono_arg]);
("toplevel_nexps", [If_mono_arg]);
("monomorphise", [String_arg "c"; If_mono_arg]);
- ("atoms_to_singletons", [If_mono_arg]);
+ ("atoms_to_singletons", [String_arg "c"; If_mono_arg]);
("recheck_defs", [If_mono_arg]);
("undefined", [Bool_arg false]);
("vector_string_pats_to_bit_list", []);
@@ -5034,10 +5109,8 @@ let rewrites_c = [
("remove_vector_concat", []);
("remove_bitvector_pats", []);
("pattern_literals", [Literal_arg "all"]);
- ("vector_concat_assignments", []);
("tuple_assignments", []);
("vector_concat_assignments", []);
- ("tuple_assignments", []);
("simple_assignments", []);
("exp_lift_assign", []);
("merge_function_clauses", []);
@@ -5052,8 +5125,8 @@ let rewrites_interpreter = [
("pat_string_append", []);
("mapping_builtins", []);
("undefined", [Bool_arg false]);
- ("vector_concat_assignments", []);
("tuple_assignments", []);
+ ("vector_concat_assignments", []);
("simple_assignments", [])
]