aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
blob: ea694719f277e30fea4c8e43a7b876f1dc88826e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// See LICENSE for license details.

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.options.Dependency

import scala.collection.mutable

object FlattenRegUpdate {

  /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
  type Netlist = mutable.HashMap[WrappedExpression, Expression]

  /** Build a [[Netlist]] from a Module's connections and Nodes
    *
    * This assumes [[firrtl.LowForm LowForm]]
    *
    * @param mod [[firrtl.ir.Module Module]] from which to build a [[Netlist]]
    * @return [[Netlist]] of the module's connections and nodes
    */
  def buildNetlist(mod: Module): Netlist = {
    val netlist = new Netlist()
    def onStmt(stmt: Statement): Statement = {
      stmt.map(onStmt) match {
        case Connect(_, lhs, rhs) =>
          netlist(lhs) = rhs
        case DefNode(_, nname, rhs) =>
          netlist(WRef(nname)) = rhs
        case _: IsInvalid => throwInternalError("Unexpected IsInvalid, should have been removed by now")
        case _ => // Do nothing
      }
      stmt
    }
    mod.map(onStmt)
    netlist
  }

  /** Flatten Register Updates
    *
    * Constructs nested mux trees (up to a certain arbitrary threshold) for register updates. This
    * can result in dead code that this function does NOT remove.
    *
    * @param mod [[firrtl.ir.Module Module]] to transform
    * @return [[firrtl.ir.Module Module]] with register updates flattened
    */
  def flattenReg(mod: Module): Module = {
    // We want to flatten Mux trees for reg updates into if-trees for
    // improved QoR for conditional updates.  However, unbounded recursion
    // would take exponential time, so don't redundantly flatten the same
    // Mux more than a bounded number of times, preserving linear runtime.
    // The threshold is empirical but ample.
    val flattenThreshold = 4
    val numTimesFlattened = mutable.HashMap[Mux, Int]()
    def canFlatten(m: Mux): Boolean = {
      val n = numTimesFlattened.getOrElse(m, 0)
      numTimesFlattened(m) = n + 1
      n < flattenThreshold
    }

    val regUpdates = mutable.ArrayBuffer.empty[Connect]
    val netlist = buildNetlist(mod)

    def constructRegUpdate(e: Expression): Expression = {
      // Only walk netlist for nodes and wires, NOT registers or other state
      val expr = kind(e) match {
        case NodeKind | WireKind => netlist.getOrElse(e, e)
        case _ => e
      }
      expr match {
        case mux: Mux if canFlatten(mux) =>
          val tvalx = constructRegUpdate(mux.tval)
          val fvalx = constructRegUpdate(mux.fval)
          mux.copy(tval = tvalx, fval = fvalx)
        // Return the original expression to end flattening
        case _ => e
      }
    }

    def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
      case reg @ DefRegister(_, rname, _,_, resetCond, _) =>
        assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
          "Synchronous reset should have already been made explicit!")
        val ref = WRef(reg)
        val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref)))
        regUpdates += update
        reg
      // Remove connections to Registers so we preserve LowFirrtl single-connection semantics
      case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt
      case other => other
    }

    val bodyx = onStmt(mod.body)
    mod.copy(body = Block(bodyx +: regUpdates))
  }

}

/** Flatten register update
  *
  * This transform flattens register updates into a single expression on the rhs of connection to
  * the register
  */
// TODO Preserve source locators
class FlattenRegUpdate extends Transform with DependencyAPIMigration {

  override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
    Seq( Dependency[BlackBoxSourceHelper],
         Dependency[FixAddingNegativeLiterals],
         Dependency[ReplaceTruncatingArithmetic],
         Dependency[InlineBitExtractionsTransform],
         Dependency[InlineCastsTransform],
         Dependency[LegalizeClocksTransform] )

  override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized

  override def optionalPrerequisiteOf = Seq.empty

  override def invalidates(a: Transform): Boolean = a match {
    case _: DeadCodeElimination => true
    case _ => false
  }

  def execute(state: CircuitState): CircuitState = {
    val modulesx = state.circuit.modules.map {
      case mod: Module => FlattenRegUpdate.flattenReg(mod)
      case ext: ExtModule => ext
    }
    state.copy(circuit = state.circuit.copy(modules = modulesx))
  }
}