diff options
| author | Kamyar Mohajerani | 2016-07-21 23:40:34 +0430 |
|---|---|---|
| committer | Jack Koenig | 2016-07-21 12:10:34 -0700 |
| commit | ab340febdc7a5418da945f9b79624d36e66e26db (patch) | |
| tree | 04e4aef30081fdd419281d69be4b141fd49b4b1f /src/main | |
| parent | b7de40e23161a7346fea90576f07b5c200c2675b (diff) | |
Indentation support for the ANTLR parser (as discussed in #192) (#194)
Indentation support for the ANTLR parser
- some clean-up of the parser code (TODO: file input could be improved, more clean-up)
- get rid of Translator and specify all syntactic rules in antlr4 grammer
- support for else-when shorthand in the grammar
- rename Begin to Block which makes more sense
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/antlr4/FIRRTL.g4 | 120 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/LexerHelper.scala | 159 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Mappers.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Parser.scala | 47 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Serialize.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Translator.scala | 171 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Visitor.scala | 291 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 12 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 28 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/SplitExpressions.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Uniquify.scala | 2 |
17 files changed, 483 insertions, 381 deletions
diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index c452fae1..ef0fd7eb 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -26,6 +26,27 @@ MODIFICATIONS. */ grammar FIRRTL; +tokens { INDENT, DEDENT } + +@lexer::header { +import firrtl.LexerHelper; +} + +@lexer::members { + private final LexerHelper denter = new firrtl.LexerHelper() + { + @Override + public Token pullToken() { + return FIRRTLLexer.super.nextToken(); + } + }; + + @Override + public Token nextToken() { + return denter.nextToken(); + } +} + /*------------------------------------------------------------------ * PARSER RULES *------------------------------------------------------------------*/ @@ -37,16 +58,16 @@ grammar FIRRTL; // Does there have to be at least one module? circuit - : 'circuit' id ':' info? '{' module* '}' + : 'circuit' id ':' info? INDENT module* DEDENT ; module - : 'module' id ':' info? '{' port* block '}' - | 'extmodule' id ':' info? '{' port* '}' + : 'module' id ':' info? INDENT port* moduleBlock DEDENT + | 'extmodule' id ':' info? INDENT port* DEDENT ; port - : dir id ':' type info? + : dir id ':' type info? NEWLINE ; dir @@ -66,25 +87,26 @@ field : 'flip'? id ':' type ; -// Much faster than replacing block with stmt+ -block - : (stmt)* - ; +moduleBlock + : simple_stmt* + ; + +simple_reset0: 'reset' '=>' '(' exp exp ')'; + +simple_reset + : simple_reset0 + | '(' simple_reset0 ')' + ; + +reset_block + : INDENT simple_reset NEWLINE DEDENT + | '(' + simple_reset + ')' + ; stmt : 'wire' id ':' type info? - | 'reg' id ':' type exp ('with' ':' '{' 'reset' '=>' '(' exp exp ')' '}')? info? - | 'mem' id ':' info? '{' - ( 'data-type' '=>' type - | 'depth' '=>' IntLit - | 'read-latency' '=>' IntLit - | 'write-latency' '=>' IntLit - | 'read-under-write' '=>' ruw - | 'reader' '=>' id - | 'writer' '=>' id - | 'readwriter' '=>' id - )* - '}' + | 'reg' id ':' type exp ('with' ':' reset_block)? info? + | 'mem' id ':' info? INDENT memField* DEDENT | 'cmem' id ':' type info? | 'smem' id ':' type info? | mdir 'mport' id '=' id '[' exp ']' exp info? @@ -93,12 +115,43 @@ stmt | exp '<=' exp info? | exp '<-' exp info? | exp 'is' 'invalid' info? - | 'when' exp ':' info? '{' block '}' ( 'else' ':' '{' block '}' )? + | when | 'stop(' exp exp IntLit ')' info? - | 'printf(' exp exp StringLit (exp)* ')' info? + | 'printf(' exp exp StringLit ( exp)* ')' info? | 'skip' info? ; +memField + : 'data-type' '=>' type NEWLINE + | 'depth' '=>' IntLit NEWLINE + | 'read-latency' '=>' IntLit NEWLINE + | 'write-latency' '=>' IntLit NEWLINE + | 'read-under-write' '=>' ruw NEWLINE + | 'reader' '=>' id+ NEWLINE + | 'writer' '=>' id+ NEWLINE + | 'readwriter' '=>' id+ NEWLINE + ; + +simple_stmt + : stmt | NEWLINE + ; + +/* + We should provide syntatctical distinction between a "moduleBody" and a "suite": + - statements require a "suite" which means they can EITHER have a "simple statement" (one-liner) on the same line + OR a group of one or more _indented_ statements after a new-line. A "suite" may _not_ be empty + - modules on the other hand require a group of one or more statements without any indentation to follow "port" + definitions. Let's call that _the_ "moduleBody". A "moduleBody" could possibly be empty +*/ +suite + : simple_stmt + | INDENT simple_stmt+ DEDENT + ; + +when + : 'when' exp ':' info? suite? ('else' ( when | ':' info? suite?) )? + ; + info : FileInfo ; @@ -267,19 +320,18 @@ IdNondigit | [~!@#$%^*\-+=?/] ; -Comment - : ';' ~[\r\n]* - -> skip +fragment COMMENT + : ';' ~[\r\n]* ; -Whitespace - : [ \t,]+ - -> skip - ; - -Newline - : ( '\r'? '\n' )+ - -> skip - ; +fragment WHITESPACE + : [ \t,]+ + ; +SKIP_ + : ( WHITESPACE | COMMENT ) -> skip + ; +NEWLINE + :'\r'? '\n' ' '* + ; diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 041e674d..0c0cc36d 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -603,7 +603,7 @@ class VerilogEmitter extends Emitter { update(wmem_port,datax,clk,AND(tempWRef,wmode)) } } - case (s:Begin) => s map (build_streams) + case (s:Block) => s map (build_streams) } s } diff --git a/src/main/scala/firrtl/LexerHelper.scala b/src/main/scala/firrtl/LexerHelper.scala new file mode 100644 index 00000000..2b2f68a8 --- /dev/null +++ b/src/main/scala/firrtl/LexerHelper.scala @@ -0,0 +1,159 @@ +package firrtl + +import org.antlr.v4.runtime.{CommonToken, Token} + +import scala.annotation.tailrec +import scala.collection.mutable +import firrtl.antlr.FIRRTLParser + +/* + * ANTLR4 helper to handle indentation tokens in Lexer + * code adapted from: "https://github.com/yshavit/antlr-denter" (Yuval Shavit, MIT license) + */ + +abstract class LexerHelper { + + import FIRRTLParser.{NEWLINE, INDENT, DEDENT} + + private val tokenBuffer = mutable.Queue.empty[Token] + private val indentations = mutable.Stack[Int]() + private var reachedEof = false + + private def eofHandler(t: Token): Token = { + // when we reach EOF, unwind all indentations. If there aren't any, insert a NEWLINE. This lets the grammar treat + // un-indented expressions as just being NEWLINE-terminated, rather than NEWLINE|EOF. + val ret = + if (indentations.isEmpty) + createToken(NEWLINE, t) + else + unwindTo(0, t) + + tokenBuffer.enqueue(t) + reachedEof = true + + ret + } + + def nextToken(): Token = { + // first run + if (indentations.isEmpty) { + indentations.push(0) + + @tailrec + def findFirstRead(): Token = { + val t = pullToken() + if (t.getType != NEWLINE) t else findFirstRead() + } + + val firstRealToken = findFirstRead() + + if (firstRealToken.getCharPositionInLine > 0) { + indentations.push(firstRealToken.getCharPositionInLine) + tokenBuffer.enqueue(createToken(INDENT, firstRealToken)) + } + tokenBuffer.enqueue(firstRealToken) + } + + def handleNewlineToken(token: Token): Token = { + @tailrec + def nonNewline(token: Token) : (Token, Token) = { + val nextNext = pullToken() + if(nextNext.getType == NEWLINE) + nonNewline(nextNext) + else + (token, nextNext) + } + val (nxtToken, nextNext) = nonNewline(token) + + if (nextNext.getType == Token.EOF) + eofHandler(nextNext) + else { + val nlText = nxtToken.getText + val indent = + if (nlText.length > 0 && nlText.charAt(0) == '\r') + nlText.length - 2 + else + nlText.length - 1 + + val prevIndent = indentations.head + + val retToken = + if (indent == prevIndent) + nxtToken + else if (indent > prevIndent) { + indentations.push(indent) + createToken(INDENT, nxtToken) + } else { + unwindTo(indent, nxtToken) + } + + tokenBuffer.enqueue(nextNext) + retToken + } + } + + val t = if (tokenBuffer.isEmpty) + pullToken() + else + tokenBuffer.dequeue + + if (reachedEof) + t + else if (t.getType == NEWLINE) + handleNewlineToken(t) + else if (t.getType == Token.EOF) + eofHandler(t) + else + t + } + + // will be overriden to FIRRTLLexer.super.nextToken() in the g4 file + protected def pullToken(): Token + + private def createToken(tokenType: Int, copyFrom: Token): Token = + new CommonToken(copyFrom) { + setType(tokenType) + tokenType match { + case `NEWLINE` => setText("<NEWLINE>") + case `INDENT` => setText("<INDENT>") + case `DEDENT` => setText("<DEDENT>") + } + } + + /** + * Returns a DEDENT token, and also queues up additional DEDENTs as necessary. + * + * @param targetIndent the "size" of the indentation (number of spaces) by the end + * @param copyFrom the triggering token + * @return a DEDENT token + */ + private def unwindTo(targetIndent: Int, copyFrom: Token): Token = { + assert(tokenBuffer.isEmpty, tokenBuffer) + tokenBuffer.enqueue(createToken(NEWLINE, copyFrom)) + // To make things easier, we'll queue up ALL of the dedents, and then pop off the first one. + // For example, here's how some text is analyzed: + // + // Text : Indentation : Action : Indents Deque + // [ baseline ] : 0 : nothing : [0] + // [ foo ] : 2 : INDENT : [0, 2] + // [ bar ] : 3 : INDENT : [0, 2, 3] + // [ baz ] : 0 : DEDENT x2 : [0] + + @tailrec + def doPop(): Unit = { + val prevIndent = indentations.pop() + if (prevIndent < targetIndent) { + indentations.push(prevIndent) + tokenBuffer.enqueue(createToken(INDENT, copyFrom)) + } else if (prevIndent > targetIndent) { + tokenBuffer.enqueue(createToken(DEDENT, copyFrom)) + doPop() + } + } + + doPop() + + indentations.push(targetIndent) + tokenBuffer.dequeue + } +}
\ No newline at end of file diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala index c00ca855..5f073e0d 100644 --- a/src/main/scala/firrtl/Mappers.scala +++ b/src/main/scala/firrtl/Mappers.scala @@ -41,7 +41,7 @@ object Mappers { override def map(stmt: Statement): Statement = { stmt match { case s: Conditionally => Conditionally(s.info, s.pred, f(s.conseq), f(s.alt)) - case s: Begin => Begin(s.stmts.map(f)) + case s: Block => Block(s.stmts.map(f)) case s: Statement => s } } diff --git a/src/main/scala/firrtl/Parser.scala b/src/main/scala/firrtl/Parser.scala index dc8d6875..aa6ea63f 100644 --- a/src/main/scala/firrtl/Parser.scala +++ b/src/main/scala/firrtl/Parser.scala @@ -26,39 +26,53 @@ MODIFICATIONS. */ package firrtl -import org.antlr.v4.runtime._; -import org.antlr.v4.runtime.atn._; +import java.io.{ByteArrayInputStream, SequenceInputStream} + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn._ import com.typesafe.scalalogging.LazyLogging import firrtl.ir._ -import Utils.{time} -import antlr._ +import firrtl.Utils.time +import firrtl.antlr.{FIRRTLParser, _} class ParserException(message: String) extends Exception(message) + case class ParameterNotSpecifiedException(message: String) extends ParserException(message) + case class ParameterRedefinedException(message: String) extends ParserException(message) + case class InvalidStringLitException(message: String) extends ParserException(message) + case class InvalidEscapeCharException(message: String) extends ParserException(message) -object Parser extends LazyLogging -{ + +object Parser extends LazyLogging { /** Takes Iterator over lines of FIRRTL, returns FirrtlNode (root node is Circuit) */ def parse(lines: Iterator[String], infoMode: InfoMode = UseInfo): Circuit = { - val fixedInput = time("Translator") { Translator.addBrackets(lines) } - val antlrStream = new ANTLRInputStream(fixedInput.result) - val lexer = new FIRRTLLexer(antlrStream) - val tokens = new CommonTokenStream(lexer) - val parser = new FIRRTLParser(tokens) - time("ANTLR Parser") { parser.getInterpreter.setPredictionMode(PredictionMode.SLL) } + val parser = { + import scala.collection.JavaConverters._ + val inStream = new SequenceInputStream( + lines.map{s => new ByteArrayInputStream((s + "\n").getBytes("UTF-8")) }.asJavaEnumeration + ) + val lexer = new FIRRTLLexer(new ANTLRInputStream(inStream)) + new FIRRTLParser(new CommonTokenStream(lexer)) + } + + time("ANTLR Parser") { + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + } // Concrete Syntax Tree val cst = parser.circuit val numSyntaxErrors = parser.getNumberOfSyntaxErrors - if (numSyntaxErrors > 0) throw new ParserException(s"${numSyntaxErrors} syntax error(s) detected") + if (numSyntaxErrors > 0) throw new ParserException(s"$numSyntaxErrors syntax error(s) detected") val visitor = new Visitor(infoMode) - val ast = time("Visitor") { visitor.visit(cst) } match { + val ast = time("Visitor") { + visitor.visit(cst) + } match { case c: Circuit => c case x => throw new ClassCastException("Error! AST not rooted with Circuit node!") } @@ -69,8 +83,13 @@ object Parser extends LazyLogging def parse(lines: Seq[String]): Circuit = parse(lines.iterator) sealed abstract class InfoMode + case object IgnoreInfo extends InfoMode + case object UseInfo extends InfoMode + case class GenInfo(filename: String) extends InfoMode + case class AppendInfo(filename: String) extends InfoMode + } diff --git a/src/main/scala/firrtl/Serialize.scala b/src/main/scala/firrtl/Serialize.scala index 2c45c6ec..d28675b0 100644 --- a/src/main/scala/firrtl/Serialize.scala +++ b/src/main/scala/firrtl/Serialize.scala @@ -129,7 +129,7 @@ class Serialize { } } } - case b: Begin => { + case b: Block => { val s = new StringBuilder for (i <- 0 until b.stmts.size) { if (i != 0) s ++= newline ++ serialize(b.stmts(i)) diff --git a/src/main/scala/firrtl/Translator.scala b/src/main/scala/firrtl/Translator.scala deleted file mode 100644 index 4b0bd1e7..00000000 --- a/src/main/scala/firrtl/Translator.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* -Copyright (c) 2014 - 2016 The Regents of the University of -California (Regents). All Rights Reserved. Redistribution and use in -source and binary forms, with or without modification, are permitted -provided that the following conditions are met: - * Redistributions of source code must retain the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer. - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer in the documentation and/or other materials - provided with the distribution. - * Neither the name of the Regents nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. -IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, -SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, -ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF -REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF -ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION -TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR -MODIFICATIONS. -*/ - -/* TODO - * - Add better error messages for illformed FIRRTL - * - Add support for files that do not have a circuit (like a module by itself in a file) - * - Improve performance? Replace regex? - * - Wrap in Reader subclass. This would have less memory footprint than creating a large string - */ - -package firrtl - -import scala.io.Source -import scala.collection.mutable.Stack -import scala.collection.mutable.StringBuilder -import java.io._ - - -object Translator -{ - - def addBrackets(inputIt: Iterator[String]): StringBuilder = { - def countSpaces(s: String): Int = s.prefixLength(_ == ' ') - def stripComments(s: String): String = { - // Delete anything after first semicolon that's not in a comment - var done = false - var inComment = false - var escape = false - var i = 0 - while (!done && i < s.length) { - val c = s(i) - if (c == ';') { - if (!inComment) { - done = true - i = i - 1 // remove semicolon as well as what follows - } - } else { - if (c == '"' && !escape) inComment = !inComment - escape = if (c == '\\' && !escape) true else false - } - i += 1 - } - s.take(i) - } - def extractFileInfo(str: String): (String, String) = str span (_ != '@') - - val scopers = """(circuit|module|when|else|mem|with)""" - val MultiLineScope = ("""(.*""" + scopers + """)(.*:\s*)""").r - val OneLineScope = ("""(.*(with)\s*:\s*)\((.*)\)\s*""").r - - // Function start - val it = inputIt.zipWithIndex - var ret = new StringBuilder() - - if( !it.hasNext ) throw new Exception("Empty file!") - - // Find circuit before starting scope checks - var line = it.next - while ( it.hasNext && !stripComments(line._1).contains("circuit") ) { - ret ++= line._1 + "\n" - line = it.next - } - ret ++= line._1 + " { \n" - if( !it.hasNext ) throw new Exception("No circuit in file!") - - - val scope = Stack[Int]() - val lowestScope = countSpaces(line._1) - scope.push(lowestScope) - var newScope = true // indicates if increasing scope spacing is legal on next line - - while( it.hasNext ) { - it.next match { case (lineText, lineNum) => - val text = stripComments(lineText) - val (code, fileInfo) = extractFileInfo(text) - val spaces = countSpaces(text) - - val l = if (text.length > spaces ) { // Check that line has text in it - if (newScope) { - if( spaces <= scope.top ) scope.push(spaces+2) // Hack for one-line scopes - else scope.push(spaces) - } - - // Check if change in current scope - if( spaces < scope.top ) { - while( spaces < scope.top ) { - // Close scopes (adding brackets as we go) - scope.pop() - ret.deleteCharAt(ret.lastIndexOf("\n")) // Put on previous line - ret ++= " }\n" - } - if( spaces != scope.top ) - throw new Exception("Spacing does not match scope on line : " + lineNum + " : " + scope.top) - } - else if( spaces > scope.top ) - throw new Exception("Invalid increase in scope on line " + lineNum) - - // Now match on legal scope increasers - code match { - case OneLineScope(head, keyword, body) => { - newScope = false - head + "{" + body + "} " + fileInfo - } - case MultiLineScope(head, keyword, tail) => { - newScope = true - text + " { " - } - case _ => { - newScope = false - text - } - } - } // if( text.length > spaces ) - else { - text // empty lines - } - - ret ++= l + "\n" - } // it.next match - } // while( it.hasNext ) - - // Print any closing braces - while( scope.top > lowestScope ) { - scope.pop() - ret.deleteCharAt(ret.lastIndexOf("\n")) // Put on previous line - ret ++= " }\n" - } - - ret - } - - def main(args: Array[String]) { - - try { - val translation = addBrackets(Source.fromFile(args(0)).getLines) - - val writer = new PrintWriter(new File(args(1))) - writer.write(translation.result) - writer.close() - } catch { - case e: Exception => { - throw new Exception("USAGE: Translator <input file> <output file>\n" + e) - } - } - } - -} diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index a5253e84..2053a70d 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -458,7 +458,7 @@ object Utils extends LazyLogging { case s:IsInvalid => s.info case s:Stop => s.info case s:Print => s.info - case s:Begin => NoInfo + case s:Block => NoInfo case EmptyStmt => NoInfo } } @@ -484,7 +484,7 @@ object Utils extends LazyLogging { case s:DefNode => MALE case s:DefInstance => MALE case s:DefMemory => MALE - case s:Begin => UNKNOWNGENDER + case s:Block => UNKNOWNGENDER case s:Connect => UNKNOWNGENDER case s:PartialConnect => UNKNOWNGENDER case s:Stop => UNKNOWNGENDER @@ -630,7 +630,7 @@ object Utils extends LazyLogging { case (None, Some(decl)) => Some(decl) case (None, None) => None } - case begin: Begin => + case begin: Block => val stmts = begin.stmts flatMap getRootDecl(name) // can we short circuit? if (stmts.nonEmpty) Some(stmts.head) else None case _ => None diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 91f9a0ce..05202555 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -32,55 +32,54 @@ MODIFICATIONS. package firrtl -import org.antlr.v4.runtime.tree.AbstractParseTreeVisitor; import org.antlr.v4.runtime.ParserRuleContext -import org.antlr.v4.runtime.tree.ParseTree -import org.antlr.v4.runtime.tree.ErrorNode import org.antlr.v4.runtime.tree.TerminalNode import scala.collection.JavaConversions._ -import antlr._ +import scala.collection.mutable +import firrtl.antlr._ import PrimOps._ import FIRRTLParser._ -import Parser.{InfoMode, IgnoreInfo, UseInfo, GenInfo, AppendInfo} +import Parser.{AppendInfo, GenInfo, IgnoreInfo, InfoMode, UseInfo} import firrtl.ir._ -import scala.annotation.tailrec -class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] -{ + +class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] { // Strip file path - private def stripPath(filename: String) = filename.drop(filename.lastIndexOf("/")+1) + private def stripPath(filename: String) = filename.drop(filename.lastIndexOf("/") + 1) def visit[FirrtlNode](ctx: FIRRTLParser.CircuitContext): Circuit = visitCircuit(ctx) // These regex have to change if grammar changes private def string2BigInt(s: String): BigInt = { // private define legal patterns - val HexPattern = """\"*h([a-zA-Z0-9]+)\"*""".r + val HexPattern = + """\"*h([a-zA-Z0-9]+)\"*""".r val DecPattern = """(\+|-)?([1-9]\d*)""".r val ZeroPattern = "0".r val NegPattern = "(89AaBbCcDdEeFf)".r s match { case ZeroPattern(_*) => BigInt(0) - case HexPattern(hexdigits) => - hexdigits(0) match { - case NegPattern(_) =>{ - BigInt("-" + hexdigits,16) - } - case _ => BigInt(hexdigits, 16) - } + case HexPattern(hexdigits) => + hexdigits(0) match { + case NegPattern(_) => { + BigInt("-" + hexdigits, 16) + } + case _ => BigInt(hexdigits, 16) + } case DecPattern(sign, num) => { - if (sign != null) BigInt(sign + num,10) - else BigInt(num,10) + if (sign != null) BigInt(sign + num, 10) + else BigInt(num, 10) } - case _ => throw new Exception("Invalid String for conversion to BigInt " + s) + case _ => throw new Exception("Invalid String for conversion to BigInt " + s) } } + private def string2Int(s: String): Int = string2BigInt(s).toInt private def visitInfo(ctx: Option[FIRRTLParser.InfoContext], parentCtx: ParserRuleContext): Info = { def genInfo(filename: String): String = stripPath(filename) + "@" + parentCtx.getStart.getLine + "." + - parentCtx.getStart.getCharPositionInLine + parentCtx.getStart.getCharPositionInLine lazy val useInfo: String = ctx match { case Some(info) => info.getText.drop(2).init // remove surrounding @[ ... ] case None => "" @@ -98,25 +97,30 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] } } - private def visitCircuit[FirrtlNode](ctx: FIRRTLParser.CircuitContext): Circuit = - Circuit(visitInfo(Option(ctx.info), ctx), ctx.module.map(visitModule), (ctx.id.getText)) - + private def visitCircuit[FirrtlNode](ctx: FIRRTLParser.CircuitContext): Circuit = + Circuit(visitInfo(Option(ctx.info), ctx), ctx.module.map(visitModule), ctx.id.getText) + private def visitModule[FirrtlNode](ctx: FIRRTLParser.ModuleContext): DefModule = { val info = visitInfo(Option(ctx.info), ctx) ctx.getChild(0).getText match { - case "module" => Module(info, ctx.id.getText, ctx.port.map(visitPort), visitBlock(ctx.block)) + case "module" => Module(info, ctx.id.getText, ctx.port.map(visitPort), + if (ctx.moduleBlock() != null) + visitBlock(ctx.moduleBlock()) + else EmptyStmt) case "extmodule" => ExtModule(info, ctx.id.getText, ctx.port.map(visitPort)) } } private def visitPort[FirrtlNode](ctx: FIRRTLParser.PortContext): Port = { - Port(visitInfo(Option(ctx.info), ctx), (ctx.id.getText), visitDir(ctx.dir), visitType(ctx.`type`)) + Port(visitInfo(Option(ctx.info), ctx), ctx.id.getText, visitDir(ctx.dir), visitType(ctx.`type`)) } + private def visitDir[FirrtlNode](ctx: FIRRTLParser.DirContext): Direction = ctx.getText match { case "input" => Input case "output" => Output } + private def visitMdir[FirrtlNode](ctx: FIRRTLParser.MdirContext): MPortDir = ctx.getText match { case "infer" => MInfer @@ -128,64 +132,85 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] // Match on a type instead of on strings? private def visitType[FirrtlNode](ctx: FIRRTLParser.TypeContext): Type = { ctx.getChild(0) match { - case term: TerminalNode => + case term: TerminalNode => term.getText match { - case "UInt" => if (ctx.getChildCount > 1) UIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) - else UIntType( UnknownWidth ) - case "SInt" => if (ctx.getChildCount > 1) SIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) - else SIntType( UnknownWidth ) + case "UInt" => if (ctx.getChildCount > 1) UIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) + else UIntType(UnknownWidth) + case "SInt" => if (ctx.getChildCount > 1) SIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) + else SIntType(UnknownWidth) case "Clock" => ClockType case "{" => BundleType(ctx.field.map(visitField)) } - case tpe: TypeContext => new VectorType(visitType(ctx.`type`), string2Int(ctx.IntLit.getText)) + case typeContext: TypeContext => new VectorType(visitType(ctx.`type`), string2Int(ctx.IntLit.getText)) } } - - private def visitField[FirrtlNode](ctx: FIRRTLParser.FieldContext): Field = { - val flip = if(ctx.getChild(0).getText == "flip") Flip else Default - Field((ctx.id.getText), flip, visitType(ctx.`type`)) + + private def visitField[FirrtlNode](ctx: FIRRTLParser.FieldContext): Field = { + val flip = if (ctx.getChild(0).getText == "flip") Flip else Default + Field(ctx.id.getText, flip, visitType(ctx.`type`)) } - - // visitBlock - private def visitBlock[FirrtlNode](ctx: FIRRTLParser.BlockContext): Statement = - Begin(ctx.stmt.map(visitStmt)) + private def visitBlock[FirrtlNode](ctx: FIRRTLParser.ModuleBlockContext): Statement = + Block(ctx.simple_stmt().map(_.stmt).filter(_ != null).map(visitStmt)) + + private def visitSuite[FirrtlNode](ctx: FIRRTLParser.SuiteContext): Statement = + Block(ctx.simple_stmt().map(_.stmt).filter(_ != null).map(visitStmt)) + // Memories are fairly complicated to translate thus have a dedicated method private def visitMem[FirrtlNode](ctx: FIRRTLParser.StmtContext): Statement = { - def parseChildren(children: Seq[ParseTree], map: Map[String, Seq[ParseTree]]): Map[String, Seq[ParseTree]] = { - val field = children(0).getText - if (field == "}") map - else { - val newMap = - if (field == "reader" || field == "writer" || field == "readwriter") { - val seq = map getOrElse (field, Seq()) - map + (field -> (seq :+ children(2))) - } else { // data-type, depth, read-latency, write-latency, read-under-write - if (map.contains(field)) throw new ParameterRedefinedException(s"Redefinition of ${field}") - else map + (field -> Seq(children(2))) - } - parseChildren(children.drop(3), newMap) // We consume tokens in groups of three (eg. 'depth' '=>' 5) + val readers = mutable.ArrayBuffer.empty[String] + val writers = mutable.ArrayBuffer.empty[String] + val readwriters = mutable.ArrayBuffer.empty[String] + case class ParamValue(typ: Option[Type] = None, lit: Option[Int] = None, ruw: Option[String] = None, unique: Boolean = true) + val fieldMap = mutable.HashMap[String, ParamValue]() + + def parseMemFields(memFields: Seq[MemFieldContext]): Unit = + memFields.foreach { field => + val fieldName = field.children(0).getText + + fieldName match { + case "reader" => readers ++= field.id().map(_.getText) + case "writer" => writers ++= field.id().map(_.getText) + case "readwriter" => readwriters ++= field.id().map(_.getText) + case _ => + val paramDef = fieldName match { + case "data-type" => ParamValue(typ = Some(visitType(field.`type`()))) + case "read-under-write" => ParamValue(ruw = Some(field.ruw().getText)) // TODO + case _ => ParamValue(lit = Some(field.IntLit().getText.toInt)) + } + if (fieldMap.contains(fieldName)) + throw new ParameterRedefinedException(s"Redefinition of $fieldName in FIRRTL line:${field.start.getLine}") + else + fieldMap += fieldName -> paramDef + } } - } val info = visitInfo(Option(ctx.info), ctx) + // Build map of different Memory fields to their values - val map = try { - parseChildren(ctx.children.drop(4), Map[String, Seq[ParseTree]]()) // First 4 tokens are 'mem' id ':' '{', skip to fields - } catch { // attach line number - case e: ParameterRedefinedException => throw new ParameterRedefinedException(s"[${info}] ${e.message}") + try { + parseMemFields(ctx.memField()) + } catch { + // attach line number + case e: ParameterRedefinedException => throw new ParameterRedefinedException(s"[$info] ${e.message}") } + // Check for required fields Seq("data-type", "depth", "read-latency", "write-latency") foreach { field => - map.getOrElse(field, throw new ParameterNotSpecifiedException(s"[${info}] Required mem field ${field} not found")) + fieldMap.getOrElse(field, throw new ParameterNotSpecifiedException(s"[$info] Required mem field $field not found")) } - // Each memory field value has been left as ParseTree type, need to convert - // TODO Improve? Remove dynamic typecast of data-type - DefMemory(info, (ctx.id(0).getText), visitType(map("data-type").head.asInstanceOf[FIRRTLParser.TypeContext]), - string2Int(map("depth").head.getText), string2Int(map("write-latency").head.getText), - string2Int(map("read-latency").head.getText), map.getOrElse("reader", Seq()).map(x => (x.getText)), - map.getOrElse("writer", Seq()).map(x => (x.getText)), map.getOrElse("readwriter", Seq()).map(x => (x.getText))) + + def lit(param: String) = fieldMap(param).lit.get + val ruw = fieldMap.get("read-under-write").map(_.ruw).getOrElse(None) + + DefMemory(info, + name = ctx.id(0).getText, dataType = fieldMap("data-type").typ.get, + depth = lit("depth"), + writeLatency = lit("write-latency"), readLatency = lit("read-latency"), + readers = readers, writers = writers, readwriters = readwriters, + readUnderWrite = ruw + ) } // visitStringLit @@ -194,106 +219,122 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] FIRRTLStringLitHandler.unescape(raw) } + private def visitWhen[FirrtlNode](ctx: WhenContext): Conditionally = { + val info = visitInfo(Option(ctx.info(0)), ctx) + + val alt: Statement = + if (ctx.when() != null) + visitWhen(ctx.when()) + else if (ctx.suite().length > 1) + visitSuite(ctx.suite(1)) + else + EmptyStmt + + Conditionally(info, visitExp(ctx.exp()), visitSuite(ctx.suite(0)), alt) + } + // visitStmt private def visitStmt[FirrtlNode](ctx: FIRRTLParser.StmtContext): Statement = { val info = visitInfo(Option(ctx.info), ctx) ctx.getChild(0) match { + case when: WhenContext => visitWhen(when) case term: TerminalNode => term.getText match { - case "wire" => DefWire(info, (ctx.id(0).getText), visitType(ctx.`type`(0))) - case "reg" => { - val name = (ctx.id(0).getText) - val tpe = visitType(ctx.`type`(0)) - val reset = if (ctx.exp(1) != null) visitExp(ctx.exp(1)) else UIntLiteral(0, IntWidth(1)) - val init = if (ctx.exp(2) != null) visitExp(ctx.exp(2)) else Reference(name, tpe) + case "wire" => DefWire(info, ctx.id(0).getText, visitType(ctx.`type`())) + case "reg" => + val name = ctx.id(0).getText + val tpe = visitType(ctx.`type`()) + val (reset, init) = { + val rb = ctx.reset_block() + if (rb != null) { + val sr = rb.simple_reset(0).simple_reset0() + (visitExp(sr.exp(0)), visitExp(sr.exp(1))) + } + else + (UIntLiteral(0, IntWidth(1)), Reference(name, tpe)) + } DefRegister(info, name, tpe, visitExp(ctx.exp(0)), reset, init) - } case "mem" => visitMem(ctx) - case "cmem" => { - val t = visitType(ctx.`type`(0)) - t match { - case (t:VectorType) => CDefMemory(info,ctx.id(0).getText,t.tpe,t.size,false) - case _ => throw new ParserException(s"${info}: Must provide cmem with vector type") - } - } - case "smem" => { - val t = visitType(ctx.`type`(0)) - t match { - case (t:VectorType) => CDefMemory(info,ctx.id(0).getText,t.tpe,t.size,true) - case _ => throw new ParserException(s"${info}: Must provide cmem with vector type") - } - } - case "inst" => DefInstance(info, (ctx.id(0).getText), (ctx.id(1).getText)) - case "node" => DefNode(info, (ctx.id(0).getText), visitExp(ctx.exp(0))) - case "when" => { - val alt = if (ctx.block.length > 1) visitBlock(ctx.block(1)) else EmptyStmt - Conditionally(info, visitExp(ctx.exp(0)), visitBlock(ctx.block(0)), alt) - } - case "stop(" => Stop(info, string2Int(ctx.IntLit(0).getText), visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) + case "cmem" => + val t = visitType(ctx.`type`()) + t match { + case (t: VectorType) => CDefMemory(info, ctx.id(0).getText, t.tpe, t.size, seq = false) + case _ => throw new ParserException(s"${ + info + }: Must provide cmem with vector type") + } + case "smem" => + val t = visitType(ctx.`type`()) + t match { + case (t: VectorType) => CDefMemory(info, ctx.id(0).getText, t.tpe, t.size, seq = true) + case _ => throw new ParserException(s"${ + info + }: Must provide cmem with vector type") + } + case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText) + case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx.exp(0))) + + case "stop(" => Stop(info, string2Int(ctx.IntLit().getText), visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) case "printf(" => Print(info, visitStringLit(ctx.StringLit), ctx.exp.drop(2).map(visitExp), - visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) + visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) case "skip" => EmptyStmt } // If we don't match on the first child, try the next one - case _ => { + case _ => ctx.getChild(1).getText match { - case "<=" => Connect(info, visitExp(ctx.exp(0)), visitExp(ctx.exp(1)) ) - case "<-" => PartialConnect(info, visitExp(ctx.exp(0)), visitExp(ctx.exp(1)) ) + case "<=" => Connect(info, visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) + case "<-" => PartialConnect(info, visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) case "is" => IsInvalid(info, visitExp(ctx.exp(0))) - case "mport" => CDefMPort(info, ctx.id(0).getText, UnknownType,ctx.id(1).getText,Seq(visitExp(ctx.exp(0)),visitExp(ctx.exp(1))),visitMdir(ctx.mdir)) + case "mport" => CDefMPort(info, ctx.id(0).getText, UnknownType, ctx.id(1).getText, Seq(visitExp(ctx.exp(0)), visitExp(ctx.exp(1))), visitMdir(ctx.mdir)) } - } } } - - // add visitRuw ? - //T visitRuw(FIRRTLParser.RuwContext ctx); - //private def visitRuw[FirrtlNode](ctx: FIRRTLParser.RuwContext): - // TODO + // TODO // - Add mux // - Add validif - private def visitExp[FirrtlNode](ctx: FIRRTLParser.ExpContext): Expression = - if( ctx.getChildCount == 1 ) - Reference((ctx.getText), UnknownType) + private def visitExp[FirrtlNode](ctx: FIRRTLParser.ExpContext): Expression = + if (ctx.getChildCount == 1) + Reference(ctx.getText, UnknownType) else ctx.getChild(0).getText match { - case "UInt" => { // This could be better - val (width, value) = - if (ctx.getChildCount > 4) + case "UInt" => { + // This could be better + val (width, value) = + if (ctx.getChildCount > 4) (IntWidth(string2BigInt(ctx.IntLit(0).getText)), string2BigInt(ctx.IntLit(1).getText)) else { - val bigint = string2BigInt(ctx.IntLit(0).getText) - (IntWidth(BigInt(scala.math.max(bigint.bitLength,1))),bigint) + val bigint = string2BigInt(ctx.IntLit(0).getText) + (IntWidth(BigInt(scala.math.max(bigint.bitLength, 1))), bigint) } UIntLiteral(value, width) } case "SInt" => { - val (width, value) = - if (ctx.getChildCount > 4) + val (width, value) = + if (ctx.getChildCount > 4) (IntWidth(string2BigInt(ctx.IntLit(0).getText)), string2BigInt(ctx.IntLit(1).getText)) else { - val bigint = string2BigInt(ctx.IntLit(0).getText) - (IntWidth(BigInt(bigint.bitLength + 1)),bigint) + val bigint = string2BigInt(ctx.IntLit(0).getText) + (IntWidth(BigInt(bigint.bitLength + 1)), bigint) } SIntLiteral(value, width) } case "validif(" => ValidIf(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType) case "mux(" => Mux(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), visitExp(ctx.exp(2)), UnknownType) - case _ => + case _ => ctx.getChild(1).getText match { - case "." => new SubField(visitExp(ctx.exp(0)), (ctx.id.getText), UnknownType) - case "[" => if (ctx.exp(1) == null) - new SubIndex(visitExp(ctx.exp(0)), string2Int(ctx.IntLit(0).getText), UnknownType) - else new SubAccess(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType) + case "." => new SubField(visitExp(ctx.exp(0)), ctx.id.getText, UnknownType) + case "[" => if (ctx.exp(1) == null) + new SubIndex(visitExp(ctx.exp(0)), string2Int(ctx.IntLit(0).getText), UnknownType) + else new SubAccess(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType) // Assume primop case _ => DoPrim(visitPrimop(ctx.primop), ctx.exp.map(visitExp), - ctx.IntLit.map(x => string2BigInt(x.getText)), UnknownType) + ctx.IntLit.map(x => string2BigInt(x.getText)), UnknownType) } } - - // stripSuffix("(") is included because in ANTLR concrete syntax we have to include open parentheses, + + // stripSuffix("(") is included because in ANTLR concrete syntax we have to include open parentheses, // see grammar file for more details - private def visitPrimop[FirrtlNode](ctx: FIRRTLParser.PrimopContext): PrimOp = fromString(ctx.getText.stripSuffix("(")) + private def visitPrimop[FirrtlNode](ctx: FIRRTLParser.PrimopContext): PrimOp = fromString(ctx.getText.stripSuffix("(")) // visit Id and Keyword? } diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index f25ab144..fc5e26b8 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -97,14 +97,16 @@ case class DefMemory( readLatency: Int, readers: Seq[String], writers: Seq[String], - readwriters: Seq[String]) extends Statement with IsDeclaration + readwriters: Seq[String], + // TODO: handle read-under-write + readUnderWrite: Option[String] = None) extends Statement with IsDeclaration case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration case class Conditionally( info: Info, pred: Expression, conseq: Statement, alt: Statement) extends Statement with HasInfo -case class Begin(stmts: Seq[Statement]) extends Statement +case class Block(stmts: Seq[Statement]) extends Statement case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index b9096e1e..94d509ed 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -619,7 +619,7 @@ object CheckGenders extends Pass { check_gender(s.info,genders,MALE)(s.en) check_gender(s.info,genders,MALE)(s.clk) } - case (_:Begin|_:IsInvalid) => false + case (_:Block | _:IsInvalid) => false } s } diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index b6e090f4..6df7664b 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -73,12 +73,12 @@ object ExpandWhens extends Pass { } private def squashEmpty(s: Statement): Statement = { s map squashEmpty match { - case Begin(stmts) => + case Block(stmts) => val newStmts = stmts filter (_ != EmptyStmt) newStmts.size match { case 0 => EmptyStmt case 1 => newStmts.head - case _ => Begin(newStmts) + case _ => Block(newStmts) } case s => s } @@ -157,7 +157,7 @@ object ExpandWhens extends Pass { memos += memoNode netlist(lvalue) = memoExpr } - Begin(Seq(conseqStmt, altStmt) ++ memos) + Block(Seq(conseqStmt, altStmt) ++ memos) case s: Print => if(weq(p, one)) { @@ -191,7 +191,7 @@ object ExpandWhens extends Pass { case m: ExtModule => m case m: Module => val (netlist, simlist, bodyx) = expandWhens(m) - val newBody = Begin(Seq(bodyx map squashEmpty) ++ expandNetlist(netlist) ++ simlist) + val newBody = Block(Seq(bodyx map squashEmpty) ++ expandNetlist(netlist) ++ simlist) Module(m.info, m.name, m.ports, newBody) } } diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 786de0eb..3801f8cb 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -145,7 +145,7 @@ object InlineInstances extends Transform { stmts += DefWire(p.info, rename(p.name), p.tpe) } stmts += renameStmt(instInModule.body) - Begin(stmts.toSeq) + Block(stmts.toSeq) } else s } case s => s map onExp map onStmt diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index b86b0651..7ab3333a 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -163,7 +163,7 @@ object LowerTypes extends Pass { } else { val exps = create_exps(s.name, s.tpe) val stmts = exps map (e => DefWire(s.info, loweredName(e), tpe(e))) - Begin(stmts) + Block(stmts) } case s: DefRegister => sinfo = s.info @@ -177,7 +177,7 @@ object LowerTypes extends Pass { val stmts = es zip inits map { case (e, i) => DefRegister(s.info, loweredName(e), tpe(e), clock, reset, i) } - Begin(stmts) + Block(stmts) } // Could instead just save the type of each Module as it gets processed case s: WDefInstance => @@ -206,7 +206,7 @@ object LowerTypes extends Pass { s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters) } - Begin(stmts) + Block(stmts) } // wire foo : { a , b } // node x = foo @@ -222,13 +222,13 @@ object LowerTypes extends Pass { val stmts = names zip exps map { case (n, e) => DefNode(s.info, loweredName(n), e) } - Begin(stmts) + Block(stmts) case s: IsInvalid => sinfo = s.info kind(s.expr) match { case k: MemKind => val exps = lowerTypesMemExp(s.expr) - Begin(exps map (exp => IsInvalid(s.info, exp))) + Block(exps map (exp => IsInvalid(s.info, exp))) case _ => s map (lowerTypesExp) } case s: Connect => @@ -237,7 +237,7 @@ object LowerTypes extends Pass { case k: MemKind => val exp = lowerTypesExp(s.expr) val locs = lowerTypesMemExp(s.loc) - Begin(locs map (loc => Connect(s.info, loc, exp))) + Block(locs map (loc => Connect(s.info, loc, exp))) case _ => s map (lowerTypesExp) } case s => s map (lowerTypesExp) diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 6b88c514..44de3542 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -690,7 +690,7 @@ object ExpandConnects extends Pass { EmptyStmt } else if (invalids.length == 1) { invalids(0) - } else Begin(invalids) + } else Block(invalids) } case (s:Connect) => { val n = get_size(tpe(s.loc)) @@ -706,7 +706,7 @@ object ExpandConnects extends Pass { } connects += sx } - Begin(connects) + Block(connects) } case (s:PartialConnect) => { val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) @@ -722,7 +722,7 @@ object ExpandConnects extends Pass { } connects += sx }} - Begin(connects) + Block(connects) } case (s) => s map (expand_s) } @@ -873,7 +873,7 @@ object RemoveAccesses extends Pass { case (s) => s map (remove_e) map (remove_s) } stmts += sx - if (stmts.size != 1) Begin(stmts) else stmts(0) + if (stmts.size != 1) Block(stmts) else stmts(0) } Module(m.info,m.name,m.ports,remove_s(m.body)) } @@ -1053,7 +1053,7 @@ object CInferTypes extends Pass { def infer_types (m:DefModule) : DefModule = { val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { - (e map (infer_types_e)) match { + e map infer_types_e match { case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType)) case (e:SubField) => SubField(e.expr,e.name,field_type(tpe(e.expr),e.name)) case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(tpe(e.expr))) @@ -1065,10 +1065,10 @@ object CInferTypes extends Pass { } } def infer_types_s (s:Statement) : Statement = { - (s) match { + s match { case (s:DefRegister) => { types(s.name) = s.tpe - s map (infer_types_e) + s map infer_types_e s } case (s:DefWire) => { @@ -1076,7 +1076,7 @@ object CInferTypes extends Pass { s } case (s:DefNode) => { - val sx = s map (infer_types_e) + val sx = s map infer_types_e val t = get_type(sx) types(s.name) = t sx @@ -1098,13 +1098,13 @@ object CInferTypes extends Pass { types(s.name) = module_types.getOrElse(s.module,UnknownType) s } - case (s) => s map(infer_types_s) map (infer_types_e) + case (s) => s map infer_types_s map infer_types_e } } for (p <- m.ports) { types(p.name) = p.tpe } - (m) match { + m match { case (m:Module) => Module(m.info,m.name,m.ports,infer_types_s(m.body)) case (m:ExtModule) => m } @@ -1287,7 +1287,7 @@ object RemoveCHIRRTL extends Pass { set_write(rws,"data","mask") val read_l = if (s.seq) 1 else 0 val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name)) - Begin(Seq(mem,Begin(stmts))) + Block(Seq(mem,Block(stmts))) } case (s:CDefMPort) => { mport_types(s.name) = mport_types(s.mem) @@ -1327,7 +1327,7 @@ object RemoveCHIRRTL extends Pass { for (x <- ens ) { stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),one) } - Begin(stmts) + Block(stmts) } case (s) => s map (collect_refs) } @@ -1383,7 +1383,7 @@ object RemoveCHIRRTL extends Pass { stmts += Connect(s.info,wmode,one) } } - if (stmts.size > 1) Begin(stmts) + if (stmts.size > 1) Block(stmts) else stmts(0) } case (s:PartialConnect) => { @@ -1403,7 +1403,7 @@ object RemoveCHIRRTL extends Pass { stmts += Connect(s.info,wmode,one) } } - if (stmts.size > 1) Begin(stmts) + if (stmts.size > 1) Block(stmts) else stmts(0) } case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE)) diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 973e1be9..1c9674e1 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -45,11 +45,11 @@ object SplitExpressions extends Pass { } val x = s map onExp x match { - case x: Begin => x map onStmt + case x: Block => x map onStmt case EmptyStmt => x case x => { v += x - if (v.size > 1) Begin(v.toVector) + if (v.size > 1) Block(v.toVector) else v(0) } } diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index aa2c1d5d..b1a20fdd 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -244,7 +244,7 @@ object Uniquify extends Pass { } case s: DefNode => Seq(Field(s.name, Default, get_type(s))) case s: Conditionally => recStmtToType(s.conseq) ++ recStmtToType(s.alt) - case s: Begin => (s.stmts map (recStmtToType)).flatten + case s: Block => (s.stmts map (recStmtToType)).flatten case s => Seq() } BundleType(recStmtToType(s)) |
