summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorThomas Bauereiss2019-03-14 18:34:49 +0000
committerThomas Bauereiss2019-03-15 18:47:30 +0000
commite92ff6875925c2fe8b6ebc95a6b328514abc0106 (patch)
tree24ef95facd542364e9578ec55532ff9b84a96e53 /src
parent11325d9bb5f4117c5b41413ac523b7d50577ebdd (diff)
Add a rewriting pass for constant propagation in mutrecs
Propagating constants into mutually recursive calls and removing dead branches might break mutually recursive cycles. Also make constant propagation use the existing interpreter-based constant folding to evaluate function calls with only constant arguments (as opposed to a mixture of inlining and hard-coded rewrite rules).
Diffstat (limited to 'src')
-rw-r--r--src/constant_fold.ml33
-rw-r--r--src/constant_propagation.ml216
-rw-r--r--src/constant_propagation_mutrec.ml232
-rw-r--r--src/rewrites.ml4
-rw-r--r--src/sail.ml3
5 files changed, 332 insertions, 156 deletions
diff --git a/src/constant_fold.ml b/src/constant_fold.ml
index f85fb673..fd9b322b 100644
--- a/src/constant_fold.ml
+++ b/src/constant_fold.ml
@@ -136,13 +136,17 @@ let fold_to_unit id =
in
IdSet.mem id remove
-let rec is_constant (E_aux (e_aux, _)) =
+let rec is_constant (E_aux (e_aux, _) as exp) =
match e_aux with
| E_lit _ -> true
| E_vector exps -> List.for_all is_constant exps
| E_record fexps -> List.for_all is_constant_fexp fexps
| E_cast (_, exp) -> is_constant exp
| E_tuple exps -> List.for_all is_constant exps
+ | E_id id ->
+ (match Env.lookup_id id (env_of exp) with
+ | Enum _ -> true
+ | _ -> false)
| _ -> false
and is_constant_fexp (FE_aux (FE_Fexp (_, exp), _)) = is_constant exp
@@ -173,21 +177,18 @@ let rec run frame =
- Throws an exception that isn't caught.
*)
-let rec rewrite_constant_function_calls' ast =
- let rewrite_count = ref 0 in
- let ok () = incr rewrite_count in
- let not_ok () = decr rewrite_count in
-
+let initial_state ast =
let lstate, gstate =
Interpreter.initial_state ast safe_primops
in
- let gstate = { gstate with Interpreter.allow_registers = false } in
+ (lstate, { gstate with Interpreter.allow_registers = false })
+let rw_exp ok not_ok istate =
let evaluate e_aux annot =
let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in
try
begin
- let v = run (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in
+ let v = run (Interpreter.Step (lazy "", istate, initial_monad, [])) in
let exp = exp_of_value v in
try (ok (); Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)) with
| Type_error (env, l, err) ->
@@ -231,11 +232,19 @@ let rec rewrite_constant_function_calls' ast =
| _ -> E_aux (e_aux, annot)
in
- let rw_exp = {
- id_exp_alg with
- e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)
+ fold_exp { id_exp_alg with e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)}
+
+let rewrite_exp_once = rw_exp (fun _ -> ()) (fun _ -> ())
+
+let rec rewrite_constant_function_calls' ast =
+ let rewrite_count = ref 0 in
+ let ok () = incr rewrite_count in
+ let not_ok () = decr rewrite_count in
+
+ let rw_defs = {
+ rewriters_base with
+ rewrite_exp = (fun _ -> rw_exp ok not_ok (initial_state ast))
} in
- let rw_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp) } in
let ast = rewrite_defs_base rw_defs ast in
(* We keep iterating until we have no more re-writes to do *)
if !rewrite_count > 0
diff --git a/src/constant_propagation.ml b/src/constant_propagation.ml
index 33b67008..3ae46657 100644
--- a/src/constant_propagation.ml
+++ b/src/constant_propagation.ml
@@ -111,9 +111,13 @@ let rec is_value (E_aux (e,(l,annot))) =
| E_id id -> is_constructor id
| E_lit _ -> true
| E_tuple es -> List.for_all is_value es
+ | E_record fes ->
+ List.for_all (fun (FE_aux (FE_Fexp (_, e), _)) -> is_value e) fes
| E_app (id,es) -> is_constructor id && List.for_all is_value es
(* We add casts to undefined to keep the type information in the AST *)
| E_cast (typ,E_aux (E_lit (L_aux (L_undef,_)),_)) -> true
+ (* Also keep casts around records, as type inference fails without *)
+ | E_cast (_, (E_aux (E_record _, _) as e')) -> is_value e'
(* TODO: more? *)
| _ -> false
@@ -263,93 +267,6 @@ let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) =
| L_num i1, L_num i2 -> Some (Big_int.equal i1 i2)
| _ -> Some (l1 = l2)
-let try_app (l,ann) (id,args) =
- let new_l = Parse_ast.Generated l in
- let env = env_of_annot (l,ann) in
- let get_overloads f = List.map string_of_id
- (Env.get_overloads (Id_aux (Id f, Parse_ast.Unknown)) env @
- Env.get_overloads (Id_aux (DeIid f, Parse_ast.Unknown)) env) in
- let is_id f = List.mem (string_of_id id) (f :: get_overloads f) in
- if is_id "==" || is_id "!=" then
- match args with
- | [E_aux (E_lit l1,_); E_aux (E_lit l2,_)] ->
- let lit b = if b then L_true else L_false in
- let lit b = lit (if is_id "==" then b else not b) in
- (match lit_eq l1 l2 with
- | None -> None
- | Some b -> Some (E_aux (E_lit (L_aux (lit b,new_l)),(l,ann))))
- | _ -> None
- else if is_id "cast_bit_bool" then
- match args with
- | [E_aux (E_lit L_aux (L_zero,_),_)] -> Some (E_aux (E_lit (L_aux (L_false,new_l)),(l,ann)))
- | [E_aux (E_lit L_aux (L_one ,_),_)] -> Some (E_aux (E_lit (L_aux (L_true ,new_l)),(l,ann)))
- | _ -> None
- else if is_id "UInt" || is_id "unsigned" then
- match args with
- | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit,_), _)] ->
- Some (E_aux (E_lit (L_aux (L_num (int_of_str_lit lit),new_l)),(l,ann)))
- | _ -> None
- else if is_id "slice" then
- match args with
- | [E_aux (E_lit (L_aux ((L_hex _| L_bin _),_) as lit), annot);
- E_aux (E_lit L_aux (L_num i,_), _);
- E_aux (E_lit L_aux (L_num len,_), _)] ->
- (match Env.base_typ_of (env_of_annot annot) (typ_of_annot annot) with
- | Typ_aux (Typ_app (_,[_;A_aux (A_order ord,_);_]),_) ->
- (match slice_lit lit i len ord with
- | Some lit' -> Some (E_aux (E_lit lit',(l,ann)))
- | None -> None)
- | _ -> None)
- | _ -> None
- else if is_id "bitvector_concat" then
- match args with
- | [E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit1,_), _);
- E_aux (E_lit L_aux ((L_hex _| L_bin _) as lit2,_), _)] ->
- Some (E_aux (E_lit (L_aux (concat_vec lit1 lit2,new_l)),(l,ann)))
- | _ -> None
- else if is_id "shl_int" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.shift_left i (Big_int.to_int j)),new_l)),(l,ann)))
- | _ -> None
- else if is_id "mult_atom" || is_id "mult_int" || is_id "mult_range" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.mul i j),new_l)),(l,ann)))
- | _ -> None
- else if is_id "quotient_nat" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.div i j),new_l)),(l,ann)))
- | _ -> None
- else if is_id "add_atom" || is_id "add_int" || is_id "add_range" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_); E_aux (E_lit L_aux (L_num j,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.add i j),new_l)),(l,ann)))
- | _ -> None
- else if is_id "negate_range" then
- match args with
- | [E_aux (E_lit L_aux (L_num i,_),_)] ->
- Some (E_aux (E_lit (L_aux (L_num (Big_int.negate i),new_l)),(l,ann)))
- | _ -> None
- else if is_id "ex_int" then
- match args with
- | [E_aux (E_lit lit,(l,_))] -> Some (E_aux (E_lit lit,(l,ann)))
- | [E_aux (E_cast (_,(E_aux (E_lit (L_aux (L_undef,_)),_) as e)),(l,_))] ->
- Some (reduce_cast (typ_of_annot (l,ann)) e l ann)
- | _ -> None
- else if is_id "vector_access" || is_id "bitvector_access" then
- match args with
- | [E_aux (E_lit L_aux ((L_hex _ | L_bin _) as lit,_),_);
- E_aux (E_lit L_aux (L_num i,_),_)] ->
- let v = int_of_str_lit lit in
- let b = Big_int.bitwise_and (Big_int.shift_right v (Big_int.to_int i)) (Big_int.of_int 1) in
- let lit' = if Big_int.equal b (Big_int.of_int 1) then L_one else L_zero in
- Some (E_aux (E_lit (L_aux (lit',new_l)),(l,ann)))
- | _ -> None
- else None
-
-
let construct_lit_vector args =
let rec aux l = function
| [] -> Some (L_aux (L_bin (String.concat "" (List.rev l)),Unknown))
@@ -361,10 +278,18 @@ let construct_lit_vector args =
(* Add a cast to undefined so that it retains its type, otherwise it can't be
substituted safely *)
let keep_undef_typ value =
- match value with
- | E_aux (E_lit (L_aux (L_undef,lann)),eann) ->
- E_aux (E_cast (typ_of_annot eann,value),(Generated Unknown,snd eann))
- | _ -> value
+ let e_aux (e, ann) =
+ match e with
+ | E_lit (L_aux (L_undef, _)) ->
+ (* Add cast to undefined... *)
+ E_aux (E_cast (typ_of_annot ann, E_aux (e, ann)), ann)
+ | E_cast (typ, E_aux (E_cast (_, e), _)) ->
+ (* ... unless there was a cast already *)
+ E_aux (E_cast (typ, e), ann)
+ | _ -> E_aux (e, ann)
+ in
+ let open Rewriter in
+ fold_exp { id_exp_alg with e_aux = e_aux } value
(* Check whether the current environment with the given kid assignments is
inconsistent (and hence whether the code is dead) *)
@@ -375,6 +300,15 @@ let is_env_inconsistent env ksubsts =
let const_props defs ref_vars =
+ let const_fold exp =
+ try
+ strip_exp exp
+ |> infer_exp (env_of exp)
+ |> Constant_fold.rewrite_exp_once (Constant_fold.initial_state defs)
+ |> keep_undef_typ
+ with
+ | _ -> exp
+ in
let rec const_prop_exp substs assigns ((E_aux (e,(l,annot))) as exp) =
(* Functions to treat lists and tuples of subexpressions as possibly
non-deterministic: that is, we stop making any assumptions about
@@ -414,7 +348,8 @@ let const_props defs ref_vars =
let e4',_ = const_prop_exp substs assigns e4 in
e1',e2',e3',e4',assigns
in
- let re e assigns = E_aux (e,(l,annot)),assigns in
+ let rewrap e = E_aux (e,(l,annot)) in
+ let re e assigns = rewrap e,assigns in
match e with
(* TODO: are there more circumstances in which we should get rid of these? *)
| E_block [e] -> const_prop_exp substs assigns e
@@ -444,12 +379,7 @@ let const_props defs ref_vars =
| E_app (id,es) ->
let es',assigns = non_det_exp_list es in
let env = Type_check.env_of_annot (l, annot) in
- (match try_app (l,annot) (id,es') with
- | None ->
- (match const_prop_try_fn l env (id,es') with
- | None -> re (E_app (id,es')) assigns
- | Some r -> r,assigns)
- | Some r -> r,assigns)
+ const_prop_try_fn env (id, es') (l, annot), assigns
| E_tuple es ->
let es',assigns = non_det_exp_list es in
re (E_tuple es') assigns
@@ -539,10 +469,33 @@ let const_props defs ref_vars =
let assigned_in = IdSet.union (assigned_vars_in_fexps fes) (assigned_vars e) in
let assigns = isubst_minus_set assigns assigned_in in
let e',_ = const_prop_exp substs assigns e in
- re (E_record_update (e', const_prop_fexps substs assigns fes)) assigns
+ let fes' = const_prop_fexps substs assigns fes in
+ begin
+ match unaux_exp (fst (uncast_exp e')) with
+ | E_record (fes0) ->
+ let apply_fexp (FE_aux (FE_Fexp (id, e), _)) (FE_aux (FE_Fexp (id', e'), ann)) =
+ if Id.compare id id' = 0 then
+ FE_aux (FE_Fexp (id', e), ann)
+ else
+ FE_aux (FE_Fexp (id', e'), ann)
+ in
+ let update_fields fexp = List.map (apply_fexp fexp) in
+ let fes0' = List.fold_right update_fields fes' fes0 in
+ re (E_record fes0') assigns
+ | _ ->
+ re (E_record_update (e', fes')) assigns
+ end
| E_field (e,id) ->
let e',assigns = const_prop_exp substs assigns e in
- re (E_field (e',id)) assigns
+ begin
+ let is_field (FE_aux (FE_Fexp (id', _), _)) = Id.compare id id' = 0 in
+ match unaux_exp e' with
+ | E_record fes0 when List.exists is_field fes0 ->
+ let (FE_aux (FE_Fexp (_, e), _)) = List.find is_field fes0 in
+ re (unaux_exp e) assigns
+ | _ ->
+ re (E_field (e',id)) assigns
+ end
| E_case (e,cases) ->
let e',assigns = const_prop_exp substs assigns e in
(match can_match e' cases substs assigns with
@@ -568,7 +521,7 @@ let const_props defs ref_vars =
let e2',assigns = const_prop_exp substs' assigns e2 in
re (E_let (LB_aux (LB_val (p,e'), annot),
e2')) assigns in
- if is_value e' && not (is_value e) then
+ if is_value e' then
match can_match e' [Pat_aux (Pat_exp (p,e2),(Unknown,empty_tannot))] substs assigns with
| None -> plain ()
| Some (e'',bindings,kbindings) ->
@@ -581,10 +534,10 @@ let const_props defs ref_vars =
(* TODO maybe - tuple assignments *)
| E_assign (le,e) ->
let env = Type_check.env_of_annot (l, annot) in
+ let e',_ = const_prop_exp substs assigns e in
let assigned_in = IdSet.union (assigned_vars_in_lexp le) (assigned_vars e) in
let assigns = isubst_minus_set assigns assigned_in in
let le',idopt = const_prop_lexp substs assigns le in
- let e',_ = const_prop_exp substs assigns e in
let assigns =
match idopt with
| Some id ->
@@ -653,48 +606,23 @@ let const_props defs ref_vars =
| LEXP_field (le,id) -> re (LEXP_field (fst (const_prop_lexp substs assigns le), id))
| LEXP_deref e ->
re (LEXP_deref (fst (const_prop_exp substs assigns e)))
- (* Reduce a function when
- 1. all arguments are values,
- 2. the function is pure,
- 3. the result is a value
- (and 4. the function is not scattered, but that's not terribly important)
- to try and keep execution time and the results managable.
+ (* Try to evaluate function calls with constant arguments via
+ (interpreter-based) constant folding.
+ Boolean connectives are special-cased to support short-circuiting when one
+ argument has a suitable value (even if the other one is not constant).
*)
- and const_prop_try_fn l env (id,args) =
- if not (List.for_all is_value args) then
- None
- else
- let (tq,typ) = Env.get_val_spec_orig id env in
- let eff = match typ with
- | Typ_aux (Typ_fn (_,_,eff),_) -> Some eff
- | _ -> None
- in
- let Defs ds = defs in
- match eff, list_extract (function
- | (DEF_fundef (FD_aux (FD_function (_,_,eff,((FCL_aux (FCL_Funcl (id',_),_))::_ as fcls)),_)))
- -> if Id.compare id id' = 0 then Some fcls else None
- | _ -> None) ds with
- | None,_ | _,None -> None
- | Some eff,_ when not (is_pure eff) -> None
- | Some _,Some fcls ->
- let arg = match args with
- | [] -> E_aux (E_lit (L_aux (L_unit,Generated l)),(Generated l,empty_tannot))
- | [e] -> e
- | _ -> E_aux (E_tuple args,(Generated l,empty_tannot)) in
- let cases = List.map (function
- | FCL_aux (FCL_Funcl (_,pexp), ann) -> pexp)
- fcls in
- match can_match_with_env env arg cases (Bindings.empty,KBindings.empty) Bindings.empty with
- | Some (exp,bindings,kbindings) ->
- let substs = bindings_from_list bindings, kbindings_from_list kbindings in
- let result,_ = const_prop_exp substs Bindings.empty exp in
- let result = match result with
- | E_aux (E_return e,_) -> e
- | _ -> result
- in
- if is_value result then Some result else None
- | None -> None
-
+ and const_prop_try_fn env (id, args) (l, annot) =
+ match (string_of_id id, args) with
+ | "and_bool", ([E_aux (E_lit (L_aux (L_false, _)), _) as e_false; _] |
+ [_; E_aux (E_lit (L_aux (L_false, _)), _) as e_false]) ->
+ e_false
+ | "or_bool", ([E_aux (E_lit (L_aux (L_true, _)), _) as e_true; _] |
+ [_; E_aux (E_lit (L_aux (L_true, _)), _) as e_true]) ->
+ e_true
+ | _ ->
+ let exp = (E_aux (E_app (id, args), (l, annot))) in
+ if List.for_all Constant_fold.is_constant args then const_fold exp else exp
+
and can_match_with_env env (E_aux (e,(l,annot)) as exp0) cases (substs,ksubsts) assigns =
let rec findpat_generic check_pat description assigns = function
| [] -> (Reporting.print_err l "Monomorphisation"
@@ -816,6 +744,8 @@ let const_props defs ref_vars =
(Reporting.print_err l' "Monomorphisation"
"Unexpected kind of pattern for literal"; GiveUp)
in findpat_generic checkpat "literal" assigns cases
+ | E_record _ | E_cast (_, E_aux (E_record _, _)) ->
+ findpat_generic (fun _ -> DoesNotMatch) "record" assigns cases
| _ -> None
and can_match exp =
diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml
new file mode 100644
index 00000000..683cc6f3
--- /dev/null
+++ b/src/constant_propagation_mutrec.ml
@@ -0,0 +1,232 @@
+(**************************************************************************)
+(* Sail *)
+(* *)
+(* Copyright (c) 2013-2017 *)
+(* Kathyrn Gray *)
+(* Shaked Flur *)
+(* Stephen Kell *)
+(* Gabriel Kerneis *)
+(* Robert Norton-Wright *)
+(* Christopher Pulte *)
+(* Peter Sewell *)
+(* Alasdair Armstrong *)
+(* Brian Campbell *)
+(* Thomas Bauereiss *)
+(* Anthony Fox *)
+(* Jon French *)
+(* Dominic Mulligan *)
+(* Stephen Kell *)
+(* Mark Wassell *)
+(* *)
+(* All rights reserved. *)
+(* *)
+(* This software was developed by the University of Cambridge Computer *)
+(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *)
+(* (REMS) project, funded by EPSRC grant EP/K008528/1. *)
+(* *)
+(* Redistribution and use in source and binary forms, with or without *)
+(* modification, are permitted provided that the following conditions *)
+(* are met: *)
+(* 1. Redistributions of source code must retain the above copyright *)
+(* notice, this list of conditions and the following disclaimer. *)
+(* 2. Redistributions in binary form must reproduce the above copyright *)
+(* notice, this list of conditions and the following disclaimer in *)
+(* the documentation and/or other materials provided with the *)
+(* distribution. *)
+(* *)
+(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *)
+(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *)
+(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *)
+(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *)
+(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *)
+(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *)
+(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *)
+(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *)
+(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *)
+(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *)
+(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *)
+(* SUCH DAMAGE. *)
+(**************************************************************************)
+
+open Ast
+open Ast_util
+open Type_check
+open Rewriter
+
+(* Unroll mutually recursive calls, starting with the functions given as
+ targets on the command line, by looking for recursive calls with (some)
+ constant arguments, and creating copies of those functions with the
+ constants propagated in. This may cause branches with mutually recursively
+ calls to disappear, breaking the mutually recursive cycle. *)
+
+let targets = ref ([] : id list)
+
+let rec is_const_exp exp = match unaux_exp exp with
+ | E_lit (L_aux ((L_true | L_false | L_one | L_zero | L_num _), _)) -> true
+ | E_vector es -> List.for_all is_const_exp es && is_bitvector_typ (typ_of exp)
+ | E_record fes -> List.for_all is_const_fexp fes
+ | _ -> false
+and is_const_fexp (FE_aux (FE_Fexp (_, e), _)) = is_const_exp e
+
+let recheck_exp exp = check_exp (env_of exp) (strip_exp exp) (typ_of exp)
+
+(* Name function copy by encoding values of constant arguments *)
+let generate_fun_id id args =
+ let rec suffix exp = match unaux_exp exp with
+ | E_lit (L_aux (L_one, _)) -> "1"
+ | E_lit (L_aux (L_zero, _)) -> "0"
+ | E_lit (L_aux (L_true, _)) -> "T"
+ | E_lit (L_aux (L_false, _)) -> "F"
+ | E_record fes when is_const_exp exp ->
+ let fsuffix (FE_aux (FE_Fexp (id, e), _)) = suffix e
+ in
+ "struct" ^
+ Util.zencode_string (string_of_typ (typ_of exp)) ^
+ "#" ^
+ String.concat "" (List.map fsuffix fes)
+ | E_vector es when is_const_exp exp ->
+ String.concat "" (List.map suffix es)
+ | _ ->
+ if is_const_exp exp
+ then "#" ^ Util.zencode_string (string_of_exp exp)
+ else "v"
+ in
+ append_id id ("#mutrec_" ^ String.concat "" (List.map suffix args))
+
+(* Generate a val spec for a function copy, removing the constant arguments
+ that will be propagated in *)
+let generate_val_spec env id args l annot =
+ match Env.get_val_spec_orig id env with
+ | tq, Typ_aux (Typ_fn (arg_typs, ret_typ, eff), _) ->
+ let orig_ksubst (kid, typ_arg) =
+ match typ_arg with
+ | A_aux ((A_nexp _ | A_bool _), _) -> (orig_kid kid, typ_arg)
+ | _ -> raise (Reporting.err_todo l "Propagation of polymorphic arguments not implemented")
+ in
+ let ksubsts =
+ recheck_exp (E_aux (E_app (id, args), (l, annot)))
+ |> instantiation_of
+ |> KBindings.bindings
+ |> List.map orig_ksubst
+ |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty
+ in
+ let ret_typ' = KBindings.fold typ_subst ksubsts ret_typ in
+ let arg_typs' =
+ List.map (KBindings.fold typ_subst ksubsts) arg_typs
+ |> List.map2 (fun arg typ -> if is_const_exp arg then [] else [typ]) args
+ |> List.concat
+ |> function [] -> [unit_typ] | typs -> typs
+ in
+ let typ' = mk_typ (Typ_fn (arg_typs', ret_typ', eff)) in
+ let tyvars = tyvars_of_typ typ' in
+ let tq' =
+ quant_items tq |>
+ List.filter (fun qi -> KidSet.subset (tyvars_of_quant_item qi) tyvars) |>
+ mk_typquant
+ in
+ let typschm = mk_typschm tq' typ' in
+ mk_val_spec (VS_val_spec (typschm, generate_fun_id id args, (fun _ -> None), false)),
+ ksubsts
+ | _, Typ_aux (_, l) ->
+ raise (Reporting.err_unreachable l __POS__ "Function val spec is not a function type")
+
+let const_prop defs substs ksubsts exp =
+ (* Constant_propagation currently only supports nexps for kid substitutions *)
+ let nexp_substs =
+ KBindings.bindings ksubsts
+ |> List.map (function (kid, A_aux (A_nexp n, _)) -> [(kid, n)] | _ -> [])
+ |> List.concat
+ |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty
+ in
+ Constant_propagation.const_prop
+ (Defs defs)
+ (Constant_propagation.referenced_vars exp)
+ (substs, nexp_substs)
+ Bindings.empty
+ exp
+ |> fst
+
+(* Propagate constant arguments into function clause pexp *)
+let prop_args_pexp defs ksubsts args pexp =
+ let pat, guard, exp, annot = destruct_pexp pexp in
+ let pats = match pat with
+ | P_aux (P_tup pats, _) -> pats
+ | _ -> [pat]
+ in
+ let match_arg (E_aux (_, (l, _)) as arg) pat (pats, substs) =
+ if is_const_exp arg then
+ match pat with
+ | P_aux (P_id id, _) -> (pats, Bindings.add id arg substs)
+ | _ ->
+ raise (Reporting.err_todo l
+ ("Unsupported pattern match in propagation of constant arguments: " ^
+ string_of_exp arg ^ " and " ^ string_of_pat pat))
+ else (pat :: pats, substs)
+ in
+ let pats, substs = List.fold_right2 match_arg args pats ([], Bindings.empty) in
+ let exp' = const_prop defs substs ksubsts exp in
+ let pat' = match pats with
+ | [pat] -> pat
+ | _ -> P_aux (P_tup pats, (Parse_ast.Unknown, empty_tannot))
+ in
+ construct_pexp (pat', guard, exp', annot)
+
+let rewrite_defs env (Defs defs) =
+ let rec rewrite = function
+ | [] -> []
+ | DEF_internal_mutrec mutrecs :: ds ->
+ let mutrec_ids = IdSet.of_list (List.map id_of_fundef mutrecs) in
+ let valspecs = ref ([] : unit def list) in
+ let fundefs = ref ([] : unit def list) in
+ (* Try to replace mutually recursive calls that have some constant arguments *)
+ let rec e_app (id, args) (l, annot) =
+ if IdSet.mem id mutrec_ids && List.exists is_const_exp args then
+ let id' = generate_fun_id id args in
+ let args' = match List.filter (fun e -> not (is_const_exp e)) args with
+ | [] -> [infer_exp env (mk_lit_exp L_unit)]
+ | args' -> args'
+ in
+ if not (IdSet.mem id' (ids_of_defs (Defs !valspecs))) then begin
+ (* Generate copy of function with constant arguments propagated in *)
+ let (FD_aux (FD_function (_, _, _, fcls), _)) =
+ List.find (fun fd -> Id.compare id (id_of_fundef fd) = 0) mutrecs
+ in
+ let valspec, ksubsts = generate_val_spec env id args l annot in
+ let const_prop_funcl (FCL_aux (FCL_Funcl (_, pexp), (l, _))) =
+ let pexp' =
+ prop_args_pexp defs ksubsts args pexp
+ |> rewrite_pexp
+ |> strip_pexp
+ in
+ FCL_aux (FCL_Funcl (id', pexp'), (Parse_ast.Generated l, ()))
+ in
+ valspecs := valspec :: !valspecs;
+ let fundef = mk_fundef (List.map const_prop_funcl fcls) in
+ fundefs := fundef :: !fundefs
+ end else ();
+ E_aux (E_app (id', args'), (l, annot))
+ else E_aux (E_app (id, args), (l, annot))
+ and e_aux (e, (l, annot)) =
+ match e with
+ | E_app (id, args) -> e_app (id, args) (l, annot)
+ | _ -> E_aux (e, (l, annot))
+ and rewrite_pexp pexp = fold_pexp { id_exp_alg with e_aux = e_aux } pexp
+ and rewrite_funcl (FCL_aux (FCL_Funcl (id, pexp), a) as funcl) =
+ let pexp' =
+ if List.exists (fun id' -> Id.compare id id' = 0) !targets then
+ let pat, guard, body, annot = destruct_pexp pexp in
+ let body' = const_prop defs Bindings.empty KBindings.empty body in
+ rewrite_pexp (construct_pexp (pat, guard, recheck_exp body', annot))
+ else pexp
+ in FCL_aux (FCL_Funcl (id, pexp'), a)
+ and rewrite_fundef (FD_aux (FD_function (ropt, topt, eopt, fcls), a)) =
+ let fcls' = List.map rewrite_funcl fcls in
+ FD_aux (FD_function (ropt, topt, eopt, fcls'), a)
+ in
+ let mutrecs' = List.map (fun fd -> DEF_fundef (rewrite_fundef fd)) mutrecs in
+ let (Defs fdefs) = fst (check env (Defs (!valspecs @ !fundefs))) in
+ mutrecs' @ fdefs @ rewrite ds
+ | d :: ds ->
+ d :: rewrite ds
+ in
+ Spec_analysis.top_sort_defs (Defs (rewrite defs))
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 34b9388d..8bfbc351 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -4688,9 +4688,11 @@ let rewrite_defs_lem = [
("fix_val_specs", rewrite_fix_val_specs);
("split_execute", rewrite_split_fun_ctor_pats "execute");
("recheck_defs", recheck_defs);
+ ("top_sort_defs", fun _ -> top_sort_defs);
+ ("const_prop_mutrec", Constant_propagation_mutrec.rewrite_defs);
+ ("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
("exp_lift_assign", rewrite_defs_exp_lift_assign);
(* ("remove_assert", rewrite_defs_remove_assert); *)
- ("top_sort_defs", fun _ -> top_sort_defs);
(* ("sizeof", rewrite_sizeof); *)
("early_return", rewrite_defs_early_return);
("fix_val_specs", rewrite_fix_val_specs);
diff --git a/src/sail.ml b/src/sail.ml
index 23836b1d..fa8f990b 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -277,6 +277,9 @@ let options = Arg.align ([
( "-dmono_continue",
Arg.Set Rewrites.opt_dmono_continue,
" continue despite monomorphisation errors");
+ ( "-const_prop_mutrec",
+ Arg.String (fun name -> Constant_propagation_mutrec.targets := Ast_util.mk_id name :: !Constant_propagation_mutrec.targets),
+ " unroll function in a set of mutually recursive functions");
( "-verbose",
Arg.Int (fun verbosity -> Util.opt_verbosity := verbosity),
" produce verbose output");