summaryrefslogtreecommitdiff
path: root/src/jib/jib_optimize.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/jib/jib_optimize.ml')
-rw-r--r--src/jib/jib_optimize.ml143
1 files changed, 120 insertions, 23 deletions
diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml
index 73b175a1..331cf65e 100644
--- a/src/jib/jib_optimize.ml
+++ b/src/jib/jib_optimize.ml
@@ -55,7 +55,7 @@ open Jib_util
let optimize_unit instrs =
let unit_cval cval =
match cval_ctyp cval with
- | CT_unit -> (F_lit V_unit, CT_unit)
+ | CT_unit -> (V_lit (VL_unit, CT_unit))
| _ -> cval
in
let unit_instr = function
@@ -81,16 +81,20 @@ let optimize_unit instrs =
filter_instrs non_pointless_copy (map_instr_list unit_instr instrs)
let flat_counter = ref 0
-let flat_id () =
- let id = mk_id ("local#" ^ string_of_int !flat_counter) in
+let flat_id orig_id =
+ let id = mk_id (string_of_name ~zencode:false orig_id ^ "_l#" ^ string_of_int !flat_counter) in
incr flat_counter;
name id
let rec flatten_instrs = function
| I_aux (I_decl (ctyp, decl_id), aux) :: instrs ->
- let fid = flat_id () in
+ let fid = flat_id decl_id in
I_aux (I_decl (ctyp, fid), aux) :: flatten_instrs (instrs_rename decl_id fid instrs)
+ | I_aux (I_init (ctyp, decl_id, cval), aux) :: instrs ->
+ let fid = flat_id decl_id in
+ I_aux (I_init (ctyp, fid, cval), aux) :: flatten_instrs (instrs_rename decl_id fid instrs)
+
| I_aux ((I_block block | I_try_block block), _) :: instrs ->
flatten_instrs block @ flatten_instrs instrs
@@ -152,20 +156,38 @@ let unique_per_function_ids cdefs =
in
List.mapi unique_cdef cdefs
-let rec frag_subst id subst = function
- | F_id id' -> if Name.compare id id' = 0 then subst else F_id id'
- | F_ref reg_id -> F_ref reg_id
- | F_lit vl -> F_lit vl
- | F_op (frag1, op, frag2) -> F_op (frag_subst id subst frag1, op, frag_subst id subst frag2)
- | F_unary (op, frag) -> F_unary (op, frag_subst id subst frag)
- | F_call (op, frags) -> F_call (op, List.map (frag_subst id subst) frags)
- | F_field (frag, field) -> F_field (frag_subst id subst frag, field)
- | F_raw str -> F_raw str
- | F_ctor_kind (frag, ctor, unifiers, ctyp) -> F_ctor_kind (frag_subst id subst frag, ctor, unifiers, ctyp)
- | F_ctor_unwrap (ctor, unifiers, frag) -> F_ctor_unwrap (ctor, unifiers, frag_subst id subst frag)
- | F_poly frag -> F_poly (frag_subst id subst frag)
-
-let cval_subst id subst (frag, ctyp) = frag_subst id subst frag, ctyp
+let rec cval_subst id subst = function
+ | V_id (id', ctyp) -> if Name.compare id id' = 0 then subst else V_id (id', ctyp)
+ | V_ref (reg_id, ctyp) -> V_ref (reg_id, ctyp)
+ | V_lit (vl, ctyp) -> V_lit (vl, ctyp)
+ | V_op (cval1, op, cval2) -> V_op (cval_subst id subst cval1, op, cval_subst id subst cval2)
+ | V_unary (op, cval) -> V_unary (op, cval_subst id subst cval)
+ | V_call (op, cvals) -> V_call (op, List.map (cval_subst id subst) cvals)
+ | V_field (cval, field) -> V_field (cval_subst id subst cval, field)
+ | V_tuple_member (cval, len, n) -> V_tuple_member (cval_subst id subst cval, len, n)
+ | V_ctor_kind (cval, ctor, unifiers, ctyp) -> V_ctor_kind (cval_subst id subst cval, ctor, unifiers, ctyp)
+ | V_ctor_unwrap (ctor, cval, unifiers, ctyp) -> V_ctor_unwrap (ctor, cval_subst id subst cval, unifiers, ctyp)
+ | V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> field, cval_subst id subst cval) fields, ctyp)
+ | V_poly (cval, ctyp) -> V_poly (cval_subst id subst cval, ctyp)
+ | V_hd cval -> V_hd (cval_subst id subst cval)
+ | V_tl cval -> V_tl (cval_subst id subst cval)
+
+let rec cval_map_id f = function
+ | V_id (id, ctyp) -> V_id (f id, ctyp)
+ | V_ref (id, ctyp) -> V_ref (f id, ctyp)
+ | V_lit (vl, ctyp) -> V_lit (vl, ctyp)
+ | V_call (call, cvals) -> V_call (call, List.map (cval_map_id f) cvals)
+ | V_op (cval1, op, cval2) -> V_op (cval_map_id f cval1, op, cval_map_id f cval2)
+ | V_unary (op, cval) -> V_unary (op, cval_map_id f cval)
+ | V_field (cval, field) -> V_field (cval_map_id f cval, field)
+ | V_tuple_member (cval, len, n) -> V_tuple_member (cval_map_id f cval, len, n)
+ | V_ctor_kind (cval, ctor, unifiers, ctyp) -> V_ctor_kind (cval_map_id f cval, ctor, unifiers, ctyp)
+ | V_ctor_unwrap (ctor, cval, unifiers, ctyp) -> V_ctor_unwrap (ctor, cval_map_id f cval, unifiers, ctyp)
+ | V_hd cval -> V_hd (cval_map_id f cval)
+ | V_tl cval -> V_tl (cval_map_id f cval)
+ | V_struct (fields, ctyp) ->
+ V_struct (List.map (fun (field, cval) -> field, cval_map_id f cval) fields, ctyp)
+ | V_poly (cval, ctyp) -> V_poly (cval_map_id f cval, ctyp)
let rec instrs_subst id subst =
function
@@ -202,7 +224,7 @@ let rec instrs_subst id subst =
| I_throw cval -> I_throw (cval_subst id subst cval)
| I_comment str -> I_comment str
| I_raw str -> I_raw str
- | I_return cval -> I_return cval
+ | I_return cval -> I_return (cval_subst id subst cval)
| I_reset (ctyp, id') -> I_reset (ctyp, id')
| I_reinit (ctyp, id', cval) -> I_reinit (ctyp, id', cval_subst id subst cval)
in
@@ -211,14 +233,13 @@ let rec instrs_subst id subst =
| [] -> []
let rec clexp_subst id subst = function
- | CL_id (id', ctyp) when Name.compare id id' = 0 ->
- assert (ctyp_equal ctyp (clexp_ctyp subst));
- subst
+ | CL_id (id', ctyp) when Name.compare id id' = 0 -> subst
| CL_id (id', ctyp) -> CL_id (id', ctyp)
| CL_field (clexp, field) -> CL_field (clexp_subst id subst clexp, field)
| CL_addr clexp -> CL_addr (clexp_subst id subst clexp)
| CL_tuple (clexp, n) -> CL_tuple (clexp_subst id subst clexp, n)
| CL_void -> CL_void
+ | CL_rmw _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot substitute into read-modify-write construct"
let rec find_function fid = function
| CDEF_fundef (fid', heap_return, args, body) :: _ when Id.compare fid fid' = 0 ->
@@ -228,8 +249,15 @@ let rec find_function fid = function
| [] -> None
+let ssa_name i = function
+ | Name (id, _) -> Name (id, i)
+ | Have_exception _ -> Have_exception i
+ | Current_exception _ -> Current_exception i
+ | Return _ -> Return i
+
let inline cdefs should_inline instrs =
let inlines = ref (-1) in
+ let label_count = ref (-1) in
let replace_return subst = function
| I_aux (I_funcall (clexp, extern, fid, args), aux) ->
@@ -245,13 +273,51 @@ let inline cdefs should_inline instrs =
| instr -> instr
in
+ let fix_labels =
+ let fix_label l = "inline" ^ string_of_int !label_count ^ "_" ^ l in
+ function
+ | I_aux (I_goto label, aux) -> I_aux (I_goto (fix_label label), aux)
+ | I_aux (I_label label, aux) -> I_aux (I_label (fix_label label), aux)
+ | I_aux (I_jump (cval, label), aux) -> I_aux (I_jump (cval, fix_label label), aux)
+ | instr -> instr
+ in
+
+ let fix_substs =
+ let f = cval_map_id (ssa_name (-1)) in
+ function
+ | I_aux (I_init (ctyp, id, cval), aux) ->
+ I_aux (I_init (ctyp, id, f cval), aux)
+ | I_aux (I_jump (cval, label), aux) ->
+ I_aux (I_jump (f cval, label), aux)
+ | I_aux (I_funcall (clexp, extern, function_id, args), aux) ->
+ I_aux (I_funcall (clexp, extern, function_id, List.map f args), aux)
+ | I_aux (I_if (cval, then_instrs, else_instrs, ctyp), aux) ->
+ I_aux (I_if (f cval, then_instrs, else_instrs, ctyp), aux)
+ | I_aux (I_copy (clexp, cval), aux) ->
+ I_aux (I_copy (clexp, f cval), aux)
+ | I_aux (I_return cval, aux) ->
+ I_aux (I_return (f cval), aux)
+ | I_aux (I_throw cval, aux) ->
+ I_aux (I_throw (f cval), aux)
+ | instr -> instr
+ in
+
let rec inline_instr = function
| I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline function_id ->
begin match find_function function_id cdefs with
| Some (None, ids, body) ->
incr inlines;
+ incr label_count;
let inline_label = label "end_inline_" in
- let body = List.fold_right2 instrs_subst (List.map name ids) (List.map fst args) body in
+ (* For situations where we have e.g. x => x' and x' => y, we
+ use a dummy SSA number turning this into x => x'/-2 and
+ x' => y/-2, ensuring x's won't get turned into y's. This
+ is undone by fix_substs which removes the -2 SSA
+ numbers. *)
+ let args = List.map (cval_map_id (ssa_name (-2))) args in
+ let body = List.fold_right2 instrs_subst (List.map name ids) args body in
+ let body = List.map (map_instr fix_substs) body in
+ let body = List.map (map_instr fix_labels) body in
let body = List.map (map_instr (replace_end inline_label)) body in
let body = List.map (map_instr (replace_return clexp)) body in
I_aux (I_block (body @ [ilabel inline_label]), aux)
@@ -275,3 +341,34 @@ let inline cdefs should_inline instrs =
instrs
in
go instrs
+
+let rec remove_pointless_goto = function
+ | I_aux (I_goto label, _) :: I_aux (I_label label', aux) :: instrs when label = label' ->
+ I_aux (I_label label', aux) :: remove_pointless_goto instrs
+ | I_aux (I_goto label, aux) :: I_aux (I_goto _, _) :: instrs ->
+ I_aux (I_goto label, aux) :: remove_pointless_goto instrs
+ | instr :: instrs ->
+ instr :: remove_pointless_goto instrs
+ | [] -> []
+
+module StringSet = Set.Make(String)
+
+let rec get_used_labels set = function
+ | I_aux (I_goto label, _) :: instrs -> get_used_labels (StringSet.add label set) instrs
+ | I_aux (I_jump (_, label), _) :: instrs -> get_used_labels (StringSet.add label set) instrs
+ | _ :: instrs -> get_used_labels set instrs
+ | [] -> set
+
+let remove_unused_labels instrs =
+ let used = get_used_labels StringSet.empty instrs in
+ let rec go acc = function
+ | I_aux (I_label label, _) :: instrs when not (StringSet.mem label used) -> go acc instrs
+ | instr :: instrs -> go (instr :: acc) instrs
+ | [] -> List.rev acc
+ in
+ go [] instrs
+
+let rec remove_clear = function
+ | I_aux (I_clear _, _) :: instrs -> remove_clear instrs
+ | instr :: instrs -> instr :: remove_clear instrs
+ | [] -> []