diff options
| author | Schuyler Eldridge | 2020-03-11 14:32:32 -0400 |
|---|---|---|
| committer | GitHub | 2020-03-11 14:32:32 -0400 |
| commit | 026c18dd76d4e2121c7f6c582d15e4d5a3ab842b (patch) | |
| tree | 0537dff3091db3da167c0fffc3388a5966c46204 /src | |
| parent | 646c91e71b8bfb1b0d0f22e81ca113147637ce71 (diff) | |
| parent | abf226471249a1cbb8de33d0c4bc8526f9aafa70 (diff) | |
Merge pull request #1123 from freechipsproject/dependency-api-2
- Use Dependency API for transform scheduling
- Add tests that old order/behavior is preserved
Or: "Now you're thinking with dependencies."
Diffstat (limited to 'src')
77 files changed, 1957 insertions, 406 deletions
diff --git a/src/main/scala/firrtl/AddDescriptionNodes.scala b/src/main/scala/firrtl/AddDescriptionNodes.scala index 2cd8b9f7..dac0e513 100644 --- a/src/main/scala/firrtl/AddDescriptionNodes.scala +++ b/src/main/scala/firrtl/AddDescriptionNodes.scala @@ -5,6 +5,7 @@ package firrtl import firrtl.ir._ import firrtl.annotations._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} case class DescriptionAnnotation(named: Named, description: String) extends Annotation { def update(renames: RenameMap): Seq[DescriptionAnnotation] = { @@ -67,9 +68,25 @@ private case class DescribedMod(description: Description, * @note should only be used by VerilogEmitter, described nodes will * break other transforms. */ -class AddDescriptionNodes extends Transform { - def inputForm = LowForm - def outputForm = LowForm +class AddDescriptionNodes extends Transform with PreservesAll[Transform] { + def inputForm = UnknownForm + def outputForm = UnknownForm + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(firrtl.passes.VerilogPrep) ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty def onStmt(compMap: Map[String, Seq[String]])(stmt: Statement): Statement = { stmt.map(onStmt(compMap)) match { diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 2c5600de..3d42e9d5 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -7,11 +7,13 @@ import java.io.Writer import scala.collection.mutable + import firrtl.annotations._ import firrtl.ir.Circuit import firrtl.Utils.throwInternalError import firrtl.annotations.transforms.{EliminateTargetPaths, ResolvePaths} -import firrtl.options.{StageUtils, TransformLike} +import firrtl.options.{DependencyAPI, Dependency, PreservesAll, StageUtils, TransformLike} +import firrtl.stage.transforms.CatchCustomTransformExceptions /** Container of all annotations for a Firrtl compiler */ class AnnotationSeq private (private[firrtl] val underlying: List[Annotation]) { @@ -98,6 +100,7 @@ object CircuitState { * strictly supersets of the "lower" forms. Thus, that any transform that * operates on [[HighForm]] can also operate on [[MidForm]] or [[LowForm]] */ +@deprecated("CircuitForm will be removed in 1.3. Switch to Seq[TransformDependency] to specify dependencies.", "1.2") sealed abstract class CircuitForm(private val value: Int) extends Ordered[CircuitForm] { // Note that value is used only to allow comparisons def compare(that: CircuitForm): Int = this.value - that.value @@ -116,6 +119,7 @@ sealed abstract class CircuitForm(private val value: Int) extends Ordered[Circui * * See [[CDefMemory]] and [[CDefMPort]] */ +@deprecated("Form-based dependencies will be removed in 1.3. Please migrate to the new Dependency API.", "1.2") final case object ChirrtlForm extends CircuitForm(value = 3) { val outputSuffix: String = ".fir" } @@ -127,6 +131,7 @@ final case object ChirrtlForm extends CircuitForm(value = 3) { * * Also see [[firrtl.ir]] */ +@deprecated("Form-based dependencies will be removed in 1.3. Please migrate to the new Dependency API.", "1.2") final case object HighForm extends CircuitForm(2) { val outputSuffix: String = ".hi.fir" } @@ -138,6 +143,7 @@ final case object HighForm extends CircuitForm(2) { * - All whens must be removed * - There can only be a single connection to any element */ +@deprecated("Form-based dependencies will be removed in 1.3. Please migrate to the new Dependency API.", "1.2") final case object MidForm extends CircuitForm(1) { val outputSuffix: String = ".mid.fir" } @@ -148,6 +154,7 @@ final case object MidForm extends CircuitForm(1) { * - All aggregate types (vector/bundle) must have been removed * - All implicit truncations must be made explicit */ +@deprecated("Form-based dependencies will be removed in 1.3. Please migrate to the new Dependency API.", "1.2") final case object LowForm extends CircuitForm(0) { val outputSuffix: String = ".lo.fir" } @@ -163,6 +170,7 @@ final case object LowForm extends CircuitForm(0) { * TODO(azidar): Replace with PreviousForm, which more explicitly encodes * this requirement. */ +@deprecated("Form-based dependencies will be removed in 1.3. Please migrate to the new Dependency API.", "1.2") final case object UnknownForm extends CircuitForm(-1) { override def compare(that: CircuitForm): Int = { sys.error("Illegal to compare UnknownForm"); 0 } @@ -171,13 +179,22 @@ final case object UnknownForm extends CircuitForm(-1) { // scalastyle:on magic.number /** The basic unit of operating on a Firrtl AST */ -abstract class Transform extends TransformLike[CircuitState] { +trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform] { + /** A convenience function useful for debugging and error messages */ def name: String = this.getClass.getName /** The [[firrtl.CircuitForm]] that this transform requires to operate on */ + @deprecated( + "InputForm/OutputForm will be removed in 1.3. Use DependencyAPI methods (prerequisites, dependents, invalidates)", + "1.2") def inputForm: CircuitForm + /** The [[firrtl.CircuitForm]] that this transform outputs */ + @deprecated( + "InputForm/OutputForm will be removed in 1.3. Use DependencyAPI methods (prerequisites, dependents, invalidates)", + "1.2") def outputForm: CircuitForm + /** Perform the transform, encode renaming with RenameMap, and can * delete annotations * Called by [[runTransform]]. @@ -185,10 +202,65 @@ abstract class Transform extends TransformLike[CircuitState] { * @param state Input Firrtl AST * @return A transformed Firrtl AST */ - protected def execute(state: CircuitState): CircuitState + def execute(state: CircuitState): CircuitState def transform(state: CircuitState): CircuitState = execute(state) + import firrtl.{ChirrtlForm => C, HighForm => H, MidForm => M, LowForm => L, UnknownForm => U} + import firrtl.stage.Forms + + override def prerequisites: Seq[Dependency[Transform]] = inputForm match { + case C => Nil + case H => Forms.Deduped + case M => Forms.MidForm + case L => Forms.LowForm + case U => Nil + } + + override def optionalPrerequisites: Seq[Dependency[Transform]] = inputForm match { + case L => Forms.LowFormOptimized + case _ => Seq.empty + } + + private lazy val fullCompilerSet = new mutable.LinkedHashSet[Dependency[Transform]] ++ Forms.VerilogOptimized + + override def dependents: Seq[Dependency[Transform]] = { + val lowEmitters = Dependency[LowFirrtlEmitter] :: Dependency[VerilogEmitter] :: Dependency[MinimumVerilogEmitter] :: + Dependency[SystemVerilogEmitter] :: Nil + + val emitters = inputForm match { + case C => Dependency[ChirrtlEmitter] :: Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters + case H => Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters + case M => Dependency[MiddleFirrtlEmitter] :: lowEmitters + case L => lowEmitters + case U => Nil + } + + val selfDep = Dependency.fromTransform(this) + + inputForm match { + case C => (fullCompilerSet ++ emitters - selfDep).toSeq + case H => (fullCompilerSet -- Forms.Deduped ++ emitters - selfDep).toSeq + case M => (fullCompilerSet -- Forms.MidForm ++ emitters - selfDep).toSeq + case L => (fullCompilerSet -- Forms.LowFormOptimized ++ emitters - selfDep).toSeq + case U => Nil + } + } + + private lazy val highOutputInvalidates = fullCompilerSet -- Forms.MinimalHighForm + private lazy val midOutputInvalidates = fullCompilerSet -- Forms.MidForm + + override def invalidates(a: Transform): Boolean = { + (inputForm, outputForm) match { + case (U, _) | (_, U) => true // invalidate everything + case (i, o) if i >= o => false // invalidate nothing + case (_, C) => true // invalidate everything + case (_, H) => highOutputInvalidates(Dependency.fromTransform(a)) + case (_, M) => midOutputInvalidates(Dependency.fromTransform(a)) + case (_, L) => false // invalidate nothing + } + } + /** Convenience method to get annotations relevant to this Transform * * @param state The [[CircuitState]] form which to extract annotations @@ -309,7 +381,7 @@ trait ResolvedAnnotationPaths { } /** Defines old API for Emission. Deprecated */ -trait Emitter extends Transform { +trait Emitter extends Transform with PreservesAll[Transform] { @deprecated("Use emission annotations instead", "firrtl 1.0") def emit(state: CircuitState, writer: Writer): Unit @@ -324,6 +396,7 @@ object CompilerUtils extends LazyLogging { * @param outputForm [[CircuitForm]] to lower to * @return Sequence of transforms that will lower if outputForm is lower than inputForm */ + @deprecated("Use a TransformManager requesting which transforms you want to run. This will be removed in 1.3.", "1.2") def getLoweringTransforms(inputForm: CircuitForm, outputForm: CircuitForm): Seq[Transform] = { // If outputForm is equal-to or higher than inputForm, nothing to lower if (outputForm >= inputForm) { @@ -333,9 +406,8 @@ object CompilerUtils extends LazyLogging { case ChirrtlForm => Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm) case HighForm => - Seq(new IRToWorkingIR, new ResolveAndCheck, - new transforms.DedupModules, new HighFirrtlToMiddleFirrtl) ++ - getLoweringTransforms(MidForm, outputForm) + Seq(new IRToWorkingIR, new ResolveAndCheck, new transforms.DedupModules, new HighFirrtlToMiddleFirrtl) ++ + getLoweringTransforms(MidForm, outputForm) case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) case LowForm => throwInternalError("getLoweringTransforms - LowForm") // should be caught by if above case UnknownForm => throwInternalError("getLoweringTransforms - UnknownForm") // should be caught by if above @@ -374,6 +446,7 @@ object CompilerUtils extends LazyLogging { * inputForm of a latter transforms is equal to or lower than the outputForm * of the previous transform. */ + @deprecated("Use a TransformManager with custom targets. This will be removed in 1.3.", "1.2") def mergeTransforms(lowering: Seq[Transform], custom: Seq[Transform]): Seq[Transform] = { custom .sortWith{ @@ -392,6 +465,7 @@ object CompilerUtils extends LazyLogging { } +@deprecated("Use a TransformManager requesting which transforms you want to run. This will be removed in 1.3.", "1.2") trait Compiler extends LazyLogging { def emitter: Emitter @@ -455,15 +529,6 @@ trait Compiler extends LazyLogging { compile(state.copy(annotations = emitAnno +: state.annotations), emitter +: customTransforms) } - private def isCustomTransform(xform: Transform): Boolean = { - def getTopPackage(pack: java.lang.Package): java.lang.Package = - Package.getPackage(pack.getName.split('.').head) - // We use the top package of the Driver to get the top firrtl package - Option(xform.getClass.getPackage).map { p => - getTopPackage(p) != firrtl.Driver.getClass.getPackage - }.getOrElse(true) - } - /** Perform compilation * * Emission will only be performed if [[EmitAnnotation]]s are present @@ -482,7 +547,8 @@ trait Compiler extends LazyLogging { xform.runTransform(in) } catch { // Wrap exceptions from custom transforms so they are reported as such - case e: Exception if isCustomTransform(xform) => throw CustomTransformException(e) + case e: Exception if CatchCustomTransformExceptions.isCustomTransform(xform) => + throw CustomTransformException(e) } } } diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index babcd406..d9d35ebb 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -12,7 +12,7 @@ import firrtl.transforms._ import firrtl.Utils.throwInternalError import firrtl.stage.{FirrtlExecutionResultView, FirrtlStage} import firrtl.stage.phases.DriverCompatibility -import firrtl.options.{StageUtils, Phase, Viewer} +import firrtl.options.{Dependency, Phase, PhaseManager, StageUtils, Viewer} import firrtl.options.phases.DeletedWrapper @@ -210,13 +210,17 @@ object Driver { val annos = optionsManager.firrtlOptions.toAnnotations ++ optionsManager.commonOptions.toAnnotations - val phases: Seq[Phase] = - Seq( new DriverCompatibility.AddImplicitAnnotationFile, - new DriverCompatibility.AddImplicitFirrtlFile, - new DriverCompatibility.AddImplicitOutputFile, - new DriverCompatibility.AddImplicitEmitter, - new FirrtlStage ) + val phases: Seq[Phase] = { + import DriverCompatibility._ + new PhaseManager( + Seq( Dependency[AddImplicitFirrtlFile], + Dependency[AddImplicitAnnotationFile], + Dependency[AddImplicitOutputFile], + Dependency[AddImplicitEmitter], + Dependency[FirrtlStage] )) + .transformOrder .map(DeletedWrapper(_)) + } val annosx = try { phases.foldLeft(annos)( (a, p) => p.transform(a) ) @@ -257,4 +261,3 @@ object Driver { execute(args) } } - diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 12ef17c2..36734d81 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -16,7 +16,7 @@ import firrtl.WrappedExpression._ import Utils._ import MemPortUtils.{memPortField, memType} import firrtl.options.{HasShellOptions, ShellOption, StageUtils, PhaseException, Unserializable} -import firrtl.stage.RunFirrtlTransformAnnotation +import firrtl.stage.{RunFirrtlTransformAnnotation, TransformManager} // Datastructures import scala.collection.mutable.ArrayBuffer @@ -180,6 +180,11 @@ case class VRandom(width: BigInt) extends Expression { class VerilogEmitter extends SeqTransform with Emitter { def inputForm = LowForm def outputForm = LowForm + + override val prerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty + val outputSuffix = ".v" val tab = " " def AND(e1: WrappedExpression, e2: WrappedExpression): Expression = { @@ -973,19 +978,7 @@ class VerilogEmitter extends SeqTransform with Emitter { } /** Preamble for every emitted Verilog file */ - def transforms = Seq( - new BlackBoxSourceHelper, - new FixAddingNegativeLiterals, - new ReplaceTruncatingArithmetic, - new InlineBitExtractionsTransform, - new InlineCastsTransform, - new LegalizeClocksTransform, - new FlattenRegUpdate, - new DeadCodeElimination, - passes.VerilogModulusCleanup, - new VerilogRename, - passes.VerilogPrep, - new AddDescriptionNodes) + def transforms = new TransformManager(firrtl.stage.Forms.VerilogOptimized, prerequisites).flattenedTransformOrder def emit(state: CircuitState, writer: Writer): Unit = { val circuit = runTransforms(state).circuit @@ -1033,16 +1026,18 @@ class VerilogEmitter extends SeqTransform with Emitter { class MinimumVerilogEmitter extends VerilogEmitter with Emitter { + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized - override def transforms = super.transforms.filter{ - case _: DeadCodeElimination => false - case _ => true - } + override def transforms = new TransformManager(firrtl.stage.Forms.VerilogMinimumOptimized, prerequisites) + .flattenedTransformOrder } class SystemVerilogEmitter extends VerilogEmitter { - StageUtils.dramaticWarning("SystemVerilog Emitter is the same as the Verilog Emitter!") - override val outputSuffix: String = ".sv" + + override def execute(state: CircuitState): CircuitState = { + StageUtils.dramaticWarning("SystemVerilog Emitter is the same as the Verilog Emitter!") + super.execute(state) + } } diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 14a8e637..b3d7d087 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -4,53 +4,39 @@ package firrtl import firrtl.transforms.IdentityTransform import firrtl.options.StageUtils +import firrtl.stage.{Forms, TransformManager} +@deprecated("Use a TransformManager or some other Stage/Phase class. Will be removed in 1.3.", "1.2") sealed abstract class CoreTransform extends SeqTransform /** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting * circuit has only IR nodes, not WIR. */ +@deprecated("Use a TransformManager to handle lowering. Will be removed in 1.3.", "1.2") class ChirrtlToHighFirrtl extends CoreTransform { def inputForm = ChirrtlForm def outputForm = HighForm - def transforms = Seq( - passes.CheckChirrtl, - passes.CInferTypes, - passes.CInferMDir, - passes.RemoveCHIRRTL) + def transforms = new TransformManager(Forms.MinimalHighForm, Forms.ChirrtlForm).flattenedTransformOrder } /** Converts from the bare intermediate representation (ir.scala) * to a working representation (WIR.scala) */ +@deprecated("Use a TransformManager to handle lowering. Will be removed in 1.3.", "1.2") class IRToWorkingIR extends CoreTransform { def inputForm = HighForm def outputForm = HighForm - def transforms = Seq(passes.ToWorkingIR) + def transforms = new TransformManager(Forms.WorkingIR, Forms.MinimalHighForm).flattenedTransformOrder } /** Resolves types, kinds, and flows, and checks the circuit legality. * Operates on working IR nodes and high Firrtl. */ +@deprecated("Use a TransformManager to handle lowering. Will be removed in 1.3.", "1.2") class ResolveAndCheck extends CoreTransform { def inputForm = HighForm def outputForm = HighForm - def transforms = Seq( - passes.CheckHighForm, - passes.ResolveKinds, - passes.InferTypes, - passes.CheckTypes, - passes.Uniquify, - passes.ResolveKinds, - passes.InferTypes, - passes.ResolveFlows, - passes.CheckFlows, - new passes.InferBinaryPoints(), - new passes.TrimIntervals(), - new passes.InferWidths, - passes.CheckWidths, - new firrtl.transforms.InferResets, - passes.CheckTypes) + def transforms = new TransformManager(Forms.Resolved, Forms.WorkingIR).flattenedTransformOrder } /** Expands aggregate connects, removes dynamic accesses, and when @@ -58,78 +44,40 @@ class ResolveAndCheck extends CoreTransform { * well-formed graph. * Operates on working IR nodes. */ +@deprecated("Use a TransformManager to handle lowering. Will be removed in 1.3.", "1.2") class HighFirrtlToMiddleFirrtl extends CoreTransform { def inputForm = HighForm def outputForm = MidForm - def transforms = Seq( - passes.PullMuxes, - passes.ReplaceAccesses, - passes.ExpandConnects, - passes.RemoveAccesses, - passes.Uniquify, - passes.ExpandWhens, - passes.CheckInitialization, - passes.ResolveKinds, - passes.InferTypes, - passes.CheckTypes, - passes.ResolveFlows, - new passes.InferWidths, - passes.CheckWidths, - new passes.RemoveIntervals(), - passes.ConvertFixedToSInt, - passes.ZeroWidth, - passes.InferTypes) + def transforms = new TransformManager(Forms.MidForm, Forms.Deduped).flattenedTransformOrder } /** Expands all aggregate types into many ground-typed components. Must * accept a well-formed graph of only middle Firrtl features. * Operates on working IR nodes. */ +@deprecated("Use a TransformManager to handle lowering. Will be removed in 1.3.", "1.2") class MiddleFirrtlToLowFirrtl extends CoreTransform { def inputForm = MidForm def outputForm = LowForm - def transforms = Seq( - passes.LowerTypes, - passes.ResolveKinds, - passes.InferTypes, - passes.ResolveFlows, - new passes.InferWidths, - passes.Legalize, - new firrtl.transforms.RemoveReset, - new firrtl.transforms.CheckCombLoops, - new checks.CheckResets, - new firrtl.transforms.RemoveWires) + def transforms = new TransformManager(Forms.LowForm, Forms.MidForm).flattenedTransformOrder } /** Runs a series of optimization passes on LowFirrtl * @note This is currently required for correct Verilog emission * TODO Fix the above note */ +@deprecated("Use a TransformManager to handle lowering. Will be removed in 1.3.", "1.2") class LowFirrtlOptimization extends CoreTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq( - passes.RemoveValidIf, - new firrtl.transforms.ConstantPropagation, - passes.PadWidths, - new firrtl.transforms.ConstantPropagation, - passes.Legalize, - passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - new firrtl.transforms.ConstantPropagation, - passes.SplitExpressions, - new firrtl.transforms.CombineCats, - passes.CommonSubexpressionElimination, - new firrtl.transforms.DeadCodeElimination) + def transforms = new TransformManager(Forms.LowFormOptimized, Forms.LowForm).flattenedTransformOrder } /** Runs runs only the optimization passes needed for Verilog emission */ +@deprecated("Use a TransformManager to handle lowering. Will be removed in 1.3.", "1.2") class MinimumLowFirrtlOptimization extends CoreTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq( - passes.RemoveValidIf, - passes.Legalize, - passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - passes.SplitExpressions) + def transforms = new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm).flattenedTransformOrder } diff --git a/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala index 5a729225..e741a32b 100644 --- a/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala +++ b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala @@ -122,7 +122,7 @@ class EliminateTargetPaths extends Transform { (cir.copy(modules = finalModuleList), renameMap) } - override protected def execute(state: CircuitState): CircuitState = { + override def execute(state: CircuitState): CircuitState = { val (annotations, annotationsx) = state.annotations.partition{ case a: ResolvePaths => true diff --git a/src/main/scala/firrtl/checks/CheckResets.scala b/src/main/scala/firrtl/checks/CheckResets.scala index d6337f9e..406b7f62 100644 --- a/src/main/scala/firrtl/checks/CheckResets.scala +++ b/src/main/scala/firrtl/checks/CheckResets.scala @@ -3,6 +3,7 @@ package firrtl.checks import firrtl._ +import firrtl.options.{Dependency, PreservesAll} import firrtl.passes.{Errors, PassException} import firrtl.ir._ import firrtl.traversals.Foreachers._ @@ -25,10 +26,19 @@ object CheckResets { // Must run after ExpandWhens // Requires // - static single connections of ground types -class CheckResets extends Transform { +class CheckResets extends Transform with PreservesAll[Transform] { def inputForm: CircuitForm = MidForm def outputForm: CircuitForm = MidForm + override val prerequisites = + Seq( Dependency(passes.LowerTypes), + Dependency(passes.Legalize), + Dependency(firrtl.transforms.RemoveReset) ) ++ firrtl.stage.Forms.MidForm + + override val optionalPrerequisites = Seq(Dependency[firrtl.transforms.CheckCombLoops]) + + override val dependents = Seq.empty + import CheckResets._ private def onStmt(regCheck: RegCheckList, drivers: DirectDriverMap)(stmt: Statement): Unit = { @@ -72,4 +82,3 @@ class CheckResets extends Transform { state } } - diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 08237ab2..08c127da 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -2,8 +2,16 @@ package firrtl.passes +import firrtl.Transform import firrtl.ir._ +import firrtl.options.{Dependency, PreservesAll} + +object CheckChirrtl extends Pass with CheckHighFormLike with PreservesAll[Transform] { + + override val dependents = firrtl.stage.Forms.ChirrtlForm ++ + Seq( Dependency(CInferTypes), + Dependency(CInferMDir), + Dependency(RemoveCHIRRTL) ) -object CheckChirrtl extends Pass with CheckHighFormLike { def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] = None } diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 9fbf3eeb..63790564 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -6,6 +6,7 @@ import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.traversals.Foreachers._ +import firrtl.options.PreservesAll import annotation.tailrec @@ -14,7 +15,10 @@ import annotation.tailrec * @note This pass looks for [[firrtl.WVoid]]s left behind by [[ExpandWhens]] * @note Assumes single connection (ie. no last connect semantics) */ -object CheckInitialization extends Pass { +object CheckInitialization extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.Resolved + private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression]) class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException( diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 6ceac032..b750196a 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -9,8 +9,14 @@ import firrtl.traversals.Foreachers._ import firrtl.Utils._ import firrtl.constraint.IsKnown import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken} +import firrtl.options.{Dependency, PreservesAll} + +object CheckWidths extends Pass with PreservesAll[Transform] { + + override val prerequisites = Dependency[passes.InferWidths] +: firrtl.stage.Forms.WorkingIR + + override val dependents = Seq(Dependency[transforms.InferResets]) -object CheckWidths extends Pass { /** The maximum allowed width for any circuit element */ val MaxWidth = 1000000 val DshlMaxWidth = getUIntWidth(MaxWidth) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index a5f66a55..e176bcc4 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -9,8 +9,9 @@ import firrtl.Utils._ import firrtl.traversals.Foreachers._ import firrtl.WrappedType._ import firrtl.constraint.{Constraint, IsKnown} +import firrtl.options.{Dependency, PreservesAll} -trait CheckHighFormLike { +trait CheckHighFormLike { this: Pass => type NameSet = collection.mutable.HashSet[String] // Custom Exceptions @@ -267,7 +268,18 @@ trait CheckHighFormLike { } } -object CheckHighForm extends Pass with CheckHighFormLike { +object CheckHighForm extends Pass with CheckHighFormLike with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.WorkingIR + + override val dependents = + Seq( Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.Uniquify), + Dependency(passes.ResolveFlows), + Dependency[passes.InferWidths], + Dependency[transforms.InferResets] ) + class IllegalChirrtlMemException(info: Info, mname: String, name: String) extends PassException( s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.") @@ -279,7 +291,17 @@ object CheckHighForm extends Pass with CheckHighFormLike { Some(new IllegalChirrtlMemException(info, mname, memName)) } } -object CheckTypes extends Pass { + +object CheckTypes extends Pass with PreservesAll[Transform] { + + override val prerequisites = Dependency(InferTypes) +: firrtl.stage.Forms.WorkingIR + + override val dependents = + Seq( Dependency(passes.Uniquify), + Dependency(passes.ResolveFlows), + Dependency(passes.CheckFlows), + Dependency[passes.InferWidths], + Dependency(passes.CheckWidths) ) // Custom Exceptions class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException( @@ -583,7 +605,16 @@ object CheckTypes extends Pass { } } -object CheckFlows extends Pass { +object CheckFlows extends Pass with PreservesAll[Transform] { + + override val prerequisites = Dependency(passes.ResolveFlows) +: firrtl.stage.Forms.WorkingIR + + override val dependents = + Seq( Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals], + Dependency[passes.InferWidths], + Dependency[transforms.InferResets] ) + type FlowMap = collection.mutable.HashMap[String, Flow] implicit def toStr(g: Flow): String = g match { diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 3ba12b2d..d54d8088 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -5,9 +5,21 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} +object CommonSubexpressionElimination extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowForm ++ + Seq( Dependency(firrtl.passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[firrtl.transforms.CombineCats] ) + + override val dependents = + Seq( Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) -object CommonSubexpressionElimination extends Pass { private def cse(s: Statement): Statement = { val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]() val nodes = collection.mutable.HashMap[String, Expression]() diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 05a000c5..7e65bdd1 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -8,10 +8,20 @@ import firrtl.ir._ import firrtl._ import firrtl.Mappers._ import firrtl.Utils.{sub_type, module_type, field_type, max, throwInternalError} +import firrtl.options.{Dependency, PreservesAll} /** Replaces FixedType with SIntType, and correctly aligns all binary points */ -object ConvertFixedToSInt extends Pass { +object ConvertFixedToSInt extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck], + Dependency[RemoveIntervals] ) ++ firrtl.stage.Forms.Deduped + def alignArg(e: Expression, point: BigInt): Expression = e.tpe match { case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p) if((point - p) > 0) { @@ -83,29 +93,29 @@ object ConvertFixedToSInt extends Pass { types(name) = newType newStmt case WDefInstance(info, name, module, tpe) => - val newType = moduleTypes(module) + val newType = moduleTypes(module) types(name) = newType WDefInstance(info, name, module, newType) - case Connect(info, loc, exp) => + case Connect(info, loc, exp) => val point = calcPoint(Seq(loc)) val newExp = alignArg(exp, point) Connect(info, loc, newExp) map updateExpType - case PartialConnect(info, loc, exp) => + case PartialConnect(info, loc, exp) => val point = calcPoint(Seq(loc)) val newExp = alignArg(exp, point) PartialConnect(info, loc, newExp) map updateExpType // check Connect case, need to shl case s => (s map updateStmtType) map updateExpType } - + m.ports.foreach(p => types(p.name) = p.tpe) m match { case Module(info, name, ports, body) => Module(info,name,ports,updateStmtType(body)) case m:ExtModule => m } } - - val newModules = for(m <- c.modules) yield { + + val newModules = for(m <- c.modules) yield { val newPorts = m.ports.map(p => Port(p.info,p.name,p.direction,toSIntType(p.tpe))) m match { case Module(info, name, ports, body) => Module(info,name,newPorts,body) @@ -113,8 +123,13 @@ object ConvertFixedToSInt extends Pass { } } newModules.foreach(m => moduleTypes(m.name) = module_type(m)) - firrtl.passes.InferTypes.run(Circuit(c.info, newModules.map(onModule(_)), c.main )) + + /* @todo This should be moved outside */ + (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main )) } } + + + // vim: set ts=4 sw=4 et: diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 12aa9690..aaf3d9b4 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -8,6 +8,7 @@ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.WrappedExpression._ +import firrtl.options.Dependency import annotation.tailrec import collection.mutable @@ -24,6 +25,19 @@ import collection.mutable * @note Assumes all references are declared */ object ExpandWhens extends Pass { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency(Uniquify) ) ++ firrtl.stage.Forms.Resolved + + override def invalidates(a: Transform): Boolean = a match { + case CheckInitialization | ResolveKinds | InferTypes => true + case _ => false + } + /** Returns circuit with when and last connection semantics resolved */ def run(c: Circuit): Circuit = { val modulesx = c.modules map { @@ -287,3 +301,24 @@ object ExpandWhens extends Pass { private def NOT(e: Expression) = DoPrim(Eq, Seq(e, zero), Nil, BoolType) } + +class ExpandWhensAndCheck extends SeqTransform { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency(Uniquify) ) ++ firrtl.stage.Forms.Deduped + + override def invalidates(a: Transform): Boolean = a match { + case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true + case _ => false + } + + override def inputForm = UnknownForm + override def outputForm = UnknownForm + + override val transforms = Seq(ExpandWhens, CheckInitialization) + +} diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala index 258c9697..86bc36fc 100644 --- a/src/main/scala/firrtl/passes/InferBinaryPoints.scala +++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala @@ -7,8 +7,19 @@ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target} import firrtl.constraint.ConstraintSolver +import firrtl.Transform +import firrtl.options.{Dependency, PreservesAll} + +class InferBinaryPoints extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(ResolveKinds), + Dependency(InferTypes), + Dependency(Uniquify), + Dependency(ResolveFlows) ) + + override val dependents = Seq.empty -class InferBinaryPoints extends Pass { private val constraintSolver = new ConstraintSolver() private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { @@ -71,14 +82,14 @@ class InferBinaryPoints extends Pass { case _ => sys.error("Shouldn't be here") } private def fixType(t: Type): Type = t map fixType map fixWidth match { - case IntervalType(l, u, p) => + case IntervalType(l, u, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) case None => p case _ => sys.error("Shouldn't be here") } IntervalType(l, u, px) - case FixedType(w, p) => + case FixedType(w, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) case None => p diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 3c5cf7fb..d625b626 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -6,8 +6,12 @@ import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} + +object InferTypes extends Pass with PreservesAll[Transform] { + + override val prerequisites = Dependency(ResolveKinds) +: firrtl.stage.Forms.WorkingIR -object InferTypes extends Pass { type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { @@ -79,12 +83,15 @@ object InferTypes extends Pass { val types = new TypeMap m map infer_types_p(types) map infer_types_s(types) } - + c copy (modules = c.modules map infer_types) } } -object CInferTypes extends Pass { +object CInferTypes extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.ChirrtlForm + type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { @@ -133,12 +140,12 @@ object CInferTypes extends Pass { types(p.name) = p.tpe p } - + def infer_types(m: DefModule): DefModule = { val types = new TypeMap m map infer_types_p(types) map infer_types_s(types) } - + c copy (modules = c.modules map infer_types) } } diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 2211d238..29936ca0 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -11,6 +11,8 @@ import firrtl.Mappers._ import firrtl.Implicits.width2constraint import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target} import firrtl.constraint.{ConstraintSolver, IsMax} +import firrtl.options.{Dependency, PreservesAll} +import firrtl.traversals.Foreachers._ object InferWidths { def apply(): InferWidths = new InferWidths() @@ -60,7 +62,16 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg * * Uses firrtl.constraint package to infer widths */ -class InferWidths extends Transform with ResolvedAnnotationPaths { +class InferWidths extends Transform with ResolvedAnnotationPaths with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.Uniquify), + Dependency(passes.ResolveFlows), + Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR + def inputForm: CircuitForm = UnknownForm def outputForm: CircuitForm = UnknownForm @@ -108,12 +119,12 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { val n = get_size(c.loc.tpe) val locs = create_exps(c.loc) val exps = create_exps(c.expr) - (locs zip exps).foreach { case (loc, exp) => - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } - } + (locs zip exps).foreach { case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } + } c case pc: PartialConnect => val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) @@ -142,8 +153,8 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { } a case c: Conditionally => - addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1))) - c map addStmtConstraints(mt) + addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1))) + c map addStmtConstraints(mt) case x => x map addStmtConstraints(mt) } private def fixWidth(w: Width): Width = constraintSolver.get(w) match { @@ -152,7 +163,7 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { case _ => sys.error("Shouldn't be here") } private def fixType(t: Type): Type = t map fixType map fixWidth match { - case IntervalType(l, u, p) => + case IntervalType(l, u, p) => val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match { case (Some(x: Bound), Some(y: Bound)) => (x, y) case (None, None) => (l, u) @@ -174,8 +185,8 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name))) constraintSolver.solve() val ret = InferTypes.run(c.copy(modules = c.modules map (_ - map fixPort - map fixStmt))) + map fixPort + map fixStmt))) constraintSolver.clear() ret } @@ -212,11 +223,11 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target => val baseType = typeMap.getOrElse(target.copy(component = Seq.empty), - throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint())) + throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint())) val leafType = target.componentType(baseType) if (leafType.isInstanceOf[AggregateType]) { throw new Exception(s"Target below is an AggregateType, which " + - "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) } leafType diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index f52e1e6b..73ef8a22 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -26,6 +26,15 @@ object LowerTypes extends Transform { def inputForm = UnknownForm def outputForm = UnknownForm + override val prerequisites = firrtl.stage.Forms.MidForm + + override val dependents = Seq.empty + + override def invalidates(a: Transform): Boolean = a match { + case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true + case _ => false + } + /** Delimiter used in lowering names */ val delim = "_" /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index cbd8250a..0b318511 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -6,9 +6,30 @@ package passes import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Mappers._ +import firrtl.options.Dependency + +import scala.collection.mutable // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { + + override val prerequisites = + ((new mutable.LinkedHashSet()) + ++ firrtl.stage.Forms.LowForm + - Dependency(firrtl.passes.Legalize) + + Dependency(firrtl.passes.RemoveValidIf) + + Dependency[firrtl.transforms.ConstantPropagation]).toSeq + + override val dependents = + Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) + + override def invalidates(a: Transform): Boolean = a match { + case _: firrtl.transforms.ConstantPropagation | Legalize => true + case _ => false + } + private def width(t: Type): Int = bitWidth(t).toInt private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 9a644dc8..a8d37758 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -7,6 +7,7 @@ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ +import firrtl.options.{Dependency, PreservesAll} import firrtl.transforms.ConstantPropagation import scala.collection.mutable @@ -46,7 +47,10 @@ class Errors { } // These should be distributed into separate files -object ToWorkingIR extends Pass { +object ToWorkingIR extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.MinimalHighForm + def toExp(e: Expression): Expression = e map toExp match { case ex: Reference => WRef(ex.name, ex.tpe, UnknownKind, UnknownFlow) case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, UnknownFlow) @@ -64,8 +68,11 @@ object ToWorkingIR extends Pass { c copy (modules = c.modules map (_ map toStmt)) } -object PullMuxes extends Pass { - def run(c: Circuit): Circuit = { +object PullMuxes extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.Deduped + + def run(c: Circuit): Circuit = { def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { case ex: WSubField => ex.expr match { case exx: Mux => Mux(exx.cond, @@ -102,7 +109,12 @@ object PullMuxes extends Pass { } } -object ExpandConnects extends Pass { +object ExpandConnects extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped + def run(c: Circuit): Circuit = { def expand_connects(m: Module): Module = { val flows = collection.mutable.LinkedHashMap[String,Flow]() @@ -179,7 +191,14 @@ object ExpandConnects extends Pass { // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts // TODO replace UInt with zero-width wire instead -object Legalize extends Pass { +object Legalize extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) + + override val optionalPrerequisites = Seq.empty + + override val dependents = Seq.empty + private def legalizeShiftRight(e: DoPrim): Expression = { require(e.op == Shr) e.args.head match { @@ -260,7 +279,22 @@ object Legalize extends Pass { * * @note The result of this pass is NOT legal Firrtl */ -object VerilogPrep extends Pass { +object VerilogPrep extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty type AttachSourceMap = Map[WrappedExpression, Expression] diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index 1c2dc096..ac5d8a4e 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -2,18 +2,30 @@ package firrtl.passes -import firrtl.{WRef, WSubAccess, WSubIndex, WSubField, Namespace} +import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubIndex, WSubField} import firrtl.PrimOps.{And, Eq} import firrtl.ir._ import firrtl.Mappers._ import firrtl.Utils._ import firrtl.WrappedExpression._ -import scala.collection.mutable +import firrtl.options.Dependency +import scala.collection.mutable /** Removes all [[firrtl.WSubAccess]] from circuit */ -class RemoveAccesses extends Pass { +object RemoveAccesses extends Pass { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped + + override def invalidates(a: Transform): Boolean = a match { + case Uniquify => true + case _ => false + } + private def AND(e1: Expression, e2: Expression) = if(e1 == one) e2 else if(e2 == one) e1 @@ -166,14 +178,3 @@ class RemoveAccesses extends Pass { }) } } - -object RemoveAccesses extends Pass { - def apply: Pass = { - new RemoveAccesses() - } - - def run(c: Circuit): Circuit = { - val t = new RemoveAccesses - t.run(c) - } -} diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 921ec3c7..05dd8bd9 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -8,12 +8,18 @@ import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} case class MPort(name: String, clk: Expression) case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort]) case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean) -object RemoveCHIRRTL extends Transform { +object RemoveCHIRRTL extends Transform with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.ChirrtlForm ++ + Seq( Dependency(passes.CInferTypes), + Dependency(passes.CInferMDir) ) + def inputForm: CircuitForm = UnknownForm def outputForm: CircuitForm = UnknownForm val ut = UnknownType diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala index 73f59b59..cf3d2ff2 100644 --- a/src/main/scala/firrtl/passes/RemoveIntervals.scala +++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala @@ -8,6 +8,7 @@ import firrtl._ import firrtl.Mappers._ import Implicits.{bigint2WInt} import firrtl.constraint.IsKnown +import firrtl.options.{Dependency, PreservesAll} import scala.math.BigDecimal.RoundingMode._ @@ -35,7 +36,14 @@ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) * c. replace with SIntType * 3) Run InferTypes */ -class RemoveIntervals extends Pass { +class RemoveIntervals extends Pass with PreservesAll[Transform] { + + override val prerequisites: Seq[Dependency[Transform]] = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck] ) ++ firrtl.stage.Forms.Deduped def run(c: Circuit): Circuit = { val alignedCircuit = c diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 42eae7e5..3b5499ac 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -2,9 +2,11 @@ package firrtl package passes + import firrtl.Mappers._ import firrtl.ir._ import Utils.throwInternalError +import firrtl.options.Dependency /** Remove [[firrtl.ir.ValidIf ValidIf]] and replace [[firrtl.ir.IsInvalid IsInvalid]] with a connection to zero */ object RemoveValidIf extends Pass { @@ -27,6 +29,17 @@ object RemoveValidIf extends Pass { case other => throwInternalError(s"Unexpected type $other") } + override val prerequisites = firrtl.stage.Forms.LowForm + + override val dependents = + Seq( Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) + + override def invalidates(a: Transform): Boolean = a match { + case Legalize | _: firrtl.transforms.ConstantPropagation => true + case _ => false + } + // Recursive. Removes ValidIfs private def onExp(e: Expression): Expression = { e map onExp match { diff --git a/src/main/scala/firrtl/passes/ReplaceAccesses.scala b/src/main/scala/firrtl/passes/ReplaceAccesses.scala index 2ec035f3..75cca77a 100644 --- a/src/main/scala/firrtl/passes/ReplaceAccesses.scala +++ b/src/main/scala/firrtl/passes/ReplaceAccesses.scala @@ -2,22 +2,26 @@ package firrtl.passes +import firrtl.Transform import firrtl.ir._ import firrtl.{WSubAccess, WSubIndex} import firrtl.Mappers._ - +import firrtl.options.{Dependency, PreservesAll} /** Replaces constant [[firrtl.WSubAccess]] with [[firrtl.WSubIndex]] * TODO Fold in to High Firrtl Const Prop */ -object ReplaceAccesses extends Pass { +object ReplaceAccesses extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.Deduped :+ Dependency(PullMuxes) + def run(c: Circuit): Circuit = { def onStmt(s: Statement): Statement = s map onStmt map onExp def onExp(e: Expression): Expression = e match { case WSubAccess(ex, UIntLiteral(value, width), t, g) => WSubIndex(onExp(ex), value.toInt, t, g) case _ => e map onExp } - + c copy (modules = c.modules map (_ map onStmt)) } } diff --git a/src/main/scala/firrtl/passes/Resolves.scala b/src/main/scala/firrtl/passes/Resolves.scala index 97cc4bb3..15750b76 100644 --- a/src/main/scala/firrtl/passes/Resolves.scala +++ b/src/main/scala/firrtl/passes/Resolves.scala @@ -5,9 +5,14 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} import Utils.throwInternalError -object ResolveKinds extends Pass { + +object ResolveKinds extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.WorkingIR + type KindMap = collection.mutable.LinkedHashMap[String, Kind] def find_port(kinds: KindMap)(p: Port): Port = { @@ -45,7 +50,13 @@ object ResolveKinds extends Pass { c copy (modules = c.modules map resolve_kinds) } -object ResolveFlows extends Pass { +object ResolveFlows extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.Uniquify) ) ++ firrtl.stage.Forms.WorkingIR + def resolve_e(g: Flow)(e: Expression): Expression = e match { case ex: WRef => ex copy (flow = g) case WSubField(exp, name, tpe, _) => WSubField( @@ -88,7 +99,10 @@ object ResolveGenders extends Pass { } -object CInferMDir extends Pass { +object CInferMDir extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.ChirrtlForm :+ Dependency(CInferTypes) + type MPortDirMap = collection.mutable.LinkedHashMap[String, MPortDir] def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match { diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index de955c9a..43d0ed34 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -3,7 +3,9 @@ package firrtl package passes +import firrtl.{SystemVerilogEmitter, VerilogEmitter} import firrtl.ir._ +import firrtl.options.{Dependency, PreservesAll} import firrtl.Mappers._ import firrtl.Utils.{kind, flow, get_info} @@ -12,7 +14,16 @@ import scala.collection.mutable // Splits compound expressions into simple expressions // and named intermediate nodes -object SplitExpressions extends Pass { +object SplitExpressions extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowForm ++ + Seq( Dependency(firrtl.passes.RemoveValidIf), + Dependency(firrtl.passes.memlib.VerilogMemDelays) ) + + override val dependents = + Seq( Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) + private def onModule(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala index f659815e..4e558e2a 100644 --- a/src/main/scala/firrtl/passes/TrimIntervals.scala +++ b/src/main/scala/firrtl/passes/TrimIntervals.scala @@ -6,6 +6,8 @@ import firrtl.PrimOps._ import firrtl.ir._ import firrtl.Mappers._ import firrtl.constraint.{IsFloor, IsKnown, IsMul} +import firrtl.options.{Dependency, PreservesAll} +import firrtl.Transform /** Replaces IntervalType with SIntType, three AST walks: * 1) Align binary points @@ -18,7 +20,17 @@ import firrtl.constraint.{IsFloor, IsKnown, IsMul} * c. replace with SIntType * 3) Run InferTypes */ -class TrimIntervals extends Pass { +class TrimIntervals extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(ResolveKinds), + Dependency(InferTypes), + Dependency(Uniquify), + Dependency(ResolveFlows), + Dependency[InferBinaryPoints] ) + + override val dependents = Seq.empty + def run(c: Circuit): Circuit = { // Open -> closed val firstPass = InferTypes.run(c map replaceModuleInterval) @@ -80,7 +92,7 @@ class TrimIntervals extends Pass { val shiftMul = Closed(BigDecimal(1) / shiftGain) val bpGain = BigDecimal(BigInt(1) << current.toInt) // BP is inferred at this point - // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt) + // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt) val newBPRes = Closed(shiftGain / bpGain) val bpResInv = Closed(bpGain) val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes) diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 978ccc66..1268cac2 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -8,8 +8,9 @@ import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -import MemPortUtils.memType +import firrtl.options.Dependency +import MemPortUtils.memType /** Resolve name collisions that would occur in [[LowerTypes]] * @@ -32,6 +33,16 @@ import MemPortUtils.memType * to rename a */ object Uniquify extends Transform { + + override val prerequisites = + Seq( Dependency(ResolveKinds), + Dependency(InferTypes) ) ++ firrtl.stage.Forms.WorkingIR + + override def invalidates(a: Transform): Boolean = a match { + case ResolveKinds | InferTypes => true + case _ => false + } + def inputForm = UnknownForm def outputForm = UnknownForm private case class UniquifyException(msg: String) extends FirrtlInternalException(msg) @@ -41,9 +52,9 @@ object Uniquify extends Transform { // For creation of rename map private case class NameMapNode(name: String, elts: Map[String, NameMapNode]) - // Appends delim to prefix until no collisions of prefix + elts in names - // We don't add an _ in the collision check because elts could be Seq("") - // In this case, we're just really checking if prefix itself collides + /** Appends delim to prefix until no collisions of prefix + elts in names We don't add an _ in the collision check + * because elts could be Seq("") In this case, we're just really checking if prefix itself collides + */ @tailrec def findValidPrefix( prefix: String, @@ -55,10 +66,12 @@ object Uniquify extends Transform { } } - // Enumerates all possible names for a given type - // eg. foo : { bar : { a, b }[2], c } - // => foo, foo bar, foo bar 0, foo bar 1, foo bar 0 a, foo bar 0 b, - // foo bar 1 a, foo bar 1 b, foo c + /** Enumerates all possible names for a given type. For example: + * {{{ + * foo : { bar : { a, b }[2], c } + * => foo, foo bar, foo bar 0, foo bar 1, foo bar 0 a, foo bar 0 b, foo bar 1 a, foo bar 1 b, foo c + * }}} + */ private [firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match { case t: BundleType => t.fields flatMap { f => @@ -72,6 +85,36 @@ object Uniquify extends Transform { case _ => Seq() } + /** Creates a Bundle Type from a Stmt */ + def stmtToType(s: Statement)(implicit sinfo: Info, mname: String): BundleType = { + // Recursive helper + def recStmtToType(s: Statement): Seq[Field] = s match { + case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe)) + case sx: DefRegister => Seq(Field(sx.name, Default, sx.tpe)) + case sx: WDefInstance => Seq(Field(sx.name, Default, sx.tpe)) + case sx: DefMemory => sx.dataType match { + case (_: UIntType | _: SIntType | _: FixedType) => + Seq(Field(sx.name, Default, memType(sx))) + case tpe: BundleType => + val newFields = tpe.fields map ( f => + DefMemory(sx.info, f.name, f.tpe, sx.depth, sx.writeLatency, + sx.readLatency, sx.readers, sx.writers, sx.readwriters) + ) flatMap recStmtToType + Seq(Field(sx.name, Default, BundleType(newFields))) + case tpe: VectorType => + val newFields = (0 until tpe.size) map ( i => + sx.copy(name = i.toString, dataType = tpe.tpe) + ) flatMap recStmtToType + Seq(Field(sx.name, Default, BundleType(newFields))) + } + case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe)) + case sx: Conditionally => recStmtToType(sx.conseq) ++ recStmtToType(sx.alt) + case sx: Block => (sx.stmts map recStmtToType).flatten + case sx => Seq() + } + BundleType(recStmtToType(s)) + } + // Accepts a Type and an initial namespace // Returns new Type with uniquified names private def uniquifyNames( @@ -202,36 +245,6 @@ object Uniquify extends Transform { case t => t } - // Creates a Bundle Type from a Stmt - def stmtToType(s: Statement)(implicit sinfo: Info, mname: String): BundleType = { - // Recursive helper - def recStmtToType(s: Statement): Seq[Field] = s match { - case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe)) - case sx: DefRegister => Seq(Field(sx.name, Default, sx.tpe)) - case sx: WDefInstance => Seq(Field(sx.name, Default, sx.tpe)) - case sx: DefMemory => sx.dataType match { - case (_: UIntType | _: SIntType | _: FixedType) => - Seq(Field(sx.name, Default, memType(sx))) - case tpe: BundleType => - val newFields = tpe.fields map ( f => - DefMemory(sx.info, f.name, f.tpe, sx.depth, sx.writeLatency, - sx.readLatency, sx.readers, sx.writers, sx.readwriters) - ) flatMap recStmtToType - Seq(Field(sx.name, Default, BundleType(newFields))) - case tpe: VectorType => - val newFields = (0 until tpe.size) map ( i => - sx.copy(name = i.toString, dataType = tpe.tpe) - ) flatMap recStmtToType - Seq(Field(sx.name, Default, BundleType(newFields))) - } - case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe)) - case sx: Conditionally => recStmtToType(sx.conseq) ++ recStmtToType(sx.alt) - case sx: Block => (sx.stmts map recStmtToType).flatten - case sx => Seq() - } - BundleType(recStmtToType(s)) - } - // Everything wrapped in run so that it's thread safe def execute(state: CircuitState): CircuitState = { val c = state.circuit diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala index fdc81797..f47ddfbd 100644 --- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala +++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala @@ -7,6 +7,7 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.PrimOps.{Bits, Rem} import firrtl.Utils._ +import firrtl.options.{Dependency, PreservesAll} import scala.collection.mutable @@ -23,7 +24,20 @@ import scala.collection.mutable * This is technically incorrect firrtl, but allows the verilog emitter * to emit correct verilog without needing to add temporary nodes */ -object VerilogModulusCleanup extends Pass { +object VerilogModulusCleanup extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty private def onModule(m: Module): Module = { val namespace = Namespace(m) diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index e01cfffc..e60d76d1 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -6,8 +6,24 @@ import firrtl.PrimOps._ import firrtl.ir._ import firrtl._ import firrtl.Mappers._ +import firrtl.options.Dependency object ZeroWidth extends Transform { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency(Uniquify), + Dependency[ExpandWhensAndCheck], + Dependency(ConvertFixedToSInt) ) ++ firrtl.stage.Forms.Deduped + + override def invalidates(a: Transform): Boolean = a match { + case InferTypes => true + case _ => false + } + def inputForm: CircuitForm = UnknownForm def outputForm: CircuitForm = UnknownForm diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index 80b5cbb8..e5e6d6d4 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -8,6 +8,8 @@ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ +import firrtl.transforms +import firrtl.options.Dependency import MemPortUtils._ import WrappedExpression._ @@ -69,7 +71,7 @@ object MemDelayAndReadwriteTransformer { * read and write ports while simultaneously compiling memory latencies to combinational-read * memories with delay pipelines. It is represented as a class that takes a module as a constructor * argument, as it encapsulates the mutable state required to analyze and transform one module. - * + * * @note The final transformed module is found in the (sole public) field [[transformed]] */ class MemDelayAndReadwriteTransformer(m: DefModule) { @@ -165,6 +167,18 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { } object VerilogMemDelays extends Pass { + + override val prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf) + + override val dependents = + Seq( Dependency[VerilogEmitter], + Dependency[SystemVerilogEmitter] ) + + override def invalidates(a: Transform): Boolean = a match { + case _: transforms.ConstantPropagation => true + case _ => false + } + def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) } diff --git a/src/main/scala/firrtl/stage/FirrtlStage.scala b/src/main/scala/firrtl/stage/FirrtlStage.scala index 2f7f0d11..94c4a896 100644 --- a/src/main/scala/firrtl/stage/FirrtlStage.scala +++ b/src/main/scala/firrtl/stage/FirrtlStage.scala @@ -2,37 +2,32 @@ package firrtl.stage -import firrtl.{AnnotationSeq, CustomTransformException, FirrtlInternalException, - FirrtlUserException, FIRRTLException, Utils} -import firrtl.options.{Stage, Phase, PhaseException, Shell, OptionsException, StageMain} +import firrtl.AnnotationSeq +import firrtl.options.{Dependency, Phase, PhaseManager, PreservesAll, Shell, Stage, StageMain} import firrtl.options.phases.DeletedWrapper +import firrtl.stage.phases.CatchExceptions -import scala.util.control.ControlThrowable +class FirrtlPhase + extends PhaseManager(targets=Seq(Dependency[firrtl.stage.phases.Compiler], Dependency[firrtl.stage.phases.WriteEmitted])) + with PreservesAll[Phase] { + override val wrappers = Seq(CatchExceptions(_: Phase), DeletedWrapper(_: Phase)) + +} class FirrtlStage extends Stage { + + lazy val phase = new FirrtlPhase + + override lazy val prerequisites = phase.prerequisites + + override lazy val dependents = phase.dependents + + override def invalidates(a: Phase): Boolean = phase.invalidates(a) + val shell: Shell = new Shell("firrtl") with FirrtlCli - private val phases: Seq[Phase] = - Seq( new firrtl.stage.phases.AddDefaults, - new firrtl.stage.phases.AddImplicitEmitter, - new firrtl.stage.phases.Checks, - new firrtl.stage.phases.AddCircuit, - new firrtl.stage.phases.AddImplicitOutputFile, - new firrtl.stage.phases.Compiler, - new firrtl.stage.phases.WriteEmitted ) - .map(DeletedWrapper(_)) - - def run(annotations: AnnotationSeq): AnnotationSeq = try { - phases.foldLeft(annotations)((a, f) => f.transform(a)) - } catch { - /* Rethrow the exceptions which are expected or due to the runtime environment (out of memory, stack overflow, etc.). - * Any UNEXPECTED exceptions should be treated as internal errors. */ - case p @ (_: ControlThrowable | _: FIRRTLException | _: OptionsException | _: FirrtlUserException - | _: FirrtlInternalException | _: PhaseException) => throw p - case CustomTransformException(cause) => throw cause - case e: Exception => Utils.throwInternalError(exception = Some(e)) - } + def run(annotations: AnnotationSeq): AnnotationSeq = phase.transform(annotations) } diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala new file mode 100644 index 00000000..f3eabd23 --- /dev/null +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -0,0 +1,94 @@ +// See LICENSE for license details. + +package firrtl.stage + +import firrtl._ +import firrtl.options.Dependency +import firrtl.stage.TransformManager.TransformDependency + +/* + * - InferWidths should have InferTypes split out + * - ConvertFixedToSInt should have InferTypes split out + * - Move InferTypes out of ZeroWidth + */ + +object Forms { + + lazy val ChirrtlForm: Seq[TransformDependency] = Seq.empty + + lazy val MinimalHighForm: Seq[TransformDependency] = ChirrtlForm ++ + Seq( Dependency(passes.CheckChirrtl), + Dependency(passes.CInferTypes), + Dependency(passes.CInferMDir), + Dependency(passes.RemoveCHIRRTL) ) + + lazy val WorkingIR: Seq[TransformDependency] = MinimalHighForm :+ Dependency(passes.ToWorkingIR) + + lazy val Resolved: Seq[TransformDependency] = WorkingIR ++ + Seq( Dependency(passes.CheckHighForm), + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.CheckTypes), + Dependency(passes.Uniquify), + Dependency(passes.ResolveFlows), + Dependency(passes.CheckFlows), + Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals], + Dependency[passes.InferWidths], + Dependency(passes.CheckWidths), + Dependency[firrtl.transforms.InferResets] ) + + lazy val Deduped: Seq[TransformDependency] = Resolved :+ Dependency[firrtl.transforms.DedupModules] + + lazy val HighForm: Seq[TransformDependency] = ChirrtlForm ++ + MinimalHighForm ++ + WorkingIR ++ + Resolved ++ + Deduped + + lazy val MidForm: Seq[TransformDependency] = HighForm ++ + Seq( Dependency(passes.PullMuxes), + Dependency(passes.ReplaceAccesses), + Dependency(passes.ExpandConnects), + Dependency(passes.RemoveAccesses), + Dependency[passes.ExpandWhensAndCheck], + Dependency[passes.RemoveIntervals], + Dependency(passes.ConvertFixedToSInt), + Dependency(passes.ZeroWidth) ) + + lazy val LowForm: Seq[TransformDependency] = MidForm ++ + Seq( Dependency(passes.LowerTypes), + Dependency(passes.Legalize), + Dependency(firrtl.transforms.RemoveReset), + Dependency[firrtl.transforms.CheckCombLoops], + Dependency[checks.CheckResets], + Dependency[firrtl.transforms.RemoveWires] ) + + lazy val LowFormMinimumOptimized: Seq[TransformDependency] = LowForm ++ + Seq( Dependency(passes.RemoveValidIf), + Dependency(passes.memlib.VerilogMemDelays), + Dependency(passes.SplitExpressions) ) + + lazy val LowFormOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ + Seq( Dependency[firrtl.transforms.ConstantPropagation], + Dependency(passes.PadWidths), + Dependency[firrtl.transforms.CombineCats], + Dependency(passes.CommonSubexpressionElimination), + Dependency[firrtl.transforms.DeadCodeElimination] ) + + lazy val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ + Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(passes.VerilogPrep), + Dependency[firrtl.AddDescriptionNodes] ) + + lazy val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogMinimumOptimized + +} diff --git a/src/main/scala/firrtl/stage/TransformManager.scala b/src/main/scala/firrtl/stage/TransformManager.scala new file mode 100644 index 00000000..95878c91 --- /dev/null +++ b/src/main/scala/firrtl/stage/TransformManager.scala @@ -0,0 +1,34 @@ +// See LICENSE for license details. + +package firrtl.stage + +import firrtl.{CircuitForm, CircuitState, Transform, UnknownForm} +import firrtl.options.{Dependency, DependencyManager} + +/** A [[Transform]] that ensures some other [[Transform]]s and their prerequisites are executed. + * + * @param targets the transforms you want to run + * @param currentState the transforms that have already run + * @param knownObjects existing transform objects that have already been constructed + */ +class TransformManager( + val targets: Seq[TransformManager.TransformDependency], + val currentState: Seq[TransformManager.TransformDependency] = Seq.empty, + val knownObjects: Set[Transform] = Set.empty) extends Transform with DependencyManager[CircuitState, Transform] { + + override def inputForm: CircuitForm = UnknownForm + + override def outputForm: CircuitForm = UnknownForm + + override def execute(state: CircuitState): CircuitState = transform(state) + + override protected def copy(a: Seq[Dependency[Transform]], b: Seq[Dependency[Transform]], c: Set[Transform]) = new TransformManager(a, b, c) + +} + +object TransformManager { + + /** The type used to represent dependencies between [[Transform]]s */ + type TransformDependency = Dependency[Transform] + +} diff --git a/src/main/scala/firrtl/stage/phases/AddCircuit.scala b/src/main/scala/firrtl/stage/phases/AddCircuit.scala index 30c23098..5f1e381e 100644 --- a/src/main/scala/firrtl/stage/phases/AddCircuit.scala +++ b/src/main/scala/firrtl/stage/phases/AddCircuit.scala @@ -5,7 +5,7 @@ package firrtl.stage.phases import firrtl.stage._ import firrtl.{AnnotationSeq, Parser} -import firrtl.options.{Phase, PhasePrerequisiteException} +import firrtl.options.{Dependency, Phase, PhasePrerequisiteException, PreservesAll} /** [[firrtl.options.Phase Phase]] that expands [[FirrtlFileAnnotation]]/[[FirrtlSourceAnnotation]] into * [[FirrtlCircuitAnnotation]]s and deletes the originals. This is part of the preprocessing done on an input @@ -25,7 +25,11 @@ import firrtl.options.{Phase, PhasePrerequisiteException} * an [[InfoModeAnnotation]].'''. * @define infoModeException firrtl.options.PhasePrerequisiteException if no [[InfoModeAnnotation]] is present */ -class AddCircuit extends Phase { +class AddCircuit extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq(Dependency[AddDefaults], Dependency[Checks]) + + override val dependents = Seq.empty /** Extract the info mode from an [[AnnotationSeq]] or use the default info mode if no annotation exists * @param annotations some annotations diff --git a/src/main/scala/firrtl/stage/phases/AddDefaults.scala b/src/main/scala/firrtl/stage/phases/AddDefaults.scala index dc8dc025..17d7778d 100644 --- a/src/main/scala/firrtl/stage/phases/AddDefaults.scala +++ b/src/main/scala/firrtl/stage/phases/AddDefaults.scala @@ -3,14 +3,18 @@ package firrtl.stage.phases import firrtl.AnnotationSeq -import firrtl.options.{Phase, TargetDirAnnotation} +import firrtl.options.{Phase, PreservesAll, TargetDirAnnotation} import firrtl.transforms.BlackBoxTargetDirAnno import firrtl.stage.{CompilerAnnotation, InfoModeAnnotation, FirrtlOptions} /** [[firrtl.options.Phase Phase]] that adds default [[FirrtlOption]] [[firrtl.annotations.Annotation Annotation]]s. * This is a part of the preprocessing done by [[FirrtlStage]]. */ -class AddDefaults extends Phase { +class AddDefaults extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq.empty + + override val dependents = Seq.empty /** Append any missing default annotations to an annotation sequence */ def transform(annotations: AnnotationSeq): AnnotationSeq = { diff --git a/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala b/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala index 84f98cdb..7b7a6382 100644 --- a/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala +++ b/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala @@ -4,12 +4,16 @@ package firrtl.stage.phases import firrtl.{AnnotationSeq, EmitAnnotation, EmitCircuitAnnotation} import firrtl.stage.{CompilerAnnotation, RunFirrtlTransformAnnotation} -import firrtl.options.Phase +import firrtl.options.{Dependency, Phase, PreservesAll} /** [[firrtl.options.Phase Phase]] that adds a [[firrtl.EmitCircuitAnnotation EmitCircuitAnnotation]] derived from a * [[firrtl.stage.CompilerAnnotation CompilerAnnotation]] if one does not already exist. */ -class AddImplicitEmitter extends Phase { +class AddImplicitEmitter extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq(Dependency[AddDefaults]) + + override val dependents = Seq.empty def transform(annos: AnnotationSeq): AnnotationSeq = { val emitter = annos.collectFirst{ case a: EmitAnnotation => a } diff --git a/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala b/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala index 9ef32ab1..75e97c9b 100644 --- a/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala +++ b/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala @@ -3,22 +3,28 @@ package firrtl.stage.phases import firrtl.{AnnotationSeq, EmitAllModulesAnnotation} -import firrtl.options.{Phase, Viewer} +import firrtl.options.{Dependency, Phase, PreservesAll, Viewer} import firrtl.stage.{FirrtlOptions, OutputFileAnnotation} /** [[firrtl.options.Phase Phase]] that adds an [[OutputFileAnnotation]] if one does not already exist. * * To determine the [[OutputFileAnnotation]], the following precedence is used. Whichever happens first succeeds: * - Do nothing if an [[OutputFileAnnotation]] or [[EmitAllModulesAnnotation]] exist - * - Use the main in the first discovered [[FirrtlCircuitAnnotation]] (see note below) + * - Use the main in the first discovered [[firrtl.stage.FirrtlCircuitAnnotation FirrtlCircuitAnnotation]] (see note + * below) * - Use "a" * * The file suffix may or may not be specified, but this may be arbitrarily changed by the [[Emitter]]. * - * @note This [[firrtl.options.Phase Phase]] has a dependency on [[AddCircuit]]. Only a [[FirrtlCircuitAnnotation]] - * will be used to implicitly set the [[OutputFileAnnotation]] (not other [[CircuitOption]] subclasses). + * @note This [[firrtl.options.Phase Phase]] has a dependency on [[AddCircuit]]. Only a + * [[firrtl.stage.FirrtlCircuitAnnotation FirrtlCircuitAnnotation]] will be used to implicitly set the + * [[OutputFileAnnotation]] (not other [[firrtl.stage.CircuitOption CircuitOption]] subclasses). */ -class AddImplicitOutputFile extends Phase { +class AddImplicitOutputFile extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq(Dependency[AddCircuit]) + + override val dependents = Seq.empty /** Add an [[OutputFileAnnotation]] to an [[AnnotationSeq]] */ def transform(annotations: AnnotationSeq): AnnotationSeq = diff --git a/src/main/scala/firrtl/stage/phases/CatchExceptions.scala b/src/main/scala/firrtl/stage/phases/CatchExceptions.scala new file mode 100644 index 00000000..a7f84b73 --- /dev/null +++ b/src/main/scala/firrtl/stage/phases/CatchExceptions.scala @@ -0,0 +1,37 @@ +// See LICENSE for license details. + +package firrtl.stage.phases + +import firrtl.{AnnotationSeq, CustomTransformException, FIRRTLException, FirrtlInternalException, FirrtlUserException, + Utils} +import firrtl.options.{DependencyManagerException, Phase, PhaseException, OptionsException} +import firrtl.passes.{PassException, PassExceptions} + +import scala.util.control.ControlThrowable + +class CatchExceptions(val underlying: Phase) extends Phase { + + override final val prerequisites = underlying.prerequisites + override final val dependents = underlying.dependents + override final def invalidates(a: Phase): Boolean = underlying.invalidates(a) + override final lazy val name = underlying.name + + override def transform(a: AnnotationSeq): AnnotationSeq = try { + underlying.transform(a) + } catch { + /* Rethrow the exceptions which are expected or due to the runtime environment (out of memory, stack overflow, etc.). + * Any UNEXPECTED exceptions should be treated as internal errors. */ + case p @ (_: ControlThrowable | _: FIRRTLException | _: OptionsException | _: FirrtlUserException + | _: FirrtlInternalException | _: PhaseException | _: DependencyManagerException) => throw p + case CustomTransformException(cause) => throw cause + case e: Exception => Utils.throwInternalError(exception = Some(e)) + } + +} + + +object CatchExceptions { + + def apply(p: Phase): CatchExceptions = new CatchExceptions(p) + +} diff --git a/src/main/scala/firrtl/stage/phases/Checks.scala b/src/main/scala/firrtl/stage/phases/Checks.scala index 9ee2b854..ebf4c04f 100644 --- a/src/main/scala/firrtl/stage/phases/Checks.scala +++ b/src/main/scala/firrtl/stage/phases/Checks.scala @@ -6,7 +6,7 @@ import firrtl.stage._ import firrtl.{AnnotationSeq, EmitAllModulesAnnotation, EmitCircuitAnnotation} import firrtl.annotations.Annotation -import firrtl.options.{OptionsException, Phase} +import firrtl.options.{Dependency, OptionsException, Phase, PreservesAll} /** [[firrtl.options.Phase Phase]] that strictly validates an [[AnnotationSeq]]. The checks applied are intended to be * extremeley strict. Nothing is inferred or assumed to take a default value (for default value resolution see @@ -16,11 +16,15 @@ import firrtl.options.{OptionsException, Phase} * certain that other [[firrtl.options.Phase Phase]]s or views will succeed. See [[FirrtlStage]] for a list of * [[firrtl.options.Phase Phase]] that commonly run before this. */ -class Checks extends Phase { +class Checks extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq(Dependency[AddDefaults], Dependency[AddImplicitEmitter]) + + override val dependents = Seq.empty /** Determine if annotations are sane * - * @param annos a sequence of [[annotation.Annotation]] + * @param annos a sequence of [[firrtl.annotations.Annotation Annotation]] * @return true if all checks pass * @throws firrtl.options.OptionsException if any checks fail */ diff --git a/src/main/scala/firrtl/stage/phases/Compiler.scala b/src/main/scala/firrtl/stage/phases/Compiler.scala index 917e1a2c..3850d0a8 100644 --- a/src/main/scala/firrtl/stage/phases/Compiler.scala +++ b/src/main/scala/firrtl/stage/phases/Compiler.scala @@ -3,9 +3,9 @@ package firrtl.stage.phases import firrtl.{AnnotationSeq, ChirrtlForm, CircuitState, Compiler => FirrtlCompiler, Transform, seqToAnnoSeq} -import firrtl.options.{Phase, PhasePrerequisiteException, Translator} -import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, - RunFirrtlTransformAnnotation} +import firrtl.options.{Dependency, Phase, PhasePrerequisiteException, PreservesAll, Translator} +import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, Forms, RunFirrtlTransformAnnotation} +import firrtl.stage.TransformManager.TransformDependency import scala.collection.mutable @@ -23,13 +23,13 @@ private [stage] case class Defaults( compiler: Option[FirrtlCompiler] = None) /** Runs the FIRRTL compilers on an [[AnnotationSeq]]. If the input [[AnnotationSeq]] contains more than one circuit - * (i.e., more than one [[FirrtlCircuitAnnotation]]), then annotations will be broken up and each run will be executed - * in parallel. + * (i.e., more than one [[firrtl.stage.FirrtlCircuitAnnotation FirrtlCircuitAnnotation]]), then annotations will be + * broken up and each run will be executed in parallel. * * The [[AnnotationSeq]] will be chunked up into compiler runs using the following algorithm. All annotations that - * occur before the first [[FirrtlCircuitAnnotation]] are treated as global annotations that apply to all circuits. - * Annotations after a circuit are only associated with their closest preceeding circuit. E.g., for the following - * annotations (where A, B, and C are some annotations): + * occur before the first [[firrtl.stage.FirrtlCircuitAnnotation FirrtlCircuitAnnotation]] are treated as global + * annotations that apply to all circuits. Annotations after a circuit are only associated with their closest + * preceeding circuit. E.g., for the following annotations (where A, B, and C are some annotations): * * A(a), FirrtlCircuitAnnotation(x), B, FirrtlCircuitAnnotation(y), A(b), C, FirrtlCircuitAnnotation(z) * @@ -42,7 +42,16 @@ private [stage] case class Defaults( * FirrtlCircuitAnnotation(y). Note: A(b) ''may'' overwrite A(a) if this is a CompilerAnnotation. * FirrtlCircuitAnnotation(z) has no annotations, so it only gets the default A(a). */ -class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { +class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] with PreservesAll[Phase] { + + override val prerequisites = + Seq(Dependency[AddDefaults], + Dependency[AddImplicitEmitter], + Dependency[Checks], + Dependency[AddCircuit], + Dependency[AddImplicitOutputFile]) + + override val dependents = Seq(Dependency[WriteEmitted]) /** Convert an [[AnnotationSeq]] into a sequence of compiler runs. */ protected def aToB(a: AnnotationSeq): Seq[CompilerRun] = { @@ -85,15 +94,32 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { */ protected def internalTransform(b: Seq[CompilerRun]): Seq[CompilerRun] = { def f(c: CompilerRun): CompilerRun = { - val statex = c - .compiler - .getOrElse { throw new PhasePrerequisiteException("No compiler specified!") } - .compile(c.stateIn, c.transforms.reverse) - c.copy(stateOut = Some(statex)) + val targets = c.compiler match { + case Some(d) => c.transforms.reverse.map(Dependency.fromTransform(_)) ++ compilerToTransforms(d) + case None => throw new PhasePrerequisiteException("No compiler specified!") } + val tm = new firrtl.stage.transforms.Compiler(targets) + /* Transform order is lazily evaluated. Force it here to remove its resolution time from actual compilation. */ + val (timeResolveDependencies, _) = firrtl.Utils.time { tm.flattenedTransformOrder } + logger.error(f"Computed transform order in: $timeResolveDependencies%.1f ms") + /* Show the determined transform order */ + logger.info("Determined Transform order that will be executed:\n" + tm.prettyPrint(" ")) + /* Run all determined transforms tracking how long everything takes to run */ + val (timeExecute, annotationsOut) = firrtl.Utils.time { tm.transform(c.stateIn) } + logger.error(f"Total FIRRTL Compile Time: $timeExecute%.1f ms") + c.copy(stateOut = Some(annotationsOut)) } if (b.size <= 1) { b.map(f) } else { b.par.map(f).seq } } + private def compilerToTransforms(a: FirrtlCompiler): Seq[TransformDependency] = a match { + case _: firrtl.NoneCompiler => Forms.ChirrtlForm + case _: firrtl.HighFirrtlCompiler => Forms.HighForm + case _: firrtl.MiddleFirrtlCompiler => Forms.MidForm + case _: firrtl.LowFirrtlCompiler => Forms.LowForm + case _: firrtl.VerilogCompiler | _: firrtl.SystemVerilogCompiler => Forms.LowFormOptimized + case _: firrtl.MinimumVerilogCompiler => Forms.LowFormMinimumOptimized + } + } diff --git a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala index 665861ef..40640fb1 100644 --- a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala +++ b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala @@ -8,9 +8,9 @@ import firrtl.{AnnotationSeq, EmitAllModulesAnnotation, EmitCircuitAnnotation, F import firrtl.annotations.NoTargetAnnotation import firrtl.FileUtils import firrtl.proto.FromProto -import firrtl.options.{InputAnnotationFileAnnotation, OptionsException, Phase, - StageOptions, StageUtils} +import firrtl.options.{InputAnnotationFileAnnotation, OptionsException, Phase, PreservesAll, StageOptions, StageUtils} import firrtl.options.Viewer +import firrtl.options.Dependency import scopt.OptionParser @@ -63,7 +63,7 @@ object DriverCompatibility { /** Indicates that the implicit emitter, derived from a [[CompilerAnnotation]] should be an [[EmitAllModulesAnnotation]] * as opposed to an [[EmitCircuitAnnotation]]. */ - private [firrtl] case object EmitOneFilePerModuleAnnotation extends NoTargetAnnotation { + case object EmitOneFilePerModuleAnnotation extends NoTargetAnnotation { def addOptions(p: OptionParser[AnnotationSeq]): Unit = p .opt[Unit]("split-modules") @@ -105,10 +105,11 @@ object DriverCompatibility { * [[firrtl.options.InputAnnotationFileAnnotation InputAnnotationFileAnnotation]] is present. * * The implicit annotation file is determined through the following complicated semantics: - * - If an [[InputAnnotationFileAnnotation]] already exists, then nothing is modified + * - If an [[firrtl.options.InputAnnotationFileAnnotation InputAnnotationFileAnnotation]] already exists, then + * nothing is modified * - If the derived topName (the `main` in a [[firrtl.ir.Circuit Circuit]]) is ''discernable'' (see below) and a * file called `topName.anno` (exactly, not `topName.anno.json`) exists, then this will add an - * [[InputAnnotationFileAnnotation]] using that `topName.anno` + * [[firrtl.options.InputAnnotationFileAnnotation InputAnnotationFileAnnotation]] using that `topName.anno` * - If any of this doesn't work, then the the [[AnnotationSeq]] is unmodified * * The precedence for determining the `topName` is the following (first one wins): @@ -121,9 +122,14 @@ object DriverCompatibility { * @param annos input annotations * @return output annotations */ - class AddImplicitAnnotationFile extends Phase { + class AddImplicitAnnotationFile extends Phase with PreservesAll[Phase] { - /** Try to add an [[InputAnnotationFileAnnotation]] implicitly specified by an [[AnnotationSeq]]. */ + override val prerequisites = Seq(Dependency[AddImplicitFirrtlFile]) + + override val dependents = Seq(Dependency[FirrtlPhase], Dependency[FirrtlStage]) + + /** Try to add an [[firrtl.options.InputAnnotationFileAnnotation InputAnnotationFileAnnotation]] implicitly specified by + * an [[AnnotationSeq]]. */ def transform(annotations: AnnotationSeq): AnnotationSeq = annotations .collectFirst{ case a: InputAnnotationFileAnnotation => a } match { case Some(_) => annotations @@ -155,7 +161,11 @@ object DriverCompatibility { * @param annotations input annotations * @return */ - class AddImplicitFirrtlFile extends Phase { + class AddImplicitFirrtlFile extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq.empty + + override val dependents = Seq(Dependency[FirrtlPhase], Dependency[FirrtlStage]) /** Try to add a [[FirrtlFileAnnotation]] implicitly specified by an [[AnnotationSeq]]. */ def transform(annotations: AnnotationSeq): AnnotationSeq = { @@ -174,7 +184,7 @@ object DriverCompatibility { } } - /** Adds an [[EmitAnnotation]] for each [[CompilerAnnotation]]. + /** Adds an [[firrtl.EmitAnnotation EmitAnnotation]] for each [[CompilerAnnotation]]. * * If an [[EmitOneFilePerModuleAnnotation]] exists, then this will add an [[EmitAllModulesAnnotation]]. Otherwise, * this adds an [[EmitCircuitAnnotation]]. This replicates old behavior where specifying a compiler automatically @@ -182,7 +192,11 @@ object DriverCompatibility { */ @deprecated("""AddImplicitEmitter should only be used to build Driver compatibility wrappers. Switch to Stage.""", "1.2") - class AddImplicitEmitter extends Phase { + class AddImplicitEmitter extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq.empty + + override val dependents = Seq(Dependency[FirrtlPhase], Dependency[FirrtlStage]) /** Add one [[EmitAnnotation]] foreach [[CompilerAnnotation]]. */ def transform(annotations: AnnotationSeq): AnnotationSeq = { @@ -204,7 +218,11 @@ object DriverCompatibility { */ @deprecated("""AddImplicitOutputFile should only be used to build Driver compatibility wrappers. Switch to Stage.""", "1.2") - class AddImplicitOutputFile extends Phase { + class AddImplicitOutputFile extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq(Dependency[AddImplicitFirrtlFile]) + + override val dependents = Seq(Dependency[FirrtlPhase], Dependency[FirrtlStage]) /** Add an [[OutputFileAnnotation]] derived from a [[TopNameAnnotation]] if needed. */ def transform(annotations: AnnotationSeq): AnnotationSeq = { diff --git a/src/main/scala/firrtl/stage/phases/WriteEmitted.scala b/src/main/scala/firrtl/stage/phases/WriteEmitted.scala index 7c38ebbf..dcef4629 100644 --- a/src/main/scala/firrtl/stage/phases/WriteEmitted.scala +++ b/src/main/scala/firrtl/stage/phases/WriteEmitted.scala @@ -3,7 +3,7 @@ package firrtl.stage.phases import firrtl.{AnnotationSeq, EmittedModuleAnnotation, EmittedCircuitAnnotation} -import firrtl.options.{Phase, StageOptions, Viewer} +import firrtl.options.{Phase, PreservesAll, StageOptions, Viewer} import firrtl.stage.FirrtlOptions import java.io.PrintWriter @@ -24,7 +24,11 @@ import java.io.PrintWriter * * Any annotations written to files will be deleted. */ -class WriteEmitted extends Phase { +class WriteEmitted extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq.empty + + override val dependents = Seq.empty /** Write any [[EmittedAnnotation]]s in an [[AnnotationSeq]] to files. Written [[EmittedAnnotation]]s are deleted. */ def transform(annotations: AnnotationSeq): AnnotationSeq = { diff --git a/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala b/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala new file mode 100644 index 00000000..8b8b8368 --- /dev/null +++ b/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala @@ -0,0 +1,30 @@ +// See LICENSE for license details. + +package firrtl.stage.transforms + +import firrtl.{CircuitState, CustomTransformException, Transform} + +class CatchCustomTransformExceptions(val underlying: Transform) extends Transform with WrappedTransform { + + override def execute(c: CircuitState): CircuitState = try { + underlying.execute(c) + } catch { + case e: Exception if CatchCustomTransformExceptions.isCustomTransform(trueUnderlying) => throw CustomTransformException(e) + } + +} + +object CatchCustomTransformExceptions { + + private[firrtl] def isCustomTransform(xform: Transform): Boolean = { + def getTopPackage(pack: java.lang.Package): java.lang.Package = + Package.getPackage(pack.getName.split('.').head) + // We use the top package of the Driver to get the top firrtl package + Option(xform.getClass.getPackage).map { p => + getTopPackage(p) != firrtl.Driver.getClass.getPackage + }.getOrElse(true) + } + + def apply(a: Transform): CatchCustomTransformExceptions = new CatchCustomTransformExceptions(a) + +} diff --git a/src/main/scala/firrtl/stage/transforms/Compiler.scala b/src/main/scala/firrtl/stage/transforms/Compiler.scala new file mode 100644 index 00000000..ded50ce6 --- /dev/null +++ b/src/main/scala/firrtl/stage/transforms/Compiler.scala @@ -0,0 +1,47 @@ +// See LICENSE for license details. + +package firrtl.stage.transforms + +import firrtl.{CircuitState, Transform, VerilogEmitter} +import firrtl.options.DependencyManagerUtils.CharSet +import firrtl.stage.TransformManager + +class Compiler( + targets: Seq[TransformManager.TransformDependency], + currentState: Seq[TransformManager.TransformDependency] = Seq.empty, + knownObjects: Set[Transform] = Set.empty) extends TransformManager(targets, currentState, knownObjects) { + + override val wrappers = Seq( + (a: Transform) => CatchCustomTransformExceptions(a), + (a: Transform) => UpdateAnnotations(a) + ) + + override def customPrintHandling( + tab: String, + charSet: CharSet, + size: Int): Option[PartialFunction[(Transform, Int), Seq[String]]] = { + + val (l, n, c) = (charSet.lastNode, charSet.notLastNode, charSet.continuation) + val last = size - 1 + + val f: PartialFunction[(Transform, Int), Seq[String]] = { + { + case (a: VerilogEmitter, `last`) => + val firstTransforms = a.transforms.dropRight(1) + val lastTransform = a.transforms.last + Seq(s"$tab$l ${a.name}") ++ + firstTransforms.map(t => s"""$tab${" " * c.size} $n ${t.name}""") :+ + s"""$tab${" " * c.size} $l ${lastTransform.name}""" + case (a: VerilogEmitter, _) => + val firstTransforms = a.transforms.dropRight(1) + val lastTransform = a.transforms.last + Seq(s"$tab$n ${a.name}") ++ + firstTransforms.map(t => s"""$tab$c $n ${t.name}""") :+ + s"""$tab$c $l ${lastTransform.name}""" + } + } + + Some(f) + } + +} diff --git a/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala b/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala new file mode 100644 index 00000000..6c77cd1d --- /dev/null +++ b/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala @@ -0,0 +1,69 @@ +// See LICENSE for license details. + +package firrtl.stage.transforms + +import firrtl.{AnnotationSeq, CircuitState, Transform} +import firrtl.annotations.NoTargetAnnotation +import firrtl.options.{Dependency, DependencyManagerException} + +case class TransformHistoryAnnotation(history: Seq[Transform], state: Set[Transform]) extends NoTargetAnnotation { + + def add(transform: Transform, + invalidates: (Transform) => Boolean = (a: Transform) => false): TransformHistoryAnnotation = + this.copy( + history = transform +: this.history, + state = (this.state + transform).filterNot(invalidates) + ) + +} + +object TransformHistoryAnnotation { + + def apply(transform: Transform): TransformHistoryAnnotation = TransformHistoryAnnotation( + history = Seq(transform), + state = Set(transform) + ) + +} + +class TrackTransforms(val underlying: Transform) extends Transform with WrappedTransform { + + private def updateState(annotations: AnnotationSeq): AnnotationSeq = { + var foundAnnotation = false + val annotationsx = annotations.map { + case x: TransformHistoryAnnotation => + foundAnnotation = true + x.add(trueUnderlying) + case x => x + } + if (!foundAnnotation) { + TransformHistoryAnnotation(trueUnderlying) +: annotationsx + } else { + annotationsx + } + } + + override def execute(c: CircuitState): CircuitState = { + val state = c.annotations + .collectFirst{ case TransformHistoryAnnotation(_, state) => state } + .getOrElse(Set.empty[Transform]) + .map(Dependency.fromTransform(_)) + + if (!trueUnderlying.prerequisites.toSet.subsetOf(state)) { + throw new DependencyManagerException( + s"""|Tried to execute Transform '$trueUnderlying' for which run-time prerequisites were not satisfied: + | state: ${state.mkString("\n -", "\n -", "")} + | prerequisites: ${trueUnderlying.prerequisites.mkString("\n -", "\n -", "")}""".stripMargin) + } + + val out = underlying.execute(c) + out.copy(annotations = updateState(out.annotations)) + } + +} + +object TrackTransforms { + + def apply(a: Transform): Transform = new TrackTransforms(a) + +} diff --git a/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala b/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala new file mode 100644 index 00000000..aad0ab48 --- /dev/null +++ b/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala @@ -0,0 +1,95 @@ +// See LICENSE for license details. + +package firrtl.stage.transforms + +import firrtl.{AnnotationSeq, CircuitState, RenameMap, Transform, Utils} +import firrtl.annotations.{Annotation, DeletedAnnotation} +import firrtl.options.Translator + +import scala.collection.mutable + +class UpdateAnnotations(val underlying: Transform) extends Transform with WrappedTransform + with Translator[CircuitState, (CircuitState, CircuitState)] { + + override def execute(c: CircuitState): CircuitState = underlying.execute(c) + + def aToB(a: CircuitState): (CircuitState, CircuitState) = (a, a) + + def bToA(b: (CircuitState, CircuitState)): CircuitState = { + val (state, result) = (b._1, b._2) + + val remappedAnnotations = propagateAnnotations(state.annotations, result.annotations, result.renames) + + logger.info(s"Form: ${result.form}") + + logger.debug(s"Annotations:") + remappedAnnotations.foreach( a => logger.debug(a.serialize) ) + + logger.trace(s"Circuit:\n${result.circuit.serialize}") + logger.info(s"======== Finished Transform $name ========\n") + + CircuitState(result.circuit, result.form, remappedAnnotations, None) + } + + def internalTransform(b: (CircuitState, CircuitState)): (CircuitState, CircuitState) = { + logger.info(s"======== Starting Transform $name ========") + + /* @todo: prepare should likely be factored out of this */ + val (timeMillis, result) = Utils.time { execute( trueUnderlying.prepare(b._2) ) } + + logger.info(s"""----------------------------${"-" * name.size}---------\n""") + logger.info(f"Time: $timeMillis%.1f ms") + + (b._1, result) + } + + /** Propagate annotations and update their names. + * + * @param inAnno input AnnotationSeq + * @param resAnno result AnnotationSeq + * @param renameOpt result RenameMap + * @return the updated annotations + */ + private[firrtl] def propagateAnnotations( + inAnno: AnnotationSeq, + resAnno: AnnotationSeq, + renameOpt: Option[RenameMap]): AnnotationSeq = { + val newAnnotations = { + val inSet = mutable.LinkedHashSet() ++ inAnno + val resSet = mutable.LinkedHashSet() ++ resAnno + val deleted = (inSet -- resSet).map { + case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"$xFormName+$name", delAnno) + case anno => DeletedAnnotation(name, anno) + } + val created = resSet -- inSet + val unchanged = resSet & inSet + (deleted ++ created ++ unchanged) + } + + // For each annotation, rename all annotations. + val renames = renameOpt.getOrElse(RenameMap()) + val remapped2original = mutable.LinkedHashMap[Annotation, mutable.LinkedHashSet[Annotation]]() + val keysOfNote = mutable.LinkedHashSet[Annotation]() + val finalAnnotations = newAnnotations.flatMap { anno => + val remappedAnnos = anno.update(renames) + remappedAnnos.foreach { remapped => + val set = remapped2original.getOrElseUpdate(remapped, mutable.LinkedHashSet.empty[Annotation]) + set += anno + if(set.size > 1) keysOfNote += remapped + } + remappedAnnos + }.toSeq + keysOfNote.foreach { key => + logger.debug(s"""The following original annotations are renamed to the same new annotation.""") + logger.debug(s"""Original Annotations:\n ${remapped2original(key).mkString("\n ")}""") + logger.debug(s"""New Annotation:\n $key""") + } + finalAnnotations + } +} + +object UpdateAnnotations { + + def apply(a: Transform): UpdateAnnotations = new UpdateAnnotations(a) + +} diff --git a/src/main/scala/firrtl/stage/transforms/WrappedTransform.scala b/src/main/scala/firrtl/stage/transforms/WrappedTransform.scala new file mode 100644 index 00000000..5fcfa250 --- /dev/null +++ b/src/main/scala/firrtl/stage/transforms/WrappedTransform.scala @@ -0,0 +1,33 @@ +// See LICENSE for license details. + +package firrtl.stage.transforms + +import firrtl.Transform + +/** A [[firrtl.Transform]] that "wraps" a second [[firrtl.Transform Transform]] to do some work before and after the + * second [[firrtl.Transform Transform]]. + * + * This is intended to synergize with the [[firrtl.options.DependencyManager.wrappers]] method. + * @see [[firrtl.stage.transforms.CatchCustomTransformExceptions]] + * @see [[firrtl.stage.transforms.TrackTransforms]] + * @see [[firrtl.stage.transforms.UpdateAnnotations]] + */ +trait WrappedTransform { this: Transform => + + /** The underlying [[firrtl.Transform]] */ + val underlying: Transform + + /** Return the original [[firrtl.Transform]] if this wrapper is wrapping other wrappers. */ + lazy final val trueUnderlying: Transform = underlying match { + case a: WrappedTransform => a.trueUnderlying + case _ => underlying + } + + override final val inputForm = underlying.inputForm + override final val outputForm = underlying.outputForm + override final val prerequisites = underlying.prerequisites + override final val dependents = underlying.dependents + override final def invalidates(b: Transform): Boolean = underlying.invalidates(b) + override final lazy val name = underlying.name + +} diff --git a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala index b62cf7a1..07cf09b0 100644 --- a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala +++ b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala @@ -6,6 +6,7 @@ import java.io.{File, FileNotFoundException, FileInputStream, FileOutputStream, import firrtl._ import firrtl.annotations._ +import firrtl.options.PreservesAll import scala.collection.immutable.ListSet @@ -54,12 +55,18 @@ class BlackBoxNotFoundException(fileName: String, message: String) extends Firrt * will set the directory where the Verilog will be written. This annotation is typically be * set by the execution harness, or directly in the tests */ -class BlackBoxSourceHelper extends firrtl.Transform { +class BlackBoxSourceHelper extends firrtl.Transform with PreservesAll[Transform] { import BlackBoxSourceHelper._ private val DefaultTargetDir = new File(".") override def inputForm: CircuitForm = LowForm override def outputForm: CircuitForm = LowForm + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty + /** Collect BlackBoxHelperAnnos and and find the target dir if specified * @param annos a list of generic annotations for this transform * @return BlackBoxHelperAnnos and target directory diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index bb5d88e7..b7ba5c5e 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -4,7 +4,6 @@ package firrtl.transforms import scala.collection.mutable - import firrtl._ import firrtl.ir._ import firrtl.passes.{Errors, PassException} @@ -13,7 +12,7 @@ import firrtl.annotations._ import firrtl.Utils.throwInternalError import firrtl.graph._ import firrtl.analyses.InstanceGraph -import firrtl.options.{RegisteredTransform, ShellOption} +import firrtl.options.{Dependency, PreservesAll, RegisteredTransform, ShellOption} /* * A case class that represents a net in the circuit. This is @@ -95,10 +94,19 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget * @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules * @note The pass will throw exceptions on "false paths" */ -class CheckCombLoops extends Transform with RegisteredTransform { +class CheckCombLoops extends Transform with RegisteredTransform with PreservesAll[Transform] { def inputForm = LowForm def outputForm = LowForm + override val prerequisites = firrtl.stage.Forms.MidForm ++ + Seq( Dependency(passes.LowerTypes), + Dependency(passes.Legalize), + Dependency(firrtl.transforms.RemoveReset) ) + + override val optionalPrerequisites = Seq.empty + + override val dependents = Seq.empty + import CheckCombLoops._ val options = Seq( diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala index ac8fc5fb..8f5972e1 100644 --- a/src/main/scala/firrtl/transforms/CombineCats.scala +++ b/src/main/scala/firrtl/transforms/CombineCats.scala @@ -7,6 +7,8 @@ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.WrappedExpression._ import firrtl.annotations.NoTargetAnnotation +import firrtl.options.PreservesAll +import firrtl.options.Dependency import scala.collection.mutable @@ -51,9 +53,22 @@ object CombineCats { * Use [[MaxCatLenAnnotation]] to limit the number of elements that can be concatenated. * The default maximum number of elements is 10. */ -class CombineCats extends Transform { +class CombineCats extends Transform with PreservesAll[Transform] { def inputForm: LowForm.type = LowForm def outputForm: LowForm.type = LowForm + + override val prerequisites = firrtl.stage.Forms.LowForm ++ + Seq( Dependency(passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions) ) + + override val optionalPrerequisites = Seq.empty + + override val dependents = Seq( + Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) + val defaultMaxCatLen = 10 def execute(state: CircuitState): CircuitState = { diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 55c897b3..c11bc44d 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -13,6 +13,7 @@ import firrtl.PrimOps._ import firrtl.graph.DiGraph import firrtl.analyses.InstanceGraph import firrtl.annotations.TargetToken.Ref +import firrtl.options.Dependency import annotation.tailrec import collection.mutable @@ -102,6 +103,25 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { def inputForm = LowForm def outputForm = LowForm + override val prerequisites = + ((new mutable.LinkedHashSet()) + ++ firrtl.stage.Forms.LowForm + - Dependency(firrtl.passes.Legalize) + + Dependency(firrtl.passes.RemoveValidIf)).toSeq + + override val optionalPrerequisites = Seq.empty + + override val dependents = + Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) + + override def invalidates(a: Transform): Boolean = a match { + case firrtl.passes.Legalize => true + case _ => false + } + override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation]) sealed trait SimplifyBinaryOp { diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index 983f1048..04f1c7d2 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -10,7 +10,7 @@ import firrtl.analyses.InstanceGraph import firrtl.Mappers._ import firrtl.Utils.{throwInternalError, kind} import firrtl.MemoizedHash._ -import firrtl.options.{RegisteredTransform, ShellOption} +import firrtl.options.{Dependency, PreservesAll, RegisteredTransform, ShellOption} import collection.mutable @@ -29,9 +29,29 @@ import collection.mutable * circumstances of their instantiation in their parent module, they will still not be removed. To * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication. */ -class DeadCodeElimination extends Transform with ResolvedAnnotationPaths with RegisteredTransform { - def inputForm = LowForm - def outputForm = LowForm +class DeadCodeElimination extends Transform with ResolvedAnnotationPaths with RegisteredTransform + with PreservesAll[Transform] { + def inputForm = UnknownForm + def outputForm = UnknownForm + + override val prerequisites = firrtl.stage.Forms.LowForm ++ + Seq( Dependency(firrtl.passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[firrtl.transforms.CombineCats], + Dependency(passes.CommonSubexpressionElimination) ) + + override val optionalPrerequisites = Seq.empty + + override val dependents = + Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(passes.VerilogPrep), + Dependency[firrtl.AddDescriptionNodes] ) val options = Seq( new ShellOption[Unit]( diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index 0667a184..5caa9228 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -9,7 +9,7 @@ import firrtl.analyses.InstanceGraph import firrtl.annotations._ import firrtl.passes.{InferTypes, MemPortUtils} import firrtl.Utils.throwInternalError -import firrtl.options.{HasShellOptions, ShellOption} +import firrtl.options.{HasShellOptions, PreservesAll, ShellOption} // Datastructures import scala.collection.mutable @@ -39,10 +39,14 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt * Specifically, the restriction of instance loops must have been checked, or else this pass can * infinitely recurse */ -class DedupModules extends Transform { +class DedupModules extends Transform with PreservesAll[Transform] { def inputForm: CircuitForm = HighForm def outputForm: CircuitForm = HighForm + override val prerequisites = firrtl.stage.Forms.Resolved + + override val dependents = Seq.empty + /** Deduplicate a Circuit * @param state Input Firrtl AST * @return A transformed Firrtl AST diff --git a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala index 08bf4af4..59d14ab2 100644 --- a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala +++ b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala @@ -2,10 +2,12 @@ package firrtl.transforms -import firrtl.{CircuitState, LowForm, Namespace, PrimOps, Transform, Utils, WRef} +import firrtl.{CircuitState, Namespace, PrimOps, Transform, UnknownForm, Utils, WRef} import firrtl.ir._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} import firrtl.PrimOps.{Add, AsSInt, Sub, Tail} +import firrtl.stage.Forms import scala.collection.mutable @@ -105,9 +107,15 @@ object FixAddingNegativeLiterals { * the literal and thus not all expressions in the add are the same. This is fixed here when we directly * subtract the literal instead. */ -class FixAddingNegativeLiterals extends Transform { - def inputForm = LowForm - def outputForm = LowForm +class FixAddingNegativeLiterals extends Transform with PreservesAll[Transform] { + def inputForm = UnknownForm + def outputForm = UnknownForm + + override val prerequisites = Forms.LowFormMinimumOptimized :+ Dependency[BlackBoxSourceHelper] + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty def execute(state: CircuitState): CircuitState = { val modulesx = state.circuit.modules.map(FixAddingNegativeLiterals.fixupModule) diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index 2d04dc89..eadbb0cb 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -6,6 +6,7 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ import firrtl.Utils._ +import firrtl.options.Dependency import scala.collection.mutable @@ -105,8 +106,25 @@ object FlattenRegUpdate { */ // TODO Preserve source locators class FlattenRegUpdate extends Transform { - def inputForm = MidForm - def outputForm = MidForm + def inputForm = UnknownForm + def outputForm = UnknownForm + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[InlineCastsTransform], + Dependency[LegalizeClocksTransform] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty + + override def invalidates(a: Transform): Boolean = a match { + case _: DeadCodeElimination => true + case _ => false + } def execute(state: CircuitState): CircuitState = { val modulesx = state.circuit.modules.map { diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala index c617e685..9e3d639d 100644 --- a/src/main/scala/firrtl/transforms/GroupComponents.scala +++ b/src/main/scala/firrtl/transforms/GroupComponents.scala @@ -61,6 +61,7 @@ class GroupComponents extends firrtl.Transform { case other => Seq(other) } val cs = state.copy(circuit = state.circuit.copy(modules = newModules)) + /* @todo move ResolveKinds and InferTypes out */ val csx = ResolveKinds.execute(InferTypes.execute(cs)) csx } diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala index 026b15fc..72724b27 100644 --- a/src/main/scala/firrtl/transforms/InferResets.scala +++ b/src/main/scala/firrtl/transforms/InferResets.scala @@ -8,6 +8,7 @@ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ import firrtl.annotations.{ReferenceTarget, TargetToken} import firrtl.Utils.{toTarget, throwInternalError} +import firrtl.options.Dependency import firrtl.passes.{Pass, PassException, InferTypes} import firrtl.graph.MutableDiGraph @@ -110,8 +111,21 @@ object InferResets { */ // TODO should we error if a DefMemory is of type AsyncReset? In CheckTypes? class InferResets extends Transform { - def inputForm: CircuitForm = HighForm - def outputForm: CircuitForm = HighForm + + def inputForm: CircuitForm = UnknownForm + def outputForm: CircuitForm = UnknownForm + + override val prerequisites = + Seq( Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.Uniquify), + Dependency(passes.ResolveFlows), + Dependency[passes.InferWidths] ) ++ stage.Forms.WorkingIR + + override def invalidates(a: Transform): Boolean = a match { + case _: checks.CheckResets | passes.CheckTypes => true + case _ => false + } import InferResets._ diff --git a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala index c4f40700..617dff96 100644 --- a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala +++ b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala @@ -3,6 +3,7 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} import firrtl.PrimOps.{Bits, Head, Tail, Shr} import firrtl.Utils.{isBitExtract, isTemp} import firrtl.WrappedExpression._ @@ -91,10 +92,19 @@ object InlineBitExtractionsTransform { } /** Inline nodes that are simple bits */ -class InlineBitExtractionsTransform extends Transform { +class InlineBitExtractionsTransform extends Transform with PreservesAll[Transform] { def inputForm = UnknownForm def outputForm = UnknownForm + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty + def execute(state: CircuitState): CircuitState = { val modulesx = state.circuit.modules.map(InlineBitExtractionsTransform.onMod(_)) state.copy(circuit = state.circuit.copy(modules = modulesx)) diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala index e504eb70..91ba7578 100644 --- a/src/main/scala/firrtl/transforms/InlineCasts.scala +++ b/src/main/scala/firrtl/transforms/InlineCasts.scala @@ -3,6 +3,7 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} import firrtl.Utils.{isCast, NodeMap} @@ -59,9 +60,19 @@ object InlineCastsTransform { } /** Inline nodes that are simple casts */ -class InlineCastsTransform extends Transform { - def inputForm = LowForm - def outputForm = LowForm +class InlineCastsTransform extends Transform with PreservesAll[Transform] { + def inputForm = UnknownForm + def outputForm = UnknownForm + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty def execute(state: CircuitState): CircuitState = { val modulesx = state.circuit.modules.map(InlineCastsTransform.onMod(_)) diff --git a/src/main/scala/firrtl/transforms/LegalizeClocks.scala b/src/main/scala/firrtl/transforms/LegalizeClocks.scala index 1c2fc045..d87cd735 100644 --- a/src/main/scala/firrtl/transforms/LegalizeClocks.scala +++ b/src/main/scala/firrtl/transforms/LegalizeClocks.scala @@ -3,6 +3,7 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} import firrtl.Utils.isCast // Fixup otherwise legal Verilog that lint tools and other tools don't like @@ -58,9 +59,20 @@ object LegalizeClocksTransform { } /** Ensure Clocks to be emitted are legal Verilog */ -class LegalizeClocksTransform extends Transform { - def inputForm = LowForm - def outputForm = LowForm +class LegalizeClocksTransform extends Transform with PreservesAll[Transform] { + def inputForm = UnknownForm + def outputForm = UnknownForm + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[InlineCastsTransform] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty def execute(state: CircuitState): CircuitState = { val modulesx = state.circuit.modules.map(LegalizeClocksTransform.onMod(_)) diff --git a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala index 1f0202d1..fdb0090e 100644 --- a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala +++ b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala @@ -10,6 +10,8 @@ import firrtl.ir import firrtl.passes.{Uniquify, PassException} import firrtl.Utils.v_keywords import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} + import scala.collection.mutable /** Transform that removes collisions with reserved keywords @@ -19,8 +21,8 @@ import scala.collection.mutable * @define implicitScope @param scope the enclosing scope of this name. If [[None]], then this is a [[Circuit]] name */ class RemoveKeywordCollisions(keywords: Set[String]) extends Transform { - val inputForm: CircuitForm = LowForm - val outputForm: CircuitForm = LowForm + val inputForm: CircuitForm = UnknownForm + val outputForm: CircuitForm = UnknownForm private type ModuleType = mutable.HashMap[String, ir.Type] private val inlineDelim = "_" @@ -231,4 +233,20 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends Transform { } /** Transform that removes collisions with Verilog keywords */ -class VerilogRename extends RemoveKeywordCollisions(v_keywords) +class VerilogRename extends RemoveKeywordCollisions(v_keywords) with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[InlineCastsTransform], + Dependency[LegalizeClocksTransform], + Dependency[FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup) ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty + +} diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index ed1baf7d..75d64b76 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -7,6 +7,7 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ import firrtl.WrappedExpression.we +import firrtl.options.Dependency import scala.collection.{immutable, mutable} @@ -14,9 +15,22 @@ import scala.collection.{immutable, mutable} * * @note This pass must run after LowerTypes */ -class RemoveReset extends Transform { - def inputForm = MidForm - def outputForm = MidForm +object RemoveReset extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + override val prerequisites = firrtl.stage.Forms.MidForm ++ + Seq( Dependency(passes.LowerTypes), + Dependency(passes.Legalize) ) + + override val optionalPrerequisites = Seq.empty + + override val dependents = Seq.empty + + override def invalidates(a: Transform): Boolean = a match { + case firrtl.passes.ResolveFlows => true + case _ => false + } private case class Reset(cond: Expression, value: Expression) diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index 825cdb60..5e6b7910 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -9,6 +9,7 @@ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ import firrtl.WrappedExpression._ import firrtl.graph.{MutableDiGraph, CyclicException} +import firrtl.options.{Dependency, PreservesAll} import scala.collection.mutable import scala.util.{Try, Success, Failure} @@ -19,10 +20,20 @@ import scala.util.{Try, Success, Failure} * wires have multiple connections that may be impossible to order in a * flow-foward way */ -class RemoveWires extends Transform { +class RemoveWires extends Transform with PreservesAll[Transform] { def inputForm = LowForm def outputForm = LowForm + override val prerequisites = firrtl.stage.Forms.MidForm ++ + Seq( Dependency(passes.LowerTypes), + Dependency(passes.Legalize), + Dependency(transforms.RemoveReset), + Dependency[transforms.CheckCombLoops] ) + + override val optionalPrerequisites = Seq(Dependency[checks.CheckResets]) + + override val dependents = Seq.empty + // Extract all expressions that are references to a Node, Wire, or Reg // Since we are operating on LowForm, they can only be WRefs private def extractNodeWireRegRefs(expr: Expression): Seq[WRef] = { @@ -140,6 +151,7 @@ class RemoveWires extends Transform { } } + /* @todo move ResolveKinds outside */ private val cleanup = Seq( passes.ResolveKinds ) diff --git a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala index 8aa1553a..c8129450 100644 --- a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala +++ b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala @@ -7,6 +7,7 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.WrappedExpression._ +import firrtl.options.{Dependency, PreservesAll} import scala.collection.mutable @@ -76,9 +77,17 @@ object ReplaceTruncatingArithmetic { * @note This replaces some FIRRTL primops with ops that are not actually legal FIRRTL. They are * useful for emission to languages that support non-expanding arithmetic (like Verilog) */ -class ReplaceTruncatingArithmetic extends Transform { - def inputForm = LowForm - def outputForm = LowForm +class ReplaceTruncatingArithmetic extends Transform with PreservesAll[Transform] { + def inputForm = UnknownForm + def outputForm = UnknownForm + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty def execute(state: CircuitState): CircuitState = { val modulesx = state.circuit.modules.map(ReplaceTruncatingArithmetic.onMod(_)) diff --git a/src/main/scala/logger/Logger.scala b/src/main/scala/logger/Logger.scala index e37a45e4..e95ea643 100644 --- a/src/main/scala/logger/Logger.scala +++ b/src/main/scala/logger/Logger.scala @@ -361,7 +361,7 @@ object Logger { */ def setOptions(inputAnnotations: AnnotationSeq): Unit = { val annotations = - Seq( AddDefaults, Checks ) + Seq( new AddDefaults, Checks ) .foldLeft(inputAnnotations)((a, p) => p.transform(a)) val lopts = view[LoggerOptions](annotations) diff --git a/src/main/scala/logger/phases/AddDefaults.scala b/src/main/scala/logger/phases/AddDefaults.scala index f6daa811..3368283e 100644 --- a/src/main/scala/logger/phases/AddDefaults.scala +++ b/src/main/scala/logger/phases/AddDefaults.scala @@ -3,12 +3,15 @@ package logger.phases import firrtl.AnnotationSeq -import firrtl.options.Phase +import firrtl.options.{Phase, PreservesAll} import logger.{LoggerOption, LogLevelAnnotation} /** Add default logger [[Annotation]]s */ -private [logger] object AddDefaults extends Phase { +private [logger] class AddDefaults extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq.empty + override val dependents = Seq.empty /** Add missing default [[Logger]] [[Annotation]]s to an [[AnnotationSeq]] * @param annotations input annotations diff --git a/src/main/scala/logger/phases/Checks.scala b/src/main/scala/logger/phases/Checks.scala index c706948c..1e1ccfe6 100644 --- a/src/main/scala/logger/phases/Checks.scala +++ b/src/main/scala/logger/phases/Checks.scala @@ -4,7 +4,7 @@ package logger.phases import firrtl.AnnotationSeq import firrtl.annotations.Annotation -import firrtl.options.Phase +import firrtl.options.{Dependency, Phase, PreservesAll} import logger.{LogLevelAnnotation, LogFileAnnotation, LoggerException} @@ -12,7 +12,10 @@ import scala.collection.mutable /** Check that an [[firrtl.AnnotationSeq AnnotationSeq]] has all necessary [[firrtl.annotations.Annotation Annotation]]s * for a [[Logger]] */ -object Checks extends Phase { +object Checks extends Phase with PreservesAll[Phase] { + + override val prerequisites = Seq(Dependency[AddDefaults]) + override val dependents = Seq.empty /** Ensure that an [[firrtl.AnnotationSeq AnnotationSeq]] has necessary [[Logger]] [[firrtl.annotations.Annotation * Annotation]]s diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 03e3198e..a1c6580d 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -627,7 +627,7 @@ class JsonAnnotationTests extends AnnotationTests with BackendCompilationUtiliti override def inputForm: CircuitForm = UnknownForm override def outputForm: CircuitForm = UnknownForm - protected def execute(state: CircuitState): CircuitState = state + def execute(state: CircuitState): CircuitState = state } "annotation order" should "should be preserved" in { diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala index e03d1ab9..715e0cda 100644 --- a/src/test/scala/firrtlTests/CInferMDirSpec.scala +++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala @@ -7,7 +7,7 @@ import firrtl.ir._ import firrtl.passes._ import firrtl.transforms._ -class CInferMDir extends LowTransformSpec { +class CInferMDirSpec extends LowTransformSpec { object CInferMDirCheckPass extends Pass { // finds the memory and check its read port def checkStmt(s: Statement): Boolean = s match { diff --git a/src/test/scala/firrtlTests/CustomTransformSpec.scala b/src/test/scala/firrtlTests/CustomTransformSpec.scala index 04cbf276..809f2b1e 100644 --- a/src/test/scala/firrtlTests/CustomTransformSpec.scala +++ b/src/test/scala/firrtlTests/CustomTransformSpec.scala @@ -6,12 +6,15 @@ import firrtl.ir.Circuit import firrtl._ import firrtl.passes.Pass import firrtl.ir._ -import firrtl.stage.{FirrtlSourceAnnotation, FirrtlStage, RunFirrtlTransformAnnotation} +import firrtl.stage.{FirrtlSourceAnnotation, FirrtlStage, Forms, RunFirrtlTransformAnnotation} +import firrtl.options.Dependency +import firrtl.transforms.IdentityTransform -class CustomTransformSpec extends FirrtlFlatSpec { - behavior of "Custom Transforms" +import scala.reflect.runtime - they should "be able to introduce high firrtl" in { +object CustomTransformSpec { + + class ReplaceExtModuleTransform extends SeqTransform with FirrtlMatchers { // Simple module val delayModuleString = """ |circuit Delay : @@ -31,38 +34,99 @@ class CustomTransformSpec extends FirrtlFlatSpec { val delayModuleCircuit = parse(delayModuleString) val delayModule = delayModuleCircuit.modules.find(_.name == delayModuleCircuit.main).get - class ReplaceExtModuleTransform extends SeqTransform { - class ReplaceExtModule extends Pass { - def run(c: Circuit): Circuit = c.copy( - modules = c.modules map { - case ExtModule(_, "Delay", _, _, _) => delayModule - case other => other - } - ) - } - def transforms = Seq(new ReplaceExtModule) - def inputForm = LowForm - def outputForm = HighForm + class ReplaceExtModule extends Pass { + def run(c: Circuit): Circuit = c.copy( + modules = c.modules map { + case ExtModule(_, "Delay", _, _, _) => delayModule + case other => other + } + ) } - - runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform)) + def transforms = Seq(new ReplaceExtModule) + def inputForm = LowForm + def outputForm = HighForm } - they should "not cause \"Internal Errors\"" in { - val input = """ + val input = """ |circuit test : | module test : | output out : UInt | out <= UInt(123)""".stripMargin - val errorString = "My Custom Transform failed!" - class ErroringTransform extends Transform { + val errorString = "My Custom Transform failed!" + class ErroringTransform extends Transform { + def inputForm = HighForm + def outputForm = HighForm + def execute(state: CircuitState): CircuitState = { + require(false, errorString) + state + } + } + + object MutableState { + var count: Int = 0 + } + + class FirstTransform extends Transform { + def inputForm = HighForm + def outputForm = HighForm + + def execute(state: CircuitState): CircuitState = { + require(MutableState.count == 0, s"Count was ${MutableState.count}, expected 0") + MutableState.count = 1 + state + } + } + + class SecondTransform extends Transform { + def inputForm = HighForm + def outputForm = HighForm + + def execute(state: CircuitState): CircuitState = { + require(MutableState.count == 1, s"Count was ${MutableState.count}, expected 1") + MutableState.count = 2 + state + } + } + + class ThirdTransform extends Transform { + def inputForm = HighForm + def outputForm = HighForm + + def execute(state: CircuitState): CircuitState = { + require(MutableState.count == 2, s"Count was ${MutableState.count}, expected 2") + MutableState.count = 3 + state + } + } + + class IdentityLowForm extends IdentityTransform(LowForm) { + override val name = ">>>>> IdentityLowForm <<<<<" + } + + object Foo { + class A extends Transform { def inputForm = HighForm def outputForm = HighForm - def execute(state: CircuitState): CircuitState = { - require(false, errorString) - state + def execute(s: CircuitState) = { + assert(name.endsWith("A")) + s } } + } + +} + +class CustomTransformSpec extends FirrtlFlatSpec { + + import CustomTransformSpec._ + + behavior of "Custom Transforms" + + they should "be able to introduce high firrtl" in { + runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform)) + } + + they should "not cause \"Internal Errors\"" in { val optionsManager = new ExecutionOptionsManager("test") with HasFirrtlOptions { firrtlOptions = FirrtlExecutionOptions( firrtlSource = Some(input), @@ -73,15 +137,38 @@ class CustomTransformSpec extends FirrtlFlatSpec { }).getMessage should include (errorString) } - object Foo { - class A extends Transform { - def inputForm = HighForm - def outputForm = HighForm - def execute(s: CircuitState) = { - assert(name.endsWith("A")) - s - } + they should "preserve the input order" in { + runFirrtlTest("CustomTransform", "/features", customTransforms = List( + new FirstTransform, + new SecondTransform, + new ThirdTransform, + new ReplaceExtModuleTransform + )) + } + + they should "run right before the emitter when inputForm=LowForm" in { + + val custom = Dependency[IdentityLowForm] + + def testOrder(emitter: Dependency[Emitter], preceders: Seq[Dependency[Transform]]): Unit = { + info(s"""${preceders.map(_.getSimpleName).mkString(" -> ")} -> ${custom.getSimpleName} -> ${emitter.getSimpleName} ok!""") + + val compiler = new firrtl.stage.transforms.Compiler(Seq(custom, emitter)) + info("Transform Order: \n" + compiler.prettyPrint(" ")) + + val expectedSlice = preceders ++ Seq(custom, emitter) + + compiler + .flattenedTransformOrder + .map(Dependency.fromTransform(_)) + .containsSlice(expectedSlice) should be (true) } + + Seq( (Dependency[LowFirrtlEmitter], Seq(Forms.LowForm.last) ), + (Dependency[MinimumVerilogEmitter], Seq(Forms.LowFormMinimumOptimized.last) ), + (Dependency[VerilogEmitter], Seq(Forms.LowFormOptimized.last) ), + (Dependency[SystemVerilogEmitter], Seq(Forms.LowFormOptimized.last) ) + ).foreach((testOrder _).tupled) } they should "work if placed inside an object" in { diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala new file mode 100644 index 00000000..f3183599 --- /dev/null +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -0,0 +1,349 @@ +// See LICENSE for license details. + +package firrtlTests + +import org.scalatest.{FlatSpec, Matchers} + +import firrtl._ +import firrtl.passes +import firrtl.options.Dependency +import firrtl.stage.{Forms, TransformManager} +import firrtl.transforms.IdentityTransform + +sealed trait PatchAction { val line: Int } + +case class Add(line: Int, transforms: Seq[Dependency[Transform]]) extends PatchAction +case class Del(line: Int) extends PatchAction + +object Transforms { + class IdentityTransformDiff(val inputForm: CircuitForm, val outputForm: CircuitForm) extends Transform { + override def execute(state: CircuitState): CircuitState = state + override def name: String = s">>>>> $inputForm -> $outputForm <<<<<" + } + import firrtl.{ChirrtlForm => C, HighForm => H, MidForm => M, LowForm => L, UnknownForm => U} + class ChirrtlToChirrtl extends IdentityTransformDiff(C, C) + class HighToChirrtl extends IdentityTransformDiff(H, C) + class HighToHigh extends IdentityTransformDiff(H, H) + class MidToMid extends IdentityTransformDiff(M, M) + class MidToChirrtl extends IdentityTransformDiff(M, C) + class MidToHigh extends IdentityTransformDiff(M, H) + class LowToChirrtl extends IdentityTransformDiff(L, C) + class LowToHigh extends IdentityTransformDiff(L, H) + class LowToMid extends IdentityTransformDiff(L, M) + class LowToLow extends IdentityTransformDiff(L, L) +} + +class LoweringCompilersSpec extends FlatSpec with Matchers { + + def legacyTransforms(a: CoreTransform): Seq[Transform] = a match { + case _: ChirrtlToHighFirrtl => Seq( + passes.CheckChirrtl, + passes.CInferTypes, + passes.CInferMDir, + passes.RemoveCHIRRTL) + case _: IRToWorkingIR => Seq(passes.ToWorkingIR) + case _: ResolveAndCheck => Seq( + passes.CheckHighForm, + passes.ResolveKinds, + passes.InferTypes, + passes.CheckTypes, + passes.Uniquify, + passes.ResolveKinds, + passes.InferTypes, + passes.ResolveFlows, + passes.CheckFlows, + new passes.InferBinaryPoints, + new passes.TrimIntervals, + new passes.InferWidths, + passes.CheckWidths, + new firrtl.transforms.InferResets) + case _: HighFirrtlToMiddleFirrtl => Seq( + passes.PullMuxes, + passes.ReplaceAccesses, + passes.ExpandConnects, + passes.RemoveAccesses, + passes.Uniquify, + passes.ExpandWhens, + passes.CheckInitialization, + passes.ResolveKinds, + passes.InferTypes, + passes.CheckTypes, + passes.ResolveFlows, + new passes.InferWidths, + passes.CheckWidths, + new passes.RemoveIntervals, + passes.ConvertFixedToSInt, + passes.ZeroWidth, + passes.InferTypes) + case _: MiddleFirrtlToLowFirrtl => Seq( + passes.LowerTypes, + passes.ResolveKinds, + passes.InferTypes, + passes.ResolveFlows, + new passes.InferWidths, + passes.Legalize, + firrtl.transforms.RemoveReset, + passes.ResolveFlows, + new firrtl.transforms.CheckCombLoops, + new checks.CheckResets, + new firrtl.transforms.RemoveWires) + case _: LowFirrtlOptimization => Seq( + passes.RemoveValidIf, + new firrtl.transforms.ConstantPropagation, + passes.PadWidths, + new firrtl.transforms.ConstantPropagation, + passes.Legalize, + passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter + new firrtl.transforms.ConstantPropagation, + passes.SplitExpressions, + new firrtl.transforms.CombineCats, + passes.CommonSubexpressionElimination, + new firrtl.transforms.DeadCodeElimination) + case _: MinimumLowFirrtlOptimization => Seq( + passes.RemoveValidIf, + passes.Legalize, + passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter + passes.SplitExpressions) + } + + def compare(a: Seq[Transform], b: TransformManager, patches: Seq[PatchAction] = Seq.empty): Unit = { + info(s"""Transform Order:\n${b.prettyPrint(" ")}""") + + val m = new scala.collection.mutable.HashMap[Int, Seq[Dependency[Transform]]].withDefault(_ => Seq.empty) + a.map(Dependency.fromTransform).zipWithIndex.foreach{ case (t, idx) => m(idx) = Seq(t) } + + patches.foreach { + case Add(line, txs) => m(line - 1) = m(line - 1) ++ txs + case Del(line) => m.remove(line - 1) + } + + val patched = scala.collection.immutable.TreeMap(m.toArray:_*).values.flatten + + patched + .zip(b.flattenedTransformOrder.map(Dependency.fromTransform)) + .foreach{ case (aa, bb) => bb should be (aa) } + + info(s"found ${b.flattenedTransformOrder.size} transforms") + patched.size should be (b.flattenedTransformOrder.size) + } + + behavior of "ChirrtlToHighFirrtl" + + it should "replicate the old order" in { + val tm = new TransformManager(Forms.MinimalHighForm, Forms.ChirrtlForm) + compare(legacyTransforms(new firrtl.ChirrtlToHighFirrtl), tm) + } + + behavior of "IRToWorkingIR" + + it should "replicate the old order" in { + val tm = new TransformManager(Forms.WorkingIR, Forms.MinimalHighForm) + compare(legacyTransforms(new firrtl.IRToWorkingIR), tm) + } + + behavior of "ResolveAndCheck" + + it should "replicate the old order" in { + val tm = new TransformManager(Forms.Resolved, Forms.WorkingIR) + val patches = Seq( + Add(14, Seq(Dependency.fromTransform(firrtl.passes.CheckTypes))) + ) + compare(legacyTransforms(new ResolveAndCheck), tm, patches) + } + + behavior of "HighFirrtlToMiddleFirrtl" + + it should "replicate the old order" in { + val tm = new TransformManager(Forms.MidForm, Forms.Deduped) + val patches = Seq( + Add(5, Seq(Dependency(firrtl.passes.ResolveKinds), + Dependency(firrtl.passes.InferTypes))), + Del(6), + Del(7), + Add(6, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])), + Del(10), + Del(11), + Del(12), + Add(11, Seq(Dependency(firrtl.passes.ResolveFlows), + Dependency[firrtl.passes.InferWidths])), + Del(13) + ) + compare(legacyTransforms(new HighFirrtlToMiddleFirrtl), tm, patches) + } + + behavior of "MiddleFirrtlToLowFirrtl" + + it should "replicate the old order" in { + val tm = new TransformManager(Forms.LowForm, Forms.MidForm) + compare(legacyTransforms(new MiddleFirrtlToLowFirrtl), tm) + } + + behavior of "MinimumLowFirrtlOptimization" + + it should "replicate the old order" in { + val tm = new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm) + compare(legacyTransforms(new MinimumLowFirrtlOptimization), tm) + } + + behavior of "LowFirrtlOptimization" + + it should "replicate the old order" in { + val tm = new TransformManager(Forms.LowFormOptimized, Forms.LowForm) + val patches = Seq( + Add(7, Seq(Dependency(firrtl.passes.Legalize))) + ) + compare(legacyTransforms(new LowFirrtlOptimization), tm, patches) + } + + behavior of "VerilogMinimumOptimized" + + it should "replicate the old order" in { + val legacy = Seq( + new firrtl.transforms.BlackBoxSourceHelper, + new firrtl.transforms.FixAddingNegativeLiterals, + new firrtl.transforms.ReplaceTruncatingArithmetic, + new firrtl.transforms.InlineBitExtractionsTransform, + new firrtl.transforms.InlineCastsTransform, + new firrtl.transforms.LegalizeClocksTransform, + new firrtl.transforms.FlattenRegUpdate, + firrtl.passes.VerilogModulusCleanup, + new firrtl.transforms.VerilogRename, + firrtl.passes.VerilogPrep, + new firrtl.AddDescriptionNodes) + val tm = new TransformManager(Forms.VerilogMinimumOptimized, (new firrtl.VerilogEmitter).prerequisites) + compare(legacy, tm) + } + + behavior of "VerilogOptimized" + + it should "replicate the old order" in { + val legacy = Seq( + new firrtl.transforms.BlackBoxSourceHelper, + new firrtl.transforms.FixAddingNegativeLiterals, + new firrtl.transforms.ReplaceTruncatingArithmetic, + new firrtl.transforms.InlineBitExtractionsTransform, + new firrtl.transforms.InlineCastsTransform, + new firrtl.transforms.LegalizeClocksTransform, + new firrtl.transforms.FlattenRegUpdate, + new firrtl.transforms.DeadCodeElimination, + firrtl.passes.VerilogModulusCleanup, + new firrtl.transforms.VerilogRename, + firrtl.passes.VerilogPrep, + new firrtl.AddDescriptionNodes) + val tm = new TransformManager(Forms.VerilogOptimized, Forms.LowFormOptimized) + compare(legacy, tm) + } + + behavior of "Legacy Custom Transforms" + + it should "work for Chirrtl -> Chirrtl" in { + val expected = new Transforms.ChirrtlToChirrtl :: new firrtl.ChirrtlEmitter :: Nil + val tm = new TransformManager(Dependency[firrtl.ChirrtlEmitter] :: Dependency[Transforms.ChirrtlToChirrtl] :: Nil) + compare(expected, tm) + } + + it should "work for High -> High" in { + val expected = + new TransformManager(Forms.HighForm).flattenedTransformOrder ++ + Some(new Transforms.HighToHigh) ++ + (new TransformManager(Forms.MidForm, Forms.HighForm).flattenedTransformOrder) + val tm = new TransformManager(Forms.MidForm :+ Dependency[Transforms.HighToHigh]) + compare(expected, tm) + } + + it should "work for High -> Chirrtl" in { + val expected = + new TransformManager(Forms.HighForm).flattenedTransformOrder ++ + Some(new Transforms.HighToChirrtl) ++ + (new TransformManager(Forms.HighForm, Forms.ChirrtlForm).flattenedTransformOrder) + val tm = new TransformManager(Forms.HighForm :+ Dependency[Transforms.HighToChirrtl]) + compare(expected, tm) + } + + it should "work for Mid -> Mid" in { + val expected = + new TransformManager(Forms.MidForm).flattenedTransformOrder ++ + Some(new Transforms.MidToMid) ++ + (new TransformManager(Forms.LowForm, Forms.MidForm).flattenedTransformOrder) + val tm = new TransformManager(Forms.LowForm :+ Dependency[Transforms.MidToMid]) + compare(expected, tm) + } + + it should "work for Mid -> High" in { + val expected = + new TransformManager(Forms.MidForm).flattenedTransformOrder ++ + Some(new Transforms.MidToHigh) ++ + (new TransformManager(Forms.LowForm, Forms.MinimalHighForm).flattenedTransformOrder) + val tm = new TransformManager(Forms.LowForm :+ Dependency[Transforms.MidToHigh]) + compare(expected, tm) + } + + it should "work for Mid -> Chirrtl" in { + val expected = + new TransformManager(Forms.MidForm).flattenedTransformOrder ++ + Some(new Transforms.MidToChirrtl) ++ + (new TransformManager(Forms.LowForm, Forms.ChirrtlForm).flattenedTransformOrder) + val tm = new TransformManager(Forms.LowForm :+ Dependency[Transforms.MidToChirrtl]) + compare(expected, tm) + } + + it should "work for Low -> Low" in { + val expected = + new TransformManager(Forms.LowFormOptimized).flattenedTransformOrder ++ + Seq(new Transforms.LowToLow) + val tm = new TransformManager(Forms.LowFormOptimized :+ Dependency[Transforms.LowToLow]) + compare(expected, tm) + } + + it should "work for Low -> Mid" in { + val expected = + new TransformManager(Forms.LowFormOptimized).flattenedTransformOrder ++ + Seq(new Transforms.LowToMid) ++ + (new TransformManager(Forms.LowFormOptimized, Forms.MidForm).flattenedTransformOrder) + val tm = new TransformManager(Forms.LowFormOptimized :+ Dependency[Transforms.LowToMid]) + compare(expected, tm) + } + + it should "work for Low -> High" in { + val expected = + new TransformManager(Forms.LowFormOptimized).flattenedTransformOrder ++ + Seq(new Transforms.LowToHigh) ++ + (new TransformManager(Forms.LowFormOptimized, Forms.MinimalHighForm).flattenedTransformOrder) + val tm = new TransformManager(Forms.LowFormOptimized :+ Dependency[Transforms.LowToHigh]) + compare(expected, tm) + } + + it should "work for Low -> Chirrtl" in { + val expected = + new TransformManager(Forms.LowFormOptimized).flattenedTransformOrder ++ + Seq(new Transforms.LowToChirrtl) ++ + (new TransformManager(Forms.LowFormOptimized, Forms.ChirrtlForm).flattenedTransformOrder) + val tm = new TransformManager(Forms.LowFormOptimized :+ Dependency[Transforms.LowToChirrtl]) + compare(expected, tm) + } + + it should "schedule inputForm=LowForm after MiddleFirrtlToLowFirrtl for the LowFirrtlEmitter" in { + val expected = + new TransformManager(Forms.LowForm).flattenedTransformOrder ++ + Seq(new Transforms.LowToLow, new firrtl.LowFirrtlEmitter) + val tm = (new TransformManager(Seq(Dependency[firrtl.LowFirrtlEmitter], Dependency[Transforms.LowToLow]))) + compare(expected, tm) + } + + it should "schedule inputForm=LowForm after MinimumLowFirrtlOptimizations for the MinimalVerilogEmitter" in { + val expected = + new TransformManager(Forms.LowFormMinimumOptimized).flattenedTransformOrder ++ + Seq(new Transforms.LowToLow, new firrtl.MinimumVerilogEmitter) + val tm = (new TransformManager(Seq(Dependency[firrtl.MinimumVerilogEmitter], Dependency[Transforms.LowToLow]))) + compare(expected, tm) + } + + it should "schedule inputForm=LowForm after LowFirrtlOptimizations for the VerilogEmitter" in { + val expected = + new TransformManager(Forms.LowFormOptimized).flattenedTransformOrder ++ + Seq(new Transforms.LowToLow, new firrtl.VerilogEmitter) + val tm = (new TransformManager(Seq(Dependency[firrtl.VerilogEmitter], Dependency[Transforms.LowToLow]))) + compare(expected, tm) + } + +} diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index c5d0eacc..825d706f 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -13,12 +13,12 @@ class DoPrimVerilog extends FirrtlFlatSpec { "Xorr" should "emit correctly" in { val compiler = new VerilogCompiler val input = - """circuit Xorr : - | module Xorr : + """circuit Xorr : + | module Xorr : | input a: UInt<4> | output b: UInt<1> | b <= xorr(a)""".stripMargin - val check = + val check = """module Xorr( | input [3:0] a, | output b @@ -31,12 +31,12 @@ class DoPrimVerilog extends FirrtlFlatSpec { "Andr" should "emit correctly" in { val compiler = new VerilogCompiler val input = - """circuit Andr : - | module Andr : + """circuit Andr : + | module Andr : | input a: UInt<4> | output b: UInt<1> | b <= andr(a)""".stripMargin - val check = + val check = """module Andr( | input [3:0] a, | output b @@ -49,12 +49,12 @@ class DoPrimVerilog extends FirrtlFlatSpec { "Orr" should "emit correctly" in { val compiler = new VerilogCompiler val input = - """circuit Orr : - | module Orr : + """circuit Orr : + | module Orr : | input a: UInt<4> | output b: UInt<1> | b <= orr(a)""".stripMargin - val check = + val check = """module Orr( | input [3:0] a, | output b @@ -187,8 +187,8 @@ class DoPrimVerilog extends FirrtlFlatSpec { |""".stripMargin val check = """module Test( - | input [7:0] in, - | output out + | input [7:0] in, + | output out |); | wire [7:0] _GEN_0; | assign out = _GEN_0[0]; diff --git a/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala b/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala index eb061d8f..ea4127bc 100644 --- a/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala @@ -5,24 +5,25 @@ package annotationTests import firrtlTests._ import firrtl._ +import firrtl.annotations.{Annotation, NoTargetAnnotation} + +case object FoundTargetDirTransformRanAnnotation extends NoTargetAnnotation +case object FoundTargetDirTransformFoundTargetDirAnnotation extends NoTargetAnnotation /** Looks for [[TargetDirAnnotation]] */ -class FindTargetDirTransform(expected: String) extends Transform { +class FindTargetDirTransform extends Transform { def inputForm = HighForm def outputForm = HighForm - var foundTargetDir = false - var run = false + def execute(state: CircuitState): CircuitState = { - run = true - state.annotations.collectFirst { - case TargetDirAnnotation(expected) => - foundTargetDir = true - } - state + val a: Option[Annotation] = state.annotations.collectFirst { + case TargetDirAnnotation("a/b/c") => FoundTargetDirTransformFoundTargetDirAnnotation } + state.copy(annotations = state.annotations ++ a ++ Some(FoundTargetDirTransformRanAnnotation)) } } class TargetDirAnnotationSpec extends FirrtlFlatSpec { + behavior of "The target directory" val input = @@ -35,7 +36,7 @@ class TargetDirAnnotationSpec extends FirrtlFlatSpec { val targetDir = "a/b/c" it should "be available as an annotation when using execution options" in { - val findTargetDir = new FindTargetDirTransform(targetDir) // looks for the annotation + val findTargetDir = new FindTargetDirTransform // looks for the annotation val optionsManager = new ExecutionOptionsManager("TargetDir") with HasFirrtlOptions { commonOptions = commonOptions.copy(targetDirName = targetDir, @@ -44,11 +45,13 @@ class TargetDirAnnotationSpec extends FirrtlFlatSpec { firrtlSource = Some(input), customTransforms = Seq(findTargetDir)) } - Driver.execute(optionsManager) + val annotations: Seq[Annotation] = Driver.execute(optionsManager) match { + case a: FirrtlExecutionSuccess => a.circuitState.annotations + case _ => fail + } - // Check that FindTargetDirTransform transform is run and finds the annotation - findTargetDir.run should be (true) - findTargetDir.foundTargetDir should be (true) + annotations should contain (FoundTargetDirTransformRanAnnotation) + annotations should contain (FoundTargetDirTransformFoundTargetDirAnnotation) // Delete created directory val dir = new java.io.File(targetDir) @@ -57,13 +60,16 @@ class TargetDirAnnotationSpec extends FirrtlFlatSpec { } it should "NOT be available as an annotation when using a raw compiler" in { - val findTargetDir = new FindTargetDirTransform(targetDir) // looks for the annotation + val findTargetDir = new FindTargetDirTransform // looks for the annotation val compiler = new VerilogCompiler val circuit = Parser.parse(input split "\n") - compiler.compileAndEmit(CircuitState(circuit, HighForm), Seq(findTargetDir)) + + val annotations: Seq[Annotation] = compiler + .compileAndEmit(CircuitState(circuit, HighForm), Seq(findTargetDir)) + .annotations // Check that FindTargetDirTransform does not find the annotation - findTargetDir.run should be (true) - findTargetDir.foundTargetDir should be (false) + annotations should contain (FoundTargetDirTransformRanAnnotation) + annotations should not contain (FoundTargetDirTransformFoundTargetDirAnnotation) } } diff --git a/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala b/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala index a8176316..f2620051 100644 --- a/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala @@ -4,9 +4,11 @@ package firrtlTests.stage.phases import org.scalatest.{FlatSpec, Matchers} +import scala.collection.mutable + import firrtl.{Compiler => _, _} -import firrtl.options.Phase -import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, RunFirrtlTransformAnnotation} +import firrtl.options.{Phase, PreservesAll} +import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, Forms, RunFirrtlTransformAnnotation} import firrtl.stage.phases.Compiler class CompilerSpec extends FlatSpec with Matchers { @@ -40,7 +42,7 @@ class CompilerSpec extends FlatSpec with Matchers { val expected = Seq(FirrtlCircuitAnnotation(circuitOut)) - phase.transform(input).toSeq should be (expected) + phase.transform(input).collect{ case a: FirrtlCircuitAnnotation => a }.toSeq should be (expected) } it should "compile multiple FirrtlCircuitAnnotations" in new Fixture { @@ -140,4 +142,37 @@ class CompilerSpec extends FlatSpec with Matchers { .size should be (20) } + it should "run transforms in sequential order" in new Fixture { + import CompilerSpec.{FirstTransform, SecondTransform} + + val circuitIn = Parser.parse(chirrtl("top")) + val annotations = + Seq( FirrtlCircuitAnnotation(circuitIn), + CompilerAnnotation(new VerilogCompiler), + RunFirrtlTransformAnnotation(new FirstTransform), + RunFirrtlTransformAnnotation(new SecondTransform) ) + phase.transform(annotations) + + CompilerSpec.globalState.toSeq should be (Seq(classOf[FirstTransform], classOf[SecondTransform])) + } + +} + +object CompilerSpec { + + private[CompilerSpec] val globalState: mutable.Queue[Class[_ <: Transform]] = mutable.Queue.empty[Class[_ <: Transform]] + + class LoggingTransform extends Transform with PreservesAll[Transform] { + override def inputForm = UnknownForm + override def outputForm = UnknownForm + override def prerequisites = Forms.HighForm + def execute(c: CircuitState): CircuitState = { + globalState += this.getClass + c + } + } + + class FirstTransform extends LoggingTransform + class SecondTransform extends LoggingTransform + } diff --git a/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala b/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala index 213ead77..feba5a24 100644 --- a/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala +++ b/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala @@ -100,11 +100,6 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}").exists should be (true) } - "verilog compiler" should "have BlackBoxSourceHelper transform" in { - val verilogCompiler = new VerilogEmitter - verilogCompiler.transforms.map { x => x.getClass } should contain (classOf[BlackBoxSourceHelper]) - } - "verilog header files" should "be available but not mentioned in the file list" in { // Issue #917 - We don't want to list Verilog header files ("*.vh") in our file list. // We don't actually verify that the generated verilog code works, |
