aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/Dedup.scala
blob: 36bc6f66c53f3292c1ebef23ca3e744d9c92c58a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
// See LICENSE for license details.

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import firrtl.passes.PassException

// Datastructures
import scala.collection.mutable


/** A component, e.g. register etc. Must be declared only once under the TopAnnotation
  */
object NoDedupAnnotation {
  def apply(target: ModuleName): Annotation = Annotation(target, classOf[DedupModules], s"nodedup!")

  def unapply(a: Annotation): Option[ModuleName] = a match {
    case Annotation(ModuleName(n, c), _, "nodedup!") => Some(ModuleName(n, c))
    case _ => None
  }
}


// Only use on legal Firrtl. Specifically, the restriction of
//  instance loops must have been checked, or else this pass can
//  infinitely recurse
class DedupModules extends Transform {
  def inputForm = HighForm
  def outputForm = HighForm
  def execute(state: CircuitState): CircuitState = {
    getMyAnnotations(state) match {
      case Nil => CircuitState(run(state.circuit, Seq.empty), state.form)
      case annos =>
        val noDedups = annos.collect { case NoDedupAnnotation(ModuleName(m, c)) => m }
        CircuitState(run(state.circuit, noDedups), state.form)
    }
  }
  def run(c: Circuit, noDedups: Seq[String]): Circuit = {
    val moduleOrder = mutable.ArrayBuffer.empty[String]
    val moduleMap = c.modules.map(m => m.name -> m).toMap
    def hasInstance(b: Statement): Boolean = {
      var has = false
      def onStmt(s: Statement): Statement = s map onStmt match {
        case DefInstance(i, n, m) =>
          if(!(moduleOrder contains m)) has = true
          s
        case WDefInstance(i, n, m, t) =>
          if(!(moduleOrder contains m)) has = true
          s
        case _ => s
      }
      onStmt(b)
      has
    }
    def addModule(m: DefModule): DefModule = m match {
      case Module(info, n, ps, b) =>
        if(!hasInstance(b)) moduleOrder += m.name
        m
      case e: ExtModule =>
        moduleOrder += m.name
        m
      case _ => m
    }

    while((moduleOrder.size < c.modules.size)) {
      c.modules.foreach(m => if(!moduleOrder.contains(m.name)) addModule(m))
    }

    // Module body -> Module name
    val dedupModules = mutable.HashMap.empty[String, String]
    // Old module name -> dup module name
    val dedupMap = mutable.HashMap.empty[String, String]
    // Dup module name -> all old module names
    val oldModuleMap = mutable.HashMap.empty[String, Seq[DefModule]]

    def onModule(m: DefModule): Unit = {
      def fixInstance(s: Statement): Statement = s map fixInstance match {
        case DefInstance(i, n, m) => DefInstance(i, n, dedupMap.getOrElse(m, m))
        case WDefInstance(i, n, m, t) => WDefInstance(i, n, dedupMap.getOrElse(m, m), t)
        case x => x
      }
      def removeInfo(stmt: Statement): Statement = stmt map removeInfo match {
        case sx: HasInfo => sx match {
          case s: DefWire => s.copy(info = NoInfo)
          case s: DefNode => s.copy(info = NoInfo)
          case s: DefRegister => s.copy(info = NoInfo)
          case s: DefInstance => s.copy(info = NoInfo)
          case s: WDefInstance => s.copy(info = NoInfo)
          case s: DefMemory => s.copy(info = NoInfo)
          case s: Connect => s.copy(info = NoInfo)
          case s: PartialConnect => s.copy(info = NoInfo)
          case s: IsInvalid => s.copy(info = NoInfo)
          case s: Attach => s.copy(info = NoInfo)
          case s: Stop => s.copy(info = NoInfo)
          case s: Print => s.copy(info = NoInfo)
          case s: Conditionally => s.copy(info = NoInfo)
        }
        case sx => sx
      }
      def removePortInfo(p: Port): Port = p.copy(info = NoInfo)


      val mx = m map fixInstance
      val mxx = (mx map removeInfo) map removePortInfo

      // If shouldn't dedup, just make it fail to be the same to any other modules
      val unique = if (!noDedups.contains(mxx.name)) "" else mxx.name

      val string = mxx match {
        case Module(i, n, ps, b) =>
          ps.map(_.serialize).mkString + b.serialize + unique
        case ExtModule(i, n, ps, dn, p) =>
          ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString + unique
      }
      dedupModules.get(string) match {
        case Some(dupname) =>
          dedupMap(mx.name) = dupname
          oldModuleMap(dupname) = oldModuleMap(dupname) :+ mx
        case None =>
          dedupModules(string) = mx.name
          oldModuleMap(mx.name) = Seq(mx)
      }
    }
    def mergeModules(ms: Seq[DefModule]) = {
      def mergeStatements(ss: Seq[Statement]): Statement = ss.head match {
        case Block(stmts) =>
          val inverted = invertSeqs(ss.map { case Block(s) => s })
          val finalStmts = inverted.map { jStmts => mergeStatements(jStmts) }
          Block(finalStmts.toSeq)
        case Conditionally(info, pred, conseq, alt) =>
          val finalConseq = mergeStatements(ss.map { case Conditionally(_, _, c, _) => c })
          val finalAlt = mergeStatements(ss.map { case Conditionally(_, _, _, a) => a })
          val finalInfo = ss.map { case Conditionally(i, _, _, _) => i }.reduce (_ ++ _)
          Conditionally(finalInfo, pred, finalConseq, finalAlt)
        case sx: HasInfo => sx match {
          case s: DefWire => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: DefNode => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: DefRegister => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: DefInstance => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: WDefInstance => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: DefMemory => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: Connect => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: PartialConnect => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: IsInvalid => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: Attach => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: Stop => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
          case s: Print => s.copy(info = ss.map(getInfo).reduce(_ ++ _))
        }
        case s => s
      }
      def getInfo(s: Any): Info = s match {
        case sx: HasInfo => sx.info
        case _ => NoInfo
      }
      def invertSeqs[A](seq: Seq[Seq[A]]): Seq[Seq[A]] = {
        val finalSeq = collection.mutable.ArrayBuffer[Seq[A]]()
        for(j <- 0 until seq.head.size) {
          finalSeq += seq.map(s => s(j))
        }
        finalSeq.toSeq
      }
      val finalPorts = invertSeqs(ms.map(_.ports)).map { jPorts => 
        jPorts.tail.foldLeft(jPorts.head) { (p1, p2) =>
          Port(p1.info ++ p2.info, p1.name, p1.direction, p1.tpe)
        }
      }
      val finalInfo = ms.map(getInfo).reduce(_ ++ _)
      ms.head match {
        case e: ExtModule => ExtModule(finalInfo, e.name, finalPorts, e.defname, e.params)
        case e: Module => Module(finalInfo, e.name, finalPorts, mergeStatements(ms.collect { case m: Module => m.body}))
      }
    }
    moduleOrder.foreach(n => onModule(moduleMap(n)))

    // Use old module list to preserve ordering
    val dedupedModules = c.modules.flatMap { m => 
      oldModuleMap.get(m.name) match {
        case Some(modules) => Some(mergeModules(modules))
        case None => None
      }
    }
    c.copy(modules = dedupedModules)
  }
}