diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/Visitor.scala | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/Visitor.scala')
| -rw-r--r-- | src/main/scala/firrtl/Visitor.scala | 299 |
1 files changed, 177 insertions, 122 deletions
diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 502d021d..b14c39c7 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -13,7 +13,6 @@ import Parser.{AppendInfo, GenInfo, IgnoreInfo, InfoMode, UseInfo} import firrtl.ir._ import Utils.throwInternalError - class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] with ParseTreeVisitor[FirrtlNode] { // Strip file path private def stripPath(filename: String) = filename.drop(filename.lastIndexOf("/") + 1) @@ -21,7 +20,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w // Check if identifier is made of legal characters private def legalId(id: String) = { val legalChars = ('A' to 'Z').toSet ++ ('a' to 'z').toSet ++ ('0' to '9').toSet ++ Set('_', '$') - id forall legalChars + id.forall(legalChars) } def visit(ctx: CircuitContext): Circuit = visitCircuit(ctx) @@ -37,22 +36,22 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private def string2BigInt(s: String): BigInt = { // private define legal patterns s match { - case ZeroPattern(_*) => BigInt(0) - case HexPattern(hexdigits) => BigInt(hexdigits, 16) - case OctalPattern(octaldigits) => BigInt(octaldigits, 8) + case ZeroPattern(_*) => BigInt(0) + case HexPattern(hexdigits) => BigInt(hexdigits, 16) + case OctalPattern(octaldigits) => BigInt(octaldigits, 8) case BinaryPattern(binarydigits) => BigInt(binarydigits, 2) - case DecPattern(num) => BigInt(num, 10) - case _ => throw new Exception("Invalid String for conversion to BigInt " + s) + case DecPattern(num) => BigInt(num, 10) + case _ => throw new Exception("Invalid String for conversion to BigInt " + s) } } private def string2BigDecimal(s: String): BigDecimal = { // private define legal patterns s match { - case ZeroPattern(_*) => BigDecimal(0) - case DecPattern(num) => BigDecimal(num) + case ZeroPattern(_*) => BigDecimal(0) + case DecPattern(num) => BigDecimal(num) case DecimalPattern(num) => BigDecimal(num) - case _ => throw new Exception("Invalid String for conversion to BigDecimal " + s) + case _ => throw new Exception("Invalid String for conversion to BigDecimal " + s) } } @@ -64,7 +63,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w parentCtx.getStart.getCharPositionInLine lazy val useInfo: String = ctx match { case Some(info) => info.getText.drop(2).init // remove surrounding @[ ... ] - case None => "" + case None => "" } infoMode match { case UseInfo => @@ -88,14 +87,19 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private def visitModule(ctx: ModuleContext): DefModule = { val info = visitInfo(Option(ctx.info), ctx) ctx.getChild(0).getText match { - case "module" => Module(info, ctx.id.getText, ctx.port.asScala.map(visitPort).toSeq, - if (ctx.moduleBlock() != null) - visitBlock(ctx.moduleBlock()) - else EmptyStmt) + case "module" => + Module( + info, + ctx.id.getText, + ctx.port.asScala.map(visitPort).toSeq, + if (ctx.moduleBlock() != null) + visitBlock(ctx.moduleBlock()) + else EmptyStmt + ) case "extmodule" => val defname = if (ctx.defname != null) ctx.defname.id.getText else ctx.id.getText - val ports = ctx.port.asScala map visitPort - val params = ctx.parameter.asScala map visitParameter + val ports = ctx.port.asScala.map(visitPort) + val params = ctx.parameter.asScala.map(visitParameter) ExtModule(info, ctx.id.getText, ports.toSeq, defname, params.toSeq) } } @@ -111,22 +115,22 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w case (null, str, null, null) => StringParam(name, visitStringLit(str)) case (null, null, dbl, null) => DoubleParam(name, dbl.getText.toDouble) case (null, null, null, raw) => RawStringParam(name, raw.getText.tail.init.replace("\\'", "'")) // Remove "\'"s - case _ => throwInternalError(s"visiting impossible parameter ${ctx.getText}") + case _ => throwInternalError(s"visiting impossible parameter ${ctx.getText}") } } private def visitDir(ctx: DirContext): Direction = ctx.getText match { - case "input" => Input + case "input" => Input case "output" => Output } private def visitMdir(ctx: MdirContext): MPortDir = ctx.getText match { case "infer" => MInfer - case "read" => MRead + case "read" => MRead case "write" => MWrite - case "rdwr" => MReadWrite + case "rdwr" => MReadWrite } // Match on a type instead of on strings? @@ -135,47 +139,53 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w ctx.getChild(0) match { case term: TerminalNode => term.getText match { - case "UInt" => if (ctx.getChildCount > 1) UIntType(getWidth(ctx.intLit(0))) - else UIntType(UnknownWidth) - case "SInt" => if (ctx.getChildCount > 1) SIntType(getWidth(ctx.intLit(0))) - else SIntType(UnknownWidth) - case "Fixed" => ctx.intLit.size match { - case 0 => FixedType(UnknownWidth, UnknownWidth) - case 1 => ctx.getChild(2).getText match { - case "<" => FixedType(UnknownWidth, getWidth(ctx.intLit(0))) - case _ => FixedType(getWidth(ctx.intLit(0)), UnknownWidth) + case "UInt" => + if (ctx.getChildCount > 1) UIntType(getWidth(ctx.intLit(0))) + else UIntType(UnknownWidth) + case "SInt" => + if (ctx.getChildCount > 1) SIntType(getWidth(ctx.intLit(0))) + else SIntType(UnknownWidth) + case "Fixed" => + ctx.intLit.size match { + case 0 => FixedType(UnknownWidth, UnknownWidth) + case 1 => + ctx.getChild(2).getText match { + case "<" => FixedType(UnknownWidth, getWidth(ctx.intLit(0))) + case _ => FixedType(getWidth(ctx.intLit(0)), UnknownWidth) + } + case 2 => FixedType(getWidth(ctx.intLit(0)), getWidth(ctx.intLit(1))) } - case 2 => FixedType(getWidth(ctx.intLit(0)), getWidth(ctx.intLit(1))) - } - case "Interval" => ctx.boundValue.size match { - case 0 => - val point = ctx.intLit.size match { - case 0 => UnknownWidth - case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) - } - IntervalType(UnknownBound, UnknownBound, point) - case 2 => - val lower = (ctx.lowerBound.getText, ctx.boundValue(0).getText) match { - case (_, "?") => UnknownBound - case ("(", v) => Open(string2BigDecimal(v)) - case ("[", v) => Closed(string2BigDecimal(v)) - } - val upper = (ctx.upperBound.getText, ctx.boundValue(1).getText) match { - case (_, "?") => UnknownBound - case (")", v) => Open(string2BigDecimal(v)) - case ("]", v) => Closed(string2BigDecimal(v)) - } - val point = ctx.intLit.size match { - case 0 => UnknownWidth - case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) - } - IntervalType(lower, upper, point) - } - case "Clock" => ClockType + case "Interval" => + ctx.boundValue.size match { + case 0 => + val point = ctx.intLit.size match { + case 0 => UnknownWidth + case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) + } + IntervalType(UnknownBound, UnknownBound, point) + case 2 => + val lower = (ctx.lowerBound.getText, ctx.boundValue(0).getText) match { + case (_, "?") => UnknownBound + case ("(", v) => Open(string2BigDecimal(v)) + case ("[", v) => Closed(string2BigDecimal(v)) + } + val upper = (ctx.upperBound.getText, ctx.boundValue(1).getText) match { + case (_, "?") => UnknownBound + case (")", v) => Open(string2BigDecimal(v)) + case ("]", v) => Closed(string2BigDecimal(v)) + } + val point = ctx.intLit.size match { + case 0 => UnknownWidth + case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) + } + IntervalType(lower, upper, point) + } + case "Clock" => ClockType case "AsyncReset" => AsyncResetType - case "Reset" => ResetType - case "Analog" => if (ctx.getChildCount > 1) AnalogType(getWidth(ctx.intLit(0))) - else AnalogType(UnknownWidth) + case "Reset" => ResetType + case "Analog" => + if (ctx.getChildCount > 1) AnalogType(getWidth(ctx.intLit(0))) + else AnalogType(UnknownWidth) case "{" => BundleType(ctx.field.asScala.map(visitField).toSeq) } case typeContext: TypeContext => new VectorType(visitType(ctx.`type`), string2Int(ctx.intLit(0).getText)) @@ -208,11 +218,12 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private def visitRuw(ctx: Option[RuwContext]): ReadUnderWrite.Value = ctx match { case None => ReadUnderWrite.Undefined - case Some(ctx) => ctx.getText match { - case "undefined" => ReadUnderWrite.Undefined - case "old" => ReadUnderWrite.Old - case "new" => ReadUnderWrite.New - } + case Some(ctx) => + ctx.getText match { + case "undefined" => ReadUnderWrite.Undefined + case "old" => ReadUnderWrite.Old + case "new" => ReadUnderWrite.New + } } // Memories are fairly complicated to translate thus have a dedicated method @@ -220,7 +231,11 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w val readers = mutable.ArrayBuffer.empty[String] val writers = mutable.ArrayBuffer.empty[String] val readwriters = mutable.ArrayBuffer.empty[String] - case class ParamValue(typ: Option[Type] = None, lit: Option[BigInt] = None, ruw: ReadUnderWrite.Value = ReadUnderWrite.Undefined, unique: Boolean = true) + case class ParamValue( + typ: Option[Type] = None, + lit: Option[BigInt] = None, + ruw: ReadUnderWrite.Value = ReadUnderWrite.Undefined, + unique: Boolean = true) val fieldMap = mutable.HashMap[String, ParamValue]() val memName = ctx.id(0).getText def parseMemFields(memFields: Seq[MemFieldContext]): Unit = @@ -228,14 +243,14 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w val fieldName = field.children.asScala(0).getText fieldName match { - case "reader" => readers ++= field.id().asScala.map(_.getText) - case "writer" => writers ++= field.id().asScala.map(_.getText) + case "reader" => readers ++= field.id().asScala.map(_.getText) + case "writer" => writers ++= field.id().asScala.map(_.getText) case "readwriter" => readwriters ++= field.id().asScala.map(_.getText) case _ => val paramDef = fieldName match { - case "data-type" => ParamValue(typ = Some(visitType(field.`type`()))) + case "data-type" => ParamValue(typ = Some(visitType(field.`type`()))) case "read-under-write" => ParamValue(ruw = visitRuw(Option(field.ruw))) - case _ => ParamValue(lit = Some(BigInt(field.intLit().getText))) + case _ => ParamValue(lit = Some(BigInt(field.intLit().getText))) } if (fieldMap.contains(fieldName)) throw new ParameterRedefinedException(s"Redefinition of $fieldName in FIRRTL line:${field.start.getLine}") @@ -255,20 +270,26 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } // Check for required fields - Seq("data-type", "depth", "read-latency", "write-latency") foreach { field => - fieldMap.getOrElse(field, throw new ParameterNotSpecifiedException(s"[$info] Required mem field $field not found")) + Seq("data-type", "depth", "read-latency", "write-latency").foreach { field => + fieldMap.getOrElse( + field, + throw new ParameterNotSpecifiedException(s"[$info] Required mem field $field not found") + ) } def lit(param: String) = fieldMap(param).lit.get val ruw = fieldMap.get("read-under-write").map(_.ruw).getOrElse(ir.ReadUnderWrite.Undefined) - DefMemory(info, + DefMemory( + info, name = memName, dataType = fieldMap("data-type").typ.get, depth = lit("depth"), writeLatency = lit("write-latency").toInt, readLatency = lit("read-latency").toInt, - readers = readers.toSeq, writers = writers.toSeq, readwriters = readwriters.toSeq, + readers = readers.toSeq, + writers = writers.toSeq, + readwriters = readwriters.toSeq, readUnderWrite = ruw ) } @@ -299,56 +320,88 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w val info = visitInfo(Option(ctx.info), ctx) ctx.getChild(0) match { case when: WhenContext => visitWhen(when) - case term: TerminalNode => term.getText match { - case "wire" => DefWire(info, ctx.id(0).getText, visitType(ctx.`type`())) - case "reg" => - val name = ctx.id(0).getText - val tpe = visitType(ctx.`type`()) - val (reset, init, rinfo) = { - val rb = ctx.reset_block() - if (rb != null) { - val sr = rb.simple_reset.simple_reset0() - val innerInfo = if (info == NoInfo) visitInfo(Option(rb.info), ctx) else info - (visitExp(sr.exp(0)), visitExp(sr.exp(1)), innerInfo) + case term: TerminalNode => + term.getText match { + case "wire" => DefWire(info, ctx.id(0).getText, visitType(ctx.`type`())) + case "reg" => + val name = ctx.id(0).getText + val tpe = visitType(ctx.`type`()) + val (reset, init, rinfo) = { + val rb = ctx.reset_block() + if (rb != null) { + val sr = rb.simple_reset.simple_reset0() + val innerInfo = if (info == NoInfo) visitInfo(Option(rb.info), ctx) else info + (visitExp(sr.exp(0)), visitExp(sr.exp(1)), innerInfo) + } else + (UIntLiteral(0, IntWidth(1)), Reference(name, tpe), info) } - else - (UIntLiteral(0, IntWidth(1)), Reference(name, tpe), info) - } - DefRegister(rinfo, name, tpe, visitExp(ctx_exp(0)), reset, init) - case "mem" => visitMem(ctx) - case "cmem" => - val (tpe, size) = visitCMemType(ctx.`type`()) - CDefMemory(info, ctx.id(0).getText, tpe, size, seq = false) - case "smem" => - val (tpe, size) = visitCMemType(ctx.`type`()) - CDefMemory(info, ctx.id(0).getText, tpe, size, seq = true, readUnderWrite = visitRuw(Option(ctx.ruw))) - case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText) - case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx_exp(0))) - - case "stop(" => Stop(info, string2Int(ctx.intLit().getText), visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) - case "attach" => Attach(info, ctx_exp.map(visitExp).toSeq) - case "printf(" => Print(info, visitStringLit(ctx.StringLit), ctx_exp.drop(2).map(visitExp).toSeq, - visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) - // formal - case "assert" => Verification(Formal.Assert, info, visitExp(ctx_exp(0)), - visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), - visitStringLit(ctx.StringLit)) - case "assume" => Verification(Formal.Assume, info, visitExp(ctx_exp(0)), - visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), - visitStringLit(ctx.StringLit)) - case "cover" => Verification(Formal.Cover, info, visitExp(ctx_exp(0)), - visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), - visitStringLit(ctx.StringLit)) - // end formal - case "skip" => EmptyStmt - } + DefRegister(rinfo, name, tpe, visitExp(ctx_exp(0)), reset, init) + case "mem" => visitMem(ctx) + case "cmem" => + val (tpe, size) = visitCMemType(ctx.`type`()) + CDefMemory(info, ctx.id(0).getText, tpe, size, seq = false) + case "smem" => + val (tpe, size) = visitCMemType(ctx.`type`()) + CDefMemory(info, ctx.id(0).getText, tpe, size, seq = true, readUnderWrite = visitRuw(Option(ctx.ruw))) + case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText) + case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx_exp(0))) + + case "stop(" => Stop(info, string2Int(ctx.intLit().getText), visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) + case "attach" => Attach(info, ctx_exp.map(visitExp).toSeq) + case "printf(" => + Print( + info, + visitStringLit(ctx.StringLit), + ctx_exp.drop(2).map(visitExp).toSeq, + visitExp(ctx_exp(0)), + visitExp(ctx_exp(1)) + ) + // formal + case "assert" => + Verification( + Formal.Assert, + info, + visitExp(ctx_exp(0)), + visitExp(ctx_exp(1)), + visitExp(ctx_exp(2)), + visitStringLit(ctx.StringLit) + ) + case "assume" => + Verification( + Formal.Assume, + info, + visitExp(ctx_exp(0)), + visitExp(ctx_exp(1)), + visitExp(ctx_exp(2)), + visitStringLit(ctx.StringLit) + ) + case "cover" => + Verification( + Formal.Cover, + info, + visitExp(ctx_exp(0)), + visitExp(ctx_exp(1)), + visitExp(ctx_exp(2)), + visitStringLit(ctx.StringLit) + ) + // end formal + case "skip" => EmptyStmt + } // If we don't match on the first child, try the next one case _ => ctx.getChild(1).getText match { case "<=" => Connect(info, visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) case "<-" => PartialConnect(info, visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) case "is" => IsInvalid(info, visitExp(ctx_exp(0))) - case "mport" => CDefMPort(info, ctx.id(0).getText, UnknownType, ctx.id(1).getText, Seq(visitExp(ctx_exp(0)), visitExp(ctx_exp(1))), visitMdir(ctx.mdir)) + case "mport" => + CDefMPort( + info, + ctx.id(0).getText, + UnknownType, + ctx.id(1).getText, + Seq(visitExp(ctx_exp(0)), visitExp(ctx_exp(1))), + visitMdir(ctx.mdir) + ) } } } @@ -379,10 +432,12 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w new SubAccess(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType) } case _: PrimopContext => - DoPrim(visitPrimop(ctx.primop), - ctx_exp.map(visitExp).toSeq, - ctx.intLit.asScala.map(x => string2BigInt(x.getText)).toSeq, - UnknownType) + DoPrim( + visitPrimop(ctx.primop), + ctx_exp.map(visitExp).toSeq, + ctx.intLit.asScala.map(x => string2BigInt(x.getText)).toSeq, + UnknownType + ) case _ => ctx.getChild(0).getText match { case "UInt" => @@ -405,7 +460,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w SIntLiteral(value) } case "validif(" => ValidIf(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType) - case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType) + case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType) } } } |
