1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
// SPDX-License-Identifier: Apache-2.0
package firrtl
package transforms
import firrtl.ir._
import firrtl.Mappers._
import firrtl.options.Dependency
import firrtl.PrimOps.{Bits, Head, Shr, Tail}
import firrtl.Utils.{isBitExtract, isTemp}
import firrtl.WrappedExpression._
import scala.collection.mutable
object InlineBitExtractionsTransform {
// Checks if an Expression is made up of only Bits terminated by a Literal or Reference.
// private because it's not clear if this definition of "Simple Expression" would be useful elsewhere.
// Note that this can have false negatives but MUST NOT have false positives.
private def isSimpleExpr(expr: Expression): Boolean = expr match {
case _: WRef | _: Literal | _: WSubField => true
case DoPrim(op, args, _, _) if isBitExtract(op) => args.forall(isSimpleExpr)
case _ => false
}
// replace Head/Tail/Shr with Bits for easier back-to-back Bits Extractions
private def lowerToDoPrimOpBits(expr: Expression): Expression = expr match {
case DoPrim(Head, rhs, c, tpe) if isSimpleExpr(expr) =>
val msb = bitWidth(rhs.head.tpe) - 1
val lsb = bitWidth(rhs.head.tpe) - c.head
DoPrim(Bits, rhs, Seq(msb, lsb), tpe)
case DoPrim(Tail, rhs, c, tpe) if isSimpleExpr(expr) =>
val msb = bitWidth(rhs.head.tpe) - c.head - 1
DoPrim(Bits, rhs, Seq(msb, 0), tpe)
case DoPrim(Shr, rhs, c, tpe) if isSimpleExpr(expr) =>
DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe) - 1, c.head), tpe)
case _ => expr // Not a candidate
}
/** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
type Netlist = mutable.HashMap[WrappedExpression, Expression]
/** Recursively replace [[WRef]]s with new [[firrtl.ir.Expression Expression]]s
*
* @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
* [[firrtl.ir.Expression Expression Expression]]s. It is '''not''' mutated in this function
* @param expr the Expression being transformed
* @return Returns expr with Bits inlined
*/
def onExpr(netlist: Netlist)(expr: Expression): Expression = {
expr.map(onExpr(netlist)) match {
case e @ WRef(name, _, _, _) =>
netlist
.get(we(e))
.filter(isBitExtract)
.getOrElse(e)
// replace back-to-back Bits Extractions
case lhs @ DoPrim(lop, ival, lc, ltpe) if isSimpleExpr(lhs) =>
ival.head match {
case of @ DoPrim(rop, rhs, rc, rtpe) if isSimpleExpr(of) =>
(lop, rop) match {
case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head.min(rc.head)), ltpe)
case (Tail, Tail) => DoPrim(Tail, rhs, Seq(lc.head + rc.head), ltpe)
case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe)
case (_, _) =>
(lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match {
case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) =>
DoPrim(Bits, rhs, Seq(lmsb + rlsb, llsb + rlsb), ltpe)
case (_, _) => lhs // Not a candidate
}
}
case _ => lhs // Not a candidate
}
case other => other // Not a candidate
}
}
/** Inline bits in a Statement
*
* @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
* [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is
* a [[firrtl.ir.DefNode DefNode]] with a Temporary name and a value that is a [[firrtl.ir.PrimOp PrimOp]] Bits
* @param stmt the Statement being searched for nodes and transformed
* @return Returns stmt with Bits inlined
*/
def onStmt(netlist: Netlist)(stmt: Statement): Statement =
stmt.map(onStmt(netlist)).map(onExpr(netlist)) match {
case node @ DefNode(_, name, value) if isTemp(name) =>
netlist(we(WRef(name))) = value
node
case other => other
}
/** Replaces bits in a Module */
def onMod(mod: DefModule): DefModule = mod.map(onStmt(new Netlist))
}
/** Inline nodes that are simple bits */
class InlineBitExtractionsTransform extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
Seq(
Dependency[BlackBoxSourceHelper],
Dependency[FixAddingNegativeLiterals],
Dependency[ReplaceTruncatingArithmetic]
)
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
override def optionalPrerequisiteOf = Seq.empty
override def invalidates(a: Transform) = false
def execute(state: CircuitState): CircuitState = {
val modulesx = state.circuit.modules.map(InlineBitExtractionsTransform.onMod(_))
state.copy(circuit = state.circuit.copy(modules = modulesx))
}
}
|