1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
|
// See LICENSE for license details.
package firrtl
package transforms
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.WrappedExpression._
import firrtl.graph.{MutableDiGraph, CyclicException}
import firrtl.options.Dependency
import scala.collection.mutable
import scala.util.{Try, Success, Failure}
/** Replace wires with nodes in a legal, flow-forward order
*
* This pass must run after LowerTypes because Aggregate-type
* wires have multiple connections that may be impossible to order in a
* flow-foward way
*/
class RemoveWires extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
Seq( Dependency(passes.LowerTypes),
Dependency(passes.Legalize),
Dependency(transforms.RemoveReset),
Dependency[transforms.CheckCombLoops] )
override def optionalPrerequisites = Seq(Dependency[checks.CheckResets])
override def optionalPrerequisiteOf = Seq.empty
override def invalidates(a: Transform) = false
// Extract all expressions that are references to a Node, Wire, or Reg
// Since we are operating on LowForm, they can only be WRefs
private def extractNodeWireRegRefs(expr: Expression): Seq[WRef] = {
val refs = mutable.ArrayBuffer.empty[WRef]
def rec(e: Expression): Expression = {
e match {
case ref @ WRef(_,_, WireKind | NodeKind | RegKind, _) => refs += ref
case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec)
case _ => // Do nothing
}
e
}
rec(expr)
refs
}
// Transform netlist into DefNodes
private def getOrderedNodes(
netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)],
regInfo: mutable.Map[WrappedExpression, DefRegister]): Try[Seq[Statement]] = {
val digraph = new MutableDiGraph[WrappedExpression]
for ((sink, (exprs, _)) <- netlist) {
digraph.addVertex(sink)
for (expr <- exprs) {
for (source <- extractNodeWireRegRefs(expr)) {
digraph.addPairWithEdge(sink, source)
}
}
}
// We could reverse edge directions and not have to do this reverse, but doing it this way does
// a MUCH better job of preserving the logic order as expressed by the designer
// See RemoveWireTests for illustration
Try {
val ordered = digraph.linearize.reverse
ordered.map { key =>
val WRef(name, _, kind, _) = key.e1
kind match {
case RegKind => regInfo(key)
case WireKind | NodeKind =>
val (Seq(rhs), info) = netlist(key)
DefNode(info, name, rhs)
}
}
}
}
private def onModule(m: DefModule): DefModule = {
// Store all non-node declarations here (like reg, inst, and mem)
val decls = mutable.ArrayBuffer.empty[Statement]
// Store all "other" statements here, non-wire, non-node connections, printfs, etc.
val otherStmts = mutable.ArrayBuffer.empty[Statement]
// Add nodes and wire connection here
val netlist = mutable.LinkedHashMap.empty[WrappedExpression, (Seq[Expression], Info)]
// Info at definition of wires for combining into node
val wireInfo = mutable.HashMap.empty[WrappedExpression, Info]
// Additional info about registers
val regInfo = mutable.HashMap.empty[WrappedExpression, DefRegister]
def onStmt(stmt: Statement): Statement = {
stmt match {
case node: DefNode =>
netlist(we(WRef(node))) = (Seq(node.value), node.info)
case wire: DefWire if !wire.tpe.isInstanceOf[AnalogType] => // Remove all non-Analog wires
wireInfo(WRef(wire)) = wire.info
case reg: DefRegister =>
val resetDep = reg.reset.tpe match {
case AsyncResetType => Some(reg.reset)
case _ => None
}
val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself
regInfo(we(WRef(reg))) = reg
netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info)
case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires
decls += decl
case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match {
case WireKind =>
// Be sure to pad the rhs since nodes get their type from the rhs
val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
val dinfo = wireInfo(lhs)
netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo))
case _ => otherStmts += con // Other connections just pass through
}
case invalid @ IsInvalid(info, expr) =>
kind(expr) match {
case WireKind =>
val width = expr.tpe match { case GroundType(width) => width } // LowFirrtl
netlist(we(expr)) = (Seq(ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe)), info)
case _ => otherStmts += invalid
}
case other @ (_: Print | _: Stop | _: Attach) =>
otherStmts += other
case EmptyStmt => // Dont bother keeping EmptyStmts around
case block: Block => block.foreach(onStmt)
case _ => throwInternalError()
}
stmt
}
m match {
case mod @ Module(info, name, ports, body) =>
onStmt(body)
getOrderedNodes(netlist, regInfo) match {
case Success(logic) =>
Module(info, name, ports, Block(decls ++ logic ++ otherStmts))
// If we hit a CyclicException, just abort removing wires
case Failure(c: CyclicException) =>
val problematicNode = c.node
logger.warn(s"Cycle found in module $name, " +
s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode")
mod
case Failure(other) => throw other
}
case m: ExtModule => m
}
}
/* @todo move ResolveKinds outside */
private val cleanup = Seq(
passes.ResolveKinds
)
def execute(state: CircuitState): CircuitState = {
val result = state.copy(circuit = state.circuit.map(onModule))
cleanup.foldLeft(result) { case (in, xform) => xform.execute(in) }
}
}
|