aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/InlineCasts.scala
blob: 3dac938e6d9554f563e0607e376c9b85df343ca0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
// See LICENSE for license details.

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.PrimOps.Pad
import firrtl.options.Dependency

import firrtl.Utils.{isCast, isBitExtract, NodeMap}

object InlineCastsTransform {

  // Checks if an Expression is made up of only casts terminated by a Literal or Reference
  // There must be at least one cast
  // Note that this can have false negatives but MUST NOT have false positives
  private def isSimpleCast(castSeen: Boolean)(expr: Expression): Boolean = expr match {
    case _: WRef | _: Literal | _: WSubField => castSeen
    case DoPrim(op, args, _,_) if isCast(op) => args.forall(isSimpleCast(true))
    case _ => false
  }

  /** Recursively replace [[WRef]]s with new [[firrtl.ir.Expression Expression]]s
    *
    * @param replace a '''mutable''' HashMap mapping [[WRef]]s to values with which the [[WRef]]
    * will be replaced. It is '''not''' mutated in this function
    * @param expr the Expression being transformed
    * @return Returns expr with [[WRef]]s replaced by values found in replace
    */
  def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match {
    // Anything that may generate a part-select should not be inlined!
    case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
    case e => e.map(onExpr(replace)) match {
      case e @ WRef(name, _,_,_) =>
        replace.get(name)
          .filter(isSimpleCast(castSeen=false))
          .getOrElse(e)
      case e @ DoPrim(op, Seq(WRef(name, _,_,_)), _,_) if isCast(op) =>
        replace.get(name)
          .map(value => e.copy(args = Seq(value)))
          .getOrElse(e)
      case other => other // Not a candidate
    }
  }

  /** Inline casts in a Statement
    *
    * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
    * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate
    * it if stmt is a [[firrtl.ir.DefNode DefNode]]
    * with a value that is a cast [[firrtl.ir.PrimOp PrimpOp]]
    * @param stmt the Statement being searched for nodes and transformed
    * @return Returns stmt with casts inlined
    */
  def onStmt(netlist: NodeMap)(stmt: Statement): Statement =
    stmt.map(onStmt(netlist)).map(onExpr(netlist)) match {
      case node @ DefNode(_, name, value) =>
        netlist(name) = value
        node
      case other => other
    }

  /** Replaces truncating arithmetic in a Module */
  def onMod(mod: DefModule): DefModule = mod.map(onStmt(new NodeMap))
}

/** Inline nodes that are simple casts */
class InlineCastsTransform extends Transform with DependencyAPIMigration {

  override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
    Seq( Dependency[BlackBoxSourceHelper],
         Dependency[FixAddingNegativeLiterals],
         Dependency[ReplaceTruncatingArithmetic],
         Dependency[InlineBitExtractionsTransform],
         Dependency[PropagatePresetAnnotations] )

  override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized

  override def optionalPrerequisiteOf = Seq.empty

  override def invalidates(a: Transform): Boolean = a match {
    case _: LegalizeClocksTransform => true
    case _ => false
  }

  def execute(state: CircuitState): CircuitState = {
    val modulesx = state.circuit.modules.map(InlineCastsTransform.onMod(_))
    state.copy(circuit = state.circuit.copy(modules = modulesx))
  }
}