diff options
| author | Jack Koenig | 2021-12-01 12:04:36 -0800 |
|---|---|---|
| committer | GitHub | 2021-12-01 12:04:36 -0800 |
| commit | b14ed79d416883eb858a191e29326ec08c040a2d (patch) | |
| tree | 37b93b1c487f78e39acd17f94faf5f5c3b24b9d8 | |
| parent | a4d13a5024f7488e1d2b9fdd27d3917157a67268 (diff) | |
| parent | 17250fba841ae3129dc798c0bc48d10200be18ae (diff) | |
Merge pull request #2343 from chipsalliance/improve-parser
Improve ANTLR Parser
| -rw-r--r-- | build.sbt | 5 | ||||
| -rw-r--r-- | build.sc | 4 | ||||
| -rwxr-xr-x | scripts/formal_equiv.sh | 2 | ||||
| -rw-r--r-- | src/main/antlr4/FIRRTL.g4 | 22 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Parser.scala | 18 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Visitor.scala | 72 | ||||
| -rw-r--r-- | src/main/scala/firrtl/parser/Listener.scala | 38 |
7 files changed, 110 insertions, 51 deletions
@@ -89,7 +89,7 @@ lazy val testAssemblySettings = Seq( lazy val antlrSettings = Seq( Antlr4 / antlr4GenVisitor := true, - Antlr4 / antlr4GenListener := false, + Antlr4 / antlr4GenListener := true, Antlr4 / antlr4PackageName := Option("firrtl.antlr"), Antlr4 / antlr4Version := "4.9.3", Antlr4 / javaSource := (Compile / sourceManaged).value @@ -130,6 +130,9 @@ lazy val docSettings = Seq( Compile / doc := (ScalaUnidoc / doc).value, autoAPIMappings := true, Compile / doc / scalacOptions ++= Seq( + // ANTLR-generated classes aren't really part of public API and cause + // errors in ScalaDoc generation + "-skip-packages", "firrtl.antlr", "-Xfatal-warnings", "-feature", "-diagrams", @@ -139,7 +139,7 @@ class firrtlCrossModule(val crossScalaVersion: String) antlrSource().path.toString, "-package", "firrtl.antlr", - "-no-listener", + "-listener", "-visitor", antlrSource().path.toString ).call() @@ -152,7 +152,7 @@ class firrtlCrossModule(val crossScalaVersion: String) antlrSource().path.toString, "-package", "firrtl.antlr", - "-no-listener", + "-listener", "-visitor", antlrSource().path.toString ).call() diff --git a/scripts/formal_equiv.sh b/scripts/formal_equiv.sh index c3d45357..deb4884d 100755 --- a/scripts/formal_equiv.sh +++ b/scripts/formal_equiv.sh @@ -28,7 +28,7 @@ make_verilog () { git checkout $1 local filename="$DUT.$1.v" - sbt "runMain firrtl.stage.FirrtlMain -i $DUT.fir -o $filename -X verilog" + sbt "clean; runMain firrtl.stage.FirrtlMain -i $DUT.fir -o $filename -X verilog" RET=$filename } diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index f5116485..d40c6560 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -99,9 +99,9 @@ stmt | mdir 'mport' id '=' id '[' exp ']' exp info? | 'inst' id 'of' id info? | 'node' id '=' exp info? - | exp '<=' exp info? - | exp '<-' exp info? - | exp 'is' 'invalid' info? + | ref '<=' exp info? + | ref '<-' exp info? + | ref 'is' 'invalid' info? | when | 'stop(' exp exp intLit ')' stmtName? info? | 'printf(' exp exp StringLit ( exp)* ')' stmtName? info? @@ -167,16 +167,22 @@ ruw exp : 'UInt' ('<' intLit '>')? '(' intLit ')' | 'SInt' ('<' intLit '>')? '(' intLit ')' - | id // Ref - | exp '.' fieldId - | exp '.' DoubleLit // TODO Workaround for #470 - | exp '[' intLit ']' - | exp '[' exp ']' + | ref | 'mux(' exp exp exp ')' | 'validif(' exp exp ')' | primop exp* intLit* ')' ; +ref + : id subref? + ; + +subref + : '.' fieldId subref? + | '.' DoubleLit subref? // TODO Workaround for #470 + | '[' (intLit | exp) ']' subref? + ; + id : Id | keywordAsId diff --git a/src/main/scala/firrtl/Parser.scala b/src/main/scala/firrtl/Parser.scala index 2d2bd350..bb93511c 100644 --- a/src/main/scala/firrtl/Parser.scala +++ b/src/main/scala/firrtl/Parser.scala @@ -6,6 +6,7 @@ import org.antlr.v4.runtime._ import org.antlr.v4.runtime.atn._ import logger.LazyLogging import firrtl.ir._ +import firrtl.parser.Listener import firrtl.Utils.time import firrtl.antlr.{FIRRTLParser, _} @@ -29,29 +30,24 @@ object Parser extends LazyLogging { /** Parses a org.antlr.v4.runtime.CharStream and returns a parsed [[firrtl.ir.Circuit Circuit]] */ def parseCharStream(charStream: CharStream, infoMode: InfoMode): Circuit = { - val (parseTimeMillis, cst) = time { + val (parseTimeMillis, ast) = time { val parser = { val lexer = new FIRRTLLexer(charStream) new FIRRTLParser(new CommonTokenStream(lexer)) } + val listener = new Listener(infoMode) + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + parser.addParseListener(listener) // Concrete Syntax Tree - val cst = parser.circuit + parser.circuit val numSyntaxErrors = parser.getNumberOfSyntaxErrors if (numSyntaxErrors > 0) throw new SyntaxErrorsException(s"$numSyntaxErrors syntax error(s) detected") - cst - } - val visitor = new Visitor(infoMode) - val (visitTimeMillis, visit) = time { - visitor.visit(cst) - } - val ast = visit match { - case c: Circuit => c - case x => throw new ClassCastException("Error! AST not rooted with Circuit node!") + listener.getCircuit } ast diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index ad2f5121..f3b6837e 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -6,6 +6,7 @@ import org.antlr.v4.runtime.ParserRuleContext import org.antlr.v4.runtime.tree.{AbstractParseTreeVisitor, ParseTreeVisitor, TerminalNode} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.annotation.tailrec import firrtl.antlr._ import PrimOps._ import FIRRTLParser._ @@ -57,7 +58,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private def string2Int(s: String): Int = string2BigInt(s).toInt - private def visitInfo(ctx: Option[InfoContext], parentCtx: ParserRuleContext): Info = { + private[firrtl] def visitInfo(ctx: Option[InfoContext], parentCtx: ParserRuleContext): Info = { // Convert a compressed FileInfo string into either into a singular FileInfo or a MultiInfo // consisting of several FileInfos def parseCompressedInfo(escaped: String): Info = { @@ -129,7 +130,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private def visitCircuit(ctx: CircuitContext): Circuit = Circuit(visitInfo(Option(ctx.info), ctx), ctx.module.asScala.map(visitModule).toSeq, ctx.id.getText) - private def visitModule(ctx: ModuleContext): DefModule = { + private[firrtl] def visitModule(ctx: ModuleContext): DefModule = { val info = visitInfo(Option(ctx.info), ctx) ctx.getChild(0).getText match { case "module" => @@ -441,9 +442,9 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w // If we don't match on the first child, try the next one case _ => ctx.getChild(1).getText match { - case "<=" => Connect(info, visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) - case "<-" => PartialConnect(info, visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) - case "is" => IsInvalid(info, visitExp(ctx_exp(0))) + case "<=" => Connect(info, visitRef(ctx.ref), visitExp(ctx_exp(0))) + case "<-" => PartialConnect(info, visitRef(ctx.ref), visitExp(ctx_exp(0))) + case "is" => IsInvalid(info, visitRef(ctx.ref)) case "mport" => CDefMPort( info, @@ -457,32 +458,47 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } } + @tailrec private def visitSubRef(ctx: SubrefContext, inner: Expression): Expression = { + val ref = ctx.getChild(0).getText match { + case "." => + if (ctx.fieldId != null) { + SubField(inner, ctx.fieldId.getText) + } else { + ctx.DoubleLit.getText.split('.') match { + case Array(a, b) if legalId(a) && legalId(b) => SubField(SubField(inner, a), b) + case _ => throw new ParserException(s"Illegal Expression at ${ctx.getText}") + } + } + case "[" => + if (ctx.intLit != null) { + val lit = string2Int(ctx.intLit.getText) + SubIndex(inner, lit, UnknownType) + } else { + val idx = visitExp(ctx.exp) + SubAccess(inner, idx, UnknownType) + } + } + if (ctx.subref != null) { + visitSubRef(ctx.subref, ref) + } else { + ref + } + } + + private def visitRef(ctx: RefContext): Expression = { + val ref = Reference(ctx.getChild(0).getText) + if (ctx.subref != null) { + visitSubRef(ctx.subref, ref) + } else { + ref + } + } + private def visitExp(ctx: ExpContext): Expression = { val ctx_exp = ctx.exp.asScala ctx.getChild(0) match { - case _: IdContext => Reference(ctx.getText, UnknownType) - case _: ExpContext => - ctx.getChild(1).getText match { - case "." => - val expr1 = visitExp(ctx_exp(0)) - // TODO Workaround for #470 - if (ctx.fieldId == null) { - ctx.DoubleLit.getText.split('.') match { - case Array(a, b) if legalId(a) && legalId(b) => - val inner = new SubField(expr1, a, UnknownType) - new SubField(inner, b, UnknownType) - case Array() => throw new ParserException(s"Illegal Expression at ${ctx.getText}") - } - } else { - new SubField(expr1, ctx.fieldId.getText, UnknownType) - } - case "[" => - if (ctx.exp(1) == null) - new SubIndex(visitExp(ctx_exp(0)), string2Int(ctx.intLit(0).getText), UnknownType) - else - new SubAccess(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType) - } - case _: PrimopContext => + case ref: RefContext => visitRef(ref) + case _: PrimopContext => DoPrim( visitPrimop(ctx.primop), ctx_exp.map(visitExp).toSeq, diff --git a/src/main/scala/firrtl/parser/Listener.scala b/src/main/scala/firrtl/parser/Listener.scala new file mode 100644 index 00000000..ffaa22b2 --- /dev/null +++ b/src/main/scala/firrtl/parser/Listener.scala @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.parser + +import firrtl.antlr.{FIRRTLParser, _} +import firrtl.Visitor +import firrtl.Parser.InfoMode +import firrtl.ir._ + +import scala.collection.mutable +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.Duration + +private[firrtl] class Listener(infoMode: InfoMode) extends FIRRTLBaseListener { + private var main: Option[String] = None + private var info: Option[Info] = None + private val modules = mutable.ArrayBuffer.empty[DefModule] + + private val visitor = new Visitor(infoMode) + + override def exitModule(ctx: FIRRTLParser.ModuleContext): Unit = { + val m = visitor.visitModule(ctx) + ctx.children = null // Null out to save memory + modules += m + } + + override def exitCircuit(ctx: FIRRTLParser.CircuitContext): Unit = { + info = Some(visitor.visitInfo(Option(ctx.info), ctx)) + main = Some(ctx.id.getText) + ctx.children = null // Null out to save memory + } + + def getCircuit: Circuit = { + require(main.nonEmpty) + val mods = modules.toSeq + Circuit(info.get, mods, main.get) + } +} |
