diff options
Diffstat (limited to 'src/constant_fold.ml')
| -rw-r--r-- | src/constant_fold.ml | 101 |
1 files changed, 75 insertions, 26 deletions
diff --git a/src/constant_fold.ml b/src/constant_fold.ml index 7a35226e..45d3efe0 100644 --- a/src/constant_fold.ml +++ b/src/constant_fold.ml @@ -59,10 +59,25 @@ module StringMap = Map.Make(String);; false = no folding, true = perform constant folding. *) let optimize_constant_fold = ref false -let exp_of_value = +let rec fexp_of_ctor (field, value) = + FE_aux (FE_Fexp (mk_id field, exp_of_value value), no_annot) + +and exp_of_value = let open Value in function | V_int n -> mk_lit_exp (L_num n) + | V_bit Sail_lib.B0 -> mk_lit_exp L_zero + | V_bit Sail_lib.B1 -> mk_lit_exp L_one + | V_bool true -> mk_lit_exp L_true + | V_bool false -> mk_lit_exp L_false + | V_string str -> mk_lit_exp (L_string str) + | V_record ctors -> + mk_exp (E_record (FES_aux (FES_Fexps (List.map fexp_of_ctor (StringMap.bindings ctors), false), no_annot))) + | V_vector vs -> + mk_exp (E_vector (List.map exp_of_value vs)) + | V_tuple vs -> + mk_exp (E_tuple (List.map exp_of_value vs)) + | V_unit -> mk_lit_exp L_unit | _ -> failwith "No expression for value" (* We want to avoid evaluating things like print statements at compile @@ -85,15 +100,23 @@ let safe_primops = "Elf_loader.elf_tohost" ] -let is_literal = function - | E_aux (E_lit _, _) -> true +let rec is_constant (E_aux (e_aux, _)) = + match e_aux with + | E_lit _ -> true + | E_vector exps -> List.for_all is_constant exps + | E_record (FES_aux (FES_Fexps (fexps, _), _)) -> List.for_all is_constant_fexp fexps + | E_cast (_, exp) -> is_constant exp + | E_tuple exps -> List.for_all is_constant exps | _ -> false +and is_constant_fexp (FE_aux (FE_Fexp (_, exp), _)) = is_constant exp + (* Wrapper around interpreter that repeatedly steps until done. *) let rec run ast frame = match frame with | Interpreter.Done (state, v) -> v - | Interpreter.Step _ -> + | Interpreter.Step (lazy_str, _, _, _) -> + prerr_endline (Lazy.force lazy_str); run ast (Interpreter.eval_frame ast frame) | Interpreter.Break frame -> run ast (Interpreter.eval_frame ast frame) @@ -115,35 +138,57 @@ let rec run ast frame = - Throws an exception that isn't caught. *) -let rewrite_constant_function_calls' ast = +let rec rewrite_constant_function_calls' ast = + let rewrite_count = ref 0 in + let ok () = incr rewrite_count in + let not_ok () = decr rewrite_count in + let lstate, gstate = Interpreter.initial_state ast safe_primops in let gstate = { gstate with Interpreter.allow_registers = false } in + let evaluate e_aux annot = + let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in + try + begin + let v = run ast (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in + let exp = exp_of_value v in + try (ok (); Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)) with + | Type_error (l, err) -> + (* A type error here would be unexpected, so don't ignore it! *) + Util.warn ("Type error when folding constants in " + ^ string_of_exp (E_aux (e_aux, annot)) + ^ "\n" ^ Type_error.string_of_type_error err); + not_ok (); + E_aux (e_aux, annot) + end + with + (* Otherwise if anything goes wrong when trying to constant + fold, just continue without optimising. *) + | _ -> E_aux (e_aux, annot) + in + let rw_funcall e_aux annot = match e_aux with - | E_app (id, args) when List.for_all is_literal args -> - begin - let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in - try - begin - let v = run ast (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in - let exp = exp_of_value v in - try Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot) with - | Type_error (l, err) -> - (* A type error here would be unexpected, so don't ignore it! *) - Util.warn ("Type error when folding constants in " - ^ string_of_exp (E_aux (e_aux, annot)) - ^ "\n" ^ Type_error.string_of_type_error err); - E_aux (e_aux, annot) - end - with - (* Otherwise if anything goes wrong when trying to constant - fold, just continue without optimising. *) - | _ -> E_aux (e_aux, annot) - end + | E_app (id, args) when List.for_all is_constant args -> + evaluate e_aux annot + + | E_field (exp, id) when is_constant exp -> + evaluate e_aux annot + + | E_if (E_aux (E_lit (L_aux (L_true, _)), _), then_exp, _) -> ok (); then_exp + | E_if (E_aux (E_lit (L_aux (L_false, _)), _), _, else_exp) -> ok (); else_exp + + | E_let (LB_aux (LB_val (P_aux (P_id id, _), bind), _), exp) when is_constant bind -> + ok (); + subst id bind exp + | E_let (LB_aux (LB_val (P_aux (P_typ (typ, P_aux (P_id id, _)), annot), bind), _), exp) + when is_constant bind -> + ok (); + subst id (E_aux (E_cast (typ, bind), annot)) exp + | _ -> E_aux (e_aux, annot) in let rw_exp = { @@ -151,7 +196,11 @@ let rewrite_constant_function_calls' ast = e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot) } in let rw_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp) } in - rewrite_defs_base rw_defs ast + let ast = rewrite_defs_base rw_defs ast in + (* We keep iterating until we have no more re-writes to do *) + if !rewrite_count > 0 + then rewrite_constant_function_calls' ast + else ast let rewrite_constant_function_calls ast = if !optimize_constant_fold then |
