diff options
| author | Thomas Bauereiss | 2019-03-14 18:34:49 +0000 |
|---|---|---|
| committer | Thomas Bauereiss | 2019-03-15 18:47:30 +0000 |
| commit | e92ff6875925c2fe8b6ebc95a6b328514abc0106 (patch) | |
| tree | 24ef95facd542364e9578ec55532ff9b84a96e53 /src | |
| parent | 11325d9bb5f4117c5b41413ac523b7d50577ebdd (diff) | |
Add a rewriting pass for constant propagation in mutrecs
Propagating constants into mutually recursive calls and removing dead branches
might break mutually recursive cycles.
Also make constant propagation use the existing interpreter-based constant
folding to evaluate function calls with only constant arguments (as opposed to
a mixture of inlining and hard-coded rewrite rules).
Diffstat (limited to 'src')
| -rw-r--r-- | src/constant_fold.ml | 33 | ||||
| -rw-r--r-- | src/constant_propagation.ml | 216 | ||||
| -rw-r--r-- | src/constant_propagation_mutrec.ml | 232 | ||||
| -rw-r--r-- | src/rewrites.ml | 4 | ||||
| -rw-r--r-- | src/sail.ml | 3 |
5 files changed, 332 insertions, 156 deletions
diff --git a/src/constant_fold.ml b/src/constant_fold.ml index f85fb673..fd9b322b 100644 --- a/src/constant_fold.ml +++ b/src/constant_fold.ml @@ -136,13 +136,17 @@ let fold_to_unit id = in IdSet.mem id remove -let rec is_constant (E_aux (e_aux, _)) = +let rec is_constant (E_aux (e_aux, _) as exp) = match e_aux with | E_lit _ -> true | E_vector exps -> List.for_all is_constant exps | E_record fexps -> List.for_all is_constant_fexp fexps | E_cast (_, exp) -> is_constant exp | E_tuple exps -> List.for_all is_constant exps + | E_id id -> + (match Env.lookup_id id (env_of exp) with + | Enum _ -> true + | _ -> false) | _ -> false and is_constant_fexp (FE_aux (FE_Fexp (_, exp), _)) = is_constant exp @@ -173,21 +177,18 @@ let rec run frame = - Throws an exception that isn't caught. *) -let rec rewrite_constant_function_calls' ast = - let rewrite_count = ref 0 in - let ok () = incr rewrite_count in - let not_ok () = decr rewrite_count in - +let initial_state ast = let lstate, gstate = Interpreter.initial_state ast safe_primops in - let gstate = { gstate with Interpreter.allow_registers = false } in + (lstate, { gstate with Interpreter.allow_registers = false }) +let rw_exp ok not_ok istate = let evaluate e_aux annot = let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in try begin - let v = run (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in + let v = run (Interpreter.Step (lazy "", istate, initial_monad, [])) in let exp = exp_of_value v in try (ok (); Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)) with | Type_error (env, l, err) -> @@ -231,11 +232,19 @@ let rec rewrite_constant_function_calls' ast = | _ -> E_aux (e_aux, annot) in - let rw_exp = { - id_exp_alg with - e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot) + fold_exp { id_exp_alg with e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)} + +let rewrite_exp_once = rw_exp (fun _ -> ()) (fun _ -> ()) + +let rec rewrite_constant_function_calls' ast = + let rewrite_count = ref 0 in + let ok () = incr rewrite_count in + let not_ok () = decr rewrite_count in + + let rw_defs = { + rewriters_base with + rewrite_exp = (fun _ -> rw_exp ok not_ok (initial_state ast)) } in - let rw_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp) } in let ast = rewrite_defs_base rw_defs ast in (* We keep iterating until we have no more re-writes to do *) if !rewrite_count > 0 diff --git a/src/constant_propagation.ml b/src/constant_propagation.ml index 33b67008..3ae46657 100644 --- a/src/constant_propagation.ml +++ b/src/constant_propagation.ml @@ -111,9 +111,13 @@ let rec is_value (E_aux (e,(l,annot))) = | E_id id -> is_constructor id | E_lit _ -> true | E_tuple es -> List.for_all is_value es + | E_record fes -> + List.for_all (fun (FE_aux (FE_Fexp (_, e), _)) -> is_value e) fes | E_app (id,es) -> is_constructor id && List.for_all is_value es (* We add casts to undefined to keep the type information in the AST *) | E_cast (typ,E_aux (E_lit (L_aux (L_undef,_)),_)) -> true + (* Also keep casts around records, as type inference fails without *) + | E_cast (_, (E_aux (E_record _, _) as e')) -> is_value e' (* TODO: more? *) | _ -> false @@ -263,93 +267,6 @@ let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) = | L_num i1, L_num i2 -> Some (Big_int.equal i1 i2) | _ -> Some (l1 = l2) -let try_app (l,ann) (id,args) = - let new_l = Parse_ast.Generated l in - let env = env_of_annot (l,ann) in - let get_overloads f = List.map string_of_id - (Env.get_overloads (Id_aux (Id f, Parse_ast.Unknown)) env @ - Env.get_overloads (Id_aux (DeIid f, Parse_ast.Unknown)) env) in - let is_id f = List.mem (string_of_id id) (f :: get_overloads f) in - if is_id "==" || is_id "!=" then - match args with - | [E_aux (E_lit l1,_); E_aux (E_lit l2,_)] -> - let lit b = if b then L_true else L_false in - let lit b = lit (if is_id "==" then b else not b) in - (match lit_eq l1 l2 with - | None -> None - | Some b -> Some (E_aux (E_lit (L_aux (lit b,new_l)),(l,ann)))) - | _ -> None - else if is_id "cast_bit_bool" then - match args with - | [E_aux (E_lit L_aux (L_zero,_),_)] -> Some (E_aux (E_lit (L_aux (L_false,new_l)),(l,ann))) - | [E_aux (E_lit L_aux (L_one ,_),_)] -> Some (E_aux (E_lit (L_aux (L_true ,new_l)),(l,ann))) - | _ -> None - else if is_id "UInt" || is_id "unsigned" then - match args with - | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit,_), _)] -> - Some (E_aux (E_lit (L_aux (L_num (int_of_str_lit lit),new_l)),(l,ann))) - | _ -> None - else if is_id "slice" then - match args with - | [E_aux (E_lit (L_aux ((L_hex _| L_bin _),_) as lit), annot); - E_aux (E_lit L_aux (L_num i,_), _); - E_aux (E_lit L_aux (L_num len,_), _)] -> - (match Env.base_typ_of (env_of_annot annot) (typ_of_annot annot) with - | Typ_aux (Typ_app (_,[_;A_aux (A_order ord,_);_]),_) -> - (match slice_lit lit i len ord with - | Some lit' -> Some (E_aux (E_lit lit',(l,ann))) - | None -> None) - | _ -> None) - | _ -> None - else if is_id "bitvector_concat" then - match args with - | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit1,_), _); - E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit2,_), _)] -> - Some (E_aux (E_lit (L_aux (concat_vec lit1 lit2,new_l)),(l,ann))) - | _ -> None - else if is_id "shl_int" then - match args with - | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> - Some (E_aux (E_lit (L_aux (L_num (Big_int.shift_left i (Big_int.to_int j)),new_l)),(l,ann))) - | _ -> None - else if is_id "mult_atom" || is_id "mult_int" || is_id "mult_range" then - match args with - | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> - Some (E_aux (E_lit (L_aux (L_num (Big_int.mul i j),new_l)),(l,ann))) - | _ -> None - else if is_id "quotient_nat" then - match args with - | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> - Some (E_aux (E_lit (L_aux (L_num (Big_int.div i j),new_l)),(l,ann))) - | _ -> None - else if is_id "add_atom" || is_id "add_int" || is_id "add_range" then - match args with - | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] -> - Some (E_aux (E_lit (L_aux (L_num (Big_int.add i j),new_l)),(l,ann))) - | _ -> None - else if is_id "negate_range" then - match args with - | [E_aux (E_lit L_aux (L_num i,_),_)] -> - Some (E_aux (E_lit (L_aux (L_num (Big_int.negate i),new_l)),(l,ann))) - | _ -> None - else if is_id "ex_int" then - match args with - | [E_aux (E_lit lit,(l,_))] -> Some (E_aux (E_lit lit,(l,ann))) - | [E_aux (E_cast (_,(E_aux (E_lit (L_aux (L_undef,_)),_) as e)),(l,_))] -> - Some (reduce_cast (typ_of_annot (l,ann)) e l ann) - | _ -> None - else if is_id "vector_access" || is_id "bitvector_access" then - match args with - | [E_aux (E_lit L_aux ((L_hex _ | L_bin _) as lit,_),_); - E_aux (E_lit L_aux (L_num i,_),_)] -> - let v = int_of_str_lit lit in - let b = Big_int.bitwise_and (Big_int.shift_right v (Big_int.to_int i)) (Big_int.of_int 1) in - let lit' = if Big_int.equal b (Big_int.of_int 1) then L_one else L_zero in - Some (E_aux (E_lit (L_aux (lit',new_l)),(l,ann))) - | _ -> None - else None - - let construct_lit_vector args = let rec aux l = function | [] -> Some (L_aux (L_bin (String.concat "" (List.rev l)),Unknown)) @@ -361,10 +278,18 @@ let construct_lit_vector args = (* Add a cast to undefined so that it retains its type, otherwise it can't be substituted safely *) let keep_undef_typ value = - match value with - | E_aux (E_lit (L_aux (L_undef,lann)),eann) -> - E_aux (E_cast (typ_of_annot eann,value),(Generated Unknown,snd eann)) - | _ -> value + let e_aux (e, ann) = + match e with + | E_lit (L_aux (L_undef, _)) -> + (* Add cast to undefined... *) + E_aux (E_cast (typ_of_annot ann, E_aux (e, ann)), ann) + | E_cast (typ, E_aux (E_cast (_, e), _)) -> + (* ... unless there was a cast already *) + E_aux (E_cast (typ, e), ann) + | _ -> E_aux (e, ann) + in + let open Rewriter in + fold_exp { id_exp_alg with e_aux = e_aux } value (* Check whether the current environment with the given kid assignments is inconsistent (and hence whether the code is dead) *) @@ -375,6 +300,15 @@ let is_env_inconsistent env ksubsts = let const_props defs ref_vars = + let const_fold exp = + try + strip_exp exp + |> infer_exp (env_of exp) + |> Constant_fold.rewrite_exp_once (Constant_fold.initial_state defs) + |> keep_undef_typ + with + | _ -> exp + in let rec const_prop_exp substs assigns ((E_aux (e,(l,annot))) as exp) = (* Functions to treat lists and tuples of subexpressions as possibly non-deterministic: that is, we stop making any assumptions about @@ -414,7 +348,8 @@ let const_props defs ref_vars = let e4',_ = const_prop_exp substs assigns e4 in e1',e2',e3',e4',assigns in - let re e assigns = E_aux (e,(l,annot)),assigns in + let rewrap e = E_aux (e,(l,annot)) in + let re e assigns = rewrap e,assigns in match e with (* TODO: are there more circumstances in which we should get rid of these? *) | E_block [e] -> const_prop_exp substs assigns e @@ -444,12 +379,7 @@ let const_props defs ref_vars = | E_app (id,es) -> let es',assigns = non_det_exp_list es in let env = Type_check.env_of_annot (l, annot) in - (match try_app (l,annot) (id,es') with - | None -> - (match const_prop_try_fn l env (id,es') with - | None -> re (E_app (id,es')) assigns - | Some r -> r,assigns) - | Some r -> r,assigns) + const_prop_try_fn env (id, es') (l, annot), assigns | E_tuple es -> let es',assigns = non_det_exp_list es in re (E_tuple es') assigns @@ -539,10 +469,33 @@ let const_props defs ref_vars = let assigned_in = IdSet.union (assigned_vars_in_fexps fes) (assigned_vars e) in let assigns = isubst_minus_set assigns assigned_in in let e',_ = const_prop_exp substs assigns e in - re (E_record_update (e', const_prop_fexps substs assigns fes)) assigns + let fes' = const_prop_fexps substs assigns fes in + begin + match unaux_exp (fst (uncast_exp e')) with + | E_record (fes0) -> + let apply_fexp (FE_aux (FE_Fexp (id, e), _)) (FE_aux (FE_Fexp (id', e'), ann)) = + if Id.compare id id' = 0 then + FE_aux (FE_Fexp (id', e), ann) + else + FE_aux (FE_Fexp (id', e'), ann) + in + let update_fields fexp = List.map (apply_fexp fexp) in + let fes0' = List.fold_right update_fields fes' fes0 in + re (E_record fes0') assigns + | _ -> + re (E_record_update (e', fes')) assigns + end | E_field (e,id) -> let e',assigns = const_prop_exp substs assigns e in - re (E_field (e',id)) assigns + begin + let is_field (FE_aux (FE_Fexp (id', _), _)) = Id.compare id id' = 0 in + match unaux_exp e' with + | E_record fes0 when List.exists is_field fes0 -> + let (FE_aux (FE_Fexp (_, e), _)) = List.find is_field fes0 in + re (unaux_exp e) assigns + | _ -> + re (E_field (e',id)) assigns + end | E_case (e,cases) -> let e',assigns = const_prop_exp substs assigns e in (match can_match e' cases substs assigns with @@ -568,7 +521,7 @@ let const_props defs ref_vars = let e2',assigns = const_prop_exp substs' assigns e2 in re (E_let (LB_aux (LB_val (p,e'), annot), e2')) assigns in - if is_value e' && not (is_value e) then + if is_value e' then match can_match e' [Pat_aux (Pat_exp (p,e2),(Unknown,empty_tannot))] substs assigns with | None -> plain () | Some (e'',bindings,kbindings) -> @@ -581,10 +534,10 @@ let const_props defs ref_vars = (* TODO maybe - tuple assignments *) | E_assign (le,e) -> let env = Type_check.env_of_annot (l, annot) in + let e',_ = const_prop_exp substs assigns e in let assigned_in = IdSet.union (assigned_vars_in_lexp le) (assigned_vars e) in let assigns = isubst_minus_set assigns assigned_in in let le',idopt = const_prop_lexp substs assigns le in - let e',_ = const_prop_exp substs assigns e in let assigns = match idopt with | Some id -> @@ -653,48 +606,23 @@ let const_props defs ref_vars = | LEXP_field (le,id) -> re (LEXP_field (fst (const_prop_lexp substs assigns le), id)) | LEXP_deref e -> re (LEXP_deref (fst (const_prop_exp substs assigns e))) - (* Reduce a function when - 1. all arguments are values, - 2. the function is pure, - 3. the result is a value - (and 4. the function is not scattered, but that's not terribly important) - to try and keep execution time and the results managable. + (* Try to evaluate function calls with constant arguments via + (interpreter-based) constant folding. + Boolean connectives are special-cased to support short-circuiting when one + argument has a suitable value (even if the other one is not constant). *) - and const_prop_try_fn l env (id,args) = - if not (List.for_all is_value args) then - None - else - let (tq,typ) = Env.get_val_spec_orig id env in - let eff = match typ with - | Typ_aux (Typ_fn (_,_,eff),_) -> Some eff - | _ -> None - in - let Defs ds = defs in - match eff, list_extract (function - | (DEF_fundef (FD_aux (FD_function (_,_,eff,((FCL_aux (FCL_Funcl (id',_),_))::_ as fcls)),_))) - -> if Id.compare id id' = 0 then Some fcls else None - | _ -> None) ds with - | None,_ | _,None -> None - | Some eff,_ when not (is_pure eff) -> None - | Some _,Some fcls -> - let arg = match args with - | [] -> E_aux (E_lit (L_aux (L_unit,Generated l)),(Generated l,empty_tannot)) - | [e] -> e - | _ -> E_aux (E_tuple args,(Generated l,empty_tannot)) in - let cases = List.map (function - | FCL_aux (FCL_Funcl (_,pexp), ann) -> pexp) - fcls in - match can_match_with_env env arg cases (Bindings.empty,KBindings.empty) Bindings.empty with - | Some (exp,bindings,kbindings) -> - let substs = bindings_from_list bindings, kbindings_from_list kbindings in - let result,_ = const_prop_exp substs Bindings.empty exp in - let result = match result with - | E_aux (E_return e,_) -> e - | _ -> result - in - if is_value result then Some result else None - | None -> None - + and const_prop_try_fn env (id, args) (l, annot) = + match (string_of_id id, args) with + | "and_bool", ([E_aux (E_lit (L_aux (L_false, _)), _) as e_false; _] | + [_; E_aux (E_lit (L_aux (L_false, _)), _) as e_false]) -> + e_false + | "or_bool", ([E_aux (E_lit (L_aux (L_true, _)), _) as e_true; _] | + [_; E_aux (E_lit (L_aux (L_true, _)), _) as e_true]) -> + e_true + | _ -> + let exp = (E_aux (E_app (id, args), (l, annot))) in + if List.for_all Constant_fold.is_constant args then const_fold exp else exp + and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases (substs,ksubsts) assigns = let rec findpat_generic check_pat description assigns = function | [] -> (Reporting.print_err l "Monomorphisation" @@ -816,6 +744,8 @@ let const_props defs ref_vars = (Reporting.print_err l' "Monomorphisation" "Unexpected kind of pattern for literal"; GiveUp) in findpat_generic checkpat "literal" assigns cases + | E_record _ | E_cast (_, E_aux (E_record _, _)) -> + findpat_generic (fun _ -> DoesNotMatch) "record" assigns cases | _ -> None and can_match exp = diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml new file mode 100644 index 00000000..683cc6f3 --- /dev/null +++ b/src/constant_propagation_mutrec.ml @@ -0,0 +1,232 @@ +(**************************************************************************) +(* Sail *) +(* *) +(* Copyright (c) 2013-2017 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* Alasdair Armstrong *) +(* Brian Campbell *) +(* Thomas Bauereiss *) +(* Anthony Fox *) +(* Jon French *) +(* Dominic Mulligan *) +(* Stephen Kell *) +(* Mark Wassell *) +(* *) +(* All rights reserved. *) +(* *) +(* This software was developed by the University of Cambridge Computer *) +(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) +(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(**************************************************************************) + +open Ast +open Ast_util +open Type_check +open Rewriter + +(* Unroll mutually recursive calls, starting with the functions given as + targets on the command line, by looking for recursive calls with (some) + constant arguments, and creating copies of those functions with the + constants propagated in. This may cause branches with mutually recursively + calls to disappear, breaking the mutually recursive cycle. *) + +let targets = ref ([] : id list) + +let rec is_const_exp exp = match unaux_exp exp with + | E_lit (L_aux ((L_true | L_false | L_one | L_zero | L_num _), _)) -> true + | E_vector es -> List.for_all is_const_exp es && is_bitvector_typ (typ_of exp) + | E_record fes -> List.for_all is_const_fexp fes + | _ -> false +and is_const_fexp (FE_aux (FE_Fexp (_, e), _)) = is_const_exp e + +let recheck_exp exp = check_exp (env_of exp) (strip_exp exp) (typ_of exp) + +(* Name function copy by encoding values of constant arguments *) +let generate_fun_id id args = + let rec suffix exp = match unaux_exp exp with + | E_lit (L_aux (L_one, _)) -> "1" + | E_lit (L_aux (L_zero, _)) -> "0" + | E_lit (L_aux (L_true, _)) -> "T" + | E_lit (L_aux (L_false, _)) -> "F" + | E_record fes when is_const_exp exp -> + let fsuffix (FE_aux (FE_Fexp (id, e), _)) = suffix e + in + "struct" ^ + Util.zencode_string (string_of_typ (typ_of exp)) ^ + "#" ^ + String.concat "" (List.map fsuffix fes) + | E_vector es when is_const_exp exp -> + String.concat "" (List.map suffix es) + | _ -> + if is_const_exp exp + then "#" ^ Util.zencode_string (string_of_exp exp) + else "v" + in + append_id id ("#mutrec_" ^ String.concat "" (List.map suffix args)) + +(* Generate a val spec for a function copy, removing the constant arguments + that will be propagated in *) +let generate_val_spec env id args l annot = + match Env.get_val_spec_orig id env with + | tq, Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) -> + let orig_ksubst (kid, typ_arg) = + match typ_arg with + | A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg) + | _ -> raise (Reporting.err_todo l "Propagation of polymorphic arguments not implemented") + in + let ksubsts = + recheck_exp (E_aux (E_app (id, args), (l, annot))) + |> instantiation_of + |> KBindings.bindings + |> List.map orig_ksubst + |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty + in + let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in + let arg_typs' = + List.map (KBindings.fold typ_subst ksubsts) arg_typs + |> List.map2 (fun arg typ -> if is_const_exp arg then [] else [typ]) args + |> List.concat + |> function [] -> [unit_typ] | typs -> typs + in + let typ' = mk_typ (Typ_fn (arg_typs', ret_typ', eff)) in + let tyvars = tyvars_of_typ typ' in + let tq' = + quant_items tq |> + List.filter (fun qi -> KidSet.subset (tyvars_of_quant_item qi) tyvars) |> + mk_typquant + in + let typschm = mk_typschm tq' typ' in + mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, (fun _ -> None), false)), + ksubsts + | _, Typ_aux (_, l) -> + raise (Reporting.err_unreachable l __POS__ "Function val spec is not a function type") + +let const_prop defs substs ksubsts exp = + (* Constant_propagation currently only supports nexps for kid substitutions *) + let nexp_substs = + KBindings.bindings ksubsts + |> List.map (function (kid, A_aux (A_nexp n, _)) -> [(kid, n)] | _ -> []) + |> List.concat + |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty + in + Constant_propagation.const_prop + (Defs defs) + (Constant_propagation.referenced_vars exp) + (substs, nexp_substs) + Bindings.empty + exp + |> fst + +(* Propagate constant arguments into function clause pexp *) +let prop_args_pexp defs ksubsts args pexp = + let pat, guard, exp, annot = destruct_pexp pexp in + let pats = match pat with + | P_aux (P_tup pats, _) -> pats + | _ -> [pat] + in + let match_arg (E_aux (_, (l, _)) as arg) pat (pats, substs) = + if is_const_exp arg then + match pat with + | P_aux (P_id id, _) -> (pats, Bindings.add id arg substs) + | _ -> + raise (Reporting.err_todo l + ("Unsupported pattern match in propagation of constant arguments: " ^ + string_of_exp arg ^ " and " ^ string_of_pat pat)) + else (pat :: pats, substs) + in + let pats, substs = List.fold_right2 match_arg args pats ([], Bindings.empty) in + let exp' = const_prop defs substs ksubsts exp in + let pat' = match pats with + | [pat] -> pat + | _ -> P_aux (P_tup pats, (Parse_ast.Unknown, empty_tannot)) + in + construct_pexp (pat', guard, exp', annot) + +let rewrite_defs env (Defs defs) = + let rec rewrite = function + | [] -> [] + | DEF_internal_mutrec mutrecs :: ds -> + let mutrec_ids = IdSet.of_list (List.map id_of_fundef mutrecs) in + let valspecs = ref ([] : unit def list) in + let fundefs = ref ([] : unit def list) in + (* Try to replace mutually recursive calls that have some constant arguments *) + let rec e_app (id, args) (l, annot) = + if IdSet.mem id mutrec_ids && List.exists is_const_exp args then + let id' = generate_fun_id id args in + let args' = match List.filter (fun e -> not (is_const_exp e)) args with + | [] -> [infer_exp env (mk_lit_exp L_unit)] + | args' -> args' + in + if not (IdSet.mem id' (ids_of_defs (Defs !valspecs))) then begin + (* Generate copy of function with constant arguments propagated in *) + let (FD_aux (FD_function (_, _, _, fcls), _)) = + List.find (fun fd -> Id.compare id (id_of_fundef fd) = 0) mutrecs + in + let valspec, ksubsts = generate_val_spec env id args l annot in + let const_prop_funcl (FCL_aux (FCL_Funcl (_, pexp), (l, _))) = + let pexp' = + prop_args_pexp defs ksubsts args pexp + |> rewrite_pexp + |> strip_pexp + in + FCL_aux (FCL_Funcl (id', pexp'), (Parse_ast.Generated l, ())) + in + valspecs := valspec :: !valspecs; + let fundef = mk_fundef (List.map const_prop_funcl fcls) in + fundefs := fundef :: !fundefs + end else (); + E_aux (E_app (id', args'), (l, annot)) + else E_aux (E_app (id, args), (l, annot)) + and e_aux (e, (l, annot)) = + match e with + | E_app (id, args) -> e_app (id, args) (l, annot) + | _ -> E_aux (e, (l, annot)) + and rewrite_pexp pexp = fold_pexp { id_exp_alg with e_aux = e_aux } pexp + and rewrite_funcl (FCL_aux (FCL_Funcl (id, pexp), a) as funcl) = + let pexp' = + if List.exists (fun id' -> Id.compare id id' = 0) !targets then + let pat, guard, body, annot = destruct_pexp pexp in + let body' = const_prop defs Bindings.empty KBindings.empty body in + rewrite_pexp (construct_pexp (pat, guard, recheck_exp body', annot)) + else pexp + in FCL_aux (FCL_Funcl (id, pexp'), a) + and rewrite_fundef (FD_aux (FD_function (ropt, topt, eopt, fcls), a)) = + let fcls' = List.map rewrite_funcl fcls in + FD_aux (FD_function (ropt, topt, eopt, fcls'), a) + in + let mutrecs' = List.map (fun fd -> DEF_fundef (rewrite_fundef fd)) mutrecs in + let (Defs fdefs) = fst (check env (Defs (!valspecs @ !fundefs))) in + mutrecs' @ fdefs @ rewrite ds + | d :: ds -> + d :: rewrite ds + in + Spec_analysis.top_sort_defs (Defs (rewrite defs)) diff --git a/src/rewrites.ml b/src/rewrites.ml index 34b9388d..8bfbc351 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -4688,9 +4688,11 @@ let rewrite_defs_lem = [ ("fix_val_specs", rewrite_fix_val_specs); ("split_execute", rewrite_split_fun_ctor_pats "execute"); ("recheck_defs", recheck_defs); + ("top_sort_defs", fun _ -> top_sort_defs); + ("const_prop_mutrec", Constant_propagation_mutrec.rewrite_defs); + ("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list); ("exp_lift_assign", rewrite_defs_exp_lift_assign); (* ("remove_assert", rewrite_defs_remove_assert); *) - ("top_sort_defs", fun _ -> top_sort_defs); (* ("sizeof", rewrite_sizeof); *) ("early_return", rewrite_defs_early_return); ("fix_val_specs", rewrite_fix_val_specs); diff --git a/src/sail.ml b/src/sail.ml index 23836b1d..fa8f990b 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -277,6 +277,9 @@ let options = Arg.align ([ ( "-dmono_continue", Arg.Set Rewrites.opt_dmono_continue, " continue despite monomorphisation errors"); + ( "-const_prop_mutrec", + Arg.String (fun name -> Constant_propagation_mutrec.targets := Ast_util.mk_id name :: !Constant_propagation_mutrec.targets), + " unroll function in a set of mutually recursive functions"); ( "-verbose", Arg.Int (fun verbosity -> Util.opt_verbosity := verbosity), " produce verbose output"); |
