diff options
| author | Adam Izraelevitz | 2017-03-23 16:16:24 -0700 |
|---|---|---|
| committer | GitHub | 2017-03-23 16:16:24 -0700 |
| commit | 67eb4e2de6166b8f1eb5190215640117b82e8c48 (patch) | |
| tree | 18cbaf901eff58262d833bf5bc0d75262c9ab57d | |
| parent | 4cffd184397905eeb79e2df0913b4ded97dc8558 (diff) | |
Pass now subclasses Transform (#477)
49 files changed, 165 insertions, 205 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index b0d42332..ea801086 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -110,61 +110,117 @@ final case object MidForm extends CircuitForm(1) * - All implicit truncations must be made explicit */ final case object LowForm extends CircuitForm(0) +/** Unknown Form + * + * Often passes may modify a circuit (e.g. InferTypes), but return + * a circuit in the same form it was given. + * + * For this use case, use UnknownForm. It cannot be compared against other + * forms. + * + * TODO(azidar): Replace with PreviousForm, which more explicitly encodes + * this requirement. + */ +final case object UnknownForm extends CircuitForm(-1) { + override def compare(that: CircuitForm): Int = { error("Illegal to compare UnknownForm"); 0 } +} /** The basic unit of operating on a Firrtl AST */ -abstract class Transform { +abstract class Transform extends LazyLogging { /** A convenience function useful for debugging and error messages */ def name: String = this.getClass.getSimpleName /** The [[CircuitForm]] that this transform requires to operate on */ def inputForm: CircuitForm /** The [[CircuitForm]] that this transform outputs */ def outputForm: CircuitForm - /** Perform the transform + /** Perform the transform, encode renaming with RenameMap, and can + * delete annotations + * Called by [[runTransform]]. * * @param state Input Firrtl AST * @return A transformed Firrtl AST */ - def execute(state: CircuitState): CircuitState + protected def execute(state: CircuitState): CircuitState /** Convenience method to get annotations relevant to this Transform * * @param state The [[CircuitState]] form which to extract annotations * @return A collection of annotations */ final def getMyAnnotations(state: CircuitState): Seq[Annotation] = state.annotations match { - case Some(annotations) => annotations.get(this.getClass) + case Some(annotations) => annotations.get(this.getClass) //TODO(azidar): ++ annotations.get(classOf[Transform]) case None => Nil } -} + /** Perform the transform and update annotations. + * + * @param state Input Firrtl AST + * @return A transformed Firrtl AST + */ + final def runTransform(state: CircuitState): CircuitState = { + logger.info(s"======== Starting Transform $name ========") + + val (timeMillis, result) = Utils.time { execute(state) } + + logger.info(s"""----------------------------${"-" * name.size}---------\n""") + logger.info(f"Time: $timeMillis%.1f ms") + + val remappedAnnotations = propagateAnnotations(state.annotations, result.annotations, result.renames) + + logger.info(s"Form: ${result.form}") + logger.debug(s"Annotations:") + remappedAnnotations.foreach { a => + logger.debug(a.serialize) + } + logger.debug(s"Circuit:\n${result.circuit.serialize}") + logger.info(s"======== Finished Transform $name ========\n") + + CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations)), None) + } -trait SimpleRun extends LazyLogging { - def runPasses(circuit: Circuit, passSeq: Seq[Pass]): Circuit = - passSeq.foldLeft(circuit) { (c: Circuit, pass: Pass) => - val name = pass.name - logger.info(s"-------- Starting Pass $name --------") - val (timeMillis, x) = Utils.time { pass.run(c) } - logger.info(f"Time: $timeMillis%.1f ms") - logger.debug(s"Circuit:\n${c.serialize}") - logger.info(s"-------- Finished Pass $name --------") - x + /** Propagate annotations and update their names. + * + * @param inAnno input AnnotationMap + * @param resAnno result AnnotationMap + * @param renameOpt result RenameMap + * @return the updated annotations + */ + final private def propagateAnnotations( + inAnno: Option[AnnotationMap], + resAnno: Option[AnnotationMap], + renameOpt: Option[RenameMap]): Seq[Annotation] = { + val newAnnotations = { + val inSet = inAnno.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet + val resSet = resAnno.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet + val deleted = (inSet -- resSet).map { + case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"$xFormName+$name", delAnno) + case anno => DeletedAnnotation(name, anno) + } + val created = resSet -- inSet + val unchanged = resSet & inSet + (deleted ++ created ++ unchanged) } + + // For each annotation, rename all annotations. + val renames = renameOpt.getOrElse(RenameMap()).map + for { + anno <- newAnnotations.toSeq + newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target))) + } yield newAnno + } } -/** For PassBased Transforms and Emitters - * - * @note passSeq accepts no arguments - * @todo make passes accept CircuitState so annotations can pass data between them - */ -trait PassBased extends SimpleRun { - def passSeq: Seq[Pass] - def runPasses(circuit: Circuit): Circuit = runPasses(circuit, passSeq) +trait SeqTransformBased { + def transforms: Seq[Transform] + protected def runTransforms(state: CircuitState): CircuitState = + transforms.foldLeft(state) { (in, xform) => xform.runTransform(in) } } -/** For transformations that are simply a sequence of passes */ -abstract class PassBasedTransform extends Transform with PassBased { +/** For transformations that are simply a sequence of transforms */ +abstract class SeqTransform extends Transform with SeqTransformBased { def execute(state: CircuitState): CircuitState = { require(state.form <= inputForm, s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") - CircuitState(runPasses(state.circuit), outputForm, state.annotations) + val ret = runTransforms(state) + CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames) } } @@ -174,7 +230,7 @@ trait Emitter extends Transform { def emit(state: CircuitState, writer: Writer): Unit } -object CompilerUtils { +object CompilerUtils extends LazyLogging { /** Generates a sequence of [[Transform]]s to lower a Firrtl circuit * * @param inputForm [[CircuitForm]] to lower from @@ -194,6 +250,7 @@ object CompilerUtils { new HighFirrtlToMiddleFirrtl) ++ getLoweringTransforms(MidForm, outputForm) case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) case LowForm => throwInternalError // should be caught by if above + case UnknownForm => throwInternalError // should be caught by if above } } } @@ -318,42 +375,9 @@ trait Compiler extends LazyLogging { def compile(state: CircuitState, customTransforms: Seq[Transform]): CircuitState = { val allTransforms = CompilerUtils.mergeTransforms(transforms, customTransforms) :+ emitter - val finalState = allTransforms.foldLeft(state) { (in, xform) => - logger.info(s"======== Starting Transform ${xform.name} ========") - val (timeMillis, result) = Utils.time { xform.execute(in) } - - logger.info(f"Time: $timeMillis%.1f ms") - - val newAnnotations = { - val inSet = in.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet - val resSet = result.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet - val deleted = (inSet -- resSet).map { - case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"${xFormName}+${xform.name}", delAnno) - case anno => DeletedAnnotation(xform.name, anno) - } - val created = resSet -- inSet - val unchanged = resSet & inSet - (deleted ++ created ++ unchanged) - } - - // For each annotation, rename all annotations. - val renames = result.renames.getOrElse(RenameMap()).map - val remappedAnnotations: Seq[Annotation] = for { - anno <- newAnnotations.toSeq - newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target))) - } yield newAnno - - logger.info(s"Form: ${result.form}") - logger.debug(s"Annotations:") - remappedAnnotations.foreach { a => - logger.debug(a.serialize) - } - logger.debug(s"Circuit:\n${result.circuit.serialize}") - logger.info(s"======== Finished Transform ${xform.name} ========\n") - - CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations))) - } + val finalState = allTransforms.foldLeft(state) { (in, xform) => xform.runTransform(in) } finalState } + } diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 1153b1e6..933d98c5 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -139,7 +139,7 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em } } - def execute(state: CircuitState): CircuitState = { + override def execute(state: CircuitState): CircuitState = { val newAnnos = getMyAnnotations(state).flatMap { case EmitCircuitAnnotation() => Seq(EmittedFirrtlCircuitAnnotation.apply( @@ -161,8 +161,8 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em // ***** Start actual Emitters ***** class HighFirrtlEmitter extends FirrtlEmitter(HighForm) -class MiddleFirrtlEmitter extends FirrtlEmitter(HighForm) -class LowFirrtlEmitter extends FirrtlEmitter(HighForm) +class MiddleFirrtlEmitter extends FirrtlEmitter(MidForm) +class LowFirrtlEmitter extends FirrtlEmitter(LowForm) case class VRandom(width: BigInt) extends Expression { def tpe = UIntType(IntWidth(width)) @@ -174,10 +174,9 @@ case class VRandom(width: BigInt) extends Expression { def mapWidth(f: Width => Width): Expression = this } -class VerilogEmitter extends Transform with PassBased with Emitter { +class VerilogEmitter extends SeqTransform with Emitter { def inputForm = LowForm def outputForm = LowForm - val tab = " " def AND(e1: WrappedExpression, e2: WrappedExpression): Expression = { if (e1 == e2) e1.e1 @@ -744,7 +743,7 @@ class VerilogEmitter extends Transform with PassBased with Emitter { | |""".stripMargin - def passSeq = Seq( + def transforms = Seq( passes.VerilogModulusCleanup, passes.VerilogWrap, passes.VerilogRename, @@ -753,7 +752,7 @@ class VerilogEmitter extends Transform with PassBased with Emitter { def emit(state: CircuitState, writer: Writer): Unit = { writer.write(preamble) - val circuit = runPasses(state.circuit) + val circuit = runTransforms(state).circuit val moduleMap = circuit.modules.map(m => m.name -> m).toMap circuit.modules.foreach { case m: Module => emit_verilog(m, moduleMap)(writer) @@ -761,7 +760,7 @@ class VerilogEmitter extends Transform with PassBased with Emitter { } } - def execute(state: CircuitState): CircuitState = { + override def execute(state: CircuitState): CircuitState = { val newAnnos = getMyAnnotations(state).flatMap { case EmitCircuitAnnotation() => val writer = new java.io.StringWriter @@ -769,7 +768,7 @@ class VerilogEmitter extends Transform with PassBased with Emitter { Seq(EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit(state.circuit.main, writer.toString))) case EmitAllModulesAnnotation() => - val circuit = runPasses(state.circuit) + val circuit = runTransforms(state).circuit val moduleMap = circuit.modules.map(m => m.name -> m).toMap circuit.modules flatMap { diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index c2238c2d..b5808e93 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -2,7 +2,7 @@ package firrtl -sealed abstract class CoreTransform extends PassBasedTransform +sealed abstract class CoreTransform extends SeqTransform /** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting * circuit has only IR nodes, not WIR. @@ -11,7 +11,7 @@ sealed abstract class CoreTransform extends PassBasedTransform class ChirrtlToHighFirrtl extends CoreTransform { def inputForm = ChirrtlForm def outputForm = HighForm - def passSeq = Seq( + def transforms = Seq( passes.CheckChirrtl, passes.CInferTypes, passes.CInferMDir, @@ -24,7 +24,7 @@ class ChirrtlToHighFirrtl extends CoreTransform { class IRToWorkingIR extends CoreTransform { def inputForm = HighForm def outputForm = HighForm - def passSeq = Seq(passes.ToWorkingIR) + def transforms = Seq(passes.ToWorkingIR) } /** Resolves types, kinds, and genders, and checks the circuit legality. @@ -33,7 +33,7 @@ class IRToWorkingIR extends CoreTransform { class ResolveAndCheck extends CoreTransform { def inputForm = HighForm def outputForm = HighForm - def passSeq = Seq( + def transforms = Seq( passes.CheckHighForm, passes.ResolveKinds, passes.InferTypes, @@ -55,7 +55,7 @@ class ResolveAndCheck extends CoreTransform { class HighFirrtlToMiddleFirrtl extends CoreTransform { def inputForm = HighForm def outputForm = MidForm - def passSeq = Seq( + def transforms = Seq( passes.PullMuxes, passes.ReplaceAccesses, passes.ExpandConnects, @@ -80,7 +80,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { class MiddleFirrtlToLowFirrtl extends CoreTransform { def inputForm = MidForm def outputForm = LowForm - def passSeq = Seq( + def transforms = Seq( passes.LowerTypes, passes.ResolveKinds, passes.InferTypes, @@ -96,7 +96,7 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform { class LowFirrtlOptimization extends CoreTransform { def inputForm = LowForm def outputForm = LowForm - def passSeq = Seq( + def transforms = Seq( passes.RemoveValidIf, passes.ConstProp, passes.PadWidths, diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index ef189c11..3722fd0d 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -8,7 +8,6 @@ import firrtl.Utils._ import firrtl.Mappers._ object CheckChirrtl extends Pass { - def name = "Chirrtl Check" type NameSet = collection.mutable.HashSet[String] class NotUniqueException(info: Info, mname: String, name: String) extends PassException( diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 84d6b448..4c392510 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -15,8 +15,6 @@ import annotation.tailrec * @note Assumes single connection (ie. no last connect semantics) */ object CheckInitialization extends Pass { - def name = "Check Initialization" - private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression]) class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException( diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 4b0b1c0d..24735009 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -9,7 +9,6 @@ import firrtl.Mappers._ import firrtl.Utils._ object CheckWidths extends Pass { - def name = "Width Check" /** The maximum allowed width for any circuit element */ val MaxWidth = 1000000 val DshlMaxWidth = ceilLog2(MaxWidth + 1) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index bd4c7f63..0bebcd18 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -10,7 +10,6 @@ import firrtl.Mappers._ import firrtl.WrappedType._ object CheckHighForm extends Pass { - def name = "High Form Check" type NameSet = collection.mutable.HashSet[String] // Custom Exceptions @@ -202,7 +201,6 @@ object CheckHighForm extends Pass { } object CheckTypes extends Pass { - def name = "Check Types" // Custom Exceptions class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException( @@ -463,7 +461,6 @@ object CheckTypes extends Pass { } object CheckGenders extends Pass { - def name = "Check Genders" type GenderMap = collection.mutable.HashMap[String, Gender] implicit def toStr(g: Gender): String = g match { diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 40d04d07..0abdaa36 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -10,8 +10,6 @@ import firrtl.Mappers._ import annotation.tailrec object CommonSubexpressionElimination extends Pass { - def name = "Common Subexpression Elimination" - private def cseOnce(s: Statement): (Statement, Long) = { var nEliminated = 0L val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]() diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index a5a8238e..8736ee31 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -11,8 +11,6 @@ import firrtl.PrimOps._ import annotation.tailrec object ConstProp extends Pass { - def name = "Constant Propagation" - private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 823fb7fb..2e151741 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -12,7 +12,6 @@ import firrtl.Utils.{sub_type, module_type, field_type, BoolType, max, min, pow_ /** Replaces FixedType with SIntType, and correctly aligns all binary points */ object ConvertFixedToSInt extends Pass { - def name = "Convert Fixed Types to SInt Types" def alignArg(e: Expression, point: BigInt): Expression = e.tpe match { case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p) if((point - p) > 0) { diff --git a/src/main/scala/firrtl/passes/DeadCodeElimination.scala b/src/main/scala/firrtl/passes/DeadCodeElimination.scala index 6f37feae..9f249f35 100644 --- a/src/main/scala/firrtl/passes/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/passes/DeadCodeElimination.scala @@ -10,8 +10,6 @@ import firrtl.Mappers._ import annotation.tailrec object DeadCodeElimination extends Pass { - def name = "Dead Code Elimination" - private def dceOnce(s: Statement): (Statement, Long) = { val referenced = collection.mutable.HashSet[String]() var nEliminated = 0L diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index a2845f43..1f093dd1 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -25,7 +25,6 @@ import collection.immutable.ListSet * @note Assumes all references are declared */ object ExpandWhens extends Pass { - def name = "Expand Whens" type NodeMap = mutable.HashMap[MemoizedHash[Expression], String] type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression] type Simlist = mutable.ArrayBuffer[Statement] diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 0e503115..2de2a76e 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -8,7 +8,6 @@ import firrtl.Utils._ import firrtl.Mappers._ object InferTypes extends Pass { - def name = "Infer Types" type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { @@ -76,7 +75,6 @@ object InferTypes extends Pass { } object CInferTypes extends Pass { - def name = "CInfer Types" type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index f3b77ec5..11b819ce 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -12,7 +12,6 @@ import firrtl.Utils._ import firrtl.Mappers._ object InferWidths extends Pass { - def name = "Infer Widths" type ConstraintMap = collection.mutable.LinkedHashMap[String, Width] def solve_constraints(l: Seq[WGeq]): ConstraintMap = { diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index f4556733..2e15f09c 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -27,7 +27,6 @@ class InlineInstances extends Transform { def inputForm = LowForm def outputForm = LowForm val inlineDelim = "$" - override def name = "Inline Instances" private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) { diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 23518d14..5826f56e 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -21,8 +21,6 @@ import firrtl.Mappers._ * }}} */ object LowerTypes extends Pass { - def name = "Lower Types" - /** Delimiter used in lowering names */ val delim = "_" /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 398cc6d7..c9aa1539 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -9,7 +9,6 @@ import firrtl.Mappers._ // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { - def name = "Pad Widths" private def width(t: Type): Int = bitWidth(t).toInt private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index c595727e..68f278a9 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -12,9 +12,23 @@ import firrtl.PrimOps._ import scala.collection.mutable -trait Pass extends LazyLogging { - def name: String +/** [[Pass]] is simple transform that is generally part of a larger [[Transform]] + * Has an [[UnknownForm]], because larger [[Transform]] should specify form + */ +trait Pass extends Transform { + def inputForm: CircuitForm = UnknownForm + def outputForm: CircuitForm = UnknownForm def run(c: Circuit): Circuit + def execute(state: CircuitState): CircuitState = { + val result = (state.form, inputForm) match { + case (_, UnknownForm) => run(state.circuit) + case (UnknownForm, _) => run(state.circuit) + case (x, y) if x > y => + error(s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") + case _ => run(state.circuit) + } + CircuitState(result, outputForm, state.annotations, state.renames) + } } // Error handling @@ -34,8 +48,6 @@ class Errors { // These should be distributed into separate files object ToWorkingIR extends Pass { - def name = "Working IR" - def toExp(e: Expression): Expression = e map toExp match { case ex: Reference => WRef(ex.name, ex.tpe, NodeKind, UNKNOWNGENDER) case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, UNKNOWNGENDER) @@ -54,7 +66,6 @@ object ToWorkingIR extends Pass { } object PullMuxes extends Pass { - def name = "Pull Muxes" def run(c: Circuit): Circuit = { def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { case ex: WSubField => ex.exp match { @@ -93,7 +104,6 @@ object PullMuxes extends Pass { } object ExpandConnects extends Pass { - def name = "Expand Connects" def run(c: Circuit): Circuit = { def expand_connects(m: Module): Module = { val genders = collection.mutable.LinkedHashMap[String,Gender]() @@ -171,7 +181,6 @@ object ExpandConnects extends Pass { // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts // TODO replace UInt with zero-width wire instead object Legalize extends Pass { - def name = "Legalize" private def legalizeShiftRight(e: DoPrim): Expression = { require(e.op == Shr) val amount = e.consts.head.toInt @@ -244,7 +253,6 @@ object Legalize extends Pass { } object VerilogWrap extends Pass { - def name = "Verilog Wrap" def vWrapE(e: Expression): Expression = e map vWrapE match { case e: DoPrim => e.op match { case Tail => e.args.head match { @@ -271,7 +279,6 @@ object VerilogWrap extends Pass { } object VerilogRename extends Pass { - def name = "Verilog Rename" def verilogRenameN(n: String): String = if (v_keywords(n)) "%s$".format(n) else n @@ -301,7 +308,6 @@ object VerilogRename extends Pass { * @note The result of this pass is NOT legal Firrtl */ object VerilogPrep extends Pass { - def name = "Verilog Prep" type AttachSourceMap = Map[WrappedExpression, Expression] diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index a8bc9fb2..5d74d5ba 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -14,8 +14,6 @@ import scala.collection.mutable /** Removes all [[firrtl.WSubAccess]] from circuit */ object RemoveAccesses extends Pass { - def name = "Remove Accesses" - private def AND(e1: Expression, e2: Expression) = DoPrim(And, Seq(e1, e2), Nil, BoolType) diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index aae4ca80..b072dfa0 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -15,8 +15,6 @@ case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], read case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean) object RemoveCHIRRTL extends Pass { - def name = "Remove CHIRRTL" - val ut = UnknownType type MPortMap = collection.mutable.LinkedHashMap[String, MPorts] type SeqMemSet = collection.mutable.HashSet[String] diff --git a/src/main/scala/firrtl/passes/RemoveEmpty.scala b/src/main/scala/firrtl/passes/RemoveEmpty.scala index 0fdfc4d9..97c86dda 100644 --- a/src/main/scala/firrtl/passes/RemoveEmpty.scala +++ b/src/main/scala/firrtl/passes/RemoveEmpty.scala @@ -8,7 +8,6 @@ import firrtl.Mappers._ import firrtl.ir._ object RemoveEmpty extends Pass { - def name = "Remove Empty Statements" private def onModule(m: DefModule): DefModule = { m match { case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body)) diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 7769eac2..865143a5 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -7,7 +7,6 @@ import firrtl.ir._ // Removes ValidIf as an optimization object RemoveValidIf extends Pass { - def name = "Remove ValidIfs" // Recursive. Removes ValidIf's private def onExp(e: Expression): Expression = { e map onExp match { diff --git a/src/main/scala/firrtl/passes/ReplaceAccesses.scala b/src/main/scala/firrtl/passes/ReplaceAccesses.scala index 13562717..c3a5bd4c 100644 --- a/src/main/scala/firrtl/passes/ReplaceAccesses.scala +++ b/src/main/scala/firrtl/passes/ReplaceAccesses.scala @@ -15,8 +15,6 @@ import scala.collection.mutable * TODO Fold in to High Firrtl Const Prop */ object ReplaceAccesses extends Pass { - def name = "Replace Accesses" - def run(c: Circuit): Circuit = { def onStmt(s: Statement): Statement = s map onStmt map onExp def onExp(e: Expression): Expression = e match { diff --git a/src/main/scala/firrtl/passes/Resolves.scala b/src/main/scala/firrtl/passes/Resolves.scala index e60e0478..c8ba43bf 100644 --- a/src/main/scala/firrtl/passes/Resolves.scala +++ b/src/main/scala/firrtl/passes/Resolves.scala @@ -7,7 +7,6 @@ import firrtl.ir._ import firrtl.Mappers._ object ResolveKinds extends Pass { - def name = "Resolve Kinds" type KindMap = collection.mutable.LinkedHashMap[String, Kind] def find_port(kinds: KindMap)(p: Port): Port = { @@ -46,7 +45,6 @@ object ResolveKinds extends Pass { } object ResolveGenders extends Pass { - def name = "Resolve Genders" def resolve_e(g: Gender)(e: Expression): Expression = e match { case ex: WRef => ex copy (gender = g) case WSubField(exp, name, tpe, _) => WSubField( @@ -79,7 +77,6 @@ object ResolveGenders extends Pass { } object CInferMDir extends Pass { - def name = "CInfer MDir" type MPortDirMap = collection.mutable.LinkedHashMap[String, MPortDir] def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match { diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 797292dc..a32f5366 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -13,7 +13,6 @@ import scala.collection.mutable // Splits compound expressions into simple expressions // and named intermediate nodes object SplitExpressions extends Pass { - def name = "Split Expressions" private def onModule(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 40783e21..deddb93e 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -32,8 +32,6 @@ import MemPortUtils.memType * to rename a */ object Uniquify extends Pass { - def name = "Uniquify Identifiers" - private case class UniquifyException(msg: String) extends FIRRTLException(msg) private def error(msg: String)(implicit sinfo: Info, mname: String) = throw new UniquifyException(s"$sinfo: [module $mname] $msg") diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala index b4df534f..330ca497 100644 --- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala +++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala @@ -24,7 +24,6 @@ import scala.collection.mutable * to emit correct verilog without needing to add temporary nodes */ object VerilogModulusCleanup extends Pass { - def name = "Add temporary nodes with verilog widths for modulus" private def onModule(m: Module): Module = { val namespace = Namespace(m) diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index a50fdc16..520075fe 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -11,7 +11,6 @@ import firrtl.Utils.throwInternalError object ZeroWidth extends Pass { - def name = this.getClass.getName private val ZERO = BigInt(0) private def removeZero(t: Type): Option[Type] = t match { case GroundType(IntWidth(ZERO)) => None diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala index 66139c49..bd2536ab 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala @@ -20,7 +20,6 @@ import Mappers._ * Write the result to writer. */ class ClockList(top: String, writer: Writer) extends Pass { - def name = this.getClass.getSimpleName def run(c: Circuit): Circuit = { // Build useful datastructures val childrenMap = getChildrenMap(c) diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala index feb7f42e..53787b1d 100644 --- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala +++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala @@ -20,7 +20,6 @@ import Mappers._ * expressions do not relate to ground types. */ object RemoveAllButClocks extends Pass { - def name = this.getClass.getSimpleName def onStmt(s: Statement): Statement = (s map onStmt) match { case DefWire(i, n, ClockType) => s case DefNode(i, n, value) if value.tpe == ClockType => s diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 668bc2e5..e48dc8c2 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -10,7 +10,6 @@ import wiring._ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform { def inputForm = MidForm def outputForm = MidForm - override def name = "Create Memory Annotations" def execute(state: CircuitState): CircuitState = reader match { case None => state case Some(r) => diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 9bd6a4ab..73fec1ee 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -29,7 +29,6 @@ object InferReadWriteAnnotation { // of any product term of the enable signal of the write port, then the readwrite // port is inferred. object InferReadWritePass extends Pass { - def name = "Infer ReadWrite Ports" type Netlist = collection.mutable.HashMap[String, Expression] type Statements = collection.mutable.ArrayBuffer[Statement] @@ -150,10 +149,10 @@ object InferReadWritePass extends Pass { // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite extends Transform with PassBased { +class InferReadWrite extends Transform with SeqTransformBased { def inputForm = MidForm def outputForm = MidForm - def passSeq = Seq( + def transforms = Seq( InferReadWritePass, CheckInitialization, InferTypes, @@ -163,6 +162,7 @@ class InferReadWrite extends Transform with PassBased { def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { case Nil => state case Seq(InferReadWriteAnnotation(CircuitName(state.circuit.main))) => - state.copy(circuit = runPasses(state.circuit)) + val ret = runTransforms(state) + CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames) } } diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala index 57c301b1..9debff7a 100644 --- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala +++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala @@ -15,9 +15,6 @@ import MemTransformUtils._ /** Changes memory port names to standard port names (i.e. RW0 instead T_408) */ object RenameAnnotatedMemoryPorts extends Pass { - - def name = "Rename Annotated Memory Ports" - /** Renames memory ports to a standard naming scheme: * - R0, R1, ... for each read port * - W0, W1, ... for each write port diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index af6761fd..b18ed289 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -33,7 +33,6 @@ object PinAnnotation { * Creates the minimum # of black boxes needed by the design. */ class ReplaceMemMacros(writer: ConfWriter) extends Transform { - override def name = "Replace Memory Macros" def inputForm = MidForm def outputForm = MidForm @@ -227,11 +226,14 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform { case Seq(PinAnnotation(CircuitName(c), pins)) => pins case _ => throwInternalError } - val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) => + val annos = (pins.foldLeft(Seq[Annotation]()) { (seq, pin) => seq ++ memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), pin) } - } + }) ++ (state.annotations match { + case None => Seq.empty + case Some(a) => a.annotations + }) CircuitState(c.copy(modules = modules ++ memMods), inputForm, Some(AnnotationMap(annos))) } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 0c12d2aa..caaf430b 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -116,10 +116,10 @@ class SimpleTransform(p: Pass, form: CircuitForm) extends Transform { class SimpleMidTransform(p: Pass) extends SimpleTransform(p, MidForm) // SimpleRun instead of PassBased because of the arguments to passSeq -class ReplSeqMem extends Transform with SimpleRun { +class ReplSeqMem extends Transform { def inputForm = MidForm def outputForm = MidForm - def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = + def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = Seq(new SimpleMidTransform(Legalize), new SimpleMidTransform(ToMemIR), new SimpleMidTransform(ResolveMaskGranularity), @@ -134,31 +134,19 @@ class ReplSeqMem extends Transform with SimpleRun { new SimpleMidTransform(Uniquify), new SimpleMidTransform(ResolveKinds), new SimpleMidTransform(ResolveGenders)) - def run(state: CircuitState, xForms: Seq[Transform]): CircuitState = { - (xForms.foldLeft(state) { case (curState: CircuitState, xForm: Transform) => - val res = xForm.execute(curState) - val newAnnotations = res.annotations match { - case None => curState.annotations - case Some(ann) => - Some(AnnotationMap(ann.annotations ++ curState.annotations.get.annotations)) - } - CircuitState(res.circuit, res.form, newAnnotations) - }) - } - def execute(state: CircuitState): CircuitState = - getMyAnnotations(state) match { - case Nil => state // Do nothing if there are no annotations - case p => (p.collectFirst { case a if (a.target == CircuitName(state.circuit.main)) => a }) match { - case Some(ReplSeqMemAnnotation(target, inputFileName, outputConfig)) => - val inConfigFile = { - if (inputFileName.isEmpty) None - else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) - else error("Input configuration file does not exist!") - } - val outConfigFile = new ConfWriter(outputConfig) - run(state, passSeq(inConfigFile, outConfigFile)) - case _ => error("Unexpected transform annotation") - } + def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { + case Nil => state // Do nothing if there are no annotations + case p => (p.collectFirst { case a if (a.target == CircuitName(state.circuit.main)) => a }) match { + case Some(ReplSeqMemAnnotation(target, inputFileName, outputConfig)) => + val inConfigFile = { + if (inputFileName.isEmpty) None + else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) + else error("Input configuration file does not exist!") + } + val outConfigFile = new ConfWriter(outputConfig) + transforms(inConfigFile, outConfigFile).foldLeft(state) { (in, xform) => xform.runTransform(in) } + case _ => error("Unexpected transform annotation") } + } } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 956bdd3c..79ecd9cd 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -90,7 +90,6 @@ object AnalysisUtils { * TODO(shunshou): Add floorplan info? */ object ResolveMaskGranularity extends Pass { - def name = "Resolve Mask Granularity" /** Returns the number of mask bits, if used */ diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala index df555e57..e132e369 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -51,6 +51,6 @@ class ResolveMemoryReference extends Transform { case annos => annos.collect { case NoDedupMemAnnotation(ComponentName(cn, _)) => cn } } - CircuitState(run(state.circuit, noDedups), state.form) + state.copy(circuit=run(state.circuit, noDedups)) } } diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala index eb9d0859..feb6ae59 100644 --- a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala @@ -13,8 +13,6 @@ import firrtl.ir._ * - zero or one read port */ object ToMemIR extends Pass { - def name = "To Memory IR" - /** Only annotate memories that are candidates for memory macro replacements * i.e. rw, w + r (read, write 1 cycle delay) */ diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index fc126b74..6eefb69e 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -12,7 +12,6 @@ import MemPortUtils._ /** This pass generates delay reigsters for memories for verilog */ object VerilogMemDelays extends Pass { - def name = "Verilog Memory Delays" val ug = UNKNOWNGENDER type Netlist = collection.mutable.HashMap[String, Expression] implicit def expToString(e: Expression): String = e.serialize diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala index f5da4c06..9656abb2 100644 --- a/src/main/scala/firrtl/passes/wiring/Wiring.scala +++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala @@ -17,7 +17,6 @@ case class WiringException(msg: String) extends PassException(msg) case class WiringInfo(source: String, comp: String, sinks: Set[String], pin: String, top: String) class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { - def name = this.getClass.getSimpleName def run(c: Circuit): Circuit = { wiSeq.foldLeft(c) { (circuit, wi) => wire(circuit, wi) } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index 2c122943..a8ef5f58 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -60,10 +60,10 @@ object TopAnnotation { * Notes: * - No module uniquification occurs (due to imposed restrictions) */ -class WiringTransform extends Transform with SimpleRun { +class WiringTransform extends Transform { def inputForm = MidForm def outputForm = MidForm - def passSeq(wis: Seq[WiringInfo]) = + def transforms(wis: Seq[WiringInfo]) = Seq(new Wiring(wis), InferTypes, ResolveKinds, @@ -89,7 +89,7 @@ class WiringTransform extends Transform with SimpleRun { val wis = tops.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, top)) => seq :+ WiringInfo(sources(pin), comp(pin), sinks(pin), pin, top) } - state.copy(circuit = runPasses(state.circuit, passSeq(wis))) + transforms(wis).foldLeft(state) { (in, xform) => xform.runTransform(in) } case _ => error("Wrong number of sources, tops, or sinks!") } } diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 77f07781..81d982e4 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -17,7 +17,7 @@ import org.scalatest.Matchers */ trait AnnotationSpec extends LowTransformSpec { // Dummy transform - def transform = new CustomResolveAndCheck(LowForm) + def transform = new ResolveAndCheck // Check if Annotation Exception is thrown override def failingexecute(annotations: AnnotationMap, input: String): Exception = { diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala index 773a0bf3..0d31038a 100644 --- a/src/test/scala/firrtlTests/CInferMDirSpec.scala +++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala @@ -10,8 +10,6 @@ import annotations._ class CInferMDir extends LowTransformSpec { object CInferMDirCheckPass extends Pass { - val name = "Check Enable Signal for Chirrtl Mems" - // finds the memory and check its read port def checkStmt(s: Statement): Boolean = s match { case s: DefMemory if s.name == "indices" => @@ -38,10 +36,10 @@ class CInferMDir extends LowTransformSpec { } } - def transform = new PassBasedTransform { + def transform = new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def passSeq = Seq(ConstProp, CInferMDirCheckPass) + def transforms = Seq(ConstProp, CInferMDirCheckPass) } "Memory" should "have correct mem port directions" in { diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index fd984661..c963c8ae 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -10,7 +10,6 @@ import annotations._ class ChirrtlMemSpec extends LowTransformSpec { object MemEnableCheckPass extends Pass { - val name = "Check Enable Signal for Chirrtl Mems" type Netlist = collection.mutable.HashMap[String, Expression] def buildNetlist(netlist: Netlist)(s: Statement): Statement = { s match { @@ -51,10 +50,10 @@ class ChirrtlMemSpec extends LowTransformSpec { } } - def transform = new PassBasedTransform { + def transform = new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def passSeq = Seq(ConstProp, MemEnableCheckPass) + def transforms = Seq(ConstProp, MemEnableCheckPass) } "Sequential Memory" should "have correct enable signals" in { diff --git a/src/test/scala/firrtlTests/CustomTransformSpec.scala b/src/test/scala/firrtlTests/CustomTransformSpec.scala index 3a20082f..d1ff6fd1 100644 --- a/src/test/scala/firrtlTests/CustomTransformSpec.scala +++ b/src/test/scala/firrtlTests/CustomTransformSpec.scala @@ -30,9 +30,8 @@ class CustomTransformSpec extends FirrtlFlatSpec { val delayModuleCircuit = parse(delayModuleString) val delayModule = delayModuleCircuit.modules.find(_.name == delayModuleCircuit.main).get - class ReplaceExtModuleTransform extends PassBasedTransform { + class ReplaceExtModuleTransform extends SeqTransform { class ReplaceExtModule extends Pass { - def name = "Replace External Module" def run(c: Circuit): Circuit = c.copy( modules = c.modules map { case ExtModule(_, "Delay", _, _, _) => delayModule @@ -40,7 +39,7 @@ class CustomTransformSpec extends FirrtlFlatSpec { } ) } - def passSeq = Seq(new ReplaceExtModule) + def transforms = Seq(new ReplaceExtModule) def inputForm = LowForm def outputForm = HighForm } diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index 73fdbe91..82c9d65f 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -12,8 +12,9 @@ class InferReadWriteSpec extends SimpleTransformSpec { class InferReadWriteCheckException extends PassException( "Readwrite ports are not found!") - object InferReadWriteCheckPass extends Pass { - val name = "Check Infer ReadWrite Ports" + object InferReadWriteCheck extends Pass { + override def inputForm = MidForm + override def outputForm = MidForm def findReadWrite(s: Statement): Boolean = s match { case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 => s.name == "mem" && s.readwriters.head == "rw" @@ -36,12 +37,6 @@ class InferReadWriteSpec extends SimpleTransformSpec { } } - class InferReadWriteCheck extends PassBasedTransform { - def inputForm = MidForm - def outputForm = MidForm - def passSeq = Seq(InferReadWriteCheckPass) - } - def emitter = new MiddleFirrtlEmitter def transforms = Seq( new ChirrtlToHighFirrtl, @@ -49,7 +44,7 @@ class InferReadWriteSpec extends SimpleTransformSpec { new ResolveAndCheck, new HighFirrtlToMiddleFirrtl, new memlib.InferReadWrite, - new InferReadWriteCheck + InferReadWriteCheck ) "Infer ReadWrite Ports" should "infer readwrite ports for the same clock" in { diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index 589dfd38..e22fd513 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -35,17 +35,16 @@ abstract class SimpleTransformSpec extends FlatSpec with Matchers with Compiler } } -class CustomResolveAndCheck(form: CircuitForm) extends PassBasedTransform { - private val wrappedTransform = new ResolveAndCheck +class CustomResolveAndCheck(form: CircuitForm) extends SeqTransform { def inputForm = form def outputForm = form - def passSeq = wrappedTransform.passSeq + def transforms: Seq[Transform] = Seq[Transform](new ResolveAndCheck) } trait LowTransformSpec extends SimpleTransformSpec { def emitter = new LowFirrtlEmitter def transform: Transform - def transforms = Seq( + def transforms: Seq[Transform] = Seq( new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), @@ -59,7 +58,7 @@ trait LowTransformSpec extends SimpleTransformSpec { trait MiddleTransformSpec extends SimpleTransformSpec { def emitter = new MiddleFirrtlEmitter def transform: Transform - def transforms = Seq( + def transforms: Seq[Transform] = Seq( new ChirrtlToHighFirrtl(), new IRToWorkingIR(), new ResolveAndCheck(), @@ -75,7 +74,7 @@ trait HighTransformSpec extends SimpleTransformSpec { def transforms = Seq( new ChirrtlToHighFirrtl(), new IRToWorkingIR(), - new ResolveAndCheck(), + new CustomResolveAndCheck(HighForm), transform ) } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 1a5b44e6..0831bb31 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -18,10 +18,10 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new InferReadWrite(), new ReplSeqMem(), new MiddleFirrtlToLowFirrtl(), - new PassBasedTransform { + new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def passSeq = Seq(ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) + def transforms = Seq(ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) } ) diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala index ce591485..34a22c26 100644 --- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala +++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala @@ -178,10 +178,10 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | io_out <= io_in """.stripMargin - class CheckChirrtlTransform extends PassBasedTransform { + class CheckChirrtlTransform extends SeqTransform { def inputForm = ChirrtlForm def outputForm = ChirrtlForm - val passSeq = Seq(passes.CheckChirrtl) + val transforms = Seq(passes.CheckChirrtl) } val chirrtlTransform = new CheckChirrtlTransform |
