aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala
blob: a75c19dcbbde952a405d3ae98d067d4659e3e652 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// SPDX-License-Identifier: Apache-2.0

package tutorial
package lesson2

// Compiler Infrastructure
import firrtl.{CircuitState, LowForm, Transform}
// Firrtl IR classes
import firrtl.ir.{DefInstance, DefModule, Expression, Mux, Statement}
// Map functions
import firrtl.Mappers._
// Scala's mutable collections
import scala.collection.mutable

/** Ledger tracks [[firrtl.ir.Circuit]] statistics
  *
  * In this lesson, we want to calculate the number of muxes, not just in a module, but also in any instances it has of
  * other modules, etc.
  *
  * To do this, we need to update our Ledger class to keep track of this module instance information
  *
  * See [[lesson2.AnalyzeCircuit]]
  */
class Ledger {
  private var moduleName: Option[String] = None
  private val modules = mutable.Set[String]()
  private val moduleMuxMap = mutable.Map[String, Int]()
  private val moduleInstanceMap = mutable.Map[String, Seq[String]]()
  def getModuleName: String = moduleName match {
    case None       => sys.error("Module name not defined in Ledger!")
    case Some(name) => name
  }
  def setModuleName(myName: String): Unit = {
    modules += myName
    moduleName = Some(myName)
  }
  def foundMux(): Unit = {
    val myName = getModuleName
    moduleMuxMap(myName) = moduleMuxMap.getOrElse(myName, 0) + 1
  }
  // Added this function to track when a module instantiates another module
  def foundInstance(name: String): Unit = {
    val myName = getModuleName
    moduleInstanceMap(myName) = moduleInstanceMap.getOrElse(myName, Nil) :+ name
  }
  // Counts mux's in a module, and all its instances (recursively).
  private def countMux(myName: String): Int = {
    val myMuxes = moduleMuxMap.getOrElse(myName, 0)
    val myInstanceMuxes =
      moduleInstanceMap.getOrElse(myName, Nil).foldLeft(0) { (total, name) =>
        total + countMux(name)
      }
    myMuxes + myInstanceMuxes
  }
  // Display recursive total of muxes
  def serialize: String = {
    modules.map { myName => s"$myName => ${countMux(myName)} muxes!" }.mkString("\n")
  }
}

/** AnalyzeCircuit Transform
  *
  * Walks [[firrtl.ir.Circuit]], and records the number of muxes and instances it finds, per module.
  *
  * While the Firrtl IR specification describes a written format, the AST nodes used internally by the
  * implementation have additional "analysis" fields.
  *
  * Take a look at [[firrtl.ir.Reference Reference]] to see how a reference to a component name is
  * augmented with relevant type, kind (memory, wire, etc), and flow information.
  *
  * Future lessons will explain the IR's additional fields. For now, it is enough to know that declaring
  * [[firrtl.stage.Forms.Resolved]] as a prerequisite is a handy shorthand for ensuring that all of these
  * fields will be populated with accurant information before your transform runs. If you create new IR
  * nodes and do not wish to calculate the proper final values for all these fields, you can populate them
  * with default 'unknown' values.
  *   - Kind -> ExpKind
  *   - Flow -> UnknownFlow
  *   - Type -> UnknownType
  */
class AnalyzeCircuit extends Transform {
  def inputForm = LowForm
  def outputForm = LowForm

  // Called by [[Compiler]] to run your pass.
  def execute(state: CircuitState): CircuitState = {
    val ledger = new Ledger()
    val circuit = state.circuit

    // Execute the function walkModule(ledger) on all [[DefModule]] in circuit
    circuit.map(walkModule(ledger))

    // Print our ledger
    println(ledger.serialize)

    // Return an unchanged [[CircuitState]]
    state
  }

  // Deeply visits every [[Statement]] in m.
  def walkModule(ledger: Ledger)(m: DefModule): DefModule = {
    // Set ledger to current module name
    ledger.setModuleName(m.name)

    // Execute the function walkStatement(ledger) on every [[Statement]] in m.
    m.map(walkStatement(ledger))
  }

  // Deeply visits every [[Statement]] and [[Expression]] in s.
  def walkStatement(ledger: Ledger)(s: Statement): Statement = {
    // Map the functions walkStatement(ledger) and walkExpression(ledger)
    val visited = s.map(walkStatement(ledger)).map(walkExpression(ledger))
    visited match {
      case DefInstance(info, name, module, tpe) =>
        ledger.foundInstance(module)
        visited
      case _ => visited
    }
  }

  // Deeply visits every [[Expression]] in e.
  def walkExpression(ledger: Ledger)(e: Expression): Expression = {
    // Execute the function walkExpression(ledger) on every [[Expression]] in e,
    //  then handle if a [[Mux]].
    e.map(walkExpression(ledger)) match {
      case mux: Mux =>
        ledger.foundMux()
        mux
      case notmux => notmux
    }
  }
}