aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDavid Biancolin2019-10-29 19:17:44 -0700
committerDavid Biancolin2019-10-29 19:17:44 -0700
commit7d16fb5f03812b56addfc7450f0808ccc54530c0 (patch)
treec8a8e8365c2377d39807dc140928688bb4de5ede /src
parent79d7287d91443cfa8faa98b20fa0f2a5a261f237 (diff)
Try implementing recursive typeHint look up
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/annotations/JsonProtocol.scala41
1 files changed, 15 insertions, 26 deletions
diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala
index db366395..02ec041a 100644
--- a/src/main/scala/firrtl/annotations/JsonProtocol.scala
+++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala
@@ -16,14 +16,6 @@ trait HasSerializationHints {
def typeHints(): Seq[Class[_]]
}
-/**
- * An annotation carrying fully qualified class names for which we'd like the
- * deserialization protocol to generate type hints for.
- *
- * @param typeTags the the additional class names
- */
-case class DeserializationTypeHintsAnnotation(typeTags: Seq[String]) extends NoTargetAnnotation
-
object JsonProtocol {
class TransformClassSerializer extends CustomSerializer[Class[_ <: Transform]](format => (
{ case JString(s) => Class.forName(s).asInstanceOf[Class[_ <: Transform]] },
@@ -101,18 +93,13 @@ object JsonProtocol {
def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get
def serializeTry(annos: Seq[Annotation]): Try[String] = {
- val annotationTags = annos.map(_.getClass).distinct
- val additionalTags = annos.collect({
- case anno: HasSerializationHints => anno.typeHints
+ val tags = annos.collect({
+ case anno: HasSerializationHints => anno.getClass +: anno.typeHints
+ case anno => Seq(anno.getClass)
}).flatten.distinct
- val typeHintAnno = additionalTags match {
- case Nil => None
- case tags => Some(DeserializationTypeHintsAnnotation(tags.map(_.getName)))
- }
-
- implicit val formats = jsonFormat(classOf[DeserializationTypeHintsAnnotation] +: (annotationTags ++ additionalTags))
- Try(writePretty(typeHintAnno ++ annos))
+ implicit val formats = jsonFormat(tags)
+ Try(writePretty(annos))
}
def deserialize(in: JsonInput): Seq[Annotation] = deserializeTry(in).get
@@ -124,15 +111,17 @@ object JsonProtocol {
case x => throw new InvalidAnnotationJSONException(
s"Annotations must be serialized as a JArray, got ${x.getClass.getSimpleName} instead!")
}
- // Gather classes so we can deserialize arbitrary Annotations
- val typeHintAnnoName = classOf[DeserializationTypeHintsAnnotation].getName
- val classes = annos.flatMap({
- case JObject(("class", JString(name)) :: ("typeTags", JArray(classes)) :: Nil) if name == typeHintAnnoName =>
- typeHintAnnoName +: classes.collect({ case JString(className) => className })
- case JObject(("class", JString(c)) :: tail) => Seq(c)
- case obj => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj")
+ // Recursively gather typeHints by pulling the "class" field from JObjects
+ // Json4s should emit this as the first field in all serialized classes
+ def findTypeHints(classInst: Seq[JValue]): Seq[String] = classInst.flatMap({
+ case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2))
+ case obj: JObject => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj")
+ case JArray(arr) => findTypeHints(arr)
+ case oJValue => Seq()
}).distinct
- val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_ <: Annotation]])
+
+ val classes = findTypeHints(annos)
+ val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_]])
implicit val formats = jsonFormat(loaded)
read[List[Annotation]](in)
}).recoverWith {