1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
|
// See LICENSE for license details.
package firrtl
package passes
import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
// Datastructures
import scala.collection.mutable
/** Indicates that something should be inlined */
case class InlineAnnotation(target: Named) extends SingleTargetAnnotation[Named] {
def duplicate(n: Named) = InlineAnnotation(n)
}
// Only use on legal Firrtl. Specifically, the restriction of
// instance loops must have been checked, or else this pass can
// infinitely recurse
class InlineInstances extends Transform {
def inputForm = LowForm
def outputForm = LowForm
val inlineDelim = "$"
private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) {
case ((modNames, instNames), ann) => ann match {
case InlineAnnotation(CircuitName(c)) =>
(circuit.modules.collect {
case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
}.toSet, instNames)
case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
case _ => (modNames, instNames)
}
}
def execute(state: CircuitState): CircuitState = {
// TODO Add error check for more than one annotation for inlining
val (modNames, instNames) = collectAnns(state.circuit, state.annotations)
if (modNames.nonEmpty || instNames.nonEmpty) {
run(state.circuit, modNames, instNames, state.annotations)
} else {
state
}
}
// Checks the following properties:
// 1) All annotated modules exist
// 2) All annotated modules are InModules (can be inlined)
// 3) All annotated instances exist, and their modules can be inline
def check(c: Circuit, moduleNames: Set[ModuleName], instanceNames: Set[ComponentName]): Unit = {
val errors = mutable.ArrayBuffer[PassException]()
val moduleMap = (for(m <- c.modules) yield m.name -> m).toMap
def checkExists(name: String): Unit =
if (!moduleMap.contains(name))
errors += new PassException(s"Annotated module does not exist: $name")
def checkExternal(name: String): Unit = moduleMap(name) match {
case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: $name")
case _ =>
}
def checkInstance(cn: ComponentName): Unit = {
var containsCN = false
def onStmt(name: String)(s: Statement): Statement = {
s match {
case WDefInstance(_, inst_name, module_name, tpe) =>
if (name == inst_name) {
containsCN = true
checkExternal(module_name)
}
case _ =>
}
s map onStmt(name)
}
onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body)
if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}")
}
moduleNames.foreach{mn => checkExists(mn.name)}
if (errors.nonEmpty) throw new PassExceptions(errors)
moduleNames.foreach{mn => checkExternal(mn.name)}
if (errors.nonEmpty) throw new PassExceptions(errors)
instanceNames.foreach{cn => checkInstance(cn)}
if (errors.nonEmpty) throw new PassExceptions(errors)
}
def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName], annos: AnnotationSeq): CircuitState = {
def getInstancesOf(c: Circuit, modules: Set[String]): Set[String] =
c.modules.foldLeft(Set[String]()) { (set, d) =>
d match {
case e: ExtModule => set
case m: Module =>
val instances = mutable.HashSet[String]()
def findInstances(s: Statement): Statement = s match {
case WDefInstance(info, instName, moduleName, instTpe) if modules.contains(moduleName) =>
instances += m.name + "." + instName
s
case sx => sx map findInstances
}
findInstances(m.body)
instances.toSet ++ set
}
}
// Check annotations and circuit match up
check(c, modsToInline, instsToInline)
val flatModules = modsToInline.map(m => m.name)
val flatInstances = instsToInline.map(i => i.module.name + "." + i.name) ++ getInstancesOf(c, flatModules)
val moduleMap = c.modules.foldLeft(Map[String, DefModule]()) { (map, m) => map + (m.name -> m) }
def appendNamePrefix(prefix: String)(name:String): String = prefix + name
def appendRefPrefix(prefix: String, currentModule: String)(e: Expression): Expression = e match {
case WSubField(WRef(ref, _, InstanceKind, _), field, tpe, gen) if flatInstances.contains(currentModule + "." + ref) =>
WRef(prefix + ref + inlineDelim + field, tpe, WireKind, gen)
case WRef(name, tpe, kind, gen) => WRef(prefix + name, tpe, kind, gen)
case ex => ex map appendRefPrefix(prefix, currentModule)
}
def onStmt(prefix: String, currentModule: String)(s: Statement): Statement = s match {
case WDefInstance(info, instName, moduleName, instTpe) =>
// Rewrites references in inlined statements from ref to inst$ref
val shouldInline = flatInstances.contains(currentModule + "." + instName)
// Used memoized instance if available
if (shouldInline) {
val toInline = moduleMap(moduleName) match {
case m: ExtModule => throw new PassException("Cannot inline external module")
case m: Module => m
}
val stmts = toInline.ports.map(p => DefWire(p.info, p.name, p.tpe)) :+ toInline.body
onStmt(prefix + instName + inlineDelim, moduleName)(Block(stmts))
} else WDefInstance(info, prefix + instName, moduleName, instTpe)
case sx => sx map appendRefPrefix(prefix, currentModule) map onStmt(prefix, currentModule) map appendNamePrefix(prefix)
}
val flatCircuit = c.copy(modules = c.modules.flatMap {
case m if flatModules.contains(m.name) => None
case m =>
Some(m map onStmt("", m.name))
})
CircuitState(flatCircuit, LowForm, annos, None)
}
}
|