1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
// See LICENSE for license details.
package firrtl
package transforms
import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import scala.collection.mutable
object FlattenRegUpdate {
/** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
type Netlist = mutable.HashMap[WrappedExpression, Expression]
/** Build a [[Netlist]] from a Module's connections and Nodes
*
* This assumes [[firrtl.LowForm LowForm]]
*
* @param mod [[firrtl.ir.Module Module]] from which to build a [[Netlist]]
* @return [[Netlist]] of the module's connections and nodes
*/
def buildNetlist(mod: Module): Netlist = {
val netlist = new Netlist()
def onStmt(stmt: Statement): Statement = {
stmt.map(onStmt) match {
case Connect(_, lhs, rhs) =>
netlist(lhs) = rhs
case DefNode(_, nname, rhs) =>
netlist(WRef(nname)) = rhs
case _: IsInvalid => throwInternalError("Unexpected IsInvalid, should have been removed by now")
case _ => // Do nothing
}
stmt
}
mod.map(onStmt)
netlist
}
/** Flatten Register Updates
*
* Constructs nested mux trees (up to a certain arbitrary threshold) for register updates. This
* can result in dead code that this function does NOT remove.
*
* @param mod [[firrtl.ir.Module Module]] to transform
* @return [[firrtl.ir.Module Module]] with register updates flattened
*/
def flattenReg(mod: Module): Module = {
// We want to flatten Mux trees for reg updates into if-trees for
// improved QoR for conditional updates. However, unbounded recursion
// would take exponential time, so don't redundantly flatten the same
// Mux more than a bounded number of times, preserving linear runtime.
// The threshold is empirical but ample.
val flattenThreshold = 4
val numTimesFlattened = mutable.HashMap[Mux, Int]()
def canFlatten(m: Mux): Boolean = {
val n = numTimesFlattened.getOrElse(m, 0)
numTimesFlattened(m) = n + 1
n < flattenThreshold
}
val regUpdates = mutable.ArrayBuffer.empty[Connect]
val netlist = buildNetlist(mod)
def constructRegUpdate(e: Expression): Expression = {
// Only walk netlist for nodes and wires, NOT registers or other state
val expr = kind(e) match {
case NodeKind | WireKind => netlist.getOrElse(e, e)
case _ => e
}
expr match {
case mux: Mux if canFlatten(mux) =>
val tvalx = constructRegUpdate(mux.tval)
val fvalx = constructRegUpdate(mux.fval)
mux.copy(tval = tvalx, fval = fvalx)
// Return the original expression to end flattening
case _ => e
}
}
def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
case reg @ DefRegister(_, rname, _,_, resetCond, _) =>
assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
"Synchronous reset should have already been made explicit!")
val ref = WRef(reg)
val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref)))
regUpdates += update
reg
// Remove connections to Registers so we preserve LowFirrtl single-connection semantics
case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt
case other => other
}
val bodyx = onStmt(mod.body)
mod.copy(body = Block(bodyx +: regUpdates))
}
}
/** Flatten register update
*
* This transform flattens register updates into a single expression on the rhs of connection to
* the register
*/
// TODO Preserve source locators
class FlattenRegUpdate extends Transform {
def inputForm = MidForm
def outputForm = MidForm
def execute(state: CircuitState): CircuitState = {
val modulesx = state.circuit.modules.map {
case mod: Module => FlattenRegUpdate.flattenReg(mod)
case ext: ExtModule => ext
}
state.copy(circuit = state.circuit.copy(modules = modulesx))
}
}
|