diff options
| author | jackkoenig | 2016-10-20 00:19:01 -0700 |
|---|---|---|
| committer | Jack Koenig | 2016-11-04 13:29:09 -0700 |
| commit | 8fa9429a6e916ab2a789f5d81fa803b022805b52 (patch) | |
| tree | fac2efcbd0a68bfb1916f09afc7f003c7a3d6528 /src/main | |
| parent | 62133264a788f46b319ebab9c31424b7e0536101 (diff) | |
Refactor Compilers and Transforms
* Transform Ids now handled by Class[_ <: Transform] instead of magic numbers
* Transforms define inputForm and outputForm
* Custom transforms can be inserted at runtime into compiler or the Driver
* Current "built-in" custom transforms handled via above mechanism
* Verilog-specific passes moved to the Verilog emitter
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/Annotations.scala | 28 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Compiler.scala | 265 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Driver.scala | 17 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 32 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ExecutionOptionsManager.scala | 18 | ||||
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 186 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 48 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/DecorateMems.scala | 22 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/InferReadWrite.scala | 21 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala | 24 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala | 89 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/wiring/WiringTransform.scala | 29 |
13 files changed, 484 insertions, 297 deletions
diff --git a/src/main/scala/firrtl/Annotations.scala b/src/main/scala/firrtl/Annotations.scala index d47ce67e..d70732e6 100644 --- a/src/main/scala/firrtl/Annotations.scala +++ b/src/main/scala/firrtl/Annotations.scala @@ -115,12 +115,6 @@ object Annotations { } /** - * Transform ID (TransID) associates an annotation with an instantiated - * Firrtl compiler transform - */ - case class TransID(id: Int) - - /** * Permissibility defines the range of acceptable changes to the annotated component. */ trait Permissibility { @@ -215,7 +209,7 @@ object Annotations { /** * Annotation associates with a given named circuit component (target) and a - * given transformation (tID). Also defined are the legal ranges of changes + * given transformation (transform). Also defined are the legal ranges of changes * to the associated component (Permissibility) and how the annotation * propagates under such changes (Tenacity). Subclasses must implement the * duplicate function to create the same annotation associated with a new @@ -223,7 +217,7 @@ object Annotations { */ trait Annotation extends Permissibility with Tenacity { def target: Named - def tID: TransID + def transform: Class[_ <: Transform] protected def duplicate(n: Named): Annotation def serialize: String = this.toString def update(tos: Seq[Named]): Seq[Annotation] = { @@ -236,23 +230,23 @@ object Annotations { * Container of all annotations for a Firrtl compiler. */ case class AnnotationMap(annotations: Seq[Annotation]) { - type NamedMap = Map[Named, Map[TransID, Annotation]] - type IDMap = Map[TransID, Map[Named, Annotation]] + type NamedMap = Map[Named, Map[Class[_], Annotation]] + type IDMap = Map[Class[_], Map[Named, Annotation]] val (namedMap: NamedMap, idMap:IDMap) = //annotations.foldLeft(Tuple2[NamedMap, IDMap](Map.empty, Map.empty)){ annotations.foldLeft((Map.empty: NamedMap, Map.empty: IDMap)){ (partialMaps: (NamedMap, IDMap), annotation: Annotation) => { - val tIDToAnn = partialMaps._1.getOrElse(annotation.target, Map.empty) - val pNMap = partialMaps._1 + (annotation.target -> (tIDToAnn + (annotation.tID -> annotation))) + val transformToAnn = partialMaps._1.getOrElse(annotation.target, Map.empty) + val pNMap = partialMaps._1 + (annotation.target -> (transformToAnn + (annotation.transform -> annotation))) - val nToAnn = partialMaps._2.getOrElse(annotation.tID, Map.empty) - val ptIDMap = partialMaps._2 + (annotation.tID -> (nToAnn + (annotation.target -> annotation))) - Tuple2(pNMap, ptIDMap) + val nToAnn = partialMaps._2.getOrElse(annotation.transform, Map.empty) + val ptransformMap = partialMaps._2 + (annotation.transform -> (nToAnn + (annotation.target -> annotation))) + Tuple2(pNMap, ptransformMap) } } - def get(id: TransID): Option[Map[Named, Annotation]] = idMap.get(id) - def get(named: Named): Option[Map[TransID, Annotation]] = namedMap.get(named) + def get(id: Class[_]): Option[Map[Named, Annotation]] = idMap.get(id) + def get(named: Named): Option[Map[Class[_], Annotation]] = namedMap.get(named) } } diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index f566544e..9781972e 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -27,48 +27,249 @@ MODIFICATIONS. package firrtl -import com.typesafe.scalalogging.LazyLogging -import scala.collection.mutable +import logger.LazyLogging import java.io.Writer import Annotations._ import firrtl.ir.Circuit +import passes.Pass + /** * RenameMap maps old names to modified names. Generated by transformations * that modify names */ case class RenameMap(map: Map[Named, Seq[Named]]) -// =========================================== -// Transforms -// ------------------------------------------- - -case class TransformResult( +/** Current State of the Circuit + * + * @constructor Creates a CircuitState object + * @param circuit The current state of the Firrtl AST + * @param form The current form of the circuit + * @param annotations The current collection of [[Annotations.Annotation]] + * @param renames A map of [[Annotations.Named]] things that have been renamed. + * Generally only a return value from [[Transform]]s + */ +case class CircuitState( circuit: Circuit, - renames: Option[RenameMap] = None, - annotation: Option[AnnotationMap] = None) + form: CircuitForm, + annotations: Option[AnnotationMap] = None, + renames: Option[RenameMap] = None) -// - Transforms a circuit -// - Can consume multiple CircuitAnnotation's -trait Transform { - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult +/** Current form of the Firrtl Circuit + * + * Form is a measure of addition restrictions on the legality of a Firrtl + * circuit. There is a notion of "highness" and "lowness" implemented in the + * compiler by extending scala.math.Ordered. "Lower" forms add additional + * restrictions compared to "higher" forms. This means that "higher" forms are + * strictly supersets of the "lower" forms. Thus, that any transform that + * operates on [[HighForm]] can also operate on [[MidForm]] or [[LowForm]] + */ +sealed abstract class CircuitForm(private val value: Int) extends Ordered[CircuitForm] { + // Note that value is used only to allow comparisons + def compare(that: CircuitForm): Int = this.value - that.value } +/** Chirrtl Form + * + * The form of the circuit emitted by Chisel. Not a true Firrtl form. + * Includes cmem, smem, and mport IR nodes which enable declaring memories + * separately form their ports. A "Higher" form than [[HighForm]] + * + * See [[CDefMemory]] and [[CDefMPort]] + */ +final case object ChirrtlForm extends CircuitForm(3) +/** High Form + * + * As detailed in the Firrtl specification + * [[https://github.com/ucb-bar/firrtl/blob/master/spec/spec.pdf]] + * + * Also see [[firrtl.ir]] + */ +final case object HighForm extends CircuitForm(2) +/** Middle Form + * + * A "lower" form than [[HighForm]] with the following restrictions: + * - All widths must be explicit + * - All whens must be removed + * - There can only be a single connection to any element + */ +final case object MidForm extends CircuitForm(1) +/** Low Form + * + * The "lowest" form. In addition to the restrictions in [[MidForm]]: + * - All aggregate types (vector/bundle) must have been removed + * - All implicit truncations must be made explicit + */ +final case object LowForm extends CircuitForm(0) +/** The basic unit of operating on a Firrtl AST */ +abstract class Transform { + /** A convenience function useful for debugging and error messages */ + def name: String = this.getClass.getSimpleName + /** The [[CircuitForm]] that this transform requires to operate on */ + def inputForm: CircuitForm + /** The [[CircuitForm]] that this transform outputs */ + def outputForm: CircuitForm + /** Perform the transform + * + * @param state Input Firrtl AST + * @return A transformed Firrtl AST + */ + def execute(state: CircuitState): CircuitState + /** Convenience method to get annotations relevant to this Transform + * + * @param state The [[CircuitState]] form which to extract annotations + * @return A collection of annotations + */ + final def getMyAnnotations(state: CircuitState): Option[Map[Named, Annotation]] = + for { + annotations <- state.annotations + myAnnotations <- annotations.get(this.getClass) + } yield myAnnotations +} -// =========================================== -// Compilers -// ------------------------------------------- +trait SimpleRun extends LazyLogging { + def runPasses(circuit: Circuit, passSeq: Seq[Pass]): Circuit = + passSeq.foldLeft(circuit) { (c: Circuit, pass: Pass) => + val x = Utils.time(pass.name) { pass.run(c) } + logger.debug(x.serialize) + x + } +} + +/** For PassBased Transforms and Emitters + * + * @note passSeq accepts no arguments + * @todo make passes accept CircuitState so annotations can pass data between them + */ +trait PassBased extends SimpleRun { + def passSeq: Seq[Pass] + def runPasses(circuit: Circuit): Circuit = runPasses(circuit, passSeq) +} -case class CompilerResult(circuit: Circuit, annotationMap: AnnotationMap) +/** For transformations that are simply a sequence of passes */ +abstract class PassBasedTransform extends Transform with PassBased { + def execute(state: CircuitState): CircuitState = { + require(state.form <= inputForm, + s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") + CircuitState(runPasses(state.circuit), outputForm) + } +} + +/** Similar to a Transform except that it writes to a Writer instead of returning a + * CircuitState + */ +abstract class Emitter { + def emit(state: CircuitState, writer: Writer): Unit +} + +object CompilerUtils { + /** Generates a sequence of [[Transform]]s to lower a Firrtl circuit + * + * @param inputForm [[CircuitForm]] to lower from + * @param outputForm [[CircuitForm to lower to + * @return Sequence of transforms that will lower if outputForm is lower than inputForm + */ + def getLoweringTransforms(inputForm: CircuitForm, outputForm: CircuitForm): Seq[Transform] = { + // If outputForm is equal-to or higher than inputForm, nothing to lower + if (outputForm >= inputForm) { + Seq.empty + } else { + inputForm match { + case ChirrtlForm => Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm) + case HighForm => Seq(new IRToWorkingIR, new ResolveAndCheck, new HighFirrtlToMiddleFirrtl) ++ + getLoweringTransforms(MidForm, outputForm) + case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) + case LowForm => error("Internal Error! This shouldn't be possible") // should be caught by if above + } + } + } + + /** Merge a Seq of lowering transforms with custom transforms + * + * Custom Transforms are inserted based on their [[Transform.inputForm]] and + * [[Transform.outputForm]]. Custom transforms are inserted in order at the + * last location in the Seq of transforms where previous.outputForm == + * customTransform.inputForm. If a customTransform outputs a higher form + * than input, [[getLoweringTransforms]] is used to relower the circuit. + * + * @example + * {{{ + * // Let Transforms be represented by CircuitForm => CircuitForm + * val A = HighForm => MidForm + * val B = MidForm => LowForm + * val lowering = List(A, B) // Assume these transforms are used by getLoweringTransforms + * // Some custom transforms + * val C = LowForm => LowForm + * val D = MidForm => MidForm + * val E = LowForm => HighForm + * // All of the following comparisons are true + * mergeTransforms(lowering, List(C)) == List(A, B, C) + * mergeTransforms(lowering, List(D)) == List(A, D, B) + * mergeTransforms(lowering, List(E)) == List(A, B, E, A, B) + * mergeTransforms(lowering, List(C, E)) == List(A, B, C, E, A, B) + * mergeTransforms(lowering, List(E, C)) == List(A, B, E, A, B, C) + * // Notice that in the following, custom transform order is NOT preserved (see note) + * mergeTransforms(lowering, List(C, D)) == List(A, D, B, C) + * }}} + * + * @note Order will be preserved for custom transforms so long as the + * inputForm of a latter transforms is equal to or lower than the outputForm + * of the previous transform. + */ + def mergeTransforms(lowering: Seq[Transform], custom: Seq[Transform]): Seq[Transform] = { + custom.foldLeft(lowering) { case (transforms, xform) => + val index = transforms lastIndexWhere (_.outputForm == xform.inputForm) + assert(index >= 0 || xform.inputForm == ChirrtlForm, // If ChirrtlForm just put at front + s"No transform in $lowering has outputForm ${xform.inputForm} as required by $xform") + val (front, back) = transforms.splitAt(index + 1) // +1 because we want to be AFTER index + front ++ List(xform) ++ getLoweringTransforms(xform.outputForm, xform.inputForm) ++ back + } + } + +} -// - A sequence of transformations -// - Call compile to executes each transformation in sequence onto -// a given circuit. trait Compiler { - def transforms(w: Writer): Seq[Transform] - def compile(circuit: Circuit, annotationMap: AnnotationMap, writer: Writer): CompilerResult = - (transforms(writer) foldLeft CompilerResult(circuit, annotationMap)){ (in, xform) => - val result = xform.execute(in.circuit, in.annotationMap) + def emitter: Emitter + /** The sequence of transforms this compiler will execute + * @note The inputForm of a given transform must be higher than or equal to the ouputForm of the + * preceding transform. See [[CircuitForm]] + */ + def transforms: Seq[Transform] + + // Similar to (input|output)Form on [[Transform]] but derived from this Compiler's transforms + def inputForm = transforms.head.inputForm + def outputForm = transforms.last.outputForm + + private def transformsLegal(xforms: Seq[Transform]): Boolean = + if (xforms.size < 2) { + true + } else { + xforms.sliding(2, 1) + .map { case Seq(p, n) => n.inputForm >= p.outputForm } + .reduce(_ && _) + } + + assert(transformsLegal(transforms), + "Illegal Compiler, each transform must be able to accept the output of the previous transform!") + + /** Perform compilation + * + * @param state The Firrtl AST to compile + * @param writer The java.io.Writer where the output of compilation will be emitted + * @param customTransforms Any custom [[Transform]]s that will be inserted + * into the compilation process by [[CompilerUtils.mergeTransforms]] + */ + def compile(state: CircuitState, + writer: Writer, + customTransforms: Seq[Transform] = Seq.empty): CircuitState = { + val allTransforms = CompilerUtils.mergeTransforms(transforms, customTransforms) + + val finalState = allTransforms.foldLeft(state) { (in, xform) => + val result = Utils.time(s"***${xform.name}***") { xform.execute(in) } + + // Annotation propagation + // TODO: This should be redone + val inAnnotationMap = in.annotations getOrElse AnnotationMap(Seq.empty) val remappedAnnotations: Seq[Annotation] = result.renames match { case Some(RenameMap(rmap)) => // For each key in the rename map (rmap), obtain the @@ -77,18 +278,22 @@ trait Compiler { // annotations with the names in rmap's value. for { (oldName, newNames) <- rmap.toSeq - tID2OldAnnos <- in.annotationMap.get(oldName).toSeq - oldAnno <- tID2OldAnnos.values + transform2OldAnnos <- inAnnotationMap.get(oldName).toSeq + oldAnno <- transform2OldAnnos.values newAnno <- oldAnno.update(newNames) } yield newAnno - case _ => in.annotationMap.annotations + case _ => inAnnotationMap.annotations } - val resultAnnotations: Seq[Annotation] = result.annotation match { + val resultAnnotations: Seq[Annotation] = result.annotations match { case None => Nil case Some(p) => p.annotations } - CompilerResult(result.circuit, - new AnnotationMap(remappedAnnotations ++ resultAnnotations)) + val newAnnotations = AnnotationMap(remappedAnnotations ++ resultAnnotations) + CircuitState(result.circuit, result.form, Some(newAnnotations)) } + + emitter.emit(finalState, writer) + finalState + } } diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index ba5527b4..293ac4fd 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -30,7 +30,8 @@ import scala.collection._ * firrtl.Driver.execute(Array("--top-name Dummy --compiler verilog".split(" +")) * }}} * each approach has its own endearing aspects - * @see firrtlTests.DriverSpec.scala in the test directory for a lot more examples + * @see firrtlTests/DriverSpec.scala in the test directory for a lot more examples + * @see [[CompilerUtils.mergeTransforms]] to see how customTransformations are inserted */ object Driver { @@ -42,11 +43,15 @@ object Driver { output: String, compiler: Compiler, infoMode: InfoMode = IgnoreInfo, + customTransforms: Seq[Transform] = Seq.empty, annotations: AnnotationMap = new AnnotationMap(Seq.empty) ): String = { val parsedInput = Parser.parse(Source.fromFile(input).getLines(), infoMode) val outputBuffer = new java.io.CharArrayWriter - compiler.compile(parsedInput, annotations, outputBuffer) + compiler.compile( + CircuitState(parsedInput, ChirrtlForm, Some(annotations)), + outputBuffer, + customTransforms) val outputFile = new java.io.PrintWriter(output) val outputString = outputBuffer.toString @@ -108,7 +113,11 @@ object Driver { val parsedInput = Parser.parse(firrtlSource, firrtlConfig.infoMode) val outputBuffer = new java.io.CharArrayWriter - firrtlConfig.compiler.compile(parsedInput, new AnnotationMap(firrtlConfig.annotations), outputBuffer) + firrtlConfig.compiler.compile( + CircuitState(parsedInput, ChirrtlForm, Some(new AnnotationMap(firrtlConfig.annotations))), + outputBuffer, + firrtlConfig.customTransforms + ) val outputFileName = firrtlConfig.getOutputFileName(optionsManager) val outputFile = new java.io.PrintWriter(outputFileName) @@ -193,4 +202,4 @@ object FileUtils { } } } -}
\ No newline at end of file +} diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 7b198149..1d64dc91 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -47,12 +47,8 @@ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap, HashSet} case class EmitterException(message: String) extends PassException(message) -trait Emitter extends LazyLogging { - def run(c: Circuit, w: Writer) -} - -object FIRRTLEmitter extends Emitter { - def run(c: Circuit, w: Writer) = w.write(c.serialize) +class FirrtlEmitter extends Emitter { + def emit(state: CircuitState, writer: Writer): Unit = writer.write(state.circuit.serialize) } case class VRandom(width: BigInt) extends Expression { @@ -65,7 +61,7 @@ case class VRandom(width: BigInt) extends Expression { def mapWidth(f: Width => Width): Expression = this } -class VerilogEmitter extends Emitter { +class VerilogEmitter extends Emitter with PassBased { val tab = " " def AND(e1: WrappedExpression, e2: WrappedExpression): Expression = { if (e1 == e2) e1.e1 @@ -590,12 +586,18 @@ class VerilogEmitter extends Emitter { "`endif\n")) } - def run(c: Circuit, w: Writer) = { - emit_preamble(w) - val moduleMap = (c.modules map (m => m.name -> m)).toMap - c.modules foreach { - case (m: Module) => emit_verilog(m, moduleMap)(w) - case (m: ExtModule) => - } - } + def passSeq = Seq( + passes.VerilogWrap, + passes.VerilogRename, + passes.VerilogPrep) + + def emit(state: CircuitState, writer: Writer): Unit = { + val circuit = runPasses(state.circuit) + emit_preamble(writer) + val moduleMap = (circuit.modules map (m => m.name -> m)).toMap + circuit.modules foreach { + case (m: Module) => emit_verilog(m, moduleMap)(writer) + case (m: ExtModule) => + } + } } diff --git a/src/main/scala/firrtl/ExecutionOptionsManager.scala b/src/main/scala/firrtl/ExecutionOptionsManager.scala index e4954610..21a9cc50 100644 --- a/src/main/scala/firrtl/ExecutionOptionsManager.scala +++ b/src/main/scala/firrtl/ExecutionOptionsManager.scala @@ -140,6 +140,7 @@ case class FirrtlExecutionOptions( infoModeName: String = "append", inferRW: Seq[String] = Seq.empty, firrtlSource: Option[String] = None, + customTransforms: Seq[Transform] = List.empty, annotations: List[Annotation] = List.empty) extends ComposableOptions { @@ -249,14 +250,17 @@ trait HasFirrtlOptions { val newAnnotations = x.map { value => value.split('.') match { case Array(circuit) => - passes.InlineAnnotation(CircuitName(circuit), TransID(0)) + passes.InlineAnnotation(CircuitName(circuit)) case Array(circuit, module) => - passes.InlineAnnotation(ModuleName(module, CircuitName(circuit)), TransID(0)) + passes.InlineAnnotation(ModuleName(module, CircuitName(circuit))) case Array(circuit, module, inst) => - passes.InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit))), TransID(0)) + passes.InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit)))) } } - firrtlOptions = firrtlOptions.copy(annotations = firrtlOptions.annotations ++ newAnnotations) + firrtlOptions = firrtlOptions.copy( + annotations = firrtlOptions.annotations ++ newAnnotations, + customTransforms = firrtlOptions.customTransforms :+ new passes.InlineInstances + ) } .text { """Inline one or more module (comma separated, no spaces) module looks like "MyModule" or "MyModule.myinstance""" @@ -267,7 +271,8 @@ trait HasFirrtlOptions { .valueName ("<circuit>") .foreach { x => firrtlOptions = firrtlOptions.copy( - annotations = firrtlOptions.annotations :+ InferReadWriteAnnotation(x, TransID(-1)) + annotations = firrtlOptions.annotations :+ InferReadWriteAnnotation(x), + customTransforms = firrtlOptions.customTransforms :+ new passes.memlib.InferReadWrite ) }.text { "Enable readwrite port inference for the target circuit" @@ -278,7 +283,8 @@ trait HasFirrtlOptions { .valueName ("-c:<circuit>:-i:<filename>:-o:<filename>") .foreach { x => firrtlOptions = firrtlOptions.copy( - annotations = firrtlOptions.annotations :+ ReplSeqMemAnnotation(x, TransID(-2)) + annotations = firrtlOptions.annotations :+ ReplSeqMemAnnotation(x), + customTransforms = firrtlOptions.customTransforms :+ new passes.memlib.ReplSeqMem ) } .text { diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 446df6d0..986ebd9f 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -27,59 +27,38 @@ MODIFICATIONS. package firrtl -import java.io.Writer -import firrtl.passes.Pass -import firrtl.ir.Circuit -import Annotations._ -import logger.LazyLogging +sealed abstract class CoreTransform extends PassBasedTransform -// =========================================== -// Utility Traits -// ------------------------------------------- -// Valid if all passes in transformation: -// 1) Don't produce annotations -// 2) Don't consume annotations -// 3) No component or module names are renamed -trait SimpleRun extends LazyLogging { - def run (circuit: Circuit, passes: Seq[Pass]): TransformResult = { - val result = (passes foldLeft circuit){ (c: Circuit, pass: Pass) => - val name = pass.name - val x = Utils.time(name)(pass.run(c)) - logger.debug(x.serialize) - x - } - TransformResult(result) - } -} - -// =========================================== -// Lowering Transforms -// ------------------------------------------- -// This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting -// circuit has only IR nodes, not WIR. -// TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL -class Chisel3ToHighFirrtl extends Transform with SimpleRun { - val passSeq = Seq( +/** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting + * circuit has only IR nodes, not WIR. + * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL + */ +class ChirrtlToHighFirrtl extends CoreTransform { + def inputForm = ChirrtlForm + def outputForm = HighForm + def passSeq = Seq( passes.CheckChirrtl, passes.CInferTypes, passes.CInferMDir, passes.RemoveCHIRRTL) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } -// Converts from the bare intermediate representation (ir.scala) -// to a working representation (WIR.scala) -class IRToWorkingIR extends Transform with SimpleRun { - val passSeq = Seq(passes.ToWorkingIR) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) +/** Converts from the bare intermediate representation (ir.scala) + * to a working representation (WIR.scala) + */ +class IRToWorkingIR extends CoreTransform { + def inputForm = HighForm + def outputForm = HighForm + def passSeq = Seq(passes.ToWorkingIR) } -// Resolves types, kinds, and genders, and checks the circuit legality. -// Operates on working IR nodes and high Firrtl. -class ResolveAndCheck extends Transform with SimpleRun { - val passSeq = Seq( +/** Resolves types, kinds, and genders, and checks the circuit legality. + * Operates on working IR nodes and high Firrtl. + */ +class ResolveAndCheck extends CoreTransform { + def inputForm = HighForm + def outputForm = HighForm + def passSeq = Seq( passes.CheckHighForm, passes.ResolveKinds, passes.InferTypes, @@ -91,16 +70,17 @@ class ResolveAndCheck extends Transform with SimpleRun { passes.CheckGenders, passes.InferWidths, passes.CheckWidths) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } -// Expands aggregate connects, removes dynamic accesses, and when -// statements. Checks for uninitialized values. Must accept a -// well-formed graph. -// Operates on working IR nodes. -class HighFirrtlToMiddleFirrtl extends Transform with SimpleRun { - val passSeq = Seq( +/** Expands aggregate connects, removes dynamic accesses, and when + * statements. Checks for uninitialized values. Must accept a + * well-formed graph. + * Operates on working IR nodes. + */ +class HighFirrtlToMiddleFirrtl extends CoreTransform { + def inputForm = HighForm + def outputForm = MidForm + def passSeq = Seq( passes.PullMuxes, passes.ReplaceAccesses, passes.ExpandConnects, @@ -112,16 +92,17 @@ class HighFirrtlToMiddleFirrtl extends Transform with SimpleRun { passes.ResolveGenders, passes.InferWidths, passes.CheckWidths) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } -// Expands all aggregate types into many ground-typed components. Must -// accept a well-formed graph of only middle Firrtl features. -// Operates on working IR nodes. -// TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL -class MiddleFirrtlToLowFirrtl extends Transform with SimpleRun { - val passSeq = Seq( +/** Expands all aggregate types into many ground-typed components. Must + * accept a well-formed graph of only middle Firrtl features. + * Operates on working IR nodes. + * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL + */ +class MiddleFirrtlToLowFirrtl extends CoreTransform { + def inputForm = MidForm + def outputForm = LowForm + def passSeq = Seq( passes.LowerTypes, passes.ResolveKinds, passes.InferTypes, @@ -129,87 +110,48 @@ class MiddleFirrtlToLowFirrtl extends Transform with SimpleRun { passes.InferWidths, passes.ConvertFixedToSInt, passes.Legalize) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = - run(circuit, passSeq) } -// Emits Verilog. -// First optimizes for verilog width semantics with custom Primops, -// then splits complex expressions into temporary nodes. Finally, -// renames names that conflict with Verilog keywords. -// Operates on working IR nodes. -// TODO(izraelevitz): Create RenameMap from VerilogRename -class EmitVerilogFromLowFirrtl(val writer: Writer) extends Transform with SimpleRun { - val passSeq = Seq( +/** Runs a series of optimization passes on LowFirrtl + * @note This is currently required for correct Verilog emission + * TODO Fix the above note + */ +class LowFirrtlOptimization extends CoreTransform { + def inputForm = LowForm + def outputForm = LowForm + def passSeq = Seq( passes.RemoveValidIf, passes.ConstProp, passes.PadWidths, passes.ConstProp, passes.Legalize, - passes.VerilogWrap, - passes.memlib.VerilogMemDelays, + passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter passes.ConstProp, passes.SplitExpressions, passes.CommonSubexpressionElimination, - passes.DeadCodeElimination, - passes.VerilogRename, - passes.VerilogPrep) - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = { - val result = run(circuit, passSeq) - (new VerilogEmitter).run(result.circuit, writer) - result - } + passes.DeadCodeElimination) } -// Emits Firrtl. -// Operates on WIR/IR nodes. -class EmitFirrtl(val writer: Writer) extends Transform { - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = { - FIRRTLEmitter.run(circuit, writer) - TransformResult(circuit) - } -} +import CompilerUtils.getLoweringTransforms -// =========================================== -// Lowering Compilers -// ------------------------------------------- -// Emits input circuit -// Will replace Chirrtl constructs with Firrtl +/** Emits input circuit + * Will replace Chirrtl constructs with Firrtl + */ class HighFirrtlCompiler extends Compiler { - def transforms(writer: Writer): Seq[Transform] = Seq( - new Chisel3ToHighFirrtl, - new IRToWorkingIR, - new EmitFirrtl(writer) - ) + def emitter = new FirrtlEmitter + def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, HighForm) } -// Emits lowered input circuit +/** Emits lowered input circuit */ class LowFirrtlCompiler extends Compiler { - def transforms(writer: Writer): Seq[Transform] = Seq( - new Chisel3ToHighFirrtl, - new IRToWorkingIR, - new passes.InlineInstances(TransID(0)), - new ResolveAndCheck, - new HighFirrtlToMiddleFirrtl, - new passes.memlib.InferReadWrite(TransID(-1)), - new passes.memlib.ReplSeqMem(TransID(-2)), - new MiddleFirrtlToLowFirrtl, - new EmitFirrtl(writer) - ) + def emitter = new FirrtlEmitter + def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, LowForm) } -// Emits Verilog +/** Emits Verilog */ class VerilogCompiler extends Compiler { - def transforms(writer: Writer): Seq[Transform] = Seq( - new Chisel3ToHighFirrtl, - new IRToWorkingIR, - new ResolveAndCheck, - new HighFirrtlToMiddleFirrtl, - new passes.memlib.InferReadWrite(TransID(-1)), - new passes.memlib.ReplSeqMem(TransID(-2)), - new MiddleFirrtlToLowFirrtl, - new passes.InlineInstances(TransID(0)), - new EmitVerilogFromLowFirrtl(writer) - ) + def emitter = new VerilogEmitter + def transforms: Seq[Transform] = + getLoweringTransforms(ChirrtlForm, LowForm) :+ (new LowFirrtlOptimization) } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 22a3eac6..7c023ac8 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -44,7 +44,7 @@ import firrtl.WrappedType._ import scala.collection.mutable import scala.collection.mutable.{StringBuilder, ArrayBuffer, LinkedHashMap, HashMap, HashSet} import java.io.PrintWriter -import com.typesafe.scalalogging.LazyLogging +import logger.LazyLogging //import scala.reflect.runtime.universe._ class FIRRTLException(str: String) extends Exception(str) diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 5c80baff..c741dc06 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -9,34 +9,37 @@ import firrtl.Annotations._ import scala.collection.mutable // Tags an annotation to be consumed by this pass -case class InlineAnnotation(target: Named, tID: TransID) extends Annotation with Loose with Unstable { +case class InlineAnnotation(target: Named) extends Annotation with Loose with Unstable { def duplicate(n: Named) = this.copy(target=n) + def transform = classOf[InlineInstances] } // Only use on legal Firrtl. Specifically, the restriction of // instance loops must have been checked, or else this pass can // infinitely recurse -class InlineInstances (transID: TransID) extends Transform { +class InlineInstances extends Transform { + def inputForm = LowForm + def outputForm = LowForm val inlineDelim = "$" - def name = "Inline Instances" - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = { - annotationMap.get(transID) match { - case None => TransformResult(circuit, None, None) - case Some(map) => - val moduleNames = mutable.HashSet[ModuleName]() - val instanceNames = mutable.HashSet[ComponentName]() - map.values.foreach {x: Annotation => x match { - case InlineAnnotation(ModuleName(mod, cir), _) => moduleNames += ModuleName(mod, cir) - case InlineAnnotation(ComponentName(com, mod), _) => instanceNames += ComponentName(com, mod) - case _ => throw new PassException("Annotation must be InlineAnnotation") - }} - check(circuit, moduleNames.toSet, instanceNames.toSet) - run(circuit, moduleNames.toSet, instanceNames.toSet) + override def name = "Inline Instances" - // Default behavior is to error if more than one annotation for inlining - // This could potentially change - case _ => throw new PassException("Found more than one circuit annotation of InlineCAKind!") + private def collectAnns(anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) { + case ((modNames, instNames), ann) => ann match { + case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) + case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) + case _ => throw new PassException("Annotation must be InlineAnnotation") + } } + + def execute(state: CircuitState): CircuitState = { + // TODO Add error check for more than one annotation for inlining + // TODO Propagate other annotations + val result = for { + myAnnotations <- getMyAnnotations(state) + (modNames, instNames) = collectAnns(myAnnotations.values) + } yield run(state.circuit, modNames, instNames) + result getOrElse state // Return state if nothing to do } // Checks the following properties: @@ -78,7 +81,10 @@ class InlineInstances (transID: TransID) extends Transform { if (errors.nonEmpty) throw new PassExceptions(errors) } - def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): TransformResult = { + def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): CircuitState = { + // Check annotations and circuit match up + check(c, modsToInline, instsToInline) + // ---- Rename functions/data ---- val renameMap = mutable.HashMap[Named,Seq[Named]]() // Updates renameMap with new names @@ -168,6 +174,6 @@ class InlineInstances (transID: TransID) extends Transform { val top = c.modules.find(m => m.name == c.main).get onModule(top) val modulesx = c.modules.map(m => inlinedModules(m.name)) - TransformResult(Circuit(c.info, modulesx, c.main), Some(RenameMap(renameMap.toMap)), None) + CircuitState(Circuit(c.info, modulesx, c.main), LowForm, None, Some(RenameMap(renameMap.toMap))) } } diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 10cc8f88..c98dd4ca 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -5,20 +5,22 @@ import ir._ import Annotations._ import wiring._ -class CreateMemoryAnnotations(reader: Option[YamlFileReader], replaceID: TransID, wiringID: TransID) extends Transform { - def name = "Create Memory Annotations" - def execute(c: Circuit, map: AnnotationMap): TransformResult = reader match { - case None => TransformResult(c) +class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform { + def inputForm = MidForm + def outputForm = MidForm + override def name = "Create Memory Annotations" + def execute(state: CircuitState): CircuitState = reader match { + case None => state case Some(r) => import CustomYAMLProtocol._ r.parse[Config] match { case Seq(config) => - val cN = CircuitName(c.main) - val top = TopAnnotation(ModuleName(config.top.name, cN), wiringID) - val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)), wiringID) - val pin = PinAnnotation(cN, replaceID, config.pin.name) - TransformResult(c, None, Some(AnnotationMap(Seq(top, source, pin)))) - case Nil => TransformResult(c, None, None) + val cN = CircuitName(state.circuit.main) + val top = TopAnnotation(ModuleName(config.top.name, cN)) + val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN))) + val pin = PinAnnotation(cN, config.pin.name) + state.copy(annotations = Some(AnnotationMap(Seq(top, source, pin)))) + case Nil => state case _ => error("Can only have one config in yaml file") } } diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 28291135..2d6f4e96 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -38,10 +38,10 @@ import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin} import WrappedExpression.weq import Annotations._ -case class InferReadWriteAnnotation(t: String, tID: TransID) - extends Annotation with Loose with Unstable { +case class InferReadWriteAnnotation(t: String) extends Annotation with Loose with Unstable { val target = CircuitName(t) def duplicate(n: Named) = this.copy(t=n.name) + def transform = classOf[InferReadWrite] } // This pass examine the enable signals of the read & write ports of memories @@ -168,7 +168,9 @@ object InferReadWritePass extends Pass { // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite(transID: TransID) extends Transform with SimpleRun { +class InferReadWrite extends Transform with PassBased { + def inputForm = MidForm + def outputForm = MidForm def passSeq = Seq( InferReadWritePass, CheckInitialization, @@ -176,11 +178,12 @@ class InferReadWrite(transID: TransID) extends Transform with SimpleRun { ResolveKinds, ResolveGenders ) - def execute(c: Circuit, map: AnnotationMap) = map get transID match { - case Some(p) => p get CircuitName(c.main) match { - case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq) - case _ => sys.error("Unexpected annotation for InferReadWrite") - } - case _ => TransformResult(c) + def execute(state: CircuitState): CircuitState = { + val result = for { + myAnnotations <- getMyAnnotations(state) + InferReadWriteAnnotation(_) <- myAnnotations get CircuitName(state.circuit.main) + resCircuit = runPasses(state.circuit) + } yield state.copy(circuit = resCircuit) + result getOrElse state // Return state if nothing to do } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index 9ab496d2..ae872639 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -16,7 +16,8 @@ import wiring._ /** Annotates the name of the pin to add for WiringTransform */ -case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends Annotation with Loose with Unstable { +case class PinAnnotation(target: CircuitName, pin: String) extends Annotation with Loose with Unstable { + def transform = classOf[ReplaceMemMacros] def duplicate(n: Named) = n match { case n: CircuitName => this.copy(target = n) case _ => throwInternalError @@ -27,8 +28,10 @@ case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends * This will not generate wmask ports if not needed. * Creates the minimum # of black boxes needed by the design. */ -class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) extends Transform { - def name = "Replace Memory Macros" +class ReplaceMemMacros(writer: ConfWriter) extends Transform { + override def name = "Replace Memory Macros" + def inputForm = MidForm + def outputForm = MidForm /** Return true if mask granularity is per bit, false if per byte or unspecified */ @@ -206,7 +209,8 @@ class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) ext map updateStmtRefs(memPortMap)) } - def execute(c: Circuit, map: AnnotationMap): TransformResult = { + def execute(state: CircuitState): CircuitState = { + val c = state.circuit val namespace = Namespace(c) val memMods = new Modules val nameMap = new NameMap @@ -214,15 +218,15 @@ class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) ext val modules = c.modules map updateMemMods(namespace, nameMap, memMods) // print conf writer.serialize() - val pin = map get myID match { - case Some(p) => + val pin = getMyAnnotations(state) match { + case Some(p) => p.values.head match { - case PinAnnotation(c, _, pin) => pin + case PinAnnotation(c, pin) => pin case _ => error(s"Bad Annotations: ${p.values}") } case None => "pin" } - val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), wiringID, pin) } - TransformResult(c.copy(modules = modules ++ memMods), None, Some(AnnotationMap(annos))) - } + val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), pin) } + CircuitState(c.copy(modules = modules ++ memMods), inputForm, Some(AnnotationMap(annos))) + } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 01f020f5..818bd9cc 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -61,8 +61,7 @@ class ConfWriter(filename: String) { } } -case class ReplSeqMemAnnotation(t: String, tID: TransID) - extends Annotation with Loose with Unstable { +case class ReplSeqMemAnnotation(t: String) extends Annotation with Loose with Unstable { val usage = """ [Optional] ReplSeqMem @@ -91,52 +90,60 @@ Optional Arguments: ) val target = CircuitName(passCircuit) def duplicate(n: Named) = this copy (t = t.replace(s"-c:$passCircuit", s"-c:${n.name}")) + def transform = classOf[ReplSeqMem] } -case class SimpleTransform(p: Pass) extends Transform { - def execute(c: Circuit, map: AnnotationMap): TransformResult = - TransformResult(p.run(c)) +class SimpleTransform(p: Pass, form: CircuitForm) extends Transform { + def inputForm = form + def outputForm = form + def execute(state: CircuitState): CircuitState = state.copy(circuit = p.run(state.circuit)) } -class ReplSeqMem(transID: TransID) extends Transform with SimpleRun { +class SimpleMidTransform(p: Pass) extends SimpleTransform(p, MidForm) + +// SimpleRun instead of PassBased because of the arguments to passSeq +class ReplSeqMem extends Transform with SimpleRun { + def inputForm = MidForm + def outputForm = MidForm def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = - Seq(SimpleTransform(Legalize), - SimpleTransform(ToMemIR), - SimpleTransform(ResolveMaskGranularity), - SimpleTransform(RenameAnnotatedMemoryPorts), - SimpleTransform(ResolveMemoryReference), - new CreateMemoryAnnotations(inConfigFile, TransID(-7), TransID(-8)), - new ReplaceMemMacros(outConfigFile, TransID(-7), TransID(-8)), - new WiringTransform(TransID(-8)), - SimpleTransform(RemoveEmpty), - SimpleTransform(CheckInitialization), - SimpleTransform(InferTypes), - SimpleTransform(Uniquify), - SimpleTransform(ResolveKinds), - SimpleTransform(ResolveGenders)) - def run(circuit: Circuit, map: AnnotationMap, xForms: Seq[Transform]): TransformResult = { - (xForms.foldLeft(TransformResult(circuit, None, Some(map)))) { case (tr: TransformResult, xForm: Transform) => - val x = xForm.execute(tr.circuit, tr.annotation.get) - x.annotation match { - case None => TransformResult(x.circuit, None, Some(map)) - case Some(ann) => TransformResult(x.circuit, None, Some( - AnnotationMap(ann.annotations ++ tr.annotation.get.annotations))) + Seq(new SimpleMidTransform(Legalize), + new SimpleMidTransform(ToMemIR), + new SimpleMidTransform(ResolveMaskGranularity), + new SimpleMidTransform(RenameAnnotatedMemoryPorts), + new SimpleMidTransform(ResolveMemoryReference), + new CreateMemoryAnnotations(inConfigFile), + new ReplaceMemMacros(outConfigFile), + new WiringTransform, + new SimpleMidTransform(RemoveEmpty), + new SimpleMidTransform(CheckInitialization), + new SimpleMidTransform(InferTypes), + new SimpleMidTransform(Uniquify), + new SimpleMidTransform(ResolveKinds), + new SimpleMidTransform(ResolveGenders)) + def run(state: CircuitState, xForms: Seq[Transform]): CircuitState = { + xForms.foldLeft(state) { case (curState: CircuitState, xForm: Transform) => + val res = xForm.execute(state) + res.annotations match { + case None => CircuitState(res.circuit, res.form, state.annotations) + case Some(ann) => CircuitState(res.circuit, res.form, Some( + AnnotationMap(ann.annotations ++ curState.annotations.get.annotations))) } } } - def execute(c: Circuit, map: AnnotationMap) = map get transID match { - case Some(p) => p get CircuitName(c.main) match { - case Some(ReplSeqMemAnnotation(t, _)) => - val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") - val inConfigFile = { - if (inputFileName.isEmpty) None - else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) - else error("Input configuration file does not exist!") - } - val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName)) - run(c, map, passSeq(inConfigFile, outConfigFile)) - case _ => error("Unexpected transform annotation") + def execute(state: CircuitState): CircuitState = + getMyAnnotations(state) match { + case Some(p) => p get CircuitName(state.circuit.main) match { + case Some(ReplSeqMemAnnotation(t)) => + val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") + val inConfigFile = { + if (inputFileName.isEmpty) None + else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) + else error("Input configuration file does not exist!") + } + val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName)) + run(state, passSeq(inConfigFile, outConfigFile)) + case _ => error("Unexpected transform annotation") + } + case None => state // Do nothing if there are no annotations } - case _ => TransformResult(c) - } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index 919948b6..59e76d65 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -11,7 +11,8 @@ import WiringUtils._ /** A component, e.g. register etc. Must be declared only once under the TopAnnotation */ -case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotation with Loose with Unstable { +case class SourceAnnotation(target: ComponentName) extends Annotation with Loose with Unstable { + def transform = classOf[WiringTransform] def duplicate(n: Named) = n match { case n: ComponentName => this.copy(target = n) case _ => throwInternalError @@ -20,7 +21,8 @@ case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotat /** A module, e.g. ExtModule etc., that should add the input pin */ -case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends Annotation with Loose with Unstable { +case class SinkAnnotation(target: ModuleName, pin: String) extends Annotation with Loose with Unstable { + def transform = classOf[WiringTransform] def duplicate(n: Named) = n match { case n: ModuleName => this.copy(target = n) case _ => throwInternalError @@ -30,7 +32,8 @@ case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends /** A module under which all sink module must be declared, and there is only * one source component */ -case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation with Loose with Unstable { +case class TopAnnotation(target: ModuleName) extends Annotation with Loose with Unstable { + def transform = classOf[WiringTransform] def duplicate(n: Named) = n match { case n: ModuleName => this.copy(target = n) case _ => throwInternalError @@ -49,13 +52,15 @@ case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation wi * Notes: * - No module uniquification occurs (due to imposed restrictions) */ -class WiringTransform(transID: TransID) extends Transform with SimpleRun { +class WiringTransform extends Transform with SimpleRun { + def inputForm = MidForm + def outputForm = MidForm def passSeq(wi: WiringInfo) = Seq(new Wiring(wi), InferTypes, ResolveKinds, ResolveGenders) - def execute(c: Circuit, map: AnnotationMap) = map get transID match { + def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { case Some(p) => val sinks = mutable.HashMap[String, String]() val sources = mutable.Set[String]() @@ -63,18 +68,20 @@ class WiringTransform(transID: TransID) extends Transform with SimpleRun { val comp = mutable.Set[String]() p.values.foreach { a => a match { - case SinkAnnotation(m, _, pin) => sinks(m.name) = pin - case SourceAnnotation(c, _) => + case SinkAnnotation(m, pin) => sinks(m.name) = pin + case SourceAnnotation(c) => sources += c.module.name comp += c.name - case TopAnnotation(m, _) => tops += m.name + case TopAnnotation(m) => tops += m.name } } (sources.size, tops.size, sinks.size, comp.size) match { - case (0, 0, p, 0) => TransformResult(c) - case (1, 1, p, 1) if p > 0 => run(c, passSeq(WiringInfo(sources.head, comp.head, sinks.toMap, tops.head))) + case (0, 0, p, 0) => state + case (1, 1, p, 1) if p > 0 => + val winfo = WiringInfo(sources.head, comp.head, sinks.toMap, tops.head) + state.copy(circuit = runPasses(state.circuit, passSeq(winfo))) case _ => error("Wrong number of sources, tops, or sinks!") } - case None => TransformResult(c) + case None => state } } |
