diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/graph/DiGraph.scala | 3 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala | 30 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/graph/DiGraphTests.scala | 11 |
3 files changed, 43 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index 7e56919c..b869982d 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -302,7 +302,8 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * @return a transformed DiGraph[Q] */ def transformNodes[Q](f: (T) => Q): DiGraph[Q] = { - val eprime = edges.map({ case (k, v) => (f(k), v.map(f(_))) }) + val eprime = edges.map({ case (k, _) => (f(k), new LinkedHashSet[Q]) }) + edges.foreach({ case (k, v) => eprime(f(k)) ++= v.map(f(_)) }) new DiGraph(eprime) } diff --git a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala index 3e517079..f01de1f4 100644 --- a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala +++ b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala @@ -65,4 +65,34 @@ circuit Top : val graph = new InstanceGraph(circuit).graph.transformNodes(_.module) getEdgeSet(graph) shouldBe Map("Top" -> Set("Child1"), "Top2" -> Set("Child2", "Child3"), "Child2" -> Set("Child2a", "Child2b"), "Child1" -> Set(), "Child2a" -> Set(), "Child2b" -> Set(), "Child3" -> Set()) } + + it should "not drop duplicate nodes when they collide as a result of transformNodes" in { + val input = +"""circuit Top : + module Buzz : + skip + module Fizz : + inst b of Buzz + module Foo : + inst f1 of Fizz + module Bar : + inst f2 of Fizz + module Top : + inst f of Foo + inst b of Bar +""" + val circuit = ToWorkingIR.run(parse(input)) + val graph = (new InstanceGraph(circuit)).graph + + // Create graphs with edges from child to parent module + // g1 has collisions on parents to children, ie. it combines: + // (f1, Fizz) -> (b, Buzz) and (f2, Fizz) -> (b, Buzz) + val g1 = graph.transformNodes(_.module).reverse + g1.getEdges("Fizz") shouldBe Set("Foo", "Bar") + + val g2 = graph.reverse.transformNodes(_.module) + // g2 combines + // (f1, Fizz) -> (f, Foo) and (f2, Fizz) -> (b, Bar) + g2.getEdges("Fizz") shouldBe Set("Foo", "Bar") + } } diff --git a/src/test/scala/firrtlTests/graph/DiGraphTests.scala b/src/test/scala/firrtlTests/graph/DiGraphTests.scala index 84122f83..b9f51699 100644 --- a/src/test/scala/firrtlTests/graph/DiGraphTests.scala +++ b/src/test/scala/firrtlTests/graph/DiGraphTests.scala @@ -29,6 +29,13 @@ class DiGraphTests extends FirrtlFlatSpec { "c" -> Set("d"), "d" -> Set("a"))) + val tupleGraph = DiGraph(Map( + ("a", 0) -> Set(("b", 2)), + ("a", 1) -> Set(("c", 3)), + ("b", 2) -> Set.empty[(String, Int)], + ("c", 3) -> Set.empty[(String, Int)] + )) + val degenerateGraph = DiGraph(Map("a" -> Set.empty[String])) acyclicGraph.findSCCs.filter(_.length > 1) shouldBe empty @@ -47,4 +54,8 @@ class DiGraphTests extends FirrtlFlatSpec { degenerateGraph.getEdgeMap should equal (degenerateGraph.reverse.getEdgeMap) + "transformNodes" should "combine vertices that collide, not drop them" in { + tupleGraph.transformNodes(_._1).getEdgeMap should contain ("a" -> Set("b", "c")) + } + } |
