diff options
Diffstat (limited to 'src/main/scala/midas/Fame.scala')
| -rw-r--r-- | src/main/scala/midas/Fame.scala | 220 |
1 files changed, 127 insertions, 93 deletions
diff --git a/src/main/scala/midas/Fame.scala b/src/main/scala/midas/Fame.scala index f53ab56d..49cdb613 100644 --- a/src/main/scala/midas/Fame.scala +++ b/src/main/scala/midas/Fame.scala @@ -30,16 +30,22 @@ import firrtl.Utils._ * - All and only instances in NewTop are sim tagged * - No black boxes in design * 2. Simulation Transformation - * a. Iteratively transform each inst in Top (see example firrtl in dram_midas top dir) + * a. Perform Decoupled Transformation on every RTL Module + * i. Add targetFire signal input + * ii. Find all DefInst nodes and propogate targetFire to the instances + * iii. Find all registers and add when statement connecting regIn to regOut when !targetFire + * b. Iteratively transform each inst in Top (see example firrtl in dram_midas top dir) * i. Create wrapper class * ii. Create input and output (ready, valid) pairs for every other sim module this module connects to * * Note that TopIO counts as a "sim module" - * iii. Create simFire as AND of inputs.valid and outputs.ready - * iv. Create [target] simClock as AND of simFire and [host] clock + * iii. Create targetFire as AND of inputs.valid and outputs.ready + * iv. Connect targetFire to targetFire input of target rtl inst * v. Connect target IO to wrapper IO, except connect target clock to simClock * * TODO + * - Change from clock gating to reg enable, dont' forget to change sequential memory read enable * - Implement Flatten RTL + * - Refactor important strings/naming to API (eg. "topIO" needs to be a constant defined somewhere or something) * - Check that circuit is in LowFIRRTL? * * NOTES @@ -61,49 +67,35 @@ import firrtl.Utils._ */ object Fame1 { - private def getDefInsts(s: Stmt): Seq[DefInst] = { - s match { - case i: DefInst => Seq(i) - case b: Block => b.stmts.map(getDefInsts).flatten - case _ => Seq() - } - } - private def getDefInsts(m: Module): Seq[DefInst] = getDefInsts(m.stmt) + // Constants, common nodes, and common types used throughout + private type PortMap = Map[String, Seq[String]] + private val PortMap = Map[String, Seq[String]]() + private type ConnMap = Map[String, PortMap] + private val ConnMap = Map[String, PortMap]() + private type SimQMap = Map[(String, String), Module] + private val SimQMap = Map[(String, String), Module]() - private def getDefInstRef(inst: DefInst): Ref = { - inst.module match { - case ref: Ref => ref - case _ => throw new Exception("Invalid module expression for DefInst: " + inst.serialize) - } - } + private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1))) + private val hostValid = Field("hostValid", Default, UIntType(IntWidth(1))) + private val hostClock = Port(NoInfo, "hostClock", Input, ClockType) + private val hostReset = Port(NoInfo, "hostReset", Input, UIntType(IntWidth(1))) + private val targetFire = Port(NoInfo, "targetFire", Input, UIntType(IntWidth(1))) - // DefInsts have an Expression for the module, this expression should be a reference this - // reference has a tpe that should be a bundle representing the IO of that module class - private def getDefInstType(inst: DefInst): BundleType = { - val ref = getDefInstRef(inst) - ref.tpe match { - case b: BundleType => b - case _ => throw new Exception("Invalid reference type for DefInst: " + inst.serialize) - } - } + private def wrapName(name: String): String = s"SimWrap_${name}" + private def unwrapName(name: String): String = name.stripPrefix("SimWrap_") + private def queueName(src: String, dst: String): String = s"SimQueue_${src}_${dst}" + private def instName(name: String): String = s"inst_${name}" + private def unInstName(name: String): String = name.stripPrefix("inst_") - private def getModuleFromDefInst(nameToModule: Map[String, Module], inst: DefInst): Module = { - val instModule = getDefInstRef(inst) - if(!nameToModule.contains(instModule.name)) - throw new Exception(s"Module ${instModule.name} not found in circuit!") - else - nameToModule(instModule.name) + private def genHostDecoupled(fields: Seq[Field]): BundleType = { + BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(fields))) } - // ***** findPortConn ***** + // ********** findPortConn ********** // This takes lowFIRRTL top module that follows invariants described above and returns a connection Map // of instanceName -> (instanctPorts -> portEndpoint) // It honestly feels kind of brittle given it assumes there will be no intermediate nodes or anything in // the way of direct connections between IO of module instances - private type PortMap = Map[String, Seq[String]] - private val PortMap = Map[String, Seq[String]]() - private type ConnMap = Map[String, PortMap] - private val ConnMap = Map[String, PortMap]() private def processConnectExp(exp: Exp): (String, String) = { val unsupportedExp = new Exception("Unsupported Exp for finding port connections: " + exp) exp match { @@ -143,27 +135,33 @@ object Fame1 { findPortConn(initConnMap, topStmts) } - // ***** Name Translation functions to help make naming consistent and easy to change ***** - private def wrapName(name: String): String = s"SimWrap_${name}" - private def unwrapName(name: String): String = name.stripPrefix("SimWrap_") - private def queueName(src: String, dst: String): String = s"SimQueue_${src}_${dst}" - private def instName(name: String): String = s"inst_${name}" - private def unInstName(name: String): String = name.stripPrefix("inst_") - - // ***** genWrapperModule ***** - // Generates FAME-1 Decoupled wrappers for simulation module instances - private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1))) - private val hostValid = Field("hostValid", Default, UIntType(IntWidth(1))) - private val hostClock = Port(NoInfo, "hostClock", Input, ClockType) - private val hostReset = Port(NoInfo, "hostReset", Input, UIntType(IntWidth(1))) + // Removes clocks from a portmap + private def scrubClocks(ports: Seq[Port], portMap: PortMap): PortMap = { + val clocks = ports filter (_.tpe == ClockType) map (_.name) + portMap filter { case (portName, _) => !clocks.contains(portName) } + } - private def genHostDecoupled(fields: Seq[Field]): BundleType = { - BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(fields))) + // ********** transformRTL ********** + // Takes an RTL module and give it targetFire input, propogates targetFire to all child instances, + // puts targetFire on regEnable for all registers + // TODO + // - Add smem support + private def transformRTL(m: Module): Module = { + val ports = m.ports :+ targetFire + val instProp = getDefInsts(m) map { inst => + Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name)) + } + val regEn = getDefRegs(m) map { reg => + When(NoInfo, DoPrimop(Not, Seq(buildExp(targetFire.name)), Seq(), UnknownType), + Connect(NoInfo, buildExp(reg.name), buildExp(reg.name)), EmptyStmt) + } + Module(m.info, m.name, ports, Block(m.stmt +: (instProp ++ regEn))) } + // ********** genWrapperModule ********** + // Generates FAME-1 Decoupled wrappers for simulation module instances private def genWrapperModule(inst: DefInst, portMap: PortMap): Module = { - println(s"Wrapping ${inst.name}") - println(portMap) + val instIO = getDefInstType(inst) val nameToField = (instIO.fields map (f => f.name -> f)).toMap @@ -181,7 +179,7 @@ object Fame1 { else Seq(Field("hostIn", Reverse, genHostDecoupled(inputSet))) ) ++ (if (outputSet.isEmpty) Seq() - else Seq(Field("hostOut", Default, genHostDecoupled(inputSet))) + else Seq(Field("hostOut", Default, genHostDecoupled(outputSet))) ) )) } @@ -198,11 +196,8 @@ object Fame1 { } }).flatten - val targetFire = DefNode(inst.info, "targetFire", genPrimopReduce(And, targetFireInputs)) - // targetClock is the simple AND of targetFire and the hostClock so that the rtl module only executes when data - // is available and outputs are ready - val targetClock = DefNode(inst.info, "targetClock", DoPrimop(And, - Seq(buildExp(targetFire.name), buildExp(hostClock.name)), Seq(), UnknownType)) + val defTargetFire = DefNode(inst.info, targetFire.name, genPrimopReduce(And, targetFireInputs)) + val connectTargetFire = Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name)) // As a simple RTL module, we're always ready val inputsReady = (connPorts map { port => @@ -224,37 +219,33 @@ object Fame1 { val instIOConnect = (connectedInstIOFields map { field => field.tpe match { case ClockType => Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)), - Ref(targetClock.name, ClockType))) + Ref(hostClock.name, ClockType))) case _ => field.dir match { case Default => portMap(field.name) map { endpoint => - Connect(inst.info, buildExp(Seq(endpoint, "hostOut", field.name)), + Connect(inst.info, buildExp(Seq(endpoint, "hostOut", "hostBits", field.name)), buildExp(Seq(inst.name, field.name))) } case Reverse => { if (portMap(field.name).length > 1) throw new Exception("It is illegal to have more than 1 connection to a single input" + field) Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)), - buildExp(Seq(portMap(field.name).head, "hostIn", field.name)))) + buildExp(Seq(portMap(field.name).head, "hostIn", "hostBits", field.name)))) } } } }).flatten - //val stmts = Block(Seq(simFire, simClock, inst) ++ inputsReady ++ outputsValid ++ instIOConnect) - val stmts = Block(Seq(targetFire, targetClock) ++ inputsReady ++ outputsValid ++ Seq(inst) ++ instIOConnect) + val stmts = Block(Seq(defTargetFire) ++ inputsReady ++ outputsValid ++ Seq(inst) ++ + Seq(connectTargetFire) ++ instIOConnect) Module(inst.info, wrapName(inst.name), ports, stmts) } - - - // ***** generateSimQueues ***** + // ********** generateSimQueues ********** // Takes Seq of SimWrapper modules // Returns Map of (src, dest) -> SimQueue // To prevent duplicates, instead of creating a map with (src, dest) as the key, we could instead // only one direction of the queue for each simport of each module. The only problem with this is // it won't create queues for TopIO since that isn't a real module - private type SimQMap = Map[(String, String), Module] - private val SimQMap = Map[(String, String), Module]() private def generateSimQueues(wrappers: Seq[Module]): SimQMap = { def rec(wrappers: Seq[Module], map: SimQMap): SimQMap = { if (wrappers.isEmpty) map @@ -274,58 +265,101 @@ object Fame1 { rec(wrappers, SimQMap) } - // ***** generateSimTop ***** + // ********** generateSimTop ********** // Creates the Simulation Top module where all sim modules and sim queues are instantiated and connected - private def generateSimTop(wrappers: Seq[Module], simQueues: SimQMap, rtlTop: Module): Module = { + private def transformTopIO(ports: Seq[Port]): Seq[Port] = { + val noClock = ports filter (_.tpe != ClockType) + val inputs = noClock filter (_.dir == Input) map (_.toField.flip) // Flip because wrapping port is input + val outputs = noClock filter (_.dir == Output) map (_.toField) + + Seq(Port(NoInfo, "io", Output, BundleType(Seq(Field("hostIn", Reverse, genHostDecoupled(inputs)), + Field("hostOut", Default, genHostDecoupled(outputs)))))) + } + private def generateSimTop(wrappers: Seq[Module], simQueues: SimQMap, portMap: PortMap, rtlTop: Module): Module = { val insts = (wrappers map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) ++ (simQueues.values map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) val connectClocks = (wrappers ++ simQueues.values) map { m => - Connect(NoInfo, buildExp(Seq(m.name, hostClock.name)), buildExp(hostClock.name)) + Connect(NoInfo, buildExp(Seq(instName(m.name), hostClock.name)), buildExp(hostClock.name)) } val connectResets = (wrappers ++ simQueues.values) map { m => - Connect(NoInfo, buildExp(Seq(m.name, hostReset.name)), buildExp(hostReset.name)) - } - val connectQueues = simQueues map { case (key, queue) => - (key._1, key._2) match { - //case (src, "topIO") => EmptyStmt - //case ("topIO", dst) => EmptyStmt - case (src, dst) => Block(Seq(BulkConnect(NoInfo, buildExp(Seq(queue.name, "io", "enq")), - buildExp(Seq(instName(wrapName(src)), dst, "hostOut"))), - BulkConnect(NoInfo, buildExp(Seq(queue.name, "io", "deq")), - buildExp(Seq(instName(wrapName(dst)), src, "hostIn"))))) - } + Connect(NoInfo, buildExp(Seq(instName(m.name), hostReset.name)), buildExp(hostReset.name)) } - val stmts = Block(insts ++ connectClocks ++ connectResets ++ connectQueues) - val ports = Seq(hostClock, hostReset) ++ rtlTop.ports + // Connect queues to simulation modules (excludes IO) + val connectQueues = (simQueues map { case ((src, dst), queue) => + (if (src == "topIO") Seq() + else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(queue.name), "io", "enq")), + buildExp(Seq(instName(wrapName(src)), dst, "hostOut")))) + ) ++ + (if (dst == "topIO") Seq() + else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(wrapName(dst)), src, "hostIn")), + buildExp(Seq(instName(queue.name), "io", "deq")))) + ) + }).flatten + // Connect IO queues, Src means input, Dst means output (ie. the outside word is the Src or Dst) + val ioSrcQueues = (simQueues filter {case ((src, dst), _) => src == "topIO"} map {case (_, queue) => queue}).toSeq + val ioDstQueues = (simQueues filter {case ((src, dst), _) => dst == "topIO"} map {case (_, queue) => queue}).toSeq + val ioSrcSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Input) map (_.name) + val ioDstSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Output) map (_.name) + + val ioSrcQueueConnect = if (ioSrcQueues.length > 0) { + val readySignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostReady.name))) + val validSignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostValid.name))) + + (ioSrcSignals map { sig => + (portMap(sig) map { dst => + Connect(NoInfo, buildExp(Seq(instName(queueName("topIO", dst)), "io", "enq", "hostBits", sig)), + buildExp(Seq("io", "hostIn", "hostBits", sig))) + }) + }).flatten ++ + (validSignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostIn", hostValid.name))))) :+ + Connect(NoInfo, buildExp(Seq("io", "hostIn", hostReady.name)), genPrimopReduce(And, readySignals)) + } else Seq(EmptyStmt) + + val ioDstQueueConnect = if (ioDstQueues.length > 0) { + val readySignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostReady.name))) + val validSignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostValid.name))) + + (ioDstSignals map { sig => + (portMap(sig) map { src => + Connect(NoInfo, buildExp(Seq("io", "hostOut", "hostBits", sig)), + buildExp(Seq(instName(queueName(src, "topIO")), "io", "deq", "hostBits", sig))) + }) + }).flatten ++ + (readySignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostOut", hostReady.name))))) :+ + Connect(NoInfo, buildExp(Seq("io", "hostOut", hostValid.name)), genPrimopReduce(And, validSignals)) + } else Seq(EmptyStmt) + + val stmts = Block(insts ++ connectClocks ++ connectResets ++ connectQueues ++ ioSrcQueueConnect ++ ioDstQueueConnect) + val ports = Seq(hostClock, hostReset) ++ transformTopIO(rtlTop.ports) Module(NoInfo, "SimTop", ports, stmts) } - // ***** transform ***** + // ********** transform ********** // Perform FAME-1 Transformation for MIDAS def transform(c: Circuit): Circuit = { // We should check that the invariants mentioned above are true val nameToModule = (c.modules map (m => m.name -> m))(collection.breakOut): Map[String, Module] val top = nameToModule(c.name) + val rtlModules = c.modules filter (_.name != top.name) map (transformRTL) + val insts = getDefInsts(top) val portConn = findPortConn(top, insts) + // Check that port Connections include all ports for each instance? - println(portConn) val simWrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name))) val simQueues = generateSimQueues(simWrappers) - // Remove duplicate simWrapper and simQueue modules + // Remove duplicate simWrapper and simQueue modules? - val simTop = generateSimTop(simWrappers, simQueues, top) + val simTop = generateSimTop(simWrappers, simQueues, portConn("topIO"), top) - val modules = (c.modules filter (_.name != top.name)) ++ simWrappers ++ simQueues.values.toSeq ++ Seq(simTop) + val modules = rtlModules ++ simWrappers ++ simQueues.values.toSeq ++ Seq(simTop) - println(simTop.serialize) Circuit(c.info, simTop.name, modules) - //c } } |
