aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/DeadCodeElimination.scala
blob: 54ac76fe71ea02cac84f59ae13b1857a22b43ba5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// See LICENSE for license details.

package firrtl.passes

import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._

import annotation.tailrec

object DeadCodeElimination extends Transform {
  def inputForm = UnknownForm
  def outputForm = UnknownForm
  private def dceOnce(renames: RenameMap)(s: Statement): (Statement, Long) = {
    val referenced = collection.mutable.HashSet[String]()
    var nEliminated = 0L

    def checkExpressionUse(e: Expression): Expression = {
      e match {
        case WRef(name, _, _, _) => referenced += name
        case _ => e map checkExpressionUse
      }
      e
    }

    def checkUse(s: Statement): Statement = s map checkUse map checkExpressionUse

    def maybeEliminate(x: Statement, name: String) =
      if (referenced(name)) x
      else {
        nEliminated += 1
        renames.delete(name)
        EmptyStmt
      }

    def removeUnused(s: Statement): Statement = s match {
      case x: DefRegister => maybeEliminate(x, x.name)
      case x: DefWire => maybeEliminate(x, x.name)
      case x: DefNode => maybeEliminate(x, x.name)
      case x => s map removeUnused
    }

    checkUse(s)
    (removeUnused(s), nEliminated)
  }

  @tailrec
  private def dce(renames: RenameMap)(s: Statement): Statement = {
    val (res, n) = dceOnce(renames)(s)
    if (n > 0) dce(renames)(res) else res
  }

  def execute(state: CircuitState): CircuitState = {
    val c = state.circuit
    val renames = RenameMap()
    renames.setCircuit(c.main)
    val modulesx = c.modules.map {
      case m: ExtModule => m
      case m: Module =>
        renames.setModule(m.name)
        Module(m.info, m.name, m.ports, dce(renames)(m.body))
    }
    val result = Circuit(c.info, modulesx, c.main)
    CircuitState(result, outputForm, state.annotations, Some(renames))
  }
}