aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/CombineCats.scala
blob: 4f678826ce663f7c93fa3b3963269af342d6a6c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.PrimOps._
import firrtl.WrappedExpression._
import firrtl.annotations.NoTargetAnnotation
import firrtl.options.PreservesAll
import firrtl.options.Dependency

import scala.collection.mutable

case class MaxCatLenAnnotation(maxCatLen: Int) extends NoTargetAnnotation

object CombineCats {
  /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them paired with their Cat length */
  type Netlist = mutable.HashMap[WrappedExpression, (Int, Expression)]

  def expandCatArgs(maxCatLen: Int, netlist: Netlist)(expr: Expression): (Int, Expression) = expr match {
    case cat@DoPrim(Cat, args, _, _) =>
      val (a0Len, a0Expanded) = expandCatArgs(maxCatLen - 1, netlist)(args.head)
      val (a1Len, a1Expanded) = expandCatArgs(maxCatLen - a0Len, netlist)(args(1))
      (a0Len + a1Len, cat.copy(args = Seq(a0Expanded, a1Expanded)).asInstanceOf[Expression])
    case other =>
      netlist.get(we(expr)).collect {
        case (len, cat@DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat)
      }.getOrElse((1, other))
  }

  def onStmt(maxCatLen: Int, netlist: Netlist)(stmt: Statement): Statement = {
    stmt.map(onStmt(maxCatLen, netlist)) match {
      case node@DefNode(_, name, value) =>
        val catLenAndVal = value match {
          case cat@DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat)
          case other => (1, other)
        }
        netlist(we(WRef(name))) = catLenAndVal
        node.copy(value = catLenAndVal._2)
      case other => other
    }
  }

  def onMod(maxCatLen: Int)(mod: DefModule): DefModule = mod.map(onStmt(maxCatLen, new Netlist))
}

/** Combine Cat DoPrims
  *
  * Expands the arguments of any Cat DoPrims if they are references to other Cat DoPrims.
  * Operates only on Cat DoPrims that are node values.
  *
  * Use [[MaxCatLenAnnotation]] to limit the number of elements that can be concatenated.
  * The default maximum number of elements is 10.
  */
class CombineCats extends Transform with DependencyAPIMigration with PreservesAll[Transform] {

  override def prerequisites = firrtl.stage.Forms.LowForm ++
    Seq( Dependency(passes.RemoveValidIf),
         Dependency[firrtl.transforms.ConstantPropagation],
         Dependency(firrtl.passes.memlib.VerilogMemDelays),
         Dependency(firrtl.passes.SplitExpressions) )

  override def optionalPrerequisites = Seq.empty

  override def dependents = Seq(
    Dependency[SystemVerilogEmitter],
    Dependency[VerilogEmitter] )

  val defaultMaxCatLen = 10

  def execute(state: CircuitState): CircuitState = {
    val maxCatLen = state.annotations.collectFirst {
      case m: MaxCatLenAnnotation => m.maxCatLen
    }.getOrElse(defaultMaxCatLen)

    val modulesx = state.circuit.modules.map(CombineCats.onMod(maxCatLen))
    state.copy(circuit = state.circuit.copy(modules = modulesx))
  }
}