diff options
Diffstat (limited to 'src/constant_fold.ml')
| -rw-r--r-- | src/constant_fold.ml | 116 |
1 files changed, 108 insertions, 8 deletions
diff --git a/src/constant_fold.ml b/src/constant_fold.ml index 7a7067ef..35417ac8 100644 --- a/src/constant_fold.ml +++ b/src/constant_fold.ml @@ -91,7 +91,7 @@ and exp_of_value = let safe_primops = List.fold_left (fun m k -> StringMap.remove k m) - Value.primops + !Value.primops [ "print_endline"; "prerr_endline"; "putchar"; @@ -191,7 +191,17 @@ let rec run frame = let initial_state ast env = Interpreter.initial_state ~registers:false ast env safe_primops -let rw_exp target ok not_ok istate = +type fixed = { + registers: tannot exp Bindings.t; + fields: tannot exp Bindings.t Bindings.t; + } + +let no_fixed = { + registers = Bindings.empty; + fields = Bindings.empty; + } + +let rw_exp fixed target ok not_ok istate = let evaluate e_aux annot = let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in try @@ -219,6 +229,34 @@ let rw_exp target ok not_ok istate = | E_app (id, args) when fold_to_unit id -> ok (); E_aux (E_lit (L_aux (L_unit, fst annot)), annot) + | E_id id -> + begin match Bindings.find_opt id fixed.registers with + | Some exp -> + ok (); exp + | None -> + E_aux (e_aux, annot) + end + + | E_field (E_aux (E_id id, _), field) -> + begin match Bindings.find_opt id fixed.fields with + | Some fields -> + begin match Bindings.find_opt field fields with + | Some exp -> + ok (); exp + | None -> + E_aux (e_aux, annot) + end + | None -> + E_aux (e_aux, annot) + end + + (* Short-circuit boolean operators with constants *) + | E_app (id, [(E_aux (E_lit (L_aux (L_false, _)), _) as false_exp); _]) when string_of_id id = "and_bool" -> + ok (); false_exp + + | E_app (id, [(E_aux (E_lit (L_aux (L_true, _)), _) as true_exp); _]) when string_of_id id = "or_bool" -> + ok (); true_exp + | E_app (id, args) when List.for_all is_constant args -> let env = env_of_annot annot in (* We want to fold all primitive operations, but avoid folding @@ -252,9 +290,9 @@ let rw_exp target ok not_ok istate = in fold_exp { id_exp_alg with e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)} -let rewrite_exp_once target = rw_exp target (fun _ -> ()) (fun _ -> ()) +let rewrite_exp_once target = rw_exp no_fixed target (fun _ -> ()) (fun _ -> ()) -let rec rewrite_constant_function_calls' target ast = +let rec rewrite_constant_function_calls' fixed target ast = let rewrite_count = ref 0 in let ok () = incr rewrite_count in let not_ok () = decr rewrite_count in @@ -262,16 +300,78 @@ let rec rewrite_constant_function_calls' target ast = let rw_defs = { rewriters_base with - rewrite_exp = (fun _ -> rw_exp target ok not_ok istate) + rewrite_exp = (fun _ -> rw_exp fixed target ok not_ok istate) } in let ast = rewrite_defs_base rw_defs ast in (* We keep iterating until we have no more re-writes to do *) if !rewrite_count > 0 - then rewrite_constant_function_calls' target ast + then rewrite_constant_function_calls' fixed target ast else ast -let rewrite_constant_function_calls target ast = +let rewrite_constant_function_calls fixed target ast = if !optimize_constant_fold then - rewrite_constant_function_calls' target ast + rewrite_constant_function_calls' fixed target ast else ast + +type to_constant = + | Register of id * typ * tannot exp + | Register_field of id * id * typ * tannot exp + +let () = + let open Interactive in + let open Printf in + + let update_fixed fixed = function + | Register (id, _, exp) -> + { fixed with registers = Bindings.add id exp fixed.registers } + | Register_field (id, field, _, exp) -> + let prev_fields = match Bindings.find_opt id fixed.fields with Some f -> f | None -> Bindings.empty in + let updated_fields = Bindings.add field exp prev_fields in + { fixed with fields = Bindings.add id updated_fields fixed.fields } + in + + ArgString ("target", fun target -> ArgString ("assignments", fun assignments -> Action (fun () -> + let assignments = Str.split (Str.regexp " +") assignments in + let assignments = + List.map (fun assignment -> + match String.split_on_char '=' assignment with + | [reg; value] -> + begin match String.split_on_char '.' reg with + | [reg; field] -> + let reg = mk_id reg in + let field = mk_id field in + begin match Env.lookup_id reg !env with + | Register (_, _, Typ_aux (Typ_id rec_id, _)) -> + let (_, fields) = Env.get_record rec_id !env in + let typ = match List.find_opt (fun (typ, id) -> Id.compare id field = 0) fields with + | Some (typ, _) -> typ + | None -> failwith (sprintf "Register %s does not have a field %s" (string_of_id reg) (string_of_id field)) + in + let exp = Initial_check.exp_of_string value in + let exp = check_exp !env exp typ in + Register_field (reg, field, typ, exp) + | _ -> + failwith (sprintf "Register %s is not defined as a record in the current environment" (string_of_id reg)) + end + | _ -> + let reg = mk_id reg in + begin match Env.lookup_id reg !env with + | Register (_, _, typ) -> + let exp = Initial_check.exp_of_string value in + let exp = check_exp !env exp typ in + Register (reg, typ, exp) + | _ -> + failwith (sprintf "Register %s is not defined in the current environment" (string_of_id reg)) + end + end + | _ -> failwith (sprintf "Could not parse '%s' as an assignment <register>=<value>" assignment) + ) assignments in + let assignments = List.fold_left update_fixed no_fixed assignments in + + ast := rewrite_constant_function_calls' assignments target !ast))) + |> register_command + ~name:"fix_registers" + ~help:"Fix the value of specified registers, specified as a \ + list of <register>=<value>. Can also fix a specific \ + register field as <register>.<field>=<value>." |
