diff options
Diffstat (limited to 'src/rewrites.ml')
| -rw-r--r-- | src/rewrites.ml | 243 |
1 files changed, 158 insertions, 85 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index 48ea78ae..35659bb4 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -385,7 +385,7 @@ let remove_vector_concat_pat pat = let id_pat = match typ_opt with - | Some typ -> add_p_typ typ (P_aux (P_id child,cannot)) + | Some typ -> add_p_typ env typ (P_aux (P_id child,cannot)) | None -> P_aux (P_id child,cannot) in let letbind = fix_eff_lb (LB_aux (LB_val (id_pat,subv),cannot)) in (letbind, @@ -716,6 +716,26 @@ let rec subsumes_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) = | _, P_wild -> if is_irrefutable_pattern pat1 then Some [] else None | _ -> None +let vector_string_to_bits_pat (L_aux (lit, _) as l_aux) (l, tannot) = + let bit_annot = match destruct_tannot tannot with + | Some (env, _, _) -> mk_tannot env bit_typ no_effect + | None -> empty_tannot + in + begin match lit with + | L_hex _ | L_bin _ -> P_aux (P_vector (List.map (fun p -> P_aux (P_lit p, (l, bit_annot))) (vector_string_to_bit_list l_aux)), (l, tannot)) + | lit -> P_aux (P_lit l_aux, (l, tannot)) + end + +let vector_string_to_bits_exp (L_aux (lit, _) as l_aux) (l, tannot) = + let bit_annot = match destruct_tannot tannot with + | Some (env, _, _) -> mk_tannot env bit_typ no_effect + | None -> empty_tannot + in + begin match lit with + | L_hex _ | L_bin _ -> E_aux (E_vector (List.map (fun p -> E_aux (E_lit p, (l, bit_annot))) (vector_string_to_bit_list l_aux)), (l, tannot)) + | lit -> E_aux (E_lit l_aux, (l, tannot)) + end + (* A simple check for pattern disjointness; used for optimisation in the guarded pattern rewrite step *) let rec disjoint_pat env (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) = @@ -729,6 +749,14 @@ let rec disjoint_pat env (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) | P_id id, _ when id_is_unbound id env -> false | _, P_id id when id_is_unbound id env -> false | P_id id1, P_id id2 -> Id.compare id1 id2 <> 0 + | P_lit (L_aux ((L_bin _ | L_hex _), _) as lit), _ -> + disjoint_pat env (vector_string_to_bits_pat lit (Unknown, empty_tannot)) pat2 + | _, P_lit (L_aux ((L_bin _ | L_hex _), _) as lit) -> + disjoint_pat env pat1 (vector_string_to_bits_pat lit (Unknown, empty_tannot)) + | P_lit (L_aux (L_num n1, _)), P_lit (L_aux (L_num n2, _)) -> + not (Big_int.equal n1 n2) + | P_lit (L_aux (l1, _)), P_lit (L_aux (l2, _)) -> + l1 <> l2 | P_app (id1, args1), P_app (id2, args2) -> Id.compare id1 id2 <> 0 || List.exists2 (disjoint_pat env) args1 args2 | P_vector pats1, P_vector pats2 @@ -1221,21 +1249,13 @@ let rewrite_defs_vector_string_pats_to_bit_list env = let rewrite_p_aux (pat, (annot : tannot annot)) = let env = env_of_annot annot in match pat with - | P_lit (L_aux (lit, l) as l_aux) -> - begin match lit with - | L_hex _ | L_bin _ -> P_aux (P_vector (List.map (fun p -> P_aux (P_lit p, (l, mk_tannot env bit_typ no_effect))) (vector_string_to_bit_list l_aux)), annot) - | lit -> P_aux (P_lit l_aux, annot) - end + | P_lit lit -> vector_string_to_bits_pat lit annot | pat -> (P_aux (pat, annot)) in let rewrite_e_aux (exp, (annot : tannot annot)) = let env = env_of_annot annot in match exp with - | E_lit (L_aux (lit, l) as l_aux) -> - begin match lit with - | L_hex _ | L_bin _ -> E_aux (E_vector (List.map (fun e -> E_aux (E_lit e, (l, mk_tannot env bit_typ no_effect))) (vector_string_to_bit_list l_aux)), annot) - | lit -> E_aux (E_lit l_aux, annot) - end + | E_lit lit -> vector_string_to_bits_exp lit annot | exp -> (E_aux (exp, annot)) in let pat_alg = { id_pat_alg with p_aux = rewrite_p_aux } in @@ -1556,7 +1576,7 @@ let rewrite_defs_early_return env (Defs defs) = let eff = effect_of_annot tannot in let tannot' = mk_tannot env typ (union_effects eff (mk_effect [BE_escape])) in let exp' = match Env.get_ret_typ env with - | Some typ -> add_e_cast typ exp + | Some typ -> add_e_cast env typ exp | None -> exp in E_aux (E_app (mk_id "early_return", [exp']), (l, tannot')) | _ -> full_exp in @@ -2067,35 +2087,36 @@ let rewrite_vector_concat_assignments env defs = match nexp_simp len with | Nexp_aux (Nexp_constant len, _) -> len | _ -> (Big_int.of_int 1) - else (Big_int.of_int 1) in + else (Big_int.of_int 1) + in let next i step = if is_order_inc ord then (Big_int.sub (Big_int.add i step) (Big_int.of_int 1), Big_int.add i step) - else (Big_int.add (Big_int.sub i step) (Big_int.of_int 1), Big_int.sub i step) in + else (Big_int.add (Big_int.sub i step) (Big_int.of_int 1), Big_int.sub i step) + in let i = match nexp_simp start with - | (Nexp_aux (Nexp_constant i, _)) -> i - | _ -> if is_order_inc ord then Big_int.zero else Big_int.of_int (List.length lexps - 1) in + | (Nexp_aux (Nexp_constant i, _)) -> i + | _ -> if is_order_inc ord then Big_int.zero else Big_int.of_int (List.length lexps - 1) + in let l = gen_loc (fst annot) in - let exp' = - if small exp then strip_exp exp - else mk_exp (E_id (mk_id "split_vec")) in + let vec_id = mk_id "split_vec" in + let exp' = if small exp then strip_exp exp else mk_exp (E_id vec_id) in let lexp_to_exp (i, exps) lexp = let (j, i') = next i (len lexp) in let i_exp = mk_exp (E_lit (mk_lit (L_num i))) in let j_exp = mk_exp (E_lit (mk_lit (L_num j))) in let sub = mk_exp (E_vector_subrange (exp', i_exp, j_exp)) in - (i', exps @ [sub]) in + (i', exps @ [sub]) + in let (_, exps) = List.fold_left lexp_to_exp (i, []) lexps in - let tup = mk_exp (E_tuple exps) in - let lexp = LEXP_aux (LEXP_tup (List.map strip_lexp lexps), (l, ())) in - let e_aux = - if small exp then mk_exp (E_assign (lexp, tup)) - else mk_exp ( - E_let ( - mk_letbind (mk_pat (P_id (mk_id "split_vec"))) (strip_exp exp), - mk_exp (E_assign (lexp, tup)))) in + let assign lexp exp = mk_exp (E_assign (strip_lexp lexp, exp)) in + let block = mk_exp (E_block (List.map2 assign lexps exps)) in + let full_exp = + if small exp then block else + mk_exp (E_let (mk_letbind (mk_pat (P_id vec_id)) (strip_exp exp), block)) + in begin - try check_exp env e_aux unit_typ with + try check_exp env full_exp unit_typ with | Type_error (_, l, err) -> raise (Reporting.err_typ l (Type_error.string_of_type_error err)) end @@ -2114,14 +2135,12 @@ let rewrite_tuple_assignments env defs = let env = env_of_annot annot in match e_aux with | E_assign (LEXP_aux (LEXP_tup lexps, _), exp) -> - (* let _ = Pretty_print_common.print stderr (Pretty_print_sail.doc_exp (E_aux (e_aux, annot))) in *) let (_, ids) = List.fold_left (fun (n, ids) _ -> (n + 1, ids @ [mk_id ("tup__" ^ string_of_int n)])) (0, []) lexps in let block_assign i lexp = mk_exp (E_assign (strip_lexp lexp, mk_exp (E_id (mk_id ("tup__" ^ string_of_int i))))) in let block = mk_exp (E_block (List.mapi block_assign lexps)) in - let letbind = mk_letbind (mk_pat (P_tup (List.map (fun id -> mk_pat (P_id id)) ids))) - (strip_exp exp) - in - let let_exp = mk_exp (E_let (letbind, block)) in + let pat = mk_pat (P_tup (List.map (fun id -> mk_pat (P_id id)) ids)) in + let exp' = add_e_cast env (typ_of exp) exp in + let let_exp = mk_exp (E_let (mk_letbind pat (strip_exp exp'), block)) in begin try check_exp env let_exp unit_typ with | Type_error (_, l, err) -> @@ -2162,9 +2181,11 @@ let rewrite_defs_remove_blocks env = let l = get_loc_exp v in let env = env_of v in let typ = typ_of v in - let wild = add_p_typ typ (annot_pat P_wild l env typ) in + let wild = annot_pat P_wild l env typ in let e_aux = E_let (annot_letbind (unaux_pat wild, v) l env typ, body) in - fix_eff_exp (annot_exp e_aux l env (typ_of body)) in + fix_eff_exp (annot_exp e_aux l env (typ_of body)) + |> add_typs_let env typ (typ_of body) + in let rec f l = function | [] -> E_aux (E_lit (L_aux (L_unit,gen_loc l)), (simple_annot l unit_typ)) @@ -2195,19 +2216,20 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = | Some (env, typ, eff) when is_unit_typ typ -> let body = body (annot_exp (E_lit (mk_lit L_unit)) l env unit_typ) in let body_typ = try typ_of body with _ -> unit_typ in - let wild = add_p_typ (typ_of v) (annot_pat P_wild l env (typ_of v)) in + let wild = annot_pat P_wild l env typ in let lb = fix_eff_lb (annot_letbind (unaux_pat wild, v) l env unit_typ) in fix_eff_exp (annot_exp (E_let (lb, body)) l env body_typ) + |> add_typs_let env typ body_typ | Some (env, typ, eff) -> let id = fresh_id "w__" l in - let pat = add_p_typ (typ_of v) (annot_pat (P_id id) l env (typ_of v)) in + let pat = annot_pat (P_id id) l env typ in let lb = fix_eff_lb (annot_letbind (unaux_pat pat, v) l env typ) in let body = body (annot_exp (E_id id) l env typ) in fix_eff_exp (annot_exp (E_let (lb, body)) l env (typ_of body)) + |> add_typs_let env typ (typ_of body) | None -> raise (Reporting.err_unreachable l __POS__ "no type information") - let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp = match l with | [] -> k [] @@ -2244,9 +2266,9 @@ let rewrite_defs_letbind_effects env = and n_pexp : 'b. bool -> 'a pexp -> ('a pexp -> 'b) -> 'b = fun newreturn pexp k -> match pexp with | Pat_aux (Pat_exp (pat,exp),annot) -> - k (fix_eff_pexp (Pat_aux (Pat_exp (pat,n_exp_term newreturn exp), annot))) + k (fix_eff_pexp (Pat_aux (Pat_exp (pat, n_exp_term newreturn exp), annot))) | Pat_aux (Pat_when (pat,guard,exp),annot) -> - k (fix_eff_pexp (Pat_aux (Pat_when (pat,n_exp_term newreturn guard,n_exp_term newreturn exp), annot))) + k (fix_eff_pexp (Pat_aux (Pat_when (pat, n_exp_term newreturn guard, n_exp_term newreturn exp), annot))) and n_pexpL (newreturn : bool) (pexps : 'a pexp list) (k : 'a pexp list -> 'a exp) : 'a exp = mapCont (n_pexp newreturn) pexps k @@ -2300,15 +2322,15 @@ let rewrite_defs_letbind_effects env = and n_lexpL (lexps : 'a lexp list) (k : 'a lexp list -> 'a exp) : 'a exp = mapCont n_lexp lexps k - and n_exp_term (newreturn : bool) (exp : 'a exp) : 'a exp = + and n_exp_term ?cast:(cast=false) (newreturn : bool) (exp : 'a exp) : 'a exp = let (E_aux (_,(l,tannot))) = exp in let exp = if newreturn then (* let typ = try typ_of exp with _ -> unit_typ in *) - let exp = add_e_cast (typ_of exp) exp in + let exp = if cast then add_e_cast (env_of exp) (typ_of exp) exp else exp in annot_exp (E_internal_return exp) l (env_of exp) (typ_of exp) - else - exp in + else exp + in (* n_exp_term forces an expression to be translated into a form "let .. let .. let .. in EXP" where EXP has no effect and does not update variables *) @@ -2331,8 +2353,8 @@ let rewrite_defs_letbind_effects env = (* Leave effectful operands of Boolean "and"/"or" in place to allow short-circuiting. *) let newreturn = effectful l || effectful r in - let l = n_exp_term newreturn l in - let r = n_exp_term newreturn r in + let l = n_exp_term ~cast:true newreturn l in + let r = n_exp_term ~cast:true newreturn r in k (rewrap (E_app (op_bool, [l; r]))) | E_app (id,exps) -> n_exp_nameL exps (fun exps -> @@ -2351,7 +2373,8 @@ let rewrite_defs_letbind_effects env = let newreturn = effectful exp2 || effectful exp3 in let exp2 = n_exp_term newreturn exp2 in let exp3 = n_exp_term newreturn exp3 in - k (rewrap (E_if (exp1,exp2,exp3))) in + k (rewrap (E_if (exp1,exp2,exp3))) + in if value exp1 then e_if (n_exp_term false exp1) else n_exp_name exp1 e_if | E_for (id,start,stop,by,dir,body) -> n_exp_name start (fun start -> @@ -2365,7 +2388,7 @@ let rewrite_defs_letbind_effects env = | Measure_aux (Measure_some exp,l) -> Measure_aux (Measure_some (n_exp_term false exp),l) in - let cond = n_exp_term (effectful cond) cond in + let cond = n_exp_term ~cast:true (effectful cond) cond in let body = n_exp_term (effectful body) body in k (rewrap (E_loop (loop,measure,cond,body))) | E_vector exps -> @@ -2437,7 +2460,7 @@ let rewrite_defs_letbind_effects env = | E_assert (exp1,exp2) -> n_exp_name exp1 (fun exp1 -> n_exp_name exp2 (fun exp2 -> - k (rewrap (E_assert (exp1,exp2))))) + k (rewrap (E_assert (exp1, exp2))))) | E_var (lexp,exp1,exp2) -> n_lexp lexp (fun lexp -> n_exp exp1 (fun exp1 -> @@ -2505,7 +2528,7 @@ let rewrite_defs_internal_lets env = let rec pat_of_local_lexp (LEXP_aux (lexp, ((l, _) as annot))) = match lexp with | LEXP_id id -> P_aux (P_id id, annot) - | LEXP_cast (typ, id) -> add_p_typ typ (P_aux (P_id id, annot)) + | LEXP_cast (typ, id) -> add_p_typ (env_of_annot annot) typ (P_aux (P_id id, annot)) | LEXP_tup lexps -> P_aux (P_tup (List.map pat_of_local_lexp lexps), annot) | _ -> raise (Reporting.err_unreachable l __POS__ "unexpected local lexp") in @@ -2524,7 +2547,7 @@ let rewrite_defs_internal_lets env = | LEXP_aux (_,lexp_annot') -> lexp_annot' | exception _ -> lannot) in - let rhs = add_e_cast ltyp (rhs exp) in + let rhs = add_e_cast (env_of exp) ltyp (rhs exp) in E_let (LB_aux (LB_val (pat_of_local_lexp lhs, rhs), annot), body) | LB_aux (LB_val (pat,exp'),annot') -> if effectful exp' @@ -2536,7 +2559,7 @@ let rewrite_defs_internal_lets env = | LEXP_aux (LEXP_id id, annot) -> (P_id id, annot) | LEXP_aux (LEXP_cast (typ, id), annot) -> - (unaux_pat (add_p_typ typ (P_aux (P_id id, annot))), annot) + (unaux_pat (add_p_typ (env_of_annot annot) typ (P_aux (P_id id, annot))), annot) | _ -> failwith "E_var with unexpected lexp" in if effectful exp1 then E_internal_plet (P_aux (paux, annot), exp1, exp2) @@ -2566,7 +2589,7 @@ let fold_typed_guards env guards = | g :: gs -> List.fold_left (fun g g' -> annot_exp (E_app (mk_id "and_bool", [g; g'])) Parse_ast.Unknown env bool_typ) g gs -let rewrite_pexp_with_guards rewrite_pat (Pat_aux (pexp_aux, (l, _)) as pexp) = +let rewrite_pexp_with_guards rewrite_pat (Pat_aux (pexp_aux, (annot: tannot annot)) as pexp) = let guards = ref [] in match pexp_aux with @@ -2576,13 +2599,13 @@ let rewrite_pexp_with_guards rewrite_pat (Pat_aux (pexp_aux, (l, _)) as pexp) = match !guards with | [] -> pexp | gs -> - let unchecked_pexp = mk_pexp ~loc:l (Pat_when (strip_pat pat, List.map strip_exp gs |> fold_guards, strip_exp exp)) in + let unchecked_pexp = mk_pexp (Pat_when (strip_pat pat, List.map strip_exp gs |> fold_guards, strip_exp exp)) in check_case (env_of_pat pat) (typ_of_pat pat) unchecked_pexp (typ_of exp) end | Pat_when (pat, guard, exp) -> begin let pat = fold_pat { id_pat_alg with p_aux = rewrite_pat guards } pat in - let unchecked_pexp = mk_pexp ~loc:l (Pat_when (strip_pat pat, List.map strip_exp !guards |> fold_guards, strip_exp exp)) in + let unchecked_pexp = mk_pexp (Pat_when (strip_pat pat, List.map strip_exp !guards |> fold_guards, strip_exp exp)) in check_case (env_of_pat pat) (typ_of_pat pat) unchecked_pexp (typ_of exp) end @@ -3334,22 +3357,24 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let tuple_exp = function | [] -> annot_exp (E_lit (mk_lit L_unit)) l env unit_typ | [e] -> e - | es -> annot_exp (E_tuple es) l env (tuple_typ (List.map typ_of es)) in + | es -> annot_exp (E_tuple es) l env (tuple_typ (List.map typ_of es)) + in let tuple_pat = function | [] -> annot_pat P_wild l env unit_typ | [pat] -> let typ = typ_of_pat pat in - add_p_typ typ pat + add_p_typ env typ pat | pats -> let typ = tuple_typ (List.map typ_of_pat pats) in - add_p_typ typ (annot_pat (P_tup pats) l env typ) in + add_p_typ env typ (annot_pat (P_tup pats) l env typ) + in let rec add_vars overwrite ((E_aux (expaux,annot)) as exp) vars = match expaux with | E_let (lb,exp) -> let exp = add_vars overwrite exp vars in - E_aux (E_let (lb,exp),swaptyp (typ_of exp) annot) + E_aux (E_let (lb,exp), swaptyp (typ_of exp) annot) | E_var (lexp,exp1,exp2) -> let exp2 = add_vars overwrite exp2 vars in E_aux (E_var (lexp,exp1,exp2), swaptyp (typ_of exp2) annot) @@ -3358,7 +3383,11 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = E_aux (E_internal_plet (pat,exp1,exp2), swaptyp (typ_of exp2) annot) | E_internal_return exp2 -> let exp2 = add_vars overwrite exp2 vars in - E_aux (E_internal_return exp2,swaptyp (typ_of exp2) annot) + E_aux (E_internal_return exp2, swaptyp (typ_of exp2) annot) + | E_cast (typ, exp) -> + let (E_aux (expaux, annot) as exp) = add_vars overwrite exp vars in + let typ' = typ_of exp in + add_e_cast (env_of exp) typ' (E_aux (expaux, swaptyp typ' annot)) | _ -> (* after rewrite_defs_letbind_effects there cannot be terms that have effects/update local variables in "tail-position": check n_exp_term @@ -3367,6 +3396,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let lb = LB_aux (LB_val (P_aux (P_wild, annot), exp), annot) in let exp' = tuple_exp vars in E_aux (E_let (lb, exp'), swaptyp (typ_of exp') annot) + |> add_typs_let env (typ_of exp) (typ_of exp') else tuple_exp (exp :: vars) in let mk_var_exps_pats l env ids = @@ -3378,7 +3408,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = exp, P_aux (P_id id, a)) |> List.split in - let rewrite used_vars (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) = + let rec rewrite used_vars (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (paux,(pl,pannot)) as pat) = let env = env_of_annot annot in let overwrite = match paux with | P_wild | P_typ (_, P_aux (P_wild, _)) -> true @@ -3417,11 +3447,15 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = in let lvar_nc = nc_and (nc_lteq (nvar lower_kid) (nvar lvar_kid)) (nc_lteq (nvar lvar_kid) (nvar upper_kid)) in let lvar_typ = mk_typ (Typ_exist (List.map (mk_kopt K_int) [lvar_kid], lvar_nc, atom_typ (nvar lvar_kid))) in - let lvar_pat = unaux_pat (add_p_typ lvar_typ (annot_pat (P_var ( + let lvar_pat = unaux_pat (annot_pat (P_var ( annot_pat (P_id id) el env (atom_typ (nvar lvar_kid)), - TP_aux (TP_var lvar_kid, gen_loc el))) el env lvar_typ)) in + TP_aux (TP_var lvar_kid, gen_loc el))) el env lvar_typ) + in let lb = fix_eff_lb (annot_letbind (lvar_pat, exp1) el env lvar_typ) in - let body = fix_eff_exp (annot_exp (E_let (lb, exp4)) el env (typ_of exp4)) in + let body = + fix_eff_exp (annot_exp (E_let (lb, exp4)) el env (typ_of exp4)) + |> add_typs_let env lvar_typ (typ_of exp4) + in (* If lower > upper, the loop body never gets executed, and the type checker might not be able to prove that the initial value exp1 satisfies the constraints on the loop variable. @@ -3521,7 +3555,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | E_assign (lexp,vexp) -> let mk_id_pat id = let typ = lvar_typ (Env.lookup_id id env) in - add_p_typ typ (annot_pat (P_id id) pl env typ) + add_p_typ env typ (annot_pat (P_id id) pl env typ) in if effectful exp then Same_vars (E_aux (E_assign (lexp,vexp),annot)) @@ -3531,7 +3565,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let pat = annot_pat (P_id id) pl env (typ_of vexp) in Added_vars (vexp, mk_id_pat id) | LEXP_aux (LEXP_cast (typ,id),annot) -> - let pat = add_p_typ typ (annot_pat (P_id id) pl env (typ_of vexp)) in + let pat = add_p_typ env typ (annot_pat (P_id id) pl env (typ_of vexp)) in Added_vars (vexp,pat) | LEXP_aux (LEXP_vector (LEXP_aux (LEXP_id id,((l2,_) as annot2)),i),((l1,_) as annot)) -> let eid = annot_exp (E_id id) l2 env (typ_of_annot annot2) in @@ -3545,6 +3579,13 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let pat = annot_pat (P_id id) pl env (typ_of vexp) in Added_vars (vexp,pat) | _ -> Same_vars (E_aux (E_assign (lexp,vexp),annot))) + | E_cast (typ, exp) -> + begin match rewrite used_vars exp pat with + | Added_vars (exp', pat') -> + Added_vars (add_e_cast (env_of exp') (typ_of exp') exp', pat') + | Same_vars (exp') -> + Same_vars (E_aux (E_cast (typ, exp'), annot)) + end | _ -> (* after rewrite_defs_letbind_effects this expression is pure and updates no variables: check n_exp_term and where it's used. *) @@ -3565,7 +3606,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | LEXP_aux (LEXP_id id, _) -> P_id id, typ_of v | LEXP_aux (LEXP_cast (typ, id), _) -> - unaux_pat (add_p_typ typ (annot_pat (P_id id) l env (typ_of v))), typ + unaux_pat (add_p_typ env typ (annot_pat (P_id id) l env (typ_of v))), typ | _ -> raise (Reporting.err_unreachable l __POS__ "E_var with a lexp that is not a variable") in @@ -3579,6 +3620,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = rewrite_var_updates exp' | E_internal_plet (pat,v,body) -> failwith "rewrite_var_updates: E_internal_plet shouldn't be introduced yet" + | E_cast (typ, exp) -> + let exp' = rewrite_var_updates exp in + E_aux (E_cast (typ, exp'), annot) (* There are no other expressions that have effects or variable updates in "tail-position": check the definition nexp_term and where it is used. *) | _ -> exp @@ -3714,7 +3758,7 @@ let rewrite_defs_remove_superfluous_returns env = let add_opt_cast typopt1 typopt2 annot exp = match typopt1, typopt2 with - | Some typ, _ | _, Some typ -> add_e_cast typ exp + | Some typ, _ | _, Some typ -> add_e_cast (env_of exp) typ exp | None, None -> exp in @@ -3765,11 +3809,11 @@ let rewrite_defs_remove_superfluous_returns env = let rewrite_defs_remove_e_assign env (Defs defs) = let (Defs loop_specs) = fst (Type_error.check initial_env (Defs (List.map gen_vs - [("foreach#", "forall ('vars : Type). (int, int, int, bool, 'vars, 'vars) -> 'vars"); - ("while#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); - ("until#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); - ("while#t", "forall ('vars : Type). (bool, 'vars, 'vars, int) -> 'vars"); - ("until#t", "forall ('vars : Type). (bool, 'vars, 'vars, int) -> 'vars")]))) in + [("foreach#", "forall ('vars_in 'vars_out : Type). (int, int, int, bool, 'vars_in, 'vars_out) -> 'vars_out"); + ("while#", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out) -> 'vars_out"); + ("until#", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out) -> 'vars_out"); + ("while#t", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out, int) -> 'vars_out"); + ("until#t", "forall ('vars_in 'vars_out : Type). (bool, 'vars_in, 'vars_out, int) -> 'vars_out")]))) in let rewrite_exp _ e = replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in rewrite_defs_base @@ -4752,6 +4796,36 @@ let rec move_loop_measures (Defs defs) = in Defs (List.rev rev_defs) +let rewrite_toplevel_consts target type_env (Defs defs) = + let istate = Constant_fold.initial_state (Defs defs) type_env in + let subst consts exp = + let open Rewriter in + let used_ids = fold_exp { (pure_exp_alg IdSet.empty IdSet.union) with e_id = IdSet.singleton } exp in + let subst_ids = IdSet.filter (fun id -> Bindings.mem id consts) used_ids in + IdSet.fold (fun id -> subst id (Bindings.find id consts)) subst_ids exp + in + let rewrite_def (revdefs, consts) = function + | DEF_val (LB_aux (LB_val (pat, exp), a) as lb) -> + begin match unaux_pat pat with + | P_id id | P_typ (_, P_aux (P_id id, _)) -> + let exp' = Constant_fold.rewrite_exp_once target istate (subst consts exp) in + if Constant_fold.is_constant exp' then + try + let exp' = infer_exp (env_of exp') (strip_exp exp') in + let pannot = (pat_loc pat, mk_tannot (env_of_pat pat) (typ_of exp') no_effect) in + let pat' = P_aux (P_typ (typ_of exp', P_aux (P_id id, pannot)), pannot) in + let consts' = Bindings.add id exp' consts in + (DEF_val (LB_aux (LB_val (pat', exp'), a)) :: revdefs, consts') + with + | _ -> (DEF_val lb :: revdefs, consts) + else (DEF_val lb :: revdefs, consts) + | _ -> (DEF_val lb :: revdefs, consts) + end + | def -> (def :: revdefs, consts) + in + let (revdefs, _) = List.fold_left rewrite_def ([], Bindings.empty) defs in + Defs (List.rev revdefs) + let opt_mono_rewrites = ref false let opt_mono_complex_nexps = ref true @@ -4851,10 +4925,10 @@ let all_rewrites = [ ("mapping_builtins", Basic_rewriter rewrite_defs_mapping_patterns); ("mono_rewrites", Basic_rewriter mono_rewrites); ("toplevel_nexps", Basic_rewriter rewrite_toplevel_nexps); + ("toplevel_consts", String_rewriter (fun target -> Basic_rewriter (rewrite_toplevel_consts target))); ("monomorphise", String_rewriter (fun target -> Basic_rewriter (monomorphise target))); - ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons)); - ("add_bitvector_casts", Basic_rewriter (fun _ -> Monomorphise.add_bitvector_casts)); - ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons)); + ("atoms_to_singletons", String_rewriter (fun target -> (Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons target)))); + ("add_bitvector_casts", Basic_rewriter Monomorphise.add_bitvector_casts); ("remove_impossible_int_cases", Basic_rewriter Constant_propagation.remove_impossible_int_cases); ("const_prop_mutrec", String_rewriter (fun target -> Basic_rewriter (Constant_propagation_mutrec.rewrite_defs target))); ("make_cases_exhaustive", Basic_rewriter MakeExhaustive.rewrite); @@ -4902,17 +4976,18 @@ let rewrites_lem = [ ("mono_rewrites", []); ("recheck_defs", [If_mono_arg]); ("undefined", [Bool_arg false]); + ("toplevel_consts", [String_arg "lem"; If_mwords_arg]); ("toplevel_nexps", [If_mono_arg]); ("monomorphise", [String_arg "lem"; If_mono_arg]); ("recheck_defs", [If_mwords_arg]); ("add_bitvector_casts", [If_mwords_arg]); - ("atoms_to_singletons", [If_mono_arg]); + ("atoms_to_singletons", [String_arg "lem"; If_mono_arg]); ("recheck_defs", [If_mwords_arg]); ("vector_string_pats_to_bit_list", []); ("remove_not_pats", []); ("remove_impossible_int_cases", []); - ("vector_concat_assignments", []); ("tuple_assignments", []); + ("vector_concat_assignments", []); ("simple_assignments", []); ("remove_vector_concat", []); ("remove_bitvector_pats", []); @@ -4953,8 +5028,8 @@ let rewrites_coq = [ ("vector_string_pats_to_bit_list", []); ("remove_not_pats", []); ("remove_impossible_int_cases", []); - ("vector_concat_assignments", []); ("tuple_assignments", []); + ("vector_concat_assignments", []); ("simple_assignments", []); ("remove_vector_concat", []); ("remove_bitvector_pats", []); @@ -5002,8 +5077,8 @@ let rewrites_ocaml = [ ("mapping_builtins", []); ("undefined", [Bool_arg false]); ("vector_string_pats_to_bit_list", []); - ("vector_concat_assignments", []); ("tuple_assignments", []); + ("vector_concat_assignments", []); ("simple_assignments", []); ("remove_not_pats", []); ("remove_vector_concat", []); @@ -5026,7 +5101,7 @@ let rewrites_c = [ ("recheck_defs", [If_mono_arg]); ("toplevel_nexps", [If_mono_arg]); ("monomorphise", [String_arg "c"; If_mono_arg]); - ("atoms_to_singletons", [If_mono_arg]); + ("atoms_to_singletons", [String_arg "c"; If_mono_arg]); ("recheck_defs", [If_mono_arg]); ("undefined", [Bool_arg false]); ("vector_string_pats_to_bit_list", []); @@ -5034,10 +5109,8 @@ let rewrites_c = [ ("remove_vector_concat", []); ("remove_bitvector_pats", []); ("pattern_literals", [Literal_arg "all"]); - ("vector_concat_assignments", []); ("tuple_assignments", []); ("vector_concat_assignments", []); - ("tuple_assignments", []); ("simple_assignments", []); ("exp_lift_assign", []); ("merge_function_clauses", []); @@ -5052,8 +5125,8 @@ let rewrites_interpreter = [ ("pat_string_append", []); ("mapping_builtins", []); ("undefined", [Bool_arg false]); - ("vector_concat_assignments", []); ("tuple_assignments", []); + ("vector_concat_assignments", []); ("simple_assignments", []) ] |
