1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
|
// See LICENSE for license details.
package firrtl
package transforms
import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import firrtl.passes.PassException
// Datastructures
import scala.collection.mutable
/** A component, e.g. register etc. Must be declared only once under the TopAnnotation
*/
object NoDedupAnnotation {
def apply(target: ModuleName): Annotation = Annotation(target, classOf[DedupModules], s"nodedup!")
def unapply(a: Annotation): Option[ModuleName] = a match {
case Annotation(ModuleName(n, c), _, "nodedup!") => Some(ModuleName(n, c))
case _ => None
}
}
// Only use on legal Firrtl. Specifically, the restriction of
// instance loops must have been checked, or else this pass can
// infinitely recurse
class DedupModules extends Transform {
def inputForm = HighForm
def outputForm = HighForm
def execute(state: CircuitState): CircuitState = {
getMyAnnotations(state) match {
case Nil => state.copy(circuit = run(state.circuit, Seq.empty))
case annos =>
val noDedups = annos.collect { case NoDedupAnnotation(ModuleName(m, c)) => m }
state.copy(circuit = run(state.circuit, noDedups))
}
}
// Orders the modules of a circuit from leaves to root
// A module will appear *after* all modules it instantiates
private def buildModuleOrder(c: Circuit): Seq[String] = {
val moduleOrder = mutable.ArrayBuffer.empty[String]
def hasInstance(b: Statement): Boolean = {
var has = false
def onStmt(s: Statement): Statement = s map onStmt match {
case DefInstance(i, n, m) =>
if(!(moduleOrder contains m)) has = true
s
case WDefInstance(i, n, m, t) =>
if(!(moduleOrder contains m)) has = true
s
case _ => s
}
onStmt(b)
has
}
def addModule(m: DefModule): DefModule = m match {
case Module(info, n, ps, b) =>
if (!hasInstance(b)) moduleOrder += m.name
m
case e: ExtModule =>
moduleOrder += m.name
m
case _ => m
}
while ((moduleOrder.size < c.modules.size)) {
c.modules.foreach(m => if (!moduleOrder.contains(m.name)) addModule(m))
}
moduleOrder
}
// Finds duplicate Modules
// Also changes DefInstances to instantiate the deduplicated module
private def findDups(
moduleOrder: Seq[String],
moduleMap: Map[String, DefModule],
noDedups: Seq[String]): Map[String, Seq[DefModule]] = {
// Module body -> Module name
val dedupModules = mutable.HashMap.empty[String, String]
// Old module name -> dup module name
val dedupMap = mutable.HashMap.empty[String, String]
// Dup module name -> all old module names
val oldModuleMap = mutable.HashMap.empty[String, Seq[DefModule]]
def onModule(m: DefModule): Unit = {
def fixInstance(s: Statement): Statement = s map fixInstance match {
case DefInstance(i, n, m) => DefInstance(i, n, dedupMap.getOrElse(m, m))
case WDefInstance(i, n, m, t) => WDefInstance(i, n, dedupMap.getOrElse(m, m), t)
case x => x
}
def removeInfo(stmt: Statement): Statement = stmt map removeInfo match {
case sx: HasInfo => sx match {
case s: DefWire => s.copy(info = NoInfo)
case s: DefNode => s.copy(info = NoInfo)
case s: DefRegister => s.copy(info = NoInfo)
case s: DefInstance => s.copy(info = NoInfo)
case s: WDefInstance => s.copy(info = NoInfo)
case s: DefMemory => s.copy(info = NoInfo)
case s: Connect => s.copy(info = NoInfo)
case s: PartialConnect => s.copy(info = NoInfo)
case s: IsInvalid => s.copy(info = NoInfo)
case s: Attach => s.copy(info = NoInfo)
case s: Stop => s.copy(info = NoInfo)
case s: Print => s.copy(info = NoInfo)
case s: Conditionally => s.copy(info = NoInfo)
}
case sx => sx
}
def removePortInfo(p: Port): Port = p.copy(info = NoInfo)
val mx = m map fixInstance
val mxx = (mx map removeInfo) map removePortInfo
// If shouldn't dedup, just make it fail to be the same to any other modules
val unique = if (!noDedups.contains(mxx.name)) "" else mxx.name
val string = mxx match {
case Module(i, n, ps, b) =>
ps.map(_.serialize).mkString + b.serialize + unique
case ExtModule(i, n, ps, dn, p) =>
ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString + unique
}
dedupModules.get(string) match {
case Some(dupname) =>
dedupMap(mx.name) = dupname
oldModuleMap(dupname) = oldModuleMap(dupname) :+ mx
case None =>
dedupModules(string) = mx.name
oldModuleMap(mx.name) = Seq(mx)
}
}
moduleOrder.foreach(n => onModule(moduleMap(n)))
oldModuleMap.toMap
}
def run(c: Circuit, noDedups: Seq[String]): Circuit = {
val moduleOrder = buildModuleOrder(c)
val moduleMap = c.modules.map(m => m.name -> m).toMap
val oldModuleMap = findDups(moduleOrder, moduleMap, noDedups)
// Use old module list to preserve ordering
val dedupedModules = c.modules.flatMap(m => oldModuleMap.get(m.name).map(_.head))
c.copy(modules = dedupedModules)
}
}
|