diff options
| author | Alasdair | 2018-06-21 03:19:11 +0100 |
|---|---|---|
| committer | Alasdair | 2018-06-29 23:45:31 +0100 |
| commit | 668ae1bc10fc1d9ea4d62ae0a2708a52cd83e211 (patch) | |
| tree | 3f461ebea3b8bf74b24b1491091d1f0b0b1d6e56 /src | |
| parent | b5424eea9c13935680b4cd6b76a377ff524699cd (diff) | |
Constant folding improvements
Diffstat (limited to 'src')
| -rw-r--r-- | src/constant_fold.ml | 101 | ||||
| -rw-r--r-- | src/interpreter.ml | 2 | ||||
| -rw-r--r-- | src/process_file.ml | 2 | ||||
| -rw-r--r-- | src/type_check.ml | 17 | ||||
| -rw-r--r-- | src/value2.lem | 20 |
5 files changed, 84 insertions, 58 deletions
diff --git a/src/constant_fold.ml b/src/constant_fold.ml index 7a35226e..45d3efe0 100644 --- a/src/constant_fold.ml +++ b/src/constant_fold.ml @@ -59,10 +59,25 @@ module StringMap = Map.Make(String);; false = no folding, true = perform constant folding. *) let optimize_constant_fold = ref false -let exp_of_value = +let rec fexp_of_ctor (field, value) = + FE_aux (FE_Fexp (mk_id field, exp_of_value value), no_annot) + +and exp_of_value = let open Value in function | V_int n -> mk_lit_exp (L_num n) + | V_bit Sail_lib.B0 -> mk_lit_exp L_zero + | V_bit Sail_lib.B1 -> mk_lit_exp L_one + | V_bool true -> mk_lit_exp L_true + | V_bool false -> mk_lit_exp L_false + | V_string str -> mk_lit_exp (L_string str) + | V_record ctors -> + mk_exp (E_record (FES_aux (FES_Fexps (List.map fexp_of_ctor (StringMap.bindings ctors), false), no_annot))) + | V_vector vs -> + mk_exp (E_vector (List.map exp_of_value vs)) + | V_tuple vs -> + mk_exp (E_tuple (List.map exp_of_value vs)) + | V_unit -> mk_lit_exp L_unit | _ -> failwith "No expression for value" (* We want to avoid evaluating things like print statements at compile @@ -85,15 +100,23 @@ let safe_primops = "Elf_loader.elf_tohost" ] -let is_literal = function - | E_aux (E_lit _, _) -> true +let rec is_constant (E_aux (e_aux, _)) = + match e_aux with + | E_lit _ -> true + | E_vector exps -> List.for_all is_constant exps + | E_record (FES_aux (FES_Fexps (fexps, _), _)) -> List.for_all is_constant_fexp fexps + | E_cast (_, exp) -> is_constant exp + | E_tuple exps -> List.for_all is_constant exps | _ -> false +and is_constant_fexp (FE_aux (FE_Fexp (_, exp), _)) = is_constant exp + (* Wrapper around interpreter that repeatedly steps until done. *) let rec run ast frame = match frame with | Interpreter.Done (state, v) -> v - | Interpreter.Step _ -> + | Interpreter.Step (lazy_str, _, _, _) -> + prerr_endline (Lazy.force lazy_str); run ast (Interpreter.eval_frame ast frame) | Interpreter.Break frame -> run ast (Interpreter.eval_frame ast frame) @@ -115,35 +138,57 @@ let rec run ast frame = - Throws an exception that isn't caught. *) -let rewrite_constant_function_calls' ast = +let rec rewrite_constant_function_calls' ast = + let rewrite_count = ref 0 in + let ok () = incr rewrite_count in + let not_ok () = decr rewrite_count in + let lstate, gstate = Interpreter.initial_state ast safe_primops in let gstate = { gstate with Interpreter.allow_registers = false } in + let evaluate e_aux annot = + let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in + try + begin + let v = run ast (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in + let exp = exp_of_value v in + try (ok (); Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)) with + | Type_error (l, err) -> + (* A type error here would be unexpected, so don't ignore it! *) + Util.warn ("Type error when folding constants in " + ^ string_of_exp (E_aux (e_aux, annot)) + ^ "\n" ^ Type_error.string_of_type_error err); + not_ok (); + E_aux (e_aux, annot) + end + with + (* Otherwise if anything goes wrong when trying to constant + fold, just continue without optimising. *) + | _ -> E_aux (e_aux, annot) + in + let rw_funcall e_aux annot = match e_aux with - | E_app (id, args) when List.for_all is_literal args -> - begin - let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in - try - begin - let v = run ast (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in - let exp = exp_of_value v in - try Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot) with - | Type_error (l, err) -> - (* A type error here would be unexpected, so don't ignore it! *) - Util.warn ("Type error when folding constants in " - ^ string_of_exp (E_aux (e_aux, annot)) - ^ "\n" ^ Type_error.string_of_type_error err); - E_aux (e_aux, annot) - end - with - (* Otherwise if anything goes wrong when trying to constant - fold, just continue without optimising. *) - | _ -> E_aux (e_aux, annot) - end + | E_app (id, args) when List.for_all is_constant args -> + evaluate e_aux annot + + | E_field (exp, id) when is_constant exp -> + evaluate e_aux annot + + | E_if (E_aux (E_lit (L_aux (L_true, _)), _), then_exp, _) -> ok (); then_exp + | E_if (E_aux (E_lit (L_aux (L_false, _)), _), _, else_exp) -> ok (); else_exp + + | E_let (LB_aux (LB_val (P_aux (P_id id, _), bind), _), exp) when is_constant bind -> + ok (); + subst id bind exp + | E_let (LB_aux (LB_val (P_aux (P_typ (typ, P_aux (P_id id, _)), annot), bind), _), exp) + when is_constant bind -> + ok (); + subst id (E_aux (E_cast (typ, bind), annot)) exp + | _ -> E_aux (e_aux, annot) in let rw_exp = { @@ -151,7 +196,11 @@ let rewrite_constant_function_calls' ast = e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot) } in let rw_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp) } in - rewrite_defs_base rw_defs ast + let ast = rewrite_defs_base rw_defs ast in + (* We keep iterating until we have no more re-writes to do *) + if !rewrite_count > 0 + then rewrite_constant_function_calls' ast + else ast let rewrite_constant_function_calls ast = if !optimize_constant_fold then diff --git a/src/interpreter.ml b/src/interpreter.ml index 00846d73..99d5889a 100644 --- a/src/interpreter.ml +++ b/src/interpreter.ml @@ -232,7 +232,6 @@ let is_value_fexp (FE_aux (FE_Fexp (id, exp), _)) = is_value exp let value_of_fexp (FE_aux (FE_Fexp (id, exp), _)) = (string_of_id id, value_of_exp exp) let rec build_letchain id lbs (E_aux (_, annot) as exp) = - (* print_endline ("LETCHAIN " ^ string_of_exp exp); *) match lbs with | [] -> exp | lb :: lbs when IdSet.mem id (letbind_pat_ids lb)-> @@ -311,7 +310,6 @@ let rec step (E_aux (e_aux, annot) as orig_exp) = else failwith "Match failure" - | E_vector_subrange (vec, n, m) -> wrap (E_app (mk_id "vector_subrange_dec", [vec; n; m])) | E_vector_access (vec, n) -> diff --git a/src/process_file.ml b/src/process_file.ml index 9603e986..9ed52e8d 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -403,7 +403,7 @@ let rewrite_ast_ocaml = rewrite Rewrites.rewrite_defs_ocaml let rewrite_ast_c ast = ast |> rewrite Rewrites.rewrite_defs_c - |> Constant_fold.rewrite_constant_function_calls + |> rewrite [("constant_fold", Constant_fold.rewrite_constant_function_calls)] let rewrite_ast_interpreter = rewrite Rewrites.rewrite_defs_interpreter let rewrite_ast_check = rewrite Rewrites.rewrite_defs_check diff --git a/src/type_check.ml b/src/type_check.ml index 0a29b6d6..afdf41c9 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -132,28 +132,11 @@ let is_unknown_type = function | (Typ_aux (Typ_internal_unknown, _)) -> true | _ -> false -(* An index_sort is a more general form of range type: it can either - be IS_int, which represents every natural number, or some set of - natural numbers given by an IS_prop expression of the form - {'n. f('n) <= g('n) /\ ...} *) -type index_sort = - | IS_int - | IS_prop of kid * (nexp * nexp) list - -let string_of_index_sort = function - | IS_int -> "INT" - | IS_prop (kid, constraints) -> - "{" ^ string_of_kid kid ^ " | " - ^ string_of_list " & " (fun (x, y) -> string_of_nexp x ^ " <= " ^ string_of_nexp y) constraints - ^ "}" - let is_atom (Typ_aux (typ_aux, _)) = match typ_aux with | Typ_app (f, [_]) when string_of_id f = "atom" -> true | _ -> false - - let rec strip_id = function | Id_aux (Id x, _) -> Id_aux (Id x, Parse_ast.Unknown) | Id_aux (DeIid x, _) -> Id_aux (DeIid x, Parse_ast.Unknown) diff --git a/src/value2.lem b/src/value2.lem index 33416503..d9fd1263 100644 --- a/src/value2.lem +++ b/src/value2.lem @@ -70,17 +70,13 @@ type vl = | V_record of list (string * vl) | V_null (* Used for unitialized values and null pointers in C compilation *) -let primops extern args = - match (extern, args) with - | ("and_bool", [V_bool b1; V_bool b2]) -> V_bool (b1 && b2) - | ("and_bool", [V_nondet; V_bool false]) -> V_bool false - | ("and_bool", [V_bool false; V_nondet]) -> V_bool false - | ("and_bool", _) -> V_nondet - | ("or_bool", [V_bool b1; V_bool b2]) -> V_bool (b1 || b2) - | ("or_bool", [V_nondet; V_bool true]) -> V_bool true - | ("or_bool", [V_bool true; V_nondet]) -> V_bool true - | ("or_bool", _) -> V_nondet +let value_int_op_int op = function + | [V_int v1; V_int v2] -> V_int (op v1 v2) + | _ -> V_null +end - | _ -> failwith ("Unimplemented primitive operation " ^ extern) - end +let value_bool_op_int op = function + | [V_int v1; V_int v2] -> V_bool (op v1 v2) + | _ -> V_null +end |
