aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/AddDescriptionNodes.scala
blob: 2cd8b9f7fb268b5c0e5ec880ab64f4396cc31b2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// See LICENSE for license details.

package firrtl

import firrtl.ir._
import firrtl.annotations._
import firrtl.Mappers._

case class DescriptionAnnotation(named: Named, description: String) extends Annotation {
  def update(renames: RenameMap): Seq[DescriptionAnnotation] = {
    renames.get(named) match {
      case None => Seq(this)
      case Some(seq) => seq.map(n => this.copy(named = n))
    }
  }
}

private sealed trait HasDescription {
  def description: Description
}

private abstract class Description extends FirrtlNode

private case class DocString(string: StringLit) extends Description {
  def serialize: String = "@[" + string.serialize + "]"
}

private case object EmptyDescription extends Description {
  def serialize: String = ""
}

private case class DescribedStmt(description: Description, stmt: Statement) extends Statement with HasDescription {
  def serialize: String = s"${description.serialize}\n${stmt.serialize}"
  def mapStmt(f: Statement => Statement): Statement = f(stmt)
  def mapExpr(f: Expression => Expression): Statement = this.copy(stmt = stmt.mapExpr(f))
  def mapType(f: Type => Type): Statement = this.copy(stmt = stmt.mapType(f))
  def mapString(f: String => String): Statement = this.copy(stmt = stmt.mapString(f))
  def mapInfo(f: Info => Info): Statement = this.copy(stmt = stmt.mapInfo(f))
  def foreachStmt(f: Statement => Unit): Unit = f(stmt)
  def foreachExpr(f: Expression => Unit): Unit = stmt.foreachExpr(f)
  def foreachType(f: Type => Unit): Unit = stmt.foreachType(f)
  def foreachString(f: String => Unit): Unit = stmt.foreachString(f)
  def foreachInfo(f: Info => Unit): Unit = stmt.foreachInfo(f)
}

private case class DescribedMod(description: Description,
  portDescriptions: Map[String, Description],
  mod: DefModule) extends DefModule with HasDescription {
  val info = mod.info
  val name = mod.name
  val ports = mod.ports
  def serialize: String = s"${description.serialize}\n${mod.serialize}"
  def mapStmt(f: Statement => Statement): DefModule = this.copy(mod = mod.mapStmt(f))
  def mapPort(f: Port => Port): DefModule = this.copy(mod = mod.mapPort(f))
  def mapString(f: String => String): DefModule = this.copy(mod = mod.mapString(f))
  def mapInfo(f: Info => Info): DefModule = this.copy(mod = mod.mapInfo(f))
  def foreachStmt(f: Statement => Unit): Unit = mod.foreachStmt(f)
  def foreachPort(f: Port => Unit): Unit = mod.foreachPort(f)
  def foreachString(f: String => Unit): Unit = mod.foreachString(f)
  def foreachInfo(f: Info => Unit): Unit = mod.foreachInfo(f)
}

/** Wraps modules or statements with their respective described nodes. Descriptions come from [[DescriptionAnnotation]].
  * Describing a module or any of its ports will turn it into a `DescribedMod`. Describing a Statement will turn it into
  * a (private) `DescribedStmt`.
  *
  * @note should only be used by VerilogEmitter, described nodes will
  *       break other transforms.
  */
class AddDescriptionNodes extends Transform {
  def inputForm = LowForm
  def outputForm = LowForm

  def onStmt(compMap: Map[String, Seq[String]])(stmt: Statement): Statement = {
    stmt.map(onStmt(compMap)) match {
      case d: IsDeclaration if compMap.contains(d.name) =>
        DescribedStmt(DocString(StringLit.unescape(compMap(d.name).mkString("\n\n"))), d)
      case other => other
    }
  }

  def onModule(modMap: Map[String, Seq[String]], compMaps: Map[String, Map[String, Seq[String]]])
    (mod: DefModule): DefModule = {
    val (newMod, portDesc: Map[String, Description]) = compMaps.get(mod.name) match {
      case None => (mod, Map.empty)
      case Some(compMap) => (mod.mapStmt(onStmt(compMap)), mod.ports.collect {
        case p @ Port(_, name, _, _) if compMap.contains(name) =>
          name -> DocString(StringLit.unescape(compMap(name).mkString("\n\n")))
      }.toMap)
    }

    val modDesc = modMap.get(newMod.name).map {
      desc => DocString(StringLit.unescape(desc.mkString("\n\n")))
    }

    if (portDesc.nonEmpty || modDesc.nonEmpty) {
      DescribedMod(modDesc.getOrElse(EmptyDescription), portDesc, newMod)
    } else {
      newMod
    }
  }

  def collectMaps(annos: Seq[Annotation]): (Map[String, Seq[String]], Map[String, Map[String, Seq[String]]]) = {
    val modMap = annos.collect {
      case DescriptionAnnotation(ModuleName(m, CircuitName(c)), desc) => (m, desc)
    }.groupBy(_._1).mapValues(_.map(_._2))

    val compMap = annos.collect {
      case DescriptionAnnotation(ComponentName(comp, ModuleName(mod, CircuitName(circ))), desc) =>
        (mod, comp, desc)
    }.groupBy(_._1).mapValues(_.groupBy(_._2).mapValues(_.map(_._3)))

    (modMap, compMap)
  }

  def executeModule(module: DefModule, annos: Seq[Annotation]): DefModule = {
    val (modMap, compMap) = collectMaps(annos)

    onModule(modMap, compMap)(module)
  }

  override def execute(state: CircuitState): CircuitState = {
    val (modMap, compMap) = collectMaps(state.annotations)

    state.copy(circuit = state.circuit.mapModule(onModule(modMap, compMap)))
  }
}