diff options
| author | Schuyler Eldridge | 2019-12-17 18:29:47 -0500 |
|---|---|---|
| committer | Schuyler Eldridge | 2020-03-11 14:01:31 -0400 |
| commit | abf226471249a1cbb8de33d0c4bc8526f9aafa70 (patch) | |
| tree | 0537dff3091db3da167c0fffc3388a5966c46204 /src/main/scala/firrtl/passes | |
| parent | 646c91e71b8bfb1b0d0f22e81ca113147637ce71 (diff) | |
Migrate to DependencyAPI
Co-authored-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
Co-authored-by: Albert Magyar <albert.magyar@gmail.com>
Signed-off-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
Diffstat (limited to 'src/main/scala/firrtl/passes')
25 files changed, 439 insertions, 109 deletions
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 08237ab2..08c127da 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -2,8 +2,16 @@ package firrtl.passes +import firrtl.Transform import firrtl.ir._ +import firrtl.options.{Dependency, PreservesAll} + +object CheckChirrtl extends Pass with CheckHighFormLike with PreservesAll[Transform] { + + override val dependents = firrtl.stage.Forms.ChirrtlForm ++ + Seq( Dependency(CInferTypes), + Dependency(CInferMDir), + Dependency(RemoveCHIRRTL) ) -object CheckChirrtl extends Pass with CheckHighFormLike { def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] = None } diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 9fbf3eeb..63790564 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -6,6 +6,7 @@ import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.traversals.Foreachers._ +import firrtl.options.PreservesAll import annotation.tailrec @@ -14,7 +15,10 @@ import annotation.tailrec * @note This pass looks for [[firrtl.WVoid]]s left behind by [[ExpandWhens]] * @note Assumes single connection (ie. no last connect semantics) */ -object CheckInitialization extends Pass { +object CheckInitialization extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.Resolved + private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression]) class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException( diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 6ceac032..b750196a 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -9,8 +9,14 @@ import firrtl.traversals.Foreachers._ import firrtl.Utils._ import firrtl.constraint.IsKnown import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken} +import firrtl.options.{Dependency, PreservesAll} + +object CheckWidths extends Pass with PreservesAll[Transform] { + + override val prerequisites = Dependency[passes.InferWidths] +: firrtl.stage.Forms.WorkingIR + + override val dependents = Seq(Dependency[transforms.InferResets]) -object CheckWidths extends Pass { /** The maximum allowed width for any circuit element */ val MaxWidth = 1000000 val DshlMaxWidth = getUIntWidth(MaxWidth) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index a5f66a55..e176bcc4 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -9,8 +9,9 @@ import firrtl.Utils._ import firrtl.traversals.Foreachers._ import firrtl.WrappedType._ import firrtl.constraint.{Constraint, IsKnown} +import firrtl.options.{Dependency, PreservesAll} -trait CheckHighFormLike { +trait CheckHighFormLike { this: Pass => type NameSet = collection.mutable.HashSet[String] // Custom Exceptions @@ -267,7 +268,18 @@ trait CheckHighFormLike { } } -object CheckHighForm extends Pass with CheckHighFormLike { +object CheckHighForm extends Pass with CheckHighFormLike with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.WorkingIR + + override val dependents = + Seq( Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.Uniquify), + Dependency(passes.ResolveFlows), + Dependency[passes.InferWidths], + Dependency[transforms.InferResets] ) + class IllegalChirrtlMemException(info: Info, mname: String, name: String) extends PassException( s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.") @@ -279,7 +291,17 @@ object CheckHighForm extends Pass with CheckHighFormLike { Some(new IllegalChirrtlMemException(info, mname, memName)) } } -object CheckTypes extends Pass { + +object CheckTypes extends Pass with PreservesAll[Transform] { + + override val prerequisites = Dependency(InferTypes) +: firrtl.stage.Forms.WorkingIR + + override val dependents = + Seq( Dependency(passes.Uniquify), + Dependency(passes.ResolveFlows), + Dependency(passes.CheckFlows), + Dependency[passes.InferWidths], + Dependency(passes.CheckWidths) ) // Custom Exceptions class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException( @@ -583,7 +605,16 @@ object CheckTypes extends Pass { } } -object CheckFlows extends Pass { +object CheckFlows extends Pass with PreservesAll[Transform] { + + override val prerequisites = Dependency(passes.ResolveFlows) +: firrtl.stage.Forms.WorkingIR + + override val dependents = + Seq( Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals], + Dependency[passes.InferWidths], + Dependency[transforms.InferResets] ) + type FlowMap = collection.mutable.HashMap[String, Flow] implicit def toStr(g: Flow): String = g match { diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 3ba12b2d..d54d8088 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -5,9 +5,21 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} +object CommonSubexpressionElimination extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowForm ++ + Seq( Dependency(firrtl.passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[firrtl.transforms.CombineCats] ) + + override val dependents = + Seq( Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) -object CommonSubexpressionElimination extends Pass { private def cse(s: Statement): Statement = { val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]() val nodes = collection.mutable.HashMap[String, Expression]() diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 05a000c5..7e65bdd1 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -8,10 +8,20 @@ import firrtl.ir._ import firrtl._ import firrtl.Mappers._ import firrtl.Utils.{sub_type, module_type, field_type, max, throwInternalError} +import firrtl.options.{Dependency, PreservesAll} /** Replaces FixedType with SIntType, and correctly aligns all binary points */ -object ConvertFixedToSInt extends Pass { +object ConvertFixedToSInt extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck], + Dependency[RemoveIntervals] ) ++ firrtl.stage.Forms.Deduped + def alignArg(e: Expression, point: BigInt): Expression = e.tpe match { case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p) if((point - p) > 0) { @@ -83,29 +93,29 @@ object ConvertFixedToSInt extends Pass { types(name) = newType newStmt case WDefInstance(info, name, module, tpe) => - val newType = moduleTypes(module) + val newType = moduleTypes(module) types(name) = newType WDefInstance(info, name, module, newType) - case Connect(info, loc, exp) => + case Connect(info, loc, exp) => val point = calcPoint(Seq(loc)) val newExp = alignArg(exp, point) Connect(info, loc, newExp) map updateExpType - case PartialConnect(info, loc, exp) => + case PartialConnect(info, loc, exp) => val point = calcPoint(Seq(loc)) val newExp = alignArg(exp, point) PartialConnect(info, loc, newExp) map updateExpType // check Connect case, need to shl case s => (s map updateStmtType) map updateExpType } - + m.ports.foreach(p => types(p.name) = p.tpe) m match { case Module(info, name, ports, body) => Module(info,name,ports,updateStmtType(body)) case m:ExtModule => m } } - - val newModules = for(m <- c.modules) yield { + + val newModules = for(m <- c.modules) yield { val newPorts = m.ports.map(p => Port(p.info,p.name,p.direction,toSIntType(p.tpe))) m match { case Module(info, name, ports, body) => Module(info,name,newPorts,body) @@ -113,8 +123,13 @@ object ConvertFixedToSInt extends Pass { } } newModules.foreach(m => moduleTypes(m.name) = module_type(m)) - firrtl.passes.InferTypes.run(Circuit(c.info, newModules.map(onModule(_)), c.main )) + + /* @todo This should be moved outside */ + (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main )) } } + + + // vim: set ts=4 sw=4 et: diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 12aa9690..aaf3d9b4 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -8,6 +8,7 @@ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.WrappedExpression._ +import firrtl.options.Dependency import annotation.tailrec import collection.mutable @@ -24,6 +25,19 @@ import collection.mutable * @note Assumes all references are declared */ object ExpandWhens extends Pass { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency(Uniquify) ) ++ firrtl.stage.Forms.Resolved + + override def invalidates(a: Transform): Boolean = a match { + case CheckInitialization | ResolveKinds | InferTypes => true + case _ => false + } + /** Returns circuit with when and last connection semantics resolved */ def run(c: Circuit): Circuit = { val modulesx = c.modules map { @@ -287,3 +301,24 @@ object ExpandWhens extends Pass { private def NOT(e: Expression) = DoPrim(Eq, Seq(e, zero), Nil, BoolType) } + +class ExpandWhensAndCheck extends SeqTransform { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency(Uniquify) ) ++ firrtl.stage.Forms.Deduped + + override def invalidates(a: Transform): Boolean = a match { + case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true + case _ => false + } + + override def inputForm = UnknownForm + override def outputForm = UnknownForm + + override val transforms = Seq(ExpandWhens, CheckInitialization) + +} diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala index 258c9697..86bc36fc 100644 --- a/src/main/scala/firrtl/passes/InferBinaryPoints.scala +++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala @@ -7,8 +7,19 @@ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target} import firrtl.constraint.ConstraintSolver +import firrtl.Transform +import firrtl.options.{Dependency, PreservesAll} + +class InferBinaryPoints extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(ResolveKinds), + Dependency(InferTypes), + Dependency(Uniquify), + Dependency(ResolveFlows) ) + + override val dependents = Seq.empty -class InferBinaryPoints extends Pass { private val constraintSolver = new ConstraintSolver() private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { @@ -71,14 +82,14 @@ class InferBinaryPoints extends Pass { case _ => sys.error("Shouldn't be here") } private def fixType(t: Type): Type = t map fixType map fixWidth match { - case IntervalType(l, u, p) => + case IntervalType(l, u, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) case None => p case _ => sys.error("Shouldn't be here") } IntervalType(l, u, px) - case FixedType(w, p) => + case FixedType(w, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) case None => p diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 3c5cf7fb..d625b626 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -6,8 +6,12 @@ import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} + +object InferTypes extends Pass with PreservesAll[Transform] { + + override val prerequisites = Dependency(ResolveKinds) +: firrtl.stage.Forms.WorkingIR -object InferTypes extends Pass { type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { @@ -79,12 +83,15 @@ object InferTypes extends Pass { val types = new TypeMap m map infer_types_p(types) map infer_types_s(types) } - + c copy (modules = c.modules map infer_types) } } -object CInferTypes extends Pass { +object CInferTypes extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.ChirrtlForm + type TypeMap = collection.mutable.LinkedHashMap[String, Type] def run(c: Circuit): Circuit = { @@ -133,12 +140,12 @@ object CInferTypes extends Pass { types(p.name) = p.tpe p } - + def infer_types(m: DefModule): DefModule = { val types = new TypeMap m map infer_types_p(types) map infer_types_s(types) } - + c copy (modules = c.modules map infer_types) } } diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 2211d238..29936ca0 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -11,6 +11,8 @@ import firrtl.Mappers._ import firrtl.Implicits.width2constraint import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target} import firrtl.constraint.{ConstraintSolver, IsMax} +import firrtl.options.{Dependency, PreservesAll} +import firrtl.traversals.Foreachers._ object InferWidths { def apply(): InferWidths = new InferWidths() @@ -60,7 +62,16 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg * * Uses firrtl.constraint package to infer widths */ -class InferWidths extends Transform with ResolvedAnnotationPaths { +class InferWidths extends Transform with ResolvedAnnotationPaths with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.Uniquify), + Dependency(passes.ResolveFlows), + Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR + def inputForm: CircuitForm = UnknownForm def outputForm: CircuitForm = UnknownForm @@ -108,12 +119,12 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { val n = get_size(c.loc.tpe) val locs = create_exps(c.loc) val exps = create_exps(c.expr) - (locs zip exps).foreach { case (loc, exp) => - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } - } + (locs zip exps).foreach { case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } + } c case pc: PartialConnect => val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) @@ -142,8 +153,8 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { } a case c: Conditionally => - addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1))) - c map addStmtConstraints(mt) + addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1))) + c map addStmtConstraints(mt) case x => x map addStmtConstraints(mt) } private def fixWidth(w: Width): Width = constraintSolver.get(w) match { @@ -152,7 +163,7 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { case _ => sys.error("Shouldn't be here") } private def fixType(t: Type): Type = t map fixType map fixWidth match { - case IntervalType(l, u, p) => + case IntervalType(l, u, p) => val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match { case (Some(x: Bound), Some(y: Bound)) => (x, y) case (None, None) => (l, u) @@ -174,8 +185,8 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name))) constraintSolver.solve() val ret = InferTypes.run(c.copy(modules = c.modules map (_ - map fixPort - map fixStmt))) + map fixPort + map fixStmt))) constraintSolver.clear() ret } @@ -212,11 +223,11 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target => val baseType = typeMap.getOrElse(target.copy(component = Seq.empty), - throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint())) + throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint())) val leafType = target.componentType(baseType) if (leafType.isInstanceOf[AggregateType]) { throw new Exception(s"Target below is an AggregateType, which " + - "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) } leafType diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index f52e1e6b..73ef8a22 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -26,6 +26,15 @@ object LowerTypes extends Transform { def inputForm = UnknownForm def outputForm = UnknownForm + override val prerequisites = firrtl.stage.Forms.MidForm + + override val dependents = Seq.empty + + override def invalidates(a: Transform): Boolean = a match { + case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true + case _ => false + } + /** Delimiter used in lowering names */ val delim = "_" /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index cbd8250a..0b318511 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -6,9 +6,30 @@ package passes import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Mappers._ +import firrtl.options.Dependency + +import scala.collection.mutable // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { + + override val prerequisites = + ((new mutable.LinkedHashSet()) + ++ firrtl.stage.Forms.LowForm + - Dependency(firrtl.passes.Legalize) + + Dependency(firrtl.passes.RemoveValidIf) + + Dependency[firrtl.transforms.ConstantPropagation]).toSeq + + override val dependents = + Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) + + override def invalidates(a: Transform): Boolean = a match { + case _: firrtl.transforms.ConstantPropagation | Legalize => true + case _ => false + } + private def width(t: Type): Int = bitWidth(t).toInt private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 9a644dc8..a8d37758 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -7,6 +7,7 @@ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ +import firrtl.options.{Dependency, PreservesAll} import firrtl.transforms.ConstantPropagation import scala.collection.mutable @@ -46,7 +47,10 @@ class Errors { } // These should be distributed into separate files -object ToWorkingIR extends Pass { +object ToWorkingIR extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.MinimalHighForm + def toExp(e: Expression): Expression = e map toExp match { case ex: Reference => WRef(ex.name, ex.tpe, UnknownKind, UnknownFlow) case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, UnknownFlow) @@ -64,8 +68,11 @@ object ToWorkingIR extends Pass { c copy (modules = c.modules map (_ map toStmt)) } -object PullMuxes extends Pass { - def run(c: Circuit): Circuit = { +object PullMuxes extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.Deduped + + def run(c: Circuit): Circuit = { def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { case ex: WSubField => ex.expr match { case exx: Mux => Mux(exx.cond, @@ -102,7 +109,12 @@ object PullMuxes extends Pass { } } -object ExpandConnects extends Pass { +object ExpandConnects extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped + def run(c: Circuit): Circuit = { def expand_connects(m: Module): Module = { val flows = collection.mutable.LinkedHashMap[String,Flow]() @@ -179,7 +191,14 @@ object ExpandConnects extends Pass { // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts // TODO replace UInt with zero-width wire instead -object Legalize extends Pass { +object Legalize extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) + + override val optionalPrerequisites = Seq.empty + + override val dependents = Seq.empty + private def legalizeShiftRight(e: DoPrim): Expression = { require(e.op == Shr) e.args.head match { @@ -260,7 +279,22 @@ object Legalize extends Pass { * * @note The result of this pass is NOT legal Firrtl */ -object VerilogPrep extends Pass { +object VerilogPrep extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty type AttachSourceMap = Map[WrappedExpression, Expression] diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index 1c2dc096..ac5d8a4e 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -2,18 +2,30 @@ package firrtl.passes -import firrtl.{WRef, WSubAccess, WSubIndex, WSubField, Namespace} +import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubIndex, WSubField} import firrtl.PrimOps.{And, Eq} import firrtl.ir._ import firrtl.Mappers._ import firrtl.Utils._ import firrtl.WrappedExpression._ -import scala.collection.mutable +import firrtl.options.Dependency +import scala.collection.mutable /** Removes all [[firrtl.WSubAccess]] from circuit */ -class RemoveAccesses extends Pass { +object RemoveAccesses extends Pass { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped + + override def invalidates(a: Transform): Boolean = a match { + case Uniquify => true + case _ => false + } + private def AND(e1: Expression, e2: Expression) = if(e1 == one) e2 else if(e2 == one) e1 @@ -166,14 +178,3 @@ class RemoveAccesses extends Pass { }) } } - -object RemoveAccesses extends Pass { - def apply: Pass = { - new RemoveAccesses() - } - - def run(c: Circuit): Circuit = { - val t = new RemoveAccesses - t.run(c) - } -} diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 921ec3c7..05dd8bd9 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -8,12 +8,18 @@ import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} case class MPort(name: String, clk: Expression) case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort]) case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean) -object RemoveCHIRRTL extends Transform { +object RemoveCHIRRTL extends Transform with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.ChirrtlForm ++ + Seq( Dependency(passes.CInferTypes), + Dependency(passes.CInferMDir) ) + def inputForm: CircuitForm = UnknownForm def outputForm: CircuitForm = UnknownForm val ut = UnknownType diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala index 73f59b59..cf3d2ff2 100644 --- a/src/main/scala/firrtl/passes/RemoveIntervals.scala +++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala @@ -8,6 +8,7 @@ import firrtl._ import firrtl.Mappers._ import Implicits.{bigint2WInt} import firrtl.constraint.IsKnown +import firrtl.options.{Dependency, PreservesAll} import scala.math.BigDecimal.RoundingMode._ @@ -35,7 +36,14 @@ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) * c. replace with SIntType * 3) Run InferTypes */ -class RemoveIntervals extends Pass { +class RemoveIntervals extends Pass with PreservesAll[Transform] { + + override val prerequisites: Seq[Dependency[Transform]] = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck] ) ++ firrtl.stage.Forms.Deduped def run(c: Circuit): Circuit = { val alignedCircuit = c diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 42eae7e5..3b5499ac 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -2,9 +2,11 @@ package firrtl package passes + import firrtl.Mappers._ import firrtl.ir._ import Utils.throwInternalError +import firrtl.options.Dependency /** Remove [[firrtl.ir.ValidIf ValidIf]] and replace [[firrtl.ir.IsInvalid IsInvalid]] with a connection to zero */ object RemoveValidIf extends Pass { @@ -27,6 +29,17 @@ object RemoveValidIf extends Pass { case other => throwInternalError(s"Unexpected type $other") } + override val prerequisites = firrtl.stage.Forms.LowForm + + override val dependents = + Seq( Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) + + override def invalidates(a: Transform): Boolean = a match { + case Legalize | _: firrtl.transforms.ConstantPropagation => true + case _ => false + } + // Recursive. Removes ValidIfs private def onExp(e: Expression): Expression = { e map onExp match { diff --git a/src/main/scala/firrtl/passes/ReplaceAccesses.scala b/src/main/scala/firrtl/passes/ReplaceAccesses.scala index 2ec035f3..75cca77a 100644 --- a/src/main/scala/firrtl/passes/ReplaceAccesses.scala +++ b/src/main/scala/firrtl/passes/ReplaceAccesses.scala @@ -2,22 +2,26 @@ package firrtl.passes +import firrtl.Transform import firrtl.ir._ import firrtl.{WSubAccess, WSubIndex} import firrtl.Mappers._ - +import firrtl.options.{Dependency, PreservesAll} /** Replaces constant [[firrtl.WSubAccess]] with [[firrtl.WSubIndex]] * TODO Fold in to High Firrtl Const Prop */ -object ReplaceAccesses extends Pass { +object ReplaceAccesses extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.Deduped :+ Dependency(PullMuxes) + def run(c: Circuit): Circuit = { def onStmt(s: Statement): Statement = s map onStmt map onExp def onExp(e: Expression): Expression = e match { case WSubAccess(ex, UIntLiteral(value, width), t, g) => WSubIndex(onExp(ex), value.toInt, t, g) case _ => e map onExp } - + c copy (modules = c.modules map (_ map onStmt)) } } diff --git a/src/main/scala/firrtl/passes/Resolves.scala b/src/main/scala/firrtl/passes/Resolves.scala index 97cc4bb3..15750b76 100644 --- a/src/main/scala/firrtl/passes/Resolves.scala +++ b/src/main/scala/firrtl/passes/Resolves.scala @@ -5,9 +5,14 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} import Utils.throwInternalError -object ResolveKinds extends Pass { + +object ResolveKinds extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.WorkingIR + type KindMap = collection.mutable.LinkedHashMap[String, Kind] def find_port(kinds: KindMap)(p: Port): Port = { @@ -45,7 +50,13 @@ object ResolveKinds extends Pass { c copy (modules = c.modules map resolve_kinds) } -object ResolveFlows extends Pass { +object ResolveFlows extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.Uniquify) ) ++ firrtl.stage.Forms.WorkingIR + def resolve_e(g: Flow)(e: Expression): Expression = e match { case ex: WRef => ex copy (flow = g) case WSubField(exp, name, tpe, _) => WSubField( @@ -88,7 +99,10 @@ object ResolveGenders extends Pass { } -object CInferMDir extends Pass { +object CInferMDir extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.ChirrtlForm :+ Dependency(CInferTypes) + type MPortDirMap = collection.mutable.LinkedHashMap[String, MPortDir] def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match { diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index de955c9a..43d0ed34 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -3,7 +3,9 @@ package firrtl package passes +import firrtl.{SystemVerilogEmitter, VerilogEmitter} import firrtl.ir._ +import firrtl.options.{Dependency, PreservesAll} import firrtl.Mappers._ import firrtl.Utils.{kind, flow, get_info} @@ -12,7 +14,16 @@ import scala.collection.mutable // Splits compound expressions into simple expressions // and named intermediate nodes -object SplitExpressions extends Pass { +object SplitExpressions extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowForm ++ + Seq( Dependency(firrtl.passes.RemoveValidIf), + Dependency(firrtl.passes.memlib.VerilogMemDelays) ) + + override val dependents = + Seq( Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] ) + private def onModule(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala index f659815e..4e558e2a 100644 --- a/src/main/scala/firrtl/passes/TrimIntervals.scala +++ b/src/main/scala/firrtl/passes/TrimIntervals.scala @@ -6,6 +6,8 @@ import firrtl.PrimOps._ import firrtl.ir._ import firrtl.Mappers._ import firrtl.constraint.{IsFloor, IsKnown, IsMul} +import firrtl.options.{Dependency, PreservesAll} +import firrtl.Transform /** Replaces IntervalType with SIntType, three AST walks: * 1) Align binary points @@ -18,7 +20,17 @@ import firrtl.constraint.{IsFloor, IsKnown, IsMul} * c. replace with SIntType * 3) Run InferTypes */ -class TrimIntervals extends Pass { +class TrimIntervals extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(ResolveKinds), + Dependency(InferTypes), + Dependency(Uniquify), + Dependency(ResolveFlows), + Dependency[InferBinaryPoints] ) + + override val dependents = Seq.empty + def run(c: Circuit): Circuit = { // Open -> closed val firstPass = InferTypes.run(c map replaceModuleInterval) @@ -80,7 +92,7 @@ class TrimIntervals extends Pass { val shiftMul = Closed(BigDecimal(1) / shiftGain) val bpGain = BigDecimal(BigInt(1) << current.toInt) // BP is inferred at this point - // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt) + // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt) val newBPRes = Closed(shiftGain / bpGain) val bpResInv = Closed(bpGain) val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes) diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 978ccc66..1268cac2 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -8,8 +8,9 @@ import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -import MemPortUtils.memType +import firrtl.options.Dependency +import MemPortUtils.memType /** Resolve name collisions that would occur in [[LowerTypes]] * @@ -32,6 +33,16 @@ import MemPortUtils.memType * to rename a */ object Uniquify extends Transform { + + override val prerequisites = + Seq( Dependency(ResolveKinds), + Dependency(InferTypes) ) ++ firrtl.stage.Forms.WorkingIR + + override def invalidates(a: Transform): Boolean = a match { + case ResolveKinds | InferTypes => true + case _ => false + } + def inputForm = UnknownForm def outputForm = UnknownForm private case class UniquifyException(msg: String) extends FirrtlInternalException(msg) @@ -41,9 +52,9 @@ object Uniquify extends Transform { // For creation of rename map private case class NameMapNode(name: String, elts: Map[String, NameMapNode]) - // Appends delim to prefix until no collisions of prefix + elts in names - // We don't add an _ in the collision check because elts could be Seq("") - // In this case, we're just really checking if prefix itself collides + /** Appends delim to prefix until no collisions of prefix + elts in names We don't add an _ in the collision check + * because elts could be Seq("") In this case, we're just really checking if prefix itself collides + */ @tailrec def findValidPrefix( prefix: String, @@ -55,10 +66,12 @@ object Uniquify extends Transform { } } - // Enumerates all possible names for a given type - // eg. foo : { bar : { a, b }[2], c } - // => foo, foo bar, foo bar 0, foo bar 1, foo bar 0 a, foo bar 0 b, - // foo bar 1 a, foo bar 1 b, foo c + /** Enumerates all possible names for a given type. For example: + * {{{ + * foo : { bar : { a, b }[2], c } + * => foo, foo bar, foo bar 0, foo bar 1, foo bar 0 a, foo bar 0 b, foo bar 1 a, foo bar 1 b, foo c + * }}} + */ private [firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match { case t: BundleType => t.fields flatMap { f => @@ -72,6 +85,36 @@ object Uniquify extends Transform { case _ => Seq() } + /** Creates a Bundle Type from a Stmt */ + def stmtToType(s: Statement)(implicit sinfo: Info, mname: String): BundleType = { + // Recursive helper + def recStmtToType(s: Statement): Seq[Field] = s match { + case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe)) + case sx: DefRegister => Seq(Field(sx.name, Default, sx.tpe)) + case sx: WDefInstance => Seq(Field(sx.name, Default, sx.tpe)) + case sx: DefMemory => sx.dataType match { + case (_: UIntType | _: SIntType | _: FixedType) => + Seq(Field(sx.name, Default, memType(sx))) + case tpe: BundleType => + val newFields = tpe.fields map ( f => + DefMemory(sx.info, f.name, f.tpe, sx.depth, sx.writeLatency, + sx.readLatency, sx.readers, sx.writers, sx.readwriters) + ) flatMap recStmtToType + Seq(Field(sx.name, Default, BundleType(newFields))) + case tpe: VectorType => + val newFields = (0 until tpe.size) map ( i => + sx.copy(name = i.toString, dataType = tpe.tpe) + ) flatMap recStmtToType + Seq(Field(sx.name, Default, BundleType(newFields))) + } + case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe)) + case sx: Conditionally => recStmtToType(sx.conseq) ++ recStmtToType(sx.alt) + case sx: Block => (sx.stmts map recStmtToType).flatten + case sx => Seq() + } + BundleType(recStmtToType(s)) + } + // Accepts a Type and an initial namespace // Returns new Type with uniquified names private def uniquifyNames( @@ -202,36 +245,6 @@ object Uniquify extends Transform { case t => t } - // Creates a Bundle Type from a Stmt - def stmtToType(s: Statement)(implicit sinfo: Info, mname: String): BundleType = { - // Recursive helper - def recStmtToType(s: Statement): Seq[Field] = s match { - case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe)) - case sx: DefRegister => Seq(Field(sx.name, Default, sx.tpe)) - case sx: WDefInstance => Seq(Field(sx.name, Default, sx.tpe)) - case sx: DefMemory => sx.dataType match { - case (_: UIntType | _: SIntType | _: FixedType) => - Seq(Field(sx.name, Default, memType(sx))) - case tpe: BundleType => - val newFields = tpe.fields map ( f => - DefMemory(sx.info, f.name, f.tpe, sx.depth, sx.writeLatency, - sx.readLatency, sx.readers, sx.writers, sx.readwriters) - ) flatMap recStmtToType - Seq(Field(sx.name, Default, BundleType(newFields))) - case tpe: VectorType => - val newFields = (0 until tpe.size) map ( i => - sx.copy(name = i.toString, dataType = tpe.tpe) - ) flatMap recStmtToType - Seq(Field(sx.name, Default, BundleType(newFields))) - } - case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe)) - case sx: Conditionally => recStmtToType(sx.conseq) ++ recStmtToType(sx.alt) - case sx: Block => (sx.stmts map recStmtToType).flatten - case sx => Seq() - } - BundleType(recStmtToType(s)) - } - // Everything wrapped in run so that it's thread safe def execute(state: CircuitState): CircuitState = { val c = state.circuit diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala index fdc81797..f47ddfbd 100644 --- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala +++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala @@ -7,6 +7,7 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.PrimOps.{Bits, Rem} import firrtl.Utils._ +import firrtl.options.{Dependency, PreservesAll} import scala.collection.mutable @@ -23,7 +24,20 @@ import scala.collection.mutable * This is technically incorrect firrtl, but allows the verilog emitter * to emit correct verilog without needing to add temporary nodes */ -object VerilogModulusCleanup extends Pass { +object VerilogModulusCleanup extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty private def onModule(m: Module): Module = { val namespace = Namespace(m) diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index e01cfffc..e60d76d1 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -6,8 +6,24 @@ import firrtl.PrimOps._ import firrtl.ir._ import firrtl._ import firrtl.Mappers._ +import firrtl.options.Dependency object ZeroWidth extends Transform { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency(Uniquify), + Dependency[ExpandWhensAndCheck], + Dependency(ConvertFixedToSInt) ) ++ firrtl.stage.Forms.Deduped + + override def invalidates(a: Transform): Boolean = a match { + case InferTypes => true + case _ => false + } + def inputForm: CircuitForm = UnknownForm def outputForm: CircuitForm = UnknownForm diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index 80b5cbb8..e5e6d6d4 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -8,6 +8,8 @@ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ +import firrtl.transforms +import firrtl.options.Dependency import MemPortUtils._ import WrappedExpression._ @@ -69,7 +71,7 @@ object MemDelayAndReadwriteTransformer { * read and write ports while simultaneously compiling memory latencies to combinational-read * memories with delay pipelines. It is represented as a class that takes a module as a constructor * argument, as it encapsulates the mutable state required to analyze and transform one module. - * + * * @note The final transformed module is found in the (sole public) field [[transformed]] */ class MemDelayAndReadwriteTransformer(m: DefModule) { @@ -165,6 +167,18 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { } object VerilogMemDelays extends Pass { + + override val prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf) + + override val dependents = + Seq( Dependency[VerilogEmitter], + Dependency[SystemVerilogEmitter] ) + + override def invalidates(a: Transform): Boolean = a match { + case _: transforms.ConstantPropagation => true + case _ => false + } + def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) } |
