diff options
| author | Adam Izraelevitz | 2017-03-23 16:16:24 -0700 |
|---|---|---|
| committer | GitHub | 2017-03-23 16:16:24 -0700 |
| commit | 67eb4e2de6166b8f1eb5190215640117b82e8c48 (patch) | |
| tree | 18cbaf901eff58262d833bf5bc0d75262c9ab57d /src/main | |
| parent | 4cffd184397905eeb79e2df0913b4ded97dc8558 (diff) | |
Pass now subclasses Transform (#477)
Diffstat (limited to 'src/main')
41 files changed, 145 insertions, 175 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index b0d42332..ea801086 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -110,61 +110,117 @@ final case object MidForm extends CircuitForm(1) * - All implicit truncations must be made explicit */ final case object LowForm extends CircuitForm(0) +/** Unknown Form + * + * Often passes may modify a circuit (e.g. InferTypes), but return + * a circuit in the same form it was given. + * + * For this use case, use UnknownForm. It cannot be compared against other + * forms. + * + * TODO(azidar): Replace with PreviousForm, which more explicitly encodes + * this requirement. + */ +final case object UnknownForm extends CircuitForm(-1) { + override def compare(that: CircuitForm): Int = { error("Illegal to compare UnknownForm"); 0 } +} /** The basic unit of operating on a Firrtl AST */ -abstract class Transform { +abstract class Transform extends LazyLogging { /** A convenience function useful for debugging and error messages */ def name: String = this.getClass.getSimpleName /** The [[CircuitForm]] that this transform requires to operate on */ def inputForm: CircuitForm /** The [[CircuitForm]] that this transform outputs */ def outputForm: CircuitForm - /** Perform the transform + /** Perform the transform, encode renaming with RenameMap, and can + * delete annotations + * Called by [[runTransform]]. * * @param state Input Firrtl AST * @return A transformed Firrtl AST */ - def execute(state: CircuitState): CircuitState + protected def execute(state: CircuitState): CircuitState /** Convenience method to get annotations relevant to this Transform * * @param state The [[CircuitState]] form which to extract annotations * @return A collection of annotations */ final def getMyAnnotations(state: CircuitState): Seq[Annotation] = state.annotations match { - case Some(annotations) => annotations.get(this.getClass) + case Some(annotations) => annotations.get(this.getClass) //TODO(azidar): ++ annotations.get(classOf[Transform]) case None => Nil } -} + /** Perform the transform and update annotations. + * + * @param state Input Firrtl AST + * @return A transformed Firrtl AST + */ + final def runTransform(state: CircuitState): CircuitState = { + logger.info(s"======== Starting Transform $name ========") + + val (timeMillis, result) = Utils.time { execute(state) } + + logger.info(s"""----------------------------${"-" * name.size}---------\n""") + logger.info(f"Time: $timeMillis%.1f ms") + + val remappedAnnotations = propagateAnnotations(state.annotations, result.annotations, result.renames) + + logger.info(s"Form: ${result.form}") + logger.debug(s"Annotations:") + remappedAnnotations.foreach { a => + logger.debug(a.serialize) + } + logger.debug(s"Circuit:\n${result.circuit.serialize}") + logger.info(s"======== Finished Transform $name ========\n") + + CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations)), None) + } -trait SimpleRun extends LazyLogging { - def runPasses(circuit: Circuit, passSeq: Seq[Pass]): Circuit = - passSeq.foldLeft(circuit) { (c: Circuit, pass: Pass) => - val name = pass.name - logger.info(s"-------- Starting Pass $name --------") - val (timeMillis, x) = Utils.time { pass.run(c) } - logger.info(f"Time: $timeMillis%.1f ms") - logger.debug(s"Circuit:\n${c.serialize}") - logger.info(s"-------- Finished Pass $name --------") - x + /** Propagate annotations and update their names. + * + * @param inAnno input AnnotationMap + * @param resAnno result AnnotationMap + * @param renameOpt result RenameMap + * @return the updated annotations + */ + final private def propagateAnnotations( + inAnno: Option[AnnotationMap], + resAnno: Option[AnnotationMap], + renameOpt: Option[RenameMap]): Seq[Annotation] = { + val newAnnotations = { + val inSet = inAnno.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet + val resSet = resAnno.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet + val deleted = (inSet -- resSet).map { + case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"$xFormName+$name", delAnno) + case anno => DeletedAnnotation(name, anno) + } + val created = resSet -- inSet + val unchanged = resSet & inSet + (deleted ++ created ++ unchanged) } + + // For each annotation, rename all annotations. + val renames = renameOpt.getOrElse(RenameMap()).map + for { + anno <- newAnnotations.toSeq + newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target))) + } yield newAnno + } } -/** For PassBased Transforms and Emitters - * - * @note passSeq accepts no arguments - * @todo make passes accept CircuitState so annotations can pass data between them - */ -trait PassBased extends SimpleRun { - def passSeq: Seq[Pass] - def runPasses(circuit: Circuit): Circuit = runPasses(circuit, passSeq) +trait SeqTransformBased { + def transforms: Seq[Transform] + protected def runTransforms(state: CircuitState): CircuitState = + transforms.foldLeft(state) { (in, xform) => xform.runTransform(in) } } -/** For transformations that are simply a sequence of passes */ -abstract class PassBasedTransform extends Transform with PassBased { +/** For transformations that are simply a sequence of transforms */ +abstract class SeqTransform extends Transform with SeqTransformBased { def execute(state: CircuitState): CircuitState = { require(state.form <= inputForm, s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") - CircuitState(runPasses(state.circuit), outputForm, state.annotations) + val ret = runTransforms(state) + CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames) } } @@ -174,7 +230,7 @@ trait Emitter extends Transform { def emit(state: CircuitState, writer: Writer): Unit } -object CompilerUtils { +object CompilerUtils extends LazyLogging { /** Generates a sequence of [[Transform]]s to lower a Firrtl circuit * * @param inputForm [[CircuitForm]] to lower from @@ -194,6 +250,7 @@ object CompilerUtils { new HighFirrtlToMiddleFirrtl) ++ getLoweringTransforms(MidForm, outputForm) case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) case LowForm => throwInternalError // should be caught by if above + case UnknownForm => throwInternalError // should be caught by if above } } } @@ -318,42 +375,9 @@ trait Compiler extends LazyLogging { def compile(state: CircuitState, customTransforms: Seq[Transform]): CircuitState = { val allTransforms = CompilerUtils.mergeTransforms(transforms, customTransforms) :+ emitter - val finalState = allTransforms.foldLeft(state) { (in, xform) => - logger.info(s"======== Starting Transform ${xform.name} ========") - val (timeMillis, result) = Utils.time { xform.execute(in) } - - logger.info(f"Time: $timeMillis%.1f ms") - - val newAnnotations = { - val inSet = in.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet - val resSet = result.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet - val deleted = (inSet -- resSet).map { - case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"${xFormName}+${xform.name}", delAnno) - case anno => DeletedAnnotation(xform.name, anno) - } - val created = resSet -- inSet - val unchanged = resSet & inSet - (deleted ++ created ++ unchanged) - } - - // For each annotation, rename all annotations. - val renames = result.renames.getOrElse(RenameMap()).map - val remappedAnnotations: Seq[Annotation] = for { - anno <- newAnnotations.toSeq - newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target))) - } yield newAnno - - logger.info(s"Form: ${result.form}") - logger.debug(s"Annotations:") - remappedAnnotations.foreach { a => - logger.debug(a.serialize) - } - logger.debug(s"Circuit:\n${result.circuit.serialize}") - logger.info(s"======== Finished Transform ${xform.name} ========\n") - - CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations))) - } + val finalState = allTransforms.foldLeft(state) { (in, xform) => xform.runTransform(in) } finalState } + } diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 1153b1e6..933d98c5 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -139,7 +139,7 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em } } - def execute(state: CircuitState): CircuitState = { + override def execute(state: CircuitState): CircuitState = { val newAnnos = getMyAnnotations(state).flatMap { case EmitCircuitAnnotation() => Seq(EmittedFirrtlCircuitAnnotation.apply( @@ -161,8 +161,8 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em // ***** Start actual Emitters ***** class HighFirrtlEmitter extends FirrtlEmitter(HighForm) -class MiddleFirrtlEmitter extends FirrtlEmitter(HighForm) -class LowFirrtlEmitter extends FirrtlEmitter(HighForm) +class MiddleFirrtlEmitter extends FirrtlEmitter(MidForm) +class LowFirrtlEmitter extends FirrtlEmitter(LowForm) case class VRandom(width: BigInt) extends Expression { def tpe = UIntType(IntWidth(width)) @@ -174,10 +174,9 @@ case class VRandom(width: BigInt) extends Expression { def mapWidth(f: Width => Width): Expression = this } -class VerilogEmitter extends Transform with PassBased with Emitter { +class VerilogEmitter extends SeqTransform with Emitter { def inputForm = LowForm def outputForm = LowForm - val tab = " " def AND(e1: WrappedExpression, e2: WrappedExpression): Expression = { if (e1 == e2) e1.e1 @@ -744,7 +743,7 @@ class VerilogEmitter extends Transform with PassBased with Emitter { | |""".stripMargin - def passSeq = Seq( + def transforms = Seq( passes.VerilogModulusCleanup, passes.VerilogWrap, passes.VerilogRename, @@ -753,7 +752,7 @@ class VerilogEmitter extends Transform with PassBased with Emitter { def emit(state: CircuitState, writer: Writer): Unit = { writer.write(preamble) - val circuit = runPasses(state.circuit) + val circuit = runTransforms(state).circuit val moduleMap = circuit.modules.map(m => m.name -> m).toMap circuit.modules.foreach { case m: Module => emit_verilog(m, moduleMap)(writer) @@ -761,7 +760,7 @@ class VerilogEmitter extends Transform with PassBased with Emitter { } } - def execute(state: CircuitState): CircuitState = { + override def execute(state: CircuitState): CircuitState = { val newAnnos = getMyAnnotations(state).flatMap { case EmitCircuitAnnotation() => val writer = new java.io.StringWriter @@ -769,7 +768,7 @@ class VerilogEmitter extends Transform with PassBased with Emitter { Seq(EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit(state.circuit.main, writer.toString))) case EmitAllModulesAnnotation() => - val circuit = runPasses(state.circuit) + val circuit = runTransforms(state).circuit val moduleMap = circuit.modules.map(m => m.name -> m).toMap circuit.modules flatMap { diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index c2238c2d..b5808e93 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -2,7 +2,7 @@ package firrtl -sealed abstract class CoreTransform extends PassBasedTransform +sealed abstract class CoreTransform extends SeqTransform /** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting * circuit has only IR nodes, not WIR. @@ -11,7 +11,7 @@ sealed abstract class CoreTransform extends PassBasedTransform class ChirrtlToHighFirrtl extends CoreTransform { def inputForm = ChirrtlForm def outputForm = HighForm - def passSeq = Seq( + def transforms = Seq( passes.CheckChirrtl, passes.CInferTypes, passes.CInferMDir, @@ -24,7 +24,7 @@ class ChirrtlToHighFirrtl extends CoreTransform { class IRToWorkingIR extends CoreTransform { def inputForm = HighForm def outputForm = HighForm - def passSeq = Seq(passes.ToWorkingIR) + def transforms = Seq(passes.ToWorkingIR) } /** Resolves types, kinds, and genders, and checks the circuit legality. @@ -33,7 +33,7 @@ class IRToWorkingIR extends CoreTransform { class ResolveAndCheck extends CoreTransform { def inputForm = HighForm def outputForm = HighForm - def passSeq = Seq( + def transforms = Seq( passes.CheckHighForm, passes.ResolveKinds, passes.InferTypes, @@ -55,7 +55,7 @@ class ResolveAndCheck extends CoreTransform { class HighFirrtlToMiddleFirrtl extends CoreTransform { def inputForm = HighForm def outputForm = MidForm - def passSeq = Seq( + def transforms = Seq( passes.PullMuxes, passes.ReplaceAccesses, passes.ExpandConnects, @@ -80,7 +80,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { class MiddleFirrtlToLowFirrtl extends CoreTransform { def inputForm = MidForm def outputForm = LowForm - def passSeq = Seq( + def transforms = Seq( passes.LowerTypes, passes.ResolveKinds, passes.InferTypes, @@ -96,7 +96,7 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform { class LowFirrtlOptimization extends CoreTransform { def inputForm = LowForm def outputForm = LowForm - def passSeq = Seq( + def transforms = Seq( passes.RemoveValidIf, passes.ConstProp, passes.PadWidths, diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index ef189c11..3722fd0d 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -8,7 +8,6 @@ import firrtl.Utils._ import firrtl.Mappers._ object CheckChirrtl extends Pass { - def name = "Chirrtl Check" type NameSet = collection.mutable.HashSet[String] class NotUniqueException(info: Info, mname: String, name: String) extends PassException( diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 84d6b448..4c392510 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -15,8 +15,6 @@ import annotation.tailrec * @note Assumes single connection (ie. no last connect semantics) */ object CheckInitialization extends Pass { - def name = "Check Initialization" - private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression]) class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException( diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 4b0b1c0d..24735009 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -9,7 +9,6 @@ import firrtl.Mappers._ import firrtl.Utils._ object CheckWidths extends Pass { - def name = "Width Check" /** The maximum allowed width for any circuit element */ val MaxWidth = 1000000 val DshlMaxWidth = ceilLog2(MaxWidth + 1) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index bd4c7f63..0bebcd18 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -10,7 +10,6 @@ import firrtl.Mappers._ import firrtl.WrappedType._ object CheckHighForm extends Pass { - def name = "High Form Check" type NameSet = collection.mutable.HashSet[String] // Custom Exceptions @@ -202,7 +201,6 @@ object CheckHighForm extends Pass { } object CheckTypes extends Pass { - def name = "Check Types" // Custom Exceptions class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException( @@ -463,7 +461,6 @@ object CheckTypes extends Pass { } object CheckGenders extends Pass { - def name = "Check Genders" type GenderMap = collection.mutable.HashMap[String, Gender] implicit def toStr(g: Gender): String = g match { diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 40d04d07..0abdaa36 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -10,8 +10,6 @@ import firrtl.Mappers._ import annotation.tailrec object CommonSubexpressionElimination extends Pass { - def name = "Common Subexpression Elimination" - private def cseOnce(s: Statement): (Statement, Long) = { var nEliminated = 0L val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]() diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index a5a8238e..8736ee31 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -11,8 +11,6 @@ import firrtl.PrimOps._ import annotation.tailrec object ConstProp extends Pass { - def name = "Constant Propagation" - private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 823fb7fb..2e151741 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -12,7 +12,6 @@ import firrtl.Utils.{sub_type, module_type, field_type, BoolType, max, min, pow_ /** Replaces FixedType with SIntType, and correctly aligns all binary points */ object ConvertFixedToSInt extends Pass { - def name = "Convert Fixed Types to SInt Types" def alignArg(e: Expression, point: BigInt): Expression = e.tpe match { case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p) if((point - p) > 0) { diff --git a/src/main/scala/firrtl/passes/DeadCodeElimination.scala b/src/main/scala/firrtl/passes/DeadCodeElimination.scala index 6f37feae..9f249f35 100644 --- a/src/main/scala/firrtl/passes/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/passes/DeadCodeElimination.scala @@ -10,8 +10,6 @@ import firrtl.Mappers._ import annotation.tailrec object DeadCodeElimination extends Pass { - def name = "Dead Code Elimination" - private def dceOnce(s: Statement): (Statement, Long) = { val referenced = collection.mutable.HashSet[String]() var nEliminated = 0L diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index a2845f43..1f093dd1 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -25,7 +25,6 @@ import collection.immutable.ListSet * @note Assumes all references are declared */ object ExpandWhens extends Pass { - def name = "Expand Whens" type NodeMap = mutable.HashMap[MemoizedHash[Expression], String] type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression] type Simlist = mutable.ArrayBuffer[Statement] diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 0e503115..2de2a76e 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -8,7 +8,6 @@ import firrtl.Utils._ import firrtl.Mappers._ object InferTypes extends Pass { - def name = "Infer Types" type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { @@ -76,7 +75,6 @@ object InferTypes extends Pass { } object CInferTypes extends Pass { - def name = "CInfer Types" type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index f3b77ec5..11b819ce 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -12,7 +12,6 @@ import firrtl.Utils._ import firrtl.Mappers._ object InferWidths extends Pass { - def name = "Infer Widths" type ConstraintMap = collection.mutable.LinkedHashMap[String, Width] def solve_constraints(l: Seq[WGeq]): ConstraintMap = { diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index f4556733..2e15f09c 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -27,7 +27,6 @@ class InlineInstances extends Transform { def inputForm = LowForm def outputForm = LowForm val inlineDelim = "$" - override def name = "Inline Instances" private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) { diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 23518d14..5826f56e 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -21,8 +21,6 @@ import firrtl.Mappers._ * }}} */ object LowerTypes extends Pass { - def name = "Lower Types" - /** Delimiter used in lowering names */ val delim = "_" /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 398cc6d7..c9aa1539 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -9,7 +9,6 @@ import firrtl.Mappers._ // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { - def name = "Pad Widths" private def width(t: Type): Int = bitWidth(t).toInt private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index c595727e..68f278a9 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -12,9 +12,23 @@ import firrtl.PrimOps._ import scala.collection.mutable -trait Pass extends LazyLogging { - def name: String +/** [[Pass]] is simple transform that is generally part of a larger [[Transform]] + * Has an [[UnknownForm]], because larger [[Transform]] should specify form + */ +trait Pass extends Transform { + def inputForm: CircuitForm = UnknownForm + def outputForm: CircuitForm = UnknownForm def run(c: Circuit): Circuit + def execute(state: CircuitState): CircuitState = { + val result = (state.form, inputForm) match { + case (_, UnknownForm) => run(state.circuit) + case (UnknownForm, _) => run(state.circuit) + case (x, y) if x > y => + error(s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") + case _ => run(state.circuit) + } + CircuitState(result, outputForm, state.annotations, state.renames) + } } // Error handling @@ -34,8 +48,6 @@ class Errors { // These should be distributed into separate files object ToWorkingIR extends Pass { - def name = "Working IR" - def toExp(e: Expression): Expression = e map toExp match { case ex: Reference => WRef(ex.name, ex.tpe, NodeKind, UNKNOWNGENDER) case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, UNKNOWNGENDER) @@ -54,7 +66,6 @@ object ToWorkingIR extends Pass { } object PullMuxes extends Pass { - def name = "Pull Muxes" def run(c: Circuit): Circuit = { def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { case ex: WSubField => ex.exp match { @@ -93,7 +104,6 @@ object PullMuxes extends Pass { } object ExpandConnects extends Pass { - def name = "Expand Connects" def run(c: Circuit): Circuit = { def expand_connects(m: Module): Module = { val genders = collection.mutable.LinkedHashMap[String,Gender]() @@ -171,7 +181,6 @@ object ExpandConnects extends Pass { // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts // TODO replace UInt with zero-width wire instead object Legalize extends Pass { - def name = "Legalize" private def legalizeShiftRight(e: DoPrim): Expression = { require(e.op == Shr) val amount = e.consts.head.toInt @@ -244,7 +253,6 @@ object Legalize extends Pass { } object VerilogWrap extends Pass { - def name = "Verilog Wrap" def vWrapE(e: Expression): Expression = e map vWrapE match { case e: DoPrim => e.op match { case Tail => e.args.head match { @@ -271,7 +279,6 @@ object VerilogWrap extends Pass { } object VerilogRename extends Pass { - def name = "Verilog Rename" def verilogRenameN(n: String): String = if (v_keywords(n)) "%s$".format(n) else n @@ -301,7 +308,6 @@ object VerilogRename extends Pass { * @note The result of this pass is NOT legal Firrtl */ object VerilogPrep extends Pass { - def name = "Verilog Prep" type AttachSourceMap = Map[WrappedExpression, Expression] diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index a8bc9fb2..5d74d5ba 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -14,8 +14,6 @@ import scala.collection.mutable /** Removes all [[firrtl.WSubAccess]] from circuit */ object RemoveAccesses extends Pass { - def name = "Remove Accesses" - private def AND(e1: Expression, e2: Expression) = DoPrim(And, Seq(e1, e2), Nil, BoolType) diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index aae4ca80..b072dfa0 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -15,8 +15,6 @@ case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], read case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean) object RemoveCHIRRTL extends Pass { - def name = "Remove CHIRRTL" - val ut = UnknownType type MPortMap = collection.mutable.LinkedHashMap[String, MPorts] type SeqMemSet = collection.mutable.HashSet[String] diff --git a/src/main/scala/firrtl/passes/RemoveEmpty.scala b/src/main/scala/firrtl/passes/RemoveEmpty.scala index 0fdfc4d9..97c86dda 100644 --- a/src/main/scala/firrtl/passes/RemoveEmpty.scala +++ b/src/main/scala/firrtl/passes/RemoveEmpty.scala @@ -8,7 +8,6 @@ import firrtl.Mappers._ import firrtl.ir._ object RemoveEmpty extends Pass { - def name = "Remove Empty Statements" private def onModule(m: DefModule): DefModule = { m match { case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body)) diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 7769eac2..865143a5 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -7,7 +7,6 @@ import firrtl.ir._ // Removes ValidIf as an optimization object RemoveValidIf extends Pass { - def name = "Remove ValidIfs" // Recursive. Removes ValidIf's private def onExp(e: Expression): Expression = { e map onExp match { diff --git a/src/main/scala/firrtl/passes/ReplaceAccesses.scala b/src/main/scala/firrtl/passes/ReplaceAccesses.scala index 13562717..c3a5bd4c 100644 --- a/src/main/scala/firrtl/passes/ReplaceAccesses.scala +++ b/src/main/scala/firrtl/passes/ReplaceAccesses.scala @@ -15,8 +15,6 @@ import scala.collection.mutable * TODO Fold in to High Firrtl Const Prop */ object ReplaceAccesses extends Pass { - def name = "Replace Accesses" - def run(c: Circuit): Circuit = { def onStmt(s: Statement): Statement = s map onStmt map onExp def onExp(e: Expression): Expression = e match { diff --git a/src/main/scala/firrtl/passes/Resolves.scala b/src/main/scala/firrtl/passes/Resolves.scala index e60e0478..c8ba43bf 100644 --- a/src/main/scala/firrtl/passes/Resolves.scala +++ b/src/main/scala/firrtl/passes/Resolves.scala @@ -7,7 +7,6 @@ import firrtl.ir._ import firrtl.Mappers._ object ResolveKinds extends Pass { - def name = "Resolve Kinds" type KindMap = collection.mutable.LinkedHashMap[String, Kind] def find_port(kinds: KindMap)(p: Port): Port = { @@ -46,7 +45,6 @@ object ResolveKinds extends Pass { } object ResolveGenders extends Pass { - def name = "Resolve Genders" def resolve_e(g: Gender)(e: Expression): Expression = e match { case ex: WRef => ex copy (gender = g) case WSubField(exp, name, tpe, _) => WSubField( @@ -79,7 +77,6 @@ object ResolveGenders extends Pass { } object CInferMDir extends Pass { - def name = "CInfer MDir" type MPortDirMap = collection.mutable.LinkedHashMap[String, MPortDir] def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match { diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 797292dc..a32f5366 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -13,7 +13,6 @@ import scala.collection.mutable // Splits compound expressions into simple expressions // and named intermediate nodes object SplitExpressions extends Pass { - def name = "Split Expressions" private def onModule(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 40783e21..deddb93e 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -32,8 +32,6 @@ import MemPortUtils.memType * to rename a */ object Uniquify extends Pass { - def name = "Uniquify Identifiers" - private case class UniquifyException(msg: String) extends FIRRTLException(msg) private def error(msg: String)(implicit sinfo: Info, mname: String) = throw new UniquifyException(s"$sinfo: [module $mname] $msg") diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala index b4df534f..330ca497 100644 --- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala +++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala @@ -24,7 +24,6 @@ import scala.collection.mutable * to emit correct verilog without needing to add temporary nodes */ object VerilogModulusCleanup extends Pass { - def name = "Add temporary nodes with verilog widths for modulus" private def onModule(m: Module): Module = { val namespace = Namespace(m) diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index a50fdc16..520075fe 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -11,7 +11,6 @@ import firrtl.Utils.throwInternalError object ZeroWidth extends Pass { - def name = this.getClass.getName private val ZERO = BigInt(0) private def removeZero(t: Type): Option[Type] = t match { case GroundType(IntWidth(ZERO)) => None diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala index 66139c49..bd2536ab 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala @@ -20,7 +20,6 @@ import Mappers._ * Write the result to writer. */ class ClockList(top: String, writer: Writer) extends Pass { - def name = this.getClass.getSimpleName def run(c: Circuit): Circuit = { // Build useful datastructures val childrenMap = getChildrenMap(c) diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala index feb7f42e..53787b1d 100644 --- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala +++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala @@ -20,7 +20,6 @@ import Mappers._ * expressions do not relate to ground types. */ object RemoveAllButClocks extends Pass { - def name = this.getClass.getSimpleName def onStmt(s: Statement): Statement = (s map onStmt) match { case DefWire(i, n, ClockType) => s case DefNode(i, n, value) if value.tpe == ClockType => s diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 668bc2e5..e48dc8c2 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -10,7 +10,6 @@ import wiring._ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform { def inputForm = MidForm def outputForm = MidForm - override def name = "Create Memory Annotations" def execute(state: CircuitState): CircuitState = reader match { case None => state case Some(r) => diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 9bd6a4ab..73fec1ee 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -29,7 +29,6 @@ object InferReadWriteAnnotation { // of any product term of the enable signal of the write port, then the readwrite // port is inferred. object InferReadWritePass extends Pass { - def name = "Infer ReadWrite Ports" type Netlist = collection.mutable.HashMap[String, Expression] type Statements = collection.mutable.ArrayBuffer[Statement] @@ -150,10 +149,10 @@ object InferReadWritePass extends Pass { // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite extends Transform with PassBased { +class InferReadWrite extends Transform with SeqTransformBased { def inputForm = MidForm def outputForm = MidForm - def passSeq = Seq( + def transforms = Seq( InferReadWritePass, CheckInitialization, InferTypes, @@ -163,6 +162,7 @@ class InferReadWrite extends Transform with PassBased { def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { case Nil => state case Seq(InferReadWriteAnnotation(CircuitName(state.circuit.main))) => - state.copy(circuit = runPasses(state.circuit)) + val ret = runTransforms(state) + CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames) } } diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala index 57c301b1..9debff7a 100644 --- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala +++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala @@ -15,9 +15,6 @@ import MemTransformUtils._ /** Changes memory port names to standard port names (i.e. RW0 instead T_408) */ object RenameAnnotatedMemoryPorts extends Pass { - - def name = "Rename Annotated Memory Ports" - /** Renames memory ports to a standard naming scheme: * - R0, R1, ... for each read port * - W0, W1, ... for each write port diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index af6761fd..b18ed289 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -33,7 +33,6 @@ object PinAnnotation { * Creates the minimum # of black boxes needed by the design. */ class ReplaceMemMacros(writer: ConfWriter) extends Transform { - override def name = "Replace Memory Macros" def inputForm = MidForm def outputForm = MidForm @@ -227,11 +226,14 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform { case Seq(PinAnnotation(CircuitName(c), pins)) => pins case _ => throwInternalError } - val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) => + val annos = (pins.foldLeft(Seq[Annotation]()) { (seq, pin) => seq ++ memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), pin) } - } + }) ++ (state.annotations match { + case None => Seq.empty + case Some(a) => a.annotations + }) CircuitState(c.copy(modules = modules ++ memMods), inputForm, Some(AnnotationMap(annos))) } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 0c12d2aa..caaf430b 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -116,10 +116,10 @@ class SimpleTransform(p: Pass, form: CircuitForm) extends Transform { class SimpleMidTransform(p: Pass) extends SimpleTransform(p, MidForm) // SimpleRun instead of PassBased because of the arguments to passSeq -class ReplSeqMem extends Transform with SimpleRun { +class ReplSeqMem extends Transform { def inputForm = MidForm def outputForm = MidForm - def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = + def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = Seq(new SimpleMidTransform(Legalize), new SimpleMidTransform(ToMemIR), new SimpleMidTransform(ResolveMaskGranularity), @@ -134,31 +134,19 @@ class ReplSeqMem extends Transform with SimpleRun { new SimpleMidTransform(Uniquify), new SimpleMidTransform(ResolveKinds), new SimpleMidTransform(ResolveGenders)) - def run(state: CircuitState, xForms: Seq[Transform]): CircuitState = { - (xForms.foldLeft(state) { case (curState: CircuitState, xForm: Transform) => - val res = xForm.execute(curState) - val newAnnotations = res.annotations match { - case None => curState.annotations - case Some(ann) => - Some(AnnotationMap(ann.annotations ++ curState.annotations.get.annotations)) - } - CircuitState(res.circuit, res.form, newAnnotations) - }) - } - def execute(state: CircuitState): CircuitState = - getMyAnnotations(state) match { - case Nil => state // Do nothing if there are no annotations - case p => (p.collectFirst { case a if (a.target == CircuitName(state.circuit.main)) => a }) match { - case Some(ReplSeqMemAnnotation(target, inputFileName, outputConfig)) => - val inConfigFile = { - if (inputFileName.isEmpty) None - else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) - else error("Input configuration file does not exist!") - } - val outConfigFile = new ConfWriter(outputConfig) - run(state, passSeq(inConfigFile, outConfigFile)) - case _ => error("Unexpected transform annotation") - } + def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { + case Nil => state // Do nothing if there are no annotations + case p => (p.collectFirst { case a if (a.target == CircuitName(state.circuit.main)) => a }) match { + case Some(ReplSeqMemAnnotation(target, inputFileName, outputConfig)) => + val inConfigFile = { + if (inputFileName.isEmpty) None + else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) + else error("Input configuration file does not exist!") + } + val outConfigFile = new ConfWriter(outputConfig) + transforms(inConfigFile, outConfigFile).foldLeft(state) { (in, xform) => xform.runTransform(in) } + case _ => error("Unexpected transform annotation") } + } } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 956bdd3c..79ecd9cd 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -90,7 +90,6 @@ object AnalysisUtils { * TODO(shunshou): Add floorplan info? */ object ResolveMaskGranularity extends Pass { - def name = "Resolve Mask Granularity" /** Returns the number of mask bits, if used */ diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala index df555e57..e132e369 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -51,6 +51,6 @@ class ResolveMemoryReference extends Transform { case annos => annos.collect { case NoDedupMemAnnotation(ComponentName(cn, _)) => cn } } - CircuitState(run(state.circuit, noDedups), state.form) + state.copy(circuit=run(state.circuit, noDedups)) } } diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala index eb9d0859..feb6ae59 100644 --- a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala @@ -13,8 +13,6 @@ import firrtl.ir._ * - zero or one read port */ object ToMemIR extends Pass { - def name = "To Memory IR" - /** Only annotate memories that are candidates for memory macro replacements * i.e. rw, w + r (read, write 1 cycle delay) */ diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index fc126b74..6eefb69e 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -12,7 +12,6 @@ import MemPortUtils._ /** This pass generates delay reigsters for memories for verilog */ object VerilogMemDelays extends Pass { - def name = "Verilog Memory Delays" val ug = UNKNOWNGENDER type Netlist = collection.mutable.HashMap[String, Expression] implicit def expToString(e: Expression): String = e.serialize diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala index f5da4c06..9656abb2 100644 --- a/src/main/scala/firrtl/passes/wiring/Wiring.scala +++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala @@ -17,7 +17,6 @@ case class WiringException(msg: String) extends PassException(msg) case class WiringInfo(source: String, comp: String, sinks: Set[String], pin: String, top: String) class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { - def name = this.getClass.getSimpleName def run(c: Circuit): Circuit = { wiSeq.foldLeft(c) { (circuit, wi) => wire(circuit, wi) } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index 2c122943..a8ef5f58 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -60,10 +60,10 @@ object TopAnnotation { * Notes: * - No module uniquification occurs (due to imposed restrictions) */ -class WiringTransform extends Transform with SimpleRun { +class WiringTransform extends Transform { def inputForm = MidForm def outputForm = MidForm - def passSeq(wis: Seq[WiringInfo]) = + def transforms(wis: Seq[WiringInfo]) = Seq(new Wiring(wis), InferTypes, ResolveKinds, @@ -89,7 +89,7 @@ class WiringTransform extends Transform with SimpleRun { val wis = tops.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, top)) => seq :+ WiringInfo(sources(pin), comp(pin), sinks(pin), pin, top) } - state.copy(circuit = runPasses(state.circuit, passSeq(wis))) + transforms(wis).foldLeft(state) { (in, xform) => xform.runTransform(in) } case _ => error("Wrong number of sources, tops, or sinks!") } } |
