From f6db9d4bd6156bb8af7dfdbcced20cd6a098920f Mon Sep 17 00:00:00 2001 From: David Biancolin Date: Thu, 24 Oct 2019 10:44:04 -0700 Subject: Supply a trait to allow user annotations to provide SERDES type hints --- .../scala/firrtl/annotations/JsonProtocol.scala | 33 ++++++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index b09155d8..9fe575c7 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -9,6 +9,20 @@ import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization import org.json4s.native.Serialization.{read, writePretty} +trait HasSerializationHints { + // For serialization of complicated constuctor arguments, let the annotation + // writer specify additional type hints for relevant classes that might be + // contained within + def typeHints(): Seq[Class[_]] +} + +/** + * An annotation carrying fully qualified class names for which we'd like the + * deserialization protocol to generate type hints for. + * + * @param typeTags the the additional class names + */ +case class DeserializationTypeHintsAnnotation(typeTags: Seq[String]) extends NoTargetAnnotation object JsonProtocol { class TransformClassSerializer extends CustomSerializer[Class[_ <: Transform]](format => ( @@ -74,7 +88,7 @@ object JsonProtocol { )) /** Construct Json formatter for annotations */ - def jsonFormat(tags: Seq[Class[_ <: Annotation]]) = { + def jsonFormat(tags: Seq[Class[_]]) = { Serialization.formats(FullTypeHints(tags.toList)).withTypeHintFieldName("class") + new TransformClassSerializer + new NamedSerializer + new CircuitNameSerializer + new ModuleNameSerializer + new ComponentNameSerializer + new TargetSerializer + @@ -87,9 +101,13 @@ object JsonProtocol { def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get def serializeTry(annos: Seq[Annotation]): Try[String] = { - val tags = annos.map(_.getClass).distinct - implicit val formats = jsonFormat(tags) - Try(writePretty(annos)) + val tags = annos.flatMap({ + case anno: HasSerializationHints => anno.getClass +: anno.typeHints + case other => Seq(other.getClass) + }).distinct + + implicit val formats = jsonFormat(classOf[DeserializationTypeHintsAnnotation] +: tags) + Try(writePretty(DeserializationTypeHintsAnnotation(tags.map(_.getName)) +: annos)) } def deserialize(in: JsonInput): Seq[Annotation] = deserializeTry(in).get @@ -102,8 +120,11 @@ object JsonProtocol { s"Annotations must be serialized as a JArray, got ${x.getClass.getSimpleName} instead!") } // Gather classes so we can deserialize arbitrary Annotations - val classes = annos.map({ - case JObject(("class", JString(c)) :: tail) => c + val typeHintAnnoName = classOf[DeserializationTypeHintsAnnotation].getName + val classes = annos.flatMap({ + case JObject(("class", JString(name)) :: ("typeTags", JArray(classes)) :: Nil) if name == typeHintAnnoName => + typeHintAnnoName +: classes.collect({ case JString(className) => className }) + case JObject(("class", JString(c)) :: tail) => Seq(c) case obj => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") }).distinct val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_ <: Annotation]]) -- cgit v1.2.3 From 79d7287d91443cfa8faa98b20fa0f2a5a261f237 Mon Sep 17 00:00:00 2001 From: David Biancolin Date: Fri, 25 Oct 2019 13:26:15 -0700 Subject: Only emit the DeserilizationTypeHintsAnno when needed --- src/main/scala/firrtl/annotations/JsonProtocol.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index 9fe575c7..db366395 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -101,13 +101,18 @@ object JsonProtocol { def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get def serializeTry(annos: Seq[Annotation]): Try[String] = { - val tags = annos.flatMap({ - case anno: HasSerializationHints => anno.getClass +: anno.typeHints - case other => Seq(other.getClass) - }).distinct + val annotationTags = annos.map(_.getClass).distinct + val additionalTags = annos.collect({ + case anno: HasSerializationHints => anno.typeHints + }).flatten.distinct + + val typeHintAnno = additionalTags match { + case Nil => None + case tags => Some(DeserializationTypeHintsAnnotation(tags.map(_.getName))) + } - implicit val formats = jsonFormat(classOf[DeserializationTypeHintsAnnotation] +: tags) - Try(writePretty(DeserializationTypeHintsAnnotation(tags.map(_.getName)) +: annos)) + implicit val formats = jsonFormat(classOf[DeserializationTypeHintsAnnotation] +: (annotationTags ++ additionalTags)) + Try(writePretty(typeHintAnno ++ annos)) } def deserialize(in: JsonInput): Seq[Annotation] = deserializeTry(in).get -- cgit v1.2.3 From 7d16fb5f03812b56addfc7450f0808ccc54530c0 Mon Sep 17 00:00:00 2001 From: David Biancolin Date: Tue, 29 Oct 2019 19:17:44 -0700 Subject: Try implementing recursive typeHint look up --- .../scala/firrtl/annotations/JsonProtocol.scala | 41 ++++++++-------------- 1 file changed, 15 insertions(+), 26 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index db366395..02ec041a 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -16,14 +16,6 @@ trait HasSerializationHints { def typeHints(): Seq[Class[_]] } -/** - * An annotation carrying fully qualified class names for which we'd like the - * deserialization protocol to generate type hints for. - * - * @param typeTags the the additional class names - */ -case class DeserializationTypeHintsAnnotation(typeTags: Seq[String]) extends NoTargetAnnotation - object JsonProtocol { class TransformClassSerializer extends CustomSerializer[Class[_ <: Transform]](format => ( { case JString(s) => Class.forName(s).asInstanceOf[Class[_ <: Transform]] }, @@ -101,18 +93,13 @@ object JsonProtocol { def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get def serializeTry(annos: Seq[Annotation]): Try[String] = { - val annotationTags = annos.map(_.getClass).distinct - val additionalTags = annos.collect({ - case anno: HasSerializationHints => anno.typeHints + val tags = annos.collect({ + case anno: HasSerializationHints => anno.getClass +: anno.typeHints + case anno => Seq(anno.getClass) }).flatten.distinct - val typeHintAnno = additionalTags match { - case Nil => None - case tags => Some(DeserializationTypeHintsAnnotation(tags.map(_.getName))) - } - - implicit val formats = jsonFormat(classOf[DeserializationTypeHintsAnnotation] +: (annotationTags ++ additionalTags)) - Try(writePretty(typeHintAnno ++ annos)) + implicit val formats = jsonFormat(tags) + Try(writePretty(annos)) } def deserialize(in: JsonInput): Seq[Annotation] = deserializeTry(in).get @@ -124,15 +111,17 @@ object JsonProtocol { case x => throw new InvalidAnnotationJSONException( s"Annotations must be serialized as a JArray, got ${x.getClass.getSimpleName} instead!") } - // Gather classes so we can deserialize arbitrary Annotations - val typeHintAnnoName = classOf[DeserializationTypeHintsAnnotation].getName - val classes = annos.flatMap({ - case JObject(("class", JString(name)) :: ("typeTags", JArray(classes)) :: Nil) if name == typeHintAnnoName => - typeHintAnnoName +: classes.collect({ case JString(className) => className }) - case JObject(("class", JString(c)) :: tail) => Seq(c) - case obj => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") + // Recursively gather typeHints by pulling the "class" field from JObjects + // Json4s should emit this as the first field in all serialized classes + def findTypeHints(classInst: Seq[JValue]): Seq[String] = classInst.flatMap({ + case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2)) + case obj: JObject => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") + case JArray(arr) => findTypeHints(arr) + case oJValue => Seq() }).distinct - val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_ <: Annotation]]) + + val classes = findTypeHints(annos) + val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_]]) implicit val formats = jsonFormat(loaded) read[List[Annotation]](in) }).recoverWith { -- cgit v1.2.3 From 66464f5e4c37bbfa3fc90da3dde964dd0410cfd1 Mon Sep 17 00:00:00 2001 From: David Biancolin Date: Tue, 29 Oct 2019 20:56:01 -0700 Subject: Check that all annotations provide the typeHint --- src/main/scala/firrtl/annotations/JsonProtocol.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index 02ec041a..1c84152e 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -113,14 +113,17 @@ object JsonProtocol { } // Recursively gather typeHints by pulling the "class" field from JObjects // Json4s should emit this as the first field in all serialized classes - def findTypeHints(classInst: Seq[JValue]): Seq[String] = classInst.flatMap({ + // Setting requireClassField mandates that all JObjects must provide a typeHint, + // this used on the first invocation to check all annotations do so + def findTypeHints(classInst: Seq[JValue], requireClassField: Boolean = false): Seq[String] = classInst.flatMap({ case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2)) - case obj: JObject => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") + case obj: JObject if requireClassField => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") + case JObject(fields) => findTypeHints(fields.map(_._2)) case JArray(arr) => findTypeHints(arr) case oJValue => Seq() }).distinct - val classes = findTypeHints(annos) + val classes = findTypeHints(annos, true) val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_]]) implicit val formats = jsonFormat(loaded) read[List[Annotation]](in) -- cgit v1.2.3 From ffd0f22db0fd50188f5394bbc82a9c7d53373f93 Mon Sep 17 00:00:00 2001 From: David Biancolin Date: Tue, 29 Oct 2019 20:57:59 -0700 Subject: Update src/main/scala/firrtl/annotations/JsonProtocol.scala Co-Authored-By: Jack Koenig --- src/main/scala/firrtl/annotations/JsonProtocol.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index 1c84152e..92cf4264 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -13,7 +13,7 @@ trait HasSerializationHints { // For serialization of complicated constuctor arguments, let the annotation // writer specify additional type hints for relevant classes that might be // contained within - def typeHints(): Seq[Class[_]] + def typeHints: Seq[Class[_]] } object JsonProtocol { -- cgit v1.2.3 From db6a6dce09ce2eb1c6c6c0c1c5ec3c881f4c3d77 Mon Sep 17 00:00:00 2001 From: David Biancolin Date: Tue, 29 Oct 2019 21:01:09 -0700 Subject: Some cleanup --- src/main/scala/firrtl/annotations/JsonProtocol.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index 92cf4264..4fb40b76 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -10,7 +10,7 @@ import org.json4s.native.Serialization import org.json4s.native.Serialization.{read, writePretty} trait HasSerializationHints { - // For serialization of complicated constuctor arguments, let the annotation + // For serialization of complicated constructor arguments, let the annotation // writer specify additional type hints for relevant classes that might be // contained within def typeHints: Seq[Class[_]] @@ -93,10 +93,10 @@ object JsonProtocol { def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get def serializeTry(annos: Seq[Annotation]): Try[String] = { - val tags = annos.collect({ + val tags = annos.flatMap({ case anno: HasSerializationHints => anno.getClass +: anno.typeHints case anno => Seq(anno.getClass) - }).flatten.distinct + }).distinct implicit val formats = jsonFormat(tags) Try(writePretty(annos)) -- cgit v1.2.3 From 57335e9cc43777dc206dec570dc877ded4f03bd0 Mon Sep 17 00:00:00 2001 From: David Biancolin Date: Tue, 29 Oct 2019 21:11:01 -0700 Subject: Remove an unneeded cast --- src/main/scala/firrtl/annotations/JsonProtocol.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index 4fb40b76..c3853650 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -124,7 +124,7 @@ object JsonProtocol { }).distinct val classes = findTypeHints(annos, true) - val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_]]) + val loaded = classes.map(Class.forName(_)) implicit val formats = jsonFormat(loaded) read[List[Annotation]](in) }).recoverWith { -- cgit v1.2.3 From cd433e7cd54f53066b7c1f338e828d8e1d0b9d8a Mon Sep 17 00:00:00 2001 From: David Biancolin Date: Wed, 30 Oct 2019 14:08:27 -0700 Subject: Add some simple tests to demonstrate how to provide type hints --- src/test/scala/firrtl/JsonProtocolSpec.scala | 62 ++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 src/test/scala/firrtl/JsonProtocolSpec.scala (limited to 'src') diff --git a/src/test/scala/firrtl/JsonProtocolSpec.scala b/src/test/scala/firrtl/JsonProtocolSpec.scala new file mode 100644 index 00000000..955abdc0 --- /dev/null +++ b/src/test/scala/firrtl/JsonProtocolSpec.scala @@ -0,0 +1,62 @@ +// See LICENSE for license details. + +package firrtlTests + +import org.scalatest.FlatSpec +import org.json4s._ +import org.json4s.native.JsonMethods._ + +import firrtl.annotations.{NoTargetAnnotation, JsonProtocol, InvalidAnnotationJSONException, HasSerializationHints, Annotation} + +object JsonProtocolTestClasses { + trait Parent + + case class ChildA(foo: Int) extends Parent + case class ChildB(bar: String) extends Parent + case class PolymorphicParameterAnnotation(param: Parent) extends NoTargetAnnotation + case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) extends NoTargetAnnotation with HasSerializationHints { + def typeHints = Seq(param.getClass) + } + + case class TypeParameterizedAnnotation[T](param: T) extends NoTargetAnnotation + case class TypeParameterizedAnnotationWithTypeHints[T](param: T) extends NoTargetAnnotation with HasSerializationHints { + def typeHints = Seq(param.getClass) + } +} + +import JsonProtocolTestClasses._ + +class JsonProtocolSpec extends FlatSpec { + def serializeAndDeserialize(anno: Annotation): Annotation = { + val serializedAnno = JsonProtocol.serialize(Seq(anno)) + JsonProtocol.deserialize(serializedAnno).head + } + + "Annotations with polymorphic parameters" should "not serialize and deserialize without type hints" in { + val anno = PolymorphicParameterAnnotation(ChildA(1)) + assertThrows[InvalidAnnotationJSONException] { + serializeAndDeserialize(anno) + } + } + + it should "serialize and deserialize with type hints" in { + val anno = PolymorphicParameterAnnotationWithTypeHints(ChildA(1)) + val deserAnno = serializeAndDeserialize(anno) + assert(anno == deserAnno) + + val anno2 = PolymorphicParameterAnnotationWithTypeHints(ChildB("Test")) + val deserAnno2 = serializeAndDeserialize(anno2) + assert(anno2 == deserAnno2) + } + + "Annotations with non-primitive type parameters" should "not serialize and deserialize without type hints" in { + val anno = TypeParameterizedAnnotation(ChildA(1)) + val deserAnno = serializeAndDeserialize(anno) + assert (anno != deserAnno) + } + it should "serialize and deserialize with type hints" in { + val anno = TypeParameterizedAnnotationWithTypeHints(ChildA(1)) + val deserAnno = serializeAndDeserialize(anno) + assert (anno == deserAnno) + } +} -- cgit v1.2.3