diff options
| author | David Biancolin | 2019-10-29 19:17:44 -0700 |
|---|---|---|
| committer | David Biancolin | 2019-10-29 19:17:44 -0700 |
| commit | 7d16fb5f03812b56addfc7450f0808ccc54530c0 (patch) | |
| tree | c8a8e8365c2377d39807dc140928688bb4de5ede /src | |
| parent | 79d7287d91443cfa8faa98b20fa0f2a5a261f237 (diff) | |
Try implementing recursive typeHint look up
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/annotations/JsonProtocol.scala | 41 |
1 files changed, 15 insertions, 26 deletions
diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index db366395..02ec041a 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -16,14 +16,6 @@ trait HasSerializationHints { def typeHints(): Seq[Class[_]] } -/** - * An annotation carrying fully qualified class names for which we'd like the - * deserialization protocol to generate type hints for. - * - * @param typeTags the the additional class names - */ -case class DeserializationTypeHintsAnnotation(typeTags: Seq[String]) extends NoTargetAnnotation - object JsonProtocol { class TransformClassSerializer extends CustomSerializer[Class[_ <: Transform]](format => ( { case JString(s) => Class.forName(s).asInstanceOf[Class[_ <: Transform]] }, @@ -101,18 +93,13 @@ object JsonProtocol { def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get def serializeTry(annos: Seq[Annotation]): Try[String] = { - val annotationTags = annos.map(_.getClass).distinct - val additionalTags = annos.collect({ - case anno: HasSerializationHints => anno.typeHints + val tags = annos.collect({ + case anno: HasSerializationHints => anno.getClass +: anno.typeHints + case anno => Seq(anno.getClass) }).flatten.distinct - val typeHintAnno = additionalTags match { - case Nil => None - case tags => Some(DeserializationTypeHintsAnnotation(tags.map(_.getName))) - } - - implicit val formats = jsonFormat(classOf[DeserializationTypeHintsAnnotation] +: (annotationTags ++ additionalTags)) - Try(writePretty(typeHintAnno ++ annos)) + implicit val formats = jsonFormat(tags) + Try(writePretty(annos)) } def deserialize(in: JsonInput): Seq[Annotation] = deserializeTry(in).get @@ -124,15 +111,17 @@ object JsonProtocol { case x => throw new InvalidAnnotationJSONException( s"Annotations must be serialized as a JArray, got ${x.getClass.getSimpleName} instead!") } - // Gather classes so we can deserialize arbitrary Annotations - val typeHintAnnoName = classOf[DeserializationTypeHintsAnnotation].getName - val classes = annos.flatMap({ - case JObject(("class", JString(name)) :: ("typeTags", JArray(classes)) :: Nil) if name == typeHintAnnoName => - typeHintAnnoName +: classes.collect({ case JString(className) => className }) - case JObject(("class", JString(c)) :: tail) => Seq(c) - case obj => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") + // Recursively gather typeHints by pulling the "class" field from JObjects + // Json4s should emit this as the first field in all serialized classes + def findTypeHints(classInst: Seq[JValue]): Seq[String] = classInst.flatMap({ + case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2)) + case obj: JObject => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") + case JArray(arr) => findTypeHints(arr) + case oJValue => Seq() }).distinct - val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_ <: Annotation]]) + + val classes = findTypeHints(annos) + val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_]]) implicit val formats = jsonFormat(loaded) read[List[Annotation]](in) }).recoverWith { |
