aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/Flatten.scala
blob: a94f5d5abbb78f1bfc5ab34ac5912e9190a80646 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// See LICENSE for license details.

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import scala.collection.mutable
import firrtl.passes.{InlineInstances,PassException}

/** Tags an annotation to be consumed by this transform */
case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named] {
  def duplicate(n: Named) = FlattenAnnotation(n)
}

/**
 * Takes flatten annotations for module instances and modules and inline the entire hierarchy of modules down from the annotations.
 * This transformation instantiates and is based on the InlineInstances transformation.  
 * Note: Inlining a module means inlining all its children module instances      
 */
class Flatten extends Transform {
   def inputForm = LowForm
   def outputForm = LowForm

   val inlineTransform = new InlineInstances
   
   private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
     anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) {
       case ((modNames, instNames), ann) => ann match {
         case FlattenAnnotation(CircuitName(c)) =>
           (circuit.modules.collect {
             case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
           }.toSet, instNames)
         case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
         case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
         case _ => throw new PassException("Annotation must be a FlattenAnnotation")
       }
     }

   /**
    *  Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and
    *  by rewriting the original circuit to refer to the new modules that will be inlined later. 
    *  @return modified circuit and ModuleNames to inline
    */
   def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = {
     val modMap = c.modules.map(m => m.name->m) toMap
     val seedMods = mutable.Map.empty[String, String]
     val newModDefs = mutable.Set.empty[DefModule]
     val nsp = Namespace(c)

     /** 
      *  We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later.  
      *  It populates seedMods where we capture the mapping between the original module name of the instances came from annotation 
      *  to a new module name that we will create as a replica of the original one. 
      *  Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged.
      */
     def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
       case _: Block => x map rewriteMod(parent)
       case WDefInstance(info, instName, moduleName, instTpe) =>
         if (insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main))))
           || mods.contains(ModuleName(parent.name, CircuitName(c.main)))) {
           val newModName = if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
           seedMods += moduleName -> newModName
           WDefInstance(info, instName, newModName, instTpe)
         } else x
       case _ => x
     }
     
     val modifMods = c.modules map { m => m map rewriteMod(m) }
     
     /** 
      *  Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations). 
      *  Populates newModDefs, which are replicated modules used in the subcircuit that we create 
      *  by recursively traversing modules captured inside seedMods and replicating them
      */
     def recDupMods(mods: Map[String, String]): Unit = {
       val replMods = mutable.Map.empty[String, String]

       def dupMod(x: Statement): Statement = x match {
         case _: Block => x map dupMod
         case WDefInstance(info, instName, moduleName, instTpe) =>
           val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
           replMods += moduleName -> newModName
           WDefInstance(info, instName, newModName, instTpe)
         case _ => x 
       }
       
       def dupName(name: String): String = mods(name)
       val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName }
       
       newModDefs ++= newMods
       
       if(replMods.size > 0) recDupMods(replMods.toMap)
       
     }
     recDupMods(seedMods.toMap)

     //convert newly created modules to ModuleName for inlining next (outside this function)
     val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) }
     (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
   }
   
   override def execute(state: CircuitState): CircuitState = {
     val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a }
     annos match {
       case Nil => state
       case myAnnotations =>
         val c = state.circuit
         val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
         // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline
         val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames)
         inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations)
     }
   }
}