aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl')
-rw-r--r--src/main/scala/firrtl/annotations/JsonProtocol.scala32
1 files changed, 25 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala
index b09155d8..c3853650 100644
--- a/src/main/scala/firrtl/annotations/JsonProtocol.scala
+++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala
@@ -9,6 +9,12 @@ import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization
import org.json4s.native.Serialization.{read, writePretty}
+trait HasSerializationHints {
+ // For serialization of complicated constructor arguments, let the annotation
+ // writer specify additional type hints for relevant classes that might be
+ // contained within
+ def typeHints: Seq[Class[_]]
+}
object JsonProtocol {
class TransformClassSerializer extends CustomSerializer[Class[_ <: Transform]](format => (
@@ -74,7 +80,7 @@ object JsonProtocol {
))
/** Construct Json formatter for annotations */
- def jsonFormat(tags: Seq[Class[_ <: Annotation]]) = {
+ def jsonFormat(tags: Seq[Class[_]]) = {
Serialization.formats(FullTypeHints(tags.toList)).withTypeHintFieldName("class") +
new TransformClassSerializer + new NamedSerializer + new CircuitNameSerializer +
new ModuleNameSerializer + new ComponentNameSerializer + new TargetSerializer +
@@ -87,7 +93,11 @@ object JsonProtocol {
def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get
def serializeTry(annos: Seq[Annotation]): Try[String] = {
- val tags = annos.map(_.getClass).distinct
+ val tags = annos.flatMap({
+ case anno: HasSerializationHints => anno.getClass +: anno.typeHints
+ case anno => Seq(anno.getClass)
+ }).distinct
+
implicit val formats = jsonFormat(tags)
Try(writePretty(annos))
}
@@ -101,12 +111,20 @@ object JsonProtocol {
case x => throw new InvalidAnnotationJSONException(
s"Annotations must be serialized as a JArray, got ${x.getClass.getSimpleName} instead!")
}
- // Gather classes so we can deserialize arbitrary Annotations
- val classes = annos.map({
- case JObject(("class", JString(c)) :: tail) => c
- case obj => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj")
+ // Recursively gather typeHints by pulling the "class" field from JObjects
+ // Json4s should emit this as the first field in all serialized classes
+ // Setting requireClassField mandates that all JObjects must provide a typeHint,
+ // this used on the first invocation to check all annotations do so
+ def findTypeHints(classInst: Seq[JValue], requireClassField: Boolean = false): Seq[String] = classInst.flatMap({
+ case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2))
+ case obj: JObject if requireClassField => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj")
+ case JObject(fields) => findTypeHints(fields.map(_._2))
+ case JArray(arr) => findTypeHints(arr)
+ case oJValue => Seq()
}).distinct
- val loaded = classes.map(Class.forName(_).asInstanceOf[Class[_ <: Annotation]])
+
+ val classes = findTypeHints(annos, true)
+ val loaded = classes.map(Class.forName(_))
implicit val formats = jsonFormat(loaded)
read[List[Annotation]](in)
}).recoverWith {