diff options
| author | jackkoenig | 2016-10-20 00:19:01 -0700 |
|---|---|---|
| committer | Jack Koenig | 2016-11-04 13:29:09 -0700 |
| commit | 8fa9429a6e916ab2a789f5d81fa803b022805b52 (patch) | |
| tree | fac2efcbd0a68bfb1916f09afc7f003c7a3d6528 /src | |
| parent | 62133264a788f46b319ebab9c31424b7e0536101 (diff) | |
Refactor Compilers and Transforms
* Transform Ids now handled by Class[_ <: Transform] instead of magic numbers
* Transforms define inputForm and outputForm
* Custom transforms can be inserted at runtime into compiler or the Driver
* Current "built-in" custom transforms handled via above mechanism
* Verilog-specific passes moved to the Verilog emitter
Diffstat (limited to 'src')
37 files changed, 778 insertions, 423 deletions
diff --git a/src/main/scala/firrtl/Annotations.scala b/src/main/scala/firrtl/Annotations.scala index d47ce67e..d70732e6 100644 --- a/src/main/scala/firrtl/Annotations.scala +++ b/src/main/scala/firrtl/Annotations.scala @@ -115,12 +115,6 @@ object Annotations { } /** - * Transform ID (TransID) associates an annotation with an instantiated - * Firrtl compiler transform - */ - case class TransID(id: Int) - - /** * Permissibility defines the range of acceptable changes to the annotated component. */ trait Permissibility { @@ -215,7 +209,7 @@ object Annotations { /** * Annotation associates with a given named circuit component (target) and a - * given transformation (tID). Also defined are the legal ranges of changes + * given transformation (transform). Also defined are the legal ranges of changes * to the associated component (Permissibility) and how the annotation * propagates under such changes (Tenacity). Subclasses must implement the * duplicate function to create the same annotation associated with a new @@ -223,7 +217,7 @@ object Annotations { */ trait Annotation extends Permissibility with Tenacity { def target: Named - def tID: TransID + def transform: Class[_ <: Transform] protected def duplicate(n: Named): Annotation def serialize: String = this.toString def update(tos: Seq[Named]): Seq[Annotation] = { @@ -236,23 +230,23 @@ object Annotations { * Container of all annotations for a Firrtl compiler. */ case class AnnotationMap(annotations: Seq[Annotation]) { - type NamedMap = Map[Named, Map[TransID, Annotation]] - type IDMap = Map[TransID, Map[Named, Annotation]] + type NamedMap = Map[Named, Map[Class[_], Annotation]] + type IDMap = Map[Class[_], Map[Named, Annotation]] val (namedMap: NamedMap, idMap:IDMap) = //annotations.foldLeft(Tuple2[NamedMap, IDMap](Map.empty, Map.empty)){ annotations.foldLeft((Map.empty: NamedMap, Map.empty: IDMap)){ (partialMaps: (NamedMap, IDMap), annotation: Annotation) => { - val tIDToAnn = partialMaps._1.getOrElse(annotation.target, Map.empty) - val pNMap = partialMaps._1 + (annotation.target -> (tIDToAnn + (annotation.tID -> annotation))) + val transformToAnn = partialMaps._1.getOrElse(annotation.target, Map.empty) + val pNMap = partialMaps._1 + (annotation.target -> (transformToAnn + (annotation.transform -> annotation))) - val nToAnn = partialMaps._2.getOrElse(annotation.tID, Map.empty) - val ptIDMap = partialMaps._2 + (annotation.tID -> (nToAnn + (annotation.target -> annotation))) - Tuple2(pNMap, ptIDMap) + val nToAnn = partialMaps._2.getOrElse(annotation.transform, Map.empty) + val ptransformMap = partialMaps._2 + (annotation.transform -> (nToAnn + (annotation.target -> annotation))) + Tuple2(pNMap, ptransformMap) } } - def get(id: TransID): Option[Map[Named, Annotation]] = idMap.get(id) - def get(named: Named): Option[Map[TransID, Annotation]] = namedMap.get(named) + def get(id: Class[_]): Option[Map[Named, Annotation]] = idMap.get(id) + def get(named: Named): Option[Map[Class[_], Annotation]] = namedMap.get(named) } } diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index f566544e..9781972e 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -27,48 +27,249 @@ MODIFICATIONS. package firrtl -import com.typesafe.scalalogging.LazyLogging -import scala.collection.mutable +import logger.LazyLogging import java.io.Writer import Annotations._ import firrtl.ir.Circuit +import passes.Pass + /** * RenameMap maps old names to modified names. Generated by transformations * that modify names */ case class RenameMap(map: Map[Named, Seq[Named]]) -// =========================================== -// Transforms -// ------------------------------------------- - -case class TransformResult( +/** Current State of the Circuit + * + * @constructor Creates a CircuitState object + * @param circuit The current state of the Firrtl AST + * @param form The current form of the circuit + * @param annotations The current collection of [[Annotations.Annotation]] + * @param renames A map of [[Annotations.Named]] things that have been renamed. + * Generally only a return value from [[Transform]]s + */ +case class CircuitState( circuit: Circuit, - renames: Option[RenameMap] = None, - annotation: Option[AnnotationMap] = None) + form: CircuitForm, + annotations: Option[AnnotationMap] = None, + renames: Option[RenameMap] = None) -// - Transforms a circuit -// - Can consume multiple CircuitAnnotation's -trait Transform { - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult +/** Current form of the Firrtl Circuit + * + * Form is a measure of addition restrictions on the legality of a Firrtl + * circuit. There is a notion of "highness" and "lowness" implemented in the + * compiler by extending scala.math.Ordered. "Lower" forms add additional + * restrictions compared to "higher" forms. This means that "higher" forms are + * strictly supersets of the "lower" forms. Thus, that any transform that + * operates on [[HighForm]] can also operate on [[MidForm]] or [[LowForm]] + */ +sealed abstract class CircuitForm(private val value: Int) extends Ordered[CircuitForm] { + // Note that value is used only to allow comparisons + def compare(that: CircuitForm): Int = this.value - that.value } +/** Chirrtl Form + * + * The form of the circuit emitted by Chisel. Not a true Firrtl form. + * Includes cmem, smem, and mport IR nodes which enable declaring memories + * separately form their ports. A "Higher" form than [[HighForm]] + * + * See [[CDefMemory]] and [[CDefMPort]] + */ +final case object ChirrtlForm extends CircuitForm(3) +/** High Form + * + * As detailed in the Firrtl specification + * [[https://github.com/ucb-bar/firrtl/blob/master/spec/spec.pdf]] + * + * Also see [[firrtl.ir]] + */ +final case object HighForm extends CircuitForm(2) +/** Middle Form + * + * A "lower" form than [[HighForm]] with the following restrictions: + * - All widths must be explicit + * - All whens must be removed + * - There can only be a single connection to any element + */ +final case object MidForm extends CircuitForm(1) +/** Low Form + * + * The "lowest" form. In addition to the restrictions in [[MidForm]]: + * - All aggregate types (vector/bundle) must have been removed + * - All implicit truncations must be made explicit + */ +final case object LowForm extends CircuitForm(0) +/** The basic unit of operating on a Firrtl AST */ +abstract class Transform { + /** A convenience function useful for debugging and error messages */ + def name: String = this.getClass.getSimpleName + /** The [[CircuitForm]] that this transform requires to operate on */ + def inputForm: CircuitForm + /** The [[CircuitForm]] that this transform outputs */ + def outputForm: CircuitForm + /** Perform the transform + * + * @param state Input Firrtl AST + * @return A transformed Firrtl AST + */ + def execute(state: CircuitState): CircuitState + /** Convenience method to get annotations relevant to this Transform + * + * @param state The [[CircuitState]] form which to extract annotations + * @return A collection of annotations + */ + final def getMyAnnotations(state: CircuitState): Option[Map[Named, Annotation]] = + for { + annotations <- state.annotations + myAnnotations <- annotations.get(this.getClass) + } yield myAnnotations +} -// =========================================== -// Compilers -// ------------------------------------------- +trait SimpleRun extends LazyLogging { + def runPasses(circuit: Circuit, passSeq: Seq[Pass]): Circuit = + passSeq.foldLeft(circuit) { (c: Circuit, pass: Pass) => + val x = Utils.time(pass.name) { pass.run(c) } + logger.debug(x.serialize) + x + } +} + +/** For PassBased Transforms and Emitters + * + * @note passSeq accepts no arguments + * @todo make passes accept CircuitState so annotations can pass data between them + */ +trait PassBased extends SimpleRun { + def passSeq: Seq[Pass] + def runPasses(circuit: Circuit): Circuit = runPasses(circuit, passSeq) +} -case class CompilerResult(circuit: Circuit, annotationMap: AnnotationMap) +/** For transformations that are simply a sequence of passes */ +abstract class PassBasedTransform extends Transform with PassBased { + def execute(state: CircuitState): CircuitState = { + require(state.form <= inputForm, + s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") + CircuitState(runPasses(state.circuit), outputForm) + } +} + +/** Similar to a Transform except that it writes to a Writer instead of returning a + * CircuitState + */ +abstract class Emitter { + def emit(state: CircuitState, writer: Writer): Unit +} + +object CompilerUtils { + /** Generates a sequence of [[Transform]]s to lower a Firrtl circuit + * + * @param inputForm [[CircuitForm]] to lower from + * @param outputForm [[CircuitForm to lower to + * @return Sequence of transforms that will lower if outputForm is lower than inputForm + */ + def getLoweringTransforms(inputForm: CircuitForm, outputForm: CircuitForm): Seq[Transform] = { + // If outputForm is equal-to or higher than inputForm, nothing to lower + if (outputForm >= inputForm) { + Seq.empty + } else { + inputForm match { + case ChirrtlForm => Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm) + case HighForm => Seq(new IRToWorkingIR, new ResolveAndCheck, new HighFirrtlToMiddleFirrtl) ++ + getLoweringTransforms(MidForm, outputForm) + case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) + case LowForm => error("Internal Error! This shouldn't be possible") // should be caught by if above + } + } + } + + /** Merge a Seq of lowering transforms with custom transforms + * + * Custom Transforms are inserted based on their [[Transform.inputForm]] and + * [[Transform.outputForm]]. Custom transforms are inserted in order at the + * last location in the Seq of transforms where previous.outputForm == + * customTransform.inputForm. If a customTransform outputs a higher form + * than input, [[getLoweringTransforms]] is used to relower the circuit. + * + * @example + * {{{ + * // Let Transforms be represented by CircuitForm => CircuitForm + * val A = HighForm => MidForm + * val B = MidForm => LowForm + * val lowering = List(A, B) // Assume these transforms are used by getLoweringTransforms + * // Some custom transforms + * val C = LowForm => LowForm + * val D = MidForm => MidForm + * val E = LowForm => HighForm + * // All of the following comparisons are true + * mergeTransforms(lowering, List(C)) == List(A, B, C) + * mergeTransforms(lowering, List(D)) == List(A, D, B) + * mergeTransforms(lowering, List(E)) == List(A, B, E, A, B) + * mergeTransforms(lowering, List(C, E)) == List(A, B, C, E, A, B) + * mergeTransforms(lowering, List(E, C)) == List(A, B, E, A, B, C) + * // Notice that in the following, custom transform order is NOT preserved (see note) + * mergeTransforms(lowering, List(C, D)) == List(A, D, B, C) + * }}} + * + * @note Order will be preserved for custom transforms so long as the + * inputForm of a latter transforms is equal to or lower than the outputForm + * of the previous transform. + */ + def mergeTransforms(lowering: Seq[Transform], custom: Seq[Transform]): Seq[Transform] = { + custom.foldLeft(lowering) { case (transforms, xform) => + val index = transforms lastIndexWhere (_.outputForm == xform.inputForm) + assert(index >= 0 || xform.inputForm == ChirrtlForm, // If ChirrtlForm just put at front + s"No transform in $lowering has outputForm ${xform.inputForm} as required by $xform") + val (front, back) = transforms.splitAt(index + 1) // +1 because we want to be AFTER index + front ++ List(xform) ++ getLoweringTransforms(xform.outputForm, xform.inputForm) ++ back + } + } + +} -// - A sequence of transformations -// - Call compile to executes each transformation in sequence onto -// a given circuit. trait Compiler { - def transforms(w: Writer): Seq[Transform] - def compile(circuit: Circuit, annotationMap: AnnotationMap, writer: Writer): CompilerResult = - (transforms(writer) foldLeft CompilerResult(circuit, annotationMap)){ (in, xform) => - val result = xform.execute(in.circuit, in.annotationMap) + def emitter: Emitter + /** The sequence of transforms this compiler will execute + * @note The inputForm of a given transform must be higher than or equal to the ouputForm of the + * preceding transform. See [[CircuitForm]] + */ + def transforms: Seq[Transform] + + // Similar to (input|output)Form on [[Transform]] but derived from this Compiler's transforms + def inputForm = transforms.head.inputForm + def outputForm = transforms.last.outputForm + + private def transformsLegal(xforms: Seq[Transform]): Boolean = + if (xforms.size < 2) { + true + } else { + xforms.sliding(2, 1) + .map { case Seq(p, n) => n.inputForm >= p.outputForm } + .reduce(_ && _) + } + + assert(transformsLegal(transforms), + "Illegal Compiler, each transform must be able to accept the output of the previous transform!") + + /** Perform compilation + * + * @param state The Firrtl AST to compile + * @param writer The java.io.Writer where the output of compilation will be emitted + * @param customTransforms Any custom [[Transform]]s that will be inserted + * into the compilation process by [[CompilerUtils.mergeTransforms]] + */ + def compile(state: CircuitState, + writer: Writer, + customTransforms: Seq[Transform] = Seq.empty): CircuitState = { + val allTransforms = CompilerUtils.mergeTransforms(transforms, customTransforms) + + val finalState = allTransforms.foldLeft(state) { (in, xform) => + val result = Utils.time(s"***${xform.name}***") { xform.execute(in) } + + // Annotation propagation + // TODO: This should be redone + val inAnnotationMap = in.annotations getOrElse AnnotationMap(Seq.empty) val remappedAnnotations: Seq[Annotation] = result.renames match { case Some(RenameMap(rmap)) => // For each key in the rename map (rmap), obtain the @@ -77,18 +278,22 @@ trait Compiler { // annotations with the names in rmap's value. for { (oldName, newNames) <- rmap.toSeq - tID2OldAnnos <- in.annotationMap.get(oldName).toSeq - oldAnno <- tID2OldAnnos.values + transform2OldAnnos <- inAnnotationMap.get(oldName).toSeq + oldAnno <- transform2OldAnnos.values newAnno <- oldAnno.update(newNames) } yield newAnno - case _ => in.annotationMap.annotations + case _ => inAnnotationMap.annotations } - val resultAnnotations: Seq[Annotation] = result.annotation match { + val resultAnnotations: Seq[Annotation] = result.annotations match { case None => Nil case Some(p) => p.annotations } - CompilerResult(result.circuit, - new AnnotationMap(remappedAnnotations ++ resultAnnotations)) + val newAnnotations = AnnotationMap(remappedAnnotations ++ resultAnnotations) + CircuitState(result.circuit, result.form, Some(newAnnotations)) } + + emitter.emit(finalState, writer) + finalState + } } diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index ba5527b4..293ac4fd 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -30,7 +30,8 @@ import scala.collection._ * firrtl.Driver.execute(Array("--top-name Dummy --compiler verilog".split(" +")) * }}} * each approach has its own endearing aspects - * @see firrtlTests.DriverSpec.scala in the test directory for a lot more examples + * @see firrtlTests/DriverSpec.scala in the test directory for a lot more examples + * @see [[CompilerUtils.mergeTransforms]] to see how customTransformations are inserted */ object Driver { @@ -42,11 +43,15 @@ object Driver { output: String, compiler: Compiler, infoMode: InfoMode = IgnoreInfo, + customTransforms: Seq[Transform] = Seq.empty, annotations: AnnotationMap = new AnnotationMap(Seq.empty) ): String = { val parsedInput = Parser.parse(Source.fromFile(input).getLines(), infoMode) val outputBuffer = new java.io.CharArrayWriter - compiler.compile(parsedInput, annotations, outputBuffer) + compiler.compile( + CircuitState(parsedInput, ChirrtlForm, Some(annotations)), + outputBuffer, + customTransforms) val outputFile = new java.io.PrintWriter(output) val outputString = outputBuffer.toString @@ -108,7 +113,11 @@ object Driver { val parsedInput = Parser.parse(firrtlSource, firrtlConfig.infoMode) val outputBuffer = new java.io.CharArrayWriter - firrtlConfig.compiler.compile(parsedInput, new AnnotationMap(firrtlConfig.annotations), outputBuffer) + firrtlConfig.compiler.compile( + CircuitState(parsedInput, ChirrtlForm, Some(new AnnotationMap(firrtlConfig.annotations))), + outputBuffer, + firrtlConfig.customTransforms + ) val outputFileName = firrtlConfig.getOutputFileName(optionsManager) val outputFile = new java.io.PrintWriter(outputFileName) @@ -193,4 +202,4 @@ object FileUtils { } } } -}
\ No newline at end of file +} diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 7b198149..1d64dc91 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -47,12 +47,8 @@ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap, HashSet} case class EmitterException(message: String) extends PassException(message) -trait Emitter extends LazyLogging { - def run(c: Circuit, w: Writer) -} - -object FIRRTLEmitter extends Emitter { - def run(c: Circuit, w: Writer) = w.write(c.serialize) +class FirrtlEmitter extends Emitter { + def emit(state: CircuitState, writer: Writer): Unit = writer.write(state.circuit.serialize) } case class VRandom(width: BigInt) extends Expression { @@ -65,7 +61,7 @@ case class VRandom(width: BigInt) extends Expression { def mapWidth(f: Width => Width): Expression = this } -class VerilogEmitter extends Emitter { +class VerilogEmitter extends Emitter with PassBased { val tab = " " def AND(e1: WrappedExpression, e2: WrappedExpression): Expression = { if (e1 == e2) e1.e1 @@ -590,12 +586,18 @@ class VerilogEmitter extends Emitter { "`endif\n")) } - def run(c: Circuit, w: Writer) = { - emit_preamble(w) - val moduleMap = (c.modules map (m => m.name -> m)).toMap - c.modules foreach { - case (m: Module) => emit_verilog(m, moduleMap)(w) - case (m: ExtModule) => - } - } + def passSeq = Seq( + passes.VerilogWrap, + passes.VerilogRename, + passes.VerilogPrep) + + def emit(state: CircuitState, writer: Writer): Unit = { + val circuit = runPasses(state.circuit) + emit_preamble(writer) + val moduleMap = (circuit.modules map (m => m.name -> m)).toMap + circuit.modules foreach { + case (m: Module) => emit_verilog(m, moduleMap)(writer) + case (m: ExtModule) => + } + } } diff --git a/src/main/scala/firrtl/ExecutionOptionsManager.scala b/src/main/scala/firrtl/ExecutionOptionsManager.scala index e4954610..21a9cc50 100644 --- a/src/main/scala/firrtl/ExecutionOptionsManager.scala +++ b/src/main/scala/firrtl/ExecutionOptionsManager.scala @@ -140,6 +140,7 @@ case class FirrtlExecutionOptions( infoModeName: String = "append", inferRW: Seq[String] = Seq.empty, firrtlSource: Option[String] = None, + customTransforms: Seq[Transform] = List.empty, annotations: List[Annotation] = List.empty) extends ComposableOptions { @@ -249,14 +250,17 @@ trait HasFirrtlOptions { val newAnnotations = x.map { value => value.split('.') match { case Array(circuit) => - passes.InlineAnnotation(CircuitName(circuit), TransID(0)) + passes.InlineAnnotation(CircuitName(circuit)) case Array(circuit, module) => - passes.InlineAnnotation(ModuleName(module, CircuitName(circuit)), TransID(0)) + passes.InlineAnnotation(ModuleName(module, CircuitName(circuit))) case Array(circuit, module, inst) => - passes.InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit))), TransID(0)) + passes.InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit)))) } } - firrtlOptions = firrtlOptions.copy(annotations = firrtlOptions.annotations ++ newAnnotations) + firrtlOptions = firrtlOptions.copy( + annotations = firrtlOptions.annotations ++ newAnnotations, + customTransforms = firrtlOptions.customTransforms :+ new passes.InlineInstances + ) } .text { """Inline one or more module (comma separated, no spaces) module looks like "MyModule" or "MyModule.myinstance""" @@ -267,7 +271,8 @@ trait HasFirrtlOptions { .valueName ("<circuit>") .foreach { x => firrtlOptions = firrtlOptions.copy( - annotations = firrtlOptions.annotations :+ InferReadWriteAnnotation(x, TransID(-1)) + annotations = firrtlOptions.annotations :+ InferReadWriteAnnotation(x), + customTransforms = firrtlOptions.customTransforms :+ new passes.memlib.InferReadWrite ) }.text { "Enable readwrite port inference for the target circuit" @@ -278,7 +283,8 @@ trait HasFirrtlOptions { .valueName ("-c:<circuit>:-i:<filename>:-o:<filename>") .foreach { x => firrtlOptions = firrtlOptions.copy( - annotations = firrtlOptions.annotations :+ ReplSeqMemAnnotation(x, TransID(-2)) + annotations = firrtlOptions.annotations :+ ReplSeqMemAnnotation(x), + customTransforms = firrtlOptions.customTransforms :+ new passes.memlib.ReplSeqMem ) } .text { diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 446df6d0..986ebd9f 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -27,59 +27,38 @@ MODIFICATIONS. package firrtl -import java.io.Writer -import firrtl.passes.Pass -import firrtl.ir.Circuit -import Annotations._ -import logger.LazyLogging +sealed abstract class CoreTransform extends PassBasedTransform -// =========================================== -// Utility Traits -// ------------------------------------------- -// Valid if all passes in transformation: -// 1) Don't produce annotations -// 2) Don't consume annotations -// 3) No component or module names are renamed -trait SimpleRun extends LazyLogging { - def run (circuit: Circuit, passes: Seq[Pass]): TransformResult = { - val result = (passes foldLeft circuit){ (c: Circuit, pass: Pass) => - val name = pass.name - val x = Utils.time(name)(pass.run(c)) - logger.debug(x.serialize) - x - } - TransformResult(result) - } -} - -// =========================================== -// Lowering Transforms -// ------------------------------------------- -// This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting -// circuit has only IR nodes, not WIR. -// TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL -class Chisel3ToHighFirrtl extends Transform with SimpleRun { - val passSeq = Seq( +/** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting + * circuit has only IR nodes, not WIR. + * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL + */ +class ChirrtlToHighFirrtl extends CoreTransform { + def inputForm = ChirrtlForm + def outputForm = HighForm + def passSeq = Seq( passes.CheckChirrtl, passes.CInferTypes, passes.CInferMDir, passes.RemoveCHIRRTL) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } -// Converts from the bare intermediate representation (ir.scala) -// to a working representation (WIR.scala) -class IRToWorkingIR extends Transform with SimpleRun { - val passSeq = Seq(passes.ToWorkingIR) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) +/** Converts from the bare intermediate representation (ir.scala) + * to a working representation (WIR.scala) + */ +class IRToWorkingIR extends CoreTransform { + def inputForm = HighForm + def outputForm = HighForm + def passSeq = Seq(passes.ToWorkingIR) } -// Resolves types, kinds, and genders, and checks the circuit legality. -// Operates on working IR nodes and high Firrtl. -class ResolveAndCheck extends Transform with SimpleRun { - val passSeq = Seq( +/** Resolves types, kinds, and genders, and checks the circuit legality. + * Operates on working IR nodes and high Firrtl. + */ +class ResolveAndCheck extends CoreTransform { + def inputForm = HighForm + def outputForm = HighForm + def passSeq = Seq( passes.CheckHighForm, passes.ResolveKinds, passes.InferTypes, @@ -91,16 +70,17 @@ class ResolveAndCheck extends Transform with SimpleRun { passes.CheckGenders, passes.InferWidths, passes.CheckWidths) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } -// Expands aggregate connects, removes dynamic accesses, and when -// statements. Checks for uninitialized values. Must accept a -// well-formed graph. -// Operates on working IR nodes. -class HighFirrtlToMiddleFirrtl extends Transform with SimpleRun { - val passSeq = Seq( +/** Expands aggregate connects, removes dynamic accesses, and when + * statements. Checks for uninitialized values. Must accept a + * well-formed graph. + * Operates on working IR nodes. + */ +class HighFirrtlToMiddleFirrtl extends CoreTransform { + def inputForm = HighForm + def outputForm = MidForm + def passSeq = Seq( passes.PullMuxes, passes.ReplaceAccesses, passes.ExpandConnects, @@ -112,16 +92,17 @@ class HighFirrtlToMiddleFirrtl extends Transform with SimpleRun { passes.ResolveGenders, passes.InferWidths, passes.CheckWidths) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } -// Expands all aggregate types into many ground-typed components. Must -// accept a well-formed graph of only middle Firrtl features. -// Operates on working IR nodes. -// TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL -class MiddleFirrtlToLowFirrtl extends Transform with SimpleRun { - val passSeq = Seq( +/** Expands all aggregate types into many ground-typed components. Must + * accept a well-formed graph of only middle Firrtl features. + * Operates on working IR nodes. + * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL + */ +class MiddleFirrtlToLowFirrtl extends CoreTransform { + def inputForm = MidForm + def outputForm = LowForm + def passSeq = Seq( passes.LowerTypes, passes.ResolveKinds, passes.InferTypes, @@ -129,87 +110,48 @@ class MiddleFirrtlToLowFirrtl extends Transform with SimpleRun { passes.InferWidths, passes.ConvertFixedToSInt, passes.Legalize) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } -// Emits Verilog. -// First optimizes for verilog width semantics with custom Primops, -// then splits complex expressions into temporary nodes. Finally, -// renames names that conflict with Verilog keywords. -// Operates on working IR nodes. -// TODO(izraelevitz): Create RenameMap from VerilogRename -class EmitVerilogFromLowFirrtl(val writer: Writer) extends Transform with SimpleRun { - val passSeq = Seq( +/** Runs a series of optimization passes on LowFirrtl + * @note This is currently required for correct Verilog emission + * TODO Fix the above note + */ +class LowFirrtlOptimization extends CoreTransform { + def inputForm = LowForm + def outputForm = LowForm + def passSeq = Seq( passes.RemoveValidIf, passes.ConstProp, passes.PadWidths, passes.ConstProp, passes.Legalize, - passes.VerilogWrap, - passes.memlib.VerilogMemDelays, + passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter passes.ConstProp, passes.SplitExpressions, passes.CommonSubexpressionElimination, - passes.DeadCodeElimination, - passes.VerilogRename, - passes.VerilogPrep) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = { - val result = run(circuit, passSeq) - (new VerilogEmitter).run(result.circuit, writer) - result - } + passes.DeadCodeElimination) } -// Emits Firrtl. -// Operates on WIR/IR nodes. -class EmitFirrtl(val writer: Writer) extends Transform { - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = { - FIRRTLEmitter.run(circuit, writer) - TransformResult(circuit) - } -} +import CompilerUtils.getLoweringTransforms -// =========================================== -// Lowering Compilers -// ------------------------------------------- -// Emits input circuit -// Will replace Chirrtl constructs with Firrtl +/** Emits input circuit + * Will replace Chirrtl constructs with Firrtl + */ class HighFirrtlCompiler extends Compiler { - def transforms(writer: Writer): Seq[Transform] = Seq( - new Chisel3ToHighFirrtl, - new IRToWorkingIR, - new EmitFirrtl(writer) - ) + def emitter = new FirrtlEmitter + def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, HighForm) } -// Emits lowered input circuit +/** Emits lowered input circuit */ class LowFirrtlCompiler extends Compiler { - def transforms(writer: Writer): Seq[Transform] = Seq( - new Chisel3ToHighFirrtl, - new IRToWorkingIR, - new passes.InlineInstances(TransID(0)), - new ResolveAndCheck, - new HighFirrtlToMiddleFirrtl, - new passes.memlib.InferReadWrite(TransID(-1)), - new passes.memlib.ReplSeqMem(TransID(-2)), - new MiddleFirrtlToLowFirrtl, - new EmitFirrtl(writer) - ) + def emitter = new FirrtlEmitter + def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, LowForm) } -// Emits Verilog +/** Emits Verilog */ class VerilogCompiler extends Compiler { - def transforms(writer: Writer): Seq[Transform] = Seq( - new Chisel3ToHighFirrtl, - new IRToWorkingIR, - new ResolveAndCheck, - new HighFirrtlToMiddleFirrtl, - new passes.memlib.InferReadWrite(TransID(-1)), - new passes.memlib.ReplSeqMem(TransID(-2)), - new MiddleFirrtlToLowFirrtl, - new passes.InlineInstances(TransID(0)), - new EmitVerilogFromLowFirrtl(writer) - ) + def emitter = new VerilogEmitter + def transforms: Seq[Transform] = + getLoweringTransforms(ChirrtlForm, LowForm) :+ (new LowFirrtlOptimization) } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 22a3eac6..7c023ac8 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -44,7 +44,7 @@ import firrtl.WrappedType._ import scala.collection.mutable import scala.collection.mutable.{StringBuilder, ArrayBuffer, LinkedHashMap, HashMap, HashSet} import java.io.PrintWriter -import com.typesafe.scalalogging.LazyLogging +import logger.LazyLogging //import scala.reflect.runtime.universe._ class FIRRTLException(str: String) extends Exception(str) diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 5c80baff..c741dc06 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -9,34 +9,37 @@ import firrtl.Annotations._ import scala.collection.mutable // Tags an annotation to be consumed by this pass -case class InlineAnnotation(target: Named, tID: TransID) extends Annotation with Loose with Unstable { +case class InlineAnnotation(target: Named) extends Annotation with Loose with Unstable { def duplicate(n: Named) = this.copy(target=n) + def transform = classOf[InlineInstances] } // Only use on legal Firrtl. Specifically, the restriction of // instance loops must have been checked, or else this pass can // infinitely recurse -class InlineInstances (transID: TransID) extends Transform { +class InlineInstances extends Transform { + def inputForm = LowForm + def outputForm = LowForm val inlineDelim = "$" - def name = "Inline Instances" - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = { - annotationMap.get(transID) match { - case None => TransformResult(circuit, None, None) - case Some(map) => - val moduleNames = mutable.HashSet[ModuleName]() - val instanceNames = mutable.HashSet[ComponentName]() - map.values.foreach {x: Annotation => x match { - case InlineAnnotation(ModuleName(mod, cir), _) => moduleNames += ModuleName(mod, cir) - case InlineAnnotation(ComponentName(com, mod), _) => instanceNames += ComponentName(com, mod) - case _ => throw new PassException("Annotation must be InlineAnnotation") - }} - check(circuit, moduleNames.toSet, instanceNames.toSet) - run(circuit, moduleNames.toSet, instanceNames.toSet) + override def name = "Inline Instances" - // Default behavior is to error if more than one annotation for inlining - // This could potentially change - case _ => throw new PassException("Found more than one circuit annotation of InlineCAKind!") + private def collectAnns(anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) { + case ((modNames, instNames), ann) => ann match { + case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) + case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) + case _ => throw new PassException("Annotation must be InlineAnnotation") + } } + + def execute(state: CircuitState): CircuitState = { + // TODO Add error check for more than one annotation for inlining + // TODO Propagate other annotations + val result = for { + myAnnotations <- getMyAnnotations(state) + (modNames, instNames) = collectAnns(myAnnotations.values) + } yield run(state.circuit, modNames, instNames) + result getOrElse state // Return state if nothing to do } // Checks the following properties: @@ -78,7 +81,10 @@ class InlineInstances (transID: TransID) extends Transform { if (errors.nonEmpty) throw new PassExceptions(errors) } - def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): TransformResult = { + def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): CircuitState = { + // Check annotations and circuit match up + check(c, modsToInline, instsToInline) + // ---- Rename functions/data ---- val renameMap = mutable.HashMap[Named,Seq[Named]]() // Updates renameMap with new names @@ -168,6 +174,6 @@ class InlineInstances (transID: TransID) extends Transform { val top = c.modules.find(m => m.name == c.main).get onModule(top) val modulesx = c.modules.map(m => inlinedModules(m.name)) - TransformResult(Circuit(c.info, modulesx, c.main), Some(RenameMap(renameMap.toMap)), None) + CircuitState(Circuit(c.info, modulesx, c.main), LowForm, None, Some(RenameMap(renameMap.toMap))) } } diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 10cc8f88..c98dd4ca 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -5,20 +5,22 @@ import ir._ import Annotations._ import wiring._ -class CreateMemoryAnnotations(reader: Option[YamlFileReader], replaceID: TransID, wiringID: TransID) extends Transform { - def name = "Create Memory Annotations" - def execute(c: Circuit, map: AnnotationMap): TransformResult = reader match { - case None => TransformResult(c) +class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform { + def inputForm = MidForm + def outputForm = MidForm + override def name = "Create Memory Annotations" + def execute(state: CircuitState): CircuitState = reader match { + case None => state case Some(r) => import CustomYAMLProtocol._ r.parse[Config] match { case Seq(config) => - val cN = CircuitName(c.main) - val top = TopAnnotation(ModuleName(config.top.name, cN), wiringID) - val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)), wiringID) - val pin = PinAnnotation(cN, replaceID, config.pin.name) - TransformResult(c, None, Some(AnnotationMap(Seq(top, source, pin)))) - case Nil => TransformResult(c, None, None) + val cN = CircuitName(state.circuit.main) + val top = TopAnnotation(ModuleName(config.top.name, cN)) + val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN))) + val pin = PinAnnotation(cN, config.pin.name) + state.copy(annotations = Some(AnnotationMap(Seq(top, source, pin)))) + case Nil => state case _ => error("Can only have one config in yaml file") } } diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 28291135..2d6f4e96 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -38,10 +38,10 @@ import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin} import WrappedExpression.weq import Annotations._ -case class InferReadWriteAnnotation(t: String, tID: TransID) - extends Annotation with Loose with Unstable { +case class InferReadWriteAnnotation(t: String) extends Annotation with Loose with Unstable { val target = CircuitName(t) def duplicate(n: Named) = this.copy(t=n.name) + def transform = classOf[InferReadWrite] } // This pass examine the enable signals of the read & write ports of memories @@ -168,7 +168,9 @@ object InferReadWritePass extends Pass { // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite(transID: TransID) extends Transform with SimpleRun { +class InferReadWrite extends Transform with PassBased { + def inputForm = MidForm + def outputForm = MidForm def passSeq = Seq( InferReadWritePass, CheckInitialization, @@ -176,11 +178,12 @@ class InferReadWrite(transID: TransID) extends Transform with SimpleRun { ResolveKinds, ResolveGenders ) - def execute(c: Circuit, map: AnnotationMap) = map get transID match { - case Some(p) => p get CircuitName(c.main) match { - case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq) - case _ => sys.error("Unexpected annotation for InferReadWrite") - } - case _ => TransformResult(c) + def execute(state: CircuitState): CircuitState = { + val result = for { + myAnnotations <- getMyAnnotations(state) + InferReadWriteAnnotation(_) <- myAnnotations get CircuitName(state.circuit.main) + resCircuit = runPasses(state.circuit) + } yield state.copy(circuit = resCircuit) + result getOrElse state // Return state if nothing to do } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index 9ab496d2..ae872639 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -16,7 +16,8 @@ import wiring._ /** Annotates the name of the pin to add for WiringTransform */ -case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends Annotation with Loose with Unstable { +case class PinAnnotation(target: CircuitName, pin: String) extends Annotation with Loose with Unstable { + def transform = classOf[ReplaceMemMacros] def duplicate(n: Named) = n match { case n: CircuitName => this.copy(target = n) case _ => throwInternalError @@ -27,8 +28,10 @@ case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends * This will not generate wmask ports if not needed. * Creates the minimum # of black boxes needed by the design. */ -class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) extends Transform { - def name = "Replace Memory Macros" +class ReplaceMemMacros(writer: ConfWriter) extends Transform { + override def name = "Replace Memory Macros" + def inputForm = MidForm + def outputForm = MidForm /** Return true if mask granularity is per bit, false if per byte or unspecified */ @@ -206,7 +209,8 @@ class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) ext map updateStmtRefs(memPortMap)) } - def execute(c: Circuit, map: AnnotationMap): TransformResult = { + def execute(state: CircuitState): CircuitState = { + val c = state.circuit val namespace = Namespace(c) val memMods = new Modules val nameMap = new NameMap @@ -214,15 +218,15 @@ class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) ext val modules = c.modules map updateMemMods(namespace, nameMap, memMods) // print conf writer.serialize() - val pin = map get myID match { - case Some(p) => + val pin = getMyAnnotations(state) match { + case Some(p) => p.values.head match { - case PinAnnotation(c, _, pin) => pin + case PinAnnotation(c, pin) => pin case _ => error(s"Bad Annotations: ${p.values}") } case None => "pin" } - val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), wiringID, pin) } - TransformResult(c.copy(modules = modules ++ memMods), None, Some(AnnotationMap(annos))) - } + val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), pin) } + CircuitState(c.copy(modules = modules ++ memMods), inputForm, Some(AnnotationMap(annos))) + } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 01f020f5..818bd9cc 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -61,8 +61,7 @@ class ConfWriter(filename: String) { } } -case class ReplSeqMemAnnotation(t: String, tID: TransID) - extends Annotation with Loose with Unstable { +case class ReplSeqMemAnnotation(t: String) extends Annotation with Loose with Unstable { val usage = """ [Optional] ReplSeqMem @@ -91,52 +90,60 @@ Optional Arguments: ) val target = CircuitName(passCircuit) def duplicate(n: Named) = this copy (t = t.replace(s"-c:$passCircuit", s"-c:${n.name}")) + def transform = classOf[ReplSeqMem] } -case class SimpleTransform(p: Pass) extends Transform { - def execute(c: Circuit, map: AnnotationMap): TransformResult = - TransformResult(p.run(c)) +class SimpleTransform(p: Pass, form: CircuitForm) extends Transform { + def inputForm = form + def outputForm = form + def execute(state: CircuitState): CircuitState = state.copy(circuit = p.run(state.circuit)) } -class ReplSeqMem(transID: TransID) extends Transform with SimpleRun { +class SimpleMidTransform(p: Pass) extends SimpleTransform(p, MidForm) + +// SimpleRun instead of PassBased because of the arguments to passSeq +class ReplSeqMem extends Transform with SimpleRun { + def inputForm = MidForm + def outputForm = MidForm def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = - Seq(SimpleTransform(Legalize), - SimpleTransform(ToMemIR), - SimpleTransform(ResolveMaskGranularity), - SimpleTransform(RenameAnnotatedMemoryPorts), - SimpleTransform(ResolveMemoryReference), - new CreateMemoryAnnotations(inConfigFile, TransID(-7), TransID(-8)), - new ReplaceMemMacros(outConfigFile, TransID(-7), TransID(-8)), - new WiringTransform(TransID(-8)), - SimpleTransform(RemoveEmpty), - SimpleTransform(CheckInitialization), - SimpleTransform(InferTypes), - SimpleTransform(Uniquify), - SimpleTransform(ResolveKinds), - SimpleTransform(ResolveGenders)) - def run(circuit: Circuit, map: AnnotationMap, xForms: Seq[Transform]): TransformResult = { - (xForms.foldLeft(TransformResult(circuit, None, Some(map)))) { case (tr: TransformResult, xForm: Transform) => - val x = xForm.execute(tr.circuit, tr.annotation.get) - x.annotation match { - case None => TransformResult(x.circuit, None, Some(map)) - case Some(ann) => TransformResult(x.circuit, None, Some( - AnnotationMap(ann.annotations ++ tr.annotation.get.annotations))) + Seq(new SimpleMidTransform(Legalize), + new SimpleMidTransform(ToMemIR), + new SimpleMidTransform(ResolveMaskGranularity), + new SimpleMidTransform(RenameAnnotatedMemoryPorts), + new SimpleMidTransform(ResolveMemoryReference), + new CreateMemoryAnnotations(inConfigFile), + new ReplaceMemMacros(outConfigFile), + new WiringTransform, + new SimpleMidTransform(RemoveEmpty), + new SimpleMidTransform(CheckInitialization), + new SimpleMidTransform(InferTypes), + new SimpleMidTransform(Uniquify), + new SimpleMidTransform(ResolveKinds), + new SimpleMidTransform(ResolveGenders)) + def run(state: CircuitState, xForms: Seq[Transform]): CircuitState = { + xForms.foldLeft(state) { case (curState: CircuitState, xForm: Transform) => + val res = xForm.execute(state) + res.annotations match { + case None => CircuitState(res.circuit, res.form, state.annotations) + case Some(ann) => CircuitState(res.circuit, res.form, Some( + AnnotationMap(ann.annotations ++ curState.annotations.get.annotations))) } } } - def execute(c: Circuit, map: AnnotationMap) = map get transID match { - case Some(p) => p get CircuitName(c.main) match { - case Some(ReplSeqMemAnnotation(t, _)) => - val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") - val inConfigFile = { - if (inputFileName.isEmpty) None - else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) - else error("Input configuration file does not exist!") - } - val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName)) - run(c, map, passSeq(inConfigFile, outConfigFile)) - case _ => error("Unexpected transform annotation") + def execute(state: CircuitState): CircuitState = + getMyAnnotations(state) match { + case Some(p) => p get CircuitName(state.circuit.main) match { + case Some(ReplSeqMemAnnotation(t)) => + val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") + val inConfigFile = { + if (inputFileName.isEmpty) None + else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) + else error("Input configuration file does not exist!") + } + val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName)) + run(state, passSeq(inConfigFile, outConfigFile)) + case _ => error("Unexpected transform annotation") + } + case None => state // Do nothing if there are no annotations } - case _ => TransformResult(c) - } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index 919948b6..59e76d65 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -11,7 +11,8 @@ import WiringUtils._ /** A component, e.g. register etc. Must be declared only once under the TopAnnotation */ -case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotation with Loose with Unstable { +case class SourceAnnotation(target: ComponentName) extends Annotation with Loose with Unstable { + def transform = classOf[WiringTransform] def duplicate(n: Named) = n match { case n: ComponentName => this.copy(target = n) case _ => throwInternalError @@ -20,7 +21,8 @@ case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotat /** A module, e.g. ExtModule etc., that should add the input pin */ -case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends Annotation with Loose with Unstable { +case class SinkAnnotation(target: ModuleName, pin: String) extends Annotation with Loose with Unstable { + def transform = classOf[WiringTransform] def duplicate(n: Named) = n match { case n: ModuleName => this.copy(target = n) case _ => throwInternalError @@ -30,7 +32,8 @@ case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends /** A module under which all sink module must be declared, and there is only * one source component */ -case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation with Loose with Unstable { +case class TopAnnotation(target: ModuleName) extends Annotation with Loose with Unstable { + def transform = classOf[WiringTransform] def duplicate(n: Named) = n match { case n: ModuleName => this.copy(target = n) case _ => throwInternalError @@ -49,13 +52,15 @@ case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation wi * Notes: * - No module uniquification occurs (due to imposed restrictions) */ -class WiringTransform(transID: TransID) extends Transform with SimpleRun { +class WiringTransform extends Transform with SimpleRun { + def inputForm = MidForm + def outputForm = MidForm def passSeq(wi: WiringInfo) = Seq(new Wiring(wi), InferTypes, ResolveKinds, ResolveGenders) - def execute(c: Circuit, map: AnnotationMap) = map get transID match { + def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { case Some(p) => val sinks = mutable.HashMap[String, String]() val sources = mutable.Set[String]() @@ -63,18 +68,20 @@ class WiringTransform(transID: TransID) extends Transform with SimpleRun { val comp = mutable.Set[String]() p.values.foreach { a => a match { - case SinkAnnotation(m, _, pin) => sinks(m.name) = pin - case SourceAnnotation(c, _) => + case SinkAnnotation(m, pin) => sinks(m.name) = pin + case SourceAnnotation(c) => sources += c.module.name comp += c.name - case TopAnnotation(m, _) => tops += m.name + case TopAnnotation(m) => tops += m.name } } (sources.size, tops.size, sinks.size, comp.size) match { - case (0, 0, p, 0) => TransformResult(c) - case (1, 1, p, 1) if p > 0 => run(c, passSeq(WiringInfo(sources.head, comp.head, sinks.toMap, tops.head))) + case (0, 0, p, 0) => state + case (1, 1, p, 1) if p > 0 => + val winfo = WiringInfo(sources.head, comp.head, sinks.toMap, tops.head) + state.copy(circuit = runPasses(state.circuit, passSeq(winfo))) case _ => error("Wrong number of sources, tops, or sinks!") } - case None => TransformResult(c) + case None => state } } diff --git a/src/test/resources/features/CustomTransform.fir b/src/test/resources/features/CustomTransform.fir new file mode 100644 index 00000000..941a9e9c --- /dev/null +++ b/src/test/resources/features/CustomTransform.fir @@ -0,0 +1,33 @@ +circuit CustomTransform : + ; Replaced in custom transform by an implementation + extmodule Delay : + input clk : Clock + input reset : UInt<1> + input a : UInt<32> + input en : UInt<1> + output b : UInt<32> + + module CustomTransform : + input clk : Clock + input reset : UInt<1> + + reg cycle : UInt<32>, clk with : (reset => (reset, UInt<32>(0))) + cycle <= tail(add(cycle, UInt<32>(1)), 1) + + inst delay of Delay + delay.clk <= clk + delay.reset <= reset + delay.a <= UInt(0) + delay.en <= UInt(0) + + when eq(cycle, UInt(0)) : + delay.en <= UInt(1) + delay.a <= UInt("hdeadbeef") + when eq(cycle, UInt(1)) : + when neq(delay.b, UInt("hdeadbeef")) : + printf(clk, UInt(1), "Assertion failed!\n") + stop(clk, UInt(1), 1) + when eq(cycle, UInt(2)) : + printf(clk, UInt(1), "Success!\n") + stop(clk, UInt(1), 0) + diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 0312df5d..c395139b 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -9,14 +9,16 @@ import org.scalatest.junit.JUnitRunner import firrtl.ir.Circuit import firrtl.Parser import firrtl.{ + CircuitState, ResolveAndCheck, RenameMap, Compiler, - CompilerResult, - VerilogCompiler + ChirrtlForm, + LowForm, + VerilogCompiler, + Transform } import firrtl.Annotations.{ - TransID, Named, CircuitName, ModuleName, @@ -39,17 +41,17 @@ import firrtl.Annotations.{ */ trait AnnotationSpec extends LowTransformSpec { // Dummy transform - def transform = new ResolveAndCheck() + def transform = new CustomResolveAndCheck(LowForm) // Check if Annotation Exception is thrown override def failingexecute(writer: Writer, annotations: AnnotationMap, input: String) = { intercept[AnnotationException] { - compile(parse(input), annotations, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer) } } def execute(writer: Writer, annotations: AnnotationMap, input: String, check: Annotation) = { - val cr = compile(parse(input), annotations, writer) - (cr.annotationMap.annotations.head) should be (check) + val cr = compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer) + (cr.annotations.get.annotations.head) should be (check) } } @@ -63,7 +65,6 @@ trait AnnotationSpec extends LowTransformSpec { */ class AnnotationTests extends AnnotationSpec with Matchers { def getAMap (a: Annotation): AnnotationMap = new AnnotationMap(Seq(a)) - val tID = TransID(1) val input = """circuit Top : | module Top : @@ -76,11 +77,12 @@ class AnnotationTests extends AnnotationSpec with Matchers { val cName = ComponentName("c", mName) "Loose and Sticky annotation on a node" should "pass through" in { - case class TestAnnotation(target: Named, tID: TransID) extends Annotation with Loose with Sticky { + case class TestAnnotation(target: Named) extends Annotation with Loose with Sticky { def duplicate(to: Named) = this.copy(target=to) + def transform = classOf[Transform] } val w = new StringWriter() - val ta = TestAnnotation(cName, tID) + val ta = TestAnnotation(cName) execute(w, getAMap(ta), input, ta) } } diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala index d1e07eae..3a67bf04 100644 --- a/src/test/scala/firrtlTests/AttachSpec.scala +++ b/src/test/scala/firrtlTests/AttachSpec.scala @@ -37,10 +37,9 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class InoutVerilog extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], compiler: Compiler) = { val writer = new StringWriter() - compiler.compile(parse(input), new AnnotationMap(Seq.empty), writer) + compiler.compile(CircuitState(parse(input), ChirrtlForm), writer) val lines = writer.toString().split("\n") map normalized expected foreach { e => lines should contain(e) @@ -176,7 +175,6 @@ class InoutVerilog extends FirrtlFlatSpec { } class AttachAnalogSpec extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala index 719a3334..51663eaf 100644 --- a/src/test/scala/firrtlTests/CInferMDirSpec.scala +++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala @@ -63,13 +63,12 @@ class CInferMDir extends LowTransformSpec { } } - object CInferMDirCheck extends Transform with SimpleRun { - def execute(c: Circuit, map: AnnotationMap) = - run(c, Seq(ConstProp, CInferMDirCheckPass)) + def transform = new PassBasedTransform { + def inputForm = LowForm + def outputForm = LowForm + def passSeq = Seq(ConstProp, CInferMDirCheckPass) } - def transform = CInferMDirCheck - "Memory" should "have correct mem port directions" in { val input = """ circuit foo : @@ -97,7 +96,7 @@ circuit foo : val annotationMap = AnnotationMap(Nil) val writer = new java.io.StringWriter - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) // Check correctness of firrtl parse(writer.toString) } diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala index e2eaf690..e8dc60ae 100644 --- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala +++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala @@ -36,7 +36,6 @@ import firrtl.Parser.IgnoreInfo import firrtl.passes._ class CheckInitializationSpec extends FirrtlFlatSpec { - private def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private val passes = Seq( ToWorkingIR, CheckHighForm, diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index e0691a6b..63397da8 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -76,13 +76,12 @@ class ChirrtlMemSpec extends LowTransformSpec { } } - object MemEnableCheck extends Transform with SimpleRun { - def execute(c: Circuit, map: AnnotationMap) = - run(c, Seq(ConstProp, MemEnableCheckPass)) + def transform = new PassBasedTransform { + def inputForm = LowForm + def outputForm = LowForm + def passSeq = Seq(ConstProp, MemEnableCheckPass) } - def transform = MemEnableCheck - "Sequential Memory" should "have correct enable signals" in { val input = """ circuit foo : @@ -104,7 +103,7 @@ circuit foo : val annotationMap = AnnotationMap(Nil) val writer = new java.io.StringWriter - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) // Check correctness of firrtl parse(writer.toString) } @@ -131,7 +130,7 @@ circuit foo : val annotationMap = AnnotationMap(Nil) val writer = new java.io.StringWriter - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) // Check correctness of firrtl parse(writer.toString) } diff --git a/src/test/scala/firrtlTests/CompilerTests.scala b/src/test/scala/firrtlTests/CompilerTests.scala index 2eab6e0f..28d09c2d 100644 --- a/src/test/scala/firrtlTests/CompilerTests.scala +++ b/src/test/scala/firrtlTests/CompilerTests.scala @@ -8,13 +8,14 @@ import org.scalatest.junit.JUnitRunner import firrtl.ir.Circuit import firrtl.{ - HighFirrtlCompiler, - LowFirrtlCompiler, - VerilogCompiler, - Compiler, - Parser + ChirrtlForm, + CircuitState, + Compiler, + HighFirrtlCompiler, + LowFirrtlCompiler, + Parser, + VerilogCompiler } -import firrtl.Annotations.AnnotationMap /** * An example methodology for testing Firrtl compilers. @@ -30,7 +31,7 @@ abstract class CompilerSpec extends FlatSpec { def input: String def check: String def getOutput: String = { - compiler.compile(parse(input), new AnnotationMap(Seq.empty), writer) + compiler.compile(CircuitState(parse(input), ChirrtlForm), writer) writer.toString() } } diff --git a/src/test/scala/firrtlTests/CompilerUtilsSpec.scala b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala new file mode 100644 index 00000000..1d349db1 --- /dev/null +++ b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala @@ -0,0 +1,76 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.CompilerUtils.mergeTransforms + +class CompilerUtilsSpec extends FirrtlFlatSpec { + + def genTransform(_inputForm: CircuitForm, _outputForm: CircuitForm) = new Transform { + def inputForm = _inputForm + def outputForm = _outputForm + def execute(state: CircuitState): CircuitState = state + } + + // Core lowering transforms + val chirrtlToHigh = genTransform(ChirrtlForm, HighForm) + val highToMid = genTransform(HighForm, MidForm) + val midToLow = genTransform(MidForm, LowForm) + val chirrtlToLowList = List(chirrtlToHigh, highToMid, midToLow) + + // Custom transforms + val chirrtlToChirrtl = genTransform(ChirrtlForm, ChirrtlForm) + val highToHigh = genTransform(HighForm, HighForm) + val midToMid = genTransform(MidForm, MidForm) + val lowToLow = genTransform(LowForm, LowForm) + + val lowToHigh = genTransform(LowForm, HighForm) + + val lowToLowTwo = genTransform(LowForm, LowForm) + + behavior of "mergeTransforms" + + it should "do nothing if there are no custom transforms" in { + mergeTransforms(chirrtlToLowList, List.empty) should be (chirrtlToLowList) + } + + it should "insert transforms at the correct place" in { + mergeTransforms(chirrtlToLowList, List(chirrtlToChirrtl)) should be + (chirrtlToChirrtl +: chirrtlToLowList) + mergeTransforms(chirrtlToLowList, List(highToHigh)) should be + (List(chirrtlToHigh, highToHigh, highToMid, midToLow)) + mergeTransforms(chirrtlToLowList, List(midToMid)) should be + (List(chirrtlToHigh, highToMid, midToMid, midToLow)) + mergeTransforms(chirrtlToLowList, List(lowToLow)) should be + (chirrtlToLowList :+ lowToLow) + } + + it should "insert transforms at the last legal location" in { + lowToLow should not be (lowToLowTwo) // sanity check + mergeTransforms(chirrtlToLowList :+ lowToLow, List(lowToLowTwo)).last should be (lowToLowTwo) + } + + it should "insert multiple transforms correctly" in { + mergeTransforms(chirrtlToLowList, List(highToHigh, lowToLow)) should be + (List(chirrtlToHigh, highToHigh, highToMid, midToLow, lowToLow)) + } + + it should "handle transforms that raise the form" in { + mergeTransforms(chirrtlToLowList, List(lowToHigh)) match { + case chirrtlToHigh :: highToMid :: midToLow :: lowToHigh :: remainder => + // Remainder will be the actual Firrtl lowering transforms + remainder.head.inputForm should be (HighForm) + remainder.last.outputForm should be (LowForm) + case _ => fail() + } + } + + // Order is not always maintained, see note on function Scaladoc + it should "maintain order of custom tranforms" in { + mergeTransforms(chirrtlToLowList, List(lowToLow, lowToLowTwo)) should be + (chirrtlToLowList ++ List(lowToLow, lowToLowTwo)) + } + +} + diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index bfe58a2c..f6bfa5ef 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -22,7 +22,6 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { ResolveGenders, InferWidths, ConstProp) - def parse(input: String): Circuit = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def exec (input: String) = { passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/CustomTransformSpec.scala b/src/test/scala/firrtlTests/CustomTransformSpec.scala new file mode 100644 index 00000000..4a3faf6b --- /dev/null +++ b/src/test/scala/firrtlTests/CustomTransformSpec.scala @@ -0,0 +1,51 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl.ir.Circuit +import firrtl._ +import firrtl.passes.Pass +import firrtl.ir._ + +class CustomTransformSpec extends FirrtlFlatSpec { + behavior of "Custom Transforms" + + they should "be able to introduce high firrtl" in { + // Simple module + val delayModuleString = """ + |circuit Delay : + | module Delay : + | input clk : Clock + | input reset : UInt<1> + | input a : UInt<32> + | input en : UInt<1> + | output b : UInt<32> + | + | reg r : UInt<32>, clk + | r <= r + | when en : + | r <= a + | b <= r + |""".stripMargin + val delayModuleCircuit = parse(delayModuleString) + val delayModule = delayModuleCircuit.modules.find(_.name == delayModuleCircuit.main).get + + class ReplaceExtModuleTransform extends PassBasedTransform { + class ReplaceExtModule extends Pass { + def name = "Replace External Module" + def run(c: Circuit): Circuit = c.copy( + modules = c.modules map { + case ExtModule(_, "Delay", _, _, _) => delayModule + case other => other + } + ) + } + def passSeq = Seq(new ReplaceExtModule) + def inputForm = LowForm + def outputForm = HighForm + } + + runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform)) + } +} + diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala index 8bbecaeb..06963708 100644 --- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala +++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala @@ -36,7 +36,6 @@ import firrtl.ir._ import firrtl.Parser.IgnoreInfo class ExpandWhensSpec extends FirrtlFlatSpec { - private def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, notExpected: String, passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala index f491b0f5..83cccf3b 100644 --- a/src/test/scala/firrtlTests/FirrtlSpec.scala +++ b/src/test/scala/firrtlTests/FirrtlSpec.scala @@ -36,6 +36,7 @@ import org.scalatest.prop._ import scala.io.Source import firrtl._ +import firrtl.Parser.IgnoreInfo import firrtl.Annotations.AnnotationMap // This trait is borrowed from Chisel3, ideally this code should only exist in one location @@ -131,6 +132,7 @@ trait BackendCompilationUtilities { } trait FirrtlRunners extends BackendCompilationUtilities { + def parse(str: String) = Parser.parse(str.split("\n").toIterator, IgnoreInfo) lazy val cppHarness = new File(s"/top.cpp") /** Compile a Firrtl file * @@ -141,6 +143,7 @@ trait FirrtlRunners extends BackendCompilationUtilities { def compileFirrtlTest( prefix: String, srcDir: String, + customTransforms: Seq[Transform] = Seq.empty, annotations: AnnotationMap = new AnnotationMap(Seq.empty)): File = { val testDir = createTempDirectory(prefix) copyResourceToFile(s"${srcDir}/${prefix}.fir", new File(testDir, s"${prefix}.fir")) @@ -150,6 +153,7 @@ trait FirrtlRunners extends BackendCompilationUtilities { s"$testDir/$prefix.v", new VerilogCompiler(), Parser.IgnoreInfo, + customTransforms, annotations) testDir } @@ -164,8 +168,9 @@ trait FirrtlRunners extends BackendCompilationUtilities { prefix: String, srcDir: String, verilogPrefixes: Seq[String] = Seq.empty, + customTransforms: Seq[Transform] = Seq.empty, annotations: AnnotationMap = new AnnotationMap(Seq.empty)) = { - val testDir = compileFirrtlTest(prefix, srcDir, annotations) + val testDir = compileFirrtlTest(prefix, srcDir, customTransforms, annotations) val harness = new File(testDir, s"top.cpp") copyResourceToFile(cppHarness.toString, harness) diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index be663872..b6e8f726 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -61,19 +61,19 @@ class InferReadWriteSpec extends SimpleTransformSpec { } } - object InferReadWriteCheck extends Transform with SimpleRun { - def execute (c: Circuit, map: AnnotationMap) = - run(c, Seq(InferReadWriteCheckPass)) + class InferReadWriteCheck extends PassBasedTransform { + def inputForm = MidForm + def outputForm = MidForm + def passSeq = Seq(InferReadWriteCheckPass) } - def transforms (writer: java.io.Writer) = Seq( - new Chisel3ToHighFirrtl(), - new IRToWorkingIR(), - new ResolveAndCheck(), - new HighFirrtlToMiddleFirrtl(), - new memlib.InferReadWrite(TransID(-1)), - InferReadWriteCheck, - new EmitFirrtl(writer) + def transforms = Seq( + new ChirrtlToHighFirrtl, + new IRToWorkingIR, + new ResolveAndCheck, + new HighFirrtlToMiddleFirrtl, + new memlib.InferReadWrite, + new InferReadWriteCheck ) "Infer ReadWrite Ports" should "infer readwrite ports for the same clock" in { @@ -100,9 +100,9 @@ circuit sram6t : T_5 <= io.wdata """.stripMargin - val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t", TransID(-1)))) + val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t"))) val writer = new java.io.StringWriter - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) // Check correctness of firrtl parse(writer.toString) } @@ -132,10 +132,10 @@ circuit sram6t : T_5 <= io.wdata """.stripMargin - val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t", TransID(-1)))) + val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t"))) val writer = new java.io.StringWriter intercept[InferReadWriteCheckException] { - compile(parse(input), annotationMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), writer) } } } diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 5f19af5c..f7845cc7 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -14,7 +14,6 @@ import firrtl.Annotations.{ CircuitName, ModuleName, ComponentName, - TransID, Annotation, AnnotationMap } @@ -24,9 +23,8 @@ import firrtl.passes.{InlineInstances, InlineAnnotation} /** * Tests inline instances transformation */ -class InlineInstancesTests extends HighTransformSpec { - val tID = TransID(0) - val transform = new InlineInstances(tID) +class InlineInstancesTests extends LowTransformSpec { + def transform = new InlineInstances "The module Inline" should "be inlined" in { val input = """circuit Top : @@ -48,14 +46,14 @@ class InlineInstancesTests extends HighTransformSpec { | wire i$a : UInt<32> | wire i$b : UInt<32> | i$b <= i$a - | i$a <= a | b <= i$b + | i$a <= a | module Inline : | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top"))))) execute(writer, aMap, input, check) } @@ -85,15 +83,15 @@ class InlineInstancesTests extends HighTransformSpec { | wire i1$a : UInt<32> | wire i1$b : UInt<32> | i1$b <= i1$a + | b <= i1$b | i0$a <= a | i1$a <= i0$b - | b <= i1$b | module Simple : | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top"))))) execute(writer, aMap, input, check) } @@ -121,15 +119,15 @@ class InlineInstancesTests extends HighTransformSpec { | wire i0$b : UInt<32> | i0$b <= i0$a | inst i1 of Simple + | b <= i1.b | i0$a <= a | i1.a <= i0$b - | b <= i1.b | module Simple : | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ComponentName("i0",ModuleName("Top", CircuitName("Top"))), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ComponentName("i0",ModuleName("Top", CircuitName("Top")))))) execute(writer, aMap, input, check) } @@ -163,9 +161,9 @@ class InlineInstancesTests extends HighTransformSpec { | wire i0$b : UInt<32> | i0$b <= i0$a | inst i1 of B + | b <= i1.b | i0$a <= a | i1.a <= i0$b - | b <= i1.b | module A : | input a : UInt<32> | output b : UInt<32> @@ -176,10 +174,10 @@ class InlineInstancesTests extends HighTransformSpec { | wire i$a : UInt<32> | wire i$b : UInt<32> | i$b <= i$a - | i$a <= a - | b <= i$b""".stripMargin + | b <= i$b + | i$a <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) execute(writer, aMap, input, check) } @@ -199,7 +197,7 @@ class InlineInstancesTests extends HighTransformSpec { | input a : UInt<32> | output b : UInt<32>""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) failingexecute(writer, aMap, input) } // 2) ext instance @@ -216,7 +214,7 @@ class InlineInstancesTests extends HighTransformSpec { | input a : UInt<32> | output b : UInt<32>""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) failingexecute(writer, aMap, input) } // 3) no module @@ -228,7 +226,7 @@ class InlineInstancesTests extends HighTransformSpec { | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) failingexecute(writer, aMap, input) } // 4) no inst @@ -240,7 +238,7 @@ class InlineInstancesTests extends HighTransformSpec { | output b : UInt<32> | b <= a""".stripMargin val writer = new StringWriter() - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")), tID))) + val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) failingexecute(writer, aMap, input) } } diff --git a/src/test/scala/firrtlTests/MultiThreadingSpec.scala b/src/test/scala/firrtlTests/MultiThreadingSpec.scala index bfaed330..b2934314 100644 --- a/src/test/scala/firrtlTests/MultiThreadingSpec.scala +++ b/src/test/scala/firrtlTests/MultiThreadingSpec.scala @@ -2,6 +2,8 @@ package firrtlTests +import firrtl.{ChirrtlForm, CircuitState, Compiler, Annotations} + import scala.concurrent.{Future, Await, ExecutionContext} import scala.concurrent.duration.Duration @@ -13,7 +15,7 @@ class MultiThreadingSpec extends FirrtlPropSpec { def runCompiler(input: Seq[String], compiler: firrtl.Compiler): String = { val writer = new java.io.StringWriter val parsedInput = firrtl.Parser.parse(input) - compiler.compile(parsedInput,new firrtl.Annotations.AnnotationMap(Seq.empty), writer) + compiler.compile(CircuitState(parsedInput, ChirrtlForm), writer) writer.toString } // The parameters we're testing with diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index e5269396..e574d31f 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -4,21 +4,27 @@ import com.typesafe.scalalogging.LazyLogging import java.io.{StringWriter,Writer} import org.scalatest.{FlatSpec, Matchers} import org.scalatest.junit.JUnitRunner -import firrtl.{Parser,FIRRTLEmitter} import firrtl.ir.Circuit import firrtl.Parser.IgnoreInfo -import firrtl.passes.{Pass, PassExceptions} +import firrtl.passes.{Pass, PassExceptions, RemoveEmpty} import firrtl.{ Transform, - TransformResult, + PassBasedTransform, + CircuitState, + CircuitForm, + ChirrtlForm, + HighForm, + MidForm, + LowForm, SimpleRun, - Chisel3ToHighFirrtl, + ChirrtlToHighFirrtl, IRToWorkingIR, ResolveAndCheck, HighFirrtlToMiddleFirrtl, MiddleFirrtlToLowFirrtl, - EmitFirrtl, - Compiler + FirrtlEmitter, + Compiler, + Parser } import firrtl.Annotations.AnnotationMap @@ -26,58 +32,66 @@ import firrtl.Annotations.AnnotationMap // An example methodology for testing Firrtl Passes // Spec class should extend this class abstract class SimpleTransformSpec extends FlatSpec with Matchers with Compiler with LazyLogging { + def emitter = new FirrtlEmitter + // Utility function def parse(s: String): Circuit = Parser.parse(s.split("\n").toIterator, infoMode = IgnoreInfo) // Executes the test. Call in tests. def execute(writer: Writer, annotations: AnnotationMap, input: String, check: String) = { - compile(parse(input), annotations, writer) - logger.debug(writer.toString) - logger.debug(check) - (parse(writer.toString)) should be (parse(check)) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer) + val actual = RemoveEmpty.run(parse(writer.toString)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be (expected) } // Executes the test, should throw an error def failingexecute(writer: Writer, annotations: AnnotationMap, input: String): Exception = { intercept[PassExceptions] { - compile(parse(input), annotations, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), writer) } } } +class CustomResolveAndCheck(form: CircuitForm) extends PassBasedTransform { + private val wrappedTransform = new ResolveAndCheck + def inputForm = form + def outputForm = form + def passSeq = wrappedTransform.passSeq +} + trait LowTransformSpec extends SimpleTransformSpec { def transform: Transform - def transforms (writer: Writer) = Seq( - new Chisel3ToHighFirrtl(), + def transforms = Seq( + new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), new MiddleFirrtlToLowFirrtl(), - new ResolveAndCheck(), - transform, - new EmitFirrtl(writer) + new CustomResolveAndCheck(LowForm), + transform ) } trait MiddleTransformSpec extends SimpleTransformSpec { def transform: Transform - def transforms (writer: Writer) = Seq( - new Chisel3ToHighFirrtl(), + def transforms = Seq( + new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), - new ResolveAndCheck(), - transform, - new EmitFirrtl(writer) + new CustomResolveAndCheck(MidForm), + transform ) } trait HighTransformSpec extends SimpleTransformSpec { def transform: Transform - def transforms (writer: Writer) = Seq( - new Chisel3ToHighFirrtl(), + def transforms = Seq( + new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), - transform, - new EmitFirrtl(writer) + transform ) } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 78b3d9f0..e46230ef 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -6,19 +6,19 @@ import firrtl.passes.memlib._ import Annotations._ class ReplSeqMemSpec extends SimpleTransformSpec { - val passSeq = Seq( - ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) - def transforms (writer: java.io.Writer) = Seq( - new Chisel3ToHighFirrtl(), + def transforms = Seq( + new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), - new InferReadWrite(TransID(-1)), - new ReplSeqMem(TransID(-2)), + new InferReadWrite(), + new ReplSeqMem(), new MiddleFirrtlToLowFirrtl(), - (new Transform with SimpleRun { - def execute(c: ir.Circuit, a: AnnotationMap) = run(c, passSeq) } ), - new EmitFirrtl(writer) + new PassBasedTransform { + def inputForm = LowForm + def outputForm = LowForm + def passSeq = Seq(ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) + } ) "ReplSeqMem" should "generate blackbox wrappers for mems of bundle type" in { @@ -58,9 +58,9 @@ circuit Top : io2.commit_entry.bits.info <- R1 """.stripMargin val confLoc = "ReplSeqMemTests.confTEMP" - val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc, TransID(-2)))) + val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc))) val writer = new java.io.StringWriter - compile(parse(input), aMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer) // Check correctness of firrtl parse(writer.toString) (new java.io.File(confLoc)).delete() @@ -81,9 +81,9 @@ circuit Top : write mport T_155 = mem[p_address], clk """.stripMargin val confLoc = "ReplSeqMemTests.confTEMP" - val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc, TransID(-2)))) + val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc))) val writer = new java.io.StringWriter - compile(parse(input), aMap, writer) + compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer) // Check correctness of firrtl parse(writer.toString) (new java.io.File(confLoc)).delete() diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 245c32e8..1025c02b 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -36,7 +36,6 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class UnitTests extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) @@ -114,7 +113,7 @@ class UnitTests extends FirrtlFlatSpec { (c: Circuit, p: Pass) => p.run(c) } val writer = new StringWriter() - FIRRTLEmitter.run(c_result,writer) + (new FirrtlEmitter).emit(CircuitState(c_result, HighForm), writer) (parse(writer.toString())) should be (parse(check)) } @@ -136,7 +135,7 @@ class UnitTests extends FirrtlFlatSpec { intercept[PassException] { val c = Parser.parse(splitExpTestCode.split("\n").toIterator) val c2 = passes.foldLeft(c)((c, p) => p run c) - new VerilogEmitter().run(c2, new OutputStreamWriter(new ByteArrayOutputStream)) + (new VerilogEmitter).emit(CircuitState(c2, LowForm), new StringWriter) } } @@ -147,7 +146,7 @@ class UnitTests extends FirrtlFlatSpec { InferTypes) val c = Parser.parse(splitExpTestCode.split("\n").toIterator) val c2 = passes.foldLeft(c)((c, p) => p run c) - new VerilogEmitter().run(c2, new OutputStreamWriter(new ByteArrayOutputStream)) + (new VerilogEmitter).emit(CircuitState(c2, LowForm), new StringWriter) } "Simple compound expressions" should "be split" in { diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index 1f6142bc..e9bf5429 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -37,10 +37,9 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class DoPrimVerilog extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], compiler: Compiler) = { val writer = new StringWriter() - compiler.compile(parse(input), new AnnotationMap(Seq.empty), writer) + compiler.compile(CircuitState(parse(input), ChirrtlForm), writer) val lines = writer.toString().split("\n") map normalized expected foreach { e => lines should contain(e) diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala index d1b16bb9..74f6432f 100644 --- a/src/test/scala/firrtlTests/WidthSpec.scala +++ b/src/test/scala/firrtlTests/WidthSpec.scala @@ -36,7 +36,6 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class WidthSpec extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala index 5f40d861..309014d4 100644 --- a/src/test/scala/firrtlTests/WiringTests.scala +++ b/src/test/scala/firrtlTests/WiringTests.scala @@ -12,7 +12,6 @@ import wiring.WiringUtils._ import wiring._ class WiringTests extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala index a9a1bb47..4a87290d 100644 --- a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala @@ -5,12 +5,11 @@ package firrtlTests.fixed import java.io.StringWriter import firrtl.Annotations.AnnotationMap -import firrtl.{LowFirrtlCompiler, Parser} +import firrtl.{CircuitState, ChirrtlForm, LowFirrtlCompiler, Parser} import firrtl.Parser.IgnoreInfo import firrtlTests.FirrtlFlatSpec class FixedPointMathSpec extends FirrtlFlatSpec { - def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) val SumPattern = """.*output sum.*<(\d+)>.*.*""".r val ProductPattern = """.*output product.*<(\d+)>.*""".r @@ -45,7 +44,7 @@ class FixedPointMathSpec extends FirrtlFlatSpec { val writer = new StringWriter() - lowerer.compile(parse(input), new AnnotationMap(Seq.empty), writer) + lowerer.compile(CircuitState(parse(input), ChirrtlForm), writer) val output = writer.toString.split("\n") diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala index 53b4f4c0..3f465361 100644 --- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala @@ -34,7 +34,6 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class FixedTypeInferenceSpec extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala index 6799a367..27d7e172 100644 --- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala +++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala @@ -35,7 +35,6 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class RemoveFixedTypeSpec extends FirrtlFlatSpec { - def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => p.run(c) @@ -204,14 +203,14 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | io_out <= io_in """.stripMargin - class CheckChirrtlTransform extends Transform with SimpleRun { + class CheckChirrtlTransform extends PassBasedTransform { + def inputForm = ChirrtlForm + def outputForm = ChirrtlForm val passSeq = Seq(passes.CheckChirrtl) - def execute (circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } val chirrtlTransform = new CheckChirrtlTransform - chirrtlTransform.execute(parse(input), new AnnotationMap(Seq.empty)) + chirrtlTransform.execute(CircuitState(parse(input), ChirrtlForm, Some(new AnnotationMap(Seq.empty)))) } } |
