aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/midas/Fame.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/midas/Fame.scala')
-rw-r--r--src/main/scala/midas/Fame.scala220
1 files changed, 127 insertions, 93 deletions
diff --git a/src/main/scala/midas/Fame.scala b/src/main/scala/midas/Fame.scala
index f53ab56d..49cdb613 100644
--- a/src/main/scala/midas/Fame.scala
+++ b/src/main/scala/midas/Fame.scala
@@ -30,16 +30,22 @@ import firrtl.Utils._
* - All and only instances in NewTop are sim tagged
* - No black boxes in design
* 2. Simulation Transformation
- * a. Iteratively transform each inst in Top (see example firrtl in dram_midas top dir)
+ * a. Perform Decoupled Transformation on every RTL Module
+ * i. Add targetFire signal input
+ * ii. Find all DefInst nodes and propogate targetFire to the instances
+ * iii. Find all registers and add when statement connecting regIn to regOut when !targetFire
+ * b. Iteratively transform each inst in Top (see example firrtl in dram_midas top dir)
* i. Create wrapper class
* ii. Create input and output (ready, valid) pairs for every other sim module this module connects to
* * Note that TopIO counts as a "sim module"
- * iii. Create simFire as AND of inputs.valid and outputs.ready
- * iv. Create [target] simClock as AND of simFire and [host] clock
+ * iii. Create targetFire as AND of inputs.valid and outputs.ready
+ * iv. Connect targetFire to targetFire input of target rtl inst
* v. Connect target IO to wrapper IO, except connect target clock to simClock
*
* TODO
+ * - Change from clock gating to reg enable, dont' forget to change sequential memory read enable
* - Implement Flatten RTL
+ * - Refactor important strings/naming to API (eg. "topIO" needs to be a constant defined somewhere or something)
* - Check that circuit is in LowFIRRTL?
*
* NOTES
@@ -61,49 +67,35 @@ import firrtl.Utils._
*/
object Fame1 {
- private def getDefInsts(s: Stmt): Seq[DefInst] = {
- s match {
- case i: DefInst => Seq(i)
- case b: Block => b.stmts.map(getDefInsts).flatten
- case _ => Seq()
- }
- }
- private def getDefInsts(m: Module): Seq[DefInst] = getDefInsts(m.stmt)
+ // Constants, common nodes, and common types used throughout
+ private type PortMap = Map[String, Seq[String]]
+ private val PortMap = Map[String, Seq[String]]()
+ private type ConnMap = Map[String, PortMap]
+ private val ConnMap = Map[String, PortMap]()
+ private type SimQMap = Map[(String, String), Module]
+ private val SimQMap = Map[(String, String), Module]()
- private def getDefInstRef(inst: DefInst): Ref = {
- inst.module match {
- case ref: Ref => ref
- case _ => throw new Exception("Invalid module expression for DefInst: " + inst.serialize)
- }
- }
+ private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1)))
+ private val hostValid = Field("hostValid", Default, UIntType(IntWidth(1)))
+ private val hostClock = Port(NoInfo, "hostClock", Input, ClockType)
+ private val hostReset = Port(NoInfo, "hostReset", Input, UIntType(IntWidth(1)))
+ private val targetFire = Port(NoInfo, "targetFire", Input, UIntType(IntWidth(1)))
- // DefInsts have an Expression for the module, this expression should be a reference this
- // reference has a tpe that should be a bundle representing the IO of that module class
- private def getDefInstType(inst: DefInst): BundleType = {
- val ref = getDefInstRef(inst)
- ref.tpe match {
- case b: BundleType => b
- case _ => throw new Exception("Invalid reference type for DefInst: " + inst.serialize)
- }
- }
+ private def wrapName(name: String): String = s"SimWrap_${name}"
+ private def unwrapName(name: String): String = name.stripPrefix("SimWrap_")
+ private def queueName(src: String, dst: String): String = s"SimQueue_${src}_${dst}"
+ private def instName(name: String): String = s"inst_${name}"
+ private def unInstName(name: String): String = name.stripPrefix("inst_")
- private def getModuleFromDefInst(nameToModule: Map[String, Module], inst: DefInst): Module = {
- val instModule = getDefInstRef(inst)
- if(!nameToModule.contains(instModule.name))
- throw new Exception(s"Module ${instModule.name} not found in circuit!")
- else
- nameToModule(instModule.name)
+ private def genHostDecoupled(fields: Seq[Field]): BundleType = {
+ BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(fields)))
}
- // ***** findPortConn *****
+ // ********** findPortConn **********
// This takes lowFIRRTL top module that follows invariants described above and returns a connection Map
// of instanceName -> (instanctPorts -> portEndpoint)
// It honestly feels kind of brittle given it assumes there will be no intermediate nodes or anything in
// the way of direct connections between IO of module instances
- private type PortMap = Map[String, Seq[String]]
- private val PortMap = Map[String, Seq[String]]()
- private type ConnMap = Map[String, PortMap]
- private val ConnMap = Map[String, PortMap]()
private def processConnectExp(exp: Exp): (String, String) = {
val unsupportedExp = new Exception("Unsupported Exp for finding port connections: " + exp)
exp match {
@@ -143,27 +135,33 @@ object Fame1 {
findPortConn(initConnMap, topStmts)
}
- // ***** Name Translation functions to help make naming consistent and easy to change *****
- private def wrapName(name: String): String = s"SimWrap_${name}"
- private def unwrapName(name: String): String = name.stripPrefix("SimWrap_")
- private def queueName(src: String, dst: String): String = s"SimQueue_${src}_${dst}"
- private def instName(name: String): String = s"inst_${name}"
- private def unInstName(name: String): String = name.stripPrefix("inst_")
-
- // ***** genWrapperModule *****
- // Generates FAME-1 Decoupled wrappers for simulation module instances
- private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1)))
- private val hostValid = Field("hostValid", Default, UIntType(IntWidth(1)))
- private val hostClock = Port(NoInfo, "hostClock", Input, ClockType)
- private val hostReset = Port(NoInfo, "hostReset", Input, UIntType(IntWidth(1)))
+ // Removes clocks from a portmap
+ private def scrubClocks(ports: Seq[Port], portMap: PortMap): PortMap = {
+ val clocks = ports filter (_.tpe == ClockType) map (_.name)
+ portMap filter { case (portName, _) => !clocks.contains(portName) }
+ }
- private def genHostDecoupled(fields: Seq[Field]): BundleType = {
- BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(fields)))
+ // ********** transformRTL **********
+ // Takes an RTL module and give it targetFire input, propogates targetFire to all child instances,
+ // puts targetFire on regEnable for all registers
+ // TODO
+ // - Add smem support
+ private def transformRTL(m: Module): Module = {
+ val ports = m.ports :+ targetFire
+ val instProp = getDefInsts(m) map { inst =>
+ Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name))
+ }
+ val regEn = getDefRegs(m) map { reg =>
+ When(NoInfo, DoPrimop(Not, Seq(buildExp(targetFire.name)), Seq(), UnknownType),
+ Connect(NoInfo, buildExp(reg.name), buildExp(reg.name)), EmptyStmt)
+ }
+ Module(m.info, m.name, ports, Block(m.stmt +: (instProp ++ regEn)))
}
+ // ********** genWrapperModule **********
+ // Generates FAME-1 Decoupled wrappers for simulation module instances
private def genWrapperModule(inst: DefInst, portMap: PortMap): Module = {
- println(s"Wrapping ${inst.name}")
- println(portMap)
+
val instIO = getDefInstType(inst)
val nameToField = (instIO.fields map (f => f.name -> f)).toMap
@@ -181,7 +179,7 @@ object Fame1 {
else Seq(Field("hostIn", Reverse, genHostDecoupled(inputSet)))
) ++
(if (outputSet.isEmpty) Seq()
- else Seq(Field("hostOut", Default, genHostDecoupled(inputSet)))
+ else Seq(Field("hostOut", Default, genHostDecoupled(outputSet)))
)
))
}
@@ -198,11 +196,8 @@ object Fame1 {
}
}).flatten
- val targetFire = DefNode(inst.info, "targetFire", genPrimopReduce(And, targetFireInputs))
- // targetClock is the simple AND of targetFire and the hostClock so that the rtl module only executes when data
- // is available and outputs are ready
- val targetClock = DefNode(inst.info, "targetClock", DoPrimop(And,
- Seq(buildExp(targetFire.name), buildExp(hostClock.name)), Seq(), UnknownType))
+ val defTargetFire = DefNode(inst.info, targetFire.name, genPrimopReduce(And, targetFireInputs))
+ val connectTargetFire = Connect(NoInfo, buildExp(Seq(inst.name, targetFire.name)), buildExp(targetFire.name))
// As a simple RTL module, we're always ready
val inputsReady = (connPorts map { port =>
@@ -224,37 +219,33 @@ object Fame1 {
val instIOConnect = (connectedInstIOFields map { field =>
field.tpe match {
case ClockType => Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)),
- Ref(targetClock.name, ClockType)))
+ Ref(hostClock.name, ClockType)))
case _ => field.dir match {
case Default => portMap(field.name) map { endpoint =>
- Connect(inst.info, buildExp(Seq(endpoint, "hostOut", field.name)),
+ Connect(inst.info, buildExp(Seq(endpoint, "hostOut", "hostBits", field.name)),
buildExp(Seq(inst.name, field.name)))
}
case Reverse => {
if (portMap(field.name).length > 1)
throw new Exception("It is illegal to have more than 1 connection to a single input" + field)
Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)),
- buildExp(Seq(portMap(field.name).head, "hostIn", field.name))))
+ buildExp(Seq(portMap(field.name).head, "hostIn", "hostBits", field.name))))
}
}
}
}).flatten
- //val stmts = Block(Seq(simFire, simClock, inst) ++ inputsReady ++ outputsValid ++ instIOConnect)
- val stmts = Block(Seq(targetFire, targetClock) ++ inputsReady ++ outputsValid ++ Seq(inst) ++ instIOConnect)
+ val stmts = Block(Seq(defTargetFire) ++ inputsReady ++ outputsValid ++ Seq(inst) ++
+ Seq(connectTargetFire) ++ instIOConnect)
Module(inst.info, wrapName(inst.name), ports, stmts)
}
-
-
- // ***** generateSimQueues *****
+ // ********** generateSimQueues **********
// Takes Seq of SimWrapper modules
// Returns Map of (src, dest) -> SimQueue
// To prevent duplicates, instead of creating a map with (src, dest) as the key, we could instead
// only one direction of the queue for each simport of each module. The only problem with this is
// it won't create queues for TopIO since that isn't a real module
- private type SimQMap = Map[(String, String), Module]
- private val SimQMap = Map[(String, String), Module]()
private def generateSimQueues(wrappers: Seq[Module]): SimQMap = {
def rec(wrappers: Seq[Module], map: SimQMap): SimQMap = {
if (wrappers.isEmpty) map
@@ -274,58 +265,101 @@ object Fame1 {
rec(wrappers, SimQMap)
}
- // ***** generateSimTop *****
+ // ********** generateSimTop **********
// Creates the Simulation Top module where all sim modules and sim queues are instantiated and connected
- private def generateSimTop(wrappers: Seq[Module], simQueues: SimQMap, rtlTop: Module): Module = {
+ private def transformTopIO(ports: Seq[Port]): Seq[Port] = {
+ val noClock = ports filter (_.tpe != ClockType)
+ val inputs = noClock filter (_.dir == Input) map (_.toField.flip) // Flip because wrapping port is input
+ val outputs = noClock filter (_.dir == Output) map (_.toField)
+
+ Seq(Port(NoInfo, "io", Output, BundleType(Seq(Field("hostIn", Reverse, genHostDecoupled(inputs)),
+ Field("hostOut", Default, genHostDecoupled(outputs))))))
+ }
+ private def generateSimTop(wrappers: Seq[Module], simQueues: SimQMap, portMap: PortMap, rtlTop: Module): Module = {
val insts = (wrappers map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) ++
(simQueues.values map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) })
val connectClocks = (wrappers ++ simQueues.values) map { m =>
- Connect(NoInfo, buildExp(Seq(m.name, hostClock.name)), buildExp(hostClock.name))
+ Connect(NoInfo, buildExp(Seq(instName(m.name), hostClock.name)), buildExp(hostClock.name))
}
val connectResets = (wrappers ++ simQueues.values) map { m =>
- Connect(NoInfo, buildExp(Seq(m.name, hostReset.name)), buildExp(hostReset.name))
- }
- val connectQueues = simQueues map { case (key, queue) =>
- (key._1, key._2) match {
- //case (src, "topIO") => EmptyStmt
- //case ("topIO", dst) => EmptyStmt
- case (src, dst) => Block(Seq(BulkConnect(NoInfo, buildExp(Seq(queue.name, "io", "enq")),
- buildExp(Seq(instName(wrapName(src)), dst, "hostOut"))),
- BulkConnect(NoInfo, buildExp(Seq(queue.name, "io", "deq")),
- buildExp(Seq(instName(wrapName(dst)), src, "hostIn")))))
- }
+ Connect(NoInfo, buildExp(Seq(instName(m.name), hostReset.name)), buildExp(hostReset.name))
}
- val stmts = Block(insts ++ connectClocks ++ connectResets ++ connectQueues)
- val ports = Seq(hostClock, hostReset) ++ rtlTop.ports
+ // Connect queues to simulation modules (excludes IO)
+ val connectQueues = (simQueues map { case ((src, dst), queue) =>
+ (if (src == "topIO") Seq()
+ else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(queue.name), "io", "enq")),
+ buildExp(Seq(instName(wrapName(src)), dst, "hostOut"))))
+ ) ++
+ (if (dst == "topIO") Seq()
+ else Seq(BulkConnect(NoInfo, buildExp(Seq(instName(wrapName(dst)), src, "hostIn")),
+ buildExp(Seq(instName(queue.name), "io", "deq"))))
+ )
+ }).flatten
+ // Connect IO queues, Src means input, Dst means output (ie. the outside word is the Src or Dst)
+ val ioSrcQueues = (simQueues filter {case ((src, dst), _) => src == "topIO"} map {case (_, queue) => queue}).toSeq
+ val ioDstQueues = (simQueues filter {case ((src, dst), _) => dst == "topIO"} map {case (_, queue) => queue}).toSeq
+ val ioSrcSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Input) map (_.name)
+ val ioDstSignals = rtlTop.ports filter (sig => sig.tpe != ClockType && sig.dir == Output) map (_.name)
+
+ val ioSrcQueueConnect = if (ioSrcQueues.length > 0) {
+ val readySignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostReady.name)))
+ val validSignals = ioSrcQueues map (queue => buildExp(Seq(instName(queue.name), "io", "enq", hostValid.name)))
+
+ (ioSrcSignals map { sig =>
+ (portMap(sig) map { dst =>
+ Connect(NoInfo, buildExp(Seq(instName(queueName("topIO", dst)), "io", "enq", "hostBits", sig)),
+ buildExp(Seq("io", "hostIn", "hostBits", sig)))
+ })
+ }).flatten ++
+ (validSignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostIn", hostValid.name))))) :+
+ Connect(NoInfo, buildExp(Seq("io", "hostIn", hostReady.name)), genPrimopReduce(And, readySignals))
+ } else Seq(EmptyStmt)
+
+ val ioDstQueueConnect = if (ioDstQueues.length > 0) {
+ val readySignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostReady.name)))
+ val validSignals = ioDstQueues map (queue => buildExp(Seq(instName(queue.name), "io", "deq", hostValid.name)))
+
+ (ioDstSignals map { sig =>
+ (portMap(sig) map { src =>
+ Connect(NoInfo, buildExp(Seq("io", "hostOut", "hostBits", sig)),
+ buildExp(Seq(instName(queueName(src, "topIO")), "io", "deq", "hostBits", sig)))
+ })
+ }).flatten ++
+ (readySignals map (sig => Connect(NoInfo, buildExp(sig), buildExp(Seq("io", "hostOut", hostReady.name))))) :+
+ Connect(NoInfo, buildExp(Seq("io", "hostOut", hostValid.name)), genPrimopReduce(And, validSignals))
+ } else Seq(EmptyStmt)
+
+ val stmts = Block(insts ++ connectClocks ++ connectResets ++ connectQueues ++ ioSrcQueueConnect ++ ioDstQueueConnect)
+ val ports = Seq(hostClock, hostReset) ++ transformTopIO(rtlTop.ports)
Module(NoInfo, "SimTop", ports, stmts)
}
- // ***** transform *****
+ // ********** transform **********
// Perform FAME-1 Transformation for MIDAS
def transform(c: Circuit): Circuit = {
// We should check that the invariants mentioned above are true
val nameToModule = (c.modules map (m => m.name -> m))(collection.breakOut): Map[String, Module]
val top = nameToModule(c.name)
+ val rtlModules = c.modules filter (_.name != top.name) map (transformRTL)
+
val insts = getDefInsts(top)
val portConn = findPortConn(top, insts)
+
// Check that port Connections include all ports for each instance?
- println(portConn)
val simWrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name)))
val simQueues = generateSimQueues(simWrappers)
- // Remove duplicate simWrapper and simQueue modules
+ // Remove duplicate simWrapper and simQueue modules?
- val simTop = generateSimTop(simWrappers, simQueues, top)
+ val simTop = generateSimTop(simWrappers, simQueues, portConn("topIO"), top)
- val modules = (c.modules filter (_.name != top.name)) ++ simWrappers ++ simQueues.values.toSeq ++ Seq(simTop)
+ val modules = rtlModules ++ simWrappers ++ simQueues.values.toSeq ++ Seq(simTop)
- println(simTop.serialize)
Circuit(c.info, simTop.name, modules)
- //c
}
}