From c8745fca352c79f886a1661d4985adc79e57c36d Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Mon, 19 Sep 2022 13:37:36 -0700 Subject: Add Serializer.lazily for buffered serialization (#2554) This is needed for emitting modules that serialize to Strings larger than 2 GiB (the maximum String size on the JVM). It includes micro-optimized logic for turning arbitrarily nested when scopes into Iterable[String].--- src/main/scala/firrtl/ir/Serializer.scala | 205 +++++++++++++++++++++++++----- 1 file changed, 174 insertions(+), 31 deletions(-) diff --git a/src/main/scala/firrtl/ir/Serializer.scala b/src/main/scala/firrtl/ir/Serializer.scala index 0666a4b1..21b08b65 100644 --- a/src/main/scala/firrtl/ir/Serializer.scala +++ b/src/main/scala/firrtl/ir/Serializer.scala @@ -33,7 +33,7 @@ object Serializer { case n: Info => s(n)(builder, indent) case n: StringLit => s(n)(builder, indent) case n: Expression => s(n)(builder, indent) - case n: Statement => s(n)(builder, indent) + case n: Statement => builder ++= lazily(n, indent).mkString case n: Width => s(n)(builder, indent) case n: Orientation => s(n)(builder, indent) case n: Field => s(n)(builder, indent) @@ -41,13 +41,36 @@ object Serializer { case n: Direction => s(n)(builder, indent) case n: Port => s(n)(builder, indent) case n: Param => s(n)(builder, indent) - case n: DefModule => s(n)(builder, indent) - case n: Circuit => s(n)(builder, indent) + case n: DefModule => builder ++= lazily(n, indent).mkString + case n: Circuit => builder ++= lazily(n, indent).mkString case other => builder ++= other.serialize // Handle user-defined nodes } builder.toString() } + /** Converts a `FirrtlNode` to an Iterable of Strings + * + * The Strings in the Iterable can be concatenated to give the String representation of the + * `FirrtlNode`. This is useful for buffered emission, especially for large Circuits that + * encroach on the JVM limit on String size (2 GiB). + */ + def lazily(node: FirrtlNode): Iterable[String] = lazily(node, 0) + + /** Converts a `FirrtlNode` to an Iterable of Strings + * + * The Strings in the Iterable can be concatenated to give the String representation of the + * `FirrtlNode`. This is useful for buffered emission, especially for large Circuits that + * encroach on the JVM limit on String size (2 GiB). + */ + def lazily(node: FirrtlNode, indent: Int): Iterable[String] = new Iterable[String] { + def iterator = node match { + case n: Statement => sIt(n)(indent) + case n: DefModule => sIt(n)(indent) + case n: Circuit => sIt(n)(indent) + case other => Iterator(serialize(other, indent)) + } + } + /** Converts a `Constraint` into its string representation. */ def serialize(con: Constraint): String = { val builder = new StringBuilder() @@ -101,24 +124,130 @@ object Serializer { case other => b ++= other.serialize // Handle user-defined nodes } + // Helper for some not-real Statements that only exist for Serialization + private abstract class PseudoStatement extends Statement { + def foreachExpr(f: Expression => Unit): Unit = ??? + def foreachInfo(f: Info => Unit): Unit = ??? + def foreachStmt(f: Statement => Unit): Unit = ??? + def foreachString(f: String => Unit): Unit = ??? + def foreachType(f: Type => Unit): Unit = ??? + def mapExpr(f: Expression => Expression): Statement = ??? + def mapInfo(f: Info => Info): Statement = ??? + def mapStmt(f: Statement => Statement): Statement = ??? + def mapString(f: String => String): Statement = ??? + def mapType(f: Type => Type): Statement = ??? + def serialize: String = ??? + } + + // To treat Statments as Iterable, we need to flatten out when scoping + private case class WhenBegin(info: Info, pred: Expression) extends PseudoStatement + private case object AltBegin extends PseudoStatement + private case object WhenEnd extends PseudoStatement + + // This does not extend Iterator[Statement] because + // 1. It is extended by StmtsSerializer which extends Iterator[String] + // 2. Flattening out whens introduces fake Statements needed for [un]indenting + private abstract class FlatStmtsIterator(stmts: Iterable[Statement]) { + private var underlying: Iterator[Statement] = stmts.iterator + + protected def hasNextStmt = underlying.hasNext + + protected def nextStmt(): Statement = { + var next: Statement = null + while (next == null && hasNextStmt) { + val head = underlying + head.next() match { + case b: Block => + val first = b.stmts.iterator + val last = underlying + underlying = first ++ last + case Conditionally(info, pred, conseq, alt) => + val begin = WhenBegin(info, pred) + val stmts = if (alt == EmptyStmt) { + Iterator(begin, conseq, WhenEnd) + } else { + Iterator(begin, conseq, AltBegin, alt, WhenEnd) + } + val last = underlying + underlying = stmts ++ last + case other => + next = other + } + } + next + } + } + + // Extend FlatStmtsIterator directly (rather than wrapping a FlatStmtsIterator object) to reduce + // the boxing overhead + private class StmtsSerializer(stmts: Iterable[Statement], initialIndent: Int) + extends FlatStmtsIterator(stmts) + with Iterator[String] { + + private def bufferSize = 2048 + + // We could initialze the StringBuilder size, but this is bad for small modules which may not + // even reach the bufferSize. + private implicit val b = new StringBuilder + + // The flattening of Whens into WhenBegin and friends requires us to keep track of the + // indention level + private implicit var indent: Int = initialIndent + + def hasNext: Boolean = this.hasNextStmt + + def next(): String = { + def consumeStmt(stmt: Statement): Unit = { + stmt match { + case wb: WhenBegin => + newLineAndIndent() + b ++= "when "; s(wb.pred); b ++= " :"; s(wb.info) + indent += 1 + case AltBegin => + indent -= 1 + newLineAndIndent() + b ++= "else :" + indent += 1 + case WhenEnd => + indent -= 1 + case other => + s(other) + } + if (this.hasNext) { + newLineAndIndent() + } + } + b.clear() + // There must always be at least 1 Statement because we're nonEmpty + var stmt: Statement = nextStmt() + while (stmt != null && b.size < bufferSize) { + consumeStmt(stmt) + stmt = nextStmt() + } + if (stmt != null) { + consumeStmt(stmt) + } + b.toString + } + } + + private def sIt(node: Statement)(implicit indent: Int): Iterator[String] = node match { + case b: Block => + if (b.stmts.isEmpty) Iterator("skip") + else new StmtsSerializer(b.stmts, indent) + case cond: Conditionally => new StmtsSerializer(Seq(cond), indent) + case other => + implicit val b = new StringBuilder + s(other) + Iterator(b.toString) + } + private def s(node: Statement)(implicit b: StringBuilder, indent: Int): Unit = node match { case DefNode(info, name, value) => b ++= "node "; b ++= name; b ++= " = "; s(value); s(info) case Connect(info, loc, expr) => s(loc); b ++= " <= "; s(expr); s(info) - case Conditionally(info, pred, conseq, alt) => - b ++= "when "; s(pred); b ++= " :"; s(info) - newLineAndIndent(1); s(conseq)(b, indent + 1) - if (alt != EmptyStmt) { - newLineAndIndent(); b ++= "else :" - newLineAndIndent(1); s(alt)(b, indent + 1) - } - case EmptyStmt => b ++= "skip" - case Block(Seq()) => b ++= "skip" - case Block(stmts) => - val it = stmts.iterator - while (it.hasNext) { - s(it.next()) - if (it.hasNext) newLineAndIndent() - } + case c: Conditionally => b ++= sIt(c).mkString + case EmptyStmt => b ++= "skip" + case bb: Block => b ++= sIt(bb).mkString case stop @ Stop(info, ret, clk, en) => b ++= "stop("; s(clk); b ++= ", "; s(en); b ++= ", "; b ++= ret.toString; b += ')' sStmtName(stop.name); s(info) @@ -247,29 +376,43 @@ object Serializer { case other => b ++= other.serialize // Handle user-defined nodes } - private def s(node: DefModule)(implicit b: StringBuilder, indent: Int): Unit = node match { + private def sIt(node: DefModule)(implicit indent: Int): Iterator[String] = node match { case Module(info, name, ports, body) => - doIndent(0); b ++= "module "; b ++= name; b ++= " :"; s(info) - ports.foreach { p => newLineAndIndent(1); s(p) } - newLineNoIndent() // add a new line between port declaration and body - newLineAndIndent(1); s(body)(b, indent + 1) + val start = { + implicit val b = new StringBuilder + doIndent(0); b ++= "module "; b ++= name; b ++= " :"; s(info) + ports.foreach { p => newLineAndIndent(1); s(p) } + newLineNoIndent() // add a blank line between port declaration and body + newLineAndIndent(1) // also indent before body + b.toString + } + Iterator(start) ++ sIt(body)(indent + 1) case ExtModule(info, name, ports, defname, params) => + implicit val b = new StringBuilder doIndent(0); b ++= "extmodule "; b ++= name; b ++= " :"; s(info) ports.foreach { p => newLineAndIndent(1); s(p) } newLineAndIndent(1); b ++= "defname = "; b ++= defname params.foreach { p => newLineAndIndent(1); s(p) } - case other => doIndent(0); b ++= other.serialize // Handle user-defined nodes + Iterator(b.toString) + case other => + Iterator(Indent * indent, other.serialize) // Handle user-defined nodes } - private def s(node: Circuit)(implicit b: StringBuilder, indent: Int): Unit = node match { + private def sIt(node: Circuit)(implicit indent: Int): Iterator[String] = node match { case Circuit(info, modules, main) => - b ++= s"FIRRTL version ${version.serialize}\n" - b ++= "circuit "; b ++= main; b ++= " :"; s(info) - if (modules.nonEmpty) { - newLineNoIndent(); s(modules.head)(b, indent + 1) - modules.drop(1).foreach { m => newLineNoIndent(); newLineNoIndent(); s(m)(b, indent + 1) } + val prelude = { + implicit val b = new StringBuilder // Scope this so we don't accidentally pass it anywhere + b ++= s"FIRRTL version ${version.serialize}\n" + b ++= "circuit "; b ++= main; b ++= " :"; s(info) + b.toString } - newLineNoIndent() + Iterator(prelude) ++ + modules.iterator.zipWithIndex.flatMap { + case (m, i) => + val newline = Iterator(if (i == 0) s"$NewLine" else s"${NewLine}${NewLine}") + newline ++ sIt(m)(indent + 1) + } ++ + Iterator(s"$NewLine") } // serialize constraints -- cgit v1.2.3