diff options
Diffstat (limited to 'src/c_backend.ml')
| -rw-r--r-- | src/c_backend.ml | 35 |
1 files changed, 20 insertions, 15 deletions
diff --git a/src/c_backend.ml b/src/c_backend.ml index 841f9eb1..ec5fa580 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -1902,7 +1902,8 @@ let rec specialize_variants ctx = let sort_ctype_defs cdefs = (* Split the cdefs into type definitions and non type definitions *) let is_ctype_def = function CDEF_type _ -> true | _ -> false in - let ctype_defs = List.filter is_ctype_def cdefs in + let unwrap = function CDEF_type ctdef -> ctdef | _ -> assert false in + let ctype_defs = List.map unwrap (List.filter is_ctype_def cdefs) in let cdefs = List.filter (fun cdef -> not (is_ctype_def cdef)) cdefs in let ctdef_id = function @@ -1915,22 +1916,26 @@ let sort_ctype_defs cdefs = List.fold_left (fun ids (_, ctyp) -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctors in - let depends_on cdef1 cdef2 = - match cdef1, cdef2 with - | CDEF_type ctdef1, CDEF_type ctdef2 -> - let ctdef1_ids = ctdef_ids ctdef1 in - let ctdef2_ids = ctdef_ids ctdef2 in - if IdSet.mem (ctdef_id ctdef1) ctdef2_ids && IdSet.mem (ctdef_id ctdef2) ctdef1_ids then - c_error (Printf.sprintf "Post specialization circular dependency between %s and %s found" - (string_of_id (ctdef_id ctdef1)) - (string_of_id (ctdef_id ctdef2))) - else if IdSet.mem (ctdef_id ctdef1) ctdef2_ids then -1 - else 1 - - | _, _ -> assert false (* We only use this on ctype_defs *) + (* Create a reverse id graph of dependencies between types *) + let module IdGraph = Graph.Make(Id) in + + let graph = + List.fold_left (fun g ctdef -> + List.fold_left (fun g id -> IdGraph.add_edge id (ctdef_id ctdef) g) + (IdGraph.add_edges (ctdef_id ctdef) [] g) (* Make sure even types with no dependencies are in graph *) + (IdSet.elements (ctdef_ids ctdef))) + IdGraph.empty + ctype_defs in - List.sort depends_on ctype_defs @ cdefs + (* Then select the ctypes in the correct order as given by the topsort *) + let ids = IdGraph.topsort graph in + let ctype_defs = + List.map (fun id -> CDEF_type (List.find (fun ctdef -> Id.compare (ctdef_id ctdef) id = 0) ctype_defs)) ids + in + + ctype_defs @ cdefs + (* (* When this optimization fires we know we have bytecode of the form |
