aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Biancolin2019-11-05 08:58:13 -0700
committerGitHub2019-11-05 08:58:13 -0700
commit543c7658faea82f6b7d4f3c7dacb58b17cbd02c9 (patch)
treeec0c3eba1789733087a2020fd17ea083957e5d34
parent0d7defc81b02c41e416237ad226adc5f1ab0f8f2 (diff)
parentcae20ae9ff51e7ebc2151b4f88853d3ac3859f65 (diff)
Merge pull request #1211 from freechipsproject/serialization-utils
Supply a trait to allow user annotations to provide JsonProtocol type hints
-rw-r--r--src/main/scala/firrtl/annotations/JsonProtocol.scala32
-rw-r--r--src/test/scala/firrtl/JsonProtocolSpec.scala62
2 files changed, 87 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala
index b09155d8..c3853650 100644
--- a/src/main/scala/firrtl/annotations/JsonProtocol.scala
+++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala
@@ -9,6 +9,12 @@ import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization
import org.json4s.native.Serialization.{read, writePretty}
+trait HasSerializationHints {
+ // For serialization of complicated constructor arguments, let the annotation
+ // writer specify additional type hints for relevant classes that might be
+ // contained within
+ def typeHints: Seq[Class[_]]
+}
object JsonProtocol {
class TransformClassSerializer extends CustomSerializer[Class[_ <: Transform]](format => (
@@ -74,7 +80,7 @@ object JsonProtocol {
))
/** Construct Json formatter for annotations */
- def jsonFormat(tags: Seq[Class[_ <: Annotation]]) = {
+ def jsonFormat(tags: Seq[Class[_]]) = {
Serialization.formats(FullTypeHints(tags.toList)).withTypeHintFieldName("class") +
new TransformClassSerializer + new NamedSerializer + new CircuitNameSerializer +
new ModuleNameSerializer + new ComponentNameSerializer + new TargetSerializer +
@@ -87,7 +93,11 @@ object JsonProtocol {
def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get
def serializeTry(annos: Seq[Annotation]): Try[String] = {
- val tags = annos.map(_.getClass).distinct
+ val tags = annos.flatMap({
+ case anno: HasSerializationHints => anno.getClass +: anno.typeHints
+ case anno => Seq(anno.getClass)
+ }).distinct
+
implicit val formats = jsonFormat(tags)
Try(writePretty(annos))
}
@@ -101,12 +111,20 @@ object JsonProtocol {
case x => throw new InvalidAnnotationJSONException(
s"Annotations must be serialized as a JArray, got ${x.getClass.getSimpleName} instead!")
}
- // Gather classes so we can deserialize arbitrary Annotations
- val classes = annos.map({
- case JObject(("class", JString(c)) :: tail) => c
- case obj => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj")
+ // Recursively gather typeHints by pulling the "class" field from JObjects
+ // Json4s should emit this as the first field in all serialized classes
+ // Setting requireClassField mandates that all JObjects must provide a typeHint,
+ // this used on the first invocation to check all annotations do so
+ def findTypeHints(classInst: Seq[JValue], requireClassField: Boolean = false): Seq[String] = classInst.flatMap({
+ case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2))
+ case obj: JObject if requireClassField => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj")
+ case JObject(fields) => findTypeHints(fields.map(_._2))
+ case JArray(arr) => findTypeHints(arr)
+ case oJValue => Seq()
}).distinct
- val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_ <: Annotation]])
+
+ val classes = findTypeHints(annos, true)
+ val loaded = classes.map(Class.forName(_))
implicit val formats = jsonFormat(loaded)
read[List[Annotation]](in)
}).recoverWith {
diff --git a/src/test/scala/firrtl/JsonProtocolSpec.scala b/src/test/scala/firrtl/JsonProtocolSpec.scala
new file mode 100644
index 00000000..955abdc0
--- /dev/null
+++ b/src/test/scala/firrtl/JsonProtocolSpec.scala
@@ -0,0 +1,62 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import org.scalatest.FlatSpec
+import org.json4s._
+import org.json4s.native.JsonMethods._
+
+import firrtl.annotations.{NoTargetAnnotation, JsonProtocol, InvalidAnnotationJSONException, HasSerializationHints, Annotation}
+
+object JsonProtocolTestClasses {
+ trait Parent
+
+ case class ChildA(foo: Int) extends Parent
+ case class ChildB(bar: String) extends Parent
+ case class PolymorphicParameterAnnotation(param: Parent) extends NoTargetAnnotation
+ case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) extends NoTargetAnnotation with HasSerializationHints {
+ def typeHints = Seq(param.getClass)
+ }
+
+ case class TypeParameterizedAnnotation[T](param: T) extends NoTargetAnnotation
+ case class TypeParameterizedAnnotationWithTypeHints[T](param: T) extends NoTargetAnnotation with HasSerializationHints {
+ def typeHints = Seq(param.getClass)
+ }
+}
+
+import JsonProtocolTestClasses._
+
+class JsonProtocolSpec extends FlatSpec {
+ def serializeAndDeserialize(anno: Annotation): Annotation = {
+ val serializedAnno = JsonProtocol.serialize(Seq(anno))
+ JsonProtocol.deserialize(serializedAnno).head
+ }
+
+ "Annotations with polymorphic parameters" should "not serialize and deserialize without type hints" in {
+ val anno = PolymorphicParameterAnnotation(ChildA(1))
+ assertThrows[InvalidAnnotationJSONException] {
+ serializeAndDeserialize(anno)
+ }
+ }
+
+ it should "serialize and deserialize with type hints" in {
+ val anno = PolymorphicParameterAnnotationWithTypeHints(ChildA(1))
+ val deserAnno = serializeAndDeserialize(anno)
+ assert(anno == deserAnno)
+
+ val anno2 = PolymorphicParameterAnnotationWithTypeHints(ChildB("Test"))
+ val deserAnno2 = serializeAndDeserialize(anno2)
+ assert(anno2 == deserAnno2)
+ }
+
+ "Annotations with non-primitive type parameters" should "not serialize and deserialize without type hints" in {
+ val anno = TypeParameterizedAnnotation(ChildA(1))
+ val deserAnno = serializeAndDeserialize(anno)
+ assert (anno != deserAnno)
+ }
+ it should "serialize and deserialize with type hints" in {
+ val anno = TypeParameterizedAnnotationWithTypeHints(ChildA(1))
+ val deserAnno = serializeAndDeserialize(anno)
+ assert (anno == deserAnno)
+ }
+}