aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/Dedup.scala
blob: 5fa2c0360f5c7b53543d3f93fb397672dedaff48 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// See LICENSE for license details.

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import firrtl.passes.PassException

// Datastructures
import scala.collection.mutable


/** A component, e.g. register etc. Must be declared only once under the TopAnnotation
  */
object NoDedupAnnotation {
  def apply(target: ModuleName): Annotation = Annotation(target, classOf[DedupModules], s"nodedup!")

  def unapply(a: Annotation): Option[ModuleName] = a match {
    case Annotation(ModuleName(n, c), _, "nodedup!") => Some(ModuleName(n, c))
    case _ => None
  }
}


// Only use on legal Firrtl. Specifically, the restriction of
//  instance loops must have been checked, or else this pass can
//  infinitely recurse
class DedupModules extends Transform {
  def inputForm = HighForm
  def outputForm = HighForm
  def execute(state: CircuitState): CircuitState = {
    getMyAnnotations(state) match {
      case Nil => state.copy(circuit = run(state.circuit, Seq.empty))
      case annos =>
        val noDedups = annos.collect { case NoDedupAnnotation(ModuleName(m, c)) => m }
        state.copy(circuit = run(state.circuit, noDedups))
    }
  }
  // Orders the modules of a circuit from leaves to root
  // A module will appear *after* all modules it instantiates
  private def buildModuleOrder(c: Circuit): Seq[String] = {
    val moduleOrder = mutable.ArrayBuffer.empty[String]
    def hasInstance(b: Statement): Boolean = {
      var has = false
      def onStmt(s: Statement): Statement = s map onStmt match {
        case DefInstance(i, n, m) =>
          if(!(moduleOrder contains m)) has = true
          s
        case WDefInstance(i, n, m, t) =>
          if(!(moduleOrder contains m)) has = true
          s
        case _ => s
      }
      onStmt(b)
      has
    }
    def addModule(m: DefModule): DefModule = m match {
      case Module(info, n, ps, b) =>
        if (!hasInstance(b)) moduleOrder += m.name
        m
      case e: ExtModule =>
        moduleOrder += m.name
        m
      case _ => m
    }

    while ((moduleOrder.size < c.modules.size)) {
      c.modules.foreach(m => if (!moduleOrder.contains(m.name)) addModule(m))
    }
    moduleOrder
  }

  // Finds duplicate Modules
  // Also changes DefInstances to instantiate the deduplicated module
  private def findDups(
      moduleOrder: Seq[String],
      moduleMap: Map[String, DefModule],
      noDedups: Seq[String]): Map[String, Seq[DefModule]] = {
    // Module body -> Module name
    val dedupModules = mutable.HashMap.empty[String, String]
    // Old module name -> dup module name
    val dedupMap = mutable.HashMap.empty[String, String]
    // Dup module name -> all old module names
    val oldModuleMap = mutable.HashMap.empty[String, Seq[DefModule]]

    def onModule(m: DefModule): Unit = {
      def fixInstance(s: Statement): Statement = s map fixInstance match {
        case DefInstance(i, n, m) => DefInstance(i, n, dedupMap.getOrElse(m, m))
        case WDefInstance(i, n, m, t) => WDefInstance(i, n, dedupMap.getOrElse(m, m), t)
        case x => x
      }
      def removeInfo(stmt: Statement): Statement = stmt map removeInfo match {
        case sx: HasInfo => sx match {
          case s: DefWire => s.copy(info = NoInfo)
          case s: DefNode => s.copy(info = NoInfo)
          case s: DefRegister => s.copy(info = NoInfo)
          case s: DefInstance => s.copy(info = NoInfo)
          case s: WDefInstance => s.copy(info = NoInfo)
          case s: DefMemory => s.copy(info = NoInfo)
          case s: Connect => s.copy(info = NoInfo)
          case s: PartialConnect => s.copy(info = NoInfo)
          case s: IsInvalid => s.copy(info = NoInfo)
          case s: Attach => s.copy(info = NoInfo)
          case s: Stop => s.copy(info = NoInfo)
          case s: Print => s.copy(info = NoInfo)
          case s: Conditionally => s.copy(info = NoInfo)
        }
        case sx => sx
      }
      def removePortInfo(p: Port): Port = p.copy(info = NoInfo)


      val mx = m map fixInstance
      val mxx = (mx map removeInfo) map removePortInfo

      // If shouldn't dedup, just make it fail to be the same to any other modules
      val unique = if (!noDedups.contains(mxx.name)) "" else mxx.name

      val string = mxx match {
        case Module(i, n, ps, b) =>
          ps.map(_.serialize).mkString + b.serialize + unique
        case ExtModule(i, n, ps, dn, p) =>
          ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString + unique
      }
      dedupModules.get(string) match {
        case Some(dupname) =>
          dedupMap(mx.name) = dupname
          oldModuleMap(dupname) = oldModuleMap(dupname) :+ mx
        case None =>
          dedupModules(string) = mx.name
          oldModuleMap(mx.name) = Seq(mx)
      }
    }
    moduleOrder.foreach(n => onModule(moduleMap(n)))
    oldModuleMap.toMap
  }

  def run(c: Circuit, noDedups: Seq[String]): Circuit = {
    val moduleOrder = buildModuleOrder(c)
    val moduleMap = c.modules.map(m => m.name -> m).toMap

    val oldModuleMap = findDups(moduleOrder, moduleMap, noDedups)

    // Use old module list to preserve ordering
    val dedupedModules = c.modules.flatMap(m => oldModuleMap.get(m.name).map(_.head))

    c.copy(modules = dedupedModules)
  }
}