aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala
blob: 7e2fff48f650065349a1cc61a987dcd2573813f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
// SPDX-License-Identifier: Apache-2.0

package firrtl.analyses

import firrtl.annotations._
import firrtl.annotations.TargetToken._
import firrtl.graph.{DiGraph, EulerTour, MutableDiGraph}
import firrtl.ir

import scala.collection.mutable

/** A class representing the instance hierarchy of firrtl Circuit
  * This is a faster version of the old `InstanceGraph` which only uses
  * pairs of InstanceName and Module name as vertex keys instead of using WDefInstance
  * which will hash the instance type causing some performance issues.
  */
class InstanceKeyGraph private (c: ir.Circuit) {
  import InstanceKeyGraph._

  private val nameToModule: Map[String, ir.DefModule] = c.modules.map({ m => (m.name, m) }).toMap
  private val childInstances: Seq[(String, Seq[InstanceKey])] = c.modules.map { m =>
    m.name -> InstanceKeyGraph.collectInstances(m)
  }
  private val instantiated = childInstances.flatMap(_._2).map(_.module).toSet
  private val roots = c.modules.map(_.name).filterNot(instantiated)
  private val internalGraph = buildGraph(childInstances, roots)
  private val circuitTopInstance = topKey(c.main)
  // cache vertices to speed up repeat calls to findInstancesInHierarchy
  private lazy val vertices = internalGraph.getVertices

  /** A list of absolute paths (each represented by a Seq of instances) of all module instances in the Circuit. */
  private lazy val cachedFullHierarchy: mutable.LinkedHashMap[InstanceKey, Seq[Seq[InstanceKey]]] =
    internalGraph.pathsInDAG(circuitTopInstance)

  /** modules reachable from the main module as well as the main modules itself */
  private lazy val cachedReachableModules: Seq[OfModule] =
    circuitTopInstance.OfModule +: internalGraph.reachableFrom(circuitTopInstance).toSeq.map(_.OfModule)

  private lazy val cachedUnreachableModules: Seq[OfModule] = {
    val all = mutable.LinkedHashSet(childInstances.map(c => OfModule(c._1))*)
    val reachable = mutable.LinkedHashSet(cachedReachableModules*)
    all.diff(reachable).toSeq
  }

  /** returns the underlying graph */
  def graph: DiGraph[InstanceKey] = internalGraph

  /** returns the key to the top (main) module */
  def top: InstanceKey = circuitTopInstance

  /** maps module names to the DefModule node */
  def moduleMap: Map[String, ir.DefModule] = nameToModule

  /** Module order from highest module to leaf module */
  def moduleOrder: Seq[ir.DefModule] = internalGraph.transformNodes(_.module).linearize.map(nameToModule(_))

  /** Returns a sequence that can be turned into a map from module name to instances defined in said module. */
  def getChildInstances: Seq[(String, Seq[InstanceKey])] = childInstances

  /** A count of the *static* number of instances of each module. For any module other than the top (main) module,
    * this is equivalent to the number of inst statements in the circuit instantiating each module,
    * irrespective of the number of times (if any) the enclosing module appears in the hierarchy.
    * Note that top module of the circuit has an associated count of one, even though it is never directly instantiated.
    * Any modules *not* instantiated at all will have a count of zero.
    */
  def staticInstanceCount: Map[OfModule, Int] = cachedStaticInstanceCount

  private lazy val cachedStaticInstanceCount = {
    val foo = mutable.LinkedHashMap.empty[OfModule, Int]
    childInstances.foreach {
      case (main, _) if main == c.main => foo += main.OfModule -> 1
      case (other, _)                  => foo += other.OfModule -> 0
    }
    childInstances.flatMap(_._2).map(_.OfModule).foreach { mod =>
      foo += mod -> (foo(mod) + 1)
    }
    foo.toMap
  }

