aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorJack Koenig2021-02-16 14:03:55 -0800
committerGitHub2021-02-16 14:03:55 -0800
commit856226416cfa2d770c7205efad5331297c2e3a32 (patch)
tree49feb47fa0634bf527261fc108ad16141901bb58 /src/main
parent6e0e760526090c694ce6507db71122654ffc3000 (diff)
parent5903e6a3bfce415c57c19865675db131b733c159 (diff)
Merge pull request #2077 from chipsalliance/must-dedup
Add "Must Deduplicate" API
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala47
-rw-r--r--src/main/scala/firrtl/transforms/MustDedup.scala245
2 files changed, 292 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
index 3a08d05e..9f6ffeb2 100644
--- a/src/main/scala/firrtl/graph/DiGraph.scala
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -4,6 +4,7 @@ package firrtl.graph
import scala.collection.{mutable, Map, Set}
import scala.collection.mutable.{LinkedHashMap, LinkedHashSet}
+import firrtl.options.DependencyManagerUtils.{CharSet, PrettyCharSet}
/** An exception that is raised when an assumed DAG has a cycle */
class CyclicException(val node: Any) extends Exception(s"No valid linearization for cyclic graph, found at $node")
@@ -31,6 +32,16 @@ object DiGraph {
}
new DiGraph(edgeDataCopy)
}
+
+ /** Create a DiGraph from edges */
+ def apply[T](edges: (T, T)*): DiGraph[T] = {
+ val edgeMap = new LinkedHashMap[T, LinkedHashSet[T]]
+ for ((from, to) <- edges) {
+ val set = edgeMap.getOrElseUpdate(from, new LinkedHashSet[T])
+ set += to
+ }
+ new DiGraph(edgeMap)
+ }
}
/** Represents common behavior of all directed graphs */
@@ -386,6 +397,42 @@ class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) {
that.edges.foreach({ case (k, v) => eprime.getOrElseUpdate(k, new LinkedHashSet[T]) ++= v })
new DiGraph(eprime)
}
+
+ /** Serializes a `DiGraph[String]` as a pretty tree
+ *
+ * Multiple roots are supported, but cycles are not.
+ */
+ def prettyTree(charSet: CharSet = PrettyCharSet)(implicit ev: T =:= String): String = {
+ // Set up characters for building the tree
+ val (l, n, c) = (charSet.lastNode, charSet.notLastNode, charSet.continuation)
+ val ctab = " " * c.size + " "
+
+ // Recursively adds each node of the DiGraph to accumulating List[String]
+ // Uses List because prepend is cheap and this prevents quadratic behavior of String
+ // concatenations or even flatMapping on Seqs
+ def rec(tab: String, node: T, mark: String, prev: List[String]): List[String] = {
+ val here = s"$mark$node"
+ val children = this.getEdges(node)
+ val last = children.size - 1
+ children.toList // Convert LinkedHashSet to List to avoid determinism issues
+ .zipWithIndex // Find last
+ .foldLeft(here :: prev) {
+ case (acc, (nodex, idx)) =>
+ val nextTab = if (idx == last) tab + ctab else tab + c + " "
+ val nextMark = if (idx == last) tab + l else tab + n
+ rec(nextTab, nodex, nextMark + " ", acc)
+ }
+ }
+ this.findSources
+ .toList // Convert LinkedHashSet to List to avoid determinism issues
+ .sortBy(_.toString) // Make order deterministic
+ .foldLeft(Nil: List[String]) {
+ case (acc, root) => rec("", root, "", acc)
+ }
+ .reverse
+ .mkString("\n")
+ }
+
}
class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]]) {
diff --git a/src/main/scala/firrtl/transforms/MustDedup.scala b/src/main/scala/firrtl/transforms/MustDedup.scala
new file mode 100644
index 00000000..3e7629cd
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/MustDedup.scala
@@ -0,0 +1,245 @@
+// See LICENSE for license details.
+
+package firrtl.transforms
+
+import firrtl._
+import firrtl.annotations._
+import firrtl.annotations.TargetToken.OfModule
+import firrtl.analyses.InstanceKeyGraph
+import firrtl.analyses.InstanceKeyGraph.InstanceKey
+import firrtl.options.Dependency
+import firrtl.stage.Forms
+import firrtl.graph.DiGraph
+
+import java.io.{File, FileWriter}
+
+/** Marks modules as "must deduplicate" */
+case class MustDeduplicateAnnotation(modules: Seq[IsModule]) extends MultiTargetAnnotation {
+ def targets: Seq[Seq[IsModule]] = modules.map(Seq(_))
+
+ def duplicate(n: Seq[Seq[Target]]): MustDeduplicateAnnotation = {
+ val newModules = n.map {
+ case Seq(mod: IsModule) => mod
+ case _ =>
+ val msg = "Something went wrong! This anno should only rename to single IsModules! " +
+ s"Got: $modules -> $n"
+ throw new Exception(msg)
+ }
+ MustDeduplicateAnnotation(newModules)
+ }
+}
+
+/** Specifies the directory where errors for modules that "must deduplicate" will be reported */
+case class MustDeduplicateReportDirectory(directory: String) extends NoTargetAnnotation
+
+object MustDeduplicateTransform {
+ sealed trait DedupFailureCandidate {
+ def message: String
+ def modules: Seq[OfModule]
+ }
+ case class LikelyShouldMatch(a: OfModule, b: OfModule) extends DedupFailureCandidate {
+ def message: String = s"Modules '${a.value}' and '${b.value}' likely should dedup but do not."
+ def modules = Seq(a, b)
+ }
+ object DisjointChildren {
+ sealed trait Reason
+ case object Left extends Reason
+ case object Right extends Reason
+ case object Both extends Reason
+ }
+ import DisjointChildren._
+ case class DisjointChildren(a: OfModule, b: OfModule, reason: Reason) extends DedupFailureCandidate {
+ def message: String = {
+ def helper(x: OfModule, y: OfModule): String = s"'${x.value}' contains instances not found in '${y.value}'"
+ val why = reason match {
+ case Left => helper(a, b)
+ case Right => helper(b, a)
+ case Both => s"${helper(a, b)} and ${helper(b, a)}"
+ }
+ s"Modules '${a.value}' and '${b.value}' cannot be deduplicated because $why."
+ }
+ def modules = Seq(a, b)
+ }
+
+ final class DeduplicationFailureException(msg: String) extends FirrtlUserException(msg)
+
+ case class DedupFailure(
+ shouldDedup: Seq[OfModule],
+ relevantMods: Set[OfModule],
+ candidates: Seq[DedupFailureCandidate])
+
+ /** Reports deduplication failures two Modules
+ *
+ * @return (Set of Modules that only appear in one hierarchy or the other, candidate pairs of Module names)
+ */
+ def findDedupFailures(shouldDedup: Seq[OfModule], graph: InstanceKeyGraph): DedupFailure = {
+ val instLookup = graph.getChildInstances.toMap
+ def recurse(a: OfModule, b: OfModule): Seq[DedupFailureCandidate] = {
+ val as = instLookup(a.value)
+ val bs = instLookup(b.value)
+ if (as.length != bs.length) {
+ val aa = as.toSet
+ val bb = bs.toSet
+ val reason = (aa.diff(bb).nonEmpty, bb.diff(aa).nonEmpty) match {
+ case (true, true) => Both
+ case (true, false) => Left
+ case (false, true) => Right
+ case _ => Utils.error("Impossible!")
+ }
+ Seq(DisjointChildren(a, b, reason))
+ } else {
+ val fromChildren = as.zip(bs).flatMap {
+ case (ax, bx) => recurse(ax.OfModule, bx.OfModule)
+ }
+ if (fromChildren.nonEmpty) fromChildren
+ else if (a != b) Seq(LikelyShouldMatch(a, b))
+ else Nil
+ }
+ }
+
+ val allMismatches = {
+ // Recalculating this every time is a little wasteful, but we're on a failure path anyway
+ val digraph = graph.graph.transformNodes(_.OfModule)
+ val froms = shouldDedup.map(x => digraph.reachableFrom(x) + x)
+ val union = froms.reduce(_ union _)
+ val intersection = froms.reduce(_ intersect _)
+ union.diff(intersection)
+ }.toSet
+ val pairs = shouldDedup.tail.map(n => (shouldDedup.head, n))
+ val candidates = pairs.flatMap { case (a, b) => recurse(a, b) }
+ DedupFailure(shouldDedup, allMismatches, candidates)
+ }
+
+ // Find the minimal number of vertices in the graph to show paths from "mustDedup" to failure
+ // candidates and their context (eg. children for DisjoinChildren)
+ private def findNodesToKeep(failure: DedupFailure, graph: DiGraph[String]): collection.Set[String] = {
+ val shouldDedup = failure.shouldDedup.map(_.value).toSet
+ val nodeOfInterest: Set[String] =
+ shouldDedup ++ failure.candidates.flatMap {
+ case LikelyShouldMatch(OfModule(a), OfModule(b)) => Seq(a, b)
+ case DisjointChildren(OfModule(a), OfModule(b), _) =>
+ Seq(a, b) ++ graph.getEdges(a) ++ graph.getEdges(b)
+ }
+ // Depth-first search looking for relevant nodes
+ def dfs(node: String): collection.Set[String] = {
+ val deeper = graph.getEdges(node).flatMap(dfs)
+ if (deeper.nonEmpty || nodeOfInterest(node)) deeper + node else deeper
+ }
+ shouldDedup.flatMap(dfs)
+ }
+
+ /** Turn a [[DedupFailure]] into a pretty graph for visualization
+ *
+ * @param failure Failure to visualize
+ * @param graph DiGraph of module names (no instance information)
+ */
+ def makeDedupFailureDiGraph(failure: DedupFailure, graph: DiGraph[String]): DiGraph[String] = {
+ // Recalculating this every time is a little wasteful, but we're on a failure path anyway
+ // Lookup the parent Module name of any Module
+ val getParents: String => Seq[String] =
+ graph.reverse.getEdgeMap
+ .mapValues(_.toSeq)
+
+ val candidates = failure.candidates
+ val shouldDedup = failure.shouldDedup.map(_.value)
+ val shouldDedupSet = shouldDedup.toSet
+ val mygraph = {
+ // Create a graph of paths from "shouldDedup" nodes to the candidates
+ // rooted at the "shouldDedup" nodes
+ val nodesToKeep = findNodesToKeep(failure, graph)
+ graph.subgraph(nodesToKeep) +
+ // Add fake nodes to represent parents of the "shouldDedup" nodes
+ DiGraph(shouldDedup.map(n => getParents(n).mkString(", ") -> n): _*)
+ }
+ // Gather candidate modules and assign indices for reference
+ val candidateIdx: Map[String, Int] =
+ candidates.zipWithIndex.flatMap { case (c, idx) => c.modules.map(_.value -> idx) }.toMap
+ // Now mark the graph for modules of interest
+ val markedGraph = mygraph.transformNodes { n =>
+ val next = if (shouldDedupSet(n)) s"($n)" else n
+ candidateIdx
+ .get(n)
+ .map(i => s"$next [$i]")
+ .getOrElse(next)
+ }
+ markedGraph
+ }
+}
+
+/** Checks for modules that have been marked as "must deduplicate"
+ *
+ * In cases where marked modules did not deduplicate, this transform attempts to provide context on
+ * what went wrong for debugging.
+ */
+class MustDeduplicateTransform extends Transform with DependencyAPIMigration {
+ import MustDeduplicateTransform._
+
+ override def prerequisites = Seq(Dependency[DedupModules])
+
+ // Make this run as soon after Dedup as possible
+ override def optionalPrerequisiteOf = (Forms.MidForm.toSet -- Forms.HighForm).toSeq
+
+ override def invalidates(a: Transform) = false
+
+ def execute(state: CircuitState): CircuitState = {
+
+ lazy val igraph = InstanceKeyGraph(state.circuit)
+
+ val dedupFailures: Seq[DedupFailure] =
+ state.annotations.flatMap {
+ case MustDeduplicateAnnotation(mods) =>
+ val moduleNames = mods.map(_.leafModule).distinct
+ if (moduleNames.size <= 1) None
+ else {
+ val modNames = moduleNames.map(OfModule)
+ Some(findDedupFailures(modNames, igraph))
+ }
+ case _ => None
+ }
+ if (dedupFailures.nonEmpty) {
+ val modgraph = igraph.graph.transformNodes(_.module)
+ // Create and log reports
+ val reports = dedupFailures.map {
+ case fail @ DedupFailure(shouldDedup, _, candidates) =>
+ val graph = makeDedupFailureDiGraph(fail, modgraph).prettyTree()
+ val mods = shouldDedup.map("'" + _.value + "'").mkString(", ")
+ val msg =
+ s"""===== $mods are marked as "must deduplicate", but did not deduplicate. =====
+ |$graph
+ |Failure candidates:
+ |${candidates.zipWithIndex.map { case (c, i) => s" - [$i] " + c.message }.mkString("\n")}
+ |""".stripMargin
+ logger.error(msg)
+ msg
+ }
+
+ // Write reports and modules to disk
+ val dirName = state.annotations.collectFirst { case MustDeduplicateReportDirectory(dir) => dir }
+ .getOrElse("dedup_failures")
+ val dir = new File(dirName)
+ logger.error(s"Writing error report(s) to ${dir}...")
+ FileUtils.makeDirectory(dir.toString)
+ for ((report, idx) <- reports.zipWithIndex) {
+ val f = new File(dir, s"report_$idx.rpt")
+ logger.error(s"Writing $f...")
+ val fw = new FileWriter(f)
+ fw.write(report)
+ fw.close()
+ }
+
+ val modsDir = new File(dir, "modules")
+ FileUtils.makeDirectory(modsDir.toString)
+ logger.error(s"Writing relevant modules to $modsDir...")
+ val relevantModule = dedupFailures.flatMap(_.relevantMods.map(_.value)).toSet
+ for (mod <- state.circuit.modules if relevantModule(mod.name)) {
+ val fw = new FileWriter(new File(modsDir, s"${mod.name}.fir"))
+ fw.write(mod.serialize)
+ fw.close()
+ }
+
+ val msg = s"Modules marked 'must deduplicate' failed to deduplicate! See error reports in $dirName"
+ throw new DeduplicationFailureException(msg)
+ }
+ state
+ }
+}