aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorjackkoenig2016-10-20 00:19:01 -0700
committerJack Koenig2016-11-04 13:29:09 -0700
commit8fa9429a6e916ab2a789f5d81fa803b022805b52 (patch)
treefac2efcbd0a68bfb1916f09afc7f003c7a3d6528 /src
parent62133264a788f46b319ebab9c31424b7e0536101 (diff)
Refactor Compilers and Transforms
* Transform Ids now handled by Class[_ <: Transform] instead of magic numbers * Transforms define inputForm and outputForm * Custom transforms can be inserted at runtime into compiler or the Driver * Current "built-in" custom transforms handled via above mechanism * Verilog-specific passes moved to the Verilog emitter
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Annotations.scala28
-rw-r--r--src/main/scala/firrtl/Compiler.scala265
-rw-r--r--src/main/scala/firrtl/Driver.scala17
-rw-r--r--src/main/scala/firrtl/Emitter.scala32
-rw-r--r--src/main/scala/firrtl/ExecutionOptionsManager.scala18
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala186
-rw-r--r--src/main/scala/firrtl/Utils.scala2
-rw-r--r--src/main/scala/firrtl/passes/Inline.scala48
-rw-r--r--src/main/scala/firrtl/passes/memlib/DecorateMems.scala22
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala21
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala24
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala89
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringTransform.scala29
-rw-r--r--src/test/resources/features/CustomTransform.fir33
-rw-r--r--src/test/scala/firrtlTests/AnnotationTests.scala22
-rw-r--r--src/test/scala/firrtlTests/AttachSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/CInferMDirSpec.scala11
-rw-r--r--src/test/scala/firrtlTests/CheckInitializationSpec.scala1
-rw-r--r--src/test/scala/firrtlTests/ChirrtlMemSpec.scala13
-rw-r--r--src/test/scala/firrtlTests/CompilerTests.scala15
-rw-r--r--src/test/scala/firrtlTests/CompilerUtilsSpec.scala76
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala1
-rw-r--r--src/test/scala/firrtlTests/CustomTransformSpec.scala51
-rw-r--r--src/test/scala/firrtlTests/ExpandWhensSpec.scala1
-rw-r--r--src/test/scala/firrtlTests/FirrtlSpec.scala7
-rw-r--r--src/test/scala/firrtlTests/InferReadWriteSpec.scala30
-rw-r--r--src/test/scala/firrtlTests/InlineInstancesTests.scala34
-rw-r--r--src/test/scala/firrtlTests/MultiThreadingSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/PassTests.scala64
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala26
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala7
-rw-r--r--src/test/scala/firrtlTests/VerilogEmitterTests.scala3
-rw-r--r--src/test/scala/firrtlTests/WidthSpec.scala1
-rw-r--r--src/test/scala/firrtlTests/WiringTests.scala1
-rw-r--r--src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala5
-rw-r--r--src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala1
-rw-r--r--src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala9
37 files changed, 778 insertions, 423 deletions
diff --git a/src/main/scala/firrtl/Annotations.scala b/src/main/scala/firrtl/Annotations.scala
index d47ce67e..d70732e6 100644
--- a/src/main/scala/firrtl/Annotations.scala
+++ b/src/main/scala/firrtl/Annotations.scala
@@ -115,12 +115,6 @@ object Annotations {
}
/**
- * Transform ID (TransID) associates an annotation with an instantiated
- * Firrtl compiler transform
- */
- case class TransID(id: Int)
-
- /**
* Permissibility defines the range of acceptable changes to the annotated component.
*/
trait Permissibility {
@@ -215,7 +209,7 @@ object Annotations {
/**
* Annotation associates with a given named circuit component (target) and a
- * given transformation (tID). Also defined are the legal ranges of changes
+ * given transformation (transform). Also defined are the legal ranges of changes
* to the associated component (Permissibility) and how the annotation
* propagates under such changes (Tenacity). Subclasses must implement the
* duplicate function to create the same annotation associated with a new
@@ -223,7 +217,7 @@ object Annotations {
*/
trait Annotation extends Permissibility with Tenacity {
def target: Named
- def tID: TransID
+ def transform: Class[_ <: Transform]
protected def duplicate(n: Named): Annotation
def serialize: String = this.toString
def update(tos: Seq[Named]): Seq[Annotation] = {
@@ -236,23 +230,23 @@ object Annotations {
* Container of all annotations for a Firrtl compiler.
*/
case class AnnotationMap(annotations: Seq[Annotation]) {
- type NamedMap = Map[Named, Map[TransID, Annotation]]
- type IDMap = Map[TransID, Map[Named, Annotation]]
+ type NamedMap = Map[Named, Map[Class[_], Annotation]]
+ type IDMap = Map[Class[_], Map[Named, Annotation]]
val (namedMap: NamedMap, idMap:IDMap) =
//annotations.foldLeft(Tuple2[NamedMap, IDMap](Map.empty, Map.empty)){
annotations.foldLeft((Map.empty: NamedMap, Map.empty: IDMap)){
(partialMaps: (NamedMap, IDMap), annotation: Annotation) => {
- val tIDToAnn = partialMaps._1.getOrElse(annotation.target, Map.empty)
- val pNMap = partialMaps._1 + (annotation.target -> (tIDToAnn + (annotation.tID -> annotation)))
+ val transformToAnn = partialMaps._1.getOrElse(annotation.target, Map.empty)
+ val pNMap = partialMaps._1 + (annotation.target -> (transformToAnn + (annotation.transform -> annotation)))
- val nToAnn = partialMaps._2.getOrElse(annotation.tID, Map.empty)
- val ptIDMap = partialMaps._2 + (annotation.tID -> (nToAnn + (annotation.target -> annotation)))
- Tuple2(pNMap, ptIDMap)
+ val nToAnn = partialMaps._2.getOrElse(annotation.transform, Map.empty)
+ val ptransformMap = partialMaps._2 + (annotation.transform -> (nToAnn + (annotation.target -> annotation)))
+ Tuple2(pNMap, ptransformMap)
}
}
- def get(id: TransID): Option[Map[Named, Annotation]] = idMap.get(id)
- def get(named: Named): Option[Map[TransID, Annotation]] = namedMap.get(named)
+ def get(id: Class[_]): Option[Map[Named, Annotation]] = idMap.get(id)
+ def get(named: Named): Option[Map[Class[_], Annotation]] = namedMap.get(named)
}
}
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index f566544e..9781972e 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -27,48 +27,249 @@ MODIFICATIONS.
package firrtl
-import com.typesafe.scalalogging.LazyLogging
-import scala.collection.mutable
+import logger.LazyLogging
import java.io.Writer
import Annotations._
import firrtl.ir.Circuit
+import passes.Pass
+
/**
* RenameMap maps old names to modified names. Generated by transformations
* that modify names
*/
case class RenameMap(map: Map[Named, Seq[Named]])
-// ===========================================
-// Transforms
-// -------------------------------------------
-
-case class TransformResult(
+/** Current State of the Circuit
+ *
+ * @constructor Creates a CircuitState object
+ * @param circuit The current state of the Firrtl AST
+ * @param form The current form of the circuit
+ * @param annotations The current collection of [[Annotations.Annotation]]
+ * @param renames A map of [[Annotations.Named]] things that have been renamed.
+ * Generally only a return value from [[Transform]]s
+ */
+case class CircuitState(
circuit: Circuit,
- renames: Option[RenameMap] = None,
- annotation: Option[AnnotationMap] = None)
+ form: CircuitForm,
+ annotations: Option[AnnotationMap] = None,
+ renames: Option[RenameMap] = None)
-// - Transforms a circuit
-// - Can consume multiple CircuitAnnotation's
-trait Transform {
- def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult
+/** Current form of the Firrtl Circuit
+ *
+ * Form is a measure of addition restrictions on the legality of a Firrtl
+ * circuit. There is a notion of "highness" and "lowness" implemented in the
+ * compiler by extending scala.math.Ordered. "Lower" forms add additional
+ * restrictions compared to "higher" forms. This means that "higher" forms are
+ * strictly supersets of the "lower" forms. Thus, that any transform that
+ * operates on [[HighForm]] can also operate on [[MidForm]] or [[LowForm]]
+ */
+sealed abstract class CircuitForm(private val value: Int) extends Ordered[CircuitForm] {
+ // Note that value is used only to allow comparisons
+ def compare(that: CircuitForm): Int = this.value - that.value
}
+/** Chirrtl Form
+ *
+ * The form of the circuit emitted by Chisel. Not a true Firrtl form.
+ * Includes cmem, smem, and mport IR nodes which enable declaring memories
+ * separately form their ports. A "Higher" form than [[HighForm]]
+ *
+ * See [[CDefMemory]] and [[CDefMPort]]
+ */
+final case object ChirrtlForm extends CircuitForm(3)
+/** High Form
+ *
+ * As detailed in the Firrtl specification
+ * [[https://github.com/ucb-bar/firrtl/blob/master/spec/spec.pdf]]
+ *
+ * Also see [[firrtl.ir]]
+ */
+final case object HighForm extends CircuitForm(2)
+/** Middle Form
+ *
+ * A "lower" form than [[HighForm]] with the following restrictions:
+ * - All widths must be explicit
+ * - All whens must be removed
+ * - There can only be a single connection to any element
+ */
+final case object MidForm extends CircuitForm(1)
+/** Low Form
+ *
+ * The "lowest" form. In addition to the restrictions in [[MidForm]]:
+ * - All aggregate types (vector/bundle) must have been removed
+ * - All implicit truncations must be made explicit
+ */
+final case object LowForm extends CircuitForm(0)
+/** The basic unit of operating on a Firrtl AST */
+abstract class Transform {
+ /** A convenience function useful for debugging and error messages */
+ def name: String = this.getClass.getSimpleName
+ /** The [[CircuitForm]] that this transform requires to operate on */
+ def inputForm: CircuitForm
+ /** The [[CircuitForm]] that this transform outputs */
+ def outputForm: CircuitForm
+ /** Perform the transform
+ *
+ * @param state Input Firrtl AST
+ * @return A transformed Firrtl AST
+ */
+ def execute(state: CircuitState): CircuitState
+ /** Convenience method to get annotations relevant to this Transform
+ *
+ * @param state The [[CircuitState]] form which to extract annotations
+ * @return A collection of annotations
+ */
+ final def getMyAnnotations(state: CircuitState): Option[Map[Named, Annotation]] =
+ for {
+ annotations <- state.annotations
+ myAnnotations <- annotations.get(this.getClass)
+ } yield myAnnotations
+}
-// ===========================================
-// Compilers
-// -------------------------------------------
+trait SimpleRun extends LazyLogging {
+ def runPasses(circuit: Circuit, passSeq: Seq[Pass]): Circuit =
+ passSeq.foldLeft(circuit) { (c: Circuit, pass: Pass) =>
+ val x = Utils.time(pass.name) { pass.run(c) }
+ logger.debug(x.serialize)
+ x
+ }
+}
+
+/** For PassBased Transforms and Emitters
+ *
+ * @note passSeq accepts no arguments
+ * @todo make passes accept CircuitState so annotations can pass data between them
+ */
+trait PassBased extends SimpleRun {
+ def passSeq: Seq[Pass]
+ def runPasses(circuit: Circuit): Circuit = runPasses(circuit, passSeq)
+}
-case class CompilerResult(circuit: Circuit, annotationMap: AnnotationMap)
+/** For transformations that are simply a sequence of passes */
+abstract class PassBasedTransform extends Transform with PassBased {
+ def execute(state: CircuitState): CircuitState = {
+ require(state.form <= inputForm,
+ s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}")
+ CircuitState(runPasses(state.circuit), outputForm)
+ }
+}
+
+/** Similar to a Transform except that it writes to a Writer instead of returning a
+ * CircuitState
+ */
+abstract class Emitter {
+ def emit(state: CircuitState, writer: Writer): Unit
+}
+
+object CompilerUtils {
+ /** Generates a sequence of [[Transform]]s to lower a Firrtl circuit
+ *
+ * @param inputForm [[CircuitForm]] to lower from
+ * @param outputForm [[CircuitForm to lower to
+ * @return Sequence of transforms that will lower if outputForm is lower than inputForm
+ */
+ def getLoweringTransforms(inputForm: CircuitForm, outputForm: CircuitForm): Seq[Transform] = {
+ // If outputForm is equal-to or higher than inputForm, nothing to lower
+ if (outputForm >= inputForm) {
+ Seq.empty
+ } else {
+ inputForm match {
+ case ChirrtlForm => Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm)
+ case HighForm => Seq(new IRToWorkingIR, new ResolveAndCheck, new HighFirrtlToMiddleFirrtl) ++
+ getLoweringTransforms(MidForm, outputForm)
+ case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm)
+ case LowForm => error("Internal Error! This shouldn't be possible") // should be caught by if above
+ }
+ }
+ }
+
+ /** Merge a Seq of lowering transforms with custom transforms
+ *
+ * Custom Transforms are inserted based on their [[Transform.inputForm]] and
+ * [[Transform.outputForm]]. Custom transforms are inserted in order at the
+ * last location in the Seq of transforms where previous.outputForm ==
+ * customTransform.inputForm. If a customTransform outputs a higher form
+ * than input, [[getLoweringTransforms]] is used to relower the circuit.
+ *
+ * @example
+ * {{{
+ * // Let Transforms be represented by CircuitForm => CircuitForm
+ * val A = HighForm => MidForm
+ * val B = MidForm => LowForm
+ * val lowering = List(A, B) // Assume these transforms are used by getLoweringTransforms
+ * // Some custom transforms
+ * val C = LowForm => LowForm
+ * val D = MidForm => MidForm
+ * val E = LowForm => HighForm
+ * // All of the following comparisons are true
+ * mergeTransforms(lowering, List(C)) == List(A, B, C)
+ * mergeTransforms(lowering, List(D)) == List(A, D, B)
+ * mergeTransforms(lowering, List(E)) == List(A, B, E, A, B)
+ * mergeTransforms(lowering, List(C, E)) == List(A, B, C, E, A, B)
+ * mergeTransforms(lowering, List(E, C)) == List(A, B, E, A, B, C)
+ * // Notice that in the following, custom transform order is NOT preserved (see note)
+ * mergeTransforms(lowering, List(C, D)) == List(A, D, B, C)
+ * }}}
+ *
+ * @note Order will be preserved for custom transforms so long as the
+ * inputForm of a latter transforms is equal to or lower than the outputForm
+ * of the previous transform.
+ */
+ def mergeTransforms(lowering: Seq[Transform], custom: Seq[Transform]): Seq[Transform] = {
+ custom.foldLeft(lowering) { case (transforms, xform) =>
+ val index = transforms lastIndexWhere (_.outputForm == xform.inputForm)
+ assert(index >= 0 || xform.inputForm == ChirrtlForm, // If ChirrtlForm just put at front
+ s"No transform in $lowering has outputForm ${xform.inputForm} as required by $xform")
+ val (front, back) = transforms.splitAt(index + 1) // +1 because we want to be AFTER index
+ front ++ List(xform) ++ getLoweringTransforms(xform.outputForm, xform.inputForm) ++ back
+ }
+ }
+
+}
-// - A sequence of transformations
-// - Call compile to executes each transformation in sequence onto
-// a given circuit.
trait Compiler {
- def transforms(w: Writer): Seq[Transform]
- def compile(circuit: Circuit, annotationMap: AnnotationMap, writer: Writer): CompilerResult =
- (transforms(writer) foldLeft CompilerResult(circuit, annotationMap)){ (in, xform) =>
- val result = xform.execute(in.circuit, in.annotationMap)
+ def emitter: Emitter
+ /** The sequence of transforms this compiler will execute
+ * @note The inputForm of a given transform must be higher than or equal to the ouputForm of the
+ * preceding transform. See [[CircuitForm]]
+ */
+ def transforms: Seq[Transform]
+
+ // Similar to (input|output)Form on [[Transform]] but derived from this Compiler's transforms
+ def inputForm = transforms.head.inputForm
+ def outputForm = transforms.last.outputForm
+
+ private def transformsLegal(xforms: Seq[Transform]): Boolean =
+ if (xforms.size < 2) {
+ true
+ } else {
+ xforms.sliding(2, 1)
+ .map { case Seq(p, n) => n.inputForm >= p.outputForm }
+ .reduce(_ && _)
+ }
+
+ assert(transformsLegal(transforms),
+ "Illegal Compiler, each transform must be able to accept the output of the previous transform!")
+
+ /** Perform compilation
+ *
+ * @param state The Firrtl AST to compile
+ * @param writer The java.io.Writer where the output of compilation will be emitted
+ * @param customTransforms Any custom [[Transform]]s that will be inserted
+ * into the compilation process by [[CompilerUtils.mergeTransforms]]
+ */
+ def compile(state: CircuitState,
+ writer: Writer,
+ customTransforms: Seq[Transform] = Seq.empty): CircuitState = {
+ val allTransforms = CompilerUtils.mergeTransforms(transforms, customTransforms)
+
+ val finalState = allTransforms.foldLeft(state) { (in, xform) =>
+ val result = Utils.time(s"***${xform.name}***") { xform.execute(in) }
+
+ // Annotation propagation
+ // TODO: This should be redone
+ val inAnnotationMap = in.annotations getOrElse AnnotationMap(Seq.empty)
val remappedAnnotations: Seq[Annotation] = result.renames match {
case Some(RenameMap(rmap)) =>
// For each key in the rename map (rmap), obtain the
@@ -77,18 +278,22 @@ trait Compiler {
// annotations with the names in rmap's value.
for {
(oldName, newNames) <- rmap.toSeq
- tID2OldAnnos <- in.annotationMap.get(oldName).toSeq
- oldAnno <- tID2OldAnnos.values
+ transform2OldAnnos <- inAnnotationMap.get(oldName).toSeq
+ oldAnno <- transform2OldAnnos.values
newAnno <- oldAnno.update(newNames)
} yield newAnno
- case _ => in.annotationMap.annotations
+ case _ => inAnnotationMap.annotations
}
- val resultAnnotations: Seq[Annotation] = result.annotation match {
+ val resultAnnotations: Seq[Annotation] = result.annotations match {
case None => Nil
case Some(p) => p.annotations
}
- CompilerResult(result.circuit,
- new AnnotationMap(remappedAnnotations ++ resultAnnotations))
+ val newAnnotations = AnnotationMap(remappedAnnotations ++ resultAnnotations)
+ CircuitState(result.circuit, result.form, Some(newAnnotations))
}
+
+ emitter.emit(finalState, writer)
+ finalState
+ }
}
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala
index ba5527b4..293ac4fd 100644
--- a/src/main/scala/firrtl/Driver.scala
+++ b/src/main/scala/firrtl/Driver.scala
@@ -30,7 +30,8 @@ import scala.collection._
* firrtl.Driver.execute(Array("--top-name Dummy --compiler verilog".split(" +"))
* }}}
* each approach has its own endearing aspects
- * @see firrtlTests.DriverSpec.scala in the test directory for a lot more examples
+ * @see firrtlTests/DriverSpec.scala in the test directory for a lot more examples
+ * @see [[CompilerUtils.mergeTransforms]] to see how customTransformations are inserted
*/
object Driver {
@@ -42,11 +43,15 @@ object Driver {
output: String,
compiler: Compiler,
infoMode: InfoMode = IgnoreInfo,
+ customTransforms: Seq[Transform] = Seq.empty,
annotations: AnnotationMap = new AnnotationMap(Seq.empty)
): String = {
val parsedInput = Parser.parse(Source.fromFile(input).getLines(), infoMode)
val outputBuffer = new java.io.CharArrayWriter
- compiler.compile(parsedInput, annotations, outputBuffer)
+ compiler.compile(
+ CircuitState(parsedInput, ChirrtlForm, Some(annotations)),
+ outputBuffer,
+ customTransforms)
val outputFile = new java.io.PrintWriter(output)
val outputString = outputBuffer.toString
@@ -108,7 +113,11 @@ object Driver {
val parsedInput = Parser.parse(firrtlSource, firrtlConfig.infoMode)
val outputBuffer = new java.io.CharArrayWriter
- firrtlConfig.compiler.compile(parsedInput, new AnnotationMap(firrtlConfig.annotations), outputBuffer)
+ firrtlConfig.compiler.compile(
+ CircuitState(parsedInput, ChirrtlForm, Some(new AnnotationMap(firrtlConfig.annotations))),
+ outputBuffer,
+ firrtlConfig.customTransforms
+ )
val outputFileName = firrtlConfig.getOutputFileName(optionsManager)
val outputFile = new java.io.PrintWriter(outputFileName)
@@ -193,4 +202,4 @@ object FileUtils {
}
}
}
-} \ No newline at end of file
+}
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 7b198149..1d64dc91 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -47,12 +47,8 @@ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap, HashSet}
case class EmitterException(message: String) extends PassException(message)
-trait Emitter extends LazyLogging {
- def run(c: Circuit, w: Writer)
-}
-
-object FIRRTLEmitter extends Emitter {
- def run(c: Circuit, w: Writer) = w.write(c.serialize)
+class FirrtlEmitter extends Emitter {
+ def emit(state: CircuitState, writer: Writer): Unit = writer.write(state.circuit.serialize)
}
case class VRandom(width: BigInt) extends Expression {
@@ -65,7 +61,7 @@ case class VRandom(width: BigInt) extends Expression {
def mapWidth(f: Width => Width): Expression = this
}
-class VerilogEmitter extends Emitter {
+class VerilogEmitter extends Emitter with PassBased {
val tab = " "
def AND(e1: WrappedExpression, e2: WrappedExpression): Expression = {
if (e1 == e2) e1.e1
@@ -590,12 +586,18 @@ class VerilogEmitter extends Emitter {
"`endif\n"))
}
- def run(c: Circuit, w: Writer) = {
- emit_preamble(w)
- val moduleMap = (c.modules map (m => m.name -> m)).toMap
- c.modules foreach {
- case (m: Module) => emit_verilog(m, moduleMap)(w)
- case (m: ExtModule) =>
- }
- }
+ def passSeq = Seq(
+ passes.VerilogWrap,
+ passes.VerilogRename,
+ passes.VerilogPrep)
+
+ def emit(state: CircuitState, writer: Writer): Unit = {
+ val circuit = runPasses(state.circuit)
+ emit_preamble(writer)
+ val moduleMap = (circuit.modules map (m => m.name -> m)).toMap
+ circuit.modules foreach {
+ case (m: Module) => emit_verilog(m, moduleMap)(writer)
+ case (m: ExtModule) =>
+ }
+ }
}
diff --git a/src/main/scala/firrtl/ExecutionOptionsManager.scala b/src/main/scala/firrtl/ExecutionOptionsManager.scala
index e4954610..21a9cc50 100644
--- a/src/main/scala/firrtl/ExecutionOptionsManager.scala
+++ b/src/main/scala/firrtl/ExecutionOptionsManager.scala
@@ -140,6 +140,7 @@ case class FirrtlExecutionOptions(
infoModeName: String = "append",
inferRW: Seq[String] = Seq.empty,
firrtlSource: Option[String] = None,
+ customTransforms: Seq[Transform] = List.empty,
annotations: List[Annotation] = List.empty)
extends ComposableOptions {
@@ -249,14 +250,17 @@ trait HasFirrtlOptions {
val newAnnotations = x.map { value =>
value.split('.') match {
case Array(circuit) =>
- passes.InlineAnnotation(CircuitName(circuit), TransID(0))
+ passes.InlineAnnotation(CircuitName(circuit))
case Array(circuit, module) =>
- passes.InlineAnnotation(ModuleName(module, CircuitName(circuit)), TransID(0))
+ passes.InlineAnnotation(ModuleName(module, CircuitName(circuit)))
case Array(circuit, module, inst) =>
- passes.InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit))), TransID(0))
+ passes.InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit))))
}
}
- firrtlOptions = firrtlOptions.copy(annotations = firrtlOptions.annotations ++ newAnnotations)
+ firrtlOptions = firrtlOptions.copy(
+ annotations = firrtlOptions.annotations ++ newAnnotations,
+ customTransforms = firrtlOptions.customTransforms :+ new passes.InlineInstances
+ )
}
.text {
"""Inline one or more module (comma separated, no spaces) module looks like "MyModule" or "MyModule.myinstance"""
@@ -267,7 +271,8 @@ trait HasFirrtlOptions {
.valueName ("<circuit>")
.foreach { x =>
firrtlOptions = firrtlOptions.copy(
- annotations = firrtlOptions.annotations :+ InferReadWriteAnnotation(x, TransID(-1))
+ annotations = firrtlOptions.annotations :+ InferReadWriteAnnotation(x),
+ customTransforms = firrtlOptions.customTransforms :+ new passes.memlib.InferReadWrite
)
}.text {
"Enable readwrite port inference for the target circuit"
@@ -278,7 +283,8 @@ trait HasFirrtlOptions {
.valueName ("-c:<circuit>:-i:<filename>:-o:<filename>")
.foreach { x =>
firrtlOptions = firrtlOptions.copy(
- annotations = firrtlOptions.annotations :+ ReplSeqMemAnnotation(x, TransID(-2))
+ annotations = firrtlOptions.annotations :+ ReplSeqMemAnnotation(x),
+ customTransforms = firrtlOptions.customTransforms :+ new passes.memlib.ReplSeqMem
)
}
.text {
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 446df6d0..986ebd9f 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -27,59 +27,38 @@ MODIFICATIONS.
package firrtl
-import java.io.Writer
-import firrtl.passes.Pass
-import firrtl.ir.Circuit
-import Annotations._
-import logger.LazyLogging
+sealed abstract class CoreTransform extends PassBasedTransform
-// ===========================================
-// Utility Traits
-// -------------------------------------------
-// Valid if all passes in transformation:
-// 1) Don't produce annotations
-// 2) Don't consume annotations
-// 3) No component or module names are renamed
-trait SimpleRun extends LazyLogging {
- def run (circuit: Circuit, passes: Seq[Pass]): TransformResult = {
- val result = (passes foldLeft circuit){ (c: Circuit, pass: Pass) =>
- val name = pass.name
- val x = Utils.time(name)(pass.run(c))
- logger.debug(x.serialize)
- x
- }
- TransformResult(result)
- }
-}
-
-// ===========================================
-// Lowering Transforms
-// -------------------------------------------
-// This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting
-// circuit has only IR nodes, not WIR.
-// TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL
-class Chisel3ToHighFirrtl extends Transform with SimpleRun {
- val passSeq = Seq(
+/** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting
+ * circuit has only IR nodes, not WIR.
+ * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL
+ */
+class ChirrtlToHighFirrtl extends CoreTransform {
+ def inputForm = ChirrtlForm
+ def outputForm = HighForm
+ def passSeq = Seq(
passes.CheckChirrtl,
passes.CInferTypes,
passes.CInferMDir,
passes.RemoveCHIRRTL)
- def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult =
- run(circuit, passSeq)
}
-// Converts from the bare intermediate representation (ir.scala)
-// to a working representation (WIR.scala)
-class IRToWorkingIR extends Transform with SimpleRun {
- val passSeq = Seq(passes.ToWorkingIR)
- def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult =
- run(circuit, passSeq)
+/** Converts from the bare intermediate representation (ir.scala)
+ * to a working representation (WIR.scala)
+ */
+class IRToWorkingIR extends CoreTransform {
+ def inputForm = HighForm
+ def outputForm = HighForm
+ def passSeq = Seq(passes.ToWorkingIR)
}
-// Resolves types, kinds, and genders, and checks the circuit legality.
-// Operates on working IR nodes and high Firrtl.
-class ResolveAndCheck extends Transform with SimpleRun {
- val passSeq = Seq(
+/** Resolves types, kinds, and genders, and checks the circuit legality.
+ * Operates on working IR nodes and high Firrtl.
+ */
+class ResolveAndCheck extends CoreTransform {
+ def inputForm = HighForm
+ def outputForm = HighForm
+ def passSeq = Seq(
passes.CheckHighForm,
passes.ResolveKinds,
passes.InferTypes,
@@ -91,16 +70,17 @@ class ResolveAndCheck extends Transform with SimpleRun {
passes.CheckGenders,
passes.InferWidths,
passes.CheckWidths)
- def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult =
- run(circuit, passSeq)
}
-// Expands aggregate connects, removes dynamic accesses, and when
-// statements. Checks for uninitialized values. Must accept a
-// well-formed graph.
-// Operates on working IR nodes.
-class HighFirrtlToMiddleFirrtl extends Transform with SimpleRun {
- val passSeq = Seq(
+/** Expands aggregate connects, removes dynamic accesses, and when
+ * statements. Checks for uninitialized values. Must accept a
+ * well-formed graph.
+ * Operates on working IR nodes.
+ */
+class HighFirrtlToMiddleFirrtl extends CoreTransform {
+ def inputForm = HighForm
+ def outputForm = MidForm
+ def passSeq = Seq(
passes.PullMuxes,
passes.ReplaceAccesses,
passes.ExpandConnects,
@@ -112,16 +92,17 @@ class HighFirrtlToMiddleFirrtl extends Transform with SimpleRun {
passes.ResolveGenders,
passes.InferWidths,
passes.CheckWidths)
- def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult =
- run(circuit, passSeq)
}
-// Expands all aggregate types into many ground-typed components. Must
-// accept a well-formed graph of only middle Firrtl features.
-// Operates on working IR nodes.
-// TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL
-class MiddleFirrtlToLowFirrtl extends Transform with SimpleRun {
- val passSeq = Seq(
+/** Expands all aggregate types into many ground-typed components. Must
+ * accept a well-formed graph of only middle Firrtl features.
+ * Operates on working IR nodes.
+ * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL
+ */
+class MiddleFirrtlToLowFirrtl extends CoreTransform {
+ def inputForm = MidForm
+ def outputForm = LowForm
+ def passSeq = Seq(
passes.LowerTypes,
passes.ResolveKinds,
passes.InferTypes,
@@ -129,87 +110,48 @@ class MiddleFirrtlToLowFirrtl extends Transform with SimpleRun {
passes.InferWidths,
passes.ConvertFixedToSInt,
passes.Legalize)
- def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult =
- run(circuit, passSeq)
}
-// Emits Verilog.
-// First optimizes for verilog width semantics with custom Primops,
-// then splits complex expressions into temporary nodes. Finally,
-// renames names that conflict with Verilog keywords.
-// Operates on working IR nodes.
-// TODO(izraelevitz): Create RenameMap from VerilogRename
-class EmitVerilogFromLowFirrtl(val writer: Writer) extends Transform with SimpleRun {
- val passSeq = Seq(
+/** Runs a series of optimization passes on LowFirrtl
+ * @note This is currently required for correct Verilog emission
+ * TODO Fix the above note
+ */
+class LowFirrtlOptimization extends CoreTransform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+ def passSeq = Seq(
passes.RemoveValidIf,
passes.ConstProp,
passes.PadWidths,
passes.ConstProp,
passes.Legalize,
- passes.VerilogWrap,
- passes.memlib.VerilogMemDelays,
+ passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
passes.ConstProp,
passes.SplitExpressions,
passes.CommonSubexpressionElimination,
- passes.DeadCodeElimination,
- passes.VerilogRename,
- passes.VerilogPrep)
- def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = {
- val result = run(circuit, passSeq)
- (new VerilogEmitter).run(result.circuit, writer)
- result
- }
+ passes.DeadCodeElimination)
}
-// Emits Firrtl.
-// Operates on WIR/IR nodes.
-class EmitFirrtl(val writer: Writer) extends Transform {
- def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = {
- FIRRTLEmitter.run(circuit, writer)
- TransformResult(circuit)
- }
-}
+import CompilerUtils.getLoweringTransforms
-// ===========================================
-// Lowering Compilers
-// -------------------------------------------
-// Emits input circuit
-// Will replace Chirrtl constructs with Firrtl
+/** Emits input circuit
+ * Will replace Chirrtl constructs with Firrtl
+ */
class HighFirrtlCompiler extends Compiler {
- def transforms(writer: Writer): Seq[Transform] = Seq(
- new Chisel3ToHighFirrtl,
- new IRToWorkingIR,
- new EmitFirrtl(writer)
- )
+ def emitter = new FirrtlEmitter
+ def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, HighForm)
}
-// Emits lowered input circuit
+/** Emits lowered input circuit */
class LowFirrtlCompiler extends Compiler {
- def transforms(writer: Writer): Seq[Transform] = Seq(
- new Chisel3ToHighFirrtl,
- new IRToWorkingIR,
- new passes.InlineInstances(TransID(0)),
- new ResolveAndCheck,
- new HighFirrtlToMiddleFirrtl,
- new passes.memlib.InferReadWrite(TransID(-1)),
- new passes.memlib.ReplSeqMem(TransID(-2)),
- new MiddleFirrtlToLowFirrtl,
- new EmitFirrtl(writer)
- )
+ def emitter = new FirrtlEmitter
+ def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, LowForm)
}
-// Emits Verilog
+/** Emits Verilog */
class VerilogCompiler extends Compiler {
- def transforms(writer: Writer): Seq[Transform] = Seq(
- new Chisel3ToHighFirrtl,
- new IRToWorkingIR,
- new ResolveAndCheck,
- new HighFirrtlToMiddleFirrtl,
- new passes.memlib.InferReadWrite(TransID(-1)),
- new passes.memlib.ReplSeqMem(TransID(-2)),
- new MiddleFirrtlToLowFirrtl,
- new passes.InlineInstances(TransID(0)),
- new EmitVerilogFromLowFirrtl(writer)
- )
+ def emitter = new VerilogEmitter
+ def transforms: Seq[Transform] =
+ getLoweringTransforms(ChirrtlForm, LowForm) :+ (new LowFirrtlOptimization)
}
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 22a3eac6..7c023ac8 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -44,7 +44,7 @@ import firrtl.WrappedType._
import scala.collection.mutable
import scala.collection.mutable.{StringBuilder, ArrayBuffer, LinkedHashMap, HashMap, HashSet}
import java.io.PrintWriter
-import com.typesafe.scalalogging.LazyLogging
+import logger.LazyLogging
//import scala.reflect.runtime.universe._
class FIRRTLException(str: String) extends Exception(str)
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala
index 5c80baff..c741dc06 100644
--- a/src/main/scala/firrtl/passes/Inline.scala
+++ b/src/main/scala/firrtl/passes/Inline.scala
@@ -9,34 +9,37 @@ import firrtl.Annotations._
import scala.collection.mutable
// Tags an annotation to be consumed by this pass
-case class InlineAnnotation(target: Named, tID: TransID) extends Annotation with Loose with Unstable {
+case class InlineAnnotation(target: Named) extends Annotation with Loose with Unstable {
def duplicate(n: Named) = this.copy(target=n)
+ def transform = classOf[InlineInstances]
}
// Only use on legal Firrtl. Specifically, the restriction of
// instance loops must have been checked, or else this pass can
// infinitely recurse
-class InlineInstances (transID: TransID) extends Transform {
+class InlineInstances extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
val inlineDelim = "$"
- def name = "Inline Instances"
- def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = {
- annotationMap.get(transID) match {
- case None => TransformResult(circuit, None, None)
- case Some(map) =>
- val moduleNames = mutable.HashSet[ModuleName]()
- val instanceNames = mutable.HashSet[ComponentName]()
- map.values.foreach {x: Annotation => x match {
- case InlineAnnotation(ModuleName(mod, cir), _) => moduleNames += ModuleName(mod, cir)
- case InlineAnnotation(ComponentName(com, mod), _) => instanceNames += ComponentName(com, mod)
- case _ => throw new PassException("Annotation must be InlineAnnotation")
- }}
- check(circuit, moduleNames.toSet, instanceNames.toSet)
- run(circuit, moduleNames.toSet, instanceNames.toSet)
+ override def name = "Inline Instances"
- // Default behavior is to error if more than one annotation for inlining
- // This could potentially change
- case _ => throw new PassException("Found more than one circuit annotation of InlineCAKind!")
+ private def collectAnns(anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
+ anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) {
+ case ((modNames, instNames), ann) => ann match {
+ case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
+ case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
+ case _ => throw new PassException("Annotation must be InlineAnnotation")
+ }
}
+
+ def execute(state: CircuitState): CircuitState = {
+ // TODO Add error check for more than one annotation for inlining
+ // TODO Propagate other annotations
+ val result = for {
+ myAnnotations <- getMyAnnotations(state)
+ (modNames, instNames) = collectAnns(myAnnotations.values)
+ } yield run(state.circuit, modNames, instNames)
+ result getOrElse state // Return state if nothing to do
}
// Checks the following properties:
@@ -78,7 +81,10 @@ class InlineInstances (transID: TransID) extends Transform {
if (errors.nonEmpty) throw new PassExceptions(errors)
}
- def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): TransformResult = {
+ def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): CircuitState = {
+ // Check annotations and circuit match up
+ check(c, modsToInline, instsToInline)
+
// ---- Rename functions/data ----
val renameMap = mutable.HashMap[Named,Seq[Named]]()
// Updates renameMap with new names
@@ -168,6 +174,6 @@ class InlineInstances (transID: TransID) extends Transform {
val top = c.modules.find(m => m.name == c.main).get
onModule(top)
val modulesx = c.modules.map(m => inlinedModules(m.name))
- TransformResult(Circuit(c.info, modulesx, c.main), Some(RenameMap(renameMap.toMap)), None)
+ CircuitState(Circuit(c.info, modulesx, c.main), LowForm, None, Some(RenameMap(renameMap.toMap)))
}
}
diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
index 10cc8f88..c98dd4ca 100644
--- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
+++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
@@ -5,20 +5,22 @@ import ir._
import Annotations._
import wiring._
-class CreateMemoryAnnotations(reader: Option[YamlFileReader], replaceID: TransID, wiringID: TransID) extends Transform {
- def name = "Create Memory Annotations"
- def execute(c: Circuit, map: AnnotationMap): TransformResult = reader match {
- case None => TransformResult(c)
+class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform {
+ def inputForm = MidForm
+ def outputForm = MidForm
+ override def name = "Create Memory Annotations"
+ def execute(state: CircuitState): CircuitState = reader match {
+ case None => state
case Some(r) =>
import CustomYAMLProtocol._
r.parse[Config] match {
case Seq(config) =>
- val cN = CircuitName(c.main)
- val top = TopAnnotation(ModuleName(config.top.name, cN), wiringID)
- val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)), wiringID)
- val pin = PinAnnotation(cN, replaceID, config.pin.name)
- TransformResult(c, None, Some(AnnotationMap(Seq(top, source, pin))))
- case Nil => TransformResult(c, None, None)
+ val cN = CircuitName(state.circuit.main)
+ val top = TopAnnotation(ModuleName(config.top.name, cN))
+ val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)))
+ val pin = PinAnnotation(cN, config.pin.name)
+ state.copy(annotations = Some(AnnotationMap(Seq(top, source, pin))))
+ case Nil => state
case _ => error("Can only have one config in yaml file")
}
}
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index 28291135..2d6f4e96 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -38,10 +38,10 @@ import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin}
import WrappedExpression.weq
import Annotations._
-case class InferReadWriteAnnotation(t: String, tID: TransID)
- extends Annotation with Loose with Unstable {
+case class InferReadWriteAnnotation(t: String) extends Annotation with Loose with Unstable {
val target = CircuitName(t)
def duplicate(n: Named) = this.copy(t=n.name)
+ def transform = classOf[InferReadWrite]
}
// This pass examine the enable signals of the read & write ports of memories
@@ -168,7 +168,9 @@ object InferReadWritePass extends Pass {
// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl"
// To use this transform, circuit name should be annotated with its TransId.
-class InferReadWrite(transID: TransID) extends Transform with SimpleRun {
+class InferReadWrite extends Transform with PassBased {
+ def inputForm = MidForm
+ def outputForm = MidForm
def passSeq = Seq(
InferReadWritePass,
CheckInitialization,
@@ -176,11 +178,12 @@ class InferReadWrite(transID: TransID) extends Transform with SimpleRun {
ResolveKinds,
ResolveGenders
)
- def execute(c: Circuit, map: AnnotationMap) = map get transID match {
- case Some(p) => p get CircuitName(c.main) match {
- case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq)
- case _ => sys.error("Unexpected annotation for InferReadWrite")
- }
- case _ => TransformResult(c)
+ def execute(state: CircuitState): CircuitState = {
+ val result = for {
+ myAnnotations <- getMyAnnotations(state)
+ InferReadWriteAnnotation(_) <- myAnnotations get CircuitName(state.circuit.main)
+ resCircuit = runPasses(state.circuit)
+ } yield state.copy(circuit = resCircuit)
+ result getOrElse state // Return state if nothing to do
}
}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
index 9ab496d2..ae872639 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
@@ -16,7 +16,8 @@ import wiring._
/** Annotates the name of the pin to add for WiringTransform
*/
-case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends Annotation with Loose with Unstable {
+case class PinAnnotation(target: CircuitName, pin: String) extends Annotation with Loose with Unstable {
+ def transform = classOf[ReplaceMemMacros]
def duplicate(n: Named) = n match {
case n: CircuitName => this.copy(target = n)
case _ => throwInternalError
@@ -27,8 +28,10 @@ case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends
* This will not generate wmask ports if not needed.
* Creates the minimum # of black boxes needed by the design.
*/
-class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) extends Transform {
- def name = "Replace Memory Macros"
+class ReplaceMemMacros(writer: ConfWriter) extends Transform {
+ override def name = "Replace Memory Macros"
+ def inputForm = MidForm
+ def outputForm = MidForm
/** Return true if mask granularity is per bit, false if per byte or unspecified
*/
@@ -206,7 +209,8 @@ class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) ext
map updateStmtRefs(memPortMap))
}
- def execute(c: Circuit, map: AnnotationMap): TransformResult = {
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
val namespace = Namespace(c)
val memMods = new Modules
val nameMap = new NameMap
@@ -214,15 +218,15 @@ class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) ext
val modules = c.modules map updateMemMods(namespace, nameMap, memMods)
// print conf
writer.serialize()
- val pin = map get myID match {
- case Some(p) =>
+ val pin = getMyAnnotations(state) match {
+ case Some(p) =>
p.values.head match {
- case PinAnnotation(c, _, pin) => pin
+ case PinAnnotation(c, pin) => pin
case _ => error(s"Bad Annotations: ${p.values}")
}
case None => "pin"
}
- val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), wiringID, pin) }
- TransformResult(c.copy(modules = modules ++ memMods), None, Some(AnnotationMap(annos)))
- }
+ val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), pin) }
+ CircuitState(c.copy(modules = modules ++ memMods), inputForm, Some(AnnotationMap(annos)))
+ }
}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index 01f020f5..818bd9cc 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -61,8 +61,7 @@ class ConfWriter(filename: String) {
}
}
-case class ReplSeqMemAnnotation(t: String, tID: TransID)
- extends Annotation with Loose with Unstable {
+case class ReplSeqMemAnnotation(t: String) extends Annotation with Loose with Unstable {
val usage = """
[Optional] ReplSeqMem
@@ -91,52 +90,60 @@ Optional Arguments:
)
val target = CircuitName(passCircuit)
def duplicate(n: Named) = this copy (t = t.replace(s"-c:$passCircuit", s"-c:${n.name}"))
+ def transform = classOf[ReplSeqMem]
}
-case class SimpleTransform(p: Pass) extends Transform {
- def execute(c: Circuit, map: AnnotationMap): TransformResult =
- TransformResult(p.run(c))
+class SimpleTransform(p: Pass, form: CircuitForm) extends Transform {
+ def inputForm = form
+ def outputForm = form
+ def execute(state: CircuitState): CircuitState = state.copy(circuit = p.run(state.circuit))
}
-class ReplSeqMem(transID: TransID) extends Transform with SimpleRun {
+class SimpleMidTransform(p: Pass) extends SimpleTransform(p, MidForm)
+
+// SimpleRun instead of PassBased because of the arguments to passSeq
+class ReplSeqMem extends Transform with SimpleRun {
+ def inputForm = MidForm
+ def outputForm = MidForm
def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] =
- Seq(SimpleTransform(Legalize),
- SimpleTransform(ToMemIR),
- SimpleTransform(ResolveMaskGranularity),
- SimpleTransform(RenameAnnotatedMemoryPorts),
- SimpleTransform(ResolveMemoryReference),
- new CreateMemoryAnnotations(inConfigFile, TransID(-7), TransID(-8)),
- new ReplaceMemMacros(outConfigFile, TransID(-7), TransID(-8)),
- new WiringTransform(TransID(-8)),
- SimpleTransform(RemoveEmpty),
- SimpleTransform(CheckInitialization),
- SimpleTransform(InferTypes),
- SimpleTransform(Uniquify),
- SimpleTransform(ResolveKinds),
- SimpleTransform(ResolveGenders))
- def run(circuit: Circuit, map: AnnotationMap, xForms: Seq[Transform]): TransformResult = {
- (xForms.foldLeft(TransformResult(circuit, None, Some(map)))) { case (tr: TransformResult, xForm: Transform) =>
- val x = xForm.execute(tr.circuit, tr.annotation.get)
- x.annotation match {
- case None => TransformResult(x.circuit, None, Some(map))
- case Some(ann) => TransformResult(x.circuit, None, Some(
- AnnotationMap(ann.annotations ++ tr.annotation.get.annotations)))
+ Seq(new SimpleMidTransform(Legalize),
+ new SimpleMidTransform(ToMemIR),
+ new SimpleMidTransform(ResolveMaskGranularity),
+ new SimpleMidTransform(RenameAnnotatedMemoryPorts),
+ new SimpleMidTransform(ResolveMemoryReference),
+ new CreateMemoryAnnotations(inConfigFile),
+ new ReplaceMemMacros(outConfigFile),
+ new WiringTransform,
+ new SimpleMidTransform(RemoveEmpty),
+ new SimpleMidTransform(CheckInitialization),
+ new SimpleMidTransform(InferTypes),
+ new SimpleMidTransform(Uniquify),
+ new SimpleMidTransform(ResolveKinds),
+ new SimpleMidTransform(ResolveGenders))
+ def run(state: CircuitState, xForms: Seq[Transform]): CircuitState = {
+ xForms.foldLeft(state) { case (curState: CircuitState, xForm: Transform) =>
+ val res = xForm.execute(state)
+ res.annotations match {
+ case None => CircuitState(res.circuit, res.form, state.annotations)
+ case Some(ann) => CircuitState(res.circuit, res.form, Some(
+ AnnotationMap(ann.annotations ++ curState.annotations.get.annotations)))
}
}
}
- def execute(c: Circuit, map: AnnotationMap) = map get transID match {
- case Some(p) => p get CircuitName(c.main) match {
- case Some(ReplSeqMemAnnotation(t, _)) =>
- val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "")
- val inConfigFile = {
- if (inputFileName.isEmpty) None
- else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName))
- else error("Input configuration file does not exist!")
- }
- val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName))
- run(c, map, passSeq(inConfigFile, outConfigFile))
- case _ => error("Unexpected transform annotation")
+ def execute(state: CircuitState): CircuitState =
+ getMyAnnotations(state) match {
+ case Some(p) => p get CircuitName(state.circuit.main) match {
+ case Some(ReplSeqMemAnnotation(t)) =>
+ val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "")
+ val inConfigFile = {
+ if (inputFileName.isEmpty) None
+ else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName))
+ else error("Input configuration file does not exist!")
+ }
+ val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName))
+ run(state, passSeq(inConfigFile, outConfigFile))
+ case _ => error("Unexpected transform annotation")
+ }
+ case None => state // Do nothing if there are no annotations
}
- case _ => TransformResult(c)
- }
}
diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
index 919948b6..59e76d65 100644
--- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
+++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
@@ -11,7 +11,8 @@ import WiringUtils._
/** A component, e.g. register etc. Must be declared only once under the TopAnnotation
*/
-case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotation with Loose with Unstable {
+case class SourceAnnotation(target: ComponentName) extends Annotation with Loose with Unstable {
+ def transform = classOf[WiringTransform]
def duplicate(n: Named) = n match {
case n: ComponentName => this.copy(target = n)
case _ => throwInternalError
@@ -20,7 +21,8 @@ case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotat
/** A module, e.g. ExtModule etc., that should add the input pin
*/
-case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends Annotation with Loose with Unstable {
+case class SinkAnnotation(target: ModuleName, pin: String) extends Annotation with Loose with Unstable {
+ def transform = classOf[WiringTransform]
def duplicate(n: Named) = n match {
case n: ModuleName => this.copy(target = n)
case _ => throwInternalError
@@ -30,7 +32,8 @@ case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends
/** A module under which all sink module must be declared, and there is only
* one source component
*/
-case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation with Loose with Unstable {
+case class TopAnnotation(target: ModuleName) extends Annotation with Loose with Unstable {
+ def transform = classOf[WiringTransform]
def duplicate(n: Named) = n match {
case n: ModuleName => this.copy(target = n)
case _ => throwInternalError
@@ -49,13 +52,15 @@ case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation wi
* Notes:
* - No module uniquification occurs (due to imposed restrictions)
*/
-class WiringTransform(transID: TransID) extends Transform with SimpleRun {
+class WiringTransform extends Transform with SimpleRun {
+ def inputForm = MidForm
+ def outputForm = MidForm
def passSeq(wi: WiringInfo) =
Seq(new Wiring(wi),
InferTypes,
ResolveKinds,
ResolveGenders)
- def execute(c: Circuit, map: AnnotationMap) = map get transID match {
+ def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match {
case Some(p) =>
val sinks = mutable.HashMap[String, String]()
val sources = mutable.Set[String]()
@@ -63,18 +68,20 @@ class WiringTransform(transID: TransID) extends Transform with SimpleRun {
val comp = mutable.Set[String]()
p.values.foreach { a =>
a match {
- case SinkAnnotation(m, _, pin) => sinks(m.name) = pin
- case SourceAnnotation(c, _) =>
+ case SinkAnnotation(m, pin) => sinks(m.name) = pin
+ case SourceAnnotation(c) =>
sources += c.module.name
comp += c.name
- case TopAnnotation(m, _) => tops += m.name
+ case TopAnnotation(m) => tops += m.name
}
}
(sources.size, tops.size, sinks.size, comp.size) match {
- case (0, 0, p, 0) => TransformResult(c)
- case (1, 1, p, 1) if p > 0 => run(c, passSeq(WiringInfo(sources.head, comp.head, sinks.toMap, tops.head)))
+ case (0, 0, p, 0) => state
+ case (1, 1, p, 1) if p > 0 =>
+ val winfo = WiringInfo(sources.head, comp.head, sinks.toMap, tops.head)
+ state.copy(circuit = runPasses(state.circuit, passSeq(winfo)))
case _ => error("Wrong number of sources, tops, or sinks!")
}
- case None => TransformResult(c)
+ case None => state
}
}
diff --git a/src/test/resources/features/CustomTransform.fir b/src/test/resources/features/CustomTransform.fir
new file mode 100644
index 00000000..941a9e9c
--- /dev/null
+++ b/src/test/resources/features/CustomTransform.fir
@@ -0,0 +1,33 @@
+circuit CustomTransform :
+ ; Replaced in custom transform by an implementation
+ extmodule Delay :
+ input clk : Clock
+ input reset : UInt<1>
+ input a : UInt<32>
+ input en : UInt<1>
+ output b : UInt<32>
+
+ module CustomTransform :
+ input clk : Clock
+ input reset : UInt<1>
+
+ reg cycle : UInt<32>, clk with : (reset => (reset, UInt<32>(0)))
+ cycle <= tail(add(cycle, UInt<32>(1)), 1)
+
+ inst delay of Delay
+ delay.clk <= clk
+ delay.reset <= reset
+ delay.a <= UInt(0)
+ delay.en <= UInt(0)
+
+ when eq(cycle, UInt(0)) :
+ delay.en <= UInt(1)
+ delay.a <= UInt("hdeadbeef")
+ when eq(cycle, UInt(1)) :
+ when neq(delay.b, UInt("hdeadbeef")) :
+ printf(clk, UInt(1), "Assertion failed!\n")
+ stop(clk, UInt(1), 1)
+ when eq(cycle, UInt(2)) :
+ printf(clk, UInt(1), "Success!\n")
+ stop(clk, UInt(1), 0)
+
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala
index 0312df5d..c395139b 100644
--- a/src/test/scala/firrtlTests/AnnotationTests.scala
+++ b/src/test/scala/firrtlTests/AnnotationTests.scala
@@ -9,14 +9,16 @@ import org.scalatest.junit.JUnitRunner
import firrtl.ir.Circuit
import firrtl.Parser
import firrtl.{
+ CircuitState,
ResolveAndCheck,
RenameMap,
Compiler,
- CompilerResult,
- VerilogCompiler
+ ChirrtlForm,
+ LowForm,
+ VerilogCompiler,
+ Transform
}
import firrtl.Annotations.{
- TransID,
Named,
CircuitName,
ModuleName,
@@ -39,17 +41,17 @@ import firrtl.Annotations.{
*/
trait AnnotationSpec extends LowTransformSpec {
// Dummy transform
- def transform = new ResolveAndCheck()
+ def transform = new CustomResolveAndCheck(LowForm)
// Check if Annotation Exception is thrown
override def failingexecute(writer: Writer, annotations: AnnotationMap, input: String) = {
intercept[AnnotationException] {
- compile(parse(input), annotations, writer)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer)
}
}
def execute(writer: Writer, annotations: AnnotationMap, input: String, check: Annotation) = {
- val cr = compile(parse(input), annotations, writer)
- (cr.annotationMap.annotations.head) should be (check)
+ val cr = compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer)
+ (cr.annotations.get.annotations.head) should be (check)
}
}
@@ -63,7 +65,6 @@ trait AnnotationSpec extends LowTransformSpec {
*/
class AnnotationTests extends AnnotationSpec with Matchers {
def getAMap (a: Annotation): AnnotationMap = new AnnotationMap(Seq(a))
- val tID = TransID(1)
val input =
"""circuit Top :
| module Top :
@@ -76,11 +77,12 @@ class AnnotationTests extends AnnotationSpec with Matchers {
val cName = ComponentName("c", mName)
"Loose and Sticky annotation on a node" should "pass through" in {
- case class TestAnnotation(target: Named, tID: TransID) extends Annotation with Loose with Sticky {
+ case class TestAnnotation(target: Named) extends Annotation with Loose with Sticky {
def duplicate(to: Named) = this.copy(target=to)
+ def transform = classOf[Transform]
}
val w = new StringWriter()
- val ta = TestAnnotation(cName, tID)
+ val ta = TestAnnotation(cName)
execute(w, getAMap(ta), input, ta)
}
}
diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala
index d1e07eae..3a67bf04 100644
--- a/src/test/scala/firrtlTests/AttachSpec.scala
+++ b/src/test/scala/firrtlTests/AttachSpec.scala
@@ -37,10 +37,9 @@ import firrtl.passes._
import firrtl.Parser.IgnoreInfo
class InoutVerilog extends FirrtlFlatSpec {
- def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def executeTest(input: String, expected: Seq[String], compiler: Compiler) = {
val writer = new StringWriter()
- compiler.compile(parse(input), new AnnotationMap(Seq.empty), writer)
+ compiler.compile(CircuitState(parse(input), ChirrtlForm), writer)
val lines = writer.toString().split("\n") map normalized
expected foreach { e =>
lines should contain(e)
@@ -176,7 +175,6 @@ class InoutVerilog extends FirrtlFlatSpec {
}
class AttachAnalogSpec extends FirrtlFlatSpec {
- def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
(c: Circuit, p: Pass) => p.run(c)
diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala
index 719a3334..51663eaf 100644
--- a/src/test/scala/firrtlTests/CInferMDirSpec.scala
+++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala
@@ -63,13 +63,12 @@ class CInferMDir extends LowTransformSpec {
}
}
- object CInferMDirCheck extends Transform with SimpleRun {
- def execute(c: Circuit, map: AnnotationMap) =
- run(c, Seq(ConstProp, CInferMDirCheckPass))
+ def transform = new PassBasedTransform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+ def passSeq = Seq(ConstProp, CInferMDirCheckPass)
}
- def transform = CInferMDirCheck
-
"Memory" should "have correct mem port directions" in {
val input = """
circuit foo :
@@ -97,7 +96,7 @@ circuit foo :
val annotationMap = AnnotationMap(Nil)
val writer = new java.io.StringWriter
- compile(parse(input), annotationMap, writer)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer)
// Check correctness of firrtl
parse(writer.toString)
}
diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala
index e2eaf690..e8dc60ae 100644
--- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala
+++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala
@@ -36,7 +36,6 @@ import firrtl.Parser.IgnoreInfo
import firrtl.passes._
class CheckInitializationSpec extends FirrtlFlatSpec {
- private def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private val passes = Seq(
ToWorkingIR,
CheckHighForm,
diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
index e0691a6b..63397da8 100644
--- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
+++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
@@ -76,13 +76,12 @@ class ChirrtlMemSpec extends LowTransformSpec {
}
}
- object MemEnableCheck extends Transform with SimpleRun {
- def execute(c: Circuit, map: AnnotationMap) =
- run(c, Seq(ConstProp, MemEnableCheckPass))
+ def transform = new PassBasedTransform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+ def passSeq = Seq(ConstProp, MemEnableCheckPass)
}
- def transform = MemEnableCheck
-
"Sequential Memory" should "have correct enable signals" in {
val input = """
circuit foo :
@@ -104,7 +103,7 @@ circuit foo :
val annotationMap = AnnotationMap(Nil)
val writer = new java.io.StringWriter
- compile(parse(input), annotationMap, writer)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer)
// Check correctness of firrtl
parse(writer.toString)
}
@@ -131,7 +130,7 @@ circuit foo :
val annotationMap = AnnotationMap(Nil)
val writer = new java.io.StringWriter
- compile(parse(input), annotationMap, writer)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer)
// Check correctness of firrtl
parse(writer.toString)
}
diff --git a/src/test/scala/firrtlTests/CompilerTests.scala b/src/test/scala/firrtlTests/CompilerTests.scala
index 2eab6e0f..28d09c2d 100644
--- a/src/test/scala/firrtlTests/CompilerTests.scala
+++ b/src/test/scala/firrtlTests/CompilerTests.scala
@@ -8,13 +8,14 @@ import org.scalatest.junit.JUnitRunner
import firrtl.ir.Circuit
import firrtl.{
- HighFirrtlCompiler,
- LowFirrtlCompiler,
- VerilogCompiler,
- Compiler,
- Parser
+ ChirrtlForm,
+ CircuitState,
+ Compiler,
+ HighFirrtlCompiler,
+ LowFirrtlCompiler,
+ Parser,
+ VerilogCompiler
}
-import firrtl.Annotations.AnnotationMap
/**
* An example methodology for testing Firrtl compilers.
@@ -30,7 +31,7 @@ abstract class CompilerSpec extends FlatSpec {
def input: String
def check: String
def getOutput: String = {
- compiler.compile(parse(input), new AnnotationMap(Seq.empty), writer)
+ compiler.compile(CircuitState(parse(input), ChirrtlForm), writer)
writer.toString()
}
}
diff --git a/src/test/scala/firrtlTests/CompilerUtilsSpec.scala b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala
new file mode 100644
index 00000000..1d349db1
--- /dev/null
+++ b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala
@@ -0,0 +1,76 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.CompilerUtils.mergeTransforms
+
+class CompilerUtilsSpec extends FirrtlFlatSpec {
+
+ def genTransform(_inputForm: CircuitForm, _outputForm: CircuitForm) = new Transform {
+ def inputForm = _inputForm
+ def outputForm = _outputForm
+ def execute(state: CircuitState): CircuitState = state
+ }
+
+ // Core lowering transforms
+ val chirrtlToHigh = genTransform(ChirrtlForm, HighForm)
+ val highToMid = genTransform(HighForm, MidForm)
+ val midToLow = genTransform(MidForm, LowForm)
+ val chirrtlToLowList = List(chirrtlToHigh, highToMid, midToLow)
+
+ // Custom transforms
+ val chirrtlToChirrtl = genTransform(ChirrtlForm, ChirrtlForm)
+ val highToHigh = genTransform(HighForm, HighForm)
+ val midToMid = genTransform(MidForm, MidForm)
+ val lowToLow = genTransform(LowForm, LowForm)
+
+ val lowToHigh = genTransform(LowForm, HighForm)
+
+ val lowToLowTwo = genTransform(LowForm, LowForm)
+
+ behavior of "mergeTransforms"
+
+ it should "do nothing if there are no custom transforms" in {
+ mergeTransforms(chirrtlToLowList, List.empty) should be (chirrtlToLowList)
+ }
+
+ it should "insert transforms at the correct place" in {
+ mergeTransforms(chirrtlToLowList, List(chirrtlToChirrtl)) should be
+ (chirrtlToChirrtl +: chirrtlToLowList)
+ mergeTransforms(chirrtlToLowList, List(highToHigh)) should be
+ (List(chirrtlToHigh, highToHigh, highToMid, midToLow))
+ mergeTransforms(chirrtlToLowList, List(midToMid)) should be
+ (List(chirrtlToHigh, highToMid, midToMid, midToLow))
+ mergeTransforms(chirrtlToLowList, List(lowToLow)) should be
+ (chirrtlToLowList :+ lowToLow)
+ }
+
+ it should "insert transforms at the last legal location" in {
+ lowToLow should not be (lowToLowTwo) // sanity check
+ mergeTransforms(chirrtlToLowList :+ lowToLow, List(lowToLowTwo)).last should be (lowToLowTwo)
+ }
+
+ it should "insert multiple transforms correctly" in {
+ mergeTransforms(chirrtlToLowList, List(highToHigh, lowToLow)) should be
+ (List(chirrtlToHigh, highToHigh, highToMid, midToLow, lowToLow))
+ }
+
+ it should "handle transforms that raise the form" in {
+ mergeTransforms(chirrtlToLowList, List(lowToHigh)) match {
+ case chirrtlToHigh :: highToMid :: midToLow :: lowToHigh :: remainder =>
+ // Remainder will be the actual Firrtl lowering transforms
+ remainder.head.inputForm should be (HighForm)
+ remainder.last.outputForm should be (LowForm)
+ case _ => fail()
+ }
+ }
+
+ // Order is not always maintained, see note on function Scaladoc
+ it should "maintain order of custom tranforms" in {
+ mergeTransforms(chirrtlToLowList, List(lowToLow, lowToLowTwo)) should be
+ (chirrtlToLowList ++ List(lowToLow, lowToLowTwo))
+ }
+
+}
+
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index bfe58a2c..f6bfa5ef 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -22,7 +22,6 @@ class ConstantPropagationSpec extends FirrtlFlatSpec {
ResolveGenders,
InferWidths,
ConstProp)
- def parse(input: String): Circuit = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def exec (input: String) = {
passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
diff --git a/src/test/scala/firrtlTests/CustomTransformSpec.scala b/src/test/scala/firrtlTests/CustomTransformSpec.scala
new file mode 100644
index 00000000..4a3faf6b
--- /dev/null
+++ b/src/test/scala/firrtlTests/CustomTransformSpec.scala
@@ -0,0 +1,51 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl.ir.Circuit
+import firrtl._
+import firrtl.passes.Pass
+import firrtl.ir._
+
+class CustomTransformSpec extends FirrtlFlatSpec {
+ behavior of "Custom Transforms"
+
+ they should "be able to introduce high firrtl" in {
+ // Simple module
+ val delayModuleString = """
+ |circuit Delay :
+ | module Delay :
+ | input clk : Clock
+ | input reset : UInt<1>
+ | input a : UInt<32>
+ | input en : UInt<1>
+ | output b : UInt<32>
+ |
+ | reg r : UInt<32>, clk
+ | r <= r
+ | when en :
+ | r <= a
+ | b <= r
+ |""".stripMargin
+ val delayModuleCircuit = parse(delayModuleString)
+ val delayModule = delayModuleCircuit.modules.find(_.name == delayModuleCircuit.main).get
+
+ class ReplaceExtModuleTransform extends PassBasedTransform {
+ class ReplaceExtModule extends Pass {
+ def name = "Replace External Module"
+ def run(c: Circuit): Circuit = c.copy(
+ modules = c.modules map {
+ case ExtModule(_, "Delay", _, _, _) => delayModule
+ case other => other
+ }
+ )
+ }
+ def passSeq = Seq(new ReplaceExtModule)
+ def inputForm = LowForm
+ def outputForm = HighForm
+ }
+
+ runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform))
+ }
+}
+
diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala
index 8bbecaeb..06963708 100644
--- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala
+++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala
@@ -36,7 +36,6 @@ import firrtl.ir._
import firrtl.Parser.IgnoreInfo
class ExpandWhensSpec extends FirrtlFlatSpec {
- private def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def executeTest(input: String, notExpected: String, passes: Seq[Pass]) = {
val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
(c: Circuit, p: Pass) => p.run(c)
diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala
index f491b0f5..83cccf3b 100644
--- a/src/test/scala/firrtlTests/FirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/FirrtlSpec.scala
@@ -36,6 +36,7 @@ import org.scalatest.prop._
import scala.io.Source
import firrtl._
+import firrtl.Parser.IgnoreInfo
import firrtl.Annotations.AnnotationMap
// This trait is borrowed from Chisel3, ideally this code should only exist in one location
@@ -131,6 +132,7 @@ trait BackendCompilationUtilities {
}
trait FirrtlRunners extends BackendCompilationUtilities {
+ def parse(str: String) = Parser.parse(str.split("\n").toIterator, IgnoreInfo)
lazy val cppHarness = new File(s"/top.cpp")
/** Compile a Firrtl file
*
@@ -141,6 +143,7 @@ trait FirrtlRunners extends BackendCompilationUtilities {
def compileFirrtlTest(
prefix: String,
srcDir: String,
+ customTransforms: Seq[Transform] = Seq.empty,
annotations: AnnotationMap = new AnnotationMap(Seq.empty)): File = {
val testDir = createTempDirectory(prefix)
copyResourceToFile(s"${srcDir}/${prefix}.fir", new File(testDir, s"${prefix}.fir"))
@@ -150,6 +153,7 @@ trait FirrtlRunners extends BackendCompilationUtilities {
s"$testDir/$prefix.v",
new VerilogCompiler(),
Parser.IgnoreInfo,
+ customTransforms,
annotations)
testDir
}
@@ -164,8 +168,9 @@ trait FirrtlRunners extends BackendCompilationUtilities {
prefix: String,
srcDir: String,
verilogPrefixes: Seq[String] = Seq.empty,
+ customTransforms: Seq[Transform] = Seq.empty,
annotations: AnnotationMap = new AnnotationMap(Seq.empty)) = {
- val testDir = compileFirrtlTest(prefix, srcDir, annotations)
+ val testDir = compileFirrtlTest(prefix, srcDir, customTransforms, annotations)
val harness = new File(testDir, s"top.cpp")
copyResourceToFile(cppHarness.toString, harness)
diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
index be663872..b6e8f726 100644
--- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala
+++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
@@ -61,19 +61,19 @@ class InferReadWriteSpec extends SimpleTransformSpec {
}
}
- object InferReadWriteCheck extends Transform with SimpleRun {
- def execute (c: Circuit, map: AnnotationMap) =
- run(c, Seq(InferReadWriteCheckPass))
+ class InferReadWriteCheck extends PassBasedTransform {
+ def inputForm = MidForm
+ def outputForm = MidForm
+ def passSeq = Seq(InferReadWriteCheckPass)
}
- def transforms (writer: java.io.Writer) = Seq(
- new Chisel3ToHighFirrtl(),
- new IRToWorkingIR(),
- new ResolveAndCheck(),
- new HighFirrtlToMiddleFirrtl(),
- new memlib.InferReadWrite(TransID(-1)),
- InferReadWriteCheck,
- new EmitFirrtl(writer)
+ def transforms = Seq(
+ new ChirrtlToHighFirrtl,
+ new IRToWorkingIR,
+ new ResolveAndCheck,
+ new HighFirrtlToMiddleFirrtl,
+ new memlib.InferReadWrite,
+ new InferReadWriteCheck
)
"Infer ReadWrite Ports" should "infer readwrite ports for the same clock" in {
@@ -100,9 +100,9 @@ circuit sram6t :
T_5 <= io.wdata
""".stripMargin
- val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t", TransID(-1))))
+ val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t")))
val writer = new java.io.StringWriter
- compile(parse(input), annotationMap, writer)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer)
// Check correctness of firrtl
parse(writer.toString)
}
@@ -132,10 +132,10 @@ circuit sram6t :
T_5 <= io.wdata
""".stripMargin
- val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t", TransID(-1))))
+ val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t")))
val writer = new java.io.StringWriter
intercept[InferReadWriteCheckException] {
- compile(parse(input), annotationMap, writer)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer)
}
}
}
diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala
index 5f19af5c..f7845cc7 100644
--- a/src/test/scala/firrtlTests/InlineInstancesTests.scala
+++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala
@@ -14,7 +14,6 @@ import firrtl.Annotations.{
CircuitName,
ModuleName,
ComponentName,
- TransID,
Annotation,
AnnotationMap
}
@@ -24,9 +23,8 @@ import firrtl.passes.{InlineInstances, InlineAnnotation}
/**
* Tests inline instances transformation
*/
-class InlineInstancesTests extends HighTransformSpec {
- val tID = TransID(0)
- val transform = new InlineInstances(tID)
+class InlineInstancesTests extends LowTransformSpec {
+ def transform = new InlineInstances
"The module Inline" should "be inlined" in {
val input =
"""circuit Top :
@@ -48,14 +46,14 @@ class InlineInstancesTests extends HighTransformSpec {
| wire i$a : UInt<32>
| wire i$b : UInt<32>
| i$b <= i$a
- | i$a <= a
| b <= i$b
+ | i$a <= a
| module Inline :
| input a : UInt<32>
| output b : UInt<32>
| b <= a""".stripMargin
val writer = new StringWriter()
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top")), tID)))
+ val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top")))))
execute(writer, aMap, input, check)
}
@@ -85,15 +83,15 @@ class InlineInstancesTests extends HighTransformSpec {
| wire i1$a : UInt<32>
| wire i1$b : UInt<32>
| i1$b <= i1$a
+ | b <= i1$b
| i0$a <= a
| i1$a <= i0$b
- | b <= i1$b
| module Simple :
| input a : UInt<32>
| output b : UInt<32>
| b <= a""".stripMargin
val writer = new StringWriter()
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top")), tID)))
+ val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top")))))
execute(writer, aMap, input, check)
}
@@ -121,15 +119,15 @@ class InlineInstancesTests extends HighTransformSpec {
| wire i0$b : UInt<32>
| i0$b <= i0$a
| inst i1 of Simple
+ | b <= i1.b
| i0$a <= a
| i1.a <= i0$b
- | b <= i1.b
| module Simple :
| input a : UInt<32>
| output b : UInt<32>
| b <= a""".stripMargin
val writer = new StringWriter()
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ComponentName("i0",ModuleName("Top", CircuitName("Top"))), tID)))
+ val aMap = new AnnotationMap(Seq(InlineAnnotation(ComponentName("i0",ModuleName("Top", CircuitName("Top"))))))
execute(writer, aMap, input, check)
}
@@ -163,9 +161,9 @@ class InlineInstancesTests extends HighTransformSpec {
| wire i0$b : UInt<32>
| i0$b <= i0$a
| inst i1 of B
+ | b <= i1.b
| i0$a <= a
| i1.a <= i0$b
- | b <= i1.b
| module A :
| input a : UInt<32>
| output b : UInt<32>
@@ -176,10 +174,10 @@ class InlineInstancesTests extends HighTransformSpec {
| wire i$a : UInt<32>
| wire i$b : UInt<32>
| i$b <= i$a
- | i$a <= a
- | b <= i$b""".stripMargin
+ | b <= i$b
+ | i$a <= a""".stripMargin
val writer = new StringWriter()
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID)))
+ val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
execute(writer, aMap, input, check)
}
@@ -199,7 +197,7 @@ class InlineInstancesTests extends HighTransformSpec {
| input a : UInt<32>
| output b : UInt<32>""".stripMargin
val writer = new StringWriter()
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID)))
+ val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
failingexecute(writer, aMap, input)
}
// 2) ext instance
@@ -216,7 +214,7 @@ class InlineInstancesTests extends HighTransformSpec {
| input a : UInt<32>
| output b : UInt<32>""".stripMargin
val writer = new StringWriter()
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID)))
+ val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
failingexecute(writer, aMap, input)
}
// 3) no module
@@ -228,7 +226,7 @@ class InlineInstancesTests extends HighTransformSpec {
| output b : UInt<32>
| b <= a""".stripMargin
val writer = new StringWriter()
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID)))
+ val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
failingexecute(writer, aMap, input)
}
// 4) no inst
@@ -240,7 +238,7 @@ class InlineInstancesTests extends HighTransformSpec {
| output b : UInt<32>
| b <= a""".stripMargin
val writer = new StringWriter()
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID)))
+ val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
failingexecute(writer, aMap, input)
}
}
diff --git a/src/test/scala/firrtlTests/MultiThreadingSpec.scala b/src/test/scala/firrtlTests/MultiThreadingSpec.scala
index bfaed330..b2934314 100644
--- a/src/test/scala/firrtlTests/MultiThreadingSpec.scala
+++ b/src/test/scala/firrtlTests/MultiThreadingSpec.scala
@@ -2,6 +2,8 @@
package firrtlTests
+import firrtl.{ChirrtlForm, CircuitState, Compiler, Annotations}
+
import scala.concurrent.{Future, Await, ExecutionContext}
import scala.concurrent.duration.Duration
@@ -13,7 +15,7 @@ class MultiThreadingSpec extends FirrtlPropSpec {
def runCompiler(input: Seq[String], compiler: firrtl.Compiler): String = {
val writer = new java.io.StringWriter
val parsedInput = firrtl.Parser.parse(input)
- compiler.compile(parsedInput,new firrtl.Annotations.AnnotationMap(Seq.empty), writer)
+ compiler.compile(CircuitState(parsedInput, ChirrtlForm), writer)
writer.toString
}
// The parameters we're testing with
diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala
index e5269396..e574d31f 100644
--- a/src/test/scala/firrtlTests/PassTests.scala
+++ b/src/test/scala/firrtlTests/PassTests.scala
@@ -4,21 +4,27 @@ import com.typesafe.scalalogging.LazyLogging
import java.io.{StringWriter,Writer}
import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.junit.JUnitRunner
-import firrtl.{Parser,FIRRTLEmitter}
import firrtl.ir.Circuit
import firrtl.Parser.IgnoreInfo
-import firrtl.passes.{Pass, PassExceptions}
+import firrtl.passes.{Pass, PassExceptions, RemoveEmpty}
import firrtl.{
Transform,
- TransformResult,
+ PassBasedTransform,
+ CircuitState,
+ CircuitForm,
+ ChirrtlForm,
+ HighForm,
+ MidForm,
+ LowForm,
SimpleRun,
- Chisel3ToHighFirrtl,
+ ChirrtlToHighFirrtl,
IRToWorkingIR,
ResolveAndCheck,
HighFirrtlToMiddleFirrtl,
MiddleFirrtlToLowFirrtl,
- EmitFirrtl,
- Compiler
+ FirrtlEmitter,
+ Compiler,
+ Parser
}
import firrtl.Annotations.AnnotationMap
@@ -26,58 +32,66 @@ import firrtl.Annotations.AnnotationMap
// An example methodology for testing Firrtl Passes
// Spec class should extend this class
abstract class SimpleTransformSpec extends FlatSpec with Matchers with Compiler with LazyLogging {
+ def emitter = new FirrtlEmitter
+
// Utility function
def parse(s: String): Circuit = Parser.parse(s.split("\n").toIterator, infoMode = IgnoreInfo)
// Executes the test. Call in tests.
def execute(writer: Writer, annotations: AnnotationMap, input: String, check: String) = {
- compile(parse(input), annotations, writer)
- logger.debug(writer.toString)
- logger.debug(check)
- (parse(writer.toString)) should be (parse(check))
+ compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer)
+ val actual = RemoveEmpty.run(parse(writer.toString)).serialize
+ val expected = parse(check).serialize
+ logger.debug(actual)
+ logger.debug(expected)
+ (actual) should be (expected)
}
// Executes the test, should throw an error
def failingexecute(writer: Writer, annotations: AnnotationMap, input: String): Exception = {
intercept[PassExceptions] {
- compile(parse(input), annotations, writer)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer)
}
}
}
+class CustomResolveAndCheck(form: CircuitForm) extends PassBasedTransform {
+ private val wrappedTransform = new ResolveAndCheck
+ def inputForm = form
+ def outputForm = form
+ def passSeq = wrappedTransform.passSeq
+}
+
trait LowTransformSpec extends SimpleTransformSpec {
def transform: Transform
- def transforms (writer: Writer) = Seq(
- new Chisel3ToHighFirrtl(),
+ def transforms = Seq(
+ new ChirrtlToHighFirrtl(),
new IRToWorkingIR(),
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
new MiddleFirrtlToLowFirrtl(),
- new ResolveAndCheck(),
- transform,
- new EmitFirrtl(writer)
+ new CustomResolveAndCheck(LowForm),
+ transform
)
}
trait MiddleTransformSpec extends SimpleTransformSpec {
def transform: Transform
- def transforms (writer: Writer) = Seq(
- new Chisel3ToHighFirrtl(),
+ def transforms = Seq(
+ new ChirrtlToHighFirrtl(),
new IRToWorkingIR(),
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
- new ResolveAndCheck(),
- transform,
- new EmitFirrtl(writer)
+ new CustomResolveAndCheck(MidForm),
+ transform
)
}
trait HighTransformSpec extends SimpleTransformSpec {
def transform: Transform
- def transforms (writer: Writer) = Seq(
- new Chisel3ToHighFirrtl(),
+ def transforms = Seq(
+ new ChirrtlToHighFirrtl(),
new IRToWorkingIR(),
new ResolveAndCheck(),
- transform,
- new EmitFirrtl(writer)
+ transform
)
}
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 78b3d9f0..e46230ef 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -6,19 +6,19 @@ import firrtl.passes.memlib._
import Annotations._
class ReplSeqMemSpec extends SimpleTransformSpec {
- val passSeq = Seq(
- ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty)
- def transforms (writer: java.io.Writer) = Seq(
- new Chisel3ToHighFirrtl(),
+ def transforms = Seq(
+ new ChirrtlToHighFirrtl(),
new IRToWorkingIR(),
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
- new InferReadWrite(TransID(-1)),
- new ReplSeqMem(TransID(-2)),
+ new InferReadWrite(),
+ new ReplSeqMem(),
new MiddleFirrtlToLowFirrtl(),
- (new Transform with SimpleRun {
- def execute(c: ir.Circuit, a: AnnotationMap) = run(c, passSeq) } ),
- new EmitFirrtl(writer)
+ new PassBasedTransform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+ def passSeq = Seq(ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty)
+ }
)
"ReplSeqMem" should "generate blackbox wrappers for mems of bundle type" in {
@@ -58,9 +58,9 @@ circuit Top :
io2.commit_entry.bits.info <- R1
""".stripMargin
val confLoc = "ReplSeqMemTests.confTEMP"
- val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc, TransID(-2))))
+ val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc)))
val writer = new java.io.StringWriter
- compile(parse(input), aMap, writer)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer)
// Check correctness of firrtl
parse(writer.toString)
(new java.io.File(confLoc)).delete()
@@ -81,9 +81,9 @@ circuit Top :
write mport T_155 = mem[p_address], clk
""".stripMargin
val confLoc = "ReplSeqMemTests.confTEMP"
- val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc, TransID(-2))))
+ val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc)))
val writer = new java.io.StringWriter
- compile(parse(input), aMap, writer)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer)
// Check correctness of firrtl
parse(writer.toString)
(new java.io.File(confLoc)).delete()
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index 245c32e8..1025c02b 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -36,7 +36,6 @@ import firrtl.passes._
import firrtl.Parser.IgnoreInfo
class UnitTests extends FirrtlFlatSpec {
- def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
(c: Circuit, p: Pass) => p.run(c)
@@ -114,7 +113,7 @@ class UnitTests extends FirrtlFlatSpec {
(c: Circuit, p: Pass) => p.run(c)
}
val writer = new StringWriter()
- FIRRTLEmitter.run(c_result,writer)
+ (new FirrtlEmitter).emit(CircuitState(c_result, HighForm), writer)
(parse(writer.toString())) should be (parse(check))
}
@@ -136,7 +135,7 @@ class UnitTests extends FirrtlFlatSpec {
intercept[PassException] {
val c = Parser.parse(splitExpTestCode.split("\n").toIterator)
val c2 = passes.foldLeft(c)((c, p) => p run c)
- new VerilogEmitter().run(c2, new OutputStreamWriter(new ByteArrayOutputStream))
+ (new VerilogEmitter).emit(CircuitState(c2, LowForm), new StringWriter)
}
}
@@ -147,7 +146,7 @@ class UnitTests extends FirrtlFlatSpec {
InferTypes)
val c = Parser.parse(splitExpTestCode.split("\n").toIterator)
val c2 = passes.foldLeft(c)((c, p) => p run c)
- new VerilogEmitter().run(c2, new OutputStreamWriter(new ByteArrayOutputStream))
+ (new VerilogEmitter).emit(CircuitState(c2, LowForm), new StringWriter)
}
"Simple compound expressions" should "be split" in {
diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala
index 1f6142bc..e9bf5429 100644
--- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala
+++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala
@@ -37,10 +37,9 @@ import firrtl.passes._
import firrtl.Parser.IgnoreInfo
class DoPrimVerilog extends FirrtlFlatSpec {
- def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def executeTest(input: String, expected: Seq[String], compiler: Compiler) = {
val writer = new StringWriter()
- compiler.compile(parse(input), new AnnotationMap(Seq.empty), writer)
+ compiler.compile(CircuitState(parse(input), ChirrtlForm), writer)
val lines = writer.toString().split("\n") map normalized
expected foreach { e =>
lines should contain(e)
diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala
index d1b16bb9..74f6432f 100644
--- a/src/test/scala/firrtlTests/WidthSpec.scala
+++ b/src/test/scala/firrtlTests/WidthSpec.scala
@@ -36,7 +36,6 @@ import firrtl.passes._
import firrtl.Parser.IgnoreInfo
class WidthSpec extends FirrtlFlatSpec {
- def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
(c: Circuit, p: Pass) => p.run(c)
diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala
index 5f40d861..309014d4 100644
--- a/src/test/scala/firrtlTests/WiringTests.scala
+++ b/src/test/scala/firrtlTests/WiringTests.scala
@@ -12,7 +12,6 @@ import wiring.WiringUtils._
import wiring._
class WiringTests extends FirrtlFlatSpec {
- def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
(c: Circuit, p: Pass) => p.run(c)
diff --git a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala
index a9a1bb47..4a87290d 100644
--- a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala
@@ -5,12 +5,11 @@ package firrtlTests.fixed
import java.io.StringWriter
import firrtl.Annotations.AnnotationMap
-import firrtl.{LowFirrtlCompiler, Parser}
+import firrtl.{CircuitState, ChirrtlForm, LowFirrtlCompiler, Parser}
import firrtl.Parser.IgnoreInfo
import firrtlTests.FirrtlFlatSpec
class FixedPointMathSpec extends FirrtlFlatSpec {
- def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
val SumPattern = """.*output sum.*<(\d+)>.*.*""".r
val ProductPattern = """.*output product.*<(\d+)>.*""".r
@@ -45,7 +44,7 @@ class FixedPointMathSpec extends FirrtlFlatSpec {
val writer = new StringWriter()
- lowerer.compile(parse(input), new AnnotationMap(Seq.empty), writer)
+ lowerer.compile(CircuitState(parse(input), ChirrtlForm), writer)
val output = writer.toString.split("\n")
diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
index 53b4f4c0..3f465361 100644
--- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
@@ -34,7 +34,6 @@ import firrtl.passes._
import firrtl.Parser.IgnoreInfo
class FixedTypeInferenceSpec extends FirrtlFlatSpec {
- def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
(c: Circuit, p: Pass) => p.run(c)
diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
index 6799a367..27d7e172 100644
--- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
@@ -35,7 +35,6 @@ import firrtl.passes._
import firrtl.Parser.IgnoreInfo
class RemoveFixedTypeSpec extends FirrtlFlatSpec {
- def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
(c: Circuit, p: Pass) => p.run(c)
@@ -204,14 +203,14 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
| io_out <= io_in
""".stripMargin
- class CheckChirrtlTransform extends Transform with SimpleRun {
+ class CheckChirrtlTransform extends PassBasedTransform {
+ def inputForm = ChirrtlForm
+ def outputForm = ChirrtlForm
val passSeq = Seq(passes.CheckChirrtl)
- def execute (circuit: Circuit, annotationMap: AnnotationMap): TransformResult =
- run(circuit, passSeq)
}
val chirrtlTransform = new CheckChirrtlTransform
- chirrtlTransform.execute(parse(input), new AnnotationMap(Seq.empty))
+ chirrtlTransform.execute(CircuitState(parse(input), ChirrtlForm, Some(new AnnotationMap(Seq.empty))))
}
}