From c097e953ad6f003c6359e276b18c406baa881f46 Mon Sep 17 00:00:00 2001 From: Adam Izraelevitz Date: Fri, 10 Mar 2017 10:33:25 -0800 Subject: Added lesson2 --- src/main/scala/tutorial/AnalyzeCircuit.scala | 115 --------------- .../lesson1-circuit-traversal/AnalyzeCircuit.scala | 140 +++++++++++++++++++ .../lesson2-working-ir/AnalyzeCircuit.scala | 155 +++++++++++++++++++++ 3 files changed, 295 insertions(+), 115 deletions(-) delete mode 100644 src/main/scala/tutorial/AnalyzeCircuit.scala create mode 100644 src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala create mode 100644 src/main/scala/tutorial/lesson2-working-ir/AnalyzeCircuit.scala (limited to 'src/main') diff --git a/src/main/scala/tutorial/AnalyzeCircuit.scala b/src/main/scala/tutorial/AnalyzeCircuit.scala deleted file mode 100644 index 30a88cfd..00000000 --- a/src/main/scala/tutorial/AnalyzeCircuit.scala +++ /dev/null @@ -1,115 +0,0 @@ -package tutorial - -// Compiler Infrastructure -import firrtl.{Transform, LowForm, CircuitState} -// Firrtl IR classes -import firrtl.ir.{Circuit, DefModule, Statement, Expression, Mux} -// Map functions -import firrtl.Mappers._ -// Scala's mutable collections -import scala.collection.mutable - -/** Ledger - * - * Use for tracking [[Circuit]] statistics - * See [[AnalyzeCircuit]] - */ -class Ledger { - private var moduleName: Option[String] = None - private val moduleMuxMap = mutable.Map[String, Int]() - def foundMux: Unit = moduleName match { - case None => error("Module name not defined in Ledger!") - case Some(name) => moduleMuxMap(name) = moduleMuxMap.getOrElse(name, 0) + 1 - } - def setModuleName(name: String): Unit = { - moduleName = Some(name) - } - def serialize: String = { - moduleMuxMap map { case (module, nMux) => s"$module => $nMux muxes!" } mkString "\n" - } -} - -/** AnalyzeCircuit Transform - * - * Walks [[ir.Circuit]], and records the number of muxes it finds, per module. - * - * See the following links for more detailed explanations: - * Firrtl's IR: - * - https://github.com/ucb-bar/firrtl/wiki/Understanding-Firrtl-Intermediate-Representation - * Traversing a circuit: - * - https://github.com/ucb-bar/firrtl/wiki/traversing-a-circuit for more - * Common Pass Idioms: - * - https://github.com/ucb-bar/firrtl/wiki/Common-Pass-Idioms - */ -class AnalyzeCircuit extends Transform { - // Requires the [[Circuit]] form to be "low" - def inputForm = LowForm - // Indicates the output [[Circuit]] form to be "low" - def outputForm = LowForm - - // Called by [[Compiler]] to run your pass. [[CircuitState]] contains - // the circuit and its form, as well as other related data. - def execute(state: CircuitState): CircuitState = { - val ledger = new Ledger() - val circuit = state.circuit - - // Execute the function walkModule(ledger) on every [[DefModule]] in - // circuit, returning a new [[Circuit]] with new [[Seq]] of [[DefModule]]. - // - "higher order functions" - using a function as an object - // - "function currying" - partial argument notation - // - "infix notation" - fancy function calling syntax - // - "map" - classic functional programming concept - // - discard the returned new [[Circuit]] because circuit is unmodified - circuit map walkModule(ledger) - - // Print our ledger - println(ledger.serialize) - - // Return an unchanged [[CircuitState]] - state - } - - // Deeply visits every [[Statement]] in m. - def walkModule(ledger: Ledger)(m: DefModule): DefModule = { - // Set ledger to current module name - ledger.setModuleName(m.name) - - // Execute the function walkStatement(ledger) on every [[Statement]] in m. - // - return the new [[DefModule]] (in this case, its identical to m) - // - if m does not contain [[Statement]], map returns m. - m map walkStatement(ledger) - } - - // Deeply visits every [[Statement]] and [[Expression]] in s. - def walkStatement(ledger: Ledger)(s: Statement): Statement = { - - // Execute the function walkExpression(ledger) on every [[Expression]] in s. - // - discard the new [[Statement]] (in this case, its identical to s) - // - if s does not contain [[Expression]], map returns s. - s map walkExpression(ledger) - - // Execute the function walkStatement(ledger) on every [[Statement]] in s. - // - return the new [[Statement]] (in this case, its identical to s) - // - if s does not contain [[Statement]], map returns s. - s map walkStatement(ledger) - } - - // Deeply visits every [[Expression]] in e. - // - "post-order traversal" - handle e's children [[Expression]] before e - def walkExpression(ledger: Ledger)(e: Expression): Expression = { - - // Execute the function walkExpression(ledger) on every [[Expression]] in e. - // - return the new [[Expression]] (in this case, its identical to e) - // - if s does not contain [[Expression]], map returns e. - val visited = e map walkExpression(ledger) - - visited match { - // If e is a [[Mux]], increment our ledger and return e. - case Mux(cond, tval, fval, tpe) => - ledger.foundMux - e - // If e is not a [[Mux]], return e. - case e => e - } - } -} diff --git a/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala b/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala new file mode 100644 index 00000000..394e6ad8 --- /dev/null +++ b/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala @@ -0,0 +1,140 @@ +package tutorial +package lesson1 + +// Compiler Infrastructure +import firrtl.{Transform, LowForm, CircuitState} +// Firrtl IR classes +import firrtl.ir.{Circuit, DefModule, Statement, Expression, Mux} +// Map functions +import firrtl.Mappers._ +// Scala's mutable collections +import scala.collection.mutable + +/** Ledger tracks [[Circuit]] statistics + * + * In this lesson, we want to count the number of muxes in each + * module in our design. + * + * This [[Ledger]] class will be passed along as we walk our + * circuit, and help us count each [[Mux]] we find. + * + * See [[lesson1.AnalyzeCircuit]] + */ +class Ledger { + private var moduleName: Option[String] = None + private val modules = mutable.Set[String]() + private val moduleMuxMap = mutable.Map[String, Int]() + def foundMux: Unit = moduleName match { + case None => error("Module name not defined in Ledger!") + case Some(name) => moduleMuxMap(name) = moduleMuxMap.getOrElse(name, 0) + 1 + } + def getModuleName: String = moduleName match { + case None => error("Module name not defined in Ledger!") + case Some(name) => name + } + def setModuleName(myName: String): Unit = { + modules += myName + moduleName = Some(myName) + } + def serialize: String = { + modules map { myName => + s"$myName => ${moduleMuxMap.getOrElse(myName, 0)} muxes!" + } mkString "\n" + } +} + +/** AnalyzeCircuit Transform + * + * Walks [[ir.Circuit]], and records the number of muxes it finds, per module. + * + * While some compiler frameworks operate on graphs, we represent a Firrtl + * circuit using a tree representation: + * - A Firrtl [[Circuit]] contains a sequence of [[DefModule]]s. + * - A [[DefModule]] contains a sequence of [[Port]]s, and maybe a [[Statement]]. + * - A [[Statement]] can contain other [[Statement]]s, or [[Expression]]s. + * - A [[Expression]] can contain other [[Expression]]s. + * + * To visit all Firrtl IR nodes in a circuit, we write functions that recursively + * walk down this tree. To record statistics, we will pass along the [[Ledger]] + * class and use it when we come across a [[Mux]]. + * + * See the following links for more detailed explanations: + * Firrtl's IR: + * - https://github.com/ucb-bar/firrtl/wiki/Understanding-Firrtl-Intermediate-Representation + * Traversing a circuit: + * - https://github.com/ucb-bar/firrtl/wiki/traversing-a-circuit for more + * Common Pass Idioms: + * - https://github.com/ucb-bar/firrtl/wiki/Common-Pass-Idioms + */ +class AnalyzeCircuit extends Transform { + // Requires the [[Circuit]] form to be "low" + def inputForm = LowForm + // Indicates the output [[Circuit]] form to be "low" + def outputForm = LowForm + + // Called by [[Compiler]] to run your pass. [[CircuitState]] contains + // the circuit and its form, as well as other related data. + def execute(state: CircuitState): CircuitState = { + val ledger = new Ledger() + val circuit = state.circuit + + // Execute the function walkModule(ledger) on every [[DefModule]] in + // circuit, returning a new [[Circuit]] with new [[Seq]] of [[DefModule]]. + // - "higher order functions" - using a function as an object + // - "function currying" - partial argument notation + // - "infix notation" - fancy function calling syntax + // - "map" - classic functional programming concept + // - discard the returned new [[Circuit]] because circuit is unmodified + circuit map walkModule(ledger) + + // Print our ledger + println(ledger.serialize) + + // Return an unchanged [[CircuitState]] + state + } + + // Deeply visits every [[Statement]] in m. + def walkModule(ledger: Ledger)(m: DefModule): DefModule = { + // Set ledger to current module name + ledger.setModuleName(m.name) + + // Execute the function walkStatement(ledger) on every [[Statement]] in m. + // - return the new [[DefModule]] (in this case, its identical to m) + // - if m does not contain [[Statement]], map returns m. + m map walkStatement(ledger) + } + + // Deeply visits every [[Statement]] and [[Expression]] in s. + def walkStatement(ledger: Ledger)(s: Statement): Statement = { + + // Execute the function walkExpression(ledger) on every [[Expression]] in s. + // - discard the new [[Statement]] (in this case, its identical to s) + // - if s does not contain [[Expression]], map returns s. + s map walkExpression(ledger) + + // Execute the function walkStatement(ledger) on every [[Statement]] in s. + // - return the new [[Statement]] (in this case, its identical to s) + // - if s does not contain [[Statement]], map returns s. + s map walkStatement(ledger) + } + + // Deeply visits every [[Expression]] in e. + // - "post-order traversal" - handle e's children [[Expression]] before e + def walkExpression(ledger: Ledger)(e: Expression): Expression = { + + // Execute the function walkExpression(ledger) on every [[Expression]] in e. + // - return the new [[Expression]] (in this case, its identical to e) + // - if s does not contain [[Expression]], map returns e. + val visited = e map walkExpression(ledger) + + visited match { + // If e is a [[Mux]], increment our ledger and return e. + case Mux(cond, tval, fval, tpe) => + ledger.foundMux + e + // If e is not a [[Mux]], return e. + case e => e + } + } +} diff --git a/src/main/scala/tutorial/lesson2-working-ir/AnalyzeCircuit.scala b/src/main/scala/tutorial/lesson2-working-ir/AnalyzeCircuit.scala new file mode 100644 index 00000000..ba955b7c --- /dev/null +++ b/src/main/scala/tutorial/lesson2-working-ir/AnalyzeCircuit.scala @@ -0,0 +1,155 @@ +package tutorial +package lesson2 + +// Compiler Infrastructure +import firrtl.{Transform, LowForm, CircuitState} +// Firrtl IR classes +import firrtl.ir.{Circuit, DefModule, Statement, DefInstance, Expression, Mux} +// Firrtl compiler's working IR classes (WIR) +import firrtl.WDefInstance +// Map functions +import firrtl.Mappers._ +// Scala's mutable collections +import scala.collection.mutable + +/** Ledger tracks [[Circuit]] statistics + * + * In this lesson, we want to calculate the number of muxes, not just in + * a module, but also in any instances it has of other modules, etc. + * + * To do this, we need to update our Ledger class to keep track of this + * module instance information + * + * See [[lesson2.AnalyzeCircuit]] + */ +class Ledger { + private var moduleName: Option[String] = None + private val modules = mutable.Set[String]() + private val moduleMuxMap = mutable.Map[String, Int]() + private val moduleInstanceMap = mutable.Map[String, Seq[String]]() + def getModuleName: String = moduleName match { + case None => error("Module name not defined in Ledger!") + case Some(name) => name + } + def setModuleName(myName: String): Unit = { + modules += myName + moduleName = Some(myName) + } + def foundMux: Unit = { + val myName = getModuleName + moduleMuxMap(myName) = moduleMuxMap.getOrElse(myName, 0) + 1 + } + // Added this function to track when a module instantiates another module + def foundInstance(name: String): Unit = { + val myName = getModuleName + moduleInstanceMap(myName) = moduleInstanceMap.getOrElse(myName, Nil) :+ name + } + // Counts mux's in a module, and all its instances (recursively). + private def countMux(myName: String): Int = { + val myMuxes = moduleMuxMap.getOrElse(myName, 0) + val myInstanceMuxes = + moduleInstanceMap.getOrElse(myName, Nil).foldLeft(0) { + (total, name) => total + countMux(name) + } + myMuxes + myInstanceMuxes + } + // Display recursive total of muxes + def serialize: String = { + modules map { myName => s"$myName => ${countMux(myName)} muxes!" } mkString "\n" + } +} + +/** AnalyzeCircuit Transform + * + * Walks [[ir.Circuit]], and records the number of muxes and instances it + * finds, per module. + * + * While the Firrtl parser emits a bare form of the IR (located in firrtl.ir._), + * it is often useful to have more information in these case classes. To do this, + * the Firrtl compiler has mirror "working" classes for the following IR + * nodes (which contain additional fields): + * - DefInstance -> WDefInstance + * - SubAccess -> WSubAccess + * - SubIndex -> WSubIndex + * - SubField -> WSubField + * - Reference -> WRef + * + * Take a look at [[ToWorkingIR]] in src/main/scala/firrtl/passes/Passes.scala + * to see how Firrtl IR nodes are replaced with working IR nodes. + * + * Future lessons will explain the WIR's additional fields. For now, it is + * enough to know that the transform [[ResolveAndCheck]] populates these + * fields, and checks the legality of the circuit. If your transform is + * creating new WIR nodes, use the following "unknown" values in the WIR + * node, and then call [[ResolveAndCheck]] at the end of your transform: + * - Kind -> ExpKind + * - Gender -> UNKNOWNGENDER + * - Type -> UnknownType + * + * The following [[CircuitForm]]s require WIR instead of IR nodes: + * - HighForm + * - MidForm + * - LowForm + * + * See the following links for more detailed explanations: + * IR vs Working IR + * - TODO(izraelevitz) + */ +class AnalyzeCircuit extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + // Called by [[Compiler]] to run your pass. + def execute(state: CircuitState): CircuitState = { + val ledger = new Ledger() + val circuit = state.circuit + + // Execute the function walkModule(ledger) on all [[DefModule]] in circuit + circuit map walkModule(ledger) + + // Print our ledger + println(ledger.serialize) + + // Return an unchanged [[CircuitState]] + state + } + + // Deeply visits every [[Statement]] in m. + def walkModule(ledger: Ledger)(m: DefModule): DefModule = { + // Set ledger to current module name + ledger.setModuleName(m.name) + + // Execute the function walkStatement(ledger) on every [[Statement]] in m. + m map walkStatement(ledger) + } + + // Deeply visits every [[Statement]] and [[Expression]] in s. + def walkStatement(ledger: Ledger)(s: Statement): Statement = { + // Map the functions walkStatement(ledger) and walkExpression(ledger) + val visited = s map walkStatement(ledger) map walkExpression(ledger) + visited match { + // IR node [[DefInstance]] is previously replaced by WDefInstance, a + // "working" IR node + case DefInstance(info, name, module) => + error("All DefInstances should have been replaced by WDefInstances") + // Working IR Node [[WDefInstance]] is what the compiler uses + // See src/main/scala/firrtl/WIR.scala for all working IR nodes + case WDefInstance(info, name, module, tpe) => + ledger.foundInstance(module) + visited + case _ => visited + } + } + + // Deeply visits every [[Expression]] in e. + def walkExpression(ledger: Ledger)(e: Expression): Expression = { + // Execute the function walkExpression(ledger) on every [[Expression]] in e, + // then handle if a [[Mux]]. + e map walkExpression(ledger) match { + case Mux(cond, tval, fval, tpe) => + ledger.foundMux + e + case e => e + } + } +} -- cgit v1.2.3