From 67eb4e2de6166b8f1eb5190215640117b82e8c48 Mon Sep 17 00:00:00 2001 From: Adam Izraelevitz Date: Thu, 23 Mar 2017 16:16:24 -0700 Subject: Pass now subclasses Transform (#477) --- src/main/scala/firrtl/passes/CheckChirrtl.scala | 1 - .../scala/firrtl/passes/CheckInitialization.scala | 2 -- src/main/scala/firrtl/passes/CheckWidths.scala | 1 - src/main/scala/firrtl/passes/Checks.scala | 3 -- .../passes/CommonSubexpressionElimination.scala | 2 -- src/main/scala/firrtl/passes/ConstProp.scala | 2 -- .../scala/firrtl/passes/ConvertFixedToSInt.scala | 1 - .../scala/firrtl/passes/DeadCodeElimination.scala | 2 -- src/main/scala/firrtl/passes/ExpandWhens.scala | 1 - src/main/scala/firrtl/passes/InferTypes.scala | 2 -- src/main/scala/firrtl/passes/InferWidths.scala | 1 - src/main/scala/firrtl/passes/Inline.scala | 1 - src/main/scala/firrtl/passes/LowerTypes.scala | 2 -- src/main/scala/firrtl/passes/PadWidths.scala | 1 - src/main/scala/firrtl/passes/Passes.scala | 26 ++++++++------ src/main/scala/firrtl/passes/RemoveAccesses.scala | 2 -- src/main/scala/firrtl/passes/RemoveCHIRRTL.scala | 2 -- src/main/scala/firrtl/passes/RemoveEmpty.scala | 1 - src/main/scala/firrtl/passes/RemoveValidIf.scala | 1 - src/main/scala/firrtl/passes/ReplaceAccesses.scala | 2 -- src/main/scala/firrtl/passes/Resolves.scala | 3 -- .../scala/firrtl/passes/SplitExpressions.scala | 1 - src/main/scala/firrtl/passes/Uniquify.scala | 2 -- .../firrtl/passes/VerilogModulusCleanup.scala | 1 - src/main/scala/firrtl/passes/ZeroWidth.scala | 1 - .../scala/firrtl/passes/clocklist/ClockList.scala | 1 - .../passes/clocklist/RemoveAllButClocks.scala | 1 - .../scala/firrtl/passes/memlib/DecorateMems.scala | 1 - .../firrtl/passes/memlib/InferReadWrite.scala | 8 ++--- .../passes/memlib/RenameAnnotatedMemoryPorts.scala | 3 -- .../firrtl/passes/memlib/ReplaceMemMacros.scala | 8 +++-- .../firrtl/passes/memlib/ReplaceMemTransform.scala | 42 ++++++++-------------- .../passes/memlib/ResolveMaskGranularity.scala | 1 - .../passes/memlib/ResolveMemoryReference.scala | 2 +- src/main/scala/firrtl/passes/memlib/ToMemIR.scala | 2 -- .../firrtl/passes/memlib/VerilogMemDelays.scala | 1 - src/main/scala/firrtl/passes/wiring/Wiring.scala | 1 - .../firrtl/passes/wiring/WiringTransform.scala | 6 ++-- 38 files changed, 44 insertions(+), 97 deletions(-) (limited to 'src/main/scala/firrtl/passes') diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index ef189c11..3722fd0d 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -8,7 +8,6 @@ import firrtl.Utils._ import firrtl.Mappers._ object CheckChirrtl extends Pass { - def name = "Chirrtl Check" type NameSet = collection.mutable.HashSet[String] class NotUniqueException(info: Info, mname: String, name: String) extends PassException( diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 84d6b448..4c392510 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -15,8 +15,6 @@ import annotation.tailrec * @note Assumes single connection (ie. no last connect semantics) */ object CheckInitialization extends Pass { - def name = "Check Initialization" - private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression]) class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException( diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 4b0b1c0d..24735009 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -9,7 +9,6 @@ import firrtl.Mappers._ import firrtl.Utils._ object CheckWidths extends Pass { - def name = "Width Check" /** The maximum allowed width for any circuit element */ val MaxWidth = 1000000 val DshlMaxWidth = ceilLog2(MaxWidth + 1) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index bd4c7f63..0bebcd18 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -10,7 +10,6 @@ import firrtl.Mappers._ import firrtl.WrappedType._ object CheckHighForm extends Pass { - def name = "High Form Check" type NameSet = collection.mutable.HashSet[String] // Custom Exceptions @@ -202,7 +201,6 @@ object CheckHighForm extends Pass { } object CheckTypes extends Pass { - def name = "Check Types" // Custom Exceptions class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException( @@ -463,7 +461,6 @@ object CheckTypes extends Pass { } object CheckGenders extends Pass { - def name = "Check Genders" type GenderMap = collection.mutable.HashMap[String, Gender] implicit def toStr(g: Gender): String = g match { diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 40d04d07..0abdaa36 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -10,8 +10,6 @@ import firrtl.Mappers._ import annotation.tailrec object CommonSubexpressionElimination extends Pass { - def name = "Common Subexpression Elimination" - private def cseOnce(s: Statement): (Statement, Long) = { var nEliminated = 0L val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]() diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index a5a8238e..8736ee31 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -11,8 +11,6 @@ import firrtl.PrimOps._ import annotation.tailrec object ConstProp extends Pass { - def name = "Constant Propagation" - private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 823fb7fb..2e151741 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -12,7 +12,6 @@ import firrtl.Utils.{sub_type, module_type, field_type, BoolType, max, min, pow_ /** Replaces FixedType with SIntType, and correctly aligns all binary points */ object ConvertFixedToSInt extends Pass { - def name = "Convert Fixed Types to SInt Types" def alignArg(e: Expression, point: BigInt): Expression = e.tpe match { case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p) if((point - p) > 0) { diff --git a/src/main/scala/firrtl/passes/DeadCodeElimination.scala b/src/main/scala/firrtl/passes/DeadCodeElimination.scala index 6f37feae..9f249f35 100644 --- a/src/main/scala/firrtl/passes/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/passes/DeadCodeElimination.scala @@ -10,8 +10,6 @@ import firrtl.Mappers._ import annotation.tailrec object DeadCodeElimination extends Pass { - def name = "Dead Code Elimination" - private def dceOnce(s: Statement): (Statement, Long) = { val referenced = collection.mutable.HashSet[String]() var nEliminated = 0L diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index a2845f43..1f093dd1 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -25,7 +25,6 @@ import collection.immutable.ListSet * @note Assumes all references are declared */ object ExpandWhens extends Pass { - def name = "Expand Whens" type NodeMap = mutable.HashMap[MemoizedHash[Expression], String] type Netlist = mutable.LinkedHashMap[WrappedExpression, Expression] type Simlist = mutable.ArrayBuffer[Statement] diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 0e503115..2de2a76e 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -8,7 +8,6 @@ import firrtl.Utils._ import firrtl.Mappers._ object InferTypes extends Pass { - def name = "Infer Types" type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { @@ -76,7 +75,6 @@ object InferTypes extends Pass { } object CInferTypes extends Pass { - def name = "CInfer Types" type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index f3b77ec5..11b819ce 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -12,7 +12,6 @@ import firrtl.Utils._ import firrtl.Mappers._ object InferWidths extends Pass { - def name = "Infer Widths" type ConstraintMap = collection.mutable.LinkedHashMap[String, Width] def solve_constraints(l: Seq[WGeq]): ConstraintMap = { diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index f4556733..2e15f09c 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -27,7 +27,6 @@ class InlineInstances extends Transform { def inputForm = LowForm def outputForm = LowForm val inlineDelim = "$" - override def name = "Inline Instances" private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) { diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 23518d14..5826f56e 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -21,8 +21,6 @@ import firrtl.Mappers._ * }}} */ object LowerTypes extends Pass { - def name = "Lower Types" - /** Delimiter used in lowering names */ val delim = "_" /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 398cc6d7..c9aa1539 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -9,7 +9,6 @@ import firrtl.Mappers._ // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { - def name = "Pad Widths" private def width(t: Type): Int = bitWidth(t).toInt private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index c595727e..68f278a9 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -12,9 +12,23 @@ import firrtl.PrimOps._ import scala.collection.mutable -trait Pass extends LazyLogging { - def name: String +/** [[Pass]] is simple transform that is generally part of a larger [[Transform]] + * Has an [[UnknownForm]], because larger [[Transform]] should specify form + */ +trait Pass extends Transform { + def inputForm: CircuitForm = UnknownForm + def outputForm: CircuitForm = UnknownForm def run(c: Circuit): Circuit + def execute(state: CircuitState): CircuitState = { + val result = (state.form, inputForm) match { + case (_, UnknownForm) => run(state.circuit) + case (UnknownForm, _) => run(state.circuit) + case (x, y) if x > y => + error(s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") + case _ => run(state.circuit) + } + CircuitState(result, outputForm, state.annotations, state.renames) + } } // Error handling @@ -34,8 +48,6 @@ class Errors { // These should be distributed into separate files object ToWorkingIR extends Pass { - def name = "Working IR" - def toExp(e: Expression): Expression = e map toExp match { case ex: Reference => WRef(ex.name, ex.tpe, NodeKind, UNKNOWNGENDER) case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, UNKNOWNGENDER) @@ -54,7 +66,6 @@ object ToWorkingIR extends Pass { } object PullMuxes extends Pass { - def name = "Pull Muxes" def run(c: Circuit): Circuit = { def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { case ex: WSubField => ex.exp match { @@ -93,7 +104,6 @@ object PullMuxes extends Pass { } object ExpandConnects extends Pass { - def name = "Expand Connects" def run(c: Circuit): Circuit = { def expand_connects(m: Module): Module = { val genders = collection.mutable.LinkedHashMap[String,Gender]() @@ -171,7 +181,6 @@ object ExpandConnects extends Pass { // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts // TODO replace UInt with zero-width wire instead object Legalize extends Pass { - def name = "Legalize" private def legalizeShiftRight(e: DoPrim): Expression = { require(e.op == Shr) val amount = e.consts.head.toInt @@ -244,7 +253,6 @@ object Legalize extends Pass { } object VerilogWrap extends Pass { - def name = "Verilog Wrap" def vWrapE(e: Expression): Expression = e map vWrapE match { case e: DoPrim => e.op match { case Tail => e.args.head match { @@ -271,7 +279,6 @@ object VerilogWrap extends Pass { } object VerilogRename extends Pass { - def name = "Verilog Rename" def verilogRenameN(n: String): String = if (v_keywords(n)) "%s$".format(n) else n @@ -301,7 +308,6 @@ object VerilogRename extends Pass { * @note The result of this pass is NOT legal Firrtl */ object VerilogPrep extends Pass { - def name = "Verilog Prep" type AttachSourceMap = Map[WrappedExpression, Expression] diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index a8bc9fb2..5d74d5ba 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -14,8 +14,6 @@ import scala.collection.mutable /** Removes all [[firrtl.WSubAccess]] from circuit */ object RemoveAccesses extends Pass { - def name = "Remove Accesses" - private def AND(e1: Expression, e2: Expression) = DoPrim(And, Seq(e1, e2), Nil, BoolType) diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index aae4ca80..b072dfa0 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -15,8 +15,6 @@ case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], read case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean) object RemoveCHIRRTL extends Pass { - def name = "Remove CHIRRTL" - val ut = UnknownType type MPortMap = collection.mutable.LinkedHashMap[String, MPorts] type SeqMemSet = collection.mutable.HashSet[String] diff --git a/src/main/scala/firrtl/passes/RemoveEmpty.scala b/src/main/scala/firrtl/passes/RemoveEmpty.scala index 0fdfc4d9..97c86dda 100644 --- a/src/main/scala/firrtl/passes/RemoveEmpty.scala +++ b/src/main/scala/firrtl/passes/RemoveEmpty.scala @@ -8,7 +8,6 @@ import firrtl.Mappers._ import firrtl.ir._ object RemoveEmpty extends Pass { - def name = "Remove Empty Statements" private def onModule(m: DefModule): DefModule = { m match { case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body)) diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 7769eac2..865143a5 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -7,7 +7,6 @@ import firrtl.ir._ // Removes ValidIf as an optimization object RemoveValidIf extends Pass { - def name = "Remove ValidIfs" // Recursive. Removes ValidIf's private def onExp(e: Expression): Expression = { e map onExp match { diff --git a/src/main/scala/firrtl/passes/ReplaceAccesses.scala b/src/main/scala/firrtl/passes/ReplaceAccesses.scala index 13562717..c3a5bd4c 100644 --- a/src/main/scala/firrtl/passes/ReplaceAccesses.scala +++ b/src/main/scala/firrtl/passes/ReplaceAccesses.scala @@ -15,8 +15,6 @@ import scala.collection.mutable * TODO Fold in to High Firrtl Const Prop */ object ReplaceAccesses extends Pass { - def name = "Replace Accesses" - def run(c: Circuit): Circuit = { def onStmt(s: Statement): Statement = s map onStmt map onExp def onExp(e: Expression): Expression = e match { diff --git a/src/main/scala/firrtl/passes/Resolves.scala b/src/main/scala/firrtl/passes/Resolves.scala index e60e0478..c8ba43bf 100644 --- a/src/main/scala/firrtl/passes/Resolves.scala +++ b/src/main/scala/firrtl/passes/Resolves.scala @@ -7,7 +7,6 @@ import firrtl.ir._ import firrtl.Mappers._ object ResolveKinds extends Pass { - def name = "Resolve Kinds" type KindMap = collection.mutable.LinkedHashMap[String, Kind] def find_port(kinds: KindMap)(p: Port): Port = { @@ -46,7 +45,6 @@ object ResolveKinds extends Pass { } object ResolveGenders extends Pass { - def name = "Resolve Genders" def resolve_e(g: Gender)(e: Expression): Expression = e match { case ex: WRef => ex copy (gender = g) case WSubField(exp, name, tpe, _) => WSubField( @@ -79,7 +77,6 @@ object ResolveGenders extends Pass { } object CInferMDir extends Pass { - def name = "CInfer MDir" type MPortDirMap = collection.mutable.LinkedHashMap[String, MPortDir] def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match { diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 797292dc..a32f5366 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -13,7 +13,6 @@ import scala.collection.mutable // Splits compound expressions into simple expressions // and named intermediate nodes object SplitExpressions extends Pass { - def name = "Split Expressions" private def onModule(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 40783e21..deddb93e 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -32,8 +32,6 @@ import MemPortUtils.memType * to rename a */ object Uniquify extends Pass { - def name = "Uniquify Identifiers" - private case class UniquifyException(msg: String) extends FIRRTLException(msg) private def error(msg: String)(implicit sinfo: Info, mname: String) = throw new UniquifyException(s"$sinfo: [module $mname] $msg") diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala index b4df534f..330ca497 100644 --- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala +++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala @@ -24,7 +24,6 @@ import scala.collection.mutable * to emit correct verilog without needing to add temporary nodes */ object VerilogModulusCleanup extends Pass { - def name = "Add temporary nodes with verilog widths for modulus" private def onModule(m: Module): Module = { val namespace = Namespace(m) diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index a50fdc16..520075fe 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -11,7 +11,6 @@ import firrtl.Utils.throwInternalError object ZeroWidth extends Pass { - def name = this.getClass.getName private val ZERO = BigInt(0) private def removeZero(t: Type): Option[Type] = t match { case GroundType(IntWidth(ZERO)) => None diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala index 66139c49..bd2536ab 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala @@ -20,7 +20,6 @@ import Mappers._ * Write the result to writer. */ class ClockList(top: String, writer: Writer) extends Pass { - def name = this.getClass.getSimpleName def run(c: Circuit): Circuit = { // Build useful datastructures val childrenMap = getChildrenMap(c) diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala index feb7f42e..53787b1d 100644 --- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala +++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala @@ -20,7 +20,6 @@ import Mappers._ * expressions do not relate to ground types. */ object RemoveAllButClocks extends Pass { - def name = this.getClass.getSimpleName def onStmt(s: Statement): Statement = (s map onStmt) match { case DefWire(i, n, ClockType) => s case DefNode(i, n, value) if value.tpe == ClockType => s diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 668bc2e5..e48dc8c2 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -10,7 +10,6 @@ import wiring._ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform { def inputForm = MidForm def outputForm = MidForm - override def name = "Create Memory Annotations" def execute(state: CircuitState): CircuitState = reader match { case None => state case Some(r) => diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 9bd6a4ab..73fec1ee 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -29,7 +29,6 @@ object InferReadWriteAnnotation { // of any product term of the enable signal of the write port, then the readwrite // port is inferred. object InferReadWritePass extends Pass { - def name = "Infer ReadWrite Ports" type Netlist = collection.mutable.HashMap[String, Expression] type Statements = collection.mutable.ArrayBuffer[Statement] @@ -150,10 +149,10 @@ object InferReadWritePass extends Pass { // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite extends Transform with PassBased { +class InferReadWrite extends Transform with SeqTransformBased { def inputForm = MidForm def outputForm = MidForm - def passSeq = Seq( + def transforms = Seq( InferReadWritePass, CheckInitialization, InferTypes, @@ -163,6 +162,7 @@ class InferReadWrite extends Transform with PassBased { def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { case Nil => state case Seq(InferReadWriteAnnotation(CircuitName(state.circuit.main))) => - state.copy(circuit = runPasses(state.circuit)) + val ret = runTransforms(state) + CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames) } } diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala index 57c301b1..9debff7a 100644 --- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala +++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala @@ -15,9 +15,6 @@ import MemTransformUtils._ /** Changes memory port names to standard port names (i.e. RW0 instead T_408) */ object RenameAnnotatedMemoryPorts extends Pass { - - def name = "Rename Annotated Memory Ports" - /** Renames memory ports to a standard naming scheme: * - R0, R1, ... for each read port * - W0, W1, ... for each write port diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index af6761fd..b18ed289 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -33,7 +33,6 @@ object PinAnnotation { * Creates the minimum # of black boxes needed by the design. */ class ReplaceMemMacros(writer: ConfWriter) extends Transform { - override def name = "Replace Memory Macros" def inputForm = MidForm def outputForm = MidForm @@ -227,11 +226,14 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform { case Seq(PinAnnotation(CircuitName(c), pins)) => pins case _ => throwInternalError } - val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) => + val annos = (pins.foldLeft(Seq[Annotation]()) { (seq, pin) => seq ++ memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), pin) } - } + }) ++ (state.annotations match { + case None => Seq.empty + case Some(a) => a.annotations + }) CircuitState(c.copy(modules = modules ++ memMods), inputForm, Some(AnnotationMap(annos))) } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 0c12d2aa..caaf430b 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -116,10 +116,10 @@ class SimpleTransform(p: Pass, form: CircuitForm) extends Transform { class SimpleMidTransform(p: Pass) extends SimpleTransform(p, MidForm) // SimpleRun instead of PassBased because of the arguments to passSeq -class ReplSeqMem extends Transform with SimpleRun { +class ReplSeqMem extends Transform { def inputForm = MidForm def outputForm = MidForm - def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = + def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = Seq(new SimpleMidTransform(Legalize), new SimpleMidTransform(ToMemIR), new SimpleMidTransform(ResolveMaskGranularity), @@ -134,31 +134,19 @@ class ReplSeqMem extends Transform with SimpleRun { new SimpleMidTransform(Uniquify), new SimpleMidTransform(ResolveKinds), new SimpleMidTransform(ResolveGenders)) - def run(state: CircuitState, xForms: Seq[Transform]): CircuitState = { - (xForms.foldLeft(state) { case (curState: CircuitState, xForm: Transform) => - val res = xForm.execute(curState) - val newAnnotations = res.annotations match { - case None => curState.annotations - case Some(ann) => - Some(AnnotationMap(ann.annotations ++ curState.annotations.get.annotations)) - } - CircuitState(res.circuit, res.form, newAnnotations) - }) - } - def execute(state: CircuitState): CircuitState = - getMyAnnotations(state) match { - case Nil => state // Do nothing if there are no annotations - case p => (p.collectFirst { case a if (a.target == CircuitName(state.circuit.main)) => a }) match { - case Some(ReplSeqMemAnnotation(target, inputFileName, outputConfig)) => - val inConfigFile = { - if (inputFileName.isEmpty) None - else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) - else error("Input configuration file does not exist!") - } - val outConfigFile = new ConfWriter(outputConfig) - run(state, passSeq(inConfigFile, outConfigFile)) - case _ => error("Unexpected transform annotation") - } + def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { + case Nil => state // Do nothing if there are no annotations + case p => (p.collectFirst { case a if (a.target == CircuitName(state.circuit.main)) => a }) match { + case Some(ReplSeqMemAnnotation(target, inputFileName, outputConfig)) => + val inConfigFile = { + if (inputFileName.isEmpty) None + else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) + else error("Input configuration file does not exist!") + } + val outConfigFile = new ConfWriter(outputConfig) + transforms(inConfigFile, outConfigFile).foldLeft(state) { (in, xform) => xform.runTransform(in) } + case _ => error("Unexpected transform annotation") } + } } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 956bdd3c..79ecd9cd 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -90,7 +90,6 @@ object AnalysisUtils { * TODO(shunshou): Add floorplan info? */ object ResolveMaskGranularity extends Pass { - def name = "Resolve Mask Granularity" /** Returns the number of mask bits, if used */ diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala index df555e57..e132e369 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -51,6 +51,6 @@ class ResolveMemoryReference extends Transform { case annos => annos.collect { case NoDedupMemAnnotation(ComponentName(cn, _)) => cn } } - CircuitState(run(state.circuit, noDedups), state.form) + state.copy(circuit=run(state.circuit, noDedups)) } } diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala index eb9d0859..feb6ae59 100644 --- a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala @@ -13,8 +13,6 @@ import firrtl.ir._ * - zero or one read port */ object ToMemIR extends Pass { - def name = "To Memory IR" - /** Only annotate memories that are candidates for memory macro replacements * i.e. rw, w + r (read, write 1 cycle delay) */ diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index fc126b74..6eefb69e 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -12,7 +12,6 @@ import MemPortUtils._ /** This pass generates delay reigsters for memories for verilog */ object VerilogMemDelays extends Pass { - def name = "Verilog Memory Delays" val ug = UNKNOWNGENDER type Netlist = collection.mutable.HashMap[String, Expression] implicit def expToString(e: Expression): String = e.serialize diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala index f5da4c06..9656abb2 100644 --- a/src/main/scala/firrtl/passes/wiring/Wiring.scala +++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala @@ -17,7 +17,6 @@ case class WiringException(msg: String) extends PassException(msg) case class WiringInfo(source: String, comp: String, sinks: Set[String], pin: String, top: String) class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { - def name = this.getClass.getSimpleName def run(c: Circuit): Circuit = { wiSeq.foldLeft(c) { (circuit, wi) => wire(circuit, wi) } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index 2c122943..a8ef5f58 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -60,10 +60,10 @@ object TopAnnotation { * Notes: * - No module uniquification occurs (due to imposed restrictions) */ -class WiringTransform extends Transform with SimpleRun { +class WiringTransform extends Transform { def inputForm = MidForm def outputForm = MidForm - def passSeq(wis: Seq[WiringInfo]) = + def transforms(wis: Seq[WiringInfo]) = Seq(new Wiring(wis), InferTypes, ResolveKinds, @@ -89,7 +89,7 @@ class WiringTransform extends Transform with SimpleRun { val wis = tops.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, top)) => seq :+ WiringInfo(sources(pin), comp(pin), sinks(pin), pin, top) } - state.copy(circuit = runPasses(state.circuit, passSeq(wis))) + transforms(wis).foldLeft(state) { (in, xform) => xform.runTransform(in) } case _ => error("Wrong number of sources, tops, or sinks!") } } -- cgit v1.2.3