1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
// SPDX-License-Identifier: Apache-2.0
package firrtl
package transforms
import firrtl.ir._
import firrtl.Mappers._
import firrtl.PrimOps.Pad
import firrtl.options.Dependency
import firrtl.Utils.{isBitExtract, isCast, NodeMap}
object InlineAcrossCastsTransform {
// Checks if an Expression is made up of only casts terminated by a Literal or Reference
// There must be at least one cast
// Note that this can have false negatives but MUST NOT have false positives
private def isSimpleCast(castSeen: Boolean)(expr: Expression): Boolean = expr match {
case _: WRef | _: Literal | _: WSubField => castSeen
case DoPrim(op, args, _, _) if isCast(op) => args.forall(isSimpleCast(true))
case _ => false
}
/** Recursively replace [[WRef]]s with new [[firrtl.ir.Expression Expression]]s
*
* @param replace a '''mutable''' HashMap mapping [[WRef]]s to values with which the [[WRef]]
* will be replaced. It is '''not''' mutated in this function
* @param expr the Expression being transformed
* @return Returns expr with [[WRef]]s replaced by values found in replace
*/
def onExpr(replace: NodeMap)(expr: Expression): Expression = {
// Keep track if we've seen any non-cast expressions while recursing
def rec(hasNonCastParent: Boolean)(expr: Expression): Expression = expr match {
// Skip pads to avoid inlining literals into pads which results in invalid Verilog
case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
case e =>
e.map(rec(hasNonCastParent || !isCast(e))) match {
case e @ WRef(name, _, _, _) =>
replace
.get(name)
.filter(isSimpleCast(castSeen = false))
.getOrElse(e)
case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) =>
replace
.get(name)
// Only inline the Expression if there is no non-cast parent in the expression tree OR
// if the subtree contains only casts and references.
.filter(x => !hasNonCastParent || isSimpleCast(castSeen = true)(x))
.map(value => e.copy(args = Seq(value)))
.getOrElse(e)
case other => other // Not a candidate
}
}
rec(false)(expr)
}
/** Inline across casts in a statement
*
* @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
* [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate
* it if stmt is a [[firrtl.ir.DefNode DefNode]]
* with a value that is a cast [[firrtl.ir.PrimOp PrimpOp]]
* @param stmt the Statement being searched for nodes and transformed
* @return Returns stmt with casts inlined
*/
def onStmt(netlist: NodeMap)(stmt: Statement): Statement =
stmt.map(onStmt(netlist)).map(onExpr(netlist)) match {
case node @ DefNode(_, name, value) =>
netlist(name) = value
node
case other => other
}
/** Inline across casts in a module */
def onMod(mod: DefModule): DefModule = mod.map(onStmt(new NodeMap))
}
/** Inline expressions into casts and inline casts into other expressions
*
* Because casts are no-ops in the emitted Verilog, this transform eliminates statements that
* simply contain a cast. It does so by greedily building larger expression trees that contain at
* most one expression that is neither a cast nor reference-like node.
*/
class InlineAcrossCastsTransform extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
Seq(
Dependency[BlackBoxSourceHelper],
Dependency[FixAddingNegativeLiterals],
Dependency[ReplaceTruncatingArithmetic],
Dependency[InlineBitExtractionsTransform],
Dependency[PropagatePresetAnnotations]
)
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
override def optionalPrerequisiteOf = Seq.empty
override def invalidates(a: Transform): Boolean = a match {
case _: LegalizeClocksAndAsyncResetsTransform => true
case _ => false
}
def execute(state: CircuitState): CircuitState = {
val modulesx = state.circuit.modules.map(InlineAcrossCastsTransform.onMod)
state.copy(circuit = state.circuit.copy(modules = modulesx))
}
}
|