aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/Flatten.scala
blob: 0aa155bdb124f602f3ef3669b9abfc59a35cbf89 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// SPDX-License-Identifier: Apache-2.0

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import scala.collection.mutable
import firrtl.passes.{InlineInstances, PassException}
import firrtl.stage.Forms

/** Tags an annotation to be consumed by this transform */
case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named] {
  def duplicate(n: Named) = FlattenAnnotation(n)
}

/**
  * Takes flatten annotations for module instances and modules and inline the entire hierarchy of
  * modules down from the annotations. This transformation instantiates and is based on the
  * InlineInstances transformation.
  *
  * @note Flattening a module means inlining all its fully-defined child instances
  * @note Instances of extmodules are not (and cannot be) inlined
  */
class Flatten extends Transform with DependencyAPIMigration {

  override def prerequisites = Forms.LowForm
  override def optionalPrerequisites = Seq.empty
  override def optionalPrerequisiteOf = Forms.LowEmitters
  override def invalidates(a: Transform) = false

  val inlineTransform = new InlineInstances

  private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
    anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) {
      case ((modNames, instNames), ann) =>
        ann match {
          case FlattenAnnotation(CircuitName(c)) =>
            (
              circuit.modules.collect {
                case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
              }.toSet,
              instNames
            )
          case FlattenAnnotation(ModuleName(mod, cir))    => (modNames + ModuleName(mod, cir), instNames)
          case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
          case _                                          => throw new PassException("Annotation must be a FlattenAnnotation")
        }
    }

  /**
    *  Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and
    *  by rewriting the original circuit to refer to the new modules that will be inlined later.
    *  @return modified circuit and ModuleNames to inline
    */
  def duplicateSubCircuitsFromAnno(
    c:     Circuit,
    mods:  Set[ModuleName],
    insts: Set[ComponentName]
  ): (Circuit, Set[ModuleName]) = {
    val modMap = c.modules.map(m => m.name -> m).toMap
    val seedMods = mutable.Map.empty[String, String]
    val newModDefs = mutable.Set.empty[DefModule]
    val nsp = Namespace(c)

    /**
      *  We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later.
      *  It populates seedMods where we capture the mapping between the original module name of the instances came from annotation
      *  to a new module name that we will create as a replica of the original one.
      *  Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged.
      */
    def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
      case _: Block => x.map(rewriteMod(parent))
      case WDefInstance(info, instName, moduleName, instTpe) =>
        if (
          insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main))))
          || mods.contains(ModuleName(parent.name, CircuitName(c.main)))
        ) {
          val newModName =
            if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN")
          seedMods += moduleName -> newModName
          WDefInstance(info, instName, newModName, instTpe)
        } else x
      case _ => x
    }

    val modifMods = c.modules.map { m => m.map(rewriteMod(m)) }

    /**
      *  Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations).
      *  Populates newModDefs, which are replicated modules used in the subcircuit that we create
      *  by recursively traversing modules captured inside seedMods and replicating them
      */
    def recDupMods(mods: Map[String, String]): Unit = {
      val replMods = mutable.Map.empty[String, String]

      def dupMod(x: Statement): Statement = x match {
        case _: Block => x.map(dupMod)
        case WDefInstance(info, instName, moduleName, instTpe) =>
          modMap(moduleName) match {
            case m: Module =>
              val newModName =
                if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN")
              replMods += moduleName -> newModName
              WDefInstance(info, instName, newModName, instTpe)
            case _ => x // Ignore extmodules
          }
        case _ => x
      }

      def dupName(name: String): String = mods(name)
      val newMods = mods.map { case (origName, newName) => modMap(origName).map(dupMod).map(dupName) }

      newModDefs ++= newMods

      if (replMods.size > 0) recDupMods(replMods.toMap)

    }
    recDupMods(seedMods.toMap)

    //convert newly created modules to ModuleName for inlining next (outside this function)
    val modsToInline = newModDefs.map { m => ModuleName(m.name, CircuitName(c.main)) }
    (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
  }

  override def execute(state: CircuitState): CircuitState = {
    val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a }
    annos match {
      case Nil => state
      case myAnnotations =>
        val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
        // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline
        val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames)
        val flattenedState = inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations)

        val cleanedAnnos = flattenedState.annotations.filterNot {
          case FlattenAnnotation(_) => true
          case _                    => false
        }

        flattenedState.copy(annotations = cleanedAnnos)
    }
  }
}