diff options
| author | jackkoenig | 2015-12-06 00:36:12 -0800 |
|---|---|---|
| committer | jackkoenig | 2015-12-06 00:36:12 -0800 |
| commit | c5cac5227cd164b17f2a6f02227a71dc89f8cde4 (patch) | |
| tree | 1f6d30b64a58103574bacb770bbc307f8d1e4bbe /src/main/scala/midas | |
| parent | e8ac783706cca1f7ee65d799b5d8be445b6a5c5d (diff) | |
Working on generating SimTop, need to figure out how to split the top-level IO between the sim modules.
Diffstat (limited to 'src/main/scala/midas')
| -rw-r--r-- | src/main/scala/midas/Fame.scala | 131 | ||||
| -rw-r--r-- | src/main/scala/midas/Utils.scala | 69 |
2 files changed, 129 insertions, 71 deletions
diff --git a/src/main/scala/midas/Fame.scala b/src/main/scala/midas/Fame.scala index 169193f0..f53ab56d 100644 --- a/src/main/scala/midas/Fame.scala +++ b/src/main/scala/midas/Fame.scala @@ -100,8 +100,8 @@ object Fame1 { // of instanceName -> (instanctPorts -> portEndpoint) // It honestly feels kind of brittle given it assumes there will be no intermediate nodes or anything in // the way of direct connections between IO of module instances - private type PortMap = Map[String, String] - private val PortMap = Map[String, String]() + private type PortMap = Map[String, Seq[String]] + private val PortMap = Map[String, Seq[String]]() private type ConnMap = Map[String, PortMap] private val ConnMap = Map[String, PortMap]() private def processConnectExp(exp: Exp): (String, String) = { @@ -119,7 +119,7 @@ object Fame1 { private def processConnect(conn: Connect): ConnMap = { val lhs = processConnectExp(conn.lhs) val rhs = processConnectExp(conn.rhs) - Map(lhs._1 -> Map(lhs._2 -> rhs._1), rhs._1 -> Map(rhs._2 -> lhs._1)).withDefaultValue(PortMap) + Map(lhs._1 -> Map(lhs._2 -> Seq(rhs._1)), rhs._1 -> Map(rhs._2 -> Seq(lhs._1))).withDefaultValue(PortMap) } private def findPortConn(connMap: ConnMap, stmts: Seq[Stmt]): ConnMap = { if (stmts.isEmpty) connMap @@ -127,7 +127,8 @@ object Fame1 { stmts.head match { case conn: Connect => { val newConnMap = processConnect(conn) - findPortConn(connMap map {case (k,v) => k -> (v ++ newConnMap(k)) }, stmts.tail) + findPortConn((connMap map { case (k,v) => + k -> merge(Seq(v, newConnMap(k))) { (_, v1, v2) => v1 ++ v2 }}), stmts.tail ) } case _ => findPortConn(connMap, stmts.tail) } @@ -142,14 +143,23 @@ object Fame1 { findPortConn(initConnMap, topStmts) } - private def wrapName(name: String): String = "SimWrap_" + name + // ***** Name Translation functions to help make naming consistent and easy to change ***** + private def wrapName(name: String): String = s"SimWrap_${name}" private def unwrapName(name: String): String = name.stripPrefix("SimWrap_") + private def queueName(src: String, dst: String): String = s"SimQueue_${src}_${dst}" + private def instName(name: String): String = s"inst_${name}" + private def unInstName(name: String): String = name.stripPrefix("inst_") // ***** genWrapperModule ***** // Generates FAME-1 Decoupled wrappers for simulation module instances private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1))) private val hostValid = Field("hostValid", Default, UIntType(IntWidth(1))) private val hostClock = Port(NoInfo, "hostClock", Input, ClockType) + private val hostReset = Port(NoInfo, "hostReset", Input, UIntType(IntWidth(1))) + + private def genHostDecoupled(fields: Seq[Field]): BundleType = { + BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(fields))) + } private def genWrapperModule(inst: DefInst, portMap: PortMap): Module = { println(s"Wrapping ${inst.name}") @@ -157,29 +167,25 @@ object Fame1 { val instIO = getDefInstType(inst) val nameToField = (instIO.fields map (f => f.name -> f)).toMap - val connections = (portMap map(_._2)).toSeq.distinct // modules this inst connects to + val connections = (portMap map (_._2)).toSeq.flatten.distinct // modules this inst connects to // Build simPort for each connecting module // TODO This whole chunk really ought to be rewritten or made a function val connPorts = connections map { c => // Get ports that connect to this particular module as fields - val fields = (portMap filter (_._2 == c)).keySet.toSeq.sorted map (nameToField(_)) + val fields = (portMap filter (_._2.contains(c))).keySet.toSeq.sorted map (nameToField(_)) val noClock = fields filter (_.tpe != ClockType) // Remove clock val inputSet = noClock filter (_.dir == Reverse) map (f => Field(f.name, Default, f.tpe)) val outputSet = noClock filter (_.dir == Default) Port(inst.info, c, Output, BundleType( (if (inputSet.isEmpty) Seq() - else - Seq(Field("hostIn", Reverse, BundleType(Seq(hostReady, hostValid) :+ - Field("hostBits", Default, BundleType(inputSet))))) + else Seq(Field("hostIn", Reverse, genHostDecoupled(inputSet))) ) ++ (if (outputSet.isEmpty) Seq() - else - Seq(Field("hostOut", Default, BundleType(Seq(hostReady, hostValid) :+ - Field("hostBits", Default, BundleType(outputSet))))) + else Seq(Field("hostOut", Default, genHostDecoupled(inputSet))) ) )) } - val ports = hostClock +: connPorts // Add host Clock + val ports = hostClock +: hostReset +: connPorts // Add host and host reset // targetFire is signal to indicate when a simulation module can execute, this is indicated by all of its inputs // being valid and all of its outputs being ready @@ -215,17 +221,24 @@ object Fame1 { // Connect up all of the IO of the RTL module to sim module IO, except clock which should be connected // This currently assumes naming things that are also done above when generating connPorts val connectedInstIOFields = instIO.fields filter(field => portMap.contains(field.name)) // skip unconnected IO - val instIOConnect = connectedInstIOFields map { field => + val instIOConnect = (connectedInstIOFields map { field => field.tpe match { - case ClockType => Connect(inst.info, Ref(field.name, field.tpe), Ref(targetClock.name, ClockType)) + case ClockType => Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)), + Ref(targetClock.name, ClockType))) case _ => field.dir match { - case Default => Connect(inst.info, buildExp(Seq(portMap(field.name), "hostOut", field.name)), - buildExp(Seq(inst.name, field.name))) - case Reverse => Connect(inst.info, buildExp(Seq(inst.name, field.name)), - buildExp(Seq(portMap(field.name), "hostIn", field.name))) + case Default => portMap(field.name) map { endpoint => + Connect(inst.info, buildExp(Seq(endpoint, "hostOut", field.name)), + buildExp(Seq(inst.name, field.name))) + } + case Reverse => { + if (portMap(field.name).length > 1) + throw new Exception("It is illegal to have more than 1 connection to a single input" + field) + Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)), + buildExp(Seq(portMap(field.name).head, "hostIn", field.name)))) + } } } - } + }).flatten //val stmts = Block(Seq(simFire, simClock, inst) ++ inputsReady ++ outputsValid ++ instIOConnect) val stmts = Block(Seq(targetFire, targetClock) ++ inputsReady ++ outputsValid ++ Seq(inst) ++ instIOConnect) @@ -237,24 +250,54 @@ object Fame1 { // ***** generateSimQueues ***** // Takes Seq of SimWrapper modules // Returns Map of (src, dest) -> SimQueue - def generateSimQueues(wrappers: Seq[Module]): Map[(String, String), Module] = { - def rec(wrappers: Seq[Module], map: Map[(String, String), Module]): Map[(String, String), Module] = { + // To prevent duplicates, instead of creating a map with (src, dest) as the key, we could instead + // only one direction of the queue for each simport of each module. The only problem with this is + // it won't create queues for TopIO since that isn't a real module + private type SimQMap = Map[(String, String), Module] + private val SimQMap = Map[(String, String), Module]() + private def generateSimQueues(wrappers: Seq[Module]): SimQMap = { + def rec(wrappers: Seq[Module], map: SimQMap): SimQMap = { if (wrappers.isEmpty) map else { val w = wrappers.head val name = unwrapName(w.name) - val newMap = w.ports filter(isSimPort) map { port => - splitSimPort(port) map { field => + val newMap = (w.ports filter(isSimPort) map { port => + (splitSimPort(port) map { field => val (src, dst) = if (field.dir == Default) (name, port.name) else (port.name, name) - if (map.contains((src, dst))) Map[(String, String), Module]() - else (src, dst) -> buildSimQueue(s"SimQueue_${src}_${dst}", getHostBits(field).tpe) - } - } - println(newMap) - rec(wrappers.tail, map ++ Map()) + if (map.contains((src, dst))) SimQMap + else Map((src, dst) -> buildSimQueue(queueName(src, dst), getHostBits(field).tpe)) + }).flatten.toMap + }).flatten.toMap + rec(wrappers.tail, map ++ newMap) } } - rec(wrappers, Map[(String, String), Module]()) + rec(wrappers, SimQMap) + } + + // ***** generateSimTop ***** + // Creates the Simulation Top module where all sim modules and sim queues are instantiated and connected + private def generateSimTop(wrappers: Seq[Module], simQueues: SimQMap, rtlTop: Module): Module = { + val insts = (wrappers map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) ++ + (simQueues.values map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) + val connectClocks = (wrappers ++ simQueues.values) map { m => + Connect(NoInfo, buildExp(Seq(m.name, hostClock.name)), buildExp(hostClock.name)) + } + val connectResets = (wrappers ++ simQueues.values) map { m => + Connect(NoInfo, buildExp(Seq(m.name, hostReset.name)), buildExp(hostReset.name)) + } + val connectQueues = simQueues map { case (key, queue) => + (key._1, key._2) match { + //case (src, "topIO") => EmptyStmt + //case ("topIO", dst) => EmptyStmt + case (src, dst) => Block(Seq(BulkConnect(NoInfo, buildExp(Seq(queue.name, "io", "enq")), + buildExp(Seq(instName(wrapName(src)), dst, "hostOut"))), + BulkConnect(NoInfo, buildExp(Seq(queue.name, "io", "deq")), + buildExp(Seq(instName(wrapName(dst)), src, "hostIn"))))) + } + } + val stmts = Block(insts ++ connectClocks ++ connectResets ++ connectQueues) + val ports = Seq(hostClock, hostReset) ++ rtlTop.ports + Module(NoInfo, "SimTop", ports, stmts) } // ***** transform ***** @@ -268,21 +311,21 @@ object Fame1 { val portConn = findPortConn(top, insts) // Check that port Connections include all ports for each instance? + println(portConn) + + val simWrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name))) + + val simQueues = generateSimQueues(simWrappers) + + // Remove duplicate simWrapper and simQueue modules - val wrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name))) + val simTop = generateSimTop(simWrappers, simQueues, top) - generateSimQueues(wrappers) - //val w = wrappers.head - //println(w.serialize) - //w.ports filter (isSimPort) map { p => - // splitSimPort(p) foreach { f => - // val queueName = if (f.dir == Default) s"${w.name}_${p.name}" else s"${p.name}_${w.name}" - // println(buildSimQueue("SimQueue_" + queueName, getHostBits(f).tpe)) - // } - // //println(buildSimQueue(,p.tpe)) - //} + val modules = (c.modules filter (_.name != top.name)) ++ simWrappers ++ simQueues.values.toSeq ++ Seq(simTop) - c + println(simTop.serialize) + Circuit(c.info, simTop.name, modules) + //c } } diff --git a/src/main/scala/midas/Utils.scala b/src/main/scala/midas/Utils.scala index d2c61c1b..b9006b59 100644 --- a/src/main/scala/midas/Utils.scala +++ b/src/main/scala/midas/Utils.scala @@ -6,6 +6,20 @@ import firrtl.Utils._ object Utils { + // Merges a sequence of maps via the provided function f + // Taken from: https://groups.google.com/forum/#!topic/scala-user/HaQ4fVRjlnU + def merge[K, V](maps: Seq[Map[K, V]])(f: (K, V, V) => V): Map[K, V] = { + maps.foldLeft(Map.empty[K, V]) { case (merged, m) => + m.foldLeft(merged) { case (acc, (k, v)) => + acc.get(k) match { + case Some(existing) => acc.updated(k, f(k, existing, v)) + case None => acc.updated(k, v) + } + } + } + } + + // Takes a set of strings or ints and returns equivalent expression node // Strings correspond to subfields/references, ints correspond to indexes // eg. Seq(io, port, ready) => io.port.ready @@ -23,7 +37,7 @@ object Utils { else Index(rec(names.tail), head, UnknownType) // String -> Ref/Subfield case head: String => - if( names.tail.isEmpty ) Ref("head", UnknownType) + if( names.tail.isEmpty ) Ref(head, UnknownType) else Subfield(rec(names.tail), head, UnknownType) case _ => throw new Exception("Invalid argument type to buildExp! " + names) } @@ -85,41 +99,42 @@ object Utils { def buildSimQueue(name: String, tpe: Type): Module = { val templatedQueue = """ + circuit `NAME: module `NAME : - input clock : Clock - input reset : UInt<1> - output io : {flip enq : {flip ready : UInt<1>, valid : UInt<1>, bits : `TYPE}, deq : {flip ready : UInt<1>, valid : UInt<1>, bits : `TYPE}, count : UInt<3>} + input hostClock : Clock + input hostReset : UInt<1> + output io : {flip enq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, deq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, count : UInt<3>} io.count := UInt<1>("h00") - io.deq.bits.surprise.no := UInt<1>("h00") - io.deq.bits.surprise.yes := UInt<1>("h00") - io.deq.bits.store := UInt<1>("h00") - io.deq.bits.data := UInt<1>("h00") - io.deq.bits.addr := UInt<1>("h00") - io.deq.valid := UInt<1>("h00") - io.enq.ready := UInt<1>("h00") - cmem ram : `TYPE[4], clock - reg T_80 : UInt<2>, clock, reset + //io.deq.hostBits.surprise.no := UInt<1>("h00") + //io.deq.hostBits.surprise.yes := UInt<1>("h00") + //io.deq.hostBits.store := UInt<1>("h00") + //io.deq.hostBits.data := UInt<1>("h00") + //io.deq.hostBits.addr := UInt<1>("h00") + io.deq.hostValid := UInt<1>("h00") + io.enq.hostReady := UInt<1>("h00") + cmem ram : `TYPE[4], hostClock + reg T_80 : UInt<2>, hostClock, hostReset onreset T_80 := UInt<2>("h00") - reg T_82 : UInt<2>, clock, reset + reg T_82 : UInt<2>, hostClock, hostReset onreset T_82 := UInt<2>("h00") - reg maybe_full : UInt<1>, clock, reset + reg maybe_full : UInt<1>, hostClock, hostReset onreset maybe_full := UInt<1>("h00") node ptr_match = eq(T_80, T_82) node T_87 = eq(maybe_full, UInt<1>("h00")) node empty = and(ptr_match, T_87) node full = and(ptr_match, maybe_full) node maybe_flow = and(UInt<1>("h00"), empty) - node do_flow = and(maybe_flow, io.deq.ready) - node T_93 = and(io.enq.ready, io.enq.valid) + node do_flow = and(maybe_flow, io.deq.hostReady) + node T_93 = and(io.enq.hostReady, io.enq.hostValid) node T_95 = eq(do_flow, UInt<1>("h00")) node do_enq = and(T_93, T_95) - node T_97 = and(io.deq.ready, io.deq.valid) + node T_97 = and(io.deq.hostReady, io.deq.hostValid) node T_99 = eq(do_flow, UInt<1>("h00")) node do_deq = and(T_97, T_99) when do_enq : infer accessor T_101 = ram[T_80] - T_101 <> io.enq.bits + T_101 <> io.enq.hostBits node T_109 = eq(T_80, UInt<2>("h03")) node T_111 = and(UInt<1>("h00"), T_109) node T_114 = addw(T_80, UInt<1>("h01")) @@ -138,20 +153,20 @@ object Utils { maybe_full := do_enq skip node T_126 = eq(empty, UInt<1>("h00")) - node T_128 = and(UInt<1>("h00"), io.enq.valid) + node T_128 = and(UInt<1>("h00"), io.enq.hostValid) node T_129 = or(T_126, T_128) - io.deq.valid := T_129 + io.deq.hostValid := T_129 node T_131 = eq(full, UInt<1>("h00")) - node T_133 = and(UInt<1>("h00"), io.deq.ready) + node T_133 = and(UInt<1>("h00"), io.deq.hostReady) node T_134 = or(T_131, T_133) - io.enq.ready := T_134 + io.enq.hostReady := T_134 infer accessor T_135 = ram[T_82] wire T_149 : `TYPE T_149 <> T_135 when maybe_flow : - T_149 <> io.enq.bits + T_149 <> io.enq.hostBits skip - io.deq.bits <> T_149 + io.deq.hostBits <> T_149 node ptr_diff = subw(T_80, T_82) node T_157 = and(maybe_full, ptr_match) node T_158 = cat(T_157, ptr_diff) @@ -162,8 +177,8 @@ object Utils { replaceAllLiterally("`TYPE", tpe.serialize) // Generate initial values //val bitsField = Field("bits", Default, tpe) - println(concreteQueue.stripMargin) - firrtl.Parser.parseModule(concreteQueue) + val ast = firrtl.Parser.parse(concreteQueue.split("\n")) + ast.modules.head } } |
