aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/DedupAnnotations.scala
blob: 9aad2fee113dcf99cef0be71c0a2e59a8307fc69 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
// SPDX-License-Identifier: Apache-2.0

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.options.Dependency
import firrtl.Utils.{groupByIntoSeq, BoolType}
import firrtl.annotations.Annotation
import scala.collection.mutable.Buffer
import firrtl.annotations.MemoryFileInlineAnnotation
import firrtl.passes.PassException
import firrtl.annotations.ReferenceTarget
import firrtl.annotations._
import firrtl.analyses.InstanceKeyGraph

import scala.collection.mutable.ArrayBuffer

object DedupAnnotationsTransform {

  final class DifferingModuleAnnotationsException private (msg: String) extends PassException(msg)
  object DifferingModuleAnnotationsException {
    def apply(left: ReferenceTarget, right: ReferenceTarget): DifferingModuleAnnotationsException = {
      val msg = s"${left.serialize} and ${right.serialize} have differing module binaries"
      new DifferingModuleAnnotationsException(msg)
    }
  }

  private case class DedupableRepr(
    dedupKey:       Any,
    deduped:        Annotation,
    original:       Annotation,
    absoluteTarget: ReferenceTarget)
  private object DedupableRepr {
    def apply(annotation: Annotation): Option[DedupableRepr] = annotation.dedup match {
      case Some((dedupKey, dedupedAnno, absoluteTarget)) =>
        Some(new DedupableRepr(dedupKey, dedupedAnno, annotation, absoluteTarget))
      case _ => None
    }
  }

  private type InstancePath = Seq[(TargetToken.Instance, TargetToken.OfModule)]

  private def checkInstanceGraph(
    module:        String,
    graph:         InstanceKeyGraph,
    absolutePaths: Seq[InstancePath]
  ): Boolean = graph.findInstancesInHierarchy(module).size == absolutePaths.size

  def dedupAnnotations(annotations: Seq[Annotation], graph: InstanceKeyGraph): Seq[Annotation] = {
    val canDedup = ArrayBuffer.empty[DedupableRepr]
    val outAnnos = ArrayBuffer.empty[Annotation]

    // Extract the annotations which can be deduplicated
    annotations.foreach { anno =>
      DedupableRepr(anno) match {
        case Some(repr) => canDedup += repr
        case None       => outAnnos += anno
      }
    }

    // Partition the dedupable annotations into groups that *should* deduplicate into the same annotation
    val shouldDedup: Seq[(Any, Seq[DedupableRepr])] = groupByIntoSeq(canDedup)(_.dedupKey)
    shouldDedup.foreach {
      case ((target: ReferenceTarget, _), dedupableAnnos) =>
        val originalAnnos = dedupableAnnos.map(_.original)
        val uniqueDedupedAnnos = dedupableAnnos.map(_.deduped).distinct
        // TODO: Extend this to support multi-target annotations
        val instancePaths = dedupableAnnos.map(_.absoluteTarget.path).toSeq
        // The annotation deduplication is only legal if it applies to *all* instances of a
        // deduplicated module -- requires an instance graph check
        if (uniqueDedupedAnnos.size == 1 && checkInstanceGraph(target.encapsulatingModule, graph, instancePaths))
          outAnnos += uniqueDedupedAnnos.head
        else
          outAnnos ++= originalAnnos
    }

    outAnnos.toSeq
  }
}

/** Deduplicates memory annotations
  */
class DedupAnnotationsTransform extends Transform with DependencyAPIMigration {

  override def prerequisites = Nil

  override def optionalPrerequisites = Nil

  override def optionalPrerequisiteOf = Nil

  override def invalidates(a: Transform) = false

  def execute(state: CircuitState): CircuitState = CircuitState(
    state.circuit,
    state.form,
    DedupAnnotationsTransform.dedupAnnotations(state.annotations.toSeq, InstanceKeyGraph(state.circuit)),
    state.renames
  )
}