aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
blob: f2dffc4c5b4092e2e8fe6bad9fae24e37449852a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
// SPDX-License-Identifier: Apache-2.0

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.options.Dependency
import firrtl.InfoExpr.{orElse, unwrap}

import scala.collection.mutable

object FlattenRegUpdate {

  // Combination function for dealing with inlining of muxes and the handling of Triples of infos
  private def combineInfos(muxInfo: Info, tinfo: Info, finfo: Info): Info = {
    val (eninfo, tinfoAlt, finfoAlt) = MultiInfo.demux(muxInfo)
    // Use MultiInfo constructor to preserve NoInfos
    new MultiInfo(List(eninfo, orElse(tinfo, tinfoAlt), orElse(finfo, finfoAlt)))
  }

  /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
  type Netlist = mutable.HashMap[WrappedExpression, Expression]

  /** Build a [[Netlist]] from a Module's connections and Nodes
    *
    * This assumes [[firrtl.LowForm LowForm]]
    *
    * @param mod [[firrtl.ir.Module Module]] from which to build a [[Netlist]]
    * @return [[Netlist]] of the module's connections and nodes
    */
  def buildNetlist(mod: Module): Netlist = {
    val netlist = new Netlist()
    def onStmt(stmt: Statement): Statement = {
      stmt.map(onStmt) match {
        case Connect(info, lhs, rhs) =>
          val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs)
          netlist(lhs) = expr
        case DefNode(info, nname, rhs) =>
          val expr = if (info == NoInfo) rhs else InfoExpr(info, rhs)
          netlist(WRef(nname)) = expr
        case _: IsInvalid => throwInternalError("Unexpected IsInvalid, should have been removed by now")
        case _ => // Do nothing
      }
      stmt
    }
    mod.map(onStmt)
    netlist
  }

  /** Flatten Register Updates
    *
    * Constructs nested mux trees (up to a certain arbitrary threshold) for register updates. This
    * can result in dead code that this function does NOT remove.
    *
    * @param mod [[firrtl.ir.Module Module]] to transform
    * @return [[firrtl.ir.Module Module]] with register updates flattened
    */
  def flattenReg(mod: Module): Module = {
    // We want to flatten Mux trees for reg updates into if-trees for improved QoR for conditional
    // updates.  Sometimes the fan-in for a register has a mux structure with repeated
    // sub-expressions that are themselves complex mux structures. These repeated structures can
    // cause explosions in the size and complexity of the Verilog. In addition, user code that
    // follows such structure often will have conditions in the sub-trees that are mutually
    // exclusive with the conditions in the muxes closer to the register input. For example:
    //
    // when a :      ; when 1
    //   r <= foo
    // when b :      ; when 2
    //   when a :
    //     r <= bar  ; when 3
    //
    // After expand whens, when 1 is a common sub-expression that will show up twice in the mux
    // structure from when 2:
    //
    // _GEN_0 = mux(a, foo, r)
    // _GEN_1 = mux(a, bar, _GEN_0)
    // r <= mux(b, _GEN_1, _GEN_0)
    //
    // Inlining _GEN_0 into _GEN_1 would result in unreachable lines in the Verilog. While we could
    // do some optimizations here, this is *not* really a problem, it's just that Verilog metrics
    // are based on the assumption of human-written code and as such it results in unreachable
    // lines. Simply not inlining avoids this issue and leaves the optimizations up to synthesis
    // tools which do a great job here.
    val maxDepth = 4

    val regUpdates = mutable.ArrayBuffer.empty[Connect]
    val netlist = buildNetlist(mod)

    // First traversal marks expression that would be inlined multiple times as endpoints
    // Note that we could traverse more than maxDepth times - this corresponds to an expression that
    // is already a very deeply nested mux
    def determineEndpoints(expr: Expression): collection.Set[WrappedExpression] = {
      val seen = mutable.HashSet.empty[WrappedExpression]
      val endpoint = mutable.HashSet.empty[WrappedExpression]
      def rec(depth: Int)(e: Expression): Unit = {
        val (_, ex) = kind(e) match {
          case NodeKind | WireKind if depth < maxDepth && !seen(e) =>
            seen += e
            unwrap(netlist.getOrElse(e, e))
          case _ => unwrap(e)
        }
        ex match {
          case Mux(_, tval, fval, _) =>
            rec(depth + 1)(tval)
            rec(depth + 1)(fval)
          case _ =>
            // Mark e not ex because original reference is the endpoint, not op or whatever
            endpoint += ex
        }
      }
      rec(0)(expr)
      endpoint
    }

    def constructRegUpdate(start: Expression): (Info, Expression) = {
      val endpoints = determineEndpoints(start)
      def rec(e: Expression): (Info, Expression) = {
        val (info, expr) = kind(e) match {
          case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e))
          case _                                    => unwrap(e)
        }
        expr match {
          case Mux(cond, tval, fval, tpe) =>
            val (tinfo, tvalx) = rec(tval)
            val (finfo, fvalx) = rec(fval)
            val infox = combineInfos(info, tinfo, finfo)
            (infox, Mux(cond, tvalx, fvalx, tpe))
          // Return the original expression to end flattening
          case _ => unwrap(e)
        }
      }
      rec(start)
    }

    def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
      case reg @ DefRegister(_, rname, _, _, resetCond, _) =>
        assert(
          resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
          "Synchronous reset should have already been made explicit!"
        )
        val ref = WRef(reg)
        val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref))
        val update = Connect(info, ref, rhs)
        regUpdates += update
        reg
      // Remove connections to Registers so we preserve LowFirrtl single-connection semantics
      case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt
      case other                                      => other
    }

    val bodyx = onStmt(mod.body)
    mod.copy(body = Block(bodyx +: regUpdates.toSeq))
  }

}

/** Flatten register update
  *
  * This transform flattens register updates into a single expression on the rhs of connection to
  * the register
  */
// TODO Preserve source locators
class FlattenRegUpdate extends Transform with DependencyAPIMigration {

  override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
    Seq(
      Dependency[BlackBoxSourceHelper],
      Dependency[FixAddingNegativeLiterals],
      Dependency[ReplaceTruncatingArithmetic],
      Dependency[InlineBitExtractionsTransform],
      Dependency[InlineAcrossCastsTransform],
      Dependency[LegalizeClocksAndAsyncResetsTransform]
    )

  override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized

  override def optionalPrerequisiteOf = Seq.empty

  override def invalidates(a: Transform): Boolean = a match {
    case _: DeadCodeElimination => true
    case _ => false
  }

  def execute(state: CircuitState): CircuitState = {
    val modulesx = state.circuit.modules.map {
      case mod: Module    => FlattenRegUpdate.flattenReg(mod)
      case ext: ExtModule => ext
    }
    state.copy(circuit = state.circuit.copy(modules = modulesx))
  }
}