diff options
| author | jackkoenig | 2015-12-08 12:39:49 -0800 |
|---|---|---|
| committer | jackkoenig | 2015-12-08 12:39:49 -0800 |
| commit | 2b04de4567e0ade01fdbfa6921fae91537180461 (patch) | |
| tree | 1c49094b17bc92c456d60d3c291430e1968da501 | |
| parent | f93f0b2c941960943d84c03ec4a9f0f0ba6c98b5 (diff) | |
Refactored MIDAS code into its own repo
| -rw-r--r-- | src/main/scala/firrtl/Driver.scala | 57 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Passes.scala | 13 | ||||
| -rw-r--r-- | src/main/scala/midas/Fame.scala | 367 | ||||
| -rw-r--r-- | src/main/scala/midas/Utils.scala | 329 |
4 files changed, 21 insertions, 745 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 7887d87f..141a326a 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -7,7 +7,6 @@ import scala.io.Source import Utils._ import DebugUtils._ import Passes._ -import midas.Fame1 object Driver { @@ -39,55 +38,15 @@ object Driver logger.printlnDebug(ast) } - def toVerilogWithFame(input: String, output: String) - { - val logger = Logger(new PrintWriter(System.err, true)) - - val stanzaPreTransform = List("rem-spec-chars", "high-form-check", - "temp-elim", "to-working-ir", "resolve-kinds", "infer-types", - "resolve-genders", "check-genders", "check-kinds", "check-types", - "expand-accessors", "lower-to-ground", "inline-indexers", "infer-types", - "check-genders", "expand-whens", "infer-widths", "real-ir", "width-check", - "pad-widths", "const-prop", "split-expressions", "width-check", - "high-form-check", "low-form-check", "check-init") - val stanzaPostTransform = List("rem-spec-chars", "high-form-check", - "temp-elim", "to-working-ir", "resolve-kinds", "infer-types", - "resolve-genders", "check-genders", "check-kinds", "check-types", - "expand-accessors", "lower-to-ground", "inline-indexers", "infer-types", - "check-genders", "expand-whens", "infer-widths", "real-ir", "width-check", - "pad-widths", "const-prop", "split-expressions", "width-check", - "high-form-check", "low-form-check", "check-init") + // Should we just remove logger? + private def executePassesWithLogger(ast: Circuit, passes: Seq[Circuit => Circuit])(implicit logger: Logger): Circuit = { + if (passes.isEmpty) ast + else executePasses(passes.head(ast), passes.tail) + } - //// Don't lower - //val temp1 = genTempFilename(input) - //val ast = Parser.parse(Source.fromFile(input).getLines) - //val writer = new PrintWriter(new File(temp1)) - //val ast2 = fame1Transform(ast) - //writer.write(ast2.serialize()) - //writer.close() - - // Lower-to-Ground with Stanza FIRRTL - //val temp1 = genTempFilename(input) - val temp1 = input + ".1.tmp" - val preCmd = Seq("firrtl-stanza", "-i", input, "-o", temp1, "-b", "firrtl") ++ stanzaPreTransform.flatMap(Seq("-x", _)) - println(preCmd.mkString(" ")) - preCmd.! - - // Read in and execute infer-types - val ast = Parser.parse(input, Source.fromFile(temp1).getLines) - val ast2 = inferTypes(ast)(logger) - - // FAME-1 Transformation - //val temp2 = genTempFilename(input) - val temp2 = input + ".2.tmp" - val writer = new PrintWriter(new File(temp2)) - val ast3 = Fame1.transform(ast2) - writer.write(ast3.serialize()) - writer.close() - - val postCmd = Seq("firrtl-stanza", "-i", temp2, "-o", output, "-X", "verilog") - println(postCmd.mkString(" ")) - postCmd.! + def executePasses(ast: Circuit, passes: Seq[Circuit => Circuit]): Circuit = { + implicit val logger = Logger() // No logging + executePassesWithLogger(ast, passes) } private def verilog(input: String, output: String)(implicit logger: Logger) diff --git a/src/main/scala/firrtl/Passes.scala b/src/main/scala/firrtl/Passes.scala index 39e6b64e..f64d67bb 100644 --- a/src/main/scala/firrtl/Passes.scala +++ b/src/main/scala/firrtl/Passes.scala @@ -7,6 +7,19 @@ import Primops._ object Passes { + // TODO Perhaps we should get rid of Logger since this map would be nice + ////private val defaultLogger = Logger() + //private def mapNameToPass = Map[String, Circuit => Circuit] ( + // "infer-types" -> inferTypes + //) + def nameToPass(name: String): Circuit => Circuit = { + implicit val logger = Logger() // throw logging away + //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name)) + name match { + case "infer-types" => inferTypes + } + } + private def toField(p: Port)(implicit logger: Logger): Field = { logger.trace(s"toField called on port ${p.serialize}") p.dir match { diff --git a/src/main/scala/midas/Fame.scala b/src/main/scala/midas/Fame.scala deleted file mode 100644 index aedadba5..00000000 --- a/src/main/scala/midas/Fame.scala +++ /dev/null @@ -1,367 +0,0 @@ - -package midas - -import Utils._ -import firrtl._ -import firrtl.Utils._ - -/** FAME-1 Transformation - * - * This pass takes a lowered-to-ground circuit and performs a FAME-1 (Decoupled) transformation - * to the circuit - * It does this by creating a simulation module wrapper around the circuit, if we can gate the - * clock, then there need be no modification to the target RTL, if we can't, then the target - * RTL will have to be modified (by adding a midasFire input and use this signal to gate - * register enable - * - * ALGORITHM - * 1. Flatten RTL - * a. Create NewTop - * b. Instantiate Top in NewTop - * c. Iteratively pull all sim tagged instances out of hierarchy to NewTop - * i. Move instance declaration from child module to parent module - * ii. Create io in child module corresponding to io of instance - * iii. Connect original instance io in child to new child io - * iv. Connect child io in parent to instance io - * v. Repeate until instance is in SimTop - * (if black box, repeat until completely removed from design) - * * Post-flattening invariants - * - No combinational logic on NewTop - * - All and only instances in NewTop are sim tagged - * - No black boxes in design - * 2. Simulation Transformation - * a. Perform Decoupled Transformation on every RTL Module - * i. Add targetFire signal input - * ii. Find all DefInst nodes and propogate targetFire to the instances - * iii. Find all registers and add when statement connecting regIn to regOut when !targetFire - * b. Iteratively transform each inst in Top (see example firrtl in dram_midas top dir) - * i. Create wrapper class - * ii. Create input and output (ready, valid) pairs for every other sim module this module connects to - * * Note that TopIO counts as a "sim module" - * iii. Create targetFire as AND of inputs.valid and outputs.ready - * iv. Connect targetFire to targetFire input of target rtl inst - * v. Connect target IO to wrapper IO, except connect target clock to simClock - * - * TODO - * - Is it okay to have ready signals for input queues depend on valid signals for those queues? This is generally bad - * - Change sequential memory read enable to work with targetFire - * - Implement Flatten RTL - * - Refactor important strings/naming to API (eg. "topIO" needs to be a constant defined somewhere or something) - * - Check that circuit is in LowFIRRTL? - * - * NOTES - * - How do we only transform the necessary modules? Should there be a MIDAS list of modules - * or something? - * * YES, this will be done by requiring the user to instantiate modules that should be split - * with something like: val module = new MIDASModule(class... etc.) - * - There cannot be nested DecoupledIO or ValidIO - * - How do output consumes tie in to MIDAS fire? If all of our outputs are not consumed - * in a given cycle, do we block midas$fire on the next cycle? Perhaps there should be - * a register for not having consumed all outputs last cycle - * - If our outputs are not consumed we also need to be sure not to consume out inputs, - * so the logic for this must depend on the previous cycle being consumed as well - * - We also need a way to determine the difference between the MIDAS modules and their - * connecting Queues, perhaps they should be MIDAS queues, which then perhaps prints - * out a listing of all queues so that they can be properly transformed - * * What do these MIDAS queues look like since we're enforcing true decoupled - * interfaces? - */ -object Fame1 { - - // Constants, common nodes, and common types used throughout - private type PortMap = Map[String, Seq[String]] - private val PortMap = Map[String, Seq[String]]() - private type ConnMap = Map[String, PortMap] - private val ConnMap = Map[String, PortMap]() - private type SimQMap = Map[(String, String), Module] - private val SimQMap = Map[(String, String), Module]() - - private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1))) - private val hostValid = Field("hostValid", Default, UIntType(IntWidth(1))) - private val hostClock = Port(NoInfo, "hostClock", Input, ClockType) - private val hostReset = Port(NoInfo, "hostReset", Input, UIntType(IntWidth(1))) - private val targetFire = Port(NoInfo, "targetFire", Input, UIntType(IntWidth(1))) - - private def wrapName(name: String): String = s"SimWrap_${name}" - private def unwrapName(name: String): String = name.stripPrefix("SimWrap_") - private def queueName(src: String, dst: String): String = s"SimQueue_${src}_${dst}" - private def instName(name: String): String = s"inst_${name}" - private def unInstName(name: String): String = name.stripPrefix("inst_") - - private def genHostDecoupled(fields: Seq[Field]): BundleType = { - BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(fields))) - } - - // ********** findPortConn ********** - // This takes lowFIRRTL top module that follows invariants described above and returns a connection Map - // of instanceName -> (instanctPorts -> portEndpoint) - // It honestly feels kind of brittle given it assumes there will be no intermediate nodes or anything in - // the way of direct connections between IO of module instances - private def processConnectExp(exp: Exp): (String, String) = { - val unsupportedExp = new Exception("Unsupported Exp for finding port connections: " + exp) - exp match { - case ref: Ref => ("topIO", ref.name) - case sub: Subfield => - sub.exp match { - case ref: Ref => (ref.name, sub.name) - case _ => throw unsupportedExp - } - case exp: Exp => throw unsupportedExp - } - } - private def processConnect(conn: Connect): ConnMap = { - val lhs = processConnectExp(conn.lhs) - val rhs = processConnectExp(conn.rhs) - Map(lhs._1 -> Map(lhs._2 -> Seq(rhs._1)), rhs._1 -> Map(rhs._2 -> Seq(lhs._1))).withDefaultValue(PortMap) - } - private def findPortConn(connMap: ConnMap, stmts: Seq[Stmt]): ConnMap = { - if (stmts.isEmpty) connMap - else { - stmts.head match { - case conn: Connect => { - val newConnMap = processConnect(conn) - findPortConn((connMap map { case (k,v) => - k -> merge(Seq(v, newConnMap(k))) { (_, v1, v2) => v1 ++ v2 }}), stmts.tail ) - } - case _ => findPortConn(connMap, stmts.tail) - } - } - } - private def findPortConn(top: Module, insts: Seq[DefInst]): ConnMap = { - val initConnMap = (insts map ( _.name -> PortMap )).toMap ++ Map("topIO" -> PortMap) - val topStmts = top.stmt match { - case b: Block => b.stmts - case s: Stmt => Seq(s) // This honestly shouldn't happen but let's be safe - } - findPortConn(initConnMap, topStmts) - } - - // Removes clocks from a portmap - private def scrubClocks(ports: Seq[Port], portMap: PortMap): PortMap = { - val clocks = ports filter (_.tpe == ClockType) map (_.name) - portMap filter { case (portName, _) => !clocks.contains(portName) } - } - - // ********** transformRTL ********** - // Takes an RTL module and give it targetFire input, propogates targetFire to all child instances, - // puts targetFire on regEnable for all registers - // TODO - // - Add smem support - private def transformRTL(m: Module): Module = { - val ports = m.ports :+ targetFire - val instProp = getDefInsts(m) map { inst => - Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name)) - } - val regEn = getDefRegs(m) map { reg => - When(NoInfo, DoPrimop(Not, Seq(buildExp(targetFire.name)), Seq(), UnknownType), - Connect(NoInfo, buildExp(reg.name), buildExp(reg.name)), EmptyStmt) - } - Module(m.info, m.name, ports, Block(m.stmt +: (instProp ++ regEn))) - } - - // ********** genWrapperModule ********** - // Generates FAME-1 Decoupled wrappers for simulation module instances - private def genWrapperModule(inst: DefInst, portMap: PortMap): Module = { - - val instIO = getDefInstType(inst) - val nameToField = (instIO.fields map (f => f.name -> f)).toMap - - val connections = (portMap map (_._2)).toSeq.flatten.distinct // modules this inst connects to - // Build simPort for each connecting module - // TODO This whole chunk really ought to be rewritten or made a function - val connPorts = connections map { c => - // Get ports that connect to this particular module as fields - val fields = (portMap filter (_._2.contains(c))).keySet.toSeq.sorted map (nameToField(_)) - val noClock = fields filter (_.tpe != ClockType) // Remove clock - val inputSet = noClock filter (_.dir == Reverse) map (f => Field(f.name, Default, f.tpe)) - val outputSet = noClock filter (_.dir == Default) - Port(inst.info, c, Output, BundleType( - (if (inputSet.isEmpty) Seq() - else Seq(Field("hostIn", Reverse, genHostDecoupled(inputSet))) - ) ++ - (if (outputSet.isEmpty) Seq() - else Seq(Field("hostOut", Default, genHostDecoupled(outputSet))) - ) - )) - } - val ports = hostClock +: hostReset +: connPorts // Add host and host reset - - // targetFire is signal to indicate when a simulation module can execute, this is indicated by all of its inputs - // being valid and all of its outputs being ready - val targetFireInputs = (connPorts map { port => - getFields(port) map { field => - field.dir match { - case Reverse => buildExp(Seq(port.name, field.name, hostValid.name)) - case Default => buildExp(Seq(port.name, field.name, hostReady.name)) - } - } - }).flatten - - val defTargetFire = DefNode(inst.info, targetFire.name, genPrimopReduce(And, targetFireInputs)) - val connectTargetFire = Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name)) - - // Only consume tokens when the module fires - // TODO is it bad to have the input readys depend on the input valid signals? - val inputsReady = (connPorts map { port => - getFields(port) filter (_.dir == Reverse) map { field => // filter to only take inputs - Connect(inst.info, buildExp(Seq(port.name, field.name, hostReady.name)), buildExp(targetFire.name)) - } - }).flatten - - // Outputs are valid on cycles where we fire - val outputsValid = (connPorts map { port => - getFields(port) filter (_.dir == Default) map { field => // filter to only take outputs - Connect(inst.info, buildExp(Seq(port.name, field.name, hostValid.name)), buildExp(targetFire.name)) - } - }).flatten - - // Connect up all of the IO of the RTL module to sim module IO, except clock which should be connected - // This currently assumes naming things that are also done above when generating connPorts - val connectedInstIOFields = instIO.fields filter(field => portMap.contains(field.name)) // skip unconnected IO - val instIOConnect = (connectedInstIOFields map { field => - field.tpe match { - case ClockType => Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)), - Ref(hostClock.name, ClockType))) - case _ => field.dir match { - case Default => portMap(field.name) map { endpoint => - Connect(inst.info, buildExp(Seq(endpoint, "hostOut", "hostBits", field.name)), - buildExp(Seq(inst.name, field.name))) - } - case Reverse => { - if (portMap(field.name).length > 1) - throw new Exception("It is illegal to have more than 1 connection to a single input" + field) - Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)), - buildExp(Seq(portMap(field.name).head, "hostIn", "hostBits", field.name)))) - } - } - } - }).flatten - val stmts = Block(Seq(defTargetFire) ++ inputsReady ++ outputsValid ++ Seq(inst) ++ - Seq(connectTargetFire) ++ instIOConnect) - - Module(inst.info, wrapName(inst.name), ports, stmts) - } - - // ********** generateSimQueues ********** - // Takes Seq of SimWrapper modules - // Returns Map of (src, dest) -> SimQueue - // To prevent duplicates, instead of creating a map with (src, dest) as the key, we could instead - // only one direction of the queue for each simport of each module. The only problem with this is - // it won't create queues for TopIO since that isn't a real module - private def generateSimQueues(wrappers: Seq[Module]): SimQMap = { - def rec(wrappers: Seq[Module], map: SimQMap): SimQMap = { - if (wrappers.isEmpty) map - else { - val w = wrappers.head - val name = unwrapName(w.name) - val newMap = (w.ports filter(isSimPort) map { port => - (splitSimPort(port) map { field => - val (src, dst) = if (field.dir == Default) (name, port.name) else (port.name, name) - if (map.contains((src, dst))) SimQMap - else Map((src, dst) -> buildSimQueue(queueName(src, dst), getHostBits(field).tpe)) - }).flatten.toMap - }).flatten.toMap - rec(wrappers.tail, map ++ newMap) - } - } - rec(wrappers, SimQMap) - } - - // ********** generateSimTop ********** - // Creates the Simulation Top module where all sim modules and sim queues are instantiated and connected - private def transformTopIO(ports: Seq[Port]): Seq[Port] = { - val noClock = ports filter (_.tpe != ClockType) - val inputs = noClock filter (_.dir == Input) map (_.toField.flip) // Flip because wrapping port is input - val outputs = noClock filter (_.dir == Output) map (_.toField) - - Seq(Port(NoInfo, "io", Output, BundleType(Seq(Field("hostIn", Reverse, genHostDecoupled(inputs)), - Field("hostOut", Default, genHostDecoupled(outputs)))))) - } - private def generateSimTop(wrappers: Seq[Module], simQueues: SimQMap, portMap: PortMap, rtlTop: Module): Module = { - val insts = (wrappers map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) ++ - (simQueues.values map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) - val connectClocks = (wrappers ++ simQueues.values) map { m => - Connect(NoInfo, buildExp(Seq(instName(m.name), hostClock.name)), buildExp(hostClock.name)) - } - val connectResets = (wrappers ++ simQueues.values) map { m => - Connect(NoInfo, buildExp(Seq(instName(m.name), hostReset.name)), buildExp(hostReset.name)) - } - // Connect queues to simulation modules (excludes IO) - val connectQueues = (simQueues map { case ((src, dst), queue) => - (if (src == "topIO") Seq() - else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(queue.name), "io", "enq")), - buildExp(Seq(instName(wrapName(src)), dst, "hostOut")))) - ) ++ - (if (dst == "topIO") Seq() - else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(wrapName(dst)), src, "hostIn")), - buildExp(Seq(instName(queue.name), "io", "deq")))) - ) - }).flatten - // Connect IO queues, Src means input, Dst means output (ie. the outside word is the Src or Dst) - val ioSrcQueues = (simQueues filter {case ((src, dst), _) => src == "topIO"} map {case (_, queue) => queue}).toSeq - val ioDstQueues = (simQueues filter {case ((src, dst), _) => dst == "topIO"} map {case (_, queue) => queue}).toSeq - val ioSrcSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Input) map (_.name) - val ioDstSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Output) map (_.name) - - val ioSrcQueueConnect = if (ioSrcQueues.length > 0) { - val readySignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostReady.name))) - val validSignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostValid.name))) - - (ioSrcSignals map { sig => - (portMap(sig) map { dst => - Connect(NoInfo, buildExp(Seq(instName(queueName("topIO", dst)), "io", "enq", "hostBits", sig)), - buildExp(Seq("io", "hostIn", "hostBits", sig))) - }) - }).flatten ++ - (validSignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostIn", hostValid.name))))) :+ - Connect(NoInfo, buildExp(Seq("io", "hostIn", hostReady.name)), genPrimopReduce(And, readySignals)) - } else Seq(EmptyStmt) - - val ioDstQueueConnect = if (ioDstQueues.length > 0) { - val readySignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostReady.name))) - val validSignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostValid.name))) - - (ioDstSignals map { sig => - (portMap(sig) map { src => - Connect(NoInfo, buildExp(Seq("io", "hostOut", "hostBits", sig)), - buildExp(Seq(instName(queueName(src, "topIO")), "io", "deq", "hostBits", sig))) - }) - }).flatten ++ - (readySignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostOut", hostReady.name))))) :+ - Connect(NoInfo, buildExp(Seq("io", "hostOut", hostValid.name)), genPrimopReduce(And, validSignals)) - } else Seq(EmptyStmt) - - val stmts = Block(insts ++ connectClocks ++ connectResets ++ connectQueues ++ ioSrcQueueConnect ++ ioDstQueueConnect) - val ports = Seq(hostClock, hostReset) ++ transformTopIO(rtlTop.ports) - Module(NoInfo, "SimTop", ports, stmts) - } - - // ********** transform ********** - // Perform FAME-1 Transformation for MIDAS - def transform(c: Circuit): Circuit = { - // We should check that the invariants mentioned above are true - val nameToModule = (c.modules map (m => m.name -> m))(collection.breakOut): Map[String, Module] - val top = nameToModule(c.name) - - val rtlModules = c.modules filter (_.name != top.name) map (transformRTL) - - val insts = getDefInsts(top) - - val portConn = findPortConn(top, insts) - - // Check that port Connections include all ports for each instance? - - val simWrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name))) - - val simQueues = generateSimQueues(simWrappers) - - // Remove duplicate simWrapper and simQueue modules? - - val simTop = generateSimTop(simWrappers, simQueues, portConn("topIO"), top) - - val modules = rtlModules ++ simWrappers ++ simQueues.values.toSeq ++ Seq(simTop) - - Circuit(c.info, simTop.name, modules) - } - -} diff --git a/src/main/scala/midas/Utils.scala b/src/main/scala/midas/Utils.scala deleted file mode 100644 index 003f22bb..00000000 --- a/src/main/scala/midas/Utils.scala +++ /dev/null @@ -1,329 +0,0 @@ - -package midas - -import firrtl._ -import firrtl.Utils._ - -object Utils { - - // Merges a sequence of maps via the provided function f - // Taken from: https://groups.google.com/forum/#!topic/scala-user/HaQ4fVRjlnU - def merge[K, V](maps: Seq[Map[K, V]])(f: (K, V, V) => V): Map[K, V] = { - maps.foldLeft(Map.empty[K, V]) { case (merged, m) => - m.foldLeft(merged) { case (acc, (k, v)) => - acc.get(k) match { - case Some(existing) => acc.updated(k, f(k, existing, v)) - case None => acc.updated(k, v) - } - } - } - } - - // This doesn't work because of Type Erasure >.< - //private def getStmts[A <: Stmt](s: Stmt): Seq[A] = { - // s match { - // case a: A => Seq(a) - // case b: Block => b.stmts.map(getStmts[A]).flatten - // case _ => Seq() - // } - //} - //private def getStmts[A <: Stmt](m: Module): Seq[A] = getStmts[A](m.stmt) - - def getDefRegs(s: Stmt): Seq[DefReg] = { - s match { - case r: DefReg => Seq(r) - case b: Block => b.stmts.map(getDefRegs).flatten - case _ => Seq() - } - } - def getDefRegs(m: Module): Seq[DefReg] = getDefRegs(m.stmt) - - def getDefInsts(s: Stmt): Seq[DefInst] = { - s match { - case i: DefInst => Seq(i) - case b: Block => b.stmts.map(getDefInsts).flatten - case _ => Seq() - } - } - def getDefInsts(m: Module): Seq[DefInst] = getDefInsts(m.stmt) - - def getDefInstRef(inst: DefInst): Ref = { - inst.module match { - case ref: Ref => ref - case _ => throw new Exception("Invalid module expression for DefInst: " + inst.serialize) - } - } - - // DefInsts have an Expression for the module, this expression should be a reference this - // reference has a tpe that should be a bundle representing the IO of that module class - def getDefInstType(inst: DefInst): BundleType = { - val ref = getDefInstRef(inst) - ref.tpe match { - case b: BundleType => b - case _ => throw new Exception("Invalid reference type for DefInst: " + inst.serialize) - } - } - - def getModuleFromDefInst(nameToModule: Map[String, Module], inst: DefInst): Module = { - val instModule = getDefInstRef(inst) - if(!nameToModule.contains(instModule.name)) - throw new Exception(s"Module ${instModule.name} not found in circuit!") - else - nameToModule(instModule.name) - } - - // Takes a set of strings or ints and returns equivalent expression node - // Strings correspond to subfields/references, ints correspond to indexes - // eg. Seq(io, port, ready) => io.port.ready - // Seq(io, port, 5, valid) => io.port[5].valid - // Seq(3) => UInt("h3") - def buildExp(names: Seq[Any]): Exp = { - def rec(names: Seq[Any]): Exp = { - names.head match { - // Useful for adding on indexes or subfields - case head: Exp => head - // Int -> UInt/SInt/Index - case head: Int => - if( names.tail.isEmpty ) // Is the UInt/SInt inference good enough? - if( head > 0 ) UIntValue(head, UnknownWidth) else SIntValue(head, UnknownWidth) - else Index(rec(names.tail), head, UnknownType) - // String -> Ref/Subfield - case head: String => - if( names.tail.isEmpty ) Ref(head, UnknownType) - else Subfield(rec(names.tail), head, UnknownType) - case _ => throw new Exception("Invalid argument type to buildExp! " + names) - } - } - rec(names.reverse) // Let user specify in more natural format - } - def buildExp(name: Any): Exp = buildExp(Seq(name)) - - def genPrimopReduce(op: Primop, args: Seq[Exp]): Exp = { - if( args.length == 0 ) throw new Exception("genPrimopReduce called on empty sequence!") - else if( args.length == 1 ) args.head - else if( args.length == 2 ) DoPrimop(op, Seq(args.head, args.last), Seq(), UnknownType) - else DoPrimop(op, Seq(args.head, genPrimopReduce(op, args.tail)), Seq(), UnknownType) - } - - // Checks if a firrtl.Port matches the MIDAS SimPort pattern - // This currently just checks that the port is of type bundle with ONLY the members - // hostIn and/or hostOut with correct directions - def isSimPort(port: Port): Boolean = { - //println("isSimPort called on port " + port.serialize) - port.tpe match { - case b: BundleType => { - b.fields map { field => - if( field.name == "hostIn" ) field.dir == Reverse - else if( field.name == "hostOut" ) field.dir == Default - else false - } reduce ( _ & _ ) - } - case _ => false - } - } - - def splitSimPort(port: Port): Seq[Field] = { - try { - val b = port.tpe.asInstanceOf[BundleType] - Seq(b.fields.find(_.name == "hostIn"), b.fields.find(_.name == "hostOut")).flatten - } catch { - case e: Exception => throw new Exception("Invalid SimPort " + port.serialize) - } - } - - // From simulation host decoupled, return hostbits field - def getHostBits(field: Field): Field = { - try { - val b = field.tpe.asInstanceOf[BundleType] - b.fields.find(_.name == "hostBits").get - } catch { - case e: Exception => throw new Exception("Invalid SimField " + field.serialize) - } - } - - // For a port that is known to be of type BundleType, return the fields of that bundle - def getFields(port: Port): Seq[Field] = { - port.tpe match { - case b: BundleType => b.fields - case _ => throw new Exception("getFields called on invalid port " + port) - } - } - - // Recursively iterates through firrtl.Type returning sequence of names to address signals - // * Intended for use with recursive bundle types - def enumerateMembers(tpe: Type): Seq[Seq[Any]] = { - def rec(tpe: Type, path: Seq[Any], members: Seq[Seq[Any]]): Seq[Seq[Any]] = { - tpe match { - case b: BundleType => (b.fields map ( f => rec(f.tpe, path :+ f.name, members) )).flatten - case v: VectorType => (Seq.tabulate(v.size.toInt) ( i => rec(v.tpe, path :+ i, members) )).flatten - case _ => members :+ path - } - } - rec(tpe, Seq[Any](), Seq[Seq[Any]]()) - } - - // Queue - // TODO - // - Insert null tokens upon hostReset (or should this be elsewhere?) - def buildSimQueue(name: String, tpe: Type): Module = { - val scopeSpaces = " " * 4 // Spaces before lines in module scope, for default assignments - val templatedQueue = -// """ -// circuit `NAME: -// module `NAME : -// input hostClock : Clock -// input hostReset : UInt<1> -// output io : {flip enq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, deq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, count : UInt<3>} -// -// io.count := UInt<1>("h00") -// `DEFAULT_ASSIGN -// io.deq.hostValid := UInt<1>("h00") -// io.enq.hostReady := UInt<1>("h00") -// cmem ram : `TYPE[4], hostClock -// reg T_80 : UInt<2>, hostClock, hostReset -// onreset T_80 := UInt<2>("h00") -// reg T_82 : UInt<2>, hostClock, hostReset -// onreset T_82 := UInt<2>("h00") -// reg maybe_full : UInt<1>, hostClock, hostReset -// onreset maybe_full := UInt<1>("h00") -// node ptr_match = eq(T_80, T_82) -// node T_87 = eq(maybe_full, UInt<1>("h00")) -// node empty = and(ptr_match, T_87) -// node full = and(ptr_match, maybe_full) -// node maybe_flow = and(UInt<1>("h00"), empty) -// node do_flow = and(maybe_flow, io.deq.hostReady) -// node T_93 = and(io.enq.hostReady, io.enq.hostValid) -// node T_95 = eq(do_flow, UInt<1>("h00")) -// node do_enq = and(T_93, T_95) -// node T_97 = and(io.deq.hostReady, io.deq.hostValid) -// node T_99 = eq(do_flow, UInt<1>("h00")) -// node do_deq = and(T_97, T_99) -// when do_enq : -// infer accessor T_101 = ram[T_80] -// T_101 <> io.enq.hostBits -// node T_109 = eq(T_80, UInt<2>("h03")) -// node T_111 = and(UInt<1>("h00"), T_109) -// node T_114 = addw(T_80, UInt<1>("h01")) -// node T_115 = mux(T_111, UInt<1>("h00"), T_114) -// T_80 := T_115 -// skip -// when do_deq : -// node T_117 = eq(T_82, UInt<2>("h03")) -// node T_119 = and(UInt<1>("h00"), T_117) -// node T_122 = addw(T_82, UInt<1>("h01")) -// node T_123 = mux(T_119, UInt<1>("h00"), T_122) -// T_82 := T_123 -// skip -// node T_124 = neq(do_enq, do_deq) -// when T_124 : -// maybe_full := do_enq -// skip -// node T_126 = eq(empty, UInt<1>("h00")) -// node T_128 = and(UInt<1>("h00"), io.enq.hostValid) -// node T_129 = or(T_126, T_128) -// io.deq.hostValid := T_129 -// node T_131 = eq(full, UInt<1>("h00")) -// node T_133 = and(UInt<1>("h00"), io.deq.hostReady) -// node T_134 = or(T_131, T_133) -// io.enq.hostReady := T_134 -// infer accessor T_135 = ram[T_82] -// wire T_149 : `TYPE -// T_149 <> T_135 -// when maybe_flow : -// T_149 <> io.enq.hostBits -// skip -// io.deq.hostBits <> T_149 -// node ptr_diff = subw(T_80, T_82) -// node T_157 = and(maybe_full, ptr_match) -// node T_158 = cat(T_157, ptr_diff) -// io.count := T_158 -// """ - """ -circuit `NAME: - module `NAME : - input hostClock : Clock - input hostReset : UInt<1> - output io : {flip enq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, deq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, count : UInt<3>} - - io.count := UInt<1>("h00") - `DEFAULT_ASSIGN - io.deq.hostValid := UInt<1>("h00") - io.enq.hostReady := UInt<1>("h00") - cmem ram : `TYPE[4], hostClock - reg T_404 : UInt<2>, hostClock, hostReset - onreset T_404 := UInt<2>("h00") - reg T_406 : UInt<2>, hostClock, hostReset - onreset T_406 := UInt<2>("h00") - reg maybe_full : UInt<1>, hostClock, hostReset - onreset maybe_full := UInt<1>("h00") - reg add_token_on_reset : UInt<1>, hostClock, hostReset - onreset add_token_on_reset := UInt<1>("h01") - add_token_on_reset := UInt<1>("h00") - node ptr_match = eq(T_404, T_406) - node T_414 = eq(maybe_full, UInt<1>("h00")) - node empty = and(ptr_match, T_414) - node full = and(ptr_match, maybe_full) - node maybe_flow = and(UInt<1>("h00"), empty) - node do_flow = and(maybe_flow, io.deq.hostReady) - node T_420 = and(io.enq.hostReady, io.enq.hostValid) - node T_422 = eq(do_flow, UInt<1>("h00")) - node do_enq = and(T_420, T_422) - node T_424 = and(io.deq.hostReady, io.deq.hostValid) - node T_426 = eq(do_flow, UInt<1>("h00")) - node do_deq = and(T_424, T_426) - node T_428 = or(do_enq, add_token_on_reset) - when T_428 : - infer accessor T_443 = ram[T_404] - T_443 := io.enq.hostBits - node T_473 = eq(T_404, UInt<2>("h03")) - node T_475 = and(UInt<1>("h00"), T_473) - node T_478 = addw(T_404, UInt<1>("h01")) - node T_479 = mux(T_475, UInt<1>("h00"), T_478) - T_404 := T_479 - skip - when do_deq : - node T_481 = eq(T_406, UInt<2>("h03")) - node T_483 = and(UInt<1>("h00"), T_481) - node T_486 = addw(T_406, UInt<1>("h01")) - node T_487 = mux(T_483, UInt<1>("h00"), T_486) - T_406 := T_487 - skip - node T_488 = neq(do_enq, do_deq) - when T_488 : - maybe_full := do_enq - skip - node T_490 = eq(empty, UInt<1>("h00")) - node T_492 = and(UInt<1>("h00"), io.enq.hostValid) - node T_493 = or(T_490, T_492) - io.deq.hostValid := T_493 - node T_495 = eq(full, UInt<1>("h00")) - node T_497 = and(UInt<1>("h00"), io.deq.hostReady) - node T_498 = or(T_495, T_497) - io.enq.hostReady := T_498 - infer accessor T_513 = ram[T_406] - wire T_599 : `TYPE - T_599 := T_513 - when maybe_flow : - T_599 := io.enq.hostBits - skip - io.deq.hostBits := T_599 - node ptr_diff = subw(T_404, T_406) - node T_629 = and(maybe_full, ptr_match) - node T_630 = cat(T_629, ptr_diff) - io.count := T_630 - """ - // Generate initial values - val signals = enumerateMembers(tpe) map ( Seq("io", "deq", "hostBits") ++ _ ) - val defaultAssign = signals map { sig => - scopeSpaces + Connect(NoInfo, buildExp(sig), UIntValue(0, UnknownWidth)).serialize - } - - val concreteQueue = templatedQueue.replaceAllLiterally("`NAME", name). - replaceAllLiterally("`TYPE", tpe.serialize). - replaceAllLiterally(scopeSpaces+"`DEFAULT_ASSIGN", defaultAssign.mkString("\n")) - - val ast = firrtl.Parser.parse(concreteQueue.split("\n")) - ast.modules.head - } - -} |
