diff options
| author | jackkoenig | 2015-12-06 00:36:12 -0800 |
|---|---|---|
| committer | jackkoenig | 2015-12-06 00:36:12 -0800 |
| commit | c5cac5227cd164b17f2a6f02227a71dc89f8cde4 (patch) | |
| tree | 1f6d30b64a58103574bacb770bbc307f8d1e4bbe /src | |
| parent | e8ac783706cca1f7ee65d799b5d8be445b6a5c5d (diff) | |
Working on generating SimTop, need to figure out how to split the top-level IO between the sim modules.
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/Example.scala | 19 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Driver.scala | 9 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Parser.scala | 34 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Translator.scala | 25 | ||||
| -rw-r--r-- | src/main/scala/midas/Fame.scala | 131 | ||||
| -rw-r--r-- | src/main/scala/midas/Utils.scala | 69 |
6 files changed, 151 insertions, 136 deletions
diff --git a/src/main/scala/Example.scala b/src/main/scala/Example.scala deleted file mode 100644 index 210a50fb..00000000 --- a/src/main/scala/Example.scala +++ /dev/null @@ -1,19 +0,0 @@ -import java.io._ -import firrtl._ -import firrtl.Utils._ - -object Example -{ - // Example use of Scala FIRRTL parser and serialization - def main(args: Array[String]) - { - val inputFile = args(0) - - // Parse file - val ast = firrtl.Parser.parse(inputFile) - - val writer = new PrintWriter(new File(args(1))) - writer.write(ast.serialize) // serialize returns String - writer.close() - } -} diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index c748f92e..ce8d2b1d 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -3,6 +3,7 @@ package firrtl import java.io._ import scala.sys.process._ import java.nio.file.{Paths, Files} +import scala.io.Source import Utils._ import DebugUtils._ import Passes._ @@ -31,7 +32,7 @@ object Driver // Parse input file and print to output private def firrtl(input: String, output: String)(implicit logger: Logger) { - val ast = Parser.parse(input) + val ast = Parser.parse(input, Source.fromFile(input).getLines) val writer = new PrintWriter(new File(output)) writer.write(ast.serialize()) writer.close() @@ -59,7 +60,7 @@ object Driver //// Don't lower //val temp1 = genTempFilename(input) - //val ast = Parser.parse(input) + //val ast = Parser.parse(Source.fromFile(input).getLines) //val writer = new PrintWriter(new File(temp1)) //val ast2 = fame1Transform(ast) //writer.write(ast2.serialize()) @@ -72,7 +73,7 @@ object Driver preCmd.! // Read in and execute infer-types - val ast = Parser.parse(temp1) + val ast = Parser.parse(input, Source.fromFile(temp1).getLines) val ast2 = inferTypes(ast)(logger) // FAME-1 Transformation @@ -125,7 +126,7 @@ object Driver // if( scalaPass.isEmpty ) { // scala2Stanza = stanza2Scala // } else { - // var ast = Parser.parse(stanza2Scala) + // var ast = Parser.parse(input, stanza2Scala) // //scalaPass.foreach( f => (ast = f(ast)) ) // Does this work? // for ( f <- scalaPass ) yield { ast = mapString2Pass(f)(ast) } diff --git a/src/main/scala/firrtl/Parser.scala b/src/main/scala/firrtl/Parser.scala index 40956ab7..00cd110e 100644 --- a/src/main/scala/firrtl/Parser.scala +++ b/src/main/scala/firrtl/Parser.scala @@ -11,38 +11,12 @@ import antlr._ object Parser { - def parseModule(string: String): Module = { - val fixedInput = Translator.addBrackets(Iterator(string)) - val antlrStream = new ANTLRInputStream(fixedInput.result) - val lexer = new FIRRTLLexer(antlrStream) - val tokens = new CommonTokenStream(lexer) - val parser = new FIRRTLParser(tokens) - - // FIXME Dangerous - parser.getInterpreter.setPredictionMode(PredictionMode.SLL) - - // Concrete Syntax Tree - val cst = parser.module - - val visitor = new Visitor("none") - //val ast = visitor.visitCircuit(cst) match { - val ast = visitor.visit(cst) match { - case m: Module => m - case x => throw new ClassCastException("Error! AST not rooted with Module node!") - } - - ast - - } - - /** Takes a firrtl filename, returns AST (root node is Circuit) + /** Takes Iterator over lines of FIRRTL, returns AST (root node is Circuit) * - * Currently must be standard FIRRTL file * Parser performs conversion to machine firrtl */ - def parse(filename: String): Circuit = { - //val antlrStream = new ANTLRInputStream(input.reader) - val fixedInput = Translator.addBrackets(Source.fromFile(filename).getLines) + def parse(filename: String, lines: Iterator[String]): Circuit = { + val fixedInput = Translator.addBrackets(lines) val antlrStream = new ANTLRInputStream(fixedInput.result) val lexer = new FIRRTLLexer(antlrStream) val tokens = new CommonTokenStream(lexer) @@ -64,4 +38,6 @@ object Parser ast } + def parse(lines: Seq[String]): Circuit = parse("<None>", lines.iterator) + } diff --git a/src/main/scala/firrtl/Translator.scala b/src/main/scala/firrtl/Translator.scala index e7bd6821..9fe40af8 100644 --- a/src/main/scala/firrtl/Translator.scala +++ b/src/main/scala/firrtl/Translator.scala @@ -31,21 +31,20 @@ object Translator if( !it.hasNext ) throw new Exception("Empty file!") - //// Find circuit before starting scope checks - //var line = it.next - //while ( it.hasNext && !line._1.contains("circuit") ) { - // ret ++= line._1 + "\n" - // line = it.next - //} - //ret ++= line._1 + " { \n" - //if( !it.hasNext ) throw new Exception("No circuit in file!") + // Find circuit before starting scope checks + var line = it.next + while ( it.hasNext && !line._1.contains("circuit") ) { + ret ++= line._1 + "\n" + line = it.next + } + ret ++= line._1 + " { \n" + if( !it.hasNext ) throw new Exception("No circuit in file!") val scope = Stack[Int]() - scope.push(0) - var newScope = false - //scope.push(countSpaces(line._1)) - //var newScope = true // indicates if increasing scope spacing is legal on next line + val lowestScope = countSpaces(line._1) + scope.push(lowestScope) + var newScope = true // indicates if increasing scope spacing is legal on next line while( it.hasNext ) { it.next match { case (lineText, lineNum) => @@ -94,7 +93,7 @@ object Translator } // while( it.hasNext ) // Print any closing braces - while( scope.top > 0 ) { + while( scope.top > lowestScope ) { scope.pop() ret.deleteCharAt(ret.lastIndexOf("\n")) // Put on previous line ret ++= " }\n" diff --git a/src/main/scala/midas/Fame.scala b/src/main/scala/midas/Fame.scala index 169193f0..f53ab56d 100644 --- a/src/main/scala/midas/Fame.scala +++ b/src/main/scala/midas/Fame.scala @@ -100,8 +100,8 @@ object Fame1 { // of instanceName -> (instanctPorts -> portEndpoint) // It honestly feels kind of brittle given it assumes there will be no intermediate nodes or anything in // the way of direct connections between IO of module instances - private type PortMap = Map[String, String] - private val PortMap = Map[String, String]() + private type PortMap = Map[String, Seq[String]] + private val PortMap = Map[String, Seq[String]]() private type ConnMap = Map[String, PortMap] private val ConnMap = Map[String, PortMap]() private def processConnectExp(exp: Exp): (String, String) = { @@ -119,7 +119,7 @@ object Fame1 { private def processConnect(conn: Connect): ConnMap = { val lhs = processConnectExp(conn.lhs) val rhs = processConnectExp(conn.rhs) - Map(lhs._1 -> Map(lhs._2 -> rhs._1), rhs._1 -> Map(rhs._2 -> lhs._1)).withDefaultValue(PortMap) + Map(lhs._1 -> Map(lhs._2 -> Seq(rhs._1)), rhs._1 -> Map(rhs._2 -> Seq(lhs._1))).withDefaultValue(PortMap) } private def findPortConn(connMap: ConnMap, stmts: Seq[Stmt]): ConnMap = { if (stmts.isEmpty) connMap @@ -127,7 +127,8 @@ object Fame1 { stmts.head match { case conn: Connect => { val newConnMap = processConnect(conn) - findPortConn(connMap map {case (k,v) => k -> (v ++ newConnMap(k)) }, stmts.tail) + findPortConn((connMap map { case (k,v) => + k -> merge(Seq(v, newConnMap(k))) { (_, v1, v2) => v1 ++ v2 }}), stmts.tail ) } case _ => findPortConn(connMap, stmts.tail) } @@ -142,14 +143,23 @@ object Fame1 { findPortConn(initConnMap, topStmts) } - private def wrapName(name: String): String = "SimWrap_" + name + // ***** Name Translation functions to help make naming consistent and easy to change ***** + private def wrapName(name: String): String = s"SimWrap_${name}" private def unwrapName(name: String): String = name.stripPrefix("SimWrap_") + private def queueName(src: String, dst: String): String = s"SimQueue_${src}_${dst}" + private def instName(name: String): String = s"inst_${name}" + private def unInstName(name: String): String = name.stripPrefix("inst_") // ***** genWrapperModule ***** // Generates FAME-1 Decoupled wrappers for simulation module instances private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1))) private val hostValid = Field("hostValid", Default, UIntType(IntWidth(1))) private val hostClock = Port(NoInfo, "hostClock", Input, ClockType) + private val hostReset = Port(NoInfo, "hostReset", Input, UIntType(IntWidth(1))) + + private def genHostDecoupled(fields: Seq[Field]): BundleType = { + BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(fields))) + } private def genWrapperModule(inst: DefInst, portMap: PortMap): Module = { println(s"Wrapping ${inst.name}") @@ -157,29 +167,25 @@ object Fame1 { val instIO = getDefInstType(inst) val nameToField = (instIO.fields map (f => f.name -> f)).toMap - val connections = (portMap map(_._2)).toSeq.distinct // modules this inst connects to + val connections = (portMap map (_._2)).toSeq.flatten.distinct // modules this inst connects to // Build simPort for each connecting module // TODO This whole chunk really ought to be rewritten or made a function val connPorts = connections map { c => // Get ports that connect to this particular module as fields - val fields = (portMap filter (_._2 == c)).keySet.toSeq.sorted map (nameToField(_)) + val fields = (portMap filter (_._2.contains(c))).keySet.toSeq.sorted map (nameToField(_)) val noClock = fields filter (_.tpe != ClockType) // Remove clock val inputSet = noClock filter (_.dir == Reverse) map (f => Field(f.name, Default, f.tpe)) val outputSet = noClock filter (_.dir == Default) Port(inst.info, c, Output, BundleType( (if (inputSet.isEmpty) Seq() - else - Seq(Field("hostIn", Reverse, BundleType(Seq(hostReady, hostValid) :+ - Field("hostBits", Default, BundleType(inputSet))))) + else Seq(Field("hostIn", Reverse, genHostDecoupled(inputSet))) ) ++ (if (outputSet.isEmpty) Seq() - else - Seq(Field("hostOut", Default, BundleType(Seq(hostReady, hostValid) :+ - Field("hostBits", Default, BundleType(outputSet))))) + else Seq(Field("hostOut", Default, genHostDecoupled(inputSet))) ) )) } - val ports = hostClock +: connPorts // Add host Clock + val ports = hostClock +: hostReset +: connPorts // Add host and host reset // targetFire is signal to indicate when a simulation module can execute, this is indicated by all of its inputs // being valid and all of its outputs being ready @@ -215,17 +221,24 @@ object Fame1 { // Connect up all of the IO of the RTL module to sim module IO, except clock which should be connected // This currently assumes naming things that are also done above when generating connPorts val connectedInstIOFields = instIO.fields filter(field => portMap.contains(field.name)) // skip unconnected IO - val instIOConnect = connectedInstIOFields map { field => + val instIOConnect = (connectedInstIOFields map { field => field.tpe match { - case ClockType => Connect(inst.info, Ref(field.name, field.tpe), Ref(targetClock.name, ClockType)) + case ClockType => Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)), + Ref(targetClock.name, ClockType))) case _ => field.dir match { - case Default => Connect(inst.info, buildExp(Seq(portMap(field.name), "hostOut", field.name)), - buildExp(Seq(inst.name, field.name))) - case Reverse => Connect(inst.info, buildExp(Seq(inst.name, field.name)), - buildExp(Seq(portMap(field.name), "hostIn", field.name))) + case Default => portMap(field.name) map { endpoint => + Connect(inst.info, buildExp(Seq(endpoint, "hostOut", field.name)), + buildExp(Seq(inst.name, field.name))) + } + case Reverse => { + if (portMap(field.name).length > 1) + throw new Exception("It is illegal to have more than 1 connection to a single input" + field) + Seq(Connect(inst.info, buildExp(Seq(inst.name, field.name)), + buildExp(Seq(portMap(field.name).head, "hostIn", field.name)))) + } } } - } + }).flatten //val stmts = Block(Seq(simFire, simClock, inst) ++ inputsReady ++ outputsValid ++ instIOConnect) val stmts = Block(Seq(targetFire, targetClock) ++ inputsReady ++ outputsValid ++ Seq(inst) ++ instIOConnect) @@ -237,24 +250,54 @@ object Fame1 { // ***** generateSimQueues ***** // Takes Seq of SimWrapper modules // Returns Map of (src, dest) -> SimQueue - def generateSimQueues(wrappers: Seq[Module]): Map[(String, String), Module] = { - def rec(wrappers: Seq[Module], map: Map[(String, String), Module]): Map[(String, String), Module] = { + // To prevent duplicates, instead of creating a map with (src, dest) as the key, we could instead + // only one direction of the queue for each simport of each module. The only problem with this is + // it won't create queues for TopIO since that isn't a real module + private type SimQMap = Map[(String, String), Module] + private val SimQMap = Map[(String, String), Module]() + private def generateSimQueues(wrappers: Seq[Module]): SimQMap = { + def rec(wrappers: Seq[Module], map: SimQMap): SimQMap = { if (wrappers.isEmpty) map else { val w = wrappers.head val name = unwrapName(w.name) - val newMap = w.ports filter(isSimPort) map { port => - splitSimPort(port) map { field => + val newMap = (w.ports filter(isSimPort) map { port => + (splitSimPort(port) map { field => val (src, dst) = if (field.dir == Default) (name, port.name) else (port.name, name) - if (map.contains((src, dst))) Map[(String, String), Module]() - else (src, dst) -> buildSimQueue(s"SimQueue_${src}_${dst}", getHostBits(field).tpe) - } - } - println(newMap) - rec(wrappers.tail, map ++ Map()) + if (map.contains((src, dst))) SimQMap + else Map((src, dst) -> buildSimQueue(queueName(src, dst), getHostBits(field).tpe)) + }).flatten.toMap + }).flatten.toMap + rec(wrappers.tail, map ++ newMap) } } - rec(wrappers, Map[(String, String), Module]()) + rec(wrappers, SimQMap) + } + + // ***** generateSimTop ***** + // Creates the Simulation Top module where all sim modules and sim queues are instantiated and connected + private def generateSimTop(wrappers: Seq[Module], simQueues: SimQMap, rtlTop: Module): Module = { + val insts = (wrappers map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) ++ + (simQueues.values map { m => DefInst(NoInfo, instName(m.name), buildExp(m.name)) }) + val connectClocks = (wrappers ++ simQueues.values) map { m => + Connect(NoInfo, buildExp(Seq(m.name, hostClock.name)), buildExp(hostClock.name)) + } + val connectResets = (wrappers ++ simQueues.values) map { m => + Connect(NoInfo, buildExp(Seq(m.name, hostReset.name)), buildExp(hostReset.name)) + } + val connectQueues = simQueues map { case (key, queue) => + (key._1, key._2) match { + //case (src, "topIO") => EmptyStmt + //case ("topIO", dst) => EmptyStmt + case (src, dst) => Block(Seq(BulkConnect(NoInfo, buildExp(Seq(queue.name, "io", "enq")), + buildExp(Seq(instName(wrapName(src)), dst, "hostOut"))), + BulkConnect(NoInfo, buildExp(Seq(queue.name, "io", "deq")), + buildExp(Seq(instName(wrapName(dst)), src, "hostIn"))))) + } + } + val stmts = Block(insts ++ connectClocks ++ connectResets ++ connectQueues) + val ports = Seq(hostClock, hostReset) ++ rtlTop.ports + Module(NoInfo, "SimTop", ports, stmts) } // ***** transform ***** @@ -268,21 +311,21 @@ object Fame1 { val portConn = findPortConn(top, insts) // Check that port Connections include all ports for each instance? + println(portConn) + + val simWrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name))) + + val simQueues = generateSimQueues(simWrappers) + + // Remove duplicate simWrapper and simQueue modules - val wrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name))) + val simTop = generateSimTop(simWrappers, simQueues, top) - generateSimQueues(wrappers) - //val w = wrappers.head - //println(w.serialize) - //w.ports filter (isSimPort) map { p => - // splitSimPort(p) foreach { f => - // val queueName = if (f.dir == Default) s"${w.name}_${p.name}" else s"${p.name}_${w.name}" - // println(buildSimQueue("SimQueue_" + queueName, getHostBits(f).tpe)) - // } - // //println(buildSimQueue(,p.tpe)) - //} + val modules = (c.modules filter (_.name != top.name)) ++ simWrappers ++ simQueues.values.toSeq ++ Seq(simTop) - c + println(simTop.serialize) + Circuit(c.info, simTop.name, modules) + //c } } diff --git a/src/main/scala/midas/Utils.scala b/src/main/scala/midas/Utils.scala index d2c61c1b..b9006b59 100644 --- a/src/main/scala/midas/Utils.scala +++ b/src/main/scala/midas/Utils.scala @@ -6,6 +6,20 @@ import firrtl.Utils._ object Utils { + // Merges a sequence of maps via the provided function f + // Taken from: https://groups.google.com/forum/#!topic/scala-user/HaQ4fVRjlnU + def merge[K, V](maps: Seq[Map[K, V]])(f: (K, V, V) => V): Map[K, V] = { + maps.foldLeft(Map.empty[K, V]) { case (merged, m) => + m.foldLeft(merged) { case (acc, (k, v)) => + acc.get(k) match { + case Some(existing) => acc.updated(k, f(k, existing, v)) + case None => acc.updated(k, v) + } + } + } + } + + // Takes a set of strings or ints and returns equivalent expression node // Strings correspond to subfields/references, ints correspond to indexes // eg. Seq(io, port, ready) => io.port.ready @@ -23,7 +37,7 @@ object Utils { else Index(rec(names.tail), head, UnknownType) // String -> Ref/Subfield case head: String => - if( names.tail.isEmpty ) Ref("head", UnknownType) + if( names.tail.isEmpty ) Ref(head, UnknownType) else Subfield(rec(names.tail), head, UnknownType) case _ => throw new Exception("Invalid argument type to buildExp! " + names) } @@ -85,41 +99,42 @@ object Utils { def buildSimQueue(name: String, tpe: Type): Module = { val templatedQueue = """ + circuit `NAME: module `NAME : - input clock : Clock - input reset : UInt<1> - output io : {flip enq : {flip ready : UInt<1>, valid : UInt<1>, bits : `TYPE}, deq : {flip ready : UInt<1>, valid : UInt<1>, bits : `TYPE}, count : UInt<3>} + input hostClock : Clock + input hostReset : UInt<1> + output io : {flip enq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, deq : {flip hostReady : UInt<1>, hostValid : UInt<1>, hostBits : `TYPE}, count : UInt<3>} io.count := UInt<1>("h00") - io.deq.bits.surprise.no := UInt<1>("h00") - io.deq.bits.surprise.yes := UInt<1>("h00") - io.deq.bits.store := UInt<1>("h00") - io.deq.bits.data := UInt<1>("h00") - io.deq.bits.addr := UInt<1>("h00") - io.deq.valid := UInt<1>("h00") - io.enq.ready := UInt<1>("h00") - cmem ram : `TYPE[4], clock - reg T_80 : UInt<2>, clock, reset + //io.deq.hostBits.surprise.no := UInt<1>("h00") + //io.deq.hostBits.surprise.yes := UInt<1>("h00") + //io.deq.hostBits.store := UInt<1>("h00") + //io.deq.hostBits.data := UInt<1>("h00") + //io.deq.hostBits.addr := UInt<1>("h00") + io.deq.hostValid := UInt<1>("h00") + io.enq.hostReady := UInt<1>("h00") + cmem ram : `TYPE[4], hostClock + reg T_80 : UInt<2>, hostClock, hostReset onreset T_80 := UInt<2>("h00") - reg T_82 : UInt<2>, clock, reset + reg T_82 : UInt<2>, hostClock, hostReset onreset T_82 := UInt<2>("h00") - reg maybe_full : UInt<1>, clock, reset + reg maybe_full : UInt<1>, hostClock, hostReset onreset maybe_full := UInt<1>("h00") node ptr_match = eq(T_80, T_82) node T_87 = eq(maybe_full, UInt<1>("h00")) node empty = and(ptr_match, T_87) node full = and(ptr_match, maybe_full) node maybe_flow = and(UInt<1>("h00"), empty) - node do_flow = and(maybe_flow, io.deq.ready) - node T_93 = and(io.enq.ready, io.enq.valid) + node do_flow = and(maybe_flow, io.deq.hostReady) + node T_93 = and(io.enq.hostReady, io.enq.hostValid) node T_95 = eq(do_flow, UInt<1>("h00")) node do_enq = and(T_93, T_95) - node T_97 = and(io.deq.ready, io.deq.valid) + node T_97 = and(io.deq.hostReady, io.deq.hostValid) node T_99 = eq(do_flow, UInt<1>("h00")) node do_deq = and(T_97, T_99) when do_enq : infer accessor T_101 = ram[T_80] - T_101 <> io.enq.bits + T_101 <> io.enq.hostBits node T_109 = eq(T_80, UInt<2>("h03")) node T_111 = and(UInt<1>("h00"), T_109) node T_114 = addw(T_80, UInt<1>("h01")) @@ -138,20 +153,20 @@ object Utils { maybe_full := do_enq skip node T_126 = eq(empty, UInt<1>("h00")) - node T_128 = and(UInt<1>("h00"), io.enq.valid) + node T_128 = and(UInt<1>("h00"), io.enq.hostValid) node T_129 = or(T_126, T_128) - io.deq.valid := T_129 + io.deq.hostValid := T_129 node T_131 = eq(full, UInt<1>("h00")) - node T_133 = and(UInt<1>("h00"), io.deq.ready) + node T_133 = and(UInt<1>("h00"), io.deq.hostReady) node T_134 = or(T_131, T_133) - io.enq.ready := T_134 + io.enq.hostReady := T_134 infer accessor T_135 = ram[T_82] wire T_149 : `TYPE T_149 <> T_135 when maybe_flow : - T_149 <> io.enq.bits + T_149 <> io.enq.hostBits skip - io.deq.bits <> T_149 + io.deq.hostBits <> T_149 node ptr_diff = subw(T_80, T_82) node T_157 = and(maybe_full, ptr_match) node T_158 = cat(T_157, ptr_diff) @@ -162,8 +177,8 @@ object Utils { replaceAllLiterally("`TYPE", tpe.serialize) // Generate initial values //val bitsField = Field("bits", Default, tpe) - println(concreteQueue.stripMargin) - firrtl.Parser.parseModule(concreteQueue) + val ast = firrtl.Parser.parse(concreteQueue.split("\n")) + ast.modules.head } } |
