diff options
| author | David Biancolin | 2019-11-05 08:58:13 -0700 |
|---|---|---|
| committer | GitHub | 2019-11-05 08:58:13 -0700 |
| commit | 543c7658faea82f6b7d4f3c7dacb58b17cbd02c9 (patch) | |
| tree | ec0c3eba1789733087a2020fd17ea083957e5d34 | |
| parent | 0d7defc81b02c41e416237ad226adc5f1ab0f8f2 (diff) | |
| parent | cae20ae9ff51e7ebc2151b4f88853d3ac3859f65 (diff) | |
Merge pull request #1211 from freechipsproject/serialization-utils
Supply a trait to allow user annotations to provide JsonProtocol type hints
| -rw-r--r-- | src/main/scala/firrtl/annotations/JsonProtocol.scala | 32 | ||||
| -rw-r--r-- | src/test/scala/firrtl/JsonProtocolSpec.scala | 62 |
2 files changed, 87 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index b09155d8..c3853650 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -9,6 +9,12 @@ import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization import org.json4s.native.Serialization.{read, writePretty} +trait HasSerializationHints { + // For serialization of complicated constructor arguments, let the annotation + // writer specify additional type hints for relevant classes that might be + // contained within + def typeHints: Seq[Class[_]] +} object JsonProtocol { class TransformClassSerializer extends CustomSerializer[Class[_ <: Transform]](format => ( @@ -74,7 +80,7 @@ object JsonProtocol { )) /** Construct Json formatter for annotations */ - def jsonFormat(tags: Seq[Class[_ <: Annotation]]) = { + def jsonFormat(tags: Seq[Class[_]]) = { Serialization.formats(FullTypeHints(tags.toList)).withTypeHintFieldName("class") + new TransformClassSerializer + new NamedSerializer + new CircuitNameSerializer + new ModuleNameSerializer + new ComponentNameSerializer + new TargetSerializer + @@ -87,7 +93,11 @@ object JsonProtocol { def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get def serializeTry(annos: Seq[Annotation]): Try[String] = { - val tags = annos.map(_.getClass).distinct + val tags = annos.flatMap({ + case anno: HasSerializationHints => anno.getClass +: anno.typeHints + case anno => Seq(anno.getClass) + }).distinct + implicit val formats = jsonFormat(tags) Try(writePretty(annos)) } @@ -101,12 +111,20 @@ object JsonProtocol { case x => throw new InvalidAnnotationJSONException( s"Annotations must be serialized as a JArray, got ${x.getClass.getSimpleName} instead!") } - // Gather classes so we can deserialize arbitrary Annotations - val classes = annos.map({ - case JObject(("class", JString(c)) :: tail) => c - case obj => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") + // Recursively gather typeHints by pulling the "class" field from JObjects + // Json4s should emit this as the first field in all serialized classes + // Setting requireClassField mandates that all JObjects must provide a typeHint, + // this used on the first invocation to check all annotations do so + def findTypeHints(classInst: Seq[JValue], requireClassField: Boolean = false): Seq[String] = classInst.flatMap({ + case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2)) + case obj: JObject if requireClassField => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") + case JObject(fields) => findTypeHints(fields.map(_._2)) + case JArray(arr) => findTypeHints(arr) + case oJValue => Seq() }).distinct - val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_ <: Annotation]]) + + val classes = findTypeHints(annos, true) + val loaded = classes.map(Class.forName(_)) implicit val formats = jsonFormat(loaded) read[List[Annotation]](in) }).recoverWith { diff --git a/src/test/scala/firrtl/JsonProtocolSpec.scala b/src/test/scala/firrtl/JsonProtocolSpec.scala new file mode 100644 index 00000000..955abdc0 --- /dev/null +++ b/src/test/scala/firrtl/JsonProtocolSpec.scala @@ -0,0 +1,62 @@ +// See LICENSE for license details. + +package firrtlTests + +import org.scalatest.FlatSpec +import org.json4s._ +import org.json4s.native.JsonMethods._ + +import firrtl.annotations.{NoTargetAnnotation, JsonProtocol, InvalidAnnotationJSONException, HasSerializationHints, Annotation} + +object JsonProtocolTestClasses { + trait Parent + + case class ChildA(foo: Int) extends Parent + case class ChildB(bar: String) extends Parent + case class PolymorphicParameterAnnotation(param: Parent) extends NoTargetAnnotation + case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) extends NoTargetAnnotation with HasSerializationHints { + def typeHints = Seq(param.getClass) + } + + case class TypeParameterizedAnnotation[T](param: T) extends NoTargetAnnotation + case class TypeParameterizedAnnotationWithTypeHints[T](param: T) extends NoTargetAnnotation with HasSerializationHints { + def typeHints = Seq(param.getClass) + } +} + +import JsonProtocolTestClasses._ + +class JsonProtocolSpec extends FlatSpec { + def serializeAndDeserialize(anno: Annotation): Annotation = { + val serializedAnno = JsonProtocol.serialize(Seq(anno)) + JsonProtocol.deserialize(serializedAnno).head + } + + "Annotations with polymorphic parameters" should "not serialize and deserialize without type hints" in { + val anno = PolymorphicParameterAnnotation(ChildA(1)) + assertThrows[InvalidAnnotationJSONException] { + serializeAndDeserialize(anno) + } + } + + it should "serialize and deserialize with type hints" in { + val anno = PolymorphicParameterAnnotationWithTypeHints(ChildA(1)) + val deserAnno = serializeAndDeserialize(anno) + assert(anno == deserAnno) + + val anno2 = PolymorphicParameterAnnotationWithTypeHints(ChildB("Test")) + val deserAnno2 = serializeAndDeserialize(anno2) + assert(anno2 == deserAnno2) + } + + "Annotations with non-primitive type parameters" should "not serialize and deserialize without type hints" in { + val anno = TypeParameterizedAnnotation(ChildA(1)) + val deserAnno = serializeAndDeserialize(anno) + assert (anno != deserAnno) + } + it should "serialize and deserialize with type hints" in { + val anno = TypeParameterizedAnnotationWithTypeHints(ChildA(1)) + val deserAnno = serializeAndDeserialize(anno) + assert (anno == deserAnno) + } +} |
