aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/Compiler.scala
blob: 106c973facf56e48f7eae1be54edfacee6ffec7b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
// See LICENSE for license details.

package firrtl

import logger.LazyLogging
import java.io.Writer
import annotations._

import firrtl.ir.Circuit
import passes.Pass

/**
 * RenameMap maps old names to modified names.  Generated by transformations
 * that modify names
 */
case class RenameMap(map: Map[Named, Seq[Named]])

/**
 * Container of all annotations for a Firrtl compiler.
 */
case class AnnotationMap(annotations: Seq[Annotation]) {
  def get(id: Class[_]): Seq[Annotation] = annotations.filter(a => a.transform == id)
  def get(named: Named): Seq[Annotation] = annotations.filter(n => n == named)
}

/** Current State of the Circuit
  *
  * @constructor Creates a CircuitState object
  * @param circuit The current state of the Firrtl AST
  * @param form The current form of the circuit
  * @param annotations The current collection of [[Annotations.Annotation]]
  * @param renames A map of [[Annotations.Named]] things that have been renamed.
  *   Generally only a return value from [[Transform]]s
  */
case class CircuitState(
  circuit: Circuit,
  form: CircuitForm,
  annotations: Option[AnnotationMap] = None,
  renames: Option[RenameMap] = None)

/** Current form of the Firrtl Circuit
  *
  * Form is a measure of addition restrictions on the legality of a Firrtl
  * circuit.  There is a notion of "highness" and "lowness" implemented in the
  * compiler by extending scala.math.Ordered. "Lower" forms add additional
  * restrictions compared to "higher" forms. This means that "higher" forms are
  * strictly supersets of the "lower" forms. Thus, that any transform that
  * operates on [[HighForm]] can also operate on [[MidForm]] or [[LowForm]]
  */
sealed abstract class CircuitForm(private val value: Int) extends Ordered[CircuitForm] {
  // Note that value is used only to allow comparisons
  def compare(that: CircuitForm): Int = this.value - that.value
}
/** Chirrtl Form
  *
  * The form of the circuit emitted by Chisel. Not a true Firrtl form.
  * Includes cmem, smem, and mport IR nodes which enable declaring memories
  * separately form their ports. A "Higher" form than [[HighForm]]
  *
  * See [[CDefMemory]] and [[CDefMPort]]
  */
final case object ChirrtlForm extends CircuitForm(3)
/** High Form
  *
  * As detailed in the Firrtl specification
  * [[https://github.com/ucb-bar/firrtl/blob/master/spec/spec.pdf]]
  *
  * Also see [[firrtl.ir]]
  */
final case object HighForm extends CircuitForm(2)
/** Middle Form
  *
  * A "lower" form than [[HighForm]] with the following restrictions:
  *  - All widths must be explicit
  *  - All whens must be removed
  *  - There can only be a single connection to any element
  */
final case object MidForm extends CircuitForm(1)
/** Low Form
  *
  * The "lowest" form. In addition to the restrictions in [[MidForm]]:
  *  - All aggregate types (vector/bundle) must have been removed
  *  - All implicit truncations must be made explicit
  */
final case object LowForm extends CircuitForm(0)

/** The basic unit of operating on a Firrtl AST */
abstract class Transform {
  /** A convenience function useful for debugging and error messages */
  def name: String = this.getClass.getSimpleName
  /** The [[CircuitForm]] that this transform requires to operate on */
  def inputForm: CircuitForm
  /** The [[CircuitForm]] that this transform outputs */
  def outputForm: CircuitForm
  /** Perform the transform
    *
    * @param state Input Firrtl AST
    * @return A transformed Firrtl AST
    */
  def execute(state: CircuitState): CircuitState
  /** Convenience method to get annotations relevant to this Transform
    *
    * @param state The [[CircuitState]] form which to extract annotations
    * @return A collection of annotations
    */
  final def getMyAnnotations(state: CircuitState): Seq[Annotation] = state.annotations match {
    case Some(annotations) => annotations.get(this.getClass)
    case None => Nil
  }
}

trait SimpleRun extends LazyLogging {
  def runPasses(circuit: Circuit, passSeq: Seq[Pass]): Circuit =
    passSeq.foldLeft(circuit) { (c: Circuit, pass: Pass) =>
      val x = Utils.time(pass.name) { pass.run(c) }
      logger.debug(x.serialize)
      x
    }
}

/** For PassBased Transforms and Emitters
  *
  * @note passSeq accepts no arguments
  * @todo make passes accept CircuitState so annotations can pass data between them
  */
trait PassBased extends SimpleRun {
  def passSeq: Seq[Pass]
  def runPasses(circuit: Circuit): Circuit = runPasses(circuit, passSeq)
}

/** For transformations that are simply a sequence of passes */
abstract class PassBasedTransform extends Transform with PassBased {
  def execute(state: CircuitState): CircuitState = {
    require(state.form <= inputForm,
      s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}")
    CircuitState(runPasses(state.circuit), outputForm)
  }
}

/** Similar to a Transform except that it writes to a Writer instead of returning a
  * CircuitState
  */
abstract class Emitter {
  def emit(state: CircuitState, writer: Writer): Unit
}

