diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/transforms/MustDedup.scala | 245 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/transforms/MustDedupSpec.scala | 267 |
2 files changed, 512 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/MustDedup.scala b/src/main/scala/firrtl/transforms/MustDedup.scala new file mode 100644 index 00000000..3e7629cd --- /dev/null +++ b/src/main/scala/firrtl/transforms/MustDedup.scala @@ -0,0 +1,245 @@ +// See LICENSE for license details. + +package firrtl.transforms + +import firrtl._ +import firrtl.annotations._ +import firrtl.annotations.TargetToken.OfModule +import firrtl.analyses.InstanceKeyGraph +import firrtl.analyses.InstanceKeyGraph.InstanceKey +import firrtl.options.Dependency +import firrtl.stage.Forms +import firrtl.graph.DiGraph + +import java.io.{File, FileWriter} + +/** Marks modules as "must deduplicate" */ +case class MustDeduplicateAnnotation(modules: Seq[IsModule]) extends MultiTargetAnnotation { + def targets: Seq[Seq[IsModule]] = modules.map(Seq(_)) + + def duplicate(n: Seq[Seq[Target]]): MustDeduplicateAnnotation = { + val newModules = n.map { + case Seq(mod: IsModule) => mod + case _ => + val msg = "Something went wrong! This anno should only rename to single IsModules! " + + s"Got: $modules -> $n" + throw new Exception(msg) + } + MustDeduplicateAnnotation(newModules) + } +} + +/** Specifies the directory where errors for modules that "must deduplicate" will be reported */ +case class MustDeduplicateReportDirectory(directory: String) extends NoTargetAnnotation + +object MustDeduplicateTransform { + sealed trait DedupFailureCandidate { + def message: String + def modules: Seq[OfModule] + } + case class LikelyShouldMatch(a: OfModule, b: OfModule) extends DedupFailureCandidate { + def message: String = s"Modules '${a.value}' and '${b.value}' likely should dedup but do not." + def modules = Seq(a, b) + } + object DisjointChildren { + sealed trait Reason + case object Left extends Reason + case object Right extends Reason + case object Both extends Reason + } + import DisjointChildren._ + case class DisjointChildren(a: OfModule, b: OfModule, reason: Reason) extends DedupFailureCandidate { + def message: String = { + def helper(x: OfModule, y: OfModule): String = s"'${x.value}' contains instances not found in '${y.value}'" + val why = reason match { + case Left => helper(a, b) + case Right => helper(b, a) + case Both => s"${helper(a, b)} and ${helper(b, a)}" + } + s"Modules '${a.value}' and '${b.value}' cannot be deduplicated because $why." + } + def modules = Seq(a, b) + } + + final class DeduplicationFailureException(msg: String) extends FirrtlUserException(msg) + + case class DedupFailure( + shouldDedup: Seq[OfModule], + relevantMods: Set[OfModule], + candidates: Seq[DedupFailureCandidate]) + + /** Reports deduplication failures two Modules + * + * @return (Set of Modules that only appear in one hierarchy or the other, candidate pairs of Module names) + */ + def findDedupFailures(shouldDedup: Seq[OfModule], graph: InstanceKeyGraph): DedupFailure = { + val instLookup = graph.getChildInstances.toMap + def recurse(a: OfModule, b: OfModule): Seq[DedupFailureCandidate] = { + val as = instLookup(a.value) + val bs = instLookup(b.value) + if (as.length != bs.length) { + val aa = as.toSet + val bb = bs.toSet + val reason = (aa.diff(bb).nonEmpty, bb.diff(aa).nonEmpty) match { + case (true, true) => Both + case (true, false) => Left + case (false, true) => Right + case _ => Utils.error("Impossible!") + } + Seq(DisjointChildren(a, b, reason)) + } else { + val fromChildren = as.zip(bs).flatMap { + case (ax, bx) => recurse(ax.OfModule, bx.OfModule) + } + if (fromChildren.nonEmpty) fromChildren + else if (a != b) Seq(LikelyShouldMatch(a, b)) + else Nil + } + } + + val allMismatches = { + // Recalculating this every time is a little wasteful, but we're on a failure path anyway + val digraph = graph.graph.transformNodes(_.OfModule) + val froms = shouldDedup.map(x => digraph.reachableFrom(x) + x) + val union = froms.reduce(_ union _) + val intersection = froms.reduce(_ intersect _) + union.diff(intersection) + }.toSet + val pairs = shouldDedup.tail.map(n => (shouldDedup.head, n)) + val candidates = pairs.flatMap { case (a, b) => recurse(a, b) } + DedupFailure(shouldDedup, allMismatches, candidates) + } + + // Find the minimal number of vertices in the graph to show paths from "mustDedup" to failure + // candidates and their context (eg. children for DisjoinChildren) + private def findNodesToKeep(failure: DedupFailure, graph: DiGraph[String]): collection.Set[String] = { + val shouldDedup = failure.shouldDedup.map(_.value).toSet + val nodeOfInterest: Set[String] = + shouldDedup ++ failure.candidates.flatMap { + case LikelyShouldMatch(OfModule(a), OfModule(b)) => Seq(a, b) + case DisjointChildren(OfModule(a), OfModule(b), _) => + Seq(a, b) ++ graph.getEdges(a) ++ graph.getEdges(b) + } + // Depth-first search looking for relevant nodes + def dfs(node: String): collection.Set[String] = { + val deeper = graph.getEdges(node).flatMap(dfs) + if (deeper.nonEmpty || nodeOfInterest(node)) deeper + node else deeper + } + shouldDedup.flatMap(dfs) + } + + /** Turn a [[DedupFailure]] into a pretty graph for visualization + * + * @param failure Failure to visualize + * @param graph DiGraph of module names (no instance information) + */ + def makeDedupFailureDiGraph(failure: DedupFailure, graph: DiGraph[String]): DiGraph[String] = { + // Recalculating this every time is a little wasteful, but we're on a failure path anyway + // Lookup the parent Module name of any Module + val getParents: String => Seq[String] = + graph.reverse.getEdgeMap + .mapValues(_.toSeq) + + val candidates = failure.candidates + val shouldDedup = failure.shouldDedup.map(_.value) + val shouldDedupSet = shouldDedup.toSet + val mygraph = { + // Create a graph of paths from "shouldDedup" nodes to the candidates + // rooted at the "shouldDedup" nodes + val nodesToKeep = findNodesToKeep(failure, graph) + graph.subgraph(nodesToKeep) + + // Add fake nodes to represent parents of the "shouldDedup" nodes + DiGraph(shouldDedup.map(n => getParents(n).mkString(", ") -> n): _*) + } + // Gather candidate modules and assign indices for reference + val candidateIdx: Map[String, Int] = + candidates.zipWithIndex.flatMap { case (c, idx) => c.modules.map(_.value -> idx) }.toMap + // Now mark the graph for modules of interest + val markedGraph = mygraph.transformNodes { n => + val next = if (shouldDedupSet(n)) s"($n)" else n + candidateIdx + .get(n) + .map(i => s"$next [$i]") + .getOrElse(next) + } + markedGraph + } +} + +/** Checks for modules that have been marked as "must deduplicate" + * + * In cases where marked modules did not deduplicate, this transform attempts to provide context on + * what went wrong for debugging. + */ +class MustDeduplicateTransform extends Transform with DependencyAPIMigration { + import MustDeduplicateTransform._ + + override def prerequisites = Seq(Dependency[DedupModules]) + + // Make this run as soon after Dedup as possible + override def optionalPrerequisiteOf = (Forms.MidForm.toSet -- Forms.HighForm).toSeq + + override def invalidates(a: Transform) = false + + def execute(state: CircuitState): CircuitState = { + + lazy val igraph = InstanceKeyGraph(state.circuit) + + val dedupFailures: Seq[DedupFailure] = + state.annotations.flatMap { + case MustDeduplicateAnnotation(mods) => + val moduleNames = mods.map(_.leafModule).distinct + if (moduleNames.size <= 1) None + else { + val modNames = moduleNames.map(OfModule) + Some(findDedupFailures(modNames, igraph)) + } + case _ => None + } + if (dedupFailures.nonEmpty) { + val modgraph = igraph.graph.transformNodes(_.module) + // Create and log reports + val reports = dedupFailures.map { + case fail @ DedupFailure(shouldDedup, _, candidates) => + val graph = makeDedupFailureDiGraph(fail, modgraph).prettyTree() + val mods = shouldDedup.map("'" + _.value + "'").mkString(", ") + val msg = + s"""===== $mods are marked as "must deduplicate", but did not deduplicate. ===== + |$graph + |Failure candidates: + |${candidates.zipWithIndex.map { case (c, i) => s" - [$i] " + c.message }.mkString("\n")} + |""".stripMargin + logger.error(msg) + msg + } + + // Write reports and modules to disk + val dirName = state.annotations.collectFirst { case MustDeduplicateReportDirectory(dir) => dir } + .getOrElse("dedup_failures") + val dir = new File(dirName) + logger.error(s"Writing error report(s) to ${dir}...") + FileUtils.makeDirectory(dir.toString) + for ((report, idx) <- reports.zipWithIndex) { + val f = new File(dir, s"report_$idx.rpt") + logger.error(s"Writing $f...") + val fw = new FileWriter(f) + fw.write(report) + fw.close() + } + + val modsDir = new File(dir, "modules") + FileUtils.makeDirectory(modsDir.toString) + logger.error(s"Writing relevant modules to $modsDir...") + val relevantModule = dedupFailures.flatMap(_.relevantMods.map(_.value)).toSet + for (mod <- state.circuit.modules if relevantModule(mod.name)) { + val fw = new FileWriter(new File(modsDir, s"${mod.name}.fir")) + fw.write(mod.serialize) + fw.close() + } + + val msg = s"Modules marked 'must deduplicate' failed to deduplicate! See error reports in $dirName" + throw new DeduplicationFailureException(msg) + } + state + } +} diff --git a/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala b/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala new file mode 100644 index 00000000..2f633e0e --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests.transforms + +import org.scalatest.featurespec.AnyFeatureSpec +import org.scalatest.GivenWhenThen +import firrtl.testutils.FirrtlMatchers +import java.io.File + +import firrtl.graph.DiGraph +import firrtl.analyses.InstanceKeyGraph +import firrtl.annotations.CircuitTarget +import firrtl.annotations.TargetToken.OfModule +import firrtl.transforms._ +import firrtl.transforms.MustDeduplicateTransform._ +import firrtl.transforms.MustDeduplicateTransform.DisjointChildren._ +import firrtl.util.BackendCompilationUtilities.createTestDirectory +import firrtl.stage.{FirrtlSourceAnnotation, RunFirrtlTransformAnnotation} +import firrtl.options.{TargetDirAnnotation} +import logger.{LogLevel, LogLevelAnnotation, Logger} + +class MustDedupSpec extends AnyFeatureSpec with FirrtlMatchers with GivenWhenThen { + + Feature("When you have a simple non-deduping hierarcy") { + val text = """ + |circuit A : + | module C : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module C_1 : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= and(io.in, UInt("hff")) + | module B : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of C + | io <= c.io + | module B_1 : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of C_1 + | io <= c.io + | module A : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst b of B + | inst b_1 of B_1 + | io.out <= and(b.io.out, b_1.io.out) + | b.io.in <= io.in + | b_1.io.in <= io.in + """.stripMargin + val top = CircuitTarget("A") + val bdedup = MustDeduplicateAnnotation(Seq(top.module("B"), top.module("B_1"))) + val igraph = InstanceKeyGraph(parse(text)) + + Scenario("Full compilation should fail and dump reports to disk") { + val testDir = createTestDirectory("must_dedup") + val reportDir = new File(testDir, "reports") + val annos = Seq( + TargetDirAnnotation(testDir.toString), + FirrtlSourceAnnotation(text), + RunFirrtlTransformAnnotation(new MustDeduplicateTransform), + MustDeduplicateReportDirectory(reportDir.toString), + bdedup + ) + + a[DeduplicationFailureException] shouldBe thrownBy { + (new firrtl.stage.FirrtlPhase).transform(annos) + } + + reportDir should exist + + val report0 = new File(reportDir, "report_0.rpt") + report0 should exist + + val expectedModules = Seq("B", "B_1", "C", "C_1") + for (mod <- expectedModules) { + new File(reportDir, s"modules/$mod.fir") should exist + } + } + + Scenario("Non-deduping children should give actionable debug information") { + When("Finding dedup failures") + val failure = findDedupFailures(Seq(OfModule("B"), OfModule("B_1")), igraph) + + Then("The children should appear as a failure candidate") + failure.candidates should be(Seq(LikelyShouldMatch(OfModule("C"), OfModule("C_1")))) + + And("There should be a pretty DiGraph showing context") + val got = makeDedupFailureDiGraph(failure, igraph.graph.transformNodes(_.module)) + val expected = DiGraph("A" -> "(B)", "A" -> "(B_1)", "(B)" -> "C [0]", "(B_1)" -> "C_1 [0]") + // DiGraph uses referential equality so compare serialized form + got.prettyTree() should be(expected.prettyTree()) + } + + Scenario("Unrelated hierarchies should give actionable debug information") { + When("Finding dedup failures") + val failure = findDedupFailures(Seq(OfModule("B"), OfModule("C_1")), igraph) + + Then("The failure should note the hierarchies don't match") + failure.candidates should be(Seq(DisjointChildren(OfModule("B"), OfModule("C_1"), Left))) + + And("There should be a pretty DiGraph showing context") + val got = makeDedupFailureDiGraph(failure, igraph.graph.transformNodes(_.module)) + val expected = DiGraph("A" -> "(B) [0]", "(B) [0]" -> "C", "B_1" -> "(C_1) [0]") + // DiGraph uses referential equality so compare serialized form + got.prettyTree() should be(expected.prettyTree()) + } + } + + Feature("When you have a deep, non-deduping hierarchy") { + // Shadow hierarchy just to get an InstanceKeyGraph which can only be made from a circuit + val text = parse(""" + |circuit A : + | module E: + | skip + | module F : + | skip + | module F_1 : + | inst e of E + | module D : + | skip + | module D_1 : + | skip + | module C : + | inst d of D + | inst f of F + | module C_1 : + | inst d of D_1 + | inst f of F_1 + | module B : + | inst c of C + | inst e of E + | module B_1 : + | inst c of C_1 + | inst e of E + | module A : + | inst b of B + | inst b_1 of B_1 + |""".stripMargin) + val igraph = InstanceKeyGraph(text) + + Scenario("Non-deduping children should give actionable debug information") { + When("Finding dedup failures") + val failure = findDedupFailures(Seq(OfModule("B"), OfModule("B_1")), igraph) + + Then("The children should appear as a failure candidate") + failure.candidates should be( + Seq(LikelyShouldMatch(OfModule("D"), OfModule("D_1")), DisjointChildren(OfModule("F"), OfModule("F_1"), Right)) + ) + + And("There should be a pretty DiGraph showing context") + val got = makeDedupFailureDiGraph(failure, igraph.graph.transformNodes(_.module)) + val expected = DiGraph( + "A" -> "(B)", + "A" -> "(B_1)", + "(B)" -> "C", + "C" -> "D [0]", + "C" -> "F [1]", + "(B_1)" -> "C_1", + "C_1" -> "D_1 [0]", + "C_1" -> "F_1 [1]", + "F_1 [1]" -> "E", + // These last 2 are undesirable but E is included because it's a submodule of disjoint F and F_1 + "(B)" -> "E", + "(B_1)" -> "E" + ) + // DiGraph uses referential equality so compare serialized form + got.prettyTree() should be(expected.prettyTree()) + } + } + + Feature("When you have multiple modules that should dedup, but don't") { + // Shadow hierarchy just to get an InstanceKeyGraph which can only be made from a circuit + val text = parse(""" + |circuit A : + | module D : + | skip + | module D_1 : + | skip + | module C : + | skip + | module C_1 : + | skip + | module B : + | inst c of C + | inst d of D + | module B_1 : + | inst c of C_1 + | inst d of D + | module B_2 : + | inst c of C + | inst d of D_1 + | module A : + | inst b of B + | inst b_1 of B_1 + | inst b_2 of B_2 + |""".stripMargin) + val igraph = InstanceKeyGraph(text) + + Scenario("Non-deduping children should give actionable debug information") { + When("Finding dedup failures") + val failure = findDedupFailures(Seq(OfModule("B"), OfModule("B_1"), OfModule("B_2")), igraph) + + Then("The children should appear as a failure candidate") + failure.candidates should be( + Seq(LikelyShouldMatch(OfModule("C"), OfModule("C_1")), LikelyShouldMatch(OfModule("D"), OfModule("D_1"))) + ) + + And("There should be a pretty DiGraph showing context") + val got = makeDedupFailureDiGraph(failure, igraph.graph.transformNodes(_.module)) + val expected = DiGraph( + "A" -> "(B)", + "A" -> "(B_1)", + "A" -> "(B_2)", + "(B)" -> "C [0]", + "(B)" -> "D [1]", + "(B_1)" -> "C_1 [0]", + "(B_1)" -> "D [1]", + "(B_2)" -> "C [0]", + "(B_2)" -> "D_1 [1]" + ) + // DiGraph uses referential equality so compare serialized form + got.prettyTree() should be(expected.prettyTree()) + } + } + + Feature("When you have modules that should dedup, and they do") { + val text = """ + |circuit A : + | module C : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module C_1 : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module B : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of C + | io <= c.io + | module B_1 : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of C_1 + | io <= c.io + | module A : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst b of B + | inst b_1 of B_1 + | io.out <= and(b.io.out, b_1.io.out) + | b.io.in <= io.in + | b_1.io.in <= io.in + """.stripMargin + val top = CircuitTarget("A") + val bdedup = MustDeduplicateAnnotation(Seq(top.module("B"), top.module("B_1"))) + + Scenario("Full compilation should succeed") { + val testDir = createTestDirectory("must_dedup") + val reportDir = new File(testDir, "reports") + val annos = Seq( + TargetDirAnnotation(testDir.toString), + FirrtlSourceAnnotation(text), + RunFirrtlTransformAnnotation(new MustDeduplicateTransform), + MustDeduplicateReportDirectory(reportDir.toString), + bdedup + ) + + (new firrtl.stage.FirrtlPhase).transform(annos) + } + } +} |
