aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/InlineBitExtractions.scala
blob: 4b2773c506a73693ec952d46393993cda2e1cda9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// SPDX-License-Identifier: Apache-2.0

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.options.Dependency
import firrtl.PrimOps.{Bits, Head, Shr, Tail}
import firrtl.Utils.{isBitExtract, isTemp}
import firrtl.WrappedExpression._

import scala.collection.mutable

object InlineBitExtractionsTransform {

  // Checks if an Expression is made up of only Bits terminated by a Literal or Reference.
  // private because it's not clear if this definition of "Simple Expression" would be useful elsewhere.
  // Note that this can have false negatives but MUST NOT have false positives.
  private def isSimpleExpr(expr: Expression): Boolean = expr match {
    case _: WRef | _: Literal | _: WSubField => true
    case DoPrim(op, args, _, _) if isBitExtract(op) => args.forall(isSimpleExpr)
    case _                                          => false
  }

  // replace Head/Tail/Shr with Bits for easier back-to-back Bits Extractions
  private def lowerToDoPrimOpBits(expr: Expression): Expression = expr match {
    case DoPrim(Head, rhs, c, tpe) if isSimpleExpr(expr) =>
      val msb = bitWidth(rhs.head.tpe) - 1
      val lsb = bitWidth(rhs.head.tpe) - c.head
      DoPrim(Bits, rhs, Seq(msb, lsb), tpe)
    case DoPrim(Tail, rhs, c, tpe) if isSimpleExpr(expr) =>
      val msb = bitWidth(rhs.head.tpe) - c.head - 1
      DoPrim(Bits, rhs, Seq(msb, 0), tpe)
    case DoPrim(Shr, rhs, c, tpe) if isSimpleExpr(expr) =>
      DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe) - 1, c.head), tpe)
    case _ => expr // Not a candidate
  }

  /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
  type Netlist = mutable.HashMap[WrappedExpression, Expression]

  /** Recursively replace [[WRef]]s with new [[firrtl.ir.Expression Expression]]s
    *
    * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
    * [[firrtl.ir.Expression Expression Expression]]s. It is '''not''' mutated in this function
    * @param expr the Expression being transformed
    * @return Returns expr with Bits inlined
    */
  def onExpr(netlist: Netlist)(expr: Expression): Expression = {
    expr.map(onExpr(netlist)) match {
      case e @ WRef(name, _, _, _) =>
        netlist
          .get(we(e))
          .filter(isBitExtract)
          .getOrElse(e)
      // replace back-to-back Bits Extractions
      case lhs @ DoPrim(lop, ival, lc, ltpe) if isSimpleExpr(lhs) =>
        ival.head match {
          case of @ DoPrim(rop, rhs, rc, rtpe) if isSimpleExpr(of) =>
            (lop, rop) match {
              case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head.min(rc.head)), ltpe)
              case (Tail, Tail) => DoPrim(Tail, rhs, Seq(lc.head + rc.head), ltpe)
              case (Shr, Shr)   => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe)
              case (_, _) =>
                (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match {
                  case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) =>
                    DoPrim(Bits, rhs, Seq(lmsb + rlsb, llsb + rlsb), ltpe)
                  case (_, _) => lhs // Not a candidate
                }
            }
          case _ => lhs // Not a candidate
        }
      case other => other // Not a candidate
    }
  }

  /** Inline bits in a Statement
    *
    * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
    * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is
    * a [[firrtl.ir.DefNode DefNode]] with a Temporary name and a value that is a [[firrtl.ir.PrimOp PrimOp]] Bits
    * @param stmt the Statement being searched for nodes and transformed
    * @return Returns stmt with Bits inlined
    */
  def onStmt(netlist: Netlist)(stmt: Statement): Statement =
    stmt.map(onStmt(netlist)).map(onExpr(netlist)) match {
      case node @ DefNode(_, name, value) if isTemp(name) =>
        netlist(we(WRef(name))) = value
        node
      case other => other
    }

  /** Replaces bits in a Module */
  def onMod(mod: DefModule): DefModule = mod.map(onStmt(new Netlist))
}

/** Inline nodes that are simple bits */
class InlineBitExtractionsTransform extends Transform with DependencyAPIMigration {

  override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
    Seq(
      Dependency[BlackBoxSourceHelper],
      Dependency[FixAddingNegativeLiterals],
      Dependency[ReplaceTruncatingArithmetic]
    )

  override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized

  override def optionalPrerequisiteOf = Seq.empty

  override def invalidates(a: Transform) = false

  def execute(state: CircuitState): CircuitState = {
    val modulesx = state.circuit.modules.map(InlineBitExtractionsTransform.onMod(_))
    state.copy(circuit = state.circuit.copy(modules = modulesx))
  }
}