object CompilerUtils {
  /** Generates a sequence of [[Transform]]s to lower a Firrtl circuit
    *
    * @param inputForm [[CircuitForm]] to lower from
    * @param outputForm [[CircuitForm to lower to
    * @return Sequence of transforms that will lower if outputForm is lower than inputForm
    */
  def getLoweringTransforms(inputForm: CircuitForm, outputForm: CircuitForm): Seq[Transform] = {
    // If outputForm is equal-to or higher than inputForm, nothing to lower
    if (outputForm >= inputForm) {
      Seq.empty
    } else {
      inputForm match {
        case ChirrtlForm => Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm)
        case HighForm =>
          Seq(new IRToWorkingIR, new ResolveAndCheck, new transforms.DedupModules,
              new HighFirrtlToMiddleFirrtl) ++ getLoweringTransforms(MidForm, outputForm)
        case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm)
        case LowForm => error("Internal Error! This shouldn't be possible") // should be caught by if above
      }
    }
  }

  /** Merge a Seq of lowering transforms with custom transforms
    *
    * Custom Transforms are inserted based on their [[Transform.inputForm]] and
    * [[Transform.outputForm]]. Custom transforms are inserted in order at the
    * last location in the Seq of transforms where previous.outputForm ==
    * customTransform.inputForm. If a customTransform outputs a higher form
    * than input, [[getLoweringTransforms]] is used to relower the circuit.
    *
    * @example
    *   {{{
    *     // Let Transforms be represented by CircuitForm => CircuitForm
    *     val A = HighForm => MidForm
    *     val B = MidForm => LowForm
    *     val lowering = List(A, B) // Assume these transforms are used by getLoweringTransforms
    *     // Some custom transforms
    *     val C = LowForm => LowForm
    *     val D = MidForm => MidForm
    *     val E = LowForm => HighForm
    *     // All of the following comparisons are true
    *     mergeTransforms(lowering, List(C)) == List(A, B, C)
    *     mergeTransforms(lowering, List(D)) == List(A, D, B)
    *     mergeTransforms(lowering, List(E)) == List(A, B, E, A, B)
    *     mergeTransforms(lowering, List(C, E)) == List(A, B, C, E, A, B)
    *     mergeTransforms(lowering, List(E, C)) == List(A, B, E, A, B, C)
    *     // Notice that in the following, custom transform order is NOT preserved (see note)
    *     mergeTransforms(lowering, List(C, D)) == List(A, D, B, C)
    *   }}}
    *
    * @note Order will be preserved for custom transforms so long as the
    * inputForm of a latter transforms is equal to or lower than the outputForm
    * of the previous transform.
    */
  def mergeTransforms(lowering: Seq[Transform], custom: Seq[Transform]): Seq[Transform] = {
    custom.foldLeft(lowering) { case (transforms, xform) =>
      val index = transforms lastIndexWhere (_.outputForm == xform.inputForm)
      assert(index >= 0 || xform.inputForm == ChirrtlForm, // If ChirrtlForm just put at front
        s"No transform in $lowering has outputForm ${xform.inputForm} as required by $xform")
      val (front, back) = transforms.splitAt(index + 1) // +1 because we want to be AFTER index
      front ++ List(xform) ++ getLoweringTransforms(xform.outputForm, xform.inputForm) ++ back
    }
  }

}

trait Compiler {
  def emitter: Emitter
  /** The sequence of transforms this compiler will execute
    * @note The inputForm of a given transform must be higher than or equal to the ouputForm of the
    *       preceding transform. See [[CircuitForm]]
    */
  def transforms: Seq[Transform]

  // Similar to (input|output)Form on [[Transform]] but derived from this Compiler's transforms
  def inputForm = transforms.head.inputForm
  def outputForm = transforms.last.outputForm

  private def transformsLegal(xforms: Seq[Transform]): Boolean =
    if (xforms.size < 2) {
      true
    } else {
      xforms.sliding(2, 1)
            .map { case Seq(p, n) => n.inputForm >= p.outputForm }
            .reduce(_ && _)
    }

  assert(transformsLegal(transforms),
    "Illegal Compiler, each transform must be able to accept the output of the previous transform!")

  /** Perform compilation
    *
    * @param state The Firrtl AST to compile
    * @param writer The java.io.Writer where the output of compilation will be emitted
    * @param customTransforms Any custom [[Transform]]s that will be inserted
    *   into the compilation process by [[CompilerUtils.mergeTransforms]]
    */
  def compile(state: CircuitState,
              writer: Writer,
              customTransforms: Seq[Transform] = Seq.empty): CircuitState = {
    val allTransforms = CompilerUtils.mergeTransforms(transforms, customTransforms)

    val finalState = allTransforms.foldLeft(state) { (in, xform) =>
      val result = Utils.time(s"***${xform.name}***") { xform.execute(in) }

      // Annotation propagation
      // TODO: This should be redone
      val inAnnotationMap = in.annotations getOrElse AnnotationMap(Seq.empty)
      val remappedAnnotations: Seq[Annotation] = result.renames match {
        case Some(RenameMap(rmap)) =>
          // For each key in the rename map (rmap), obtain the
          // corresponding annotations (in.annotationMap.get(from)). If any
          // annotations exist, for each annotation, create a sequence of
          // annotations with the names in rmap's value.
          for {
            (oldName, newNames) <- rmap.toSeq
            oldAnno <- inAnnotationMap.get(oldName)
            newAnno <- oldAnno.update(newNames)
          } yield newAnno
        case _ => inAnnotationMap.annotations
      }
      val resultAnnotations: Seq[Annotation] = result.annotations match {
        case None => Nil
        case Some(p) => p.annotations
      }
      val newAnnotations = AnnotationMap(remappedAnnotations ++ resultAnnotations)
      CircuitState(result.circuit, result.form, Some(newAnnotations))
    }

    emitter.emit(finalState, writer)
    finalState
  }
}