aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala
blob: b34b1f6a4603cc0fce2f8ca280aab6bd1065aba3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// SPDX-License-Identifier: Apache-2.0

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.PrimOps._
import firrtl.WrappedExpression._
import firrtl.options.Dependency

import scala.collection.mutable

object ReplaceTruncatingArithmetic {

  /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
  type Netlist = mutable.HashMap[WrappedExpression, Expression]

  private val SeqBIOne = Seq(BigInt(1))

  /** Replaces truncating arithmetic in an Expression
    *
    * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
    * [[firrtl.ir.Expression Expression]]s. It is '''not''' mutated in this function
    * @param expr the Expression being transformed
    * @return Returns expr with truncating arithmetic replaced
    */
  def onExpr(netlist: Netlist)(expr: Expression): Expression =
    expr.map(onExpr(netlist)) match {
      // If an unsigned wrapping add/sub
      case orig @ DoPrim(Tail, Seq(e), SeqBIOne, tailtpe) =>
        netlist.getOrElse(we(e), e) match {
          case DoPrim(Add, args, cs, u: UIntType) => DoPrim(Addw, args, cs, tailtpe)
          case DoPrim(Sub, args, cs, u: UIntType) => DoPrim(Subw, args, cs, tailtpe)
          case _ => orig // Not a candidate
        }
      // If a signed wrapping add/sub, there should be a cast
      case orig @ DoPrim(AsSInt, Seq(x), _, casttpe) =>
        netlist.getOrElse(we(x), x) match {
          case DoPrim(Tail, Seq(e), SeqBIOne, tailtpe) =>
            netlist.getOrElse(we(e), e) match {
              case DoPrim(Add, args, cs, s: SIntType) => DoPrim(Addw, args, cs, casttpe)
              case DoPrim(Sub, args, cs, s: SIntType) => DoPrim(Subw, args, cs, casttpe)
              case _ => orig // Not a candidate
            }
          case _ => orig // Not a candidate
        }
      case other => other // Not a candidate
    }

  /** Replaces truncating arithmetic in a Statement
    *
    * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
    * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt contains a [[firrtl.ir.DefNode
    * DefNode]]
    * @param stmt the Statement being searched for nodes and transformed
    * @return Returns stmt with truncating arithmetic replaced
    */
  def onStmt(netlist: Netlist)(stmt: Statement): Statement =
    stmt.map(onStmt(netlist)).map(onExpr(netlist)) match {
      case node @ DefNode(_, name, value) =>
        netlist(we(WRef(name))) = value
        node
      case other => other
    }

  /** Replaces truncating arithmetic in a Module */
  def onMod(mod: DefModule): DefModule = mod.map(onStmt(new Netlist))
}

/** Replaces non-expanding arithmetic
  *
  * In the case where the result of `add` or `sub` immediately throws away the expanded msb, this
  * transform will replace the operation with a non-expanding operator `addw` or `subw`
  * respectively.
  *
  * @note This replaces some FIRRTL primops with ops that are not actually legal FIRRTL. They are
  * useful for emission to languages that support non-expanding arithmetic (like Verilog)
  */
class ReplaceTruncatingArithmetic extends Transform with DependencyAPIMigration {

  override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
    Seq(Dependency[BlackBoxSourceHelper], Dependency[FixAddingNegativeLiterals])

  override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized

  override def optionalPrerequisiteOf = Seq.empty

  override def invalidates(a: Transform) = false

  def execute(state: CircuitState): CircuitState = {
    val modulesx = state.circuit.modules.map(ReplaceTruncatingArithmetic.onMod(_))
    state.copy(circuit = state.circuit.copy(modules = modulesx))
  }
}