aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/tutorial/AnalyzeCircuit.scala
blob: 30a88cfd76f93f1669525dc9b0ea4a475f30c272 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package tutorial

// Compiler Infrastructure
import firrtl.{Transform, LowForm, CircuitState}
// Firrtl IR classes
import firrtl.ir.{Circuit, DefModule, Statement, Expression, Mux}
// Map functions
import firrtl.Mappers._
// Scala's mutable collections
import scala.collection.mutable

/** Ledger
  *
  * Use for tracking [[Circuit]] statistics
  * See [[AnalyzeCircuit]]
  */
class Ledger {
  private var moduleName: Option[String] = None
  private val moduleMuxMap = mutable.Map[String, Int]()
  def foundMux: Unit = moduleName match {
    case None => error("Module name not defined in Ledger!")
    case Some(name) => moduleMuxMap(name) = moduleMuxMap.getOrElse(name, 0) + 1
  }
  def setModuleName(name: String): Unit = {
    moduleName = Some(name)
  }
  def serialize: String = {
    moduleMuxMap map { case (module, nMux) => s"$module => $nMux muxes!" } mkString "\n"
  }
}

/** AnalyzeCircuit Transform
  *
  * Walks [[ir.Circuit]], and records the number of muxes it finds, per module.
  *
  * See the following links for more detailed explanations:
  * Firrtl's IR:
  *   - https://github.com/ucb-bar/firrtl/wiki/Understanding-Firrtl-Intermediate-Representation
  * Traversing a circuit:
  *   - https://github.com/ucb-bar/firrtl/wiki/traversing-a-circuit for more
  * Common Pass Idioms:
  *   - https://github.com/ucb-bar/firrtl/wiki/Common-Pass-Idioms
  */
class AnalyzeCircuit extends Transform {
  // Requires the [[Circuit]] form to be "low"
  def inputForm = LowForm
  // Indicates the output [[Circuit]] form to be "low"
  def outputForm = LowForm

  // Called by [[Compiler]] to run your pass. [[CircuitState]] contains
  // the circuit and its form, as well as other related data.
  def execute(state: CircuitState): CircuitState = {
    val ledger = new Ledger()
    val circuit = state.circuit

    // Execute the function walkModule(ledger) on every [[DefModule]] in
    // circuit, returning a new [[Circuit]] with new [[Seq]] of [[DefModule]].
    //   - "higher order functions" - using a function as an object
    //   - "function currying" - partial argument notation
    //   - "infix notation" - fancy function calling syntax
    //   - "map" - classic functional programming concept
    //   - discard the returned new [[Circuit]] because circuit is unmodified
    circuit map walkModule(ledger)

    // Print our ledger
    println(ledger.serialize)

    // Return an unchanged [[CircuitState]]
    state
  }

  // Deeply visits every [[Statement]] in m.
  def walkModule(ledger: Ledger)(m: DefModule): DefModule = {
    // Set ledger to current module name
    ledger.setModuleName(m.name)

    // Execute the function walkStatement(ledger) on every [[Statement]] in m.
    //   - return the new [[DefModule]] (in this case, its identical to m)
    //   - if m does not contain [[Statement]], map returns m.
    m map walkStatement(ledger)
  }

  // Deeply visits every [[Statement]] and [[Expression]] in s.
  def walkStatement(ledger: Ledger)(s: Statement): Statement = {

    // Execute the function walkExpression(ledger) on every [[Expression]] in s.
    //   - discard the new [[Statement]] (in this case, its identical to s)
    //   - if s does not contain [[Expression]], map returns s.
    s map walkExpression(ledger)

    // Execute the function walkStatement(ledger) on every [[Statement]] in s.
    //   - return the new [[Statement]] (in this case, its identical to s)
    //   - if s does not contain [[Statement]], map returns s.
    s map walkStatement(ledger)
  }

  // Deeply visits every [[Expression]] in e.
  //   - "post-order traversal" - handle e's children [[Expression]] before e
  def walkExpression(ledger: Ledger)(e: Expression): Expression = {

    // Execute the function walkExpression(ledger) on every [[Expression]] in e.
    //   - return the new [[Expression]] (in this case, its identical to e)
    //   - if s does not contain [[Expression]], map returns e.
    val visited = e map walkExpression(ledger)

    visited match {
      // If e is a [[Mux]], increment our ledger and return e.
      case Mux(cond, tval, fval, tpe) =>
        ledger.foundMux
        e
      // If e is not a [[Mux]], return e.
      case e => e
    }
  }
}