aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjackkoenig2015-12-08 12:39:49 -0800
committerjackkoenig2015-12-08 12:39:49 -0800
commit2b04de4567e0ade01fdbfa6921fae91537180461 (patch)
tree1c49094b17bc92c456d60d3c291430e1968da501
parentf93f0b2c941960943d84c03ec4a9f0f0ba6c98b5 (diff)
Refactored MIDAS code into its own repo
-rw-r--r--src/main/scala/firrtl/Driver.scala57
-rw-r--r--src/main/scala/firrtl/Passes.scala13
-rw-r--r--src/main/scala/midas/Fame.scala367
-rw-r--r--src/main/scala/midas/Utils.scala329
4 files changed, 21 insertions, 745 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala
index 7887d87f..141a326a 100644
--- a/src/main/scala/firrtl/Driver.scala
+++ b/src/main/scala/firrtl/Driver.scala
@@ -7,7 +7,6 @@ import scala.io.Source
import Utils._
import DebugUtils._
import Passes._
-import midas.Fame1
object Driver
{
@@ -39,55 +38,15 @@ object Driver
logger.printlnDebug(ast)
}
- def toVerilogWithFame(input: String, output: String)
- {
- val logger = Logger(new PrintWriter(System.err, true))
-
- val stanzaPreTransform = List("rem-spec-chars", "high-form-check",
- "temp-elim", "to-working-ir", "resolve-kinds", "infer-types",
- "resolve-genders", "check-genders", "check-kinds", "check-types",
- "expand-accessors", "lower-to-ground", "inline-indexers", "infer-types",
- "check-genders", "expand-whens", "infer-widths", "real-ir", "width-check",
- "pad-widths", "const-prop", "split-expressions", "width-check",
- "high-form-check", "low-form-check", "check-init")
- val stanzaPostTransform = List("rem-spec-chars", "high-form-check",
- "temp-elim", "to-working-ir", "resolve-kinds", "infer-types",
- "resolve-genders", "check-genders", "check-kinds", "check-types",
- "expand-accessors", "lower-to-ground", "inline-indexers", "infer-types",
- "check-genders", "expand-whens", "infer-widths", "real-ir", "width-check",
- "pad-widths", "const-prop", "split-expressions", "width-check",
- "high-form-check", "low-form-check", "check-init")
+ // Should we just remove logger?
+ private def executePassesWithLogger(ast: Circuit, passes: Seq[Circuit => Circuit])(implicit logger: Logger): Circuit = {
+ if (passes.isEmpty) ast
+ else executePasses(passes.head(ast), passes.tail)
+ }
- //// Don't lower
- //val temp1 = genTempFilename(input)
- //val ast = Parser.parse(Source.fromFile(input).getLines)
- //val writer = new PrintWriter(new File(temp1))
- //val ast2 = fame1Transform(ast)
- //writer.write(ast2.serialize())
- //writer.close()
-
- // Lower-to-Ground with Stanza FIRRTL
- //val temp1 = genTempFilename(input)
- val temp1 = input + ".1.tmp"
- val preCmd = Seq("firrtl-stanza", "-i", input, "-o", temp1, "-b", "firrtl") ++ stanzaPreTransform.flatMap(Seq("-x", _))
- println(preCmd.mkString(" "))
- preCmd.!
-
- // Read in and execute infer-types
- val ast = Parser.parse(input, Source.fromFile(temp1).getLines)
- val ast2 = inferTypes(ast)(logger)
-
- // FAME-1 Transformation
- //val temp2 = genTempFilename(input)
- val temp2 = input + ".2.tmp"
- val writer = new PrintWriter(new File(temp2))
- val ast3 = Fame1.transform(ast2)
- writer.write(ast3.serialize())
- writer.close()
-
- val postCmd = Seq("firrtl-stanza", "-i", temp2, "-o", output, "-X", "verilog")
- println(postCmd.mkString(" "))
- postCmd.!
+ def executePasses(ast: Circuit, passes: Seq[Circuit => Circuit]): Circuit = {
+ implicit val logger = Logger() // No logging
+ executePassesWithLogger(ast, passes)
}
private def verilog(input: String, output: String)(implicit logger: Logger)
diff --git a/src/main/scala/firrtl/Passes.scala b/src/main/scala/firrtl/Passes.scala
index 39e6b64e..f64d67bb 100644
--- a/src/main/scala/firrtl/Passes.scala
+++ b/src/main/scala/firrtl/Passes.scala
@@ -7,6 +7,19 @@ import Primops._
object Passes {
+ // TODO Perhaps we should get rid of Logger since this map would be nice
+ ////private val defaultLogger = Logger()
+ //private def mapNameToPass = Map[String, Circuit => Circuit] (
+ // "infer-types" -> inferTypes
+ //)
+ def nameToPass(name: String): Circuit => Circuit = {
+ implicit val logger = Logger() // throw logging away
+ //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name))
+ name match {
+ case "infer-types" => inferTypes
+ }
+ }
+
private def toField(p: Port)(implicit logger: Logger): Field = {
logger.trace(s"toField called on port ${p.serialize}")
p.dir match {
diff --git a/src/main/scala/midas/Fame.scala b/src/main/scala/midas/Fame.scala
deleted file mode 100644
index aedadba5..00000000
--- a/src/main/scala/midas/Fame.scala
+++ /dev/null
@@ -1,367 +0,0 @@
-
-package midas
-
-import Utils._
-import firrtl._
-import firrtl.Utils._
-
-/** FAME-1 Transformation
- *
- * This pass takes a lowered-to-ground circuit and performs a FAME-1 (Decoupled) transformation
- * to the circuit
- * It does this by creating a simulation module wrapper around the circuit, if we can gate the
- * clock, then there need be no modification to the target RTL, if we can't, then the target
- * RTL will have to be modified (by adding a midasFire input and use this signal to gate
- * register enable
- *
- * ALGORITHM
- * 1. Flatten RTL
- * a. Create NewTop
- * b. Instantiate Top in NewTop
- * c. Iteratively pull all sim tagged instances out of hierarchy to NewTop
- * i. Move instance declaration from child module to parent module
- * ii. Create io in child module corresponding to io of instance
- * iii. Connect original instance io in child to new child io
- * iv. Connect child io in parent to instance io
- * v. Repeate until instance is in SimTop
- * (if black box, repeat until completely removed from design)
- * * Post-flattening invariants
- * - No combinational logic on NewTop
- * - All and only instances in NewTop are sim tagged
- * - No black boxes in design
- * 2. Simulation Transformation
- * a. Perform Decoupled Transformation on every RTL Module
- * i. Add targetFire signal input
- * ii. Find all DefInst nodes and propogate targetFire to the instances
- * iii. Find all registers and add when statement connecting regIn to regOut when !targetFire
- * b. Iteratively transform each inst in Top (see example firrtl in dram_midas top dir)
- * i. Create wrapper class
- * ii. Create input and output (ready, valid) pairs for every other sim module this module connects to
- * * Note that TopIO counts as a "sim module"
- * iii. Create targetFire as AND of inputs.valid and outputs.ready
- * iv. Connect targetFire to targetFire input of target rtl inst
- * v. Connect target IO to wrapper IO, except connect target clock to simClock
- *
- * TODO
- * - Is it okay to have ready signals for input queues depend on valid signals for those queues? This is generally bad
- * - Change sequential memory read enable to work with targetFire
- * - Implement Flatten RTL
- * - Refactor important strings/naming to API (eg. "topIO" needs to be a constant defined somewhere or something)
- * - Check that circuit is in LowFIRRTL?
- *
- * NOTES
- * - How do we only transform the necessary modules? Should there be a MIDAS list of modules
- * or something?
- * * YES, this will be done by requiring the user to instantiate modules that should be split
- * with something like: val module = new MIDASModule(class... etc.)
- * - There cannot be nested DecoupledIO or ValidIO
- * - How do output consumes tie in to MIDAS fire? If all of our outputs are not consumed
- * in a given cycle, do we block midas$fire on the next cycle? Perhaps there should be
- * a register for not having consumed all outputs last cycle
- * - If our outputs are not consumed we also need to be sure not to consume out inputs,
- * so the logic for this must depend on the previous cycle being consumed as well
- * - We also need a way to determine the difference between the MIDAS modules and their
- * connecting Queues, perhaps they should be MIDAS queues, which then perhaps prints
- * out a listing of all queues so that they can be properly transformed
- * * What do these MIDAS queues look like since we're enforcing true decoupled
- * interfaces?
- */
-object Fame1 {
-
- // Constants, common nodes, and common types used throughout
- private type PortMap = Map[String, Seq[String]]
- private val PortMap = Map[String, Seq[String]]()
- private type ConnMap = Map[String, PortMap]
- private val ConnMap = Map[String, PortMap]()
- private type SimQMap = Map[(String, String), Module]
- private val SimQMap = Map[(String, String), Module]()
-
- private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1)))
- private val hostValid = Field("hostValid", Default, UIntType(IntWidth(1)))
- private val hostClock = Port(NoInfo, "hostClock", Input, ClockType)
- private val hostReset = Port(NoInfo, "hostReset", Input, UIntType(IntWidth(1)))
- private val targetFire = Port(NoInfo, "targetFire", Input, UIntType(IntWidth(1)))
-
- private def wrapName(name: String): String = s"SimWrap_${name}"
- private def unwrapName(name: String): String = name.stripPrefix("SimWrap_")
- private def queueName(src: String, dst: String): String = s"SimQueue_${src}_${dst}"
- private def instName(name: String): String = s"inst_${name}"
- private def unInstName(name: String): String = name.stripPrefix("inst_")
-
- private def genHostDecoupled(fields: Seq[Field]): BundleType = {
- BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(fields)))
- }
-
- // ********** findPortConn **********
- // This takes lowFIRRTL top module that follows invariants described above and returns a connection Map
- // of instanceName -> (instanctPorts -> portEndpoint)
- // It honestly feels kind of brittle given it assumes there will be no intermediate nodes or anything in
- // the way of direct connections between IO of module instances
- private def processConnectExp(exp: Exp): (String, String) = {
- val unsupportedExp = new Exception("Unsupported Exp for finding port connections: " + exp)
- exp match {
- case ref: Ref => ("topIO", ref.name)
- case sub: Subfield =>
- sub.exp match {
- case ref: Ref => (ref.name, sub.name)
- case _ => throw unsupportedExp
- }
- case exp: Exp => throw unsupportedExp
- }
- }
- private def processConnect(conn: Connect): ConnMap = {
- val lhs = processConnectExp(conn.lhs)
- val rhs = processConnectExp(conn.rhs)
- Map(lhs._1 -> Map(lhs._2 -> Seq(rhs._1)), rhs._1 -> Map(rhs._2 -> Seq(lhs._1))).withDefaultValue(PortMap)
- }
- private def findPortConn(connMap: ConnMap, stmts: Seq[Stmt]): ConnMap = {
- if (stmts.isEmpty) connMap
- else {
- stmts.head match {
- case conn: Connect => {
- val newConnMap = processConnect(conn)
- findPortConn((connMap map { case (k,v) =>
- k -> merge(Seq(v, newConnMap(k))) { (_, v1, v2) => v1 ++ v2 }}), stmts.tail )
- }
- case _ => findPortConn(connMap, stmts.tail)
- }
- }
- }
- private def findPortConn(top: Module, insts: Seq[DefInst]): ConnMap = {
- val initConnMap = (insts map ( _.name -> PortMap )).toMap ++ Map("topIO" -> PortMap)
- val topStmts = top.stmt match {
- case b: Block => b.stmts
- case s: Stmt => Seq(s) // This honestly shouldn't happen but let's be safe
- }
- findPortConn(initConnMap, topStmts)
- }
-
- // Removes clocks from a portmap
- private def scrubClocks(ports: Seq[Port], portMap: PortMap): PortMap = {
- val clocks = ports filter (_.tpe == ClockType) map (_.name)
- portMap filter { case (portName, _) => !clocks.contains(portName) }
- }
-
- // ********** transformRTL **********
- // Takes an RTL module and give it targetFire input, propogates targetFire to all child instances,
- // puts targetFire on regEnable for all registers
- // TODO
- // - Add smem support
- private def transformRTL(m: Module): Module = {
- val ports = m.ports :+ targetFire
- val instProp = getDefInsts(m) map { inst =>
- Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name))
- }
- val regEn = getDefRegs(m) map { reg =>
- When(NoInfo, DoPrimop(Not, Seq(buildExp(targetFire.name)), Seq(), UnknownType),
- Connect(NoInfo, buildExp(reg.name), buildExp(reg.name)), EmptyStmt)
- }
- Module(m.info, m.name, ports, Block(m.stmt +: (instProp ++ regEn)))
- }
-
- // ********** genWrapperModule **********
- // Generates FAME-1 Decoupled wrappers for simulation module instances
- private def genWrapperModule(inst: DefInst, portMap: PortMap): Module = {
-
- val instIO = getDefInstType(inst)
- val nameToField = (instIO.fields map (f => f.name -> f)).toMap
-
- val connections = (portMap map (_._2)).toSeq.flatten.distinct // modules this inst connects to
- // Build simPort for each connecting module
- // TODO This whole chunk really ought to be rewritten or made a function
- val connPorts = connections map { c =>
- // Get ports that connect to this particular module as fields
- val fields = (portMap filter (_._2.contains(c))).keySet.toSeq.sorted map (nameToField(_))
- val noClock = fields filter (_.tpe != ClockType) // Remove clock
- val inputSet = noClock filter (_.dir == Reverse) map (f => Field(f.name, Default, f.tpe))
- val outputSet = noClock filter (_.dir == Default)
- Port(inst.info, c, Output, BundleType(
- (if (inputSet.isEmpty) Seq()
- else Seq(Field("hostIn", Reverse, genHostDecoupled(inputSet)))
- ) ++
- (if (outputSet.isEmpty) Seq()
- else Seq(Field("hostOut", Default, genHostDecoupled(outputSet)))
- )
- ))
- }
- val ports = hostClock +: hostReset +: connPorts // Add host and host reset
-
- // targetFire is signal to indicate when a simulation module can execute, this is indicated by all of its inputs
- // being valid and all of its outputs being ready
- val targetFireInputs = (connPorts map { port =>
- getFields(port) map { field =>
- field.dir match {
- case Reverse => buildExp(Seq(port.name, field.name, hostValid.name))
- case Default => buildExp(Seq(port.name, field.name, hostReady.name))
- }
- }
- }).flatten
-
- val defTargetFire = DefNode(inst.info, targetFire.name, genPrimopReduce(And, targetFireInputs))
- val connectTargetFire = Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name))
-
- // Only consume tokens when the module fires
- // TODO is it bad to have the input readys depend on the input valid signals?
- val inputsReady = (connPorts map { port =>
- getFields(port) filter (_.dir == Reverse) map { field => // filter to only take inputs
- Connect(inst.info, buildExp(Seq(port.name, field.name, hostReady.name)), buildExp(targetFire.name))
- }
- }).flatten
-
- // Outputs are valid on cycles where we fire
- val outputsValid = (connPorts map { port =>
- getFields(port) filter (_.dir == Default) map { field => // filter to only take outputs
- Connect(inst.info, buildExp(Seq(port.name, field.name, hostValid.name)), buildExp(targetFire.name))
- }
- }).flatten
-
- // Connect up all of the IO of the RTL module to sim module IO, except clock which should be connected
- // This currently assumes naming things that are also done above when generating connPorts
- val connectedInstIOFields = instIO.fields filter(field => portMap.contains(field.name)) // skip unconnected IO
- val instIOConnect = (connectedInstIOFields map { field =>
- field.tpe match {
- case ClockType => Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)),
- Ref(hostClock.name, ClockType)))
- case _ => field.dir match {
- case Default => portMap(field.name) map { endpoint =>
- Connect(inst.info, buildExp(Seq(endpoint, "hostOut", "hostBits", field.name)),
- buildExp(Seq(inst.name, field.name)))
- }
- case Reverse => {
- if (portMap(field.name).length > 1)
- throw new Exception("It is illegal to have more than 1 connection to a single input" + field)
- Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)),
- buildExp(Seq(portMap(field.name).head, "hostIn", "hostBits", field.name))))
- }
- }
- }
- }).flatten
- val stmts = Block(Seq(defTargetFire) ++ inputsReady ++ outputsValid ++ Seq(inst) ++
- Seq(connectTargetFire) ++ instIOConnect)
-
- Module(inst.info, wrapName(inst.name), ports, stmts)
- }
-
- // ********** generateSimQueues **********
- // Takes Seq of SimWrapper modules
- // Returns Map of (src, dest) -> SimQueue
- // To prevent duplicates, instead of creating a map with (src, dest) as the key, we could instead
- // only one direction of the queue for each simport of each module. The only problem with this is
- // it won't create queues for TopIO since that isn't a real module
- private def generateSimQueues(wrappers: Seq[Module]): SimQMap = {
- def rec(wrappers: Seq[Module], map: SimQMap): SimQMap = {
- if (wrappers.isEmpty) map
- else {
- val w = wrappers.head
- val name = unwrapName(w.name)
- val newMap = (w.ports filter(isSimPort) map { port =>
- (splitSimPort(port) map { field =>
- val (src, dst) = if (field.dir == Default) (name, port.name) else (port.name, name)
- if (map.contains((src, dst))) SimQMap
- else Map((src, dst) -> buildSimQueue(queueName(src, dst), getHostBits(field).tpe))
- }).flatten.toMap
- }).flatten.toMap
- rec(wrappers.tail, map ++ newMap)
- }
- }
- rec(wrappers, SimQMap)
- }
-
- // ********** generateSimTop **********
- // Creates the Simulation Top module where all sim modules and sim queues are instantiated and connected
- private def transformTopIO(ports: Seq[Port]): Seq[Port] = {
- val noClock = ports filter (_.tpe != ClockType)
- val inputs = noClock filter (_.dir == Input) map (_.toField.flip) // Flip because wrapping port is input
- val outputs = noClock filter (_.dir == Output) map (_.toField)
-
- Seq(Port(NoInfo, "io", Output, BundleType(Seq(Field("hostIn", Reverse, genHostDecoupled(inputs)),
- Field("hostOut", Default, genHostDecoupled(outputs))))))
- }
- private def generateSimTop(wrappers: Seq[Module], simQueues: SimQMap, portMap: PortMap, rtlTop: Module): Module = {
- val insts = (wrappers map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) ++
- (simQueues.values map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) })
- val connectClocks = (wrappers ++ simQueues.values) map { m =>
- Connect(NoInfo, buildExp(Seq(instName(m.name), hostClock.name)), buildExp(hostClock.name))
- }
- val connectResets = (wrappers ++ simQueues.values) map { m =>
- Connect(NoInfo, buildExp(Seq(instName(m.name), hostReset.name)), buildExp(hostReset.name))
- }
- // Connect queues to simulation modules (excludes IO)
- val connectQueues = (simQueues map { case ((src, dst), queue) =>
- (if (src == "topIO") Seq()
- else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(queue.name), "io", "enq")),
- buildExp(Seq(instName(wrapName(src)), dst, "hostOut"))))
- ) ++
- (if (dst == "topIO") Seq()
- else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(wrapName(dst)), src, "hostIn")),
- buildExp(Seq(instName(queue.name), "io", "deq"))))
- )
- }).flatten
- // Connect IO queues, Src means input, Dst means output (ie. the outside word is the Src or Dst)
- val ioSrcQueues = (simQueues filter {case ((src, dst), _) => src == "topIO"} map {case (_, queue) => queue}).toSeq
- val ioDstQueues = (simQueues filter {case ((src, dst), _) => dst == "topIO"} map {case (_, queue) => queue}).toSeq
- val ioSrcSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Input) map (_.name)
- val ioDstSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Output) map (_.name)
-
- val ioSrcQueueConnect = if (ioSrcQueues.length > 0) {
- val readySignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostReady.name)))
- val validSignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostValid.name)))
-
- (ioSrcSignals map { sig =>
- (portMap(sig) map { dst =>
- Connect(NoInfo, buildExp(Seq(instName(queueName("topIO", dst)), "io", "enq", "hostBits", sig)),
- buildExp(Seq("io", "hostIn", "hostBits", sig)))
- })
- }).flatten ++
- (validSignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostIn", hostValid.name))))) :+
- Connect(NoInfo, buildExp(Seq("io", "hostIn", hostReady.name)), genPrimopReduce(And, readySignals))
- } else Seq(EmptyStmt)
-
- val ioDstQueueConnect = if (ioDstQueues.length > 0) {
- val readySignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostReady.name)))
- val validSignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostValid.name)))
-
- (ioDstSignals map { sig =>
- (portMap(sig) map { src =>
- Connect(NoInfo, buildExp(Seq("io", "hostOut", "hostBits", sig)),
- buildExp(Seq(instName(queueName(src, "topIO")), "io", "deq", "hostBits", sig)))
- })
- }).flatten ++
- (readySignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostOut", hostReady.name))))) :+
- Connect(NoInfo, buildExp(Seq("io", "hostOut", hostValid.name)), genPrimopReduce(And, validSignals))
- } else Seq(EmptyStmt)
-
- val stmts = Block(insts ++ connectClocks ++ connectResets ++ connectQueues ++ ioSrcQueueConnect ++ ioDstQueueConnect)
- val ports = Seq(hostClock, hostReset) ++ transformTopIO(rtlTop.ports)
- Module(NoInfo, "SimTop", ports, stmts)
- }
-
- // ********** transform **********
- // Perform FAME-1 Transformation for MIDAS
- def transform(c: Circuit): Circuit = {
- // We should check that the invariants mentioned above are true
- val nameToModule = (c.modules map (m => m.name -> m))(collection.breakOut): Map[String, Module]
- val top = nameToModule(c.name)
-
- val rtlModules = c.modules filter (_.name != top.name) map (transformRTL)
-
- val insts = getDefInsts(top)
-
- val portConn = findPortConn(top, insts)
-
- // Check that port Connections include all ports for each instance?
-
- val simWrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name)))
-
- val simQueues = generateSimQueues(simWrappers)
-
- // Remove duplicate simWrapper and simQueue modules?
-
- val simTop = generateSimTop(simWrappers, simQueues, portConn("topIO"), top)
-
- val modules = rtlModules ++ simWrappers ++ simQueues.values.toSeq ++ Seq(simTop)
-
- Circuit(c.info, simTop.name, modules)
- }
-
-}
diff --git a/src/main/scala/midas/Utils.scala b/src/main/scala/midas/Utils.scala
deleted file mode 100644
index 003f22bb..00000000
--- a/src/main/scala/midas/Utils.scala
+++ /dev/null
@@ -1,329 +0,0 @@
-
-package midas
-
-import firrtl._
-import firrtl.Utils._
-
-object Utils {
-
- // Merges a sequence of maps via the provided function f
- // Taken from: https://groups.google.com/forum/#!topic/scala-user/HaQ4fVRjlnU
- def merge[K, V](maps: Seq[Map[K, V]])(f: (K, V, V) => V): Map[K, V] = {
- maps.foldLeft(Map.empty[K, V]) { case (merged, m) =>
- m.foldLeft(merged) { case (acc, (k, v)) =>
- acc.get(k) match {
- case Some(existing) => acc.updated(k, f(k, existing, v))
- case None => acc.updated(k, v)
- }
- }
- }
- }
-
- // This doesn't work because of Type Erasure >.<
- //private def getStmts[A <: Stmt](s: Stmt): Seq[A] = {
- // s match {
- // case a: A => Seq(a)
- // case b: Block => b.stmts.map(getStmts[A]).flatten
- // case _ => Seq()
- // }
- //}
- //private def getStmts[A <: Stmt](m: Module): Seq[A] = getStmts[A](m.stmt)
-
- def getDefRegs(s: Stmt): Seq[DefReg] = {
- s match {
- case r: DefReg => Seq(r)
- case b: Block => b.stmts.map(getDefRegs).flatten
- case _ => Seq()
- }
- }
- def getDefRegs(m: Module): Seq[DefReg] = getDefRegs(m.stmt)
-
- def getDefInsts(s: Stmt): Seq[DefInst] = {
- s match {
- case i: DefInst => Seq(i)
- case b: Block => b.stmts.map(getDefInsts).flatten
- case _ => Seq()
- }
- }
- def getDefInsts(m: Module): Seq[DefInst] = getDefInsts(m.stmt)
-
- def getDefInstRef(inst: DefInst): Ref = {
- inst.module match {
- case ref: Ref => ref
- case _ => throw new Exception("Invalid module expression for DefInst: " + inst.serialize)
- }
- }
-
- // DefInsts have an Expression for the module, this expression should be a reference this
- // reference has a tpe that should be a bundle representing the IO of that module class
- def getDefInstType(inst: DefInst): BundleType = {
- val ref = getDefInstRef(inst)
- ref.tpe match {
- case b: BundleType => b
- case _ => throw new Exception("Invalid reference type for DefInst: " + inst.serialize)
- }
- }
-
- def getModuleFromDefInst(nameToModule: Map[String, Module], inst: DefInst): Module = {
- val instModule = getDefInstRef(inst)
- if(!nameToModule.contains(instModule.name))
- throw new Exception(s"Module ${instModule.name} not found in circuit!")
- else
- nameToModule(instModule.name)
- }
-
- // Takes a set of strings or ints and returns equivalent expression node
- // Strings correspond to subfields/references, ints correspond to indexes
- // eg. Seq(io, port, ready) => io.port.ready
- // Seq(io, port, 5, valid) => io.port[5].valid
- // Seq(3) => UInt("h3")
- def buildExp(names: Seq[Any]): Exp = {
- def rec(names: Seq[Any]): Exp = {
- names.head match {
- // Useful for adding on indexes or subfields
- case head: Exp => head
- // Int -> UInt/SInt/Index
- case head: Int =>
- if( names.tail.isEmpty ) // Is the UInt/SInt inference good enough?
- if( head > 0 ) UIntValue(head, UnknownWidth) else SIntValue(head, UnknownWidth)
- else Index(rec(names.tail), head, UnknownType)
- // String -> Ref/Subfield
- case head: String =>
- if( names.tail.isEmpty ) Ref(head, UnknownType)
- else Subfield(rec(names.tail), head, UnknownType)
- case _ => throw new Exception("Invalid argument type to buildExp! " + names)
- }
- }
- rec(names.reverse) // Let user specify in more natural format
- }
- def buildExp(name: Any): Exp = buildExp(Seq(name))
-
- def genPrimopReduce(op: Primop, args: Seq[Exp]): Exp = {
- if( args.length == 0 ) throw new Exception("genPrimopReduce called on empty sequence!")
- else if( args.length == 1 ) args.head
- else if( args.length == 2 ) DoPrimop(op, Seq(args.head, args.last), Seq(), UnknownType)
- else DoPrimop(op, Seq(args.head, genPrimopReduce(op, args.tail)), Seq(), UnknownType)
- }
-
- // Checks if a firrtl.Port matches the MIDAS SimPort pattern
- // This currently just checks that the port is of type bundle with ONLY the members
- // hostIn and/or hostOut with correct directions
- def isSimPort(port: Port): Boolean = {
- //println("isSimPort called on port " + port.serialize)
- port.tpe match {
- case b: BundleType => {
- b.fields map { field =>
- if( field.name == "hostIn" ) field.dir == Reverse
- else if( field.name == "hostOut" ) field.dir == Default
- else false
- } reduce ( _ & _ )
- }
- case _ => false
- }
- }
-
- def splitSimPort(port: Port): Seq[Field] = {
- try {
- val b = port.tpe.asInstanceOf[BundleType]
- Seq(b.fields.find(_.name == "hostIn"), b.fields.find(_.name == "hostOut")).flatten
- } catch {
- case e: Exception => throw new Exception("Invalid SimPort " + port.serialize)
- }
- }
-
- // From simulation host decoupled, return hostbits field
- def getHostBits(field: Field): Field = {
- try {
- val b = field.tpe.asInstanceOf[BundleType]
- b.fields.find(_.name == "hostBits").get
- } catch {
- case e: Exception => throw new Exception("Invalid SimField " + field.serialize)
- }
- }
-
- // For a port that is known to be of type BundleType, return the fields of that bundle
- def getFields(port: Port): Seq[Field] = {
- port.tpe match {
- case b: BundleType => b.fields
- case _ => throw new Exception("getFields called on invalid port " + port)
- }
- }
-
- // Recursively iterates through firrtl.Type returning sequence of names to address signals
- // * Intended for use with recursive bundle types
- def enumerateMembers(tpe: Type): Seq[Seq[Any]] = {
- def rec(tpe: Type, path: Seq[Any], members: Seq[Seq[Any]]): Seq[Seq[Any]] = {
- tpe match {
- case b: BundleType => (b.fields map ( f => rec(f.tpe, path :+ f.name, members) )).flatten
- case v: VectorType => (Seq.tabulate(v.size.toInt) ( i => rec(v.tpe, path :+ i, members) )).flatten
- case _ => members :+ path
- }
- }
- rec(tpe, Seq[Any](), Seq[Seq[Any]]())
- }
-
- // Queue
- // TODO
- // - Insert null tokens upon hostReset (or should this be elsewhere?)
- def buildSimQueue(name: String, tpe: Type): Module = {
- val scopeSpaces = " " * 4 // Spaces before lines in module scope, for default assignments
- val templatedQueue =
-// """
-// circuit `NAME:
-// module `NAME :
-// input hostClock : Clock
-// input hostReset : UInt<1>
-// output io : {flip enq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, deq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, count : UInt<3>}
-//
-// io.count := UInt<1>("h00")
-// `DEFAULT_ASSIGN
-// io.deq.hostValid := UInt<1>("h00")
-// io.enq.hostReady := UInt<1>("h00")
-// cmem ram : `TYPE[4], hostClock
-// reg T_80 : UInt<2>, hostClock, hostReset
-// onreset T_80 := UInt<2>("h00")
-// reg T_82 : UInt<2>, hostClock, hostReset
-// onreset T_82 := UInt<2>("h00")
-// reg maybe_full : UInt<1>, hostClock, hostReset
-// onreset maybe_full := UInt<1>("h00")
-// node ptr_match = eq(T_80, T_82)
-// node T_87 = eq(maybe_full, UInt<1>("h00"))
-// node empty = and(ptr_match, T_87)
-// node full = and(ptr_match, maybe_full)
-// node maybe_flow = and(UInt<1>("h00"), empty)
-// node do_flow = and(maybe_flow, io.deq.hostReady)
-// node T_93 = and(io.enq.hostReady, io.enq.hostValid)
-// node T_95 = eq(do_flow, UInt<1>("h00"))
-// node do_enq = and(T_93, T_95)
-// node T_97 = and(io.deq.hostReady, io.deq.hostValid)
-// node T_99 = eq(do_flow, UInt<1>("h00"))
-// node do_deq = and(T_97, T_99)
-// when do_enq :
-// infer accessor T_101 = ram[T_80]
-// T_101 <> io.enq.hostBits
-// node T_109 = eq(T_80, UInt<2>("h03"))
-// node T_111 = and(UInt<1>("h00"), T_109)
-// node T_114 = addw(T_80, UInt<1>("h01"))
-// node T_115 = mux(T_111, UInt<1>("h00"), T_114)
-// T_80 := T_115
-// skip
-// when do_deq :
-// node T_117 = eq(T_82, UInt<2>("h03"))
-// node T_119 = and(UInt<1>("h00"), T_117)
-// node T_122 = addw(T_82, UInt<1>("h01"))
-// node T_123 = mux(T_119, UInt<1>("h00"), T_122)
-// T_82 := T_123
-// skip
-// node T_124 = neq(do_enq, do_deq)
-// when T_124 :
-// maybe_full := do_enq
-// skip
-// node T_126 = eq(empty, UInt<1>("h00"))
-// node T_128 = and(UInt<1>("h00"), io.enq.hostValid)
-// node T_129 = or(T_126, T_128)
-// io.deq.hostValid := T_129
-// node T_131 = eq(full, UInt<1>("h00"))
-// node T_133 = and(UInt<1>("h00"), io.deq.hostReady)
-// node T_134 = or(T_131, T_133)
-// io.enq.hostReady := T_134
-// infer accessor T_135 = ram[T_82]
-// wire T_149 : `TYPE
-// T_149 <> T_135
-// when maybe_flow :
-// T_149 <> io.enq.hostBits
-// skip
-// io.deq.hostBits <> T_149
-// node ptr_diff = subw(T_80, T_82)
-// node T_157 = and(maybe_full, ptr_match)
-// node T_158 = cat(T_157, ptr_diff)
-// io.count := T_158
-// """
- """
-circuit `NAME:
- module `NAME :
- input hostClock : Clock
- input hostReset : UInt<1>
- output io : {flip enq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, deq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, count : UInt<3>}
-
- io.count := UInt<1>("h00")
- `DEFAULT_ASSIGN
- io.deq.hostValid := UInt<1>("h00")
- io.enq.hostReady := UInt<1>("h00")
- cmem ram : `TYPE[4], hostClock
- reg T_404 : UInt<2>, hostClock, hostReset
- onreset T_404 := UInt<2>("h00")
- reg T_406 : UInt<2>, hostClock, hostReset
- onreset T_406 := UInt<2>("h00")
- reg maybe_full : UInt<1>, hostClock, hostReset
- onreset maybe_full := UInt<1>("h00")
- reg add_token_on_reset : UInt<1>, hostClock, hostReset
- onreset add_token_on_reset := UInt<1>("h01")
- add_token_on_reset := UInt<1>("h00")
- node ptr_match = eq(T_404, T_406)
- node T_414 = eq(maybe_full, UInt<1>("h00"))
- node empty = and(ptr_match, T_414)
- node full = and(ptr_match, maybe_full)
- node maybe_flow = and(UInt<1>("h00"), empty)
- node do_flow = and(maybe_flow, io.deq.hostReady)
- node T_420 = and(io.enq.hostReady, io.enq.hostValid)
- node T_422 = eq(do_flow, UInt<1>("h00"))
- node do_enq = and(T_420, T_422)
- node T_424 = and(io.deq.hostReady, io.deq.hostValid)
- node T_426 = eq(do_flow, UInt<1>("h00"))
- node do_deq = and(T_424, T_426)
- node T_428 = or(do_enq, add_token_on_reset)
- when T_428 :
- infer accessor T_443 = ram[T_404]
- T_443 := io.enq.hostBits
- node T_473 = eq(T_404, UInt<2>("h03"))
- node T_475 = and(UInt<1>("h00"), T_473)
- node T_478 = addw(T_404, UInt<1>("h01"))
- node T_479 = mux(T_475, UInt<1>("h00"), T_478)
- T_404 := T_479
- skip
- when do_deq :
- node T_481 = eq(T_406, UInt<2>("h03"))
- node T_483 = and(UInt<1>("h00"), T_481)
- node T_486 = addw(T_406, UInt<1>("h01"))
- node T_487 = mux(T_483, UInt<1>("h00"), T_486)
- T_406 := T_487
- skip
- node T_488 = neq(do_enq, do_deq)
- when T_488 :
- maybe_full := do_enq
- skip
- node T_490 = eq(empty, UInt<1>("h00"))
- node T_492 = and(UInt<1>("h00"), io.enq.hostValid)
- node T_493 = or(T_490, T_492)
- io.deq.hostValid := T_493
- node T_495 = eq(full, UInt<1>("h00"))
- node T_497 = and(UInt<1>("h00"), io.deq.hostReady)
- node T_498 = or(T_495, T_497)
- io.enq.hostReady := T_498
- infer accessor T_513 = ram[T_406]
- wire T_599 : `TYPE
- T_599 := T_513
- when maybe_flow :
- T_599 := io.enq.hostBits
- skip
- io.deq.hostBits := T_599
- node ptr_diff = subw(T_404, T_406)
- node T_629 = and(maybe_full, ptr_match)
- node T_630 = cat(T_629, ptr_diff)
- io.count := T_630
- """
- // Generate initial values
- val signals = enumerateMembers(tpe) map ( Seq("io", "deq", "hostBits") ++ _ )
- val defaultAssign = signals map { sig =>
- scopeSpaces + Connect(NoInfo, buildExp(sig), UIntValue(0, UnknownWidth)).serialize
- }
-
- val concreteQueue = templatedQueue.replaceAllLiterally("`NAME", name).
- replaceAllLiterally("`TYPE", tpe.serialize).
- replaceAllLiterally(scopeSpaces+"`DEFAULT_ASSIGN", defaultAssign.mkString("\n"))
-
- val ast = firrtl.Parser.parse(concreteQueue.split("\n"))
- ast.modules.head
- }
-
-}