diff options
| author | jackkoenig | 2015-12-04 18:15:03 -0800 |
|---|---|---|
| committer | jackkoenig | 2015-12-04 18:15:03 -0800 |
| commit | e8ac783706cca1f7ee65d799b5d8be445b6a5c5d (patch) | |
| tree | f709f4f522f1e54c41c70ae733334646d2ef17af | |
| parent | 4d88455c66bd3aa7fd549cdec4f1d05ede83fea2 (diff) | |
Everything is broken, need Translator to work on files without a circuit, need to parse queue module text in midas/Utils.scala, need to create (src, dst) -> Module mapping in midas/Fame.scala
| -rw-r--r-- | src/main/scala/firrtl/Parser.scala | 23 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Translator.scala | 24 | ||||
| -rw-r--r-- | src/main/scala/midas/Fame.scala | 57 | ||||
| -rw-r--r-- | src/main/scala/midas/Utils.scala | 187 |
4 files changed, 205 insertions, 86 deletions
diff --git a/src/main/scala/firrtl/Parser.scala b/src/main/scala/firrtl/Parser.scala index 35e41222..40956ab7 100644 --- a/src/main/scala/firrtl/Parser.scala +++ b/src/main/scala/firrtl/Parser.scala @@ -11,6 +11,29 @@ import antlr._ object Parser { + def parseModule(string: String): Module = { + val fixedInput = Translator.addBrackets(Iterator(string)) + val antlrStream = new ANTLRInputStream(fixedInput.result) + val lexer = new FIRRTLLexer(antlrStream) + val tokens = new CommonTokenStream(lexer) + val parser = new FIRRTLParser(tokens) + + // FIXME Dangerous + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + + // Concrete Syntax Tree + val cst = parser.module + + val visitor = new Visitor("none") + //val ast = visitor.visitCircuit(cst) match { + val ast = visitor.visit(cst) match { + case m: Module => m + case x => throw new ClassCastException("Error! AST not rooted with Module node!") + } + + ast + + } /** Takes a firrtl filename, returns AST (root node is Circuit) * diff --git a/src/main/scala/firrtl/Translator.scala b/src/main/scala/firrtl/Translator.scala index 152cd88e..e7bd6821 100644 --- a/src/main/scala/firrtl/Translator.scala +++ b/src/main/scala/firrtl/Translator.scala @@ -31,19 +31,21 @@ object Translator if( !it.hasNext ) throw new Exception("Empty file!") - // Find circuit before starting scope checks - var line = it.next - while ( it.hasNext && !line._1.contains("circuit") ) { - ret ++= line._1 + "\n" - line = it.next - } - ret ++= line._1 + " { \n" - if( !it.hasNext ) throw new Exception("No circuit in file!") + //// Find circuit before starting scope checks + //var line = it.next + //while ( it.hasNext && !line._1.contains("circuit") ) { + // ret ++= line._1 + "\n" + // line = it.next + //} + //ret ++= line._1 + " { \n" + //if( !it.hasNext ) throw new Exception("No circuit in file!") val scope = Stack[Int]() - scope.push(countSpaces(line._1)) - var newScope = true // indicates if increasing scope spacing is legal on next line + scope.push(0) + var newScope = false + //scope.push(countSpaces(line._1)) + //var newScope = true // indicates if increasing scope spacing is legal on next line while( it.hasNext ) { it.next match { case (lineText, lineNum) => @@ -52,7 +54,7 @@ object Translator val l = if (text.length > spaces ) { // Check that line has text in it if (newScope) { - if( spaces == scope.top ) scope.push(spaces+2) // Hack for one-line scopes + if( spaces <= scope.top ) scope.push(spaces+2) // Hack for one-line scopes else scope.push(spaces) } diff --git a/src/main/scala/midas/Fame.scala b/src/main/scala/midas/Fame.scala index 84373aa2..169193f0 100644 --- a/src/main/scala/midas/Fame.scala +++ b/src/main/scala/midas/Fame.scala @@ -142,6 +142,9 @@ object Fame1 { findPortConn(initConnMap, topStmts) } + private def wrapName(name: String): String = "SimWrap_" + name + private def unwrapName(name: String): String = name.stripPrefix("SimWrap_") + // ***** genWrapperModule ***** // Generates FAME-1 Decoupled wrappers for simulation module instances private val hostReady = Field("hostReady", Reverse, UIntType(IntWidth(1))) @@ -149,14 +152,15 @@ object Fame1 { private val hostClock = Port(NoInfo, "hostClock", Input, ClockType) private def genWrapperModule(inst: DefInst, portMap: PortMap): Module = { - + println(s"Wrapping ${inst.name}") + println(portMap) val instIO = getDefInstType(inst) val nameToField = (instIO.fields map (f => f.name -> f)).toMap val connections = (portMap map(_._2)).toSeq.distinct // modules this inst connects to // Build simPort for each connecting module // TODO This whole chunk really ought to be rewritten or made a function - val connPorts = connections map { c => + val connPorts = connections map { c => // Get ports that connect to this particular module as fields val fields = (portMap filter (_._2 == c)).keySet.toSeq.sorted map (nameToField(_)) val noClock = fields filter (_.tpe != ClockType) // Remove clock @@ -170,7 +174,7 @@ object Fame1 { ) ++ (if (outputSet.isEmpty) Seq() else - Seq(Field("hostOut", Reverse, BundleType(Seq(hostReady, hostValid) :+ + Seq(Field("hostOut", Default, BundleType(Seq(hostReady, hostValid) :+ Field("hostBits", Default, BundleType(outputSet))))) ) )) @@ -196,22 +200,23 @@ object Fame1 { // As a simple RTL module, we're always ready val inputsReady = (connPorts map { port => - getFields(port) filter (_.dir == Reverse) map { field => + getFields(port) filter (_.dir == Reverse) map { field => // filter to only take inputs Connect(inst.info, buildExp(Seq(port.name, field.name, hostReady.name)), UIntValue(1, IntWidth(1))) } }).flatten // Outputs are valid on cycles where we fire val outputsValid = (connPorts map { port => - getFields(port) filter (_.dir == Default) map { field => + getFields(port) filter (_.dir == Default) map { field => // filter to only take outputs Connect(inst.info, buildExp(Seq(port.name, field.name, hostValid.name)), buildExp(targetFire.name)) } }).flatten // Connect up all of the IO of the RTL module to sim module IO, except clock which should be connected // This currently assumes naming things that are also done above when generating connPorts - val instIOConnect = instIO.fields map { field => - field.tpe match { + val connectedInstIOFields = instIO.fields filter(field => portMap.contains(field.name)) // skip unconnected IO + val instIOConnect = connectedInstIOFields map { field => + field.tpe match { case ClockType => Connect(inst.info, Ref(field.name, field.tpe), Ref(targetClock.name, ClockType)) case _ => field.dir match { case Default => Connect(inst.info, buildExp(Seq(portMap(field.name), "hostOut", field.name)), @@ -224,7 +229,32 @@ object Fame1 { //val stmts = Block(Seq(simFire, simClock, inst) ++ inputsReady ++ outputsValid ++ instIOConnect) val stmts = Block(Seq(targetFire, targetClock) ++ inputsReady ++ outputsValid ++ Seq(inst) ++ instIOConnect) - Module(inst.info, s"SimWrap_${inst.name}", ports, stmts) + Module(inst.info, wrapName(inst.name), ports, stmts) + } + + + + // ***** generateSimQueues ***** + // Takes Seq of SimWrapper modules + // Returns Map of (src, dest) -> SimQueue + def generateSimQueues(wrappers: Seq[Module]): Map[(String, String), Module] = { + def rec(wrappers: Seq[Module], map: Map[(String, String), Module]): Map[(String, String), Module] = { + if (wrappers.isEmpty) map + else { + val w = wrappers.head + val name = unwrapName(w.name) + val newMap = w.ports filter(isSimPort) map { port => + splitSimPort(port) map { field => + val (src, dst) = if (field.dir == Default) (name, port.name) else (port.name, name) + if (map.contains((src, dst))) Map[(String, String), Module]() + else (src, dst) -> buildSimQueue(s"SimQueue_${src}_${dst}", getHostBits(field).tpe) + } + } + println(newMap) + rec(wrappers.tail, map ++ Map()) + } + } + rec(wrappers, Map[(String, String), Module]()) } // ***** transform ***** @@ -241,7 +271,16 @@ object Fame1 { val wrappers = insts map (inst => genWrapperModule(inst, portConn(inst.name))) - wrappers foreach (w => println(w.serialize)) + generateSimQueues(wrappers) + //val w = wrappers.head + //println(w.serialize) + //w.ports filter (isSimPort) map { p => + // splitSimPort(p) foreach { f => + // val queueName = if (f.dir == Default) s"${w.name}_${p.name}" else s"${p.name}_${w.name}" + // println(buildSimQueue("SimQueue_" + queueName, getHostBits(f).tpe)) + // } + // //println(buildSimQueue(,p.tpe)) + //} c } diff --git a/src/main/scala/midas/Utils.scala b/src/main/scala/midas/Utils.scala index e9c68201..d2c61c1b 100644 --- a/src/main/scala/midas/Utils.scala +++ b/src/main/scala/midas/Utils.scala @@ -2,6 +2,7 @@ package midas import firrtl._ +import firrtl.Utils._ object Utils { @@ -22,7 +23,7 @@ object Utils { else Index(rec(names.tail), head, UnknownType) // String -> Ref/Subfield case head: String => - if( names.tail.isEmpty ) Ref(head, UnknownType) + if( names.tail.isEmpty ) Ref("head", UnknownType) else Subfield(rec(names.tail), head, UnknownType) case _ => throw new Exception("Invalid argument type to buildExp! " + names) } @@ -36,6 +37,42 @@ object Utils { else DoPrimop(op, Seq(args.head, genPrimopReduce(op, args.tail)), Seq(), UnknownType) } + // Checks if a firrtl.Port matches the MIDAS SimPort pattern + // This currently just checks that the port is of type bundle with ONLY the members + // hostIn and/or hostOut with correct directions + def isSimPort(port: Port): Boolean = { + //println("isSimPort called on port " + port.serialize) + port.tpe match { + case b: BundleType => { + b.fields map { field => + if( field.name == "hostIn" ) field.dir == Reverse + else if( field.name == "hostOut" ) field.dir == Default + else false + } reduce ( _ & _ ) + } + case _ => false + } + } + + def splitSimPort(port: Port): Seq[Field] = { + try { + val b = port.tpe.asInstanceOf[BundleType] + Seq(b.fields.find(_.name == "hostIn"), b.fields.find(_.name == "hostOut")).flatten + } catch { + case e: Exception => throw new Exception("Invalid SimPort " + port.serialize) + } + } + + // From simulation host decoupled, return hostbits field + def getHostBits(field: Field): Field = { + try { + val b = field.tpe.asInstanceOf[BundleType] + b.fields.find(_.name == "hostBits").get + } catch { + case e: Exception => throw new Exception("Invalid SimField " + field.serialize) + } + } + // For a port that is known to be of type BundleType, return the fields of that bundle def getFields(port: Port): Seq[Field] = { port.tpe match { @@ -45,70 +82,88 @@ object Utils { } // Queue - /* - module Queue : - input clock : Clock - input reset : UInt<1> - output io : {flip enq : {flip ready : UInt<1>, valid : UInt<1>, bits : UInt<32>}, deq : {flip ready : UInt<1>, valid : UInt<1>, bits : UInt<32>}, count : UInt<3>} - - io.count := UInt<1>("h00") - io.deq.bits := UInt<1>("h00") - io.deq.valid := UInt<1>("h00") - io.enq.ready := UInt<1>("h00") - cmem ram : UInt<32>[4], clock - reg T_26 : UInt<2>, clock, reset - onreset T_26 := UInt<2>("h00") - reg T_28 : UInt<2>, clock, reset - onreset T_28 := UInt<2>("h00") - reg maybe_full : UInt<1>, clock, reset - onreset maybe_full := UInt<1>("h00") - node ptr_match = eq(T_26, T_28) - node T_33 = eq(maybe_full, UInt<1>("h00")) - node empty = and(ptr_match, T_33) - node full = and(ptr_match, maybe_full) - node maybe_flow = and(UInt<1>("h00"), empty) - node do_flow = and(maybe_flow, io.deq.ready) - node T_39 = and(io.enq.ready, io.enq.valid) - node T_41 = eq(do_flow, UInt<1>("h00")) - node do_enq = and(T_39, T_41) - node T_43 = and(io.deq.ready, io.deq.valid) - node T_45 = eq(do_flow, UInt<1>("h00")) - node do_deq = and(T_43, T_45) - when do_enq : - infer accessor T_47 = ram[T_26] - T_47 := io.enq.bits - node T_49 = eq(T_26, UInt<2>("h03")) - node T_51 = and(UInt<1>("h00"), T_49) - node T_54 = addw(T_26, UInt<1>("h01")) - node T_55 = mux(T_51, UInt<1>("h00"), T_54) - T_26 := T_55 - skip - when do_deq : - node T_57 = eq(T_28, UInt<2>("h03")) - node T_59 = and(UInt<1>("h00"), T_57) - node T_62 = addw(T_28, UInt<1>("h01")) - node T_63 = mux(T_59, UInt<1>("h00"), T_62) - T_28 := T_63 - skip - node T_64 = neq(do_enq, do_deq) - when T_64 : - maybe_full := do_enq - skip - node T_66 = eq(empty, UInt<1>("h00")) - node T_68 = and(UInt<1>("h00"), io.enq.valid) - node T_69 = or(T_66, T_68) - io.deq.valid := T_69 - node T_71 = eq(full, UInt<1>("h00")) - node T_73 = and(UInt<1>("h00"), io.deq.ready) - node T_74 = or(T_71, T_73) - io.enq.ready := T_74 - infer accessor T_75 = ram[T_28] - node T_76 = mux(maybe_flow, io.enq.bits, T_75) - io.deq.bits := T_76 - node ptr_diff = subw(T_26, T_28) - node T_78 = and(maybe_full, ptr_match) - node T_79 = cat(T_78, ptr_diff) - io.count := T_79 - */ + def buildSimQueue(name: String, tpe: Type): Module = { + val templatedQueue = + """ + module `NAME : + input clock : Clock + input reset : UInt<1> + output io : {flip enq : {flip ready : UInt<1>, valid : UInt<1>, bits : `TYPE}, deq : {flip ready : UInt<1>, valid : UInt<1>, bits : `TYPE}, count : UInt<3>} + + io.count := UInt<1>("h00") + io.deq.bits.surprise.no := UInt<1>("h00") + io.deq.bits.surprise.yes := UInt<1>("h00") + io.deq.bits.store := UInt<1>("h00") + io.deq.bits.data := UInt<1>("h00") + io.deq.bits.addr := UInt<1>("h00") + io.deq.valid := UInt<1>("h00") + io.enq.ready := UInt<1>("h00") + cmem ram : `TYPE[4], clock + reg T_80 : UInt<2>, clock, reset + onreset T_80 := UInt<2>("h00") + reg T_82 : UInt<2>, clock, reset + onreset T_82 := UInt<2>("h00") + reg maybe_full : UInt<1>, clock, reset + onreset maybe_full := UInt<1>("h00") + node ptr_match = eq(T_80, T_82) + node T_87 = eq(maybe_full, UInt<1>("h00")) + node empty = and(ptr_match, T_87) + node full = and(ptr_match, maybe_full) + node maybe_flow = and(UInt<1>("h00"), empty) + node do_flow = and(maybe_flow, io.deq.ready) + node T_93 = and(io.enq.ready, io.enq.valid) + node T_95 = eq(do_flow, UInt<1>("h00")) + node do_enq = and(T_93, T_95) + node T_97 = and(io.deq.ready, io.deq.valid) + node T_99 = eq(do_flow, UInt<1>("h00")) + node do_deq = and(T_97, T_99) + when do_enq : + infer accessor T_101 = ram[T_80] + T_101 <> io.enq.bits + node T_109 = eq(T_80, UInt<2>("h03")) + node T_111 = and(UInt<1>("h00"), T_109) + node T_114 = addw(T_80, UInt<1>("h01")) + node T_115 = mux(T_111, UInt<1>("h00"), T_114) + T_80 := T_115 + skip + when do_deq : + node T_117 = eq(T_82, UInt<2>("h03")) + node T_119 = and(UInt<1>("h00"), T_117) + node T_122 = addw(T_82, UInt<1>("h01")) + node T_123 = mux(T_119, UInt<1>("h00"), T_122) + T_82 := T_123 + skip + node T_124 = neq(do_enq, do_deq) + when T_124 : + maybe_full := do_enq + skip + node T_126 = eq(empty, UInt<1>("h00")) + node T_128 = and(UInt<1>("h00"), io.enq.valid) + node T_129 = or(T_126, T_128) + io.deq.valid := T_129 + node T_131 = eq(full, UInt<1>("h00")) + node T_133 = and(UInt<1>("h00"), io.deq.ready) + node T_134 = or(T_131, T_133) + io.enq.ready := T_134 + infer accessor T_135 = ram[T_82] + wire T_149 : `TYPE + T_149 <> T_135 + when maybe_flow : + T_149 <> io.enq.bits + skip + io.deq.bits <> T_149 + node ptr_diff = subw(T_80, T_82) + node T_157 = and(maybe_full, ptr_match) + node T_158 = cat(T_157, ptr_diff) + io.count := T_158 + """ + //def buildQueue(name: String, tpe: Type): Module = { + val concreteQueue = templatedQueue.replaceAllLiterally("`NAME", name). + replaceAllLiterally("`TYPE", tpe.serialize) + // Generate initial values + //val bitsField = Field("bits", Default, tpe) + println(concreteQueue.stripMargin) + firrtl.Parser.parseModule(concreteQueue) + } } |
