aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/ir/Serializer.scala13
-rw-r--r--src/test/scala/firrtlTests/SerializerSpec.scala61
2 files changed, 73 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/ir/Serializer.scala b/src/main/scala/firrtl/ir/Serializer.scala
index bf9a57c1..08ea1445 100644
--- a/src/main/scala/firrtl/ir/Serializer.scala
+++ b/src/main/scala/firrtl/ir/Serializer.scala
@@ -26,6 +26,7 @@ object Serializer {
case n: Param => s(n)(builder, indent)
case n: DefModule => s(n)(builder, indent)
case n: Circuit => s(n)(builder, indent)
+ case other => builder ++= other.serialize // Handle user-defined nodes
}
builder.toString()
}
@@ -54,6 +55,7 @@ object Serializer {
infos.zipWithIndex.foreach { case (f, i) => b ++= f.escaped; if (i < lastId) b += ' ' }
b += ']'
}
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def s(str: StringLit)(implicit b: StringBuilder, indent: Int): Unit = b ++= str.serialize
@@ -79,6 +81,7 @@ object Serializer {
case firrtl.WVoid => b ++= "VOID"
case firrtl.WInvalid => b ++= "INVALID"
case firrtl.EmptyExpression => b ++= "EMPTY"
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def s(node: Statement)(implicit b: StringBuilder, indent: Int): Unit = node match {
@@ -149,6 +152,7 @@ object Serializer {
case firrtl.WDefInstanceConnector(info, name, module, tpe, portCons) =>
b ++= "inst "; b ++= name; b ++= " of "; b ++= module; b ++= " with "; s(tpe); b ++= " connected to ("
s(portCons.map(_._2), ", "); b += ')'; s(info)
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def s(node: Width)(implicit b: StringBuilder, indent: Int): Unit = node match {
@@ -156,6 +160,7 @@ object Serializer {
case UnknownWidth => // empty string
case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')'
case VarWidth(name) => b += '<'; b ++= name; b += '>'
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def sPoint(node: Width)(implicit b: StringBuilder, indent: Int): Unit = node match {
@@ -163,11 +168,13 @@ object Serializer {
case UnknownWidth => // empty string
case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')'
case VarWidth(name) => b ++= "<<"; b ++= name; b ++= ">>"
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def s(node: Orientation)(implicit b: StringBuilder, indent: Int): Unit = node match {
case Default => // empty string
case Flip => b ++= "flip "
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def s(node: Field)(implicit b: StringBuilder, indent: Int): Unit = node match {
@@ -188,11 +195,13 @@ object Serializer {
case UnknownType => b += '?'
// the IntervalType has a complicated custom serialization method which does not recurse
case i: IntervalType => b ++= i.serialize
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def s(node: Direction)(implicit b: StringBuilder, indent: Int): Unit = node match {
case Input => b ++= "input"
case Output => b ++= "output"
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def s(node: Port)(implicit b: StringBuilder, indent: Int): Unit = node match {
@@ -207,6 +216,7 @@ object Serializer {
case RawStringParam(name, value) =>
b ++= "parameter "; b ++= name; b ++= " = "
b += '\''; b ++= value.replace("'", "\\'"); b += '\''
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def s(node: DefModule)(implicit b: StringBuilder, indent: Int): Unit = node match {
@@ -220,6 +230,7 @@ object Serializer {
ports.foreach { p => newLineAndIndent(1); s(p) }
newLineAndIndent(1); b ++= "defname = "; b ++= defname
params.foreach { p => newLineAndIndent(1); s(p) }
+ case other => b ++= other.serialize // Handle user-defined nodes
}
private def s(node: Circuit)(implicit b: StringBuilder, indent: Int): Unit = node match {
@@ -239,7 +250,7 @@ object Serializer {
case VarBound(name) => b ++= name
case Open(value) => b ++ "o("; b ++= value.toString; b += ')'
case Closed(value) => b ++ "c("; b ++= value.toString; b += ')'
- case other => other.serialize
+ case other => b ++= other.serialize // Handle user-defined nodes
}
/** create a new line with the appropriate indent */
diff --git a/src/test/scala/firrtlTests/SerializerSpec.scala b/src/test/scala/firrtlTests/SerializerSpec.scala
new file mode 100644
index 00000000..8892de4b
--- /dev/null
+++ b/src/test/scala/firrtlTests/SerializerSpec.scala
@@ -0,0 +1,61 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import org.scalatest._
+import firrtl.ir._
+import firrtl.Utils
+import org.scalatest.flatspec.AnyFlatSpec
+import org.scalatest.matchers.should.Matchers
+
+object SerializerSpec {
+ case class WrapStmt(stmt: Statement) extends Statement {
+ def serialize: String = s"wrap(${stmt.serialize})"
+ def foreachExpr(f: Expression => Unit): Unit = stmt.foreachExpr(f)
+ def foreachInfo(f: Info => Unit): Unit = stmt.foreachInfo(f)
+ def foreachStmt(f: Statement => Unit): Unit = stmt.foreachStmt(f)
+ def foreachString(f: String => Unit): Unit = stmt.foreachString(f)
+ def foreachType(f: Type => Unit): Unit = stmt.foreachType(f)
+ def mapExpr(f: Expression => Expression): Statement = this.copy(stmt.mapExpr(f))
+ def mapInfo(f: Info => Info): Statement = this.copy(stmt.mapInfo(f))
+ def mapStmt(f: Statement => Statement): Statement = this.copy(stmt.mapStmt(f))
+ def mapString(f: String => String): Statement = this.copy(stmt.mapString(f))
+ def mapType(f: Type => Type): Statement = this.copy(stmt.mapType(f))
+ }
+
+ case class WrapExpr(expr: Expression) extends Expression {
+ def serialize: String = s"wrap(${expr.serialize})"
+ def tpe: Type = expr.tpe
+ def foreachExpr(f: Expression => Unit): Unit = expr.foreachExpr(f)
+ def foreachType(f: Type => Unit): Unit = expr.foreachType(f)
+ def foreachWidth(f: Width => Unit): Unit = expr.foreachWidth(f)
+ def mapExpr(f: Expression => Expression): Expression = this.copy(expr.mapExpr(f))
+ def mapType(f: Type => Type): Expression = this.copy(expr.mapType(f))
+ def mapWidth(f: Width => Width): Expression = this.copy(expr.mapWidth(f))
+ }
+}
+
+class SerializerSpec extends AnyFlatSpec with Matchers {
+ import SerializerSpec._
+
+ "ir.Serializer" should "support custom Statements" in {
+ val stmt = WrapStmt(DefWire(NoInfo, "myWire", Utils.BoolType))
+ val ser = "wrap(wire myWire : UInt<1>)"
+ Serializer.serialize(stmt) should be (ser)
+ }
+
+ it should "support custom Expression" in {
+ val expr = WrapExpr(Reference("foo"))
+ val ser = "wrap(foo)"
+ Serializer.serialize(expr) should be (ser)
+ }
+
+ it should "support nested custom Statements and Expressions" in {
+ val expr = SubField(WrapExpr(Reference("foo")), "bar")
+ val stmt = WrapStmt(DefNode(NoInfo, "n", expr))
+ val stmts = Block(stmt :: Nil)
+ val ser = "wrap(node n = wrap(foo).bar)"
+ Serializer.serialize(stmts) should be (ser)
+ }
+
+}