aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/Flatten.scala
blob: 7a7c7338ab9ca35e02328bfd3630d05ad0726a29 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// See LICENSE for license details.

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import scala.collection.mutable
import firrtl.options.PreservesAll
import firrtl.passes.{InlineInstances,PassException}
import firrtl.stage.Forms

/** Tags an annotation to be consumed by this transform */
case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named] {
  def duplicate(n: Named) = FlattenAnnotation(n)
}

/**
  * Takes flatten annotations for module instances and modules and inline the entire hierarchy of
  * modules down from the annotations. This transformation instantiates and is based on the
  * InlineInstances transformation.
  *
  * @note Flattening a module means inlining all its fully-defined child instances
  * @note Instances of extmodules are not (and cannot be) inlined
  */
class Flatten extends Transform with DependencyAPIMigration with PreservesAll[Transform] {

   override def prerequisites = Forms.LowForm
   override def optionalPrerequisites = Seq.empty
   override def optionalPrerequisiteOf = Forms.LowEmitters

   val inlineTransform = new InlineInstances

   private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
     anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) {
       case ((modNames, instNames), ann) => ann match {
         case FlattenAnnotation(CircuitName(c)) =>
           (circuit.modules.collect {
             case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
           }.toSet, instNames)
         case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
         case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
         case _ => throw new PassException("Annotation must be a FlattenAnnotation")
       }
     }

   /**
    *  Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and
    *  by rewriting the original circuit to refer to the new modules that will be inlined later.
    *  @return modified circuit and ModuleNames to inline
    */
   def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = {
     val modMap = c.modules.map(m => m.name->m).toMap
     val seedMods = mutable.Map.empty[String, String]
     val newModDefs = mutable.Set.empty[DefModule]
     val nsp = Namespace(c)

     /**
      *  We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later.
      *  It populates seedMods where we capture the mapping between the original module name of the instances came from annotation
      *  to a new module name that we will create as a replica of the original one.
      *  Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged.
      */
     def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
       case _: Block => x map rewriteMod(parent)
       case WDefInstance(info, instName, moduleName, instTpe) =>
         if (insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main))))
           || mods.contains(ModuleName(parent.name, CircuitName(c.main)))) {
           val newModName = if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
           seedMods += moduleName -> newModName
           WDefInstance(info, instName, newModName, instTpe)
         } else x
       case _ => x
     }

     val modifMods = c.modules map { m => m map rewriteMod(m) }

     /**
      *  Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations).
      *  Populates newModDefs, which are replicated modules used in the subcircuit that we create
      *  by recursively traversing modules captured inside seedMods and replicating them
      */
     def recDupMods(mods: Map[String, String]): Unit = {
       val replMods = mutable.Map.empty[String, String]

       def dupMod(x: Statement): Statement = x match {
         case _: Block => x map dupMod
         case WDefInstance(info, instName, moduleName, instTpe) => modMap(moduleName) match {
           case m: Module =>
             val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
             replMods += moduleName -> newModName
             WDefInstance(info, instName, newModName, instTpe)
           case _ => x // Ignore extmodules
         }
         case _ => x
       }

       def dupName(name: String): String = mods(name)
       val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName }

       newModDefs ++= newMods

       if(replMods.size > 0) recDupMods(replMods.toMap)

     }
     recDupMods(seedMods.toMap)

     //convert newly created modules to ModuleName for inlining next (outside this function)
     val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) }
     (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
   }

   override def execute(state: CircuitState): CircuitState = {
     val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a }
     annos match {
       case Nil => state
       case myAnnotations =>
         val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
         // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline
         val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames)
         inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations)
     }
   }
}