  /** Finds the absolute paths (each represented by a Seq of instances
    * representing the chain of hierarchy) of all instances of a particular
    * module. Note that this includes one implicit instance of the top (main)
    * module of the circuit. If the module is not instantiated within the
    * hierarchy of the top module of the circuit, it will return Nil.
    *
    * @param module the name of the selected module
    * @return a Seq[ Seq[WDefInstance] ] of absolute instance paths
    */
  def findInstancesInHierarchy(module: String): Seq[Seq[InstanceKey]] = {
    val instances = vertices.filter(_.module == module).toSeq
    instances.flatMap { i => cachedFullHierarchy.getOrElse(i, Nil) }
  }

  /** Given a circuit, returns a map from module name to a map
    * in turn mapping instances names to corresponding module names
    */
  def getChildInstanceMap: mutable.LinkedHashMap[OfModule, mutable.LinkedHashMap[Instance, OfModule]] =
    mutable.LinkedHashMap(childInstances.map {
      case (k, v) =>
        val moduleMap: mutable.LinkedHashMap[Instance, OfModule] = mutable.LinkedHashMap(v.map(_.toTokens)*)
        TargetToken.OfModule(k) -> moduleMap
    }*)

  /** All modules in the circuit reachable from the top module */
  def reachableModules: Seq[OfModule] = cachedReachableModules

  /** All modules *not* reachable from the top module of the circuit */
  def unreachableModules: Seq[OfModule] = cachedUnreachableModules

  /** A list of absolute paths (each represented by a Seq of instances) of all module instances in the Circuit. */
  def fullHierarchy: mutable.LinkedHashMap[InstanceKey, Seq[Seq[InstanceKey]]] = cachedFullHierarchy
}

object InstanceKeyGraph {
  def apply(c: ir.Circuit): InstanceKeyGraph = new InstanceKeyGraph(c)

  /** We want to only use this untyped version as key because hashing bundle types is expensive
    * @param name the name of the instance
    * @param module the name of the module that is instantiated
    */
  case class InstanceKey(name: String, module: String) {
    def Instance: Instance = TargetToken.Instance(name)
    def OfModule: OfModule = TargetToken.OfModule(module)
    def toTokens: (Instance, OfModule) = (Instance, OfModule)
  }

  /** Finds all instance definitions in a firrtl Module. */
  def collectInstances(m: ir.DefModule): Seq[InstanceKey] = m match {
    case _: ir.ExtModule => Seq()
    case ir.Module(_, _, _, body) => {
      val instances = mutable.ArrayBuffer[InstanceKey]()
      def onStmt(s: ir.Statement): Unit = s match {
        case firrtl.WDefInstance(_, name, module, _) => instances += InstanceKey(name, module)
        case ir.DefInstance(_, name, module, _)      => instances += InstanceKey(name, module)
        case _: firrtl.WDefInstanceConnector =>
          firrtl.Utils.throwInternalError("Expecting WDefInstance, found a WDefInstanceConnector!")
        case other => other.foreachStmt(onStmt)
      }
      onStmt(body)
      instances.toSeq
    }
  }

  private def topKey(module: String): InstanceKey = InstanceKey(module, module)

  private def buildGraph(
    childInstances: Seq[(String, Seq[InstanceKey])],
    roots:          Iterable[String]
  ): DiGraph[InstanceKey] = {
    val instanceGraph = new MutableDiGraph[InstanceKey]
    val childInstanceMap = childInstances.toMap

    // iterate over all modules that are not instantiated and thus act as a root
    roots.foreach { subTop =>
      // create a root node
      val topInstance = topKey(subTop)
      // graph traversal
      val instanceQueue = new mutable.Queue[InstanceKey]
      instanceQueue.enqueue(topInstance)
      while (instanceQueue.nonEmpty) {
        val current = instanceQueue.dequeue()
        instanceGraph.addVertex(current)
        for (child <- childInstanceMap(current.module)) {
          if (!instanceGraph.contains(child)) {
            instanceQueue.enqueue(child)
            instanceGraph.addVertex(child)
          }
          instanceGraph.addEdge(current, child)
        }
      }
    }
    instanceGraph
  }